├── data └── put data here.txt ├── checkpoints └── put model checkpoints here.txt ├── assets └── poster.pdf ├── requirements.txt ├── scripts ├── evaluate.sh ├── inference_generate_data.sh ├── inference.sh ├── train_cvae.sh ├── train_diffharmony.sh ├── train_refinement_stage.sh ├── misc │ └── classify_cand_gen_data.py ├── inference │ ├── inverse.py │ └── main.py ├── evaluate │ └── main.py └── train │ └── cvae_with_gen_data.py ├── configs ├── acc_configs │ ├── multi_deepspeed.yaml │ ├── multi_default.yaml │ └── single_default.yaml └── stage2_configs │ ├── small.json │ ├── base.json │ └── large.json ├── src ├── dataset │ ├── harmony_gen.py │ └── ihd_dataset.py ├── utils.py └── models │ ├── unet_2d_blocks.py │ ├── unet_2d.py │ ├── vae.py │ └── condition_vae.py ├── README_zh.md ├── README.md └── .gitignore /data/put data here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/put model checkpoints here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nicecv/DiffHarmony/HEAD/assets/poster.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.30.1 2 | bitsandbytes==0.43.1 3 | deepspeed==0.14.4 4 | diffusers==0.26.3 5 | einops==0.8.0 6 | huggingface-hub==0.23.4 7 | opencv-python-headless==4.10.0.84 8 | pandas==2.2.2 9 | safetensors==0.4.3 10 | sentencepiece==0.2.0 11 | scikit-image==0.24.0 12 | tensorboard==2.16.2 13 | tqdm==4.66.4 14 | transformers==4.41.2 15 | wandb==0.17.3 -------------------------------------------------------------------------------- /scripts/evaluate.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=.:$PYTHONPATH 2 | 3 | OUTPUT_DIR="" 4 | DATA_DIR=data/iHarmony4 5 | TEST_FILE=test.jsonl 6 | 7 | python scripts/evaluate/main.py \ 8 | --input_dir $OUTPUT_DIR \ 9 | --output_dir $OUTPUT_DIR-evaluation \ 10 | --data_dir $DATA_DIR \ 11 | --json_file_path $TEST_FILE \ 12 | --resolution=256 \ 13 | --use_gt_bg \ 14 | --num_processes=32 -------------------------------------------------------------------------------- /configs/acc_configs/multi_deepspeed.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_accumulation_steps: 1 4 | gradient_clipping: 1.0 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | fsdp_config: {} 11 | machine_rank: 0 12 | main_process_ip: null 13 | main_process_port: null 14 | main_training_function: main 15 | mixed_precision: fp16 16 | num_machines: 1 17 | num_processes: 2 18 | use_cpu: false -------------------------------------------------------------------------------- /configs/acc_configs/multi_default.yaml: -------------------------------------------------------------------------------- 1 | command_file: null 2 | commands: null 3 | compute_environment: LOCAL_MACHINE 4 | deepspeed_config: {} 5 | distributed_type: MULTI_GPU 6 | downcast_bf16: 'no' 7 | dynamo_backend: 'NO' 8 | fsdp_config: {} 9 | gpu_ids: 10 | machine_rank: 0 11 | main_process_ip: null 12 | main_process_port: null 13 | main_training_function: main 14 | megatron_lm_config: {} 15 | mixed_precision: 'no' 16 | num_machines: 1 17 | num_processes: 2 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_name: null 21 | tpu_zone: null 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /configs/acc_configs/single_default.yaml: -------------------------------------------------------------------------------- 1 | command_file: null 2 | commands: null 3 | compute_environment: LOCAL_MACHINE 4 | deepspeed_config: {} 5 | distributed_type: 'NO' 6 | downcast_bf16: 'no' 7 | dynamo_backend: 'NO' 8 | fsdp_config: {} 9 | gpu_ids: 10 | machine_rank: 0 11 | main_process_ip: null 12 | main_process_port: null 13 | main_training_function: main 14 | megatron_lm_config: {} 15 | mixed_precision: 'no' 16 | num_machines: 1 17 | num_processes: 1 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_name: null 21 | tpu_zone: null 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /configs/stage2_configs/small.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNet2DCustom", 3 | "sample_size": 256, 4 | "in_channels": 3, 5 | "out_channels": 3, 6 | 7 | "down_block_types": [ 8 | "DownBlock2D", 9 | "DownBlock2D", 10 | "DownBlock2D", 11 | "DownBlock2D" 12 | ], 13 | "up_block_types": [ 14 | "UpBlock2D", 15 | "UpBlock2D", 16 | "UpBlock2D", 17 | "UpBlock2D" 18 | ], 19 | "block_out_channels": [64, 128, 256, 256], 20 | 21 | "layers_per_block": 2, 22 | "mid_block_scale_factor": 1, 23 | "downsample_padding": 0, 24 | "downsample_type": "conv", 25 | "upsample_type": "conv", 26 | "act_fn": "silu", 27 | "attention_head_dim": null, 28 | "norm_num_groups": 32, 29 | "norm_eps": 1e-6, 30 | "add_attention": true, 31 | "input_residual": true 32 | } -------------------------------------------------------------------------------- /configs/stage2_configs/base.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNet2DCustom", 3 | "sample_size": 256, 4 | "in_channels": 3, 5 | "out_channels": 3, 6 | 7 | "down_block_types": [ 8 | "DownBlock2D", 9 | "DownBlock2D", 10 | "DownBlock2D", 11 | "DownBlock2D" 12 | ], 13 | "up_block_types": [ 14 | "UpBlock2D", 15 | "UpBlock2D", 16 | "UpBlock2D", 17 | "UpBlock2D" 18 | ], 19 | "block_out_channels": [128, 256, 512, 512], 20 | 21 | "layers_per_block": 2, 22 | "mid_block_scale_factor": 1, 23 | "downsample_padding": 0, 24 | "downsample_type": "conv", 25 | "upsample_type": "conv", 26 | "act_fn": "silu", 27 | "attention_head_dim": null, 28 | "norm_num_groups": 32, 29 | "norm_eps": 1e-6, 30 | "add_attention": true, 31 | "input_residual": true 32 | } 33 | -------------------------------------------------------------------------------- /configs/stage2_configs/large.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNet2DCustom", 3 | "sample_size": 256, 4 | "in_channels": 3, 5 | "out_channels": 3, 6 | 7 | "down_block_types": [ 8 | "DownBlock2D", 9 | "DownBlock2D", 10 | "DownBlock2D", 11 | "DownBlock2D" 12 | ], 13 | "up_block_types": [ 14 | "UpBlock2D", 15 | "UpBlock2D", 16 | "UpBlock2D", 17 | "UpBlock2D" 18 | ], 19 | "block_out_channels": [256, 512, 1024, 1024], 20 | 21 | "layers_per_block": 2, 22 | "mid_block_scale_factor": 1, 23 | "downsample_padding": 0, 24 | "downsample_type": "conv", 25 | "upsample_type": "conv", 26 | "act_fn": "silu", 27 | "attention_head_dim": null, 28 | "norm_num_groups": 32, 29 | "norm_eps": 1e-6, 30 | "add_attention": true, 31 | "input_residual": true 32 | } 33 | -------------------------------------------------------------------------------- /scripts/inference_generate_data.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=.:$PYTHONPATH 2 | ACC_CONFIG_FILE="configs/acc_configs/multi_default.yaml" 3 | export CUDA_VISIBLE_DEVICES="0,1" 4 | NUM_PROCESSES=2 5 | MASTER_PORT=29500 6 | 7 | DATASET_ROOT= 8 | OUTPUT_DIR=$DATASET_ROOT/cand_composite_images 9 | TEST_FILE="all_mask_metadata.jsonl" 10 | 11 | 12 | accelerate launch --config_file $ACC_CONFIG_FILE --num_processes $NUM_PROCESSES --main_process_port $MASTER_PORT \ 13 | scripts/inference/inverse.py \ 14 | --pretrained_model_name_or_path checkpoints/stable-diffusion-inpainting \ 15 | --pretrained_vae_model_name_or_path \ 16 | --pretrained_unet_model_name_or_path \ 17 | --dataset_root $DATASET_ROOT \ 18 | --test_file $TEST_FILE \ 19 | --output_dir $OUTPUT_DIR \ 20 | --seed=0 \ 21 | --resolution=1024 \ 22 | --output_resolution=1024 \ 23 | --eval_batch_size= \ 24 | --dataloader_num_workers= \ 25 | --mixed_precision="fp16" \ 26 | --rounds=10 -------------------------------------------------------------------------------- /scripts/inference.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=.:$PYTHONPATH 2 | ACC_CONFIG_FILE="configs/acc_configs/multi_default.yaml" 3 | export CUDA_VISIBLE_DEVICES="0,1" 4 | NUM_PROCESSES=2 5 | MASTER_PORT=29500 6 | 7 | OUTPUT_DIR="" 8 | mkdir -p $OUTPUT_DIR 9 | cat "$0" >> $OUTPUT_DIR/run_script.sh 10 | 11 | DATA_DIR=data/iHarmony4 12 | TEST_FILE=test.jsonl 13 | 14 | accelerate launch --config_file $ACC_CONFIG_FILE --num_processes $NUM_PROCESSES --main_process_port $MASTER_PORT \ 15 | scripts/inference/main.py \ 16 | --pretrained_model_name_or_path checkpoints/stable-diffusion-inpainting \ 17 | --pretrained_vae_model_name_or_path checkpoints/sd-vae-ft-mse \ 18 | --pretrained_unet_model_name_or_path "" \ 19 | --dataset_root $DATA_DIR \ 20 | --test_file $TEST_FILE \ 21 | --output_dir $OUTPUT_DIR \ 22 | --seed=0 \ 23 | --resolution=512 \ 24 | --output_resolution=256 \ 25 | --eval_batch_size= \ 26 | --dataloader_num_workers= \ 27 | --mixed_precision="fp16" 28 | 29 | # --stage2_model_name_or_path "" -------------------------------------------------------------------------------- /scripts/train_cvae.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=.:$PYTHONPATH 2 | ACC_CONFIG_FILE="configs/acc_configs/multi_default.yaml" 3 | export CUDA_VISIBLE_DEVICES="0,1" 4 | NUM_PROCESSES=2 5 | MASTER_PORT=29500 6 | 7 | OUTPUT_DIR="" 8 | mkdir -p $OUTPUT_DIR 9 | cat "$0" >> $OUTPUT_DIR/run_script.sh 10 | 11 | accelerate launch --config_file $ACC_CONFIG_FILE --num_processes $NUM_PROCESSES --main_process_port $MASTER_PORT \ 12 | scripts/train/cvae.py \ 13 | --pretrained_vae_model_name_or_path checkpoints/sd-vae-ft-mse \ 14 | --output_dir $OUTPUT_DIR \ 15 | --seed= \ 16 | --train_batch_size= \ 17 | --eval_batch_size= \ 18 | --dataloader_num_workers= \ 19 | --num_train_epochs= \ 20 | --gradient_accumulation_steps= \ 21 | --learning_rate= \ 22 | --lr_scheduler "" \ 23 | --lr_warmup_ratio= \ 24 | --use_ema \ 25 | --adam_weight_decay= \ 26 | --ema_decay= \ 27 | --mixed_precision="fp16" \ 28 | --checkpointing_epochs= \ 29 | --checkpoints_total_limit= \ 30 | --image_logging_epochs= \ 31 | --dataset_root "data/iHarmony4" \ 32 | --train_file "train.jsonl" \ 33 | --test_file "test.jsonl" \ 34 | --resolution=256 \ 35 | --additional_in_channels=1 \ 36 | --gradient_checkpointing -------------------------------------------------------------------------------- /scripts/train_diffharmony.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=.:$PYTHONPATH 2 | 3 | ACC_CONFIG_FILE="configs/acc_configs/multi_default.yaml" 4 | export CUDA_VISIBLE_DEVICES="0,1" 5 | NUM_PROCESSES=2 6 | MASTER_PORT=29500 7 | 8 | OUTPUT_DIR="" 9 | mkdir -p $OUTPUT_DIR 10 | cat "$0" >> $OUTPUT_DIR/run_script.sh 11 | 12 | accelerate launch --config_file $ACC_CONFIG_FILE --num_processes $NUM_PROCESSES --main_process_port $MASTER_PORT \ 13 | scripts/train/diffharmony.py \ 14 | --pretrained_model_name_or_path "checkpoints/stable-diffusion-inpainting" \ 15 | --vae_path "checkpoints/sd-vae-ft-mse" \ 16 | --output_dir $OUTPUT_DIR \ 17 | --seed= \ 18 | --train_batch_size= \ 19 | --eval_batch_size= \ 20 | --dataloader_num_workers= \ 21 | --num_train_epochs= \ 22 | --gradient_accumulation_steps= \ 23 | --learning_rate= \ 24 | --lr_scheduler "" \ 25 | --lr_warmup_ratio= \ 26 | --use_ema \ 27 | --adam_weight_decay= \ 28 | --mixed_precision="fp16" \ 29 | --checkpointing_epochs= \ 30 | --checkpoints_total_limit= \ 31 | --dataset_root "data/iHarmony4" \ 32 | --train_file "train.jsonl" \ 33 | --test_file "test.jsonl" \ 34 | --resolution=512 \ 35 | --infer_resolution=512 \ 36 | --image_log_interval= \ 37 | --gradient_checkpointing \ 38 | --enable_xformers_memory_efficient_attention -------------------------------------------------------------------------------- /scripts/train_refinement_stage.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=.:$PYTHONPATH 2 | ACC_CONFIG_FILE="configs/acc_configs/multi_default.yaml" 3 | export CUDA_VISIBLE_DEVICES="0,1" 4 | NUM_PROCESSES=2 5 | MASTER_PORT=29500 6 | 7 | OUTPUT_DIR="" 8 | mkdir -p $OUTPUT_DIR 9 | cat "$0" >> $OUTPUT_DIR/run_script.sh 10 | 11 | accelerate launch --config_file $ACC_CONFIG_FILE --num_processes $NUM_PROCESSES --main_process_port $MASTER_PORT \ 12 | scripts/train/refinement_stage.py \ 13 | --pipeline_path "checkpoints/stable-diffusion-inpainting" \ 14 | --pretrained_vae_path "checkpoints/sd-vae-ft-mse" \ 15 | --pretrained_unet_path "checkpoints/base/unet" \ 16 | --model_path configs/stage2_configs/base.json \ 17 | --output_dir $OUTPUT_DIR \ 18 | --seed= \ 19 | --dataloader_num_workers= \ 20 | --train_batch_size= \ 21 | --num_train_epochs= \ 22 | --gradient_accumulation_steps= \ 23 | --learning_rate= \ 24 | --lr_scheduler "" \ 25 | --lr_warmup_ratio= \ 26 | --use_ema \ 27 | --ema_decay= \ 28 | --adam_weight_decay= \ 29 | --mixed_precision="fp16" \ 30 | --checkpointing_epochs= \ 31 | --checkpoints_total_limit= \ 32 | --infer_resolution=512 \ 33 | --resolution=256 \ 34 | --in_channels=7 \ 35 | --dataset_root "data/iHarmony4" \ 36 | --train_file "train.jsonl" \ 37 | --gradient_checkpointing \ 38 | --enable_xformers_memory_efficient_attention 39 | 40 | # --kl_div_weight=1e-8 \ -------------------------------------------------------------------------------- /src/dataset/harmony_gen.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from torch.utils.data import Dataset 3 | import torchvision.transforms as transforms 4 | from argparse import Namespace 5 | from PIL import Image 6 | import json 7 | import os 8 | 9 | 10 | class GenHarmonyDataset(Dataset): 11 | def __init__(self, dataset_root, resolution): 12 | self.metadata = [ 13 | json.loads(line) 14 | for line in open(os.path.join(dataset_root, "metadata.jsonl")).readlines() 15 | ] 16 | self.dataset_root = dataset_root 17 | self.resolution = resolution 18 | self.transforms = Namespace( 19 | resize=transforms.Resize( 20 | [self.resolution, self.resolution], 21 | interpolation=transforms.InterpolationMode.BILINEAR, 22 | antialias=True, 23 | ), 24 | convert = transforms.Compose([ 25 | transforms.ToTensor(), 26 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 27 | ]), 28 | ) 29 | 30 | def __len__(self): 31 | return len(self.metadata) 32 | 33 | def __getitem__(self, index): 34 | image_path = os.path.join(self.dataset_root, self.metadata[index]["sample"]) 35 | cond_path = self.metadata[index]["cond"] 36 | target_path = self.metadata[index]["target"] 37 | 38 | image = Image.open(image_path).convert("RGB") 39 | if image.size != (self.resolution, self.resolution): 40 | image = self.transforms.resize(image) 41 | 42 | cond_image = Image.open(cond_path).convert("RGB") 43 | if cond_image.size != (self.resolution, self.resolution): 44 | cond_image = self.transforms.resize(cond_image) 45 | 46 | target_image = Image.open(target_path).convert("RGB") 47 | if target_image.size != (self.resolution, self.resolution): 48 | target_image = self.transforms.resize(target_image) 49 | 50 | image= self.transforms.convert(image) 51 | cond_image = self.transforms.convert(cond_image) 52 | target_image = self.transforms.convert(target_image) 53 | 54 | return { 55 | "image": image, 56 | "cond": cond_image, 57 | "target": target_image, 58 | "image_path": image_path, 59 | "cond_path": cond_path, 60 | "target_path": target_path, 61 | } -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | # DiffHarmony & DiffHarmony++ 2 | 3 | [DiffHarmony](https://arxiv.org/abs/2404.06139) 和 [DiffHarmony++](https://dl.acm.org/doi/10.1145/3664647.3681466) 的官方 Pytorch 实现。 4 | 5 | 论文DiffHarmony完整的会议海报在[这里](./assets/poster.pdf)。 6 | 7 | ## 准备 8 | 9 | ### 环境 10 | 11 | 首先,准备一个虚拟环境。你可以使用 **conda** 或其他你喜欢的工具。 12 | ```shell 13 | python 3.10 14 | pytorch 2.2.0 15 | cuda 12.1 16 | xformers 0.0.24 17 | ``` 18 | 19 | 然后,安装所需的依赖。 20 | ```shell 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ## 数据集 25 | 26 | 从[这里](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4)下载 iHarmony4 数据集。 27 | 28 | 确保目录结构如下: 29 | ```shell 30 | data/iHarmony4 31 | |- HCOCO 32 | |- composite_images 33 | |- masks 34 | |- real_images 35 | |- ... 36 | |- HAdobe5k 37 | |- HFlickr 38 | |- Hday2night 39 | |- train.jsonl 40 | |- test.jsonl 41 | ``` 42 | 43 | `train.jsonl` 文件内容格式如下: 44 | ```json 45 | {"file_name": "HAdobe5k/composite_images/a0001_1_1.jpg", "text": ""} 46 | {"file_name": "HAdobe5k/composite_images/a0001_1_2.jpg", "text": ""} 47 | {"file_name": "HAdobe5k/composite_images/a0001_1_3.jpg", "text": ""} 48 | {"file_name": "HAdobe5k/composite_images/a0001_1_4.jpg", "text": ""} 49 | ... 50 | ``` 51 | 所有 `file_name` 来自原始的 `IHD_train.txt`。`test.jsonl` 和 `IHD_test.txt` 也是如此。 52 | 53 | ## 训练 54 | ### 训练 diffharmony 模型 55 | ```shell 56 | sh scripts/train_diffharmony.sh 57 | ``` 58 | 59 | ### 训练 refinement 模型 60 | ```shell 61 | sh scripts/train_refinement_stage.sh 62 | ``` 63 | 64 | ### 训练 condition vae(cvae) 65 | ```shell 66 | sh scripts/train_cvae.sh 67 | ``` 68 | 69 | ### 训练 diffharmony-gen 和 cvae-gen 70 | 在你的训练参数中添加以下内容: 71 | ```shell 72 | $script 73 | ... 74 | --mode "inverse" 75 | ``` 76 | 基本上,它会使用真实图像作为条件,而不是合成图像。 77 | 78 | ### (可选)在线训练 condition vae 79 | 参考 `scripts/train/cvae_online.py` 80 | 81 | ### (可选)使用生成的数据训练 cvae 82 | 参考 `scripts/train/cvae_with_gen_data.py` 83 | 84 | 目的是进一步提高 cvae 在特定领域(即我们生成的数据集)的性能。 85 | 86 | ## 推理 87 | 推理 iHarmony4 数据集 88 | ```shell 89 | sh scripts/inference.sh 90 | ``` 91 | 92 | ### 使用 diffharmony-gen 和 cvae-gen 增强 HFlickr 和 Hday2night 93 | ```shell 94 | sh scripts/inference_generate_data.sh 95 | ``` 96 | `all_mask_metadata.jsonl` 文件内容格式如下: 97 | ```json 98 | {"file_name": "masks/f800_1.png", "text": ""} 99 | {"file_name": "masks/f801_1.png", "text": ""} 100 | {"file_name": "masks/f803_1.png", "text": ""} 101 | {"file_name": "masks/f804_1.png", "text": ""} 102 | ... 103 | ``` 104 | 105 | ### 制作 HumanHarmony 数据集 106 | 首先,生成一些候选的合成图像。 107 | 108 | 然后,使用和谐分类器选择最不和谐的图像。 109 | ```shell 110 | python scripts/misc/classify_cand_gen_data.py 111 | ``` 112 | 113 | ## 评估 114 | ```shell 115 | sh scripts/evaluate.sh 116 | ``` 117 | 118 | ## 预训练模型 119 | [Baidu](https://pan.baidu.com/s/1IkF6YP4C3fsEAi0_9eCESg), 提取码: aqqd 120 | 121 | [Google Drive](https://drive.google.com/file/d/1rezNdcuZbwejbC9rH9S1SUuaWTGTz_wG/view?usp=drive_link) 122 | 123 | ## 引用 124 | 如果你觉得这个工作有用,请考虑引用: 125 | ```bibtex 126 | @inproceedings{zhou2024diffharmony, 127 | title={DiffHarmony: Latent Diffusion Model Meets Image Harmonization}, 128 | author={Zhou, Pengfei and Feng, Fangxiang and Wang, Xiaojie}, 129 | booktitle={Proceedings of the 2024 International Conference on Multimedia Retrieval}, 130 | pages={1130--1134}, 131 | year={2024} 132 | } 133 | @inproceedings{zhou2024diffharmonypp, 134 | title={DiffHarmony++: Enhancing Image Harmonization with Harmony-VAE and Inverse Harmonization Model}, 135 | author={Zhou, Pengfei and Feng, Fangxiang and Liu, Guang and Li, Ruifan and Wang, Xiaojie}, 136 | booktitle={ACM MM}, 137 | year={2024} 138 | } 139 | ``` 140 | 141 | ## 联系 142 | 如果你有任何问题,请随时通过 `zhoupengfei@bupt.edu.cn` 联系我。 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [中文](./README_zh.md) 2 | 3 | # DiffHarmony & DiffHarmony++ 4 | 5 | The official pytorch implementation of [DiffHarmony](https://arxiv.org/abs/2404.06139) and [DiffHarmony++](https://dl.acm.org/doi/10.1145/3664647.3681466). 6 | 7 | Full Conference Poster of DiffHarmony is [here](./assets/poster.pdf). 8 | 9 | ## Preparation 10 | 11 | ### enviroment 12 | 13 | First, prepare a virtual env. You can use **conda** or anything you like. 14 | ```shell 15 | python 3.10 16 | pytorch 2.2.0 17 | cuda 12.1 18 | xformers 0.0.24 19 | ``` 20 | 21 | Then, install requirements. 22 | ```shell 23 | pip install -r requirements.txt 24 | ``` 25 | ## dataset 26 | 27 | Download iHarmony4 dataset from [here](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4). 28 | 29 | Make sure the structure is just like that: 30 | ```shell 31 | data/iHarmony4 32 | |- HCOCO 33 | |- composite_images 34 | |- masks 35 | |- real_images 36 | |- ... 37 | |- HAdobe5k 38 | |- HFlickr 39 | |- Hday2night 40 | |- train.jsonl 41 | |- test.jsonl 42 | ``` 43 | 44 | The content in `train.jsonl` fit the following format 45 | ```json 46 | {"file_name": "HAdobe5k/composite_images/a0001_1_1.jpg", "text": ""} 47 | {"file_name": "HAdobe5k/composite_images/a0001_1_2.jpg", "text": ""} 48 | {"file_name": "HAdobe5k/composite_images/a0001_1_3.jpg", "text": ""} 49 | {"file_name": "HAdobe5k/composite_images/a0001_1_4.jpg", "text": ""} 50 | ... 51 | ``` 52 | All `file_name` are from the original `IHD_train.txt`. Same way with `test.jsonl` and `IHD_test.txt`. 53 | 54 | 55 | ## Training 56 | ### Train diffharmony model 57 | ```shell 58 | sh scripts/train_diffharmony.sh 59 | ``` 60 | 61 | ### Train refinement model 62 | ```shell 63 | sh scripts/train_refinement_stage.sh 64 | ``` 65 | 66 | ### Train condition vae (cvae) 67 | ```shell 68 | sh scripts/train_cvae.sh 69 | ``` 70 | 71 | ### Train diffharmony-gen and cvae-gen 72 | Just add this in your training args: 73 | ```shell 74 | $script 75 | ... 76 | --mode "inverse" 77 | ``` 78 | Basically it will use ground truth images as condition instead of composite images. 79 | 80 | 81 | ### (optional) online training of condition vae 82 | refer to `scripts/train/cvae_online.py` 83 | 84 | ### (optional) train cvae with generated data 85 | refer to `scripts/train/cvae_with_gen_data.py` 86 | 87 | Purpose here is trying to improve cvae performance further on specific domain, i.e. our generated dataset. 88 | 89 | ## Inference 90 | Inference iHarmony4 dataset 91 | ```shell 92 | sh scripts/inference.sh 93 | ``` 94 | 95 | ### use diffharmony-gen and cvae-gen to augment HFlickr and Hday2night 96 | ```shell 97 | sh scripts/inference_generate_data.sh 98 | ``` 99 | The `all_mask_metadata.jsonl` file as its name fits following format: 100 | ```json 101 | {"file_name": "masks/f800_1.png", "text": ""} 102 | {"file_name": "masks/f801_1.png", "text": ""} 103 | {"file_name": "masks/f803_1.png", "text": ""} 104 | {"file_name": "masks/f804_1.png", "text": ""} 105 | ... 106 | ``` 107 | 108 | ### Make HumanHarmony dataset 109 | First, generate some candidate composite images. 110 | 111 | Then, use harmony classifier to select the most unharmonious images. 112 | ```shell 113 | python scripts/misc/classify_cand_gen_data.py 114 | ``` 115 | 116 | ## Evaluation 117 | ```shell 118 | sh scripts/evaluate.sh 119 | ``` 120 | ## Pretrained Models 121 | [Baidu](https://pan.baidu.com/s/1IkF6YP4C3fsEAi0_9eCESg), code: aqqd 122 | 123 | [Google Drive](https://drive.google.com/file/d/1rezNdcuZbwejbC9rH9S1SUuaWTGTz_wG/view?usp=drive_link) 124 | 125 | ## Citation 126 | If you find this work useful, please consider citing: 127 | ```bibtex 128 | @inproceedings{zhou2024diffharmony, 129 | title={DiffHarmony: Latent Diffusion Model Meets Image Harmonization}, 130 | author={Zhou, Pengfei and Feng, Fangxiang and Wang, Xiaojie}, 131 | booktitle={Proceedings of the 2024 International Conference on Multimedia Retrieval}, 132 | pages={1130--1134}, 133 | year={2024} 134 | } 135 | @inproceedings{zhou2024diffharmonypp, 136 | title={DiffHarmony++: Enhancing Image Harmonization with Harmony-VAE and Inverse Harmonization Model}, 137 | author={Zhou, Pengfei and Feng, Fangxiang and Liu, Guang and Li, Ruifan and Wang, Xiaojie}, 138 | booktitle={ACM MM}, 139 | year={2024} 140 | } 141 | ``` 142 | ## Contact 143 | If you have any questions, please feel free to contact me via `zhoupengfei@bupt.edu.cn` . 144 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 3 | 4 | output 5 | .vscode 6 | 7 | ### Python ### 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | 169 | ### Python Patch ### 170 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 171 | poetry.toml 172 | 173 | # ruff 174 | .ruff_cache/ 175 | 176 | # LSP config files 177 | pyrightconfig.json 178 | 179 | # End of https://www.toptal.com/developers/gitignore/api/python 180 | 181 | -------------------------------------------------------------------------------- /scripts/misc/classify_cand_gen_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import os 4 | from torchvision import models, transforms 5 | from torchvision.transforms import InterpolationMode 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import glob 9 | from collections import defaultdict 10 | from accelerate import Accelerator 11 | from einops import rearrange 12 | from accelerate.utils import set_seed 13 | 14 | 15 | class CustomDataset(Dataset): 16 | def __init__(self, dataset_root): 17 | cand_composite_paths = sorted(glob.glob( 18 | os.path.join(dataset_root, "cand_composite_images", "*.jpg") 19 | )) 20 | data = defaultdict(list) 21 | for cand_composite_path in tqdm(cand_composite_paths): 22 | image_name = os.path.basename(cand_composite_path).split("_")[0] 23 | data[image_name].append(cand_composite_path) 24 | self.data = [ 25 | { 26 | "real_path": os.path.join(dataset_root, "real_images", k + ".jpg"), 27 | "cand_composite_paths": v, 28 | } 29 | for k, v in data.items() 30 | ] 31 | 32 | self.transform = transforms.Compose( 33 | [ 34 | transforms.Resize( 35 | (256, 256), interpolation=InterpolationMode.BILINEAR, antialias=True 36 | ), 37 | transforms.ToTensor(), 38 | transforms.Normalize( 39 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 40 | ), 41 | ] 42 | ) 43 | 44 | def __len__(self): 45 | return len(self.data) 46 | 47 | def __getitem__(self, idx): 48 | item = self.data[idx] 49 | cand_composite_images = [ 50 | self.transform(Image.open(cand_composite_path).convert("RGB")) 51 | for cand_composite_path in item["cand_composite_paths"] 52 | ] 53 | cand_composite_images = torch.stack( 54 | cand_composite_images, dim=0 55 | ) # (t, c, h, w) 56 | return { 57 | "cand_composite_images": cand_composite_images, 58 | "cand_composite_paths": item["cand_composite_paths"], 59 | } 60 | 61 | def collate_fn(batch): 62 | cand_composite_images = [item["cand_composite_images"] for item in batch] 63 | cand_composite_paths = [item["cand_composite_paths"] for item in batch] 64 | return { 65 | "cand_composite_images": torch.stack(cand_composite_images, dim=0), # (b, t, c, h, w) 66 | "cand_composite_paths": cand_composite_paths, 67 | } 68 | 69 | 70 | if __name__ == "__main__": 71 | seed = 0 72 | dataset_root = "" 73 | model_path = "checkpoints/harmony_classifier/model_state_dict.pth" 74 | batch_size = 16 75 | 76 | set_seed(seed) 77 | accelerator = Accelerator() 78 | model = models.resnet50() 79 | num_ftrs = model.fc.in_features 80 | model.fc = torch.nn.Linear(num_ftrs, 2) 81 | model.load_state_dict(torch.load(model_path)) 82 | model.eval() 83 | model.to(accelerator.device) 84 | 85 | dataset = CustomDataset(dataset_root) 86 | dataloader = DataLoader( 87 | dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn 88 | ) 89 | dataloader = accelerator.prepare(dataloader) 90 | 91 | progress_bar = tqdm( 92 | range(0, len(dataloader)), 93 | initial=0, 94 | desc="Batches", 95 | # Only show the progress bar once on each machine. 96 | disable=not accelerator.is_local_main_process, 97 | ) 98 | for step, batch in enumerate(dataloader): 99 | imgs = batch["cand_composite_images"] 100 | paths = batch["cand_composite_paths"] 101 | bs = imgs.size(0) 102 | imgs = ( 103 | rearrange(imgs, "b t c h w -> (b t) c h w") 104 | .contiguous() 105 | .to(accelerator.device) 106 | ) 107 | with torch.inference_mode(): 108 | logits = model(imgs) # (b*t, 2) 109 | probabilities = torch.nn.functional.softmax(logits, dim=-1) 110 | unharmony_probs = rearrange( 111 | probabilities[:, 1], "(b t) -> b t", b=bs 112 | ).contiguous() 113 | _, max_index = torch.max(unharmony_probs, dim=1) # (b,) 114 | # 复制 max index 指定的图片 115 | for i, (p, idx) in enumerate(zip(paths, max_index)): 116 | os.system( 117 | f"cp {p[idx.item()]} {os.path.join(dataset_root, 'composite_images')}" 118 | ) 119 | progress_bar.close() 120 | accelerator.wait_for_everyone() 121 | -------------------------------------------------------------------------------- /scripts/inference/inverse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | from accelerate import Accelerator 8 | from accelerate.utils import set_seed 9 | from torchvision.transforms.functional import to_pil_image, resize, to_tensor 10 | from tqdm.auto import tqdm 11 | 12 | from diffusers import ( 13 | AutoencoderKL, 14 | UNet2DConditionModel, 15 | EulerAncestralDiscreteScheduler, 16 | ) 17 | from src.pipelines.pipeline_stable_diffusion_harmony import ( 18 | StableDiffusionHarmonyPipeline, 19 | ) 20 | from src.dataset.ihd_dataset import IhdDatasetMultiRes as Dataset 21 | from src.models.condition_vae import ConditionVAE 22 | 23 | def parse_args(input_args=None): 24 | parser = argparse.ArgumentParser( 25 | description="Simple example of a inference script." 26 | ) 27 | parser.add_argument( 28 | "--pretrained_model_name_or_path", 29 | type=str, 30 | default=None, 31 | required=True, 32 | help="Path to pretrained model or model identifier from huggingface.co/models.", 33 | ) 34 | parser.add_argument( 35 | "--pretrained_vae_model_name_or_path", 36 | type=str, 37 | default=None, 38 | help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", 39 | ) 40 | parser.add_argument( 41 | "--pretrained_unet_model_name_or_path", 42 | type=str, 43 | default=None, 44 | ) 45 | parser.add_argument( 46 | "--stage2_model_name_or_path", 47 | type=str, 48 | default=None, 49 | ) 50 | parser.add_argument( 51 | "--dataset_root", 52 | type=str, 53 | default=None, 54 | help=( 55 | "A folder containing the training data. Folder contents must follow the structure described in" 56 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 57 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 58 | ), 59 | ) 60 | parser.add_argument( 61 | "--train_file", 62 | type=str, 63 | default=None, 64 | ) 65 | parser.add_argument( 66 | "--test_file", 67 | type=str, 68 | default=None, 69 | ) 70 | parser.add_argument( 71 | "--output_dir", 72 | type=str, 73 | required=True, 74 | ) 75 | parser.add_argument( 76 | "--seed", type=int, default=42, help="A seed for reproducible training." 77 | ) 78 | parser.add_argument( 79 | "--resolution", 80 | type=int, 81 | default=1024, 82 | ) 83 | parser.add_argument( 84 | "--output_resolution", 85 | type=int, 86 | default=256, 87 | ) 88 | parser.add_argument( 89 | "--random_crop", 90 | default=False, 91 | action="store_true", 92 | ) 93 | parser.add_argument( 94 | "--random_flip", 95 | default=False, 96 | action="store_true", 97 | ) 98 | parser.add_argument( 99 | "--mask_dilate", 100 | type=int, 101 | default=0, 102 | ) 103 | parser.add_argument( 104 | "--eval_batch_size", 105 | type=int, 106 | default=4, 107 | help="The number of images to generate for evaluation.", 108 | ) 109 | parser.add_argument( 110 | "--dataloader_num_workers", 111 | type=int, 112 | default=0, 113 | help=( 114 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 115 | ), 116 | ) 117 | parser.add_argument( 118 | "--mixed_precision", 119 | type=str, 120 | default=None, 121 | choices=["no", "fp16", "bf16"], 122 | help=( 123 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 124 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 125 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 126 | ), 127 | ) 128 | parser.add_argument( 129 | "--rounds", 130 | type=int, 131 | default=1, 132 | ) 133 | if input_args is not None: 134 | args = parser.parse_args(input_args) 135 | else: 136 | args = parser.parse_args() 137 | return args 138 | 139 | def list_in_str(input_list, target_str): 140 | for item in input_list: 141 | if item in target_str: 142 | return True 143 | return False 144 | 145 | def replace_background(fake_image:torch.Tensor, real_image:torch.Tensor, mask:torch.Tensor): 146 | real_image = real_image.to(fake_image) 147 | mask = mask.to(fake_image) 148 | 149 | fake_image = fake_image * mask + real_image * (1 - mask) 150 | fake_image = to_pil_image(fake_image) 151 | return fake_image 152 | 153 | 154 | def main(args): 155 | accelerator = Accelerator( 156 | mixed_precision=args.mixed_precision, 157 | ) 158 | if args.seed is not None: 159 | set_seed(args.seed) 160 | if accelerator.is_main_process: 161 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 162 | weight_dtype = torch.float32 163 | if accelerator.mixed_precision == "fp16": 164 | weight_dtype = torch.float16 165 | elif accelerator.mixed_precision == "bf16": 166 | weight_dtype = torch.bfloat16 167 | 168 | if list_in_str(["condition_vae", "cvae"], args.pretrained_vae_model_name_or_path): 169 | vae_cls = ConditionVAE 170 | else: 171 | vae_cls = AutoencoderKL 172 | 173 | vae = vae_cls.from_pretrained( 174 | args.pretrained_vae_model_name_or_path, 175 | torch_dtype=weight_dtype, 176 | ) 177 | unet = UNet2DConditionModel.from_pretrained( 178 | args.pretrained_unet_model_name_or_path, 179 | torch_dtype=weight_dtype, 180 | ) 181 | pipeline = StableDiffusionHarmonyPipeline.from_pretrained( 182 | args.pretrained_model_name_or_path, 183 | vae=vae, 184 | unet=unet, 185 | torch_dtype=weight_dtype, 186 | ) 187 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( 188 | pipeline.scheduler.config 189 | ) 190 | pipeline.to(accelerator.device) 191 | pipeline.enable_xformers_memory_efficient_attention() 192 | pipeline.set_progress_bar_config(disable=True) 193 | 194 | dataset = Dataset( 195 | split="test", 196 | tokenizer=None, 197 | resolutions=[args.resolution, args.output_resolution], 198 | opt=args, 199 | ) 200 | dataloader = torch.utils.data.DataLoader( 201 | dataset, 202 | batch_size=args.eval_batch_size, 203 | shuffle=True, 204 | num_workers=args.dataloader_num_workers, 205 | ) 206 | 207 | dataloader = accelerator.prepare(dataloader) 208 | 209 | for num_round in range(1, args.rounds+1): 210 | progress_bar = tqdm( 211 | range(0, len(dataloader)), 212 | initial=0, 213 | desc="Batches", 214 | # Only show the progress bar once on each machine. 215 | disable=not accelerator.is_local_main_process, 216 | ) 217 | for step, batch in enumerate(dataloader): 218 | eval_mask_images = batch[args.resolution]["mask"] 219 | eval_gt_images = batch[args.resolution]["real"] 220 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 221 | with torch.inference_mode(): 222 | samples = pipeline( 223 | prompt=batch[args.resolution]["caption"], 224 | image=eval_gt_images, 225 | mask_image=eval_mask_images, 226 | height=args.resolution, 227 | width=args.resolution, 228 | num_inference_steps=5, 229 | guidance_scale=1.0, 230 | generator=generator, 231 | output_type="pt", 232 | ).images # [0,1] torch tensor 233 | 234 | samples = samples.clamp(0, 1) 235 | 236 | for i, sample in enumerate(samples): 237 | sample = to_pil_image(sample) 238 | output_shape = (args.output_resolution, args.output_resolution) 239 | if sample.size != output_shape: 240 | sample = resize(sample, output_shape, antialias=True) 241 | 242 | sample = replace_background(to_tensor(sample), batch[args.output_resolution]["real"][i].add(1).div(2), batch[args.output_resolution]["mask"][i]) 243 | 244 | save_name = ( 245 | batch[args.output_resolution]["mask_path"][i] 246 | .split("/")[-1] 247 | .split(".")[0] 248 | + f"_{num_round}.jpg" 249 | ) 250 | sample.save( 251 | os.path.join(args.output_dir, save_name), quality=100 252 | ) 253 | progress_bar.update(1) 254 | progress_bar.close() 255 | accelerator.wait_for_everyone() 256 | 257 | 258 | if __name__ == "__main__": 259 | args = parse_args() 260 | main(args) 261 | -------------------------------------------------------------------------------- /scripts/inference/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | from accelerate import Accelerator 8 | from accelerate.utils import set_seed 9 | from torchvision.transforms.functional import to_pil_image, resize 10 | from tqdm.auto import tqdm 11 | 12 | from diffusers import ( 13 | AutoencoderKL, 14 | UNet2DConditionModel, 15 | EulerAncestralDiscreteScheduler, 16 | ) 17 | from src.pipelines.pipeline_stable_diffusion_harmony import ( 18 | StableDiffusionHarmonyPipeline, 19 | ) 20 | from src.dataset.ihd_dataset import IhdDatasetMultiRes as Dataset 21 | from src.models.condition_vae import ConditionVAE 22 | from src.models.unet_2d import UNet2DCustom 23 | from src.utils import make_stage2_input 24 | 25 | 26 | def parse_args(input_args=None): 27 | parser = argparse.ArgumentParser( 28 | description="Simple example of a inference script." 29 | ) 30 | parser.add_argument( 31 | "--pretrained_model_name_or_path", 32 | type=str, 33 | default=None, 34 | required=True, 35 | help="Path to pretrained model or model identifier from huggingface.co/models.", 36 | ) 37 | parser.add_argument( 38 | "--pretrained_vae_model_name_or_path", 39 | type=str, 40 | default=None, 41 | help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", 42 | ) 43 | parser.add_argument( 44 | "--pretrained_unet_model_name_or_path", 45 | type=str, 46 | default=None, 47 | ) 48 | parser.add_argument( 49 | "--stage2_model_name_or_path", 50 | type=str, 51 | default=None, 52 | ) 53 | parser.add_argument( 54 | "--dataset_root", 55 | type=str, 56 | default=None, 57 | help=( 58 | "A folder containing the training data. Folder contents must follow the structure described in" 59 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 60 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 61 | ), 62 | ) 63 | parser.add_argument( 64 | "--train_file", 65 | type=str, 66 | default=None, 67 | ) 68 | parser.add_argument( 69 | "--test_file", 70 | type=str, 71 | default=None, 72 | ) 73 | parser.add_argument( 74 | "--output_dir", 75 | type=str, 76 | required=True, 77 | ) 78 | parser.add_argument( 79 | "--seed", type=int, default=42, help="A seed for reproducible training." 80 | ) 81 | parser.add_argument( 82 | "--resolution", 83 | type=int, 84 | default=1024, 85 | ) 86 | parser.add_argument( 87 | "--output_resolution", 88 | type=int, 89 | default=256, 90 | ) 91 | parser.add_argument( 92 | "--random_crop", 93 | default=False, 94 | action="store_true", 95 | ) 96 | parser.add_argument( 97 | "--random_flip", 98 | default=False, 99 | action="store_true", 100 | ) 101 | parser.add_argument( 102 | "--mask_dilate", 103 | type=int, 104 | default=0, 105 | ) 106 | parser.add_argument( 107 | "--eval_batch_size", 108 | type=int, 109 | default=4, 110 | help="The number of images to generate for evaluation.", 111 | ) 112 | parser.add_argument( 113 | "--dataloader_num_workers", 114 | type=int, 115 | default=0, 116 | help=( 117 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 118 | ), 119 | ) 120 | parser.add_argument( 121 | "--mixed_precision", 122 | type=str, 123 | default=None, 124 | choices=["no", "fp16", "bf16"], 125 | help=( 126 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 127 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 128 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 129 | ), 130 | ) 131 | parser.add_argument( 132 | "--strict_mode", 133 | default=False, 134 | action="store_true", 135 | ) 136 | if input_args is not None: 137 | args = parser.parse_args(input_args) 138 | else: 139 | args = parser.parse_args() 140 | return args 141 | 142 | def list_in_str(input_list, target_str): 143 | for item in input_list: 144 | if item in target_str: 145 | return True 146 | return False 147 | 148 | def main(args): 149 | accelerator = Accelerator( 150 | mixed_precision=args.mixed_precision, 151 | ) 152 | if args.seed is not None: 153 | set_seed(args.seed) 154 | if accelerator.is_main_process: 155 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 156 | weight_dtype = torch.float32 157 | if accelerator.mixed_precision == "fp16": 158 | weight_dtype = torch.float16 159 | elif accelerator.mixed_precision == "bf16": 160 | weight_dtype = torch.bfloat16 161 | 162 | if list_in_str(["condition_vae", "cvae"], args.pretrained_vae_model_name_or_path): 163 | vae_cls = ConditionVAE 164 | else: 165 | vae_cls = AutoencoderKL 166 | 167 | vae = vae_cls.from_pretrained( 168 | args.pretrained_vae_model_name_or_path, 169 | torch_dtype=weight_dtype, 170 | ) 171 | unet = UNet2DConditionModel.from_pretrained( 172 | args.pretrained_unet_model_name_or_path, 173 | torch_dtype=weight_dtype, 174 | ) 175 | pipeline = StableDiffusionHarmonyPipeline.from_pretrained( 176 | args.pretrained_model_name_or_path, 177 | vae=vae, 178 | unet=unet, 179 | torch_dtype=weight_dtype, 180 | ) 181 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( 182 | pipeline.scheduler.config 183 | ) 184 | pipeline.to(accelerator.device) 185 | # pipeline.enable_model_cpu_offload(device=accelerator.device) 186 | pipeline.enable_xformers_memory_efficient_attention() 187 | pipeline.set_progress_bar_config(disable=True) 188 | 189 | use_stage2 = args.stage2_model_name_or_path is not None 190 | if use_stage2: 191 | stage2_model = UNet2DCustom.from_pretrained( 192 | args.stage2_model_name_or_path, 193 | torch_dtype=weight_dtype, 194 | ) 195 | stage2_model.to(accelerator.device) 196 | stage2_model.eval() 197 | stage2_model.requires_grad_(False) 198 | in_channels = stage2_model.config.in_channels 199 | stage2_model.enable_xformers_memory_efficient_attention() 200 | 201 | dataset = Dataset( 202 | split="test", 203 | tokenizer=None, 204 | resolutions=[args.resolution, args.output_resolution], 205 | opt=args, 206 | ) 207 | dataloader = torch.utils.data.DataLoader( 208 | dataset, 209 | batch_size=args.eval_batch_size, 210 | shuffle=True, 211 | num_workers=args.dataloader_num_workers, 212 | ) 213 | 214 | dataloader = accelerator.prepare(dataloader) 215 | progress_bar = tqdm( 216 | range(0, len(dataloader)), 217 | initial=0, 218 | desc="Batches", 219 | # Only show the progress bar once on each machine. 220 | disable=not accelerator.is_local_main_process, 221 | ) 222 | for step, batch in enumerate(dataloader): 223 | if args.strict_mode: 224 | eval_mask_images = batch[args.output_resolution]["mask"] 225 | eval_composite_images = batch[args.output_resolution]["comp"] 226 | if args.output_resolution != args.resolution: 227 | tgt_size = [args.resolution, args.resolution] 228 | eval_mask_images = resize( 229 | eval_mask_images, 230 | size=tgt_size, 231 | antialias=True, 232 | ).clamp(0, 1) 233 | eval_composite_images = resize( 234 | eval_composite_images, 235 | size=tgt_size, 236 | antialias=True, 237 | ).clamp(-1, 1) 238 | else: 239 | eval_mask_images = batch[args.resolution]["mask"] 240 | eval_composite_images = batch[args.resolution]["comp"] 241 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 242 | with torch.inference_mode(): 243 | samples = pipeline( 244 | prompt=batch[args.resolution]["caption"], 245 | image=eval_composite_images, 246 | mask_image=eval_mask_images, 247 | height=args.resolution, 248 | width=args.resolution, 249 | num_inference_steps=5, 250 | guidance_scale=1.0, 251 | generator=generator, 252 | output_type="pt", 253 | ).images # [0,1] torch tensor 254 | 255 | if use_stage2: 256 | samples = 2 * samples - 1 # [-1,1] torch tensor 257 | if tuple(samples.shape[-2:]) != ( 258 | args.output_resolution, 259 | args.output_resolution, 260 | ): 261 | samples = resize( 262 | samples, 263 | size=[args.output_resolution, args.output_resolution], 264 | antialias=True, 265 | ).clamp(-1, 1) 266 | stage2_input = make_stage2_input( 267 | samples, 268 | batch[args.output_resolution]["comp"], 269 | batch[args.output_resolution]["mask"], 270 | in_channels, 271 | ) 272 | samples = ( 273 | stage2_model( 274 | stage2_input.to(device=accelerator.device, dtype=stage2_model.dtype) 275 | ) 276 | .sample.cpu() 277 | .clamp(-1, 1) 278 | ) 279 | samples = (samples + 1) / 2 280 | 281 | samples = samples.clamp(0, 1) 282 | 283 | for i, sample in enumerate(samples): 284 | sample = to_pil_image(sample) 285 | output_shape = (args.output_resolution, args.output_resolution) 286 | if sample.size != output_shape: 287 | sample = resize(sample, output_shape, antialias=True) 288 | save_name = ( 289 | batch[args.output_resolution]["comp_path"][i] 290 | .split("/")[-1] 291 | .split(".")[0] 292 | + ".png" 293 | ) 294 | sample.save( 295 | os.path.join(args.output_dir, save_name), compression=None, quality=100 296 | ) 297 | progress_bar.update(1) 298 | progress_bar.close() 299 | accelerator.wait_for_everyone() 300 | 301 | 302 | if __name__ == "__main__": 303 | args = parse_args() 304 | main(args) 305 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | def tensor_to_pil(tensor: torch.Tensor, mode="RGB"): 5 | if mode == "RGB": 6 | image_np = ( 7 | tensor.permute(1, 2, 0).add(1).multiply(127.5).numpy().astype(np.uint8) 8 | ) 9 | image = Image.fromarray(image_np, mode=mode) 10 | elif mode == "1": 11 | image_np = tensor.squeeze().multiply(255).numpy().astype(np.uint8) 12 | image = Image.fromarray(image_np).convert("1") 13 | else: 14 | raise ValueError(f"not supported mode {mode}") 15 | return image 16 | 17 | import os 18 | def get_paths(comp_path): 19 | # .../ds_name/composite_images/xxx_x_x.jpg 20 | parts = comp_path.split("/") 21 | img_name_parts = parts[-1].split("_") 22 | real_path = os.path.join(*parts[:-2], "real_images", f"{img_name_parts[0]}.jpg") 23 | mask_path = os.path.join( 24 | *parts[:-2], "masks", f"{img_name_parts[0]}_{img_name_parts[1]}.png" 25 | ) 26 | 27 | return { 28 | "real": real_path, 29 | "mask": mask_path, 30 | "comp": comp_path, 31 | } 32 | 33 | comp_to_paths = get_paths 34 | 35 | 36 | def harm_to_names(harm_name): 37 | # xxx_x_x_harmonized.jpg 38 | img_name_parts = harm_name.split("_") 39 | real_name = f"{img_name_parts[0]}.jpg" 40 | mask_name = f"{img_name_parts[0]}_{img_name_parts[1]}.png" 41 | comp_name = f"{img_name_parts[0]}_{img_name_parts[1]}_{img_name_parts[2]}.jpg" 42 | 43 | return { 44 | "real": real_name, 45 | "mask": mask_name, 46 | "comp": comp_name, 47 | "harm": harm_name, 48 | } 49 | 50 | 51 | def comp_to_harm_path(output_dir, comp_path, suffix="jpg"): 52 | img_name = comp_path.split("/")[-1] 53 | prefix, _ = img_name.split(".") 54 | out_path = os.path.join(output_dir, f"{prefix}_harmonized.{suffix}") 55 | return out_path 56 | 57 | 58 | import torch 59 | import torch.nn.functional as F 60 | 61 | 62 | def get_mask_shift_with_corresponding_center( 63 | masks: torch.Tensor, cand_masks: torch.Tensor 64 | ): 65 | # masks : (b,c,h,w) 66 | h, w = masks.shape[-2:] 67 | 68 | # mask boundaries 69 | true_rows = torch.any(masks, dim=3) 70 | true_cols = torch.any(masks, dim=2) 71 | true_cand_rows = torch.any(cand_masks, dim=3) 72 | true_cand_cols = torch.any(cand_masks, dim=2) 73 | bs = len(masks) 74 | dy, dx = [], [] 75 | 76 | for i in range(bs): 77 | try: 78 | row_min, row_max = torch.nonzero(true_rows[i], as_tuple=True)[-1][ 79 | [0, -1] 80 | ].tolist() 81 | col_min, col_max = torch.nonzero(true_cols[i], as_tuple=True)[-1][ 82 | [0, -1] 83 | ].tolist() 84 | cand_row_min, cand_row_max = torch.nonzero( 85 | true_cand_rows[i], as_tuple=True 86 | )[-1][[0, -1]].tolist() 87 | cand_col_min, cand_col_max = torch.nonzero( 88 | true_cand_cols[i], as_tuple=True 89 | )[-1][[0, -1]].tolist() 90 | center_y, center_x = (row_min + row_max) / 2, (col_min + col_max) / 2 91 | cand_center_y, cand_center_x = (cand_row_min + cand_row_max) / 2, ( 92 | cand_col_min + cand_col_max 93 | ) / 2 94 | dy_i = ( 95 | torch.tensor([cand_center_y - center_y]) 96 | .float() 97 | .clamp(-row_min, (h - 1) - row_max) 98 | ) 99 | dx_i = ( 100 | torch.tensor([cand_center_x - center_x]) 101 | .float() 102 | .clamp(-col_min, (w - 1) - col_max) 103 | ) 104 | dy.append(dy_i) 105 | dx.append(dx_i) 106 | except: 107 | dy.append(torch.tensor([0], dtype=torch.float32)) 108 | dx.append(torch.tensor([0], dtype=torch.float32)) 109 | dy = torch.cat(dy, dim=0)[..., None, None] 110 | dx = torch.cat(dx, dim=0)[..., None, None] 111 | 112 | shift = torch.stack([dy, dx], dim=0).to(masks.device) 113 | # (2,b,1,1) 114 | return shift 115 | 116 | 117 | def get_random_mask_shift( 118 | masks, 119 | ): 120 | """ """ 121 | # masks : (b,c,h,w) 122 | h, w = masks.shape[-2:] 123 | 124 | # mask boundaries 125 | true_rows = torch.any(masks, dim=3) 126 | true_cols = torch.any(masks, dim=2) 127 | bs = len(masks) 128 | dy, dx = [], [] 129 | for i in range(bs): 130 | try: 131 | row_min, row_max = torch.nonzero(true_rows[i], as_tuple=True)[-1][ 132 | [0, -1] 133 | ].tolist() 134 | col_min, col_max = torch.nonzero(true_cols[i], as_tuple=True)[-1][ 135 | [0, -1] 136 | ].tolist() 137 | dy.append(torch.randint(-row_min, h - row_max, (1,)).float()) 138 | dx.append(torch.randint(-col_min, w - col_max, (1,)).float()) 139 | except: 140 | dy.append(torch.tensor([0], dtype=torch.float32)) 141 | dx.append(torch.tensor([0], dtype=torch.float32)) 142 | dy = torch.cat(dy, dim=0)[..., None, None] 143 | dx = torch.cat(dx, dim=0)[..., None, None] 144 | 145 | shift = torch.stack([dy, dx], dim=0).to(masks.device) 146 | # (2,b,1,1) 147 | return shift 148 | 149 | 150 | def shift_grid(shape, shift): 151 | h, w = shape 152 | dy, dx = shift 153 | 154 | # make grid for space transformation 155 | y, x = torch.meshgrid(torch.linspace(-1, 1, h), torch.linspace(-1, 1, w), indexing='ij') 156 | y, x = y.to(shift.device)[None, ...], x.to(shift.device)[None, ...] # (1,h,w) 157 | 158 | # xy shift 159 | y = y + 2 * (dy.float() / (h - 1)) 160 | x = x + 2 * (dx.float() / (w - 1)) 161 | 162 | grid = torch.stack((x, y), dim=-1) 163 | return grid 164 | 165 | 166 | def make_stage2_input( 167 | harm: torch.Tensor, comp: torch.Tensor, mask: torch.Tensor, in_channels 168 | ): 169 | if harm.dim()==4: 170 | cat_dim=1 171 | elif harm.dim()==3: 172 | cat_dim=0 173 | else: 174 | raise ValueError(f"image dims should be 3 or 4 but got {harm.dim()}") 175 | if in_channels == 3: 176 | stage2_input = harm 177 | elif in_channels == 4: 178 | stage2_input = torch.cat([harm, mask.to(harm)], dim=cat_dim) 179 | elif in_channels == 7: 180 | stage2_input = torch.cat([harm, mask.to(harm), comp.to(harm)], dim=cat_dim) 181 | else: 182 | raise ValueError( 183 | f"unsupported stage2 input type : got in channels {in_channels}" 184 | ) 185 | return stage2_input 186 | 187 | 188 | import random 189 | import torch.nn.functional as F 190 | from typing import Union 191 | 192 | METHOD_VERSION=("v1", "v2-fgfg", "v2-fgbg") 193 | 194 | def select_cand(real:torch.Tensor, mask:torch.Tensor, method_version:str): 195 | assert method_version in METHOD_VERSION , f"method_version {METHOD_VERSION} not implemented" 196 | if method_version=='v1': # * random selection 197 | cand_indices = [] 198 | bs = len(real) 199 | range_indices = list(range(bs)) 200 | while True: 201 | shuffled_indices = list(range(bs)) 202 | random.shuffle(shuffled_indices) 203 | if any([i == j for i, j in zip(range_indices, shuffled_indices)]): 204 | continue 205 | break 206 | cand_indices = torch.tensor(shuffled_indices, dtype=torch.long) 207 | elif method_version.startswith("v2"): # * select based on lightness difference 208 | fg_images = real * mask 209 | h, w = real.shape[-2:] 210 | EPS = 1e-6 211 | fg_lightness = ( 212 | ((fg_images.max(dim=1).values + fg_images.min(dim=1).values) / 2).mean( 213 | dim=[1, 2] 214 | ) 215 | * (h * w) 216 | / (mask.sum(dim=[1, 2, 3]) + EPS) 217 | ) # (b,c,h,w) -> (b,h,w) -> (b,) ; average lightness of foreground region 218 | if method_version == "v2-fgfg": 219 | lightness_diff = fg_lightness.view(-1, 1) - fg_lightness.view(1, -1) # (b,b) 220 | elif method_version == "v2-fgbg": 221 | bg_images = real * (1 - mask) 222 | bg_lightness = ( 223 | ((bg_images.max(dim=1).values + bg_images.min(dim=1).values) / 2).mean( 224 | dim=[1, 2] 225 | ) 226 | * (h * w) 227 | / ((1 - mask).sum(dim=[1, 2, 3]) + EPS) 228 | ) # (b,c,h,w) -> (b,h,w) -> (b,) ; average lightness of background region 229 | 230 | lightness_diff = fg_lightness.view(-1, 1) - bg_lightness.view(1, -1) # (b,b) 231 | cand_indices = lightness_diff.abs().max(dim=1).indices 232 | return cand_indices 233 | 234 | @torch.inference_mode() 235 | def make_comp( 236 | real: torch.Tensor, 237 | cand: torch.Tensor, 238 | mask: torch.Tensor, 239 | method_version:str, 240 | pipeline, 241 | stage2_model = None, 242 | cand_mask: torch.Tensor = None, 243 | return_dict=False, 244 | ): 245 | assert method_version in METHOD_VERSION , f"method_version {method_version} not implemented" 246 | if method_version=='v1': 247 | shift = get_random_mask_shift(mask) 248 | elif method_version.startswith('v2'): 249 | shift = get_mask_shift_with_corresponding_center(mask, cand_mask) 250 | 251 | grid = shift_grid(mask.shape[-2:], -shift) 252 | mask_cand = F.grid_sample( 253 | input=mask.float(), grid=grid, mode="nearest", padding_mode="zeros", align_corners=False 254 | ).to( 255 | mask 256 | ) # (b,c,h,w) 257 | infer_input = ( 258 | mask_cand 259 | * F.grid_sample( 260 | input=real.float(), grid=grid, mode="nearest", padding_mode="zeros", align_corners=False 261 | ).to(real) 262 | + (1 - mask_cand) * cand 263 | ) 264 | 265 | if return_dict: 266 | dict_output={ 267 | "infer_input":infer_input, 268 | } 269 | 270 | bs=len(real) 271 | h, w = infer_input.shape[-2:] 272 | infer_output = pipeline( 273 | prompt=[""] * bs, 274 | image=infer_input, 275 | mask_image=mask_cand, 276 | # generator=generator, 277 | height=h, 278 | width=w, 279 | guidance_scale=1.0, 280 | num_inference_steps=5, 281 | output_type="numpy", 282 | ).images 283 | # output (b,h,w,c) numpy ndarray , [0,1] range value 284 | infer_output = ( 285 | torch.tensor(infer_output).to(real).permute(0, 3, 1, 2).sub(0.5).multiply(2) 286 | ) # convert to [-1,1] (b,c,h,w,) pytorch tensor 287 | 288 | if stage2_model is not None: 289 | stage2_input = make_stage2_input( 290 | infer_output, 291 | infer_input, 292 | mask, 293 | stage2_model.config.in_channels, 294 | ) 295 | infer_output = stage2_model(stage2_input).sample.clamp(-1, 1) 296 | 297 | if return_dict: 298 | dict_output['infer_output']=infer_output 299 | 300 | inverse_grid = shift_grid(mask.shape[-2:], shift) 301 | comp = (1 - mask) * real + mask * F.grid_sample( 302 | input=infer_output.float(), 303 | grid=inverse_grid, 304 | mode="nearest", 305 | padding_mode="zeros", 306 | align_corners=False, 307 | ).to(infer_output) 308 | 309 | if return_dict: 310 | dict_output['comp']=comp 311 | return dict_output 312 | 313 | return comp -------------------------------------------------------------------------------- /src/models/unet_2d_blocks.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from diffusers.utils import logging 7 | from diffusers.models.resnet import ( 8 | Downsample2D, 9 | ResnetBlock2D, 10 | ResnetBlockCondNorm2D, 11 | Upsample2D, 12 | ) 13 | 14 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 15 | 16 | 17 | def get_down_block( 18 | down_block_type: str, 19 | num_layers: int, 20 | in_channels: int, 21 | out_channels: int, 22 | temb_channels: int, 23 | add_downsample: bool, 24 | resnet_eps: float, 25 | resnet_act_fn: str, 26 | transformer_layers_per_block: int = 1, 27 | num_attention_heads: Optional[int] = None, 28 | resnet_groups: Optional[int] = None, 29 | cross_attention_dim: Optional[int] = None, 30 | downsample_padding: Optional[int] = None, 31 | dual_cross_attention: bool = False, 32 | use_linear_projection: bool = False, 33 | only_cross_attention: bool = False, 34 | upcast_attention: bool = False, 35 | resnet_time_scale_shift: str = "default", 36 | attention_type: str = "default", 37 | resnet_skip_time_act: bool = False, 38 | resnet_out_scale_factor: float = 1.0, 39 | cross_attention_norm: Optional[str] = None, 40 | attention_head_dim: Optional[int] = None, 41 | downsample_type: Optional[str] = None, 42 | dropout: float = 0.0, 43 | ): 44 | # If attn head dim is not defined, we default it to the number of heads 45 | if attention_head_dim is None: 46 | logger.warn( 47 | f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." 48 | ) 49 | attention_head_dim = num_attention_heads 50 | 51 | down_block_type = ( 52 | down_block_type[7:] 53 | if down_block_type.startswith("UNetRes") 54 | else down_block_type 55 | ) 56 | 57 | if down_block_type == "DownEncoderBlock2D": 58 | return DownEncoderSkipBlock2D( 59 | num_layers=num_layers, 60 | in_channels=in_channels, 61 | out_channels=out_channels, 62 | dropout=dropout, 63 | add_downsample=add_downsample, 64 | resnet_eps=resnet_eps, 65 | resnet_act_fn=resnet_act_fn, 66 | resnet_groups=resnet_groups, 67 | downsample_padding=downsample_padding, 68 | resnet_time_scale_shift=resnet_time_scale_shift, 69 | ) 70 | 71 | raise ValueError(f"{down_block_type} does not exist.") 72 | 73 | 74 | def get_up_block( 75 | up_block_type: str, 76 | num_layers: int, 77 | in_channels: int, 78 | out_channels: int, 79 | prev_output_channel: int, 80 | temb_channels: int, 81 | add_upsample: bool, 82 | resnet_eps: float, 83 | resnet_act_fn: str, 84 | resolution_idx: Optional[int] = None, 85 | transformer_layers_per_block: int = 1, 86 | num_attention_heads: Optional[int] = None, 87 | resnet_groups: Optional[int] = None, 88 | cross_attention_dim: Optional[int] = None, 89 | dual_cross_attention: bool = False, 90 | use_linear_projection: bool = False, 91 | only_cross_attention: bool = False, 92 | upcast_attention: bool = False, 93 | resnet_time_scale_shift: str = "default", 94 | attention_type: str = "default", 95 | resnet_skip_time_act: bool = False, 96 | resnet_out_scale_factor: float = 1.0, 97 | cross_attention_norm: Optional[str] = None, 98 | attention_head_dim: Optional[int] = None, 99 | upsample_type: Optional[str] = None, 100 | dropout: float = 0.0, 101 | ) -> nn.Module: 102 | # If attn head dim is not defined, we default it to the number of heads 103 | if attention_head_dim is None: 104 | logger.warn( 105 | f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." 106 | ) 107 | attention_head_dim = num_attention_heads 108 | 109 | up_block_type = ( 110 | up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type 111 | ) 112 | 113 | if up_block_type == "UpDecoderBlock2D": 114 | return UpDecoderSkipBlock2D( 115 | num_layers=num_layers, 116 | in_channels=in_channels, 117 | out_channels=out_channels, 118 | resolution_idx=resolution_idx, 119 | dropout=dropout, 120 | add_upsample=add_upsample, 121 | resnet_eps=resnet_eps, 122 | resnet_act_fn=resnet_act_fn, 123 | resnet_groups=resnet_groups, 124 | resnet_time_scale_shift=resnet_time_scale_shift, 125 | temb_channels=temb_channels, 126 | ) 127 | raise ValueError(f"{up_block_type} does not exist.") 128 | 129 | 130 | class DownEncoderSkipBlock2D(nn.Module): 131 | def __init__( 132 | self, 133 | in_channels: int, 134 | out_channels: int, 135 | dropout: float = 0.0, 136 | num_layers: int = 1, 137 | resnet_eps: float = 1e-6, 138 | resnet_time_scale_shift: str = "default", 139 | resnet_act_fn: str = "swish", 140 | resnet_groups: int = 32, 141 | resnet_pre_norm: bool = True, 142 | output_scale_factor: float = 1.0, 143 | add_downsample: bool = True, 144 | downsample_padding: int = 1, 145 | ): 146 | super().__init__() 147 | resnets = [] 148 | 149 | for i in range(num_layers): 150 | in_channels = in_channels if i == 0 else out_channels 151 | if resnet_time_scale_shift == "spatial": 152 | resnets.append( 153 | ResnetBlockCondNorm2D( 154 | in_channels=in_channels, 155 | out_channels=out_channels, 156 | temb_channels=None, 157 | eps=resnet_eps, 158 | groups=resnet_groups, 159 | dropout=dropout, 160 | time_embedding_norm="spatial", 161 | non_linearity=resnet_act_fn, 162 | output_scale_factor=output_scale_factor, 163 | ) 164 | ) 165 | else: 166 | resnets.append( 167 | ResnetBlock2D( 168 | in_channels=in_channels, 169 | out_channels=out_channels, 170 | temb_channels=None, 171 | eps=resnet_eps, 172 | groups=resnet_groups, 173 | dropout=dropout, 174 | time_embedding_norm=resnet_time_scale_shift, 175 | non_linearity=resnet_act_fn, 176 | output_scale_factor=output_scale_factor, 177 | pre_norm=resnet_pre_norm, 178 | ) 179 | ) 180 | 181 | self.resnets = nn.ModuleList(resnets) 182 | 183 | if add_downsample: 184 | self.downsamplers = nn.ModuleList( 185 | [ 186 | Downsample2D( 187 | out_channels, 188 | use_conv=True, 189 | out_channels=out_channels, 190 | padding=downsample_padding, 191 | name="op", 192 | ) 193 | ] 194 | ) 195 | else: 196 | self.downsamplers = None 197 | 198 | def forward( 199 | self, hidden_states: torch.FloatTensor, scale: float = 1.0 200 | ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: 201 | output_states = () 202 | for resnet in self.resnets: 203 | hidden_states = resnet(hidden_states, temb=None, scale=scale) 204 | output_states = output_states + (hidden_states,) 205 | 206 | if self.downsamplers is not None: 207 | for downsampler in self.downsamplers: 208 | hidden_states = downsampler(hidden_states, scale) 209 | output_states = output_states + (hidden_states,) 210 | 211 | return hidden_states, output_states 212 | 213 | 214 | class UpDecoderSkipBlock2D(nn.Module): 215 | def __init__( 216 | self, 217 | in_channels: int, 218 | out_channels: int, 219 | resolution_idx: Optional[int] = None, 220 | dropout: float = 0.0, 221 | num_layers: int = 1, 222 | resnet_eps: float = 1e-6, 223 | resnet_time_scale_shift: str = "default", # default, spatial 224 | resnet_act_fn: str = "swish", 225 | resnet_groups: int = 32, 226 | resnet_pre_norm: bool = True, 227 | output_scale_factor: float = 1.0, 228 | add_upsample: bool = True, 229 | temb_channels: Optional[int] = None, 230 | ): 231 | super().__init__() 232 | resnets = [] 233 | 234 | for i in range(num_layers): 235 | input_channels = in_channels if i == 0 else out_channels 236 | 237 | if resnet_time_scale_shift == "spatial": 238 | resnets.append( 239 | ResnetBlockCondNorm2D( 240 | in_channels=input_channels, 241 | out_channels=out_channels, 242 | temb_channels=temb_channels, 243 | eps=resnet_eps, 244 | groups=resnet_groups, 245 | dropout=dropout, 246 | time_embedding_norm="spatial", 247 | non_linearity=resnet_act_fn, 248 | output_scale_factor=output_scale_factor, 249 | ) 250 | ) 251 | else: 252 | resnets.append( 253 | ResnetBlock2D( 254 | in_channels=input_channels, 255 | out_channels=out_channels, 256 | temb_channels=temb_channels, 257 | eps=resnet_eps, 258 | groups=resnet_groups, 259 | dropout=dropout, 260 | time_embedding_norm=resnet_time_scale_shift, 261 | non_linearity=resnet_act_fn, 262 | output_scale_factor=output_scale_factor, 263 | pre_norm=resnet_pre_norm, 264 | ) 265 | ) 266 | 267 | self.resnets = nn.ModuleList(resnets) 268 | 269 | if add_upsample: 270 | self.upsamplers = nn.ModuleList( 271 | [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] 272 | ) 273 | else: 274 | self.upsamplers = None 275 | 276 | self.resolution_idx = resolution_idx 277 | 278 | def forward( 279 | self, 280 | hidden_states: torch.FloatTensor, 281 | res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None, 282 | temb: Optional[torch.FloatTensor] = None, 283 | scale: float = 1.0, 284 | ) -> torch.FloatTensor: 285 | for resnet in self.resnets: 286 | hidden_states = resnet(hidden_states, temb=temb, scale=scale) 287 | if res_hidden_states_tuple is not None: 288 | res_hidden_states = res_hidden_states_tuple[-1] 289 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 290 | hidden_states = hidden_states + res_hidden_states 291 | 292 | if self.upsamplers is not None: 293 | for upsampler in self.upsamplers: 294 | hidden_states = upsampler(hidden_states) 295 | 296 | return hidden_states 297 | -------------------------------------------------------------------------------- /scripts/evaluate/main.py: -------------------------------------------------------------------------------- 1 | from os.path import join as pjoin 2 | from skimage.metrics import mean_squared_error as mse 3 | from skimage.metrics import peak_signal_noise_ratio as psnr 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | import json 8 | from collections import defaultdict 9 | from pathlib import Path 10 | from tqdm import tqdm 11 | import torchvision.transforms.functional as TF 12 | import argparse 13 | from concurrent.futures import ThreadPoolExecutor, as_completed 14 | from threading import Lock 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--input_dir", type=str, required=True) 18 | parser.add_argument("--output_dir", type=str, required=True) 19 | parser.add_argument("--data_dir", type=str, default="./data/iHarmony4") 20 | parser.add_argument("--json_file_path", type=str, default="DataSpecs/all_test.jsonl") 21 | parser.add_argument("--resolution", type=int, default=256) 22 | parser.add_argument("--num_processes", type=int, default=4) 23 | parser.add_argument("--use_gt_bg", action="store_true") 24 | args = parser.parse_args() 25 | 26 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 27 | (Path(args.output_dir) / "imgs").mkdir(parents=True, exist_ok=True) 28 | 29 | # 初始化变量 30 | count = 0 31 | psnr_total = 0 32 | mse_total = 0 33 | fmse_total = 0 34 | bmse_total = 0 35 | ds2psnr = defaultdict(list) 36 | ds2mse = defaultdict(list) 37 | ds2fmse = defaultdict(list) 38 | ds2bmse = defaultdict(list) 39 | ds2mask = defaultdict(list) 40 | 41 | 42 | def get_paths(comp_path): 43 | # .../ds_name/composite_images/xxx_x_x.jpg 44 | parts = comp_path.split("/") 45 | img_name_parts = parts[-1].split("_") 46 | real_path = os.path.join(*parts[:-2], "real_images", f"{img_name_parts[0]}.jpg") 47 | mask_path = os.path.join( 48 | *parts[:-2], "masks", f"{img_name_parts[0]}_{img_name_parts[1]}.png" 49 | ) 50 | 51 | return { 52 | "real": real_path, 53 | "mask": mask_path, 54 | "comp": comp_path, 55 | } 56 | 57 | 58 | comp_to_paths = get_paths 59 | 60 | 61 | # 读取jsonl文件并存储到列表中 62 | def read_jsonl_to_list(jsonl_file): 63 | data_list = [] 64 | with open(jsonl_file, "r") as file: 65 | for line in file: 66 | # 解析每一行并将其添加到列表中 67 | data = json.loads(line) 68 | data_list.append(data) 69 | return data_list 70 | 71 | 72 | def possible_resize(image: Image.Image, resolution: int): 73 | if image.size != (resolution, resolution): 74 | image = TF.resize(image, [resolution, resolution], antialias=True) 75 | return image 76 | 77 | 78 | dataset_map = { 79 | "a": "HAdobe5k", 80 | "c": "HCOCO", 81 | "d": "Hday2night", 82 | "f": "HFlickr", 83 | "s": "SAM", 84 | "g": "Gen", 85 | } 86 | 87 | 88 | def process_fn(comp_path, lock): 89 | 90 | global count 91 | global psnr_total 92 | global mse_total 93 | global fmse_total 94 | global bmse_total 95 | global ds2psnr 96 | global ds2mse 97 | global ds2fmse 98 | global ds2bmse 99 | global ds2mask 100 | 101 | _output_str = "" 102 | _htmlstr = "" 103 | 104 | all_paths = comp_to_paths(comp_path) 105 | target_path = all_paths["real"] 106 | mask_path = all_paths["mask"] 107 | comp_path = all_paths["comp"] 108 | 109 | harm_path = os.path.join( 110 | args.input_dir, comp_path.split("/")[-1].replace(".jpg", ".png") 111 | ) 112 | harm_name = harm_path.split("/")[-1] 113 | 114 | ot_file = harm_path 115 | 116 | dataset = dataset_map[harm_name[0]] 117 | 118 | if os.path.exists(ot_file): 119 | # import pdb ; pdb.set_trace() 120 | ot_img = Image.open(ot_file).convert("RGB") 121 | ot_img = possible_resize(ot_img, args.resolution) 122 | 123 | mk_img = Image.open(mask_path).convert("1") 124 | mk_img = possible_resize(mk_img, args.resolution) 125 | 126 | gt_img = Image.open(target_path).convert("RGB") 127 | gt_img = possible_resize(gt_img, args.resolution) 128 | 129 | mk_np = np.array(mk_img, dtype=np.float32)[..., np.newaxis] 130 | gt_np = np.array(gt_img, dtype=np.float32) 131 | ot_np = np.array(ot_img, dtype=np.float32) 132 | 133 | if args.use_gt_bg: 134 | ot_np = ot_np * mk_np + gt_np * (1 - mk_np) 135 | 136 | mse_score = mse(ot_np, gt_np) 137 | psnr_score = psnr(gt_np, ot_np, data_range=255) 138 | 139 | with lock: 140 | mse_total += mse_score 141 | psnr_total += psnr_score 142 | 143 | w, h = ot_img.size 144 | fscore = mse(ot_np * mk_np, gt_np * mk_np) * (h * w) / (mk_np.sum()) 145 | fmse_total += fscore 146 | 147 | bscore = ( 148 | mse(ot_np * (1 - mk_np), gt_np * (1 - mk_np)) 149 | * (h * w) 150 | / ((1 - mk_np).sum()) 151 | ) 152 | with lock: 153 | bmse_total += bscore 154 | 155 | _output_str = f""" 156 | filename : {ot_file} 157 | image size : {ot_img.size} 158 | psnr : {psnr_score:.3f} 159 | mse : {mse_score:.3f} 160 | fmse : {fscore:.3f} 161 | bmse : {bscore:.3f} 162 | \n""" 163 | if fscore > 1000: 164 | gt_save_name = ( 165 | args.output_dir + "/" + "imgs/" + "gt_" + target_path.split("/")[-1] 166 | ) 167 | comp_save_name = ( 168 | args.output_dir + "/" + "imgs/" + "comp_" + comp_path.split("/")[-1] 169 | ) 170 | ot_save_name = ( 171 | args.output_dir + "/" + "imgs/" + "out_" + ot_file.split("/")[-1] 172 | ) 173 | mk_save_name = ( 174 | args.output_dir + "/" + "imgs/" + "mask_" + mask_path.split("/")[-1] 175 | ) 176 | 177 | gt_img.save(gt_save_name) 178 | 179 | comp_img = Image.open(comp_path).convert("RGB") 180 | comp_img = possible_resize(comp_img, args.resolution) 181 | comp_img.save(comp_save_name) 182 | 183 | ot_img.save(ot_save_name) 184 | mk_img.save( 185 | args.output_dir + "/" + "imgs/" + "mask_" + mask_path.split("/")[-1] 186 | ) 187 | _htmlstr = ( 188 | '' 193 | ) 194 | _htmlstr += ( 195 | '' 200 | ) 201 | _htmlstr += ( 202 | '' 207 | ) 208 | _htmlstr += ( 209 | '' 214 | ) 215 | 216 | _htmlstr += '
mse:' + "%.2f" % mse_score + "
" 217 | _htmlstr += '
fmse:' + "%.2f" % fscore + "
" 218 | _htmlstr += '
bmse:' + "%.2f" % bscore + "
" 219 | 220 | with lock: 221 | ds2psnr[dataset].append(psnr_score) 222 | ds2mse[dataset].append(mse_score) 223 | ds2fmse[dataset].append(fscore) 224 | ds2bmse[dataset].append(bscore) 225 | ds2mask[dataset].append(mk_np.sum() / (h * w)) 226 | count = count + 1 227 | else: 228 | _output_str = "" 229 | _htmlstr = "" 230 | return _output_str, _htmlstr 231 | 232 | 233 | if __name__ == "__main__": 234 | json_file = read_jsonl_to_list(os.path.join(args.data_dir, args.json_file_path)) 235 | lock = Lock() 236 | 237 | output_str = "" 238 | htmlstr = ( 239 | 'harmony
' 240 | ) 241 | with ThreadPoolExecutor(max_workers=args.num_processes) as executor: 242 | futures = [ 243 | executor.submit( 244 | process_fn, pjoin(args.data_dir, json_data["file_name"]), lock 245 | ) 246 | for json_data in json_file 247 | ] 248 | for future in tqdm(as_completed(futures), total=len(futures)): 249 | _output_str, _htmlstr = future.result() 250 | output_str += _output_str 251 | htmlstr += _htmlstr 252 | 253 | print(count) 254 | 255 | htmlstr += "
" 256 | with open(args.output_dir + "/" + "test.html", "w") as fw: 257 | fw.write(htmlstr) 258 | 259 | mask_area_range_list = [ 260 | (0.0, 0.05), 261 | (0.05, 0.15), 262 | (0.15, 1.0), 263 | ] 264 | for mask_arae_range in mask_area_range_list: 265 | this_range_total_psnr = [] 266 | this_range_total_mse = [] 267 | this_range_total_fmse = [] 268 | this_range_total_bmse = [] 269 | for ds in ds2psnr: 270 | per_ds_total_psnr = [] 271 | per_ds_total_mse = [] 272 | per_ds_total_fmse = [] 273 | per_ds_total_bmse = [] 274 | for i in range(len(ds2mask[ds])): 275 | if mask_arae_range[0] <= ds2mask[ds][i] < mask_arae_range[1]: 276 | per_ds_total_psnr.append(ds2psnr[ds][i]) 277 | per_ds_total_mse.append(ds2mse[ds][i]) 278 | per_ds_total_fmse.append(ds2fmse[ds][i]) 279 | per_ds_total_bmse.append(ds2bmse[ds][i]) 280 | 281 | this_range_total_psnr.append(ds2psnr[ds][i]) 282 | this_range_total_mse.append(ds2mse[ds][i]) 283 | this_range_total_fmse.append(ds2fmse[ds][i]) 284 | this_range_total_bmse.append(ds2bmse[ds][i]) 285 | 286 | output_str += f"{mask_arae_range} [PSRN] {ds} : {np.sum(per_ds_total_psnr):.3f} / {len(per_ds_total_psnr)} = {np.mean(per_ds_total_psnr):.3f}\n" 287 | output_str += f"{mask_arae_range} [MSE] {ds} : {np.sum(per_ds_total_mse):.3f} / {len(per_ds_total_mse)} = {np.mean(per_ds_total_mse):.3f}\n" 288 | output_str += f"{mask_arae_range} [FMSE] {ds} : {np.sum(per_ds_total_fmse):.3f} / {len(per_ds_total_fmse)} = {np.mean(per_ds_total_fmse):.3f}\n" 289 | output_str += f"{mask_arae_range} [BMSE] {ds} : {np.sum(per_ds_total_bmse):.3f} / {len(per_ds_total_bmse)} = {np.mean(per_ds_total_bmse):.3f}\n" 290 | output_str += "=" * 30 + "\n" 291 | 292 | output_str += f"{mask_arae_range} [PSRN] in total : {np.sum(this_range_total_psnr):.3f} / {len(this_range_total_psnr)} = {np.mean(this_range_total_psnr):.3f}\n" 293 | output_str += f"{mask_arae_range} [MSE] in total : {np.sum(this_range_total_mse):.3f} / {len(this_range_total_mse)} = {np.mean(this_range_total_mse):.3f}\n" 294 | output_str += f"{mask_arae_range} [FMSE] in total : {np.sum(this_range_total_fmse):.3f} / {len(this_range_total_fmse)} = {np.mean(this_range_total_fmse):.3f}\n" 295 | output_str += f"{mask_arae_range} [BMSE] in total : {np.sum(this_range_total_bmse):.3f} / {len(this_range_total_bmse)} = {np.mean(this_range_total_bmse):.3f}\n" 296 | output_str += "=" * 30 + "\n" 297 | 298 | # 打印每个数据集的平均PSNR和MSE 299 | for ds in ds2psnr: 300 | output_str += f"[PSRN] of {ds} : {np.sum(ds2psnr[ds]):.3f} / {len(ds2psnr[ds])} = {np.mean(ds2psnr[ds]):.3f}\n" 301 | output_str += f"[MSE] of {ds} : {np.sum(ds2mse[ds]):.3f} / {len(ds2mse[ds])} = {np.mean(ds2mse[ds]):.3f}\n" 302 | output_str += f"[FMSE] of {ds} : {np.sum(ds2fmse[ds]):.3f} / {len(ds2fmse[ds])} = {np.mean(ds2fmse[ds]):.3f}\n" 303 | output_str += f"[BMSE] of {ds} : {np.sum(ds2bmse[ds]):.3f} / {len(ds2bmse[ds])} = {np.mean(ds2bmse[ds]):.3f}\n" 304 | output_str += "=" * 30 + "\n" 305 | 306 | # 打印总体平均PSNR,MSE和FMSE 307 | output_str += f"""metric in total: 308 | total smaples {count} 309 | psnr : {psnr_total:.3f} / {count} = {psnr_total / count:.3f} 310 | mse : {mse_total:.3f} / {count} = {mse_total / count:.3f} 311 | fmse : {fmse_total:.3f} / {count} = {fmse_total / count:.3f} 312 | bmse : {bmse_total:.3f} / {count} = {bmse_total / count:.3f} 313 | \n""" 314 | 315 | with open(pjoin(args.output_dir, "output.log"), "w") as fw: 316 | fw.write(output_str) 317 | -------------------------------------------------------------------------------- /src/models/unet_2d.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from diffusers.configuration_utils import ConfigMixin, register_to_config 8 | from diffusers.utils import BaseOutput 9 | from diffusers.models.modeling_utils import ModelMixin 10 | from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block, DownBlock2D, UpBlock2D 11 | 12 | @dataclass 13 | class UNet2DOutput(BaseOutput): 14 | """ 15 | The output of [`UNet2DModel`]. 16 | 17 | Args: 18 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 19 | The hidden states output from the last layer of the model. 20 | """ 21 | 22 | sample: torch.FloatTensor 23 | 24 | 25 | class UNet2DCustom(ModelMixin, ConfigMixin): 26 | r""" 27 | A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. 28 | 29 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 30 | for all models (such as downloading or saving). 31 | 32 | Parameters: 33 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): 34 | Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - 35 | 1)`. 36 | in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample. 37 | out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. 38 | center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. 39 | time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. 40 | freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding. 41 | flip_sin_to_cos (`bool`, *optional*, defaults to `True`): 42 | Whether to flip sin to cos for Fourier time embedding. 43 | down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): 44 | Tuple of downsample block types. 45 | mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): 46 | Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. 47 | up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): 48 | Tuple of upsample block types. 49 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): 50 | Tuple of block output channels. 51 | layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. 52 | mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. 53 | downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. 54 | downsample_type (`str`, *optional*, defaults to `conv`): 55 | The downsample type for downsampling layers. Choose between "conv" and "resnet" 56 | upsample_type (`str`, *optional*, defaults to `conv`): 57 | The upsample type for upsampling layers. Choose between "conv" and "resnet" 58 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. 59 | attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. 60 | norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. 61 | norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. 62 | resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config 63 | for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. 64 | class_embed_type (`str`, *optional*, defaults to `None`): 65 | The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, 66 | `"timestep"`, or `"identity"`. 67 | num_class_embeds (`int`, *optional*, defaults to `None`): 68 | Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class 69 | conditioning with `class_embed_type` equal to `None`. 70 | """ 71 | 72 | _supports_gradient_checkpointing = True 73 | 74 | @register_to_config 75 | def __init__( 76 | self, 77 | sample_size: Optional[Union[int, Tuple[int, int]]] = None, 78 | in_channels: int = 3, 79 | out_channels: int = 3, 80 | down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), 81 | up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), 82 | block_out_channels: Tuple[int] = (224, 448, 672, 896), 83 | layers_per_block: int = 2, 84 | mid_block_scale_factor: float = 1, 85 | downsample_padding: int = 1, 86 | downsample_type: str = "conv", 87 | upsample_type: str = "conv", 88 | act_fn: str = "silu", 89 | attention_head_dim: Optional[int] = 8, 90 | norm_num_groups: int = 32, 91 | norm_eps: float = 1e-5, 92 | resnet_time_scale_shift: str = "default", 93 | add_attention: bool = True, 94 | input_residual = False, 95 | ): 96 | super().__init__() 97 | 98 | # no time embedding 99 | # use input residual to accelerate training 100 | self.sample_size = sample_size 101 | 102 | self.input_residual = input_residual 103 | 104 | # Check inputs 105 | if len(down_block_types) != len(up_block_types): 106 | raise ValueError( 107 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 108 | ) 109 | 110 | if len(block_out_channels) != len(down_block_types): 111 | raise ValueError( 112 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 113 | ) 114 | 115 | # input 116 | self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 117 | 118 | self.down_blocks = nn.ModuleList([]) 119 | self.mid_block = None 120 | self.up_blocks = nn.ModuleList([]) 121 | 122 | # down 123 | output_channel = block_out_channels[0] 124 | for i, down_block_type in enumerate(down_block_types): 125 | input_channel = output_channel 126 | output_channel = block_out_channels[i] 127 | is_final_block = i == len(block_out_channels) - 1 128 | 129 | down_block = get_down_block( 130 | down_block_type, 131 | num_layers=layers_per_block, 132 | in_channels=input_channel, 133 | out_channels=output_channel, 134 | temb_channels=None, 135 | add_downsample=not is_final_block, 136 | resnet_eps=norm_eps, 137 | resnet_act_fn=act_fn, 138 | resnet_groups=norm_num_groups, 139 | attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, 140 | downsample_padding=downsample_padding, 141 | resnet_time_scale_shift=resnet_time_scale_shift, 142 | downsample_type=downsample_type, 143 | ) 144 | self.down_blocks.append(down_block) 145 | 146 | # mid 147 | self.mid_block = UNetMidBlock2D( 148 | in_channels=block_out_channels[-1], 149 | temb_channels=None, 150 | resnet_eps=norm_eps, 151 | resnet_act_fn=act_fn, 152 | output_scale_factor=mid_block_scale_factor, 153 | resnet_time_scale_shift=resnet_time_scale_shift, 154 | attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], 155 | resnet_groups=norm_num_groups, 156 | add_attention=add_attention, 157 | ) 158 | 159 | # up 160 | reversed_block_out_channels = list(reversed(block_out_channels)) 161 | output_channel = reversed_block_out_channels[0] 162 | for i, up_block_type in enumerate(up_block_types): 163 | prev_output_channel = output_channel 164 | output_channel = reversed_block_out_channels[i] 165 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 166 | 167 | is_final_block = i == len(block_out_channels) - 1 168 | 169 | up_block = get_up_block( 170 | up_block_type, 171 | num_layers=layers_per_block + 1, 172 | in_channels=input_channel, 173 | out_channels=output_channel, 174 | prev_output_channel=prev_output_channel, 175 | temb_channels=None, 176 | add_upsample=not is_final_block, 177 | resnet_eps=norm_eps, 178 | resnet_act_fn=act_fn, 179 | resnet_groups=norm_num_groups, 180 | attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, 181 | resnet_time_scale_shift=resnet_time_scale_shift, 182 | upsample_type=upsample_type, 183 | ) 184 | self.up_blocks.append(up_block) 185 | prev_output_channel = output_channel 186 | 187 | # out 188 | num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) 189 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) 190 | self.conv_act = nn.SiLU() 191 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 192 | 193 | def _set_gradient_checkpointing(self, module, value=False): 194 | if isinstance(module, (DownBlock2D, UpBlock2D)): 195 | module.gradient_checkpointing = value 196 | 197 | def forward( 198 | self, 199 | sample: torch.FloatTensor, 200 | return_dict: bool = True, 201 | ) -> Union[UNet2DOutput, Tuple]: 202 | r""" 203 | The [`UNet2DModel`] forward method. 204 | 205 | Args: 206 | sample (`torch.FloatTensor`): 207 | The noisy input tensor with the following shape `(batch, channel, height, width)`. 208 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 209 | class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): 210 | Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. 211 | return_dict (`bool`, *optional*, defaults to `True`): 212 | Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. 213 | 214 | Returns: 215 | [`~models.unet_2d.UNet2DOutput`] or `tuple`: 216 | If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is 217 | returned where the first element is the sample tensor. 218 | """ 219 | 220 | emb=None 221 | 222 | # 2. pre-process 223 | if self.input_residual: 224 | # TODO : find a better way 225 | skip_sample = sample[:,:self.config.out_channels].detach().clone() 226 | # skip_sample = sample.detach().clone() 227 | else: 228 | skip_sample = None 229 | sample = self.conv_in(sample) 230 | 231 | # 3. down 232 | down_block_res_samples = (sample,) 233 | for downsample_block in self.down_blocks: 234 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 235 | down_block_res_samples += res_samples 236 | 237 | # 4. mid 238 | sample = self.mid_block(sample, emb) 239 | 240 | # 5. up 241 | for upsample_block in self.up_blocks: 242 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 243 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 244 | sample = upsample_block(sample, res_samples, emb) 245 | 246 | # 6. post-process 247 | sample = self.conv_norm_out(sample) 248 | sample = self.conv_act(sample) 249 | sample = self.conv_out(sample) 250 | 251 | if skip_sample is not None: 252 | sample += skip_sample 253 | 254 | if not return_dict: 255 | return (sample,) 256 | 257 | return UNet2DOutput(sample=sample) 258 | -------------------------------------------------------------------------------- /src/models/vae.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from diffusers.utils import is_torch_version 7 | from diffusers.models.attention_processor import SpatialNorm 8 | from diffusers.models.unets.unet_2d_blocks import ( 9 | UNetMidBlock2D, 10 | ) 11 | from .unet_2d_blocks import get_down_block, get_up_block 12 | 13 | 14 | def zero_module(module): 15 | for p in module.parameters(): 16 | nn.init.zeros_(p) 17 | return module 18 | 19 | class EncoderSkip(nn.Module): 20 | r""" 21 | The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. 22 | 23 | Args: 24 | in_channels (`int`, *optional*, defaults to 3): 25 | The number of input channels. 26 | out_channels (`int`, *optional*, defaults to 3): 27 | The number of output channels. 28 | down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): 29 | The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available 30 | options. 31 | block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): 32 | The number of output channels for each block. 33 | layers_per_block (`int`, *optional*, defaults to 2): 34 | The number of layers per block. 35 | norm_num_groups (`int`, *optional*, defaults to 32): 36 | The number of groups for normalization. 37 | act_fn (`str`, *optional*, defaults to `"silu"`): 38 | The activation function to use. See `~diffusers.models.activations.get_activation` for available options. 39 | double_z (`bool`, *optional*, defaults to `True`): 40 | Whether to double the number of output channels for the last block. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | in_channels: int = 3, 46 | out_channels: int = 3, 47 | down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), 48 | block_out_channels: Tuple[int, ...] = (64,), 49 | layers_per_block: int = 2, 50 | norm_num_groups: int = 32, 51 | act_fn: str = "silu", 52 | double_z: bool = True, 53 | mid_block_add_attention=True, 54 | additional_in_channels: int = 0, 55 | ): 56 | super().__init__() 57 | self.layers_per_block = layers_per_block 58 | self.in_channels = in_channels 59 | self.additional_in_channels = additional_in_channels 60 | 61 | self.conv_in = nn.Conv2d( 62 | in_channels, 63 | block_out_channels[0], 64 | kernel_size=3, 65 | stride=1, 66 | padding=1, 67 | ) 68 | if additional_in_channels>0: 69 | self.add_conv_in = nn.Conv2d( 70 | additional_in_channels, 71 | block_out_channels[0], 72 | kernel_size=3, 73 | stride=1, 74 | padding=1, 75 | ) 76 | 77 | self.mid_block = None 78 | self.down_blocks = nn.ModuleList([]) 79 | self.down_block_skip_convs = nn.ModuleList([]) 80 | 81 | # down 82 | output_channel = block_out_channels[0] 83 | skip_conv = nn.Conv2d(output_channel, output_channel, kernel_size=1) 84 | skip_conv = zero_module(skip_conv) 85 | self.down_block_skip_convs.append(skip_conv) 86 | 87 | for i, down_block_type in enumerate(down_block_types): 88 | input_channel = output_channel 89 | output_channel = block_out_channels[i] 90 | is_final_block = i == len(block_out_channels) - 1 91 | 92 | down_block = get_down_block( 93 | down_block_type, 94 | num_layers=self.layers_per_block, 95 | in_channels=input_channel, 96 | out_channels=output_channel, 97 | add_downsample=not is_final_block, 98 | resnet_eps=1e-6, 99 | downsample_padding=0, 100 | resnet_act_fn=act_fn, 101 | resnet_groups=norm_num_groups, 102 | attention_head_dim=output_channel, 103 | temb_channels=None, 104 | ) 105 | self.down_blocks.append(down_block) 106 | 107 | for _ in range(layers_per_block): 108 | skip_conv = nn.Conv2d(output_channel, output_channel, kernel_size=1) 109 | skip_conv = zero_module(skip_conv) 110 | self.down_block_skip_convs.append(skip_conv) 111 | 112 | if not is_final_block: 113 | skip_conv = nn.Conv2d(block_out_channels[i], block_out_channels[i+1], kernel_size=1) 114 | skip_conv = zero_module(skip_conv) 115 | self.down_block_skip_convs.append(skip_conv) 116 | 117 | # mid 118 | mid_block_channel = block_out_channels[-1] 119 | skip_conv = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) 120 | skip_conv = zero_module(skip_conv) 121 | self.mid_block_skip_conv = skip_conv 122 | 123 | self.mid_block = UNetMidBlock2D( 124 | in_channels=block_out_channels[-1], 125 | resnet_eps=1e-6, 126 | resnet_act_fn=act_fn, 127 | output_scale_factor=1, 128 | resnet_time_scale_shift="default", 129 | attention_head_dim=block_out_channels[-1], 130 | resnet_groups=norm_num_groups, 131 | temb_channels=None, 132 | add_attention=mid_block_add_attention, 133 | ) 134 | 135 | # out 136 | self.conv_norm_out = nn.GroupNorm( 137 | num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 138 | ) 139 | self.conv_act = nn.SiLU() 140 | 141 | conv_out_channels = 2 * out_channels if double_z else out_channels 142 | self.conv_out = nn.Conv2d( 143 | block_out_channels[-1], conv_out_channels, 3, padding=1 144 | ) 145 | 146 | self.gradient_checkpointing = False 147 | 148 | def forward(self, sample: torch.FloatTensor) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: 149 | r"""The forward method of the `Encoder` class.""" 150 | # sample : (B, C, H, W) 151 | if hasattr(self, "add_conv_in"): 152 | res_sample = self.add_conv_in(sample[:, self.in_channels:, :, :]) 153 | else: 154 | res_sample = 0 155 | sample = res_sample + self.conv_in(sample[:, :self.in_channels, :, :]) 156 | 157 | down_block_res_samples = (sample,) 158 | 159 | if self.training and self.gradient_checkpointing: 160 | 161 | def create_custom_forward(module): 162 | def custom_forward(*inputs): 163 | return module(*inputs) 164 | 165 | return custom_forward 166 | 167 | # down 168 | if is_torch_version(">=", "1.11.0"): 169 | for down_block in self.down_blocks: 170 | sample, res_samples = torch.utils.checkpoint.checkpoint( 171 | create_custom_forward(down_block), sample, use_reentrant=False 172 | ) 173 | down_block_res_samples += res_samples 174 | # middle 175 | sample = torch.utils.checkpoint.checkpoint( 176 | create_custom_forward(self.mid_block), sample, use_reentrant=False 177 | ) 178 | mid_block_res_samples = sample.clone() 179 | else: 180 | for down_block in self.down_blocks: 181 | sample, res_samples = torch.utils.checkpoint.checkpoint( 182 | create_custom_forward(down_block), sample 183 | ) 184 | down_block_res_samples += res_samples 185 | # middle 186 | sample = torch.utils.checkpoint.checkpoint( 187 | create_custom_forward(self.mid_block), sample 188 | ) 189 | mid_block_res_samples = sample.clone() 190 | 191 | else: 192 | # down 193 | for down_block in self.down_blocks: 194 | sample, res_samples = down_block(sample) 195 | down_block_res_samples += res_samples 196 | 197 | # middle 198 | sample = self.mid_block(sample) 199 | mid_block_res_samples = sample.clone() 200 | 201 | # post-process 202 | sample = self.conv_norm_out(sample) 203 | sample = self.conv_act(sample) 204 | sample = self.conv_out(sample) 205 | 206 | skip_down_block_res_samples = () 207 | for down_block_res_sample, skip_conv in zip(down_block_res_samples, self.down_block_skip_convs): 208 | skip_down_block_res_samples = skip_down_block_res_samples + (skip_conv(down_block_res_sample),) 209 | down_block_res_samples = skip_down_block_res_samples 210 | mid_block_res_samples = self.mid_block_skip_conv(mid_block_res_samples) 211 | 212 | return sample, down_block_res_samples, mid_block_res_samples 213 | 214 | 215 | 216 | class DecoderSkip(nn.Module): 217 | r""" 218 | The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. 219 | 220 | Args: 221 | in_channels (`int`, *optional*, defaults to 3): 222 | The number of input channels. 223 | out_channels (`int`, *optional*, defaults to 3): 224 | The number of output channels. 225 | up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): 226 | The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. 227 | block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): 228 | The number of output channels for each block. 229 | layers_per_block (`int`, *optional*, defaults to 2): 230 | The number of layers per block. 231 | norm_num_groups (`int`, *optional*, defaults to 32): 232 | The number of groups for normalization. 233 | act_fn (`str`, *optional*, defaults to `"silu"`): 234 | The activation function to use. See `~diffusers.models.activations.get_activation` for available options. 235 | norm_type (`str`, *optional*, defaults to `"group"`): 236 | The normalization type to use. Can be either `"group"` or `"spatial"`. 237 | """ 238 | 239 | def __init__( 240 | self, 241 | in_channels: int = 3, 242 | out_channels: int = 3, 243 | up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), 244 | block_out_channels: Tuple[int, ...] = (64,), 245 | layers_per_block: int = 2, 246 | norm_num_groups: int = 32, 247 | act_fn: str = "silu", 248 | norm_type: str = "group", # group, spatial 249 | mid_block_add_attention=True, 250 | ): 251 | super().__init__() 252 | self.layers_per_block = layers_per_block 253 | 254 | self.conv_in = nn.Conv2d( 255 | in_channels, 256 | block_out_channels[-1], 257 | kernel_size=3, 258 | stride=1, 259 | padding=1, 260 | ) 261 | 262 | self.mid_block = None 263 | self.up_blocks = nn.ModuleList([]) 264 | 265 | temb_channels = in_channels if norm_type == "spatial" else None 266 | 267 | # mid 268 | self.mid_block = UNetMidBlock2D( 269 | in_channels=block_out_channels[-1], 270 | resnet_eps=1e-6, 271 | resnet_act_fn=act_fn, 272 | output_scale_factor=1, 273 | resnet_time_scale_shift="default" if norm_type == "group" else norm_type, 274 | attention_head_dim=block_out_channels[-1], 275 | resnet_groups=norm_num_groups, 276 | temb_channels=temb_channels, 277 | add_attention=mid_block_add_attention, 278 | ) 279 | 280 | # up 281 | reversed_block_out_channels = list(reversed(block_out_channels)) 282 | output_channel = reversed_block_out_channels[0] 283 | for i, up_block_type in enumerate(up_block_types): 284 | prev_output_channel = output_channel 285 | output_channel = reversed_block_out_channels[i] 286 | 287 | is_final_block = i == len(block_out_channels) - 1 288 | 289 | up_block = get_up_block( 290 | up_block_type, 291 | num_layers=self.layers_per_block + 1, 292 | in_channels=prev_output_channel, 293 | out_channels=output_channel, 294 | prev_output_channel=None, 295 | add_upsample=not is_final_block, 296 | resnet_eps=1e-6, 297 | resnet_act_fn=act_fn, 298 | resnet_groups=norm_num_groups, 299 | attention_head_dim=output_channel, 300 | temb_channels=temb_channels, 301 | resnet_time_scale_shift=norm_type, 302 | ) 303 | self.up_blocks.append(up_block) 304 | prev_output_channel = output_channel 305 | 306 | # out 307 | if norm_type == "spatial": 308 | self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) 309 | else: 310 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) 311 | self.conv_act = nn.SiLU() 312 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) 313 | 314 | self.gradient_checkpointing = False 315 | 316 | def forward( 317 | self, 318 | sample: torch.FloatTensor, 319 | down_block_res_samples: Optional[Tuple[torch.FloatTensor, ...]] = None, 320 | mid_block_res_samples: Optional[torch.FloatTensor] = None, 321 | latent_embeds: Optional[torch.FloatTensor] = None, 322 | ) -> torch.FloatTensor: 323 | r"""The forward method of the `Decoder` class.""" 324 | 325 | sample = self.conv_in(sample) 326 | 327 | upscale_dtype = next(iter(self.up_blocks.parameters())).dtype 328 | if self.training and self.gradient_checkpointing: 329 | 330 | def create_custom_forward(module): 331 | def custom_forward(*inputs): 332 | return module(*inputs) 333 | 334 | return custom_forward 335 | 336 | if is_torch_version(">=", "1.11.0"): 337 | # middle 338 | sample = torch.utils.checkpoint.checkpoint( 339 | create_custom_forward(self.mid_block), 340 | sample, 341 | latent_embeds, 342 | use_reentrant=False, 343 | ) 344 | if mid_block_res_samples is not None: 345 | sample = sample + mid_block_res_samples 346 | sample = sample.to(upscale_dtype) 347 | 348 | # up 349 | for up_block in self.up_blocks: 350 | res_samples = None 351 | if down_block_res_samples is not None: 352 | res_samples = down_block_res_samples[-len(up_block.resnets) :] 353 | down_block_res_samples = down_block_res_samples[: -len(up_block.resnets)] 354 | sample = torch.utils.checkpoint.checkpoint( 355 | create_custom_forward(up_block), 356 | sample, 357 | res_samples, 358 | latent_embeds, 359 | use_reentrant=False, 360 | ) 361 | else: 362 | # middle 363 | sample = torch.utils.checkpoint.checkpoint( 364 | create_custom_forward(self.mid_block), sample, latent_embeds 365 | ) 366 | if mid_block_res_samples is not None: 367 | sample = sample + mid_block_res_samples 368 | sample = sample.to(upscale_dtype) 369 | 370 | # up 371 | for up_block in self.up_blocks: 372 | res_samples = None 373 | if down_block_res_samples is not None: 374 | res_samples = down_block_res_samples[-len(up_block.resnets) :] 375 | down_block_res_samples = down_block_res_samples[: -len(up_block.resnets)] 376 | sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, res_samples, latent_embeds) 377 | else: 378 | # middle 379 | sample = self.mid_block(sample, latent_embeds) 380 | if mid_block_res_samples is not None: 381 | sample = sample + mid_block_res_samples 382 | sample = sample.to(upscale_dtype) 383 | 384 | # up 385 | for up_block in self.up_blocks: 386 | res_samples = None 387 | if down_block_res_samples is not None: 388 | res_samples = down_block_res_samples[-len(up_block.resnets) :] 389 | down_block_res_samples = down_block_res_samples[: -len(up_block.resnets)] 390 | sample = up_block(sample, res_samples, latent_embeds) 391 | 392 | # post-process 393 | if latent_embeds is None: 394 | sample = self.conv_norm_out(sample) 395 | else: 396 | sample = self.conv_norm_out(sample, latent_embeds) 397 | sample = self.conv_act(sample) 398 | sample = self.conv_out(sample) 399 | 400 | return sample 401 | 402 | -------------------------------------------------------------------------------- /src/dataset/ihd_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torch 3 | from torch.utils.data import Dataset 4 | import numpy as np 5 | import torchvision.transforms as T 6 | import torchvision.transforms.functional as TF 7 | from argparse import Namespace 8 | 9 | from PIL import Image 10 | import json 11 | import cv2 12 | import os 13 | from typing import List, Dict 14 | 15 | 16 | def get_paths(path) -> Dict[str,str]: 17 | parts = path.split("/") 18 | img_name_parts = parts[-1].split(".")[0].split("_") 19 | if "masks" in path: 20 | return { 21 | "gt_path": os.path.join( 22 | *parts[:-2], "real_images", f"{img_name_parts[0]}.jpg" 23 | ), 24 | "mask_path": path, 25 | "image_path": os.path.join( 26 | *parts[:-2], "real_images", f"{img_name_parts[0]}.jpg" 27 | ), 28 | } 29 | elif "composite" in path: 30 | return { 31 | "gt_path": os.path.join( 32 | *parts[:-2], "real_images", f"{img_name_parts[0]}.jpg" 33 | ), 34 | "mask_path": os.path.join( 35 | *parts[:-2], "masks", f"{img_name_parts[0]}_{img_name_parts[1]}.png" 36 | ), 37 | "image_path": path, 38 | } 39 | else: 40 | raise ValueError(f"Unknown path type: {path}") 41 | 42 | 43 | class IhdDatasetMultiRes(Dataset): 44 | def __init__(self, split, tokenizer, resolutions: List[int], opt): 45 | 46 | self.image_paths = [] 47 | self.captions = [] 48 | self.split = split 49 | self.tokenizer = tokenizer 50 | self.resolutions = list(set(resolutions)) 51 | self.random_flip = opt.random_flip 52 | self.random_crop = opt.random_crop 53 | self.mask_dilate = opt.mask_dilate 54 | 55 | data_file = opt.train_file if split == "train" else opt.test_file 56 | if split == "test": 57 | self.random_flip = False 58 | self.random_crop = False 59 | 60 | with open(os.path.join(opt.dataset_root, data_file), "r") as f: 61 | for line in f: 62 | cont = json.loads(line.strip()) 63 | image_path = os.path.join( 64 | opt.dataset_root, 65 | cont["file_name"], 66 | ) 67 | self.image_paths.append(image_path) 68 | self.captions.append(cont.get("text", "")) 69 | 70 | self.create_image_transforms() 71 | 72 | def __len__(self): 73 | return len(self.image_paths) 74 | 75 | def create_image_transforms(self): 76 | self.rgb_normalizer = T.Compose([T.ToTensor(), T.Normalize([0.5], [0.5])]) 77 | 78 | def __getitem__(self, index): 79 | paths = get_paths(self.image_paths[index]) 80 | 81 | comp = Image.open(paths["image_path"]).convert("RGB") # RGB , [0,255] 82 | mask = Image.open(paths["mask_path"]).convert("1") 83 | real = Image.open(paths["gt_path"]).convert("RGB") # RGB , [0,255] 84 | 85 | caption = self.captions[index] 86 | if self.tokenizer is not None: 87 | caption_ids = self.tokenizer( 88 | caption, 89 | max_length=self.tokenizer.model_max_length, 90 | padding="max_length", 91 | truncation=True, 92 | return_tensors="pt", 93 | ).input_ids[0] 94 | else: 95 | caption_ids = torch.empty(size=(1,), dtype=torch.long) 96 | 97 | if self.random_flip and np.random.rand() > 0.5 and self.split == "train": 98 | comp, mask, real = TF.hflip(comp), TF.hflip(mask), TF.hflip(real) 99 | if self.random_crop: 100 | for _ in range(5): 101 | mask_tensor = TF.to_tensor(mask) 102 | crop_box = T.RandomResizedCrop.get_params( 103 | mask_tensor, scale=[0.5, 1.0], ratio=[3 / 4, 4 / 3] 104 | ) 105 | cropped_mask_tensor = TF.crop(mask_tensor, *crop_box) 106 | h, w = cropped_mask_tensor.shape[-2:] 107 | if cropped_mask_tensor.sum() < 0.01 * h * w: 108 | continue 109 | break 110 | 111 | example = {} 112 | for resolution in self.resolutions: 113 | if self.random_crop: 114 | this_res_comp = TF.resize( 115 | TF.crop(comp, *crop_box), 116 | size=(resolution, resolution), 117 | ) # default : bilinear resample ; antialias for PIL Image 118 | this_res_real = TF.resize( 119 | TF.crop(real, *crop_box), 120 | size=(resolution, resolution), 121 | ) # default : bilinear resample ; antialias for PIL Image 122 | this_res_mask = TF.resize( 123 | TF.crop(mask, *crop_box), 124 | size=(resolution, resolution), 125 | ) # default : bilinear resample ; antialias for PIL Image 126 | else: 127 | if comp.size == (resolution, resolution): 128 | this_res_comp = comp 129 | this_res_mask = mask 130 | this_res_real = real 131 | else: 132 | this_res_comp = TF.resize( 133 | comp, [resolution, resolution] 134 | ) # default : bilinear resample ; antialias for PIL Image 135 | this_res_mask = TF.resize( 136 | mask, [resolution, resolution] 137 | ) # default : bilinear resample ; antialias for PIL Image 138 | this_res_real = TF.resize( 139 | real, [resolution, resolution] 140 | ) # default : bilinear resample ; antialias for PIL Image 141 | 142 | this_res_comp = self.rgb_normalizer(this_res_comp) # tensor , [-1,1] 143 | this_res_real = self.rgb_normalizer(this_res_real) # tensor , [-1,1] 144 | this_res_mask = TF.to_tensor(this_res_mask) 145 | this_res_mask = (this_res_mask >= 0.5).float() # mask : tensor , 0/1 146 | this_res_mask = self.dilate_mask_image(this_res_mask) 147 | example[resolution] = { 148 | "real": this_res_real, 149 | "mask": this_res_mask, 150 | "comp": this_res_comp, 151 | "real_path": paths["gt_path"], 152 | "mask_path": paths["mask_path"], 153 | "comp_path": paths["image_path"], 154 | "caption": caption, 155 | "caption_ids": caption_ids, 156 | } 157 | 158 | return example 159 | 160 | def dilate_mask_image(self, mask: torch.Tensor) -> torch.Tensor: 161 | if self.mask_dilate > 0: 162 | mask_np = (mask * 255).numpy().astype(np.uint8) 163 | mask_np = cv2.dilate( 164 | mask_np, np.ones((self.mask_dilate, self.mask_dilate), np.uint8) 165 | ) 166 | mask = torch.tensor(mask_np.astype(np.float32) / 255.0) 167 | return mask 168 | 169 | 170 | class IhdDatasetSingleRes(Dataset): 171 | def __init__(self, split, tokenizer, resolution, opt): 172 | self.resolution = resolution 173 | self.multires_ds = IhdDatasetMultiRes(split, tokenizer, [resolution], opt) 174 | 175 | def __len__(self): 176 | return len(self.multires_ds) 177 | 178 | def __getitem__(self, index): 179 | return self.multires_ds[index][self.resolution] 180 | 181 | 182 | 183 | import json 184 | import numpy as np 185 | 186 | subset_names = [ 187 | "HAdobe5k", 188 | "HCOCO", 189 | "Hday2night", 190 | "HFlickr", 191 | ] 192 | 193 | def extract_ds_name(path): 194 | for subset_name in subset_names: 195 | if subset_name in path: 196 | return subset_name 197 | return None 198 | 199 | def read_jsonl_file(filename): 200 | data = [] 201 | with open(filename, 'r') as file: 202 | for line in file: 203 | data.append(json.loads(line)) 204 | return data 205 | 206 | 207 | SUBSET_TO_MIX = [ 208 | "HCOCO", 209 | "HFlickr", 210 | "HAdobe5k", 211 | "Hday2night", 212 | ] 213 | 214 | class IhdWithRandomMaskComp(Dataset): 215 | def __init__( 216 | self, 217 | tokenizer, 218 | opt, 219 | ) -> None: 220 | super().__init__() 221 | self.resolution = opt.resolution 222 | self.random_flip = opt.random_flip 223 | self.random_crop = opt.random_crop 224 | self.center_crop = opt.center_crop 225 | self.dataset_root = opt.dataset_root 226 | self.mix_area_thres = opt.mix_area_thres 227 | 228 | if isinstance(self.dataset_root, list): 229 | self.dataset_root = self.dataset_root[0] 230 | self.tokenizer = tokenizer 231 | self.refer_method = opt.refer_method 232 | self.real_mask_mapping = json.load(open(opt.image_mask_mapping, "r")) 233 | self.mask_comp_mapping = json.load(open(opt.mask_comp_mapping, "r")) 234 | if opt.train_file is None: 235 | self.image_rel_paths = list(self.real_mask_mapping.keys()) 236 | else: 237 | train_file_content = [line['file_name'] for line in read_jsonl_file(os.path.join(self.dataset_root, opt.train_file))] 238 | self.image_rel_paths = list(set(self._convert_rel_comp_to_real(train_file_content))) 239 | self.image_rel_paths.sort() 240 | self.image_processor = self.create_image_transforms() 241 | self.image_normalizer = T.Compose([T.ToTensor(), T.Normalize([0.5], [0.5])]) 242 | 243 | def _convert_rel_comp_to_real(self, rel_comp_paths:List[str]): 244 | def _comp_to_real(comp_path): 245 | parts = comp_path.split("/") 246 | img_name_parts = parts[-1].split("_") 247 | real_path = os.path.join(*parts[:-2], "real_images", f"{img_name_parts[0]}.jpg") 248 | return real_path 249 | rel_real_paths = [] 250 | for rel_comp_path in rel_comp_paths: 251 | # rel_real_path = _comp_to_real(insert_ds_suffix(rel_comp_path, self.imgdir_suffix)) 252 | rel_real_path = _comp_to_real(rel_comp_path) 253 | rel_real_paths.append(rel_real_path) 254 | return rel_real_paths 255 | 256 | def create_image_transforms(self): 257 | transforms = [] 258 | if self.random_flip: 259 | transforms.append(T.RandomHorizontalFlip()) 260 | if self.random_crop: 261 | transforms.append( 262 | T.RandomResizedCrop( 263 | size=[self.resolution, self.resolution], 264 | scale=(0.5, 1), 265 | antialias=True, 266 | ), 267 | ) 268 | elif self.center_crop: 269 | transforms.extend( 270 | [ 271 | T.Resize(size=self.resolution, antialias=True), 272 | T.CenterCrop(size=[self.resolution, self.resolution]), 273 | ] 274 | ) 275 | else: 276 | transforms.append( 277 | T.Resize(size=[self.resolution, self.resolution], antialias=True) 278 | ) 279 | 280 | transforms = T.Compose(transforms) 281 | return transforms 282 | 283 | def __getitem__(self, i): 284 | image_rel_path = self.image_rel_paths[i] 285 | image_path = os.path.join(self.dataset_root, image_rel_path) 286 | 287 | image = Image.open(image_path).convert("RGB") 288 | 289 | mask_rel_path = np.random.choice(self.real_mask_mapping[image_rel_path]) 290 | comp_rel_path = np.random.choice(self.mask_comp_mapping[mask_rel_path]) 291 | 292 | mask_path = os.path.join(self.dataset_root, mask_rel_path) 293 | comp_path = os.path.join(self.dataset_root, comp_rel_path) 294 | mask = Image.open(mask_path).convert("1") 295 | comp = Image.open(comp_path).convert("RGB") 296 | 297 | image = self.image_normalizer(image) 298 | comp = self.image_normalizer(comp) 299 | mask = TF.to_tensor(mask).to(dtype=torch.float32) 300 | 301 | for _ in range(5): 302 | merged = torch.cat([image, mask, comp], dim=0) 303 | if tuple(merged.shape[-2:]) == (self.resolution, self.resolution): 304 | break 305 | else: 306 | merged_processed = self.image_processor(merged) 307 | image, mask, comp = torch.split(merged_processed, [3, 1, 3], dim=0) 308 | h, w = mask.shape[-2:] 309 | if self.random_crop and mask.sum() < (0.01 * h * w): 310 | continue 311 | break 312 | 313 | mask = (mask >= 0.5).float() 314 | image = image.clamp(-1, 1) 315 | comp = comp.clamp(-1, 1) 316 | 317 | # caption = self.captions[index] 318 | caption = "" 319 | if self.tokenizer is not None: 320 | caption_ids = self.tokenizer( 321 | caption, 322 | max_length=self.tokenizer.model_max_length, 323 | padding="max_length", 324 | truncation=True, 325 | return_tensors="pt", 326 | ).input_ids[0] 327 | else: 328 | caption_ids = torch.empty(size=(1,)) 329 | 330 | example = { 331 | "image": image, 332 | "image_path": image_path, 333 | "caption_ids": caption_ids, 334 | "subset": extract_ds_name(image_path), 335 | } 336 | 337 | example["mask"] = mask 338 | example["mask_path"] = mask_path 339 | 340 | example["comp"] = comp 341 | example["comp_path"] = comp_path 342 | 343 | select_mask = torch.tensor(1) 344 | if extract_ds_name(image_path) not in SUBSET_TO_MIX: 345 | select_mask = 0 346 | h,w=mask.shape[-2:] 347 | if mask.sum() < self.mix_area_thres: 348 | select_mask=0 349 | example["select_mask"] = select_mask 350 | 351 | if self.refer_method=='batch': 352 | pass 353 | 354 | return example 355 | 356 | def __len__(self) -> int: 357 | return len(self.image_rel_paths) 358 | 359 | 360 | class IhdDatasetWithSDXLMetadata(Dataset): 361 | def __init__(self, split, resolution: int, opt): 362 | self.image_paths = [] 363 | self.captions = [] 364 | self.split = split 365 | self.resolution = resolution 366 | self.random_flip = opt.random_flip 367 | self.random_crop = opt.random_crop 368 | if hasattr(opt, "crop_resolution") and opt.crop_resolution is not None: 369 | self.crop_resolution = opt.crop_resolution 370 | else: 371 | self.crop_resolution = resolution 372 | 373 | data_file = opt.train_file if split == "train" else opt.test_file 374 | 375 | with open(os.path.join(opt.dataset_root, data_file), "r") as f: 376 | for line in f: 377 | cont = json.loads(line.strip()) 378 | image_path = os.path.join( 379 | opt.dataset_root, 380 | cont["file_name"], 381 | ) 382 | self.image_paths.append(image_path) 383 | self.captions.append(cont.get("text", "")) 384 | 385 | self.transforms = Namespace( 386 | resize=T.Resize( 387 | self.resolution, 388 | interpolation=T.InterpolationMode.BILINEAR, 389 | antialias=True, 390 | ), 391 | crop=( 392 | T.CenterCrop(self.crop_resolution) 393 | if not self.random_crop 394 | else T.RandomCrop(self.crop_resolution) 395 | ), 396 | flip = T.RandomHorizontalFlip(p=0.5), 397 | normalize = T.Normalize([0.5], [0.5]) 398 | ) 399 | 400 | def __len__(self): 401 | return len(self.image_paths) 402 | 403 | def __getitem__(self, index): 404 | paths = get_paths(self.image_paths[index]) 405 | 406 | comp = Image.open(paths["image_path"]).convert("RGB") # RGB , [0,255] 407 | mask = Image.open(paths["mask_path"]).convert("1") 408 | real = Image.open(paths["gt_path"]).convert("RGB") # RGB , [0,255] 409 | 410 | original_size = torch.tensor([comp.height, comp.width]) 411 | comp = TF.to_tensor(comp) 412 | real = TF.to_tensor(real) 413 | mask = TF.to_tensor(mask) 414 | all_img = torch.cat([comp, real, mask], dim=0) 415 | all_img = self.transforms.resize(all_img) 416 | 417 | if self.random_flip: 418 | # flip 419 | all_img = self.transforms.flip(all_img) 420 | 421 | if not self.random_crop: 422 | y1 = max(0, int(round((all_img.shape[0] - self.crop_resolution) / 2.0))) 423 | x1 = max(0, int(round((all_img.shape[1] - self.crop_resolution) / 2.0))) 424 | all_img = self.transforms.crop(all_img) 425 | else: 426 | y1, x1, h, w = self.transforms.crop.get_params( 427 | all_img, (self.crop_resolution, self.crop_resolution) 428 | ) 429 | all_img = TF.crop(all_img, y1, x1, h, w) 430 | 431 | crop_top_left = torch.tensor([y1, x1]) 432 | comp, real, mask = torch.split(all_img, [3, 3, 1], dim=0) 433 | comp = self.transforms.normalize(comp) # tensor , [-1,1] 434 | mask = torch.ge(mask, 0.5).float() # >= 0.5 is True 435 | # mask : tensor , 0/1 436 | real = self.transforms.normalize(real) # tensor , [-1,1] 437 | 438 | return { 439 | "real": real, 440 | "mask": mask, 441 | "comp": comp, 442 | "real_path": paths["gt_path"], 443 | "mask_path": paths["mask_path"], 444 | "comp_path": paths["image_path"], 445 | "caption": self.captions[index], 446 | "original_size": original_size, 447 | "crop_top_left": crop_top_left, 448 | } 449 | -------------------------------------------------------------------------------- /src/models/condition_vae.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config 7 | # from diffusers.loaders import FromOriginalVAEMixin 8 | from diffusers.loaders.autoencoder import FromOriginalVAEMixin 9 | from diffusers.utils.accelerate_utils import apply_forward_hook 10 | from diffusers.models.attention_processor import ( 11 | ADDED_KV_ATTENTION_PROCESSORS, 12 | CROSS_ATTENTION_PROCESSORS, 13 | Attention, 14 | AttentionProcessor, 15 | AttnAddedKVProcessor, 16 | AttnProcessor, 17 | ) 18 | from diffusers.models.modeling_outputs import AutoencoderKLOutput 19 | from diffusers.models.modeling_utils import ModelMixin 20 | from diffusers.models.autoencoders.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder 21 | from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL 22 | from .vae import EncoderSkip, DecoderSkip, zero_module 23 | 24 | class ConditionVAE(ModelMixin, ConfigMixin, FromOriginalVAEMixin): 25 | r""" 26 | A VAE model with KL loss for encoding images into latents and decoding latent representations into images. 27 | 28 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 29 | for all models (such as downloading or saving). 30 | 31 | Parameters: 32 | in_channels (int, *optional*, defaults to 3): Number of channels in the input image. 33 | out_channels (int, *optional*, defaults to 3): Number of channels in the output. 34 | down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): 35 | Tuple of downsample block types. 36 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): 37 | Tuple of upsample block types. 38 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): 39 | Tuple of block output channels. 40 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. 41 | latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. 42 | sample_size (`int`, *optional*, defaults to `32`): Sample input size. 43 | scaling_factor (`float`, *optional*, defaults to 0.18215): 44 | The component-wise standard deviation of the trained latent space computed using the first batch of the 45 | training set. This is used to scale the latent space to have unit variance when training the diffusion 46 | model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the 47 | diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 48 | / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image 49 | Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. 50 | force_upcast (`bool`, *optional*, default to `True`): 51 | If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE 52 | can be fine-tuned / trained to a lower range without loosing too much precision in which case 53 | `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix 54 | """ 55 | 56 | _supports_gradient_checkpointing = True 57 | 58 | @register_to_config 59 | def __init__( 60 | self, 61 | in_channels: int = 3, 62 | out_channels: int = 3, 63 | down_block_types: Tuple[str] = ("DownEncoderBlock2D",), 64 | up_block_types: Tuple[str] = ("UpDecoderBlock2D",), 65 | block_out_channels: Tuple[int] = (64,), 66 | layers_per_block: int = 1, 67 | act_fn: str = "silu", 68 | latent_channels: int = 4, 69 | norm_num_groups: int = 32, 70 | sample_size: int = 32, 71 | additional_in_channels : int =1, 72 | scaling_factor: float = 0.18215, 73 | force_upcast: float = True, 74 | ): 75 | super().__init__() 76 | 77 | # original vae encoder 78 | self.encoder = Encoder( 79 | in_channels=in_channels, 80 | out_channels=latent_channels, 81 | down_block_types=down_block_types, 82 | block_out_channels=block_out_channels, 83 | layers_per_block=layers_per_block, 84 | act_fn=act_fn, 85 | norm_num_groups=norm_num_groups, 86 | double_z=True, 87 | ) 88 | # encoder for conditional image with skip connections 89 | self.conditional_encoder = EncoderSkip( 90 | in_channels=in_channels, 91 | out_channels=latent_channels, 92 | down_block_types=down_block_types, 93 | block_out_channels=block_out_channels, 94 | layers_per_block=layers_per_block, 95 | act_fn=act_fn, 96 | norm_num_groups=norm_num_groups, 97 | double_z=True, 98 | additional_in_channels=additional_in_channels, 99 | ) 100 | 101 | # decoder with (optional) skip connections 102 | self.decoder = DecoderSkip( 103 | in_channels=latent_channels, 104 | out_channels=out_channels, 105 | up_block_types=up_block_types, 106 | block_out_channels=block_out_channels, 107 | layers_per_block=layers_per_block, 108 | norm_num_groups=norm_num_groups, 109 | act_fn=act_fn, 110 | ) 111 | 112 | self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) 113 | self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) 114 | 115 | skip_conv = nn.Conv2d(latent_channels, latent_channels, kernel_size=1) 116 | skip_conv = zero_module(skip_conv) 117 | self.latent_skip_conv = skip_conv 118 | 119 | self.use_slicing = False 120 | self.use_tiling = False 121 | 122 | # only relevant if vae tiling is enabled 123 | self.tile_sample_min_size = self.config.sample_size 124 | sample_size = ( 125 | self.config.sample_size[0] 126 | if isinstance(self.config.sample_size, (list, tuple)) 127 | else self.config.sample_size 128 | ) 129 | self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) 130 | self.tile_overlap_factor = 0.25 131 | 132 | def _set_gradient_checkpointing(self, module, value=False): 133 | if isinstance(module, (Encoder, Decoder, EncoderSkip, DecoderSkip)): 134 | module.gradient_checkpointing = value 135 | 136 | def enable_tiling(self, use_tiling: bool = True): 137 | r""" 138 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 139 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow 140 | processing larger images. 141 | """ 142 | self.use_tiling = use_tiling 143 | 144 | def disable_tiling(self): 145 | r""" 146 | Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing 147 | decoding in one step. 148 | """ 149 | self.enable_tiling(False) 150 | 151 | def enable_slicing(self): 152 | r""" 153 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 154 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 155 | """ 156 | self.use_slicing = True 157 | 158 | def disable_slicing(self): 159 | r""" 160 | Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing 161 | decoding in one step. 162 | """ 163 | self.use_slicing = False 164 | 165 | @property 166 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 167 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 168 | r""" 169 | Returns: 170 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 171 | indexed by its weight name. 172 | """ 173 | # set recursively 174 | processors = {} 175 | 176 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 177 | if hasattr(module, "get_processor"): 178 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 179 | 180 | for sub_name, child in module.named_children(): 181 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 182 | 183 | return processors 184 | 185 | for name, module in self.named_children(): 186 | fn_recursive_add_processors(name, module, processors) 187 | 188 | return processors 189 | 190 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 191 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 192 | r""" 193 | Sets the attention processor to use to compute attention. 194 | 195 | Parameters: 196 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 197 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 198 | for **all** `Attention` layers. 199 | 200 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 201 | processor. This is strongly recommended when setting trainable attention processors. 202 | 203 | """ 204 | count = len(self.attn_processors.keys()) 205 | 206 | if isinstance(processor, dict) and len(processor) != count: 207 | raise ValueError( 208 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 209 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 210 | ) 211 | 212 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 213 | if hasattr(module, "set_processor"): 214 | if not isinstance(processor, dict): 215 | module.set_processor(processor) 216 | else: 217 | module.set_processor(processor.pop(f"{name}.processor")) 218 | 219 | for sub_name, child in module.named_children(): 220 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 221 | 222 | for name, module in self.named_children(): 223 | fn_recursive_attn_processor(name, module, processor) 224 | 225 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 226 | def set_default_attn_processor(self): 227 | """ 228 | Disables custom attention processors and sets the default attention implementation. 229 | """ 230 | if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 231 | processor = AttnAddedKVProcessor() 232 | elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 233 | processor = AttnProcessor() 234 | else: 235 | raise ValueError( 236 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 237 | ) 238 | 239 | self.set_attn_processor(processor) 240 | 241 | @apply_forward_hook 242 | def encode( 243 | self, x: torch.FloatTensor, return_dict: bool = True 244 | ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: 245 | """ 246 | Encode a batch of images into latents. 247 | 248 | Args: 249 | x (`torch.FloatTensor`): Input batch of images. 250 | return_dict (`bool`, *optional*, defaults to `True`): 251 | Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. 252 | 253 | Returns: 254 | The latent representations of the encoded images. If `return_dict` is True, a 255 | [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. 256 | """ 257 | x=x.to(self.device, dtype=self.dtype) 258 | if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): 259 | return self.tiled_encode(x, return_dict=return_dict) 260 | 261 | if self.use_slicing and x.shape[0] > 1: 262 | encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] 263 | h = torch.cat(encoded_slices) 264 | else: 265 | h = self.encoder(x) 266 | 267 | moments = self.quant_conv(h) 268 | posterior = DiagonalGaussianDistribution(moments) 269 | 270 | if not return_dict: 271 | return (posterior,) 272 | 273 | return AutoencoderKLOutput(latent_dist=posterior) 274 | 275 | @apply_forward_hook 276 | def encode_cond( 277 | self, cond: torch.FloatTensor 278 | ) -> Tuple[DiagonalGaussianDistribution, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: 279 | """ 280 | Encode a batch of images into latents. 281 | 282 | Args: 283 | x (`torch.FloatTensor`): Input batch of images. 284 | return_dict (`bool`, *optional*, defaults to `True`): 285 | Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. 286 | 287 | Returns: 288 | The latent representations of the encoded images. If `return_dict` is True, a 289 | [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. 290 | """ 291 | cond=cond.to(self.device, dtype=self.dtype) 292 | assert not self.use_tiling 293 | # if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): 294 | # return self.tiled_encode(x, return_dict=return_dict) 295 | 296 | assert not self.use_slicing 297 | # if self.use_slicing and x.shape[0] > 1: 298 | # encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] 299 | # h = torch.cat(encoded_slices) 300 | # else: 301 | # h = self.encoder(x) 302 | 303 | self.conditional_encoder.additional_in_channels = self.config.additional_in_channels 304 | h, down_blocks_res_samples, mid_block_res_samples = self.conditional_encoder(cond) 305 | 306 | moments = self.quant_conv(h) 307 | posterior = DiagonalGaussianDistribution(moments) 308 | 309 | return (posterior, down_blocks_res_samples, mid_block_res_samples) 310 | 311 | def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: 312 | if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): 313 | return self.tiled_decode(z, return_dict=return_dict) 314 | 315 | z = self.post_quant_conv(z) 316 | dec = self.decoder(z) 317 | 318 | if not return_dict: 319 | return (dec,) 320 | 321 | return DecoderOutput(sample=dec) 322 | 323 | @apply_forward_hook 324 | def decode( 325 | self, z: torch.FloatTensor, return_dict: bool = True, generator=None 326 | ) -> Union[DecoderOutput, torch.FloatTensor]: 327 | """ 328 | Decode a batch of images. 329 | 330 | Args: 331 | z (`torch.FloatTensor`): Input batch of latent vectors. 332 | return_dict (`bool`, *optional*, defaults to `True`): 333 | Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. 334 | 335 | Returns: 336 | [`~models.vae.DecoderOutput`] or `tuple`: 337 | If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is 338 | returned. 339 | 340 | """ 341 | z=z.to(self.device, dtype=self.dtype) 342 | if self.use_slicing and z.shape[0] > 1: 343 | decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] 344 | decoded = torch.cat(decoded_slices) 345 | else: 346 | decoded = self._decode(z).sample 347 | 348 | if not return_dict: 349 | return (decoded,) 350 | 351 | return DecoderOutput(sample=decoded) 352 | 353 | @apply_forward_hook 354 | def decode_with_cond( 355 | self, z: torch.FloatTensor, cond: torch.FloatTensor, return_dict: bool = True, generator=None, sample_posterior: bool = True, 356 | ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: 357 | """ 358 | Decode a batch of images. 359 | 360 | Args: 361 | z (`torch.FloatTensor`): Input batch of latent vectors. 362 | return_dict (`bool`, *optional*, defaults to `True`): 363 | Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. 364 | 365 | Returns: 366 | [`~models.vae.DecoderOutput`] or `tuple`: 367 | If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is 368 | returned. 369 | 370 | """ 371 | assert not self.use_slicing 372 | # if self.use_slicing and z.shape[0] > 1: 373 | # decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] 374 | # decoded = torch.cat(decoded_slices) 375 | # else: 376 | # decoded = self._decode(z).sample 377 | z=z.to(self.device, dtype=self.dtype) 378 | cond = cond.to(self.device, dtype=self.dtype) 379 | 380 | posterior_cond, down_block_res_samples, mid_block_res_samples = self.encode_cond(cond) 381 | if sample_posterior: 382 | z_cond = posterior_cond.sample(generator=generator) 383 | else: 384 | z_cond = posterior_cond.mode() 385 | 386 | z = self.post_quant_conv(z) 387 | z_cond = self.post_quant_conv(z_cond) 388 | 389 | z = z + self.latent_skip_conv(z_cond) 390 | 391 | decoded = self.decoder(z, down_block_res_samples, mid_block_res_samples) 392 | if not return_dict: 393 | return (decoded,) 394 | 395 | return DecoderOutput(sample=decoded) 396 | 397 | def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: 398 | blend_extent = min(a.shape[2], b.shape[2], blend_extent) 399 | for y in range(blend_extent): 400 | b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) 401 | return b 402 | 403 | def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: 404 | blend_extent = min(a.shape[3], b.shape[3], blend_extent) 405 | for x in range(blend_extent): 406 | b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) 407 | return b 408 | 409 | def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: 410 | r"""Encode a batch of images using a tiled encoder. 411 | 412 | When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several 413 | steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is 414 | different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the 415 | tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the 416 | output, but they should be much less noticeable. 417 | 418 | Args: 419 | x (`torch.FloatTensor`): Input batch of images. 420 | return_dict (`bool`, *optional*, defaults to `True`): 421 | Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. 422 | 423 | Returns: 424 | [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: 425 | If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain 426 | `tuple` is returned. 427 | """ 428 | overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) 429 | blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) 430 | row_limit = self.tile_latent_min_size - blend_extent 431 | 432 | # Split the image into 512x512 tiles and encode them separately. 433 | rows = [] 434 | for i in range(0, x.shape[2], overlap_size): 435 | row = [] 436 | for j in range(0, x.shape[3], overlap_size): 437 | tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] 438 | tile = self.encoder(tile) 439 | tile = self.quant_conv(tile) 440 | row.append(tile) 441 | rows.append(row) 442 | result_rows = [] 443 | for i, row in enumerate(rows): 444 | result_row = [] 445 | for j, tile in enumerate(row): 446 | # blend the above tile and the left tile 447 | # to the current tile and add the current tile to the result row 448 | if i > 0: 449 | tile = self.blend_v(rows[i - 1][j], tile, blend_extent) 450 | if j > 0: 451 | tile = self.blend_h(row[j - 1], tile, blend_extent) 452 | result_row.append(tile[:, :, :row_limit, :row_limit]) 453 | result_rows.append(torch.cat(result_row, dim=3)) 454 | 455 | moments = torch.cat(result_rows, dim=2) 456 | posterior = DiagonalGaussianDistribution(moments) 457 | 458 | if not return_dict: 459 | return (posterior,) 460 | 461 | return AutoencoderKLOutput(latent_dist=posterior) 462 | 463 | def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: 464 | r""" 465 | Decode a batch of images using a tiled decoder. 466 | 467 | Args: 468 | z (`torch.FloatTensor`): Input batch of latent vectors. 469 | return_dict (`bool`, *optional*, defaults to `True`): 470 | Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. 471 | 472 | Returns: 473 | [`~models.vae.DecoderOutput`] or `tuple`: 474 | If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is 475 | returned. 476 | """ 477 | overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) 478 | blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) 479 | row_limit = self.tile_sample_min_size - blend_extent 480 | 481 | # Split z into overlapping 64x64 tiles and decode them separately. 482 | # The tiles have an overlap to avoid seams between tiles. 483 | rows = [] 484 | for i in range(0, z.shape[2], overlap_size): 485 | row = [] 486 | for j in range(0, z.shape[3], overlap_size): 487 | tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] 488 | tile = self.post_quant_conv(tile) 489 | decoded = self.decoder(tile) 490 | row.append(decoded) 491 | rows.append(row) 492 | result_rows = [] 493 | for i, row in enumerate(rows): 494 | result_row = [] 495 | for j, tile in enumerate(row): 496 | # blend the above tile and the left tile 497 | # to the current tile and add the current tile to the result row 498 | if i > 0: 499 | tile = self.blend_v(rows[i - 1][j], tile, blend_extent) 500 | if j > 0: 501 | tile = self.blend_h(row[j - 1], tile, blend_extent) 502 | result_row.append(tile[:, :, :row_limit, :row_limit]) 503 | result_rows.append(torch.cat(result_row, dim=3)) 504 | 505 | dec = torch.cat(result_rows, dim=2) 506 | if not return_dict: 507 | return (dec,) 508 | 509 | return DecoderOutput(sample=dec) 510 | 511 | def forward( 512 | self, 513 | sample: torch.FloatTensor, 514 | sample_posterior: bool = False, 515 | return_dict: bool = True, 516 | generator: Optional[torch.Generator] = None, 517 | cond: Optional[torch.FloatTensor] = None, 518 | ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: 519 | r""" 520 | Args: 521 | sample (`torch.FloatTensor`): Input sample. 522 | sample_posterior (`bool`, *optional*, defaults to `False`): 523 | Whether to sample from the posterior. 524 | return_dict (`bool`, *optional*, defaults to `True`): 525 | Whether or not to return a [`DecoderOutput`] instead of a plain tuple. 526 | """ 527 | x = sample 528 | posterior = self.encode(x).latent_dist 529 | if sample_posterior: 530 | z = posterior.sample(generator=generator) 531 | else: 532 | z = posterior.mode() 533 | 534 | if cond is not None: 535 | dec = self.decode_with_cond(z, cond, return_dict=return_dict, generator=generator, sample_posterior=sample_posterior).sample 536 | else: 537 | dec = self.decode(z).sample 538 | 539 | if not return_dict: 540 | return (dec,) 541 | 542 | return DecoderOutput(sample=dec) 543 | 544 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections 545 | def fuse_qkv_projections(self): 546 | """ 547 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, 548 | key, value) are fused. For cross-attention modules, key and value projection matrices are fused. 549 | 550 | 551 | 552 | This API is 🧪 experimental. 553 | 554 | 555 | """ 556 | self.original_attn_processors = None 557 | 558 | for _, attn_processor in self.attn_processors.items(): 559 | if "Added" in str(attn_processor.__class__.__name__): 560 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") 561 | 562 | self.original_attn_processors = self.attn_processors 563 | 564 | for module in self.modules(): 565 | if isinstance(module, Attention): 566 | module.fuse_projections(fuse=True) 567 | 568 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 569 | def unfuse_qkv_projections(self): 570 | """Disables the fused QKV projection if enabled. 571 | 572 | 573 | 574 | This API is 🧪 experimental. 575 | 576 | 577 | 578 | """ 579 | if self.original_attn_processors is not None: 580 | self.set_attn_processor(self.original_attn_processors) 581 | 582 | @classmethod 583 | def from_vae( 584 | cls, 585 | vae: AutoencoderKL, 586 | load_weights: bool = True, 587 | **kwargs, 588 | ): 589 | condition_vae = cls.from_config(vae.config, **kwargs) 590 | if load_weights: 591 | condition_vae.encoder.load_state_dict(vae.encoder.state_dict()) 592 | condition_vae.decoder.load_state_dict(vae.decoder.state_dict()) 593 | condition_vae.quant_conv.load_state_dict(vae.quant_conv.state_dict()) 594 | condition_vae.post_quant_conv.load_state_dict(vae.post_quant_conv.state_dict()) 595 | 596 | condition_vae.conditional_encoder.conv_in.load_state_dict(vae.encoder.conv_in.state_dict()) 597 | condition_vae.conditional_encoder.down_blocks.load_state_dict(vae.encoder.down_blocks.state_dict()) 598 | condition_vae.conditional_encoder.mid_block.load_state_dict(vae.encoder.mid_block.state_dict()) 599 | condition_vae.conditional_encoder.conv_norm_out.load_state_dict(vae.encoder.conv_norm_out.state_dict()) 600 | condition_vae.conditional_encoder.conv_out.load_state_dict(vae.encoder.conv_out.state_dict()) 601 | return condition_vae 602 | 603 | def requires_grad_(self, requires_grad: bool = True , freeze_decoder = False): 604 | r"""Change if autograd should record operations on parameters in this module. 605 | 606 | This method sets the parameters' :attr:`requires_grad` attributes 607 | in-place. 608 | 609 | This method is helpful for freezing part of the module for finetuning 610 | or training parts of a model individually (e.g., GAN training). 611 | 612 | See :ref:`locally-disable-grad-doc` for a comparison between 613 | `.requires_grad_()` and several similar mechanisms that may be confused with it. 614 | 615 | Args: 616 | requires_grad (bool): whether autograd should record operations on 617 | parameters in this module. Default: ``True``. 618 | 619 | Returns: 620 | Module: self 621 | """ 622 | for p in self.parameters(): 623 | p.requires_grad_(False) 624 | self.conditional_encoder.requires_grad_(requires_grad) 625 | self.latent_skip_conv.requires_grad_(requires_grad) 626 | if not freeze_decoder: 627 | self.decoder.requires_grad_(requires_grad) 628 | return self -------------------------------------------------------------------------------- /scripts/train/cvae_with_gen_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import gc 4 | import logging 5 | import math 6 | import os 7 | import random 8 | import shutil 9 | from pathlib import Path 10 | from PIL import Image 11 | 12 | import accelerate 13 | import datasets 14 | import numpy as np 15 | import torch 16 | import torch.nn.functional as F 17 | import torchvision.transforms.functional as TF 18 | from torchvision.utils import save_image, make_grid 19 | import torch.utils.checkpoint 20 | import transformers 21 | from accelerate import Accelerator, InitProcessGroupKwargs 22 | from accelerate.logging import get_logger 23 | from accelerate.utils import ProjectConfiguration, set_seed 24 | from accelerate.state import DistributedType 25 | from datasets import load_dataset 26 | from huggingface_hub import create_repo, upload_folder 27 | from packaging import version 28 | from torchvision import transforms 29 | from torchvision.transforms.functional import crop 30 | from tqdm.auto import tqdm 31 | from transformers import AutoTokenizer, PretrainedConfig 32 | 33 | import diffusers 34 | from diffusers import ( 35 | AutoencoderKL, 36 | DDPMScheduler, 37 | UNet2DConditionModel, 38 | ) 39 | from src.pipeline.pipeline_stable_diffusion_xl_harmony import ( 40 | StableDiffusionXLHarmonyPipeline, 41 | ) 42 | from diffusers.optimization import get_scheduler 43 | from diffusers.training_utils import EMAModel, compute_snr 44 | from diffusers.utils import check_min_version, is_wandb_available 45 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card 46 | from diffusers.utils.import_utils import is_xformers_available 47 | from diffusers.utils.torch_utils import is_compiled_module 48 | 49 | # from src.dataset.ihd_dataset import IhdDatasetWithSDXLMetadata as Dataset 50 | from src.dataset.harmony_gen import GenHarmonyDataset as Dataset 51 | from src.models.condition_vae import ConditionVAE 52 | 53 | # # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 54 | # check_min_version("0.27.0.dev0") 55 | 56 | logger = get_logger(__name__) 57 | 58 | 59 | def get_trainable_parameters(model): 60 | trainable_parameters = [p for p in model.parameters() if p.requires_grad == True] 61 | return trainable_parameters 62 | 63 | 64 | def print_trainable_parameters(model): 65 | trainable_parameters = get_trainable_parameters(model) 66 | size = total_params = sum(p.numel() for p in trainable_parameters) 67 | units = ["B", "K", "M", "G"] 68 | unit_index = 0 69 | while size > 1024 and unit_index < len(units) - 1: 70 | size /= 1024 71 | unit_index += 1 72 | print(f"total trainable params : {size:.3f}{units[unit_index]}") 73 | 74 | 75 | def parse_args(input_args=None): 76 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 77 | parser.add_argument( 78 | "--pretrained_vae_model_name_or_path", 79 | type=str, 80 | default=None, 81 | help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", 82 | ) 83 | parser.add_argument( 84 | "--pretrained_condition_vae_model_name_or_path", 85 | type=str, 86 | default=None, 87 | help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", 88 | ) 89 | parser.add_argument( 90 | "--revision", 91 | type=str, 92 | default=None, 93 | required=False, 94 | help="Revision of pretrained model identifier from huggingface.co/models.", 95 | ) 96 | parser.add_argument( 97 | "--variant", 98 | type=str, 99 | default=None, 100 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 101 | ) 102 | 103 | parser.add_argument( 104 | "--dataset_root", 105 | type=str, 106 | default=None, 107 | help=( 108 | "A folder containing the training data. Folder contents must follow the structure described in" 109 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 110 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 111 | ), 112 | ) 113 | parser.add_argument( 114 | "--train_file", 115 | type=str, 116 | default=None, 117 | ) 118 | parser.add_argument( 119 | "--test_file", 120 | type=str, 121 | default=None, 122 | ) 123 | parser.add_argument( 124 | "--output_dir", 125 | type=str, 126 | default="logs/test", 127 | help="The output directory where the model predictions and checkpoints will be written.", 128 | ) 129 | parser.add_argument( 130 | "--seed", type=int, default=42, help="A seed for reproducible training." 131 | ) 132 | parser.add_argument( 133 | "--resolution", 134 | type=int, 135 | default=1024, 136 | help=( 137 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 138 | " resolution" 139 | ), 140 | ) 141 | parser.add_argument( 142 | "--crop_resolution", 143 | type=int, 144 | default=None, 145 | ) 146 | parser.add_argument( 147 | "--random_crop", 148 | default=False, 149 | action="store_true", 150 | help="whether to randomly flip images horizontally", 151 | ) 152 | parser.add_argument( 153 | "--random_flip", 154 | default=False, 155 | action="store_true", 156 | help="whether to randomly flip images horizontally", 157 | ) 158 | parser.add_argument( 159 | "--train_batch_size", 160 | type=int, 161 | default=16, 162 | help="Batch size (per device) for the training dataloader.", 163 | ) 164 | parser.add_argument( 165 | "--eval_batch_size", 166 | type=int, 167 | default=16, 168 | help="The number of images to generate for evaluation.", 169 | ) 170 | parser.add_argument("--num_train_epochs", type=int, default=1) 171 | parser.add_argument( 172 | "--max_train_steps", 173 | type=int, 174 | default=None, 175 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 176 | ) 177 | parser.add_argument( 178 | "--checkpointing_steps", 179 | type=int, 180 | default=500, 181 | help=( 182 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" 183 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" 184 | " training using `--resume_from_checkpoint`." 185 | ), 186 | ) 187 | parser.add_argument( 188 | "--checkpointing_epochs", 189 | type=int, 190 | default=None, 191 | ) 192 | parser.add_argument( 193 | "--checkpoints_total_limit", 194 | type=int, 195 | default=None, 196 | help=("Max number of checkpoints to store."), 197 | ) 198 | parser.add_argument( 199 | "--image_logging_epochs", 200 | type=int, 201 | default=1, 202 | ) 203 | parser.add_argument( 204 | "--resume_from_checkpoint", 205 | type=str, 206 | default=None, 207 | help=( 208 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 209 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 210 | ), 211 | ) 212 | parser.add_argument( 213 | "--gradient_accumulation_steps", 214 | type=int, 215 | default=1, 216 | help="Number of updates steps to accumulate before performing a backward/update pass.", 217 | ) 218 | parser.add_argument( 219 | "--gradient_checkpointing", 220 | action="store_true", 221 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 222 | ) 223 | parser.add_argument( 224 | "--learning_rate", 225 | type=float, 226 | default=1e-4, 227 | help="Initial learning rate (after the potential warmup period) to use.", 228 | ) 229 | parser.add_argument( 230 | "--scale_lr", 231 | action="store_true", 232 | default=False, 233 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 234 | ) 235 | parser.add_argument( 236 | "--lr_scheduler", 237 | type=str, 238 | default="constant", 239 | help=( 240 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 241 | ' "constant", "constant_with_warmup"]' 242 | ), 243 | ) 244 | parser.add_argument( 245 | "--lr_warmup_steps", 246 | type=int, 247 | default=500, 248 | help="Number of steps for the warmup in the lr scheduler.", 249 | ) 250 | parser.add_argument( 251 | "--lr_warmup_ratio", 252 | type=float, 253 | default=None, 254 | ) 255 | parser.add_argument( 256 | "--use_ema", action="store_true", help="Whether to use EMA model." 257 | ) 258 | parser.add_argument( 259 | "--allow_tf32", 260 | action="store_true", 261 | help=( 262 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 263 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 264 | ), 265 | ) 266 | parser.add_argument( 267 | "--dataloader_num_workers", 268 | type=int, 269 | default=0, 270 | help=( 271 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 272 | ), 273 | ) 274 | parser.add_argument( 275 | "--use_8bit_adam", 276 | action="store_true", 277 | help="Whether or not to use 8-bit Adam from bitsandbytes.", 278 | ) 279 | parser.add_argument( 280 | "--adam_beta1", 281 | type=float, 282 | default=0.9, 283 | help="The beta1 parameter for the Adam optimizer.", 284 | ) 285 | parser.add_argument( 286 | "--adam_beta2", 287 | type=float, 288 | default=0.999, 289 | help="The beta2 parameter for the Adam optimizer.", 290 | ) 291 | parser.add_argument( 292 | "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." 293 | ) 294 | parser.add_argument( 295 | "--adam_epsilon", 296 | type=float, 297 | default=1e-08, 298 | help="Epsilon value for the Adam optimizer", 299 | ) 300 | parser.add_argument( 301 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." 302 | ) 303 | parser.add_argument( 304 | "--push_to_hub", 305 | action="store_true", 306 | help="Whether or not to push the model to the Hub.", 307 | ) 308 | parser.add_argument( 309 | "--hub_token", 310 | type=str, 311 | default=None, 312 | help="The token to use to push to the Model Hub.", 313 | ) 314 | parser.add_argument( 315 | "--hub_model_id", 316 | type=str, 317 | default=None, 318 | help="The name of the repository to keep in sync with the local `output_dir`.", 319 | ) 320 | parser.add_argument( 321 | "--logging_dir", 322 | type=str, 323 | default="logs", 324 | help=( 325 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 326 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 327 | ), 328 | ) 329 | parser.add_argument( 330 | "--report_to", 331 | type=str, 332 | default="tensorboard", 333 | help=( 334 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 335 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 336 | ), 337 | ) 338 | parser.add_argument( 339 | "--mixed_precision", 340 | type=str, 341 | default=None, 342 | choices=["no", "fp16", "bf16"], 343 | help=( 344 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 345 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 346 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 347 | ), 348 | ) 349 | parser.add_argument( 350 | "--local_rank", 351 | type=int, 352 | default=-1, 353 | help="For distributed training: local_rank", 354 | ) 355 | 356 | if input_args is not None: 357 | args = parser.parse_args(input_args) 358 | else: 359 | args = parser.parse_args() 360 | 361 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 362 | if env_local_rank != -1 and env_local_rank != args.local_rank: 363 | args.local_rank = env_local_rank 364 | 365 | return args 366 | 367 | def main(args): 368 | if args.report_to == "wandb" and args.hub_token is not None: 369 | raise ValueError( 370 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." 371 | " Please use `huggingface-cli login` to authenticate with the Hub." 372 | ) 373 | 374 | logging_dir = Path(args.output_dir, args.logging_dir) 375 | 376 | accelerator_project_config = ProjectConfiguration( 377 | project_dir=args.output_dir, logging_dir=logging_dir 378 | ) 379 | 380 | accelerator = Accelerator( 381 | gradient_accumulation_steps=args.gradient_accumulation_steps, 382 | mixed_precision=args.mixed_precision, 383 | log_with=args.report_to, 384 | project_config=accelerator_project_config, 385 | ) 386 | 387 | if args.report_to == "wandb": 388 | if not is_wandb_available(): 389 | raise ImportError( 390 | "Make sure to install wandb if you want to use it for logging during training." 391 | ) 392 | import wandb 393 | 394 | # Make one log on every process with the configuration for debugging. 395 | logging.basicConfig( 396 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 397 | datefmt="%m/%d/%Y %H:%M:%S", 398 | level=logging.INFO, 399 | ) 400 | logger.info(accelerator.state, main_process_only=False) 401 | if accelerator.is_local_main_process: 402 | datasets.utils.logging.set_verbosity_warning() 403 | transformers.utils.logging.set_verbosity_warning() 404 | diffusers.utils.logging.set_verbosity_info() 405 | else: 406 | datasets.utils.logging.set_verbosity_error() 407 | transformers.utils.logging.set_verbosity_error() 408 | diffusers.utils.logging.set_verbosity_error() 409 | 410 | # If passed along, set the training seed now. 411 | if args.seed is not None: 412 | set_seed(args.seed) 413 | 414 | # Handle the repository creation 415 | if accelerator.is_main_process: 416 | if args.output_dir is not None: 417 | os.makedirs(args.output_dir, exist_ok=True) 418 | 419 | if args.push_to_hub: 420 | repo_id = create_repo( 421 | repo_id=args.hub_model_id or Path(args.output_dir).name, 422 | exist_ok=True, 423 | token=args.hub_token, 424 | ).repo_id 425 | 426 | if args.pretrained_condition_vae_model_name_or_path is not None: 427 | condition_vae = ConditionVAE.from_pretrained( 428 | args.pretrained_condition_vae_model_name_or_path, 429 | ) 430 | else: 431 | vae = AutoencoderKL.from_pretrained( 432 | args.pretrained_vae_model_name_or_path, 433 | ) 434 | condition_vae = ConditionVAE.from_vae(vae, load_weights=True) 435 | condition_vae.train() 436 | condition_vae.requires_grad_(True) 437 | 438 | model = condition_vae 439 | print_trainable_parameters(model) 440 | 441 | # For mixed precision training we cast all non-trainable weights to half-precision 442 | # as these weights are only used for inference, keeping weights in full precision is not required. 443 | weight_dtype = torch.float32 444 | if accelerator.mixed_precision == "fp16": 445 | weight_dtype = torch.float16 446 | elif accelerator.mixed_precision == "bf16": 447 | weight_dtype = torch.bfloat16 448 | 449 | if args.use_ema: 450 | ema_model = EMAModel( 451 | model.parameters(), 452 | model_cls=ConditionVAE, 453 | model_config=model.config, 454 | ) 455 | 456 | # `accelerate` 0.16.0 will have better support for customized saving 457 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 458 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 459 | def save_model_hook(models, weights, output_dir): 460 | if accelerator.is_main_process: 461 | if args.use_ema: 462 | ema_model.save_pretrained(os.path.join(output_dir, "condition_vae_ema")) 463 | 464 | for i, model in enumerate(models): 465 | model.save_pretrained(os.path.join(output_dir, "condition_vae")) 466 | 467 | # make sure to pop weight so that corresponding model is not saved again 468 | if weights: 469 | weights.pop() 470 | 471 | def load_model_hook(models, input_dir): 472 | if args.use_ema: 473 | load_model = EMAModel.from_pretrained( 474 | os.path.join(input_dir, "condition_vae_ema"), ConditionVAE 475 | ) 476 | ema_model.load_state_dict(load_model.state_dict()) 477 | ema_model.to(accelerator.device) 478 | del load_model 479 | 480 | for _ in range(len(models)): 481 | # pop models so that they are not loaded again 482 | model = models.pop() 483 | 484 | # load diffusers style into model 485 | load_model = ConditionVAE.from_pretrained( 486 | input_dir, subfolder="condition_vae" 487 | ) 488 | model.register_to_config(**load_model.config) 489 | 490 | model.load_state_dict(load_model.state_dict()) 491 | del load_model 492 | 493 | accelerator.register_save_state_pre_hook(save_model_hook) 494 | accelerator.register_load_state_pre_hook(load_model_hook) 495 | 496 | if args.gradient_checkpointing: 497 | model.enable_gradient_checkpointing() 498 | 499 | # Enable TF32 for faster training on Ampere GPUs, 500 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 501 | if args.allow_tf32: 502 | torch.backends.cuda.matmul.allow_tf32 = True 503 | 504 | if args.scale_lr: 505 | args.learning_rate = ( 506 | args.learning_rate 507 | * args.gradient_accumulation_steps 508 | * args.train_batch_size 509 | * accelerator.num_processes 510 | ) 511 | 512 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 513 | if args.use_8bit_adam: 514 | try: 515 | import bitsandbytes as bnb 516 | except ImportError: 517 | raise ImportError( 518 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 519 | ) 520 | 521 | optimizer_class = bnb.optim.AdamW8bit 522 | else: 523 | optimizer_class = torch.optim.AdamW 524 | 525 | # Optimizer creation 526 | optimizer = optimizer_class( 527 | get_trainable_parameters(model), 528 | lr=args.learning_rate, 529 | betas=(args.adam_beta1, args.adam_beta2), 530 | weight_decay=args.adam_weight_decay, 531 | eps=args.adam_epsilon, 532 | ) 533 | 534 | dataset = Dataset(dataset_root=args.dataset_root, resolution=args.resolution) 535 | train_dataloader = torch.utils.data.DataLoader( 536 | dataset, 537 | batch_size=args.train_batch_size, 538 | shuffle=True, 539 | num_workers=args.dataloader_num_workers, 540 | drop_last=True, 541 | ) 542 | 543 | if accelerator.is_main_process: 544 | eval_batch = next(iter(train_dataloader)) 545 | 546 | # Scheduler and math around the number of training steps. 547 | overrode_max_train_steps = False 548 | num_update_steps_per_epoch = math.ceil( 549 | len(train_dataloader) / args.gradient_accumulation_steps 550 | ) 551 | if args.max_train_steps is None: 552 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 553 | overrode_max_train_steps = True 554 | 555 | overide_lr_warmup_steps = False 556 | if args.lr_warmup_ratio is not None: 557 | overide_lr_warmup_steps = True 558 | args.lr_warmup_steps = math.ceil( 559 | args.lr_warmup_ratio * (args.max_train_steps // accelerator.num_processes) 560 | ) 561 | 562 | lr_scheduler = get_scheduler( 563 | args.lr_scheduler, 564 | optimizer=optimizer, 565 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 566 | num_training_steps=args.max_train_steps * accelerator.num_processes, 567 | ) 568 | 569 | # Prepare everything with our `accelerator`. 570 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 571 | model, optimizer, train_dataloader, lr_scheduler 572 | ) 573 | 574 | if args.use_ema: 575 | ema_model.to(accelerator.device) 576 | 577 | total_batch_size = ( 578 | args.train_batch_size 579 | * accelerator.num_processes 580 | * args.gradient_accumulation_steps 581 | ) 582 | 583 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 584 | num_update_steps_per_epoch = math.ceil( 585 | len(train_dataloader) / args.gradient_accumulation_steps 586 | ) 587 | if overrode_max_train_steps: 588 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 589 | # Afterwards we recalculate our number of training epochs 590 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 591 | 592 | if args.checkpointing_epochs is not None: 593 | args.checkpointing_steps = ( 594 | args.checkpointing_epochs * num_update_steps_per_epoch 595 | ) 596 | args.image_logging_steps = args.image_logging_epochs * num_update_steps_per_epoch 597 | 598 | # We need to initialize the trackers we use, and also store our configuration. 599 | # The trackers initializes automatically on the main process. 600 | if accelerator.is_main_process: 601 | run = os.path.split(__file__)[-1].split(".")[0] 602 | accelerator.init_trackers(run, config=vars(args)) 603 | 604 | # Function for unwrapping if torch.compile() was used in accelerate. 605 | def unwrap_model(model): 606 | model = accelerator.unwrap_model(model) 607 | model = model._orig_mod if is_compiled_module(model) else model 608 | return model 609 | 610 | # Train! 611 | 612 | logger.info("***** Running training *****") 613 | logger.info(f" Num examples = {len(dataset)}") 614 | logger.info(f" Num Epochs = {args.num_train_epochs}") 615 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 616 | logger.info( 617 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" 618 | ) 619 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 620 | logger.info(f" Total optimization steps = {args.max_train_steps}") 621 | global_step = 0 622 | first_epoch = 0 623 | 624 | # Potentially load in the weights and states from a previous save 625 | if args.resume_from_checkpoint: 626 | if args.resume_from_checkpoint != "latest": 627 | path = os.path.basename(args.resume_from_checkpoint) 628 | else: 629 | # Get the most recent checkpoint 630 | dirs = os.listdir(args.output_dir) 631 | dirs = [d for d in dirs if d.startswith("checkpoint")] 632 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 633 | path = dirs[-1] if len(dirs) > 0 else None 634 | 635 | if path is None: 636 | accelerator.print( 637 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 638 | ) 639 | args.resume_from_checkpoint = None 640 | initial_global_step = 0 641 | else: 642 | accelerator.print(f"Resuming from checkpoint {path}") 643 | accelerator.load_state(os.path.join(args.output_dir, path)) 644 | global_step = int(path.split("-")[1]) 645 | 646 | initial_global_step = global_step 647 | first_epoch = global_step // num_update_steps_per_epoch 648 | 649 | else: 650 | initial_global_step = 0 651 | 652 | progress_bar = tqdm( 653 | range(0, args.max_train_steps), 654 | initial=initial_global_step, 655 | desc="Steps", 656 | # Only show the progress bar once on each machine. 657 | disable=not accelerator.is_local_main_process, 658 | ) 659 | 660 | for epoch in range(first_epoch, args.num_train_epochs): 661 | train_loss = 0.0 662 | for step, batch in enumerate(train_dataloader): 663 | with accelerator.accumulate(model): 664 | 665 | model_input = batch["image"].to(accelerator.device) 666 | input_cond = batch["cond"].to(accelerator.device) 667 | model_output = model(model_input, sample_posterior=True, cond=input_cond).sample 668 | 669 | target = batch["target"].to(accelerator.device) 670 | loss = F.l1_loss(model_output, target, reduction="mean") 671 | 672 | # Gather the losses across all processes for logging (if we use distributed training). 673 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() 674 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 675 | 676 | # Backpropagate 677 | accelerator.backward(loss) 678 | if accelerator.sync_gradients: 679 | accelerator.clip_grad_norm_( 680 | get_trainable_parameters(model), args.max_grad_norm 681 | ) 682 | optimizer.step() 683 | lr_scheduler.step() 684 | optimizer.zero_grad() 685 | 686 | # Checks if the accelerator has performed an optimization step behind the scenes 687 | if accelerator.sync_gradients: 688 | if args.use_ema: 689 | ema_model.step(model.parameters()) 690 | progress_bar.update(1) 691 | logs = { 692 | "step_loss": train_loss, 693 | "lr": lr_scheduler.get_last_lr()[0], 694 | "epoch": epoch, 695 | "internal_step": step, 696 | } 697 | if args.use_ema: 698 | logs["ema_decay"] = ema_model.cur_decay_value 699 | accelerator.log(logs, step=global_step) 700 | progress_bar.set_postfix(**logs) 701 | global_step += 1 702 | train_loss = 0.0 703 | 704 | if global_step % args.checkpointing_steps == 0: 705 | if ( 706 | accelerator.is_main_process 707 | and args.checkpoints_total_limit is not None 708 | ): 709 | checkpoints = os.listdir(args.output_dir) 710 | checkpoints = [ 711 | d for d in checkpoints if d.startswith("checkpoint") 712 | ] 713 | checkpoints = sorted( 714 | checkpoints, key=lambda x: int(x.split("-")[1]) 715 | ) 716 | 717 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 718 | if len(checkpoints) >= args.checkpoints_total_limit: 719 | num_to_remove = ( 720 | len(checkpoints) - args.checkpoints_total_limit + 1 721 | ) 722 | removing_checkpoints = checkpoints[0:num_to_remove] 723 | 724 | logger.info( 725 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 726 | ) 727 | logger.info( 728 | f"removing checkpoints: {', '.join(removing_checkpoints)}" 729 | ) 730 | 731 | for removing_checkpoint in removing_checkpoints: 732 | removing_checkpoint = os.path.join( 733 | args.output_dir, removing_checkpoint 734 | ) 735 | shutil.rmtree(removing_checkpoint) 736 | if ( 737 | accelerator.distributed_type == DistributedType.DEEPSPEED 738 | or accelerator.is_main_process 739 | ): 740 | save_path = os.path.join( 741 | args.output_dir, f"checkpoint-{global_step}" 742 | ) 743 | accelerator.save_state(save_path) 744 | logger.info(f"Saved state to {save_path}") 745 | 746 | if accelerator.is_main_process: 747 | unwrapped_model = unwrap_model(model) 748 | if args.use_ema: 749 | ema_model.copy_to(unwrapped_model.parameters()) 750 | 751 | model_to_save = unwrapped_model 752 | model_save_dir = os.path.join( 753 | args.output_dir, f"weights-{global_step}" 754 | ) 755 | os.makedirs(model_save_dir, exist_ok=True) 756 | model_to_save.save_pretrained(model_save_dir) 757 | 758 | # Generate sample images for visual inspection 759 | if (global_step % args.image_logging_steps == 0) and ( 760 | accelerator.is_main_process 761 | ): 762 | condition_vae = unwrap_model(model) 763 | if args.use_ema: 764 | # Store the UNet parameters temporarily and load the EMA parameters to perform inference. 765 | ema_model.store(condition_vae.parameters()) 766 | ema_model.copy_to(condition_vae.parameters()) 767 | 768 | generator = ( 769 | torch.Generator(device=accelerator.device).manual_seed( 770 | args.seed 771 | ) 772 | if args.seed 773 | else None 774 | ) 775 | 776 | eval_model_input = eval_batch["image"].to(accelerator.device, dtype=weight_dtype) 777 | eval_input_cond = eval_batch["cond"].to(accelerator.device, dtype=weight_dtype) 778 | eval_target = eval_batch["target"].to(accelerator.device, dtype=weight_dtype) 779 | 780 | with torch.inference_mode(): 781 | eval_conditioned_rec = condition_vae(eval_model_input, sample_posterior=True, cond=eval_input_cond, generator=generator).sample 782 | eval_rec = condition_vae(eval_model_input, sample_posterior=True, generator=generator).sample 783 | 784 | if args.use_ema: 785 | ema_model.restore(condition_vae.parameters()) 786 | 787 | del condition_vae 788 | torch.cuda.empty_cache() 789 | 790 | image_logging_dir = os.path.join(args.output_dir, "images") 791 | if not os.path.exists(image_logging_dir): 792 | os.makedirs(image_logging_dir) 793 | 794 | bs = len(eval_model_input) 795 | nrow = bs // int(math.sqrt(bs)) 796 | 797 | input_to_save = make_grid( 798 | eval_model_input, nrow=nrow, normalize=True, value_range=(-1, 1) 799 | ) 800 | cond_to_save = make_grid( 801 | eval_input_cond, 802 | nrow=nrow, 803 | normalize=True, 804 | value_range=(-1, 1), 805 | ) 806 | rec_to_save = make_grid( 807 | eval_rec, nrow=nrow, normalize=True, value_range=(-1, 1) 808 | ) 809 | conditioned_rec_to_save = make_grid( 810 | eval_conditioned_rec, nrow=nrow, normalize=True, value_range=(-1, 1) 811 | ) 812 | target_to_save = make_grid( 813 | eval_target, nrow=nrow, normalize=True, value_range=(-1, 1) 814 | ) 815 | 816 | save_image( 817 | input_to_save, 818 | os.path.join(image_logging_dir, f"s{global_step:08d}_input.jpg"), 819 | ) 820 | save_image( 821 | cond_to_save, 822 | os.path.join(image_logging_dir, f"s{global_step:08d}_cond.jpg"), 823 | ) 824 | save_image( 825 | rec_to_save, 826 | os.path.join(image_logging_dir, f"s{global_step:08d}_rec.jpg"), 827 | ) 828 | save_image( 829 | conditioned_rec_to_save, 830 | os.path.join(image_logging_dir, f"s{global_step:08d}_conditioned_rec.jpg"), 831 | ) 832 | save_image( 833 | target_to_save, 834 | os.path.join(image_logging_dir, f"s{global_step:08d}_target.jpg"), 835 | ) 836 | if global_step >= args.max_train_steps: 837 | break 838 | progress_bar.close() 839 | accelerator.wait_for_everyone() 840 | accelerator.end_training() 841 | 842 | 843 | if __name__ == "__main__": 844 | args = parse_args() 845 | main(args) 846 | --------------------------------------------------------------------------------