├── .gitignore ├── model-hf └── put hf model here.txt ├── model-sd └── put sd model here.txt ├── .gitmodules ├── scripts ├── conda-env.sh └── install.ipynb ├── accelerate.sh ├── install.sh ├── README.md ├── v1-inference.yaml ├── convert_v3.py ├── back_convert.py ├── dreambooth-aki.ipynb ├── back_convert_alt.py └── train_dreambooth.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints -------------------------------------------------------------------------------- /model-hf/put hf model here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model-sd/put sd model here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "diffusers"] 2 | path = diffusers 3 | url = https://github.com/ShivamShrirao/diffusers 4 | -------------------------------------------------------------------------------- /scripts/conda-env.sh: -------------------------------------------------------------------------------- 1 | conda create -n diffusers python=3.10 2 | conda init bash && source /root/.bashrc 3 | 4 | # 将新的Conda虚拟环境加入jupyterlab中 5 | conda activate diffusers 6 | conda install ipykernel 7 | ipython kernel install --user --name=diffusers -------------------------------------------------------------------------------- /accelerate.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ~/.cache/huggingface/accelerate 2 | 3 | cat > ~/.cache/huggingface/accelerate/default_config.yaml <<- EOM 4 | compute_environment: LOCAL_MACHINE 5 | deepspeed_config: {} 6 | distributed_type: 'NO' 7 | downcast_bf16: 'no' 8 | fsdp_config: {} 9 | machine_rank: 0 10 | main_process_ip: null 11 | main_process_port: null 12 | main_training_function: main 13 | mixed_precision: fp16 14 | num_machines: 1 15 | num_processes: 1 16 | use_cpu: false 17 | EOM -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | git submodule init 2 | git submodule update 3 | 4 | pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple/ 5 | pip config set global.trusted-host https://pypi.tuna.tsinghua.edu.cn/simple/ 6 | pip install ./diffusers 7 | pip install -U --pre triton 8 | pip install accelerate==0.12.0 transformers==4.24.0 ftfy==6.1.1 bitsandbytes==0.35.4 omegaconf==2.2.3 einops==0.5.0 pytorch-lightning==1.7.7 gradio 9 | pip install torchvision 10 | 11 | mkdir "instance-images" "class-images" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dreambooth-autodl 2 | 3 | dreambooth autodl 训练脚本。 4 | 修改自 [Nyanko Lepsoni 的 Colab 笔记本](https://colab.research.google.com/drive/17yM4mlPVOFdJE_81oWBz5mXH9cxvhmz8) 5 | 6 | ## 使用方法 7 | 8 | ### 直接使用autodl镜像 9 | 10 | [dreambooth-autodl](https://www.codewithgpu.com/i/Akegarasu/dreambooth-autodl/dreambooth-autodl) 11 | 12 | ### 手动部署 13 | 14 | 环境选择 Miniconda / conda3 / 3.8(ubuntu20.04) / 11.3 15 | 16 | clone本项目后,首先利用 conda 创建 python 运行环境后再运行 `install.sh` 17 | 18 | ```sh 19 | git clone https://github.com/Akegarasu/dreambooth-autodl.git 20 | cd dreambooth-audodl 21 | conda create -n diffusers python=3.10 22 | conda init bash && source /root/.bashrc 23 | conda activate diffusers 24 | conda install ipykernel 25 | ipython kernel install --user --name=diffusers 26 | bash install.sh 27 | ``` 28 | 29 | 按照提示安装依赖。 30 | 31 | 将本项目文件夹移动到 `/autodl-tmp` 后打开 `dreambooth-aki.ipynb` 运行训练 -------------------------------------------------------------------------------- /scripts/install.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "71eaac90-3419-4980-82c1-c999d742ccb6", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# 安装环境\n", 11 | "## 安装环境\n", 12 | "!git clone https://github.com/ShivamShrirao/diffusers\n", 13 | "%pip config list\n", 14 | "%pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple/\n", 15 | "%pip config set global.trusted-host https://pypi.tuna.tsinghua.edu.cn/simple/\n", 16 | "# !pip install -U pip\n", 17 | "%pip install -q ./diffusers\n", 18 | "%pip install -q -U --pre triton\n", 19 | "%pip install -q accelerate==0.12.0 transformers==4.24.0 ftfy==6.1.1 bitsandbytes==0.35.4 omegaconf==2.2.3 einops==0.5.0 pytorch-lightning==1.7.7 gradio\n", 20 | "%pip install -q torchvision" 21 | ] 22 | } 23 | ], 24 | "metadata": { 25 | "kernelspec": { 26 | "display_name": "diffusers", 27 | "language": "python", 28 | "name": "diffusers" 29 | }, 30 | "language_info": { 31 | "codemirror_mode": { 32 | "name": "ipython", 33 | "version": 3 34 | }, 35 | "file_extension": ".py", 36 | "mimetype": "text/x-python", 37 | "name": "python", 38 | "nbconvert_exporter": "python", 39 | "pygments_lexer": "ipython3", 40 | "version": "3.10.6" 41 | } 42 | }, 43 | "nbformat": 4, 44 | "nbformat_minor": 5 45 | } 46 | -------------------------------------------------------------------------------- /v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /convert_v3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from io import BytesIO 4 | from typing import Optional 5 | import safetensors.torch 6 | 7 | from omegaconf import OmegaConf 8 | import requests 9 | import torch 10 | from transformers import ( 11 | CLIPTextModel, 12 | CLIPTextConfig, 13 | CLIPTokenizer 14 | ) 15 | from diffusers import ( 16 | AutoencoderKL, 17 | DDIMScheduler, 18 | UNet2DConditionModel, 19 | StableDiffusionPipeline 20 | ) 21 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( 22 | convert_ldm_vae_checkpoint, 23 | convert_open_clip_checkpoint, 24 | convert_ldm_clip_checkpoint, 25 | convert_ldm_unet_checkpoint, 26 | create_unet_diffusers_config, 27 | create_vae_diffusers_config 28 | ) 29 | 30 | 31 | def load_model(path): 32 | if path.endswith(".safetensors"): 33 | m = safetensors.torch.load_file(path, device="cpu") 34 | else: 35 | m = torch.load(path, map_location="cpu") 36 | state_dict = m["state_dict"] if "state_dict" in m else m 37 | return state_dict 38 | 39 | 40 | def convert_to_df(checkpoint, config_path="./v1-inference.yaml", return_pipe=False, extract_ema=False): 41 | # key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" 42 | # key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" 43 | # key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" 44 | 45 | global_step = None 46 | if "global_step" in checkpoint: 47 | global_step = checkpoint["global_step"] 48 | 49 | # model_type = "v1" 50 | # config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" 51 | upcast_attention = None 52 | # if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: 53 | # # model_type = "v2" 54 | # config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" 55 | 56 | # if global_step == 110000: 57 | # # v2.1 needs to upcast attention 58 | # upcast_attention = True 59 | # elif key_name_sd_xl_base in checkpoint: 60 | # # only base xl has two text embedders 61 | # config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" 62 | # elif key_name_sd_xl_refiner in checkpoint: 63 | # # only refiner xl has embedder and one text embedders 64 | # config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" 65 | 66 | # original_config_file = BytesIO(requests.get(config_url).content) 67 | original_config_file = BytesIO(open(config_path, "rb").read()) 68 | original_config = OmegaConf.load(original_config_file) 69 | 70 | # Convert the text model. 71 | if ( 72 | "cond_stage_config" in original_config.model.params 73 | and original_config.model.params.cond_stage_config is not None 74 | ): 75 | model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] 76 | elif original_config.model.params.network_config is not None: 77 | if original_config.model.params.network_config.params.context_dim == 2048: 78 | model_type = "SDXL" 79 | else: 80 | model_type = "SDXL-Refiner" 81 | 82 | if ( 83 | "parameterization" in original_config["model"]["params"] 84 | and original_config["model"]["params"]["parameterization"] == "v" 85 | ): 86 | if prediction_type is None: 87 | # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` 88 | # as it relies on a brittle global step parameter here 89 | prediction_type = "epsilon" if global_step == 875000 else "v_prediction" 90 | if image_size is None: 91 | # NOTE: For stable diffusion 2 base one has to pass `image_size==512` 92 | # as it relies on a brittle global step parameter here 93 | image_size = 512 if global_step == 875000 else 768 94 | else: 95 | prediction_type = "epsilon" 96 | image_size = 512 97 | 98 | num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 99 | beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 100 | beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 101 | scheduler = DDIMScheduler( 102 | beta_end=beta_end, 103 | beta_schedule="scaled_linear", 104 | beta_start=beta_start, 105 | num_train_timesteps=num_train_timesteps, 106 | steps_offset=1, 107 | clip_sample=False, 108 | set_alpha_to_one=False, 109 | prediction_type=prediction_type, 110 | ) 111 | # make sure scheduler works correctly with DDIM 112 | scheduler.register_to_config(clip_sample=False) 113 | 114 | # Convert the UNet2DConditionModel model. 115 | unet_config = create_unet_diffusers_config(original_config, image_size=image_size) 116 | unet_config["upcast_attention"] = upcast_attention 117 | unet = UNet2DConditionModel(**unet_config) 118 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema) 119 | 120 | # Convert the VAE model. 121 | vae_config = create_vae_diffusers_config(original_config, image_size=image_size) 122 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) 123 | 124 | if model_type == "FrozenOpenCLIPEmbedder": 125 | text_model = convert_open_clip_checkpoint(checkpoint) 126 | tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") 127 | elif model_type == "FrozenCLIPEmbedder": 128 | keys = list(checkpoint.keys()) 129 | text_model_dict = {} 130 | for key in keys: 131 | if key.startswith("cond_stage_model.transformer"): 132 | dest_key = key[len("cond_stage_model.transformer."):] 133 | if "text_model" not in dest_key: 134 | dest_key = f"text_model.{dest_key}" 135 | text_model_dict[dest_key] = checkpoint[key] 136 | 137 | text_model = CLIPTextModel(CLIPTextConfig.from_pretrained("openai/clip-vit-large-patch14")) 138 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") 139 | if "text_model.embeddings.position_ids" not in text_model.state_dict().keys() \ 140 | and "text_model.embeddings.position_ids" in text_model_dict.keys(): 141 | del text_model_dict["text_model.embeddings.position_ids"] 142 | 143 | if len(text_model_dict) < 10: 144 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") 145 | 146 | if not return_pipe: 147 | return converted_unet_checkpoint, converted_vae_checkpoint, text_model_dict 148 | else: 149 | vae = AutoencoderKL(**vae_config) 150 | vae.load_state_dict(converted_vae_checkpoint) 151 | unet.load_state_dict(converted_unet_checkpoint) 152 | text_model.load_state_dict(text_model_dict) 153 | pipe = StableDiffusionPipeline( 154 | unet=unet, 155 | vae=vae, 156 | text_encoder=text_model, 157 | tokenizer=tokenizer, 158 | scheduler=scheduler, 159 | safety_checker=None, 160 | feature_extractor=None, 161 | requires_safety_checker=False, 162 | ) 163 | 164 | return pipe 165 | 166 | 167 | if __name__ == "__main__": 168 | parser = argparse.ArgumentParser() 169 | 170 | parser.add_argument( 171 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 172 | ) 173 | parser.add_argument( 174 | "--extract_ema", 175 | action="store_true", 176 | default=False, 177 | help=( 178 | "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" 179 | " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" 180 | " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." 181 | ), 182 | ) 183 | # parser.add_argument( 184 | # "--vae_path", default=None, type=str, help="Path to the vae to convert." 185 | # ) 186 | # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml 187 | parser.add_argument( 188 | "--original_config_file", 189 | default=None, 190 | type=str, 191 | help="The YAML config file corresponding to the original architecture.", 192 | ) 193 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") 194 | 195 | args = parser.parse_args() 196 | 197 | if args.original_config_file is None: 198 | if not os.path.exists("./v1-inference.yaml"): 199 | os.system( 200 | "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" 201 | ) 202 | args.original_config_file = "./v1-inference.yaml" 203 | 204 | pipe = convert_to_df(load_model(args.checkpoint_path), config_path=args.original_config_file, return_pipe=True, extract_ema=args.extract_ema) 205 | pipe.save_pretrained(args.dump_path) 206 | -------------------------------------------------------------------------------- /back_convert.py: -------------------------------------------------------------------------------- 1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. 2 | # *Only* converts the UNet, VAE, and Text Encoder. 3 | # Does not convert optimizer state or any other thing. 4 | 5 | import argparse 6 | import os.path as osp 7 | import re 8 | 9 | import torch 10 | 11 | 12 | # =================# 13 | # UNet Conversion # 14 | # =================# 15 | 16 | unet_conversion_map = [ 17 | # (stable-diffusion, HF Diffusers) 18 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), 19 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), 20 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), 21 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), 22 | ("input_blocks.0.0.weight", "conv_in.weight"), 23 | ("input_blocks.0.0.bias", "conv_in.bias"), 24 | ("out.0.weight", "conv_norm_out.weight"), 25 | ("out.0.bias", "conv_norm_out.bias"), 26 | ("out.2.weight", "conv_out.weight"), 27 | ("out.2.bias", "conv_out.bias"), 28 | ] 29 | 30 | unet_conversion_map_resnet = [ 31 | # (stable-diffusion, HF Diffusers) 32 | ("in_layers.0", "norm1"), 33 | ("in_layers.2", "conv1"), 34 | ("out_layers.0", "norm2"), 35 | ("out_layers.3", "conv2"), 36 | ("emb_layers.1", "time_emb_proj"), 37 | ("skip_connection", "conv_shortcut"), 38 | ] 39 | 40 | unet_conversion_map_layer = [] 41 | # hardcoded number of downblocks and resnets/attentions... 42 | # would need smarter logic for other networks. 43 | for i in range(4): 44 | # loop over downblocks/upblocks 45 | 46 | for j in range(2): 47 | # loop over resnets/attentions for downblocks 48 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 49 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 50 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 51 | 52 | if i < 3: 53 | # no attention layers in down_blocks.3 54 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 55 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 56 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 57 | 58 | for j in range(3): 59 | # loop over resnets/attentions for upblocks 60 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 61 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 62 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 63 | 64 | if i > 0: 65 | # no attention layers in up_blocks.0 66 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 67 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 68 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 69 | 70 | if i < 3: 71 | # no downsample in down_blocks.3 72 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 73 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 74 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) 75 | 76 | # no upsample in up_blocks.3 77 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 78 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." 79 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) 80 | 81 | hf_mid_atn_prefix = "mid_block.attentions.0." 82 | sd_mid_atn_prefix = "middle_block.1." 83 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 84 | 85 | for j in range(2): 86 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 87 | sd_mid_res_prefix = f"middle_block.{2*j}." 88 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 89 | 90 | 91 | def convert_unet_state_dict(unet_state_dict): 92 | # buyer beware: this is a *brittle* function, 93 | # and correct output requires that all of these pieces interact in 94 | # the exact order in which I have arranged them. 95 | mapping = {k: k for k in unet_state_dict.keys()} 96 | for sd_name, hf_name in unet_conversion_map: 97 | mapping[hf_name] = sd_name 98 | for k, v in mapping.items(): 99 | if "resnets" in k: 100 | for sd_part, hf_part in unet_conversion_map_resnet: 101 | v = v.replace(hf_part, sd_part) 102 | mapping[k] = v 103 | for k, v in mapping.items(): 104 | for sd_part, hf_part in unet_conversion_map_layer: 105 | v = v.replace(hf_part, sd_part) 106 | mapping[k] = v 107 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} 108 | return new_state_dict 109 | 110 | 111 | # ================# 112 | # VAE Conversion # 113 | # ================# 114 | 115 | vae_conversion_map = [ 116 | # (stable-diffusion, HF Diffusers) 117 | ("nin_shortcut", "conv_shortcut"), 118 | ("norm_out", "conv_norm_out"), 119 | ("mid.attn_1.", "mid_block.attentions.0."), 120 | ] 121 | 122 | for i in range(4): 123 | # down_blocks have two resnets 124 | for j in range(2): 125 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." 126 | sd_down_prefix = f"encoder.down.{i}.block.{j}." 127 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) 128 | 129 | if i < 3: 130 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." 131 | sd_downsample_prefix = f"down.{i}.downsample." 132 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) 133 | 134 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 135 | sd_upsample_prefix = f"up.{3-i}.upsample." 136 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) 137 | 138 | # up_blocks have three resnets 139 | # also, up blocks in hf are numbered in reverse from sd 140 | for j in range(3): 141 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." 142 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." 143 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) 144 | 145 | # this part accounts for mid blocks in both the encoder and the decoder 146 | for i in range(2): 147 | hf_mid_res_prefix = f"mid_block.resnets.{i}." 148 | sd_mid_res_prefix = f"mid.block_{i+1}." 149 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) 150 | 151 | 152 | vae_conversion_map_attn = [ 153 | # (stable-diffusion, HF Diffusers) 154 | ("norm.", "group_norm."), 155 | ("q.", "query."), 156 | ("k.", "key."), 157 | ("v.", "value."), 158 | ("proj_out.", "proj_attn."), 159 | ] 160 | 161 | 162 | def reshape_weight_for_sd(w): 163 | # convert HF linear weights to SD conv2d weights 164 | return w.reshape(*w.shape, 1, 1) 165 | 166 | 167 | def convert_vae_state_dict(vae_state_dict): 168 | mapping = {k: k for k in vae_state_dict.keys()} 169 | for k, v in mapping.items(): 170 | for sd_part, hf_part in vae_conversion_map: 171 | v = v.replace(hf_part, sd_part) 172 | mapping[k] = v 173 | for k, v in mapping.items(): 174 | if "attentions" in k: 175 | for sd_part, hf_part in vae_conversion_map_attn: 176 | v = v.replace(hf_part, sd_part) 177 | mapping[k] = v 178 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} 179 | weights_to_convert = ["q", "k", "v", "proj_out"] 180 | for k, v in new_state_dict.items(): 181 | for weight_name in weights_to_convert: 182 | if f"mid.attn_1.{weight_name}.weight" in k: 183 | print(f"Reshaping {k} for SD format") 184 | new_state_dict[k] = reshape_weight_for_sd(v) 185 | return new_state_dict 186 | 187 | 188 | # =========================# 189 | # Text Encoder Conversion # 190 | # =========================# 191 | 192 | 193 | textenc_conversion_lst = [ 194 | # (stable-diffusion, HF Diffusers) 195 | ("resblocks.", "text_model.encoder.layers."), 196 | ("ln_1", "layer_norm1"), 197 | ("ln_2", "layer_norm2"), 198 | (".c_fc.", ".fc1."), 199 | (".c_proj.", ".fc2."), 200 | (".attn", ".self_attn"), 201 | ("ln_final.", "transformer.text_model.final_layer_norm."), 202 | ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), 203 | ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), 204 | ] 205 | protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst} 206 | textenc_pattern = re.compile("|".join(protected.keys())) 207 | 208 | # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp 209 | code2idx = {"q": 0, "k": 1, "v": 2} 210 | 211 | 212 | def convert_text_enc_state_dict_v20(text_enc_dict): 213 | new_state_dict = {} 214 | capture_qkv_weight = {} 215 | capture_qkv_bias = {} 216 | for k, v in text_enc_dict.items(): 217 | if ( 218 | k.endswith(".self_attn.q_proj.weight") 219 | or k.endswith(".self_attn.k_proj.weight") 220 | or k.endswith(".self_attn.v_proj.weight") 221 | ): 222 | k_pre = k[: -len(".q_proj.weight")] 223 | k_code = k[-len("q_proj.weight")] 224 | if k_pre not in capture_qkv_weight: 225 | capture_qkv_weight[k_pre] = [None, None, None] 226 | capture_qkv_weight[k_pre][code2idx[k_code]] = v 227 | continue 228 | 229 | if ( 230 | k.endswith(".self_attn.q_proj.bias") 231 | or k.endswith(".self_attn.k_proj.bias") 232 | or k.endswith(".self_attn.v_proj.bias") 233 | ): 234 | k_pre = k[: -len(".q_proj.bias")] 235 | k_code = k[-len("q_proj.bias")] 236 | if k_pre not in capture_qkv_bias: 237 | capture_qkv_bias[k_pre] = [None, None, None] 238 | capture_qkv_bias[k_pre][code2idx[k_code]] = v 239 | continue 240 | 241 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k) 242 | new_state_dict[relabelled_key] = v 243 | 244 | for k_pre, tensors in capture_qkv_weight.items(): 245 | if None in tensors: 246 | raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") 247 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) 248 | new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) 249 | 250 | for k_pre, tensors in capture_qkv_bias.items(): 251 | if None in tensors: 252 | raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") 253 | relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) 254 | new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) 255 | 256 | return new_state_dict 257 | 258 | 259 | def convert_text_enc_state_dict(text_enc_dict): 260 | return text_enc_dict 261 | 262 | 263 | if __name__ == "__main__": 264 | parser = argparse.ArgumentParser() 265 | 266 | parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") 267 | parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") 268 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.") 269 | 270 | args = parser.parse_args() 271 | 272 | assert args.model_path is not None, "Must provide a model path!" 273 | 274 | assert args.checkpoint_path is not None, "Must provide a checkpoint path!" 275 | 276 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") 277 | vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") 278 | text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") 279 | 280 | # Convert the UNet model 281 | unet_state_dict = torch.load(unet_path, map_location="cpu") 282 | unet_state_dict = convert_unet_state_dict(unet_state_dict) 283 | unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} 284 | 285 | # Convert the VAE model 286 | vae_state_dict = torch.load(vae_path, map_location="cpu") 287 | vae_state_dict = convert_vae_state_dict(vae_state_dict) 288 | vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} 289 | 290 | # Convert the text encoder model 291 | text_enc_dict = torch.load(text_enc_path, map_location="cpu") 292 | 293 | # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper 294 | is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict 295 | 296 | if is_v20_model: 297 | # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm 298 | text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} 299 | text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict) 300 | text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} 301 | else: 302 | text_enc_dict = convert_text_enc_state_dict(text_enc_dict) 303 | text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} 304 | 305 | # Put together new checkpoint 306 | state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} 307 | state_dict["cond_stage_model.transformer.text_model.embeddings.position_ids"] = torch.Tensor([list(range(77))]).to(torch.int64) 308 | 309 | if args.half: 310 | state_dict = {k: v.half() for k, v in state_dict.items()} 311 | state_dict = {"state_dict": state_dict} 312 | torch.save(state_dict, args.checkpoint_path) -------------------------------------------------------------------------------- /dreambooth-aki.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a0b34c19-4215-46f9-9def-65e73629665c", 6 | "metadata": {}, 7 | "source": [ 8 | "# Dreambooth Stable Diffusion 一键训练" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "da847853-7fd5-4ca2-ab72-13d5885e99bd", 14 | "metadata": {}, 15 | "source": [ 16 | "## 相关说明\n", 17 | "\n", 18 | "详细使用教程 [bilibili 秋葉aaaki](https://www.bilibili.com/video/BV1SR4y1y7Lv/)\n", 19 | "\n", 20 | "修改自 [Nyanko Lepsoni 的 Colab 笔记本](https://colab.research.google.com/drive/17yM4mlPVOFdJE_81oWBz5mXH9cxvhmz8)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "aa154e14-e812-4604-8980-9762e9563b32", 26 | "metadata": { 27 | "tags": [] 28 | }, 29 | "source": [ 30 | "## 准备全局变量" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "2260f499-85cf-481e-b7a3-97bc212ff956", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "import sys\n", 41 | "import os\n", 42 | "\n", 43 | "# 本镜像专属\n", 44 | "os.environ[\"PATH\"] = f'/root/miniconda3/envs/diffusers/bin:{os.environ[\"PATH\"]}'\n", 45 | "os.environ[\"HF_HOME\"] = \".cache\"\n", 46 | "DB_SCRIPT_WORK_PATH = os.getcwd() # \"/root/autodl-tmp/dreambooth-aki\"\n", 47 | "\n", 48 | "!python --version\n", 49 | "%cd $DB_SCRIPT_WORK_PATH\n", 50 | "\n", 51 | "TRAINER = \"train_dreambooth.py\"\n", 52 | "CONVERTER = \"convert_v3.py\"\n", 53 | "BACK_CONVERTER = \"back_convert.py\"\n", 54 | "\n", 55 | "SRC_PATH = \"./model-sd\"\n", 56 | "MODEL_NAME = \"./model-hf\"\n", 57 | "\n", 58 | "# 模型保存路径\n", 59 | "OUTPUT_DIR = \"./output\"\n", 60 | "!mkdir -p $OUTPUT_DIR" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "id": "def72b19-9851-400f-8672-48023b3e95fb", 66 | "metadata": {}, 67 | "source": [ 68 | "## 转换ckpt文件\n", 69 | "\n", 70 | "镜像里,我已经帮你转换好了animefull-final-prune这个模型。并且镜像为了节省空间,并没有自带未转换的模型。\n", 71 | "**如果有model-hf这个文件夹,那就不需要运行这个转换模型了。**" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "id": "b92e663d-7cf1-408e-911b-22270ac8a388", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "# 这步骤有些慢,没准要等个几分钟\n", 82 | "SOURCE_CHECKPOINT_PATH = f\"./model-sd/model.safetensors\" # 源模型位置\n", 83 | "!python $CONVERTER --checkpoint_path $SOURCE_CHECKPOINT_PATH --dump_path $MODEL_NAME" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "id": "746137f4-b054-43de-a2d1-6ebde5cdb3aa", 89 | "metadata": {}, 90 | "source": [ 91 | "## 配置dreambooth训练提示词" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "c0645e79-84f5-491e-8fba-25900beb7c7b", 97 | "metadata": {}, 98 | "source": [ 99 | "以训练人物为例,\n", 100 | "\n", 101 | "INSTANCE_PROMPT 中填入 bocchitherock girl\n", 102 | "这个bocchitherock需要你自己替换为要训练的tag名。你可以指定任意tag,但是需要找一个“不存在的词”。这里的bocchitherock是我做示范用的训练“孤独摇滚”中人物写的一个tag。\n", 103 | "注意:不要再用别的教程里的**sks**了,这个sks是一把枪的名字,可能会生成的时候带上这把枪\n", 104 | "\n", 105 | "CLASS_PROMPT 是让AI自动生成class image用的tag。复制一份INSTANCE_PROMPT,删掉你学习的tag即可。比如这里删掉了bocchitherock\n", 106 | "\n", 107 | "同理,下面的预览图tag设置也记得改。\n", 108 | "\n", 109 | "训练画风时,可以不需要 INSTANCE_PROMPT,直接删空引号内的内容。" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "e425bf69-1427-4f5a-858c-f0e36a42518d", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "# INSTANCE_PROMPT\n", 120 | "INSTANCE_PROMPT = \"bocchitherock girl\"\n", 121 | "INSTANCE_DIR = \"./instance-images\"\n", 122 | "\n", 123 | "# class image 设置\n", 124 | "CLASS_PROMPT = \"masterpiece, best quality, 1girl\"\n", 125 | "CLASS_NEGATIVE_PROMPT = \"lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry\"\n", 126 | "CLASS_DIR = \"./class-images\"\n", 127 | "\n", 128 | "# 预览图tag设置\n", 129 | "SAVE_SAMPLE_PROMPT = \"masterpiece, best quality, bocchitherock 1girl, looking at viewer\"\n", 130 | "SAVE_SAMPLE_NEGATIVE_PROMPT = \"lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry\"" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "id": "f1e7d7e3-35dc-4f8e-b444-0705556c67fa", 136 | "metadata": {}, 137 | "source": [ 138 | "## 训练数据可视化\n", 139 | "\n", 140 | "这里预留了 wandb 和 tensorboard 的可视化。如果你不知道这是什么,就不要改了,直接点运行即可。我已经默认打开了tensorboard。\n", 141 | "如果你会用 wandb,那么可以填写apikey并且将`use_wandb`改为`True`。" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "id": "67b489a6-eff7-42ff-80f6-823457c097d7", 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "use_tensorboard = True \n", 152 | "use_wandb = False\n", 153 | "save_weights_to_wandb = False\n", 154 | "wandb_apikey = \"\"\n", 155 | "\n", 156 | "if use_wandb:\n", 157 | " if wandb_apikey == \"\":\n", 158 | " raise ValueError('Invalid wandb.ai APIKey')\n", 159 | " !python -m wandb login $wandb_apikey\n", 160 | "\n", 161 | "if use_tensorboard:\n", 162 | " !rm -rf /tmp/.tensorboard-info/\n", 163 | " %load_ext tensorboard\n", 164 | " %tensorboard --logdir $OUTPUT_DIR/logs" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "id": "b2fa5cd8-b393-4e0c-af88-72677e29830c", 170 | "metadata": {}, 171 | "source": [ 172 | "## 配置accelerate" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "id": "fe981666-70d3-4f2d-b51a-969e12911a1a", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "!./accelerate.sh" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "id": "327eac82-bc9f-4da5-89b6-5362d26fc72b", 188 | "metadata": {}, 189 | "source": [ 190 | "## 设置训练参数" 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "id": "b0859292-f678-4ec1-aa67-4143a8f7f8e2", 196 | "metadata": {}, 197 | "source": [ 198 | "### max_train_steps\n", 199 | "训练步数\n", 200 | "\n", 201 | "### learning_rate\n", 202 | "学习率\n", 203 | "这里设置的5e-6是科学计数法的(5乘以10的-6次方)。一般就用这个值就可以了,有时候这个默认值有点大,可以小一些比如3e-6。如果你还是觉得太大可以缩小到1e-6、甚至是5e-7等等。\n", 204 | "\n", 205 | "### lr_scheduler\n", 206 | "学习率调整策略\n", 207 | "一般 lr_scheduler 就用cosine、cosine_with_restarts 就可以了。\n", 208 | "想了解更多关于 lr_scheduler 可以看看这个 [知乎](https://www.zhihu.com/question/315772308/answer/2384958925)\n", 209 | "\n", 210 | "### batch_size\n", 211 | "一般是1。我推荐不要超过3。调整 batch_size 需要同时调整学习率\n", 212 | "详情参考我的视频 [BV1A8411775m](https://www.bilibili.com/video/BV1A8411775m/)\n", 213 | "\n", 214 | "### num_class_images\n", 215 | "class image 的数量。如果 class-images 文件夹内的图片数量小于这个值,则会 AI 自动生成一些图片。\n", 216 | "如果关闭了下面的 with_prior_preservation,那么这个参数就没用了。\n", 217 | "\n", 218 | "### with_prior_preservation\n", 219 | "关闭了这个参数以后,训练将不会再用 class images,变为 native training。训练画风需要关闭这个参数\n", 220 | "\n", 221 | "更深入的细节可以参考这个 [DreamBooth讲解](https://guide.novelai.dev/advanced/finetuning/dreambooth)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "id": "1b39775a-9bb6-4f4e-be84-d0558d25befc", 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "# 常用参数\n", 232 | "## 最大训练步数\n", 233 | "max_train_steps = 3000\n", 234 | "## 学习率调整\n", 235 | "learning_rate = 5e-6\n", 236 | "## 学习率调整策略\n", 237 | "## [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\", \"cosine_with_restarts_mod\", \"cosine_mod\"]\n", 238 | "lr_scheduler = \"cosine_with_restarts\"\n", 239 | "lr_warmup_steps = 100\n", 240 | "train_batch_size = 1 # batch_size\n", 241 | "num_class_images = 20 # 自动生成的 class_images 数量\n", 242 | "\n", 243 | "with_prior_preservation = True\n", 244 | "train_text_encoder = False # 训练文本编码器\n", 245 | "use_aspect_ratio_bucket = False # 使用 ARB\n", 246 | "\n", 247 | "# 从文件名读取 prompt\n", 248 | "read_prompt_from_filename = False\n", 249 | "# 从 txt 读取prompt\n", 250 | "read_prompt_from_txt = False\n", 251 | "append_prompt = \"instance\"\n", 252 | "# 保存间隔\n", 253 | "save_interval = 500\n", 254 | "# 使用deepdanbooru\n", 255 | "use_deepdanbooru = False\n", 256 | "\n", 257 | "# 高级参数\n", 258 | "resolution = 512\n", 259 | "gradient_accumulation_steps = 1\n", 260 | "seed = 1337\n", 261 | "log_interval = 10\n", 262 | "clip_skip = 1\n", 263 | "sample_batch_size = 4\n", 264 | "prior_loss_weight = 1.0\n", 265 | "scale_lr = False\n", 266 | "scale_lr_sqrt = False\n", 267 | "gradient_checkpointing = True\n", 268 | "pad_tokens = False\n", 269 | "debug_arb = False\n", 270 | "debug_prompt = False\n", 271 | "use_ema = False\n", 272 | "#only works with _mod scheduler\n", 273 | "restart_cycle = 1\n", 274 | "last_epoch = -1" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "id": "31fe1929-395e-4175-97b7-096b8c1133ee", 280 | "metadata": {}, 281 | "source": [ 282 | "## 如果是从中途继续训练,需要运行下面这个\n", 283 | "\n", 284 | "如果是继续训练就更改这个路径到想继续训练的模型文件夹然后运行这个" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "id": "95c1da3b-c6f2-4312-98a5-dea2dd73c8d2", 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "MODEL_NAME = \"./output/checkpoint_last\"" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "id": "2a6845fc-799f-455e-a723-6fa68e7cb523", 300 | "metadata": {}, 301 | "source": [ 302 | "## 启动训练" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "id": "80fe6d54-783b-4f7b-86ae-3be589e3c718", 309 | "metadata": {}, 310 | "outputs": [], 311 | "source": [ 312 | "print(f\"[*] 模型源路径 {MODEL_NAME}\")\n", 313 | "print(f\"[*] 训练好的模型将会保存在这个路径 {OUTPUT_DIR}\")\n", 314 | "\n", 315 | "ema_arg = \"--use_ema\" if use_ema else \"\"\n", 316 | "da_arg = \"--debug_arb\" if debug_arb else \"\"\n", 317 | "db_arg = \"--debug_prompt\" if debug_prompt else \"\"\n", 318 | "pd_arg = \"--pad_tokens\" if pad_tokens else \"\"\n", 319 | "gdc_arg = \"--gradient_checkpointing\" if gradient_checkpointing else \"\"\n", 320 | "dp_arg = \"--deepdanbooru\" if use_deepdanbooru else \"\" \n", 321 | "scale_lr_arg = \"--scale_lr\" if scale_lr else \"\"\n", 322 | "wandb_arg = \"--wandb\" if use_wandb else \"\"\n", 323 | "extra_prompt_arg = \"--read_prompt_txt\" if read_prompt_from_txt else \"\"\n", 324 | "arb_arg = \"--use_aspect_ratio_bucket\" if use_aspect_ratio_bucket else \"\"\n", 325 | "tte_arg = \"--train_text_encoder\" if train_text_encoder else \"\"\n", 326 | "ppl_arg = f\"--with_prior_preservation --prior_loss_weight={prior_loss_weight}\" if with_prior_preservation else \"\"\n", 327 | "\n", 328 | "if scale_lr_sqrt:\n", 329 | " scale_lr_arg = \"--scale_lr_sqrt\"\n", 330 | "\n", 331 | "if read_prompt_from_filename:\n", 332 | " extra_prompt_arg = \"--read_prompt_filename\"\n", 333 | "\n", 334 | "if save_weights_to_wandb:\n", 335 | " wandb_arg = \"--wandb --wandb_artifact\"\n", 336 | "\n", 337 | "!python -m accelerate.commands.launch $TRAINER \\\n", 338 | " --mixed_precision=\"fp16\" \\\n", 339 | " --pretrained_model_name_or_path=$MODEL_NAME \\\n", 340 | " --instance_data_dir=$INSTANCE_DIR \\\n", 341 | " --class_data_dir=$CLASS_DIR \\\n", 342 | " --output_dir=$OUTPUT_DIR \\\n", 343 | " --instance_prompt=\"{INSTANCE_PROMPT}\" \\\n", 344 | " --class_prompt=\"{CLASS_PROMPT}\" \\\n", 345 | " --class_negative_prompt=\"{CLASS_NEGATIVE_PROMPT}\" \\\n", 346 | " --save_sample_prompt=\"{SAVE_SAMPLE_PROMPT}\" \\\n", 347 | " --save_sample_negative_prompt=\"{SAVE_SAMPLE_NEGATIVE_PROMPT}\" \\\n", 348 | " --seed=$seed \\\n", 349 | " --resolution=$resolution \\\n", 350 | " --train_batch_size=$train_batch_size \\\n", 351 | " --gradient_accumulation_steps=$gradient_accumulation_steps \\\n", 352 | " --learning_rate=$learning_rate \\\n", 353 | " --lr_scheduler=$lr_scheduler \\\n", 354 | " --lr_warmup_steps=$lr_warmup_steps \\\n", 355 | " --num_class_images=$num_class_images \\\n", 356 | " --sample_batch_size=$sample_batch_size \\\n", 357 | " --max_train_steps=$max_train_steps \\\n", 358 | " --save_interval=$save_interval \\\n", 359 | " --log_interval=$log_interval \\\n", 360 | " --clip_skip $clip_skip \\\n", 361 | " --num_cycle=$restart_cycle \\\n", 362 | " --last_epoch=$last_epoch \\\n", 363 | " --append_prompt=$append_prompt \\\n", 364 | " --use_8bit_adam --xformers $da_arg $db_arg $ema_arg \\\n", 365 | " $ppl_arg $wandb_arg $extra_prompt_arg $gdc_arg $arb_arg $tte_arg $scale_lr_arg $dp_arg $pd_arg\n", 366 | "\n", 367 | "# disabled: --not_cache_latents" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "id": "9c6ffa47-7a1b-43aa-b3c9-4570a7de3c1c", 373 | "metadata": {}, 374 | "source": [ 375 | "## 转换训练好的模型到ckpt文件\n", 376 | "\n", 377 | "这里需要你修改model_folder_name, 比如\n", 378 | "checkpoint_1000\n", 379 | "checkpoint_2000\n", 380 | "想转换哪个模型写哪个" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "id": "bf319a13-7e00-4d38-9e91-e50302a3f5bb", 387 | "metadata": {}, 388 | "outputs": [], 389 | "source": [ 390 | "model_folder_name = \"checkpoint_last\"\n", 391 | "convert_model_path = f\"output/{model_folder_name}\"\n", 392 | "ckpt_path = f'{convert_model_path}/model.ckpt'\n", 393 | "save_half = True # 改为 False 保存单精度模型(4G)\n", 394 | "\n", 395 | "ckpt_convert_half_arg = \"--half\" if save_half else \"\"\n", 396 | "\n", 397 | "!python back_convert.py --model_path $convert_model_path --checkpoint_path $ckpt_path $ckpt_convert_half_arg\n", 398 | "print(f\"[*] 转换的模型保存在如下路径 {ckpt_path}\")" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "id": "fdea851b-750e-4d01-9794-9521a85a13d0", 404 | "metadata": {}, 405 | "source": [ 406 | "# 打开生成图像界面测试用\n", 407 | "\n", 408 | "**生成效果与本地webui不太一样,仅供参考**" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": null, 414 | "id": "6e76ef0d-7130-4f11-8086-1d1349db91a3", 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "import torch\n", 419 | "import os\n", 420 | "from torch import autocast\n", 421 | "from diffusers import StableDiffusionPipeline\n", 422 | "from IPython.display import display\n", 423 | "\n", 424 | "\n", 425 | "use_checkpoint = 'checkpoint_last'\n", 426 | "ckpt_model_path = os.path.join(OUTPUT_DIR, use_checkpoint)\n", 427 | "\n", 428 | "pipe = StableDiffusionPipeline.from_pretrained(ckpt_model_path, torch_dtype=torch.float16).to(\"cuda\")\n", 429 | "g_cuda = None\n", 430 | "\n", 431 | "\n", 432 | "import gradio as gr\n", 433 | "\n", 434 | "def inference(prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50, guidance_scale=7.5):\n", 435 | " with torch.autocast(\"cuda\"), torch.inference_mode():\n", 436 | " return pipe(\n", 437 | " prompt, height=int(height), width=int(width),\n", 438 | " negative_prompt=negative_prompt,\n", 439 | " num_images_per_prompt=int(num_samples),\n", 440 | " num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,\n", 441 | " generator=g_cuda\n", 442 | " ).images\n", 443 | "\n", 444 | "with gr.Blocks() as demo:\n", 445 | " with gr.Row():\n", 446 | " with gr.Column():\n", 447 | " prompt = gr.Textbox(label=\"tag\", value=\"masterpiece, best quality,\")\n", 448 | " negative_prompt = gr.Textbox(label=\"负面tag\", value=\"lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry\")\n", 449 | " num_inference_steps = gr.Slider(label=\"Steps\", value=28)\n", 450 | " with gr.Row():\n", 451 | " width = gr.Slider(minimum=64, maximum=2048, step=64, label=\"宽\", value=512)\n", 452 | " height = gr.Slider(minimum=64, maximum=2048, step=64, label=\"高\", value=512)\n", 453 | " with gr.Row():\n", 454 | " num_samples = gr.Number(label=\"批量\", value=1)\n", 455 | " guidance_scale = gr.Number(label=\"Guidance Scale\", value=7)\n", 456 | "\n", 457 | " with gr.Column():\n", 458 | " run = gr.Button(value=\"生成\")\n", 459 | " gallery = gr.Gallery()\n", 460 | "\n", 461 | " run.click(inference, inputs=[prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery)\n", 462 | "\n", 463 | "demo.launch(share=True)" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": null, 469 | "id": "b1ca52f9-83fc-4a4b-8589-3f6ee31ef69a", 470 | "metadata": {}, 471 | "outputs": [], 472 | "source": [] 473 | } 474 | ], 475 | "metadata": { 476 | "kernelspec": { 477 | "display_name": "Python 3 (ipykernel)", 478 | "language": "python", 479 | "name": "python3" 480 | }, 481 | "language_info": { 482 | "codemirror_mode": { 483 | "name": "ipython", 484 | "version": 3 485 | }, 486 | "file_extension": ".py", 487 | "mimetype": "text/x-python", 488 | "name": "python", 489 | "nbconvert_exporter": "python", 490 | "pygments_lexer": "ipython3", 491 | "version": "3.8.10" 492 | } 493 | }, 494 | "nbformat": 4, 495 | "nbformat_minor": 5 496 | } 497 | -------------------------------------------------------------------------------- /back_convert_alt.py: -------------------------------------------------------------------------------- 1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. 2 | # *Only* converts the UNet, VAE, and Text Encoder. 3 | # Does not convert optimizer state or any other thing. 4 | 5 | # From: https://github.com/d8ahazard/stable-diffusion-webui/blob/a658e0457c391def18880bcf958090f8d12de7e6/modules/dreambooth/conversion.py#L205 6 | 7 | import argparse 8 | import os.path as osp 9 | 10 | import torch 11 | from diffusers import StableDiffusionPipeline 12 | 13 | 14 | KeyMap = { 15 | "model.diffusion_model.time_embed.0.weight": "time_embedding.linear_1.weight", 16 | "model.diffusion_model.time_embed.0.bias": "time_embedding.linear_1.bias", 17 | "model.diffusion_model.time_embed.2.weight": "time_embedding.linear_2.weight", 18 | "model.diffusion_model.time_embed.2.bias": "time_embedding.linear_2.bias", 19 | "model.diffusion_model.input_blocks.0.0.weight": "conv_in.weight", 20 | "model.diffusion_model.input_blocks.0.0.bias": "conv_in.bias", 21 | "model.diffusion_model.out.0.weight": "conv_norm_out.weight", 22 | "model.diffusion_model.out.0.bias": "conv_norm_out.bias", 23 | "model.diffusion_model.out.2.weight": "conv_out.weight", 24 | "model.diffusion_model.out.2.bias": "conv_out.bias", 25 | "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "down_blocks.0.resnets.0.norm1.weight", 26 | "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "down_blocks.0.resnets.0.norm1.bias", 27 | "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "down_blocks.0.resnets.0.conv1.weight", 28 | "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "down_blocks.0.resnets.0.conv1.bias", 29 | "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "down_blocks.0.resnets.0.time_emb_proj.weight", 30 | "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "down_blocks.0.resnets.0.time_emb_proj.bias", 31 | "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "down_blocks.0.resnets.0.norm2.weight", 32 | "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "down_blocks.0.resnets.0.norm2.bias", 33 | "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "down_blocks.0.resnets.0.conv2.weight", 34 | "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "down_blocks.0.resnets.0.conv2.bias", 35 | "model.diffusion_model.input_blocks.1.1.norm.weight": "down_blocks.0.attentions.0.norm.weight", 36 | "model.diffusion_model.input_blocks.1.1.norm.bias": "down_blocks.0.attentions.0.norm.bias", 37 | "model.diffusion_model.input_blocks.1.1.proj_in.weight": "down_blocks.0.attentions.0.proj_in.weight", 38 | "model.diffusion_model.input_blocks.1.1.proj_in.bias": "down_blocks.0.attentions.0.proj_in.bias", 39 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight", 40 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight", 41 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight", 42 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", 43 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", 44 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", 45 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", 46 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight", 47 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias", 48 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight", 49 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight", 50 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight", 51 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", 52 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", 53 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "down_blocks.0.attentions.0.transformer_blocks.0.norm1.weight", 54 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "down_blocks.0.attentions.0.transformer_blocks.0.norm1.bias", 55 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "down_blocks.0.attentions.0.transformer_blocks.0.norm2.weight", 56 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "down_blocks.0.attentions.0.transformer_blocks.0.norm2.bias", 57 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "down_blocks.0.attentions.0.transformer_blocks.0.norm3.weight", 58 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "down_blocks.0.attentions.0.transformer_blocks.0.norm3.bias", 59 | "model.diffusion_model.input_blocks.1.1.proj_out.weight": "down_blocks.0.attentions.0.proj_out.weight", 60 | "model.diffusion_model.input_blocks.1.1.proj_out.bias": "down_blocks.0.attentions.0.proj_out.bias", 61 | "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "down_blocks.0.resnets.1.norm1.weight", 62 | "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "down_blocks.0.resnets.1.norm1.bias", 63 | "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "down_blocks.0.resnets.1.conv1.weight", 64 | "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "down_blocks.0.resnets.1.conv1.bias", 65 | "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "down_blocks.0.resnets.1.time_emb_proj.weight", 66 | "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "down_blocks.0.resnets.1.time_emb_proj.bias", 67 | "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "down_blocks.0.resnets.1.norm2.weight", 68 | "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "down_blocks.0.resnets.1.norm2.bias", 69 | "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "down_blocks.0.resnets.1.conv2.weight", 70 | "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "down_blocks.0.resnets.1.conv2.bias", 71 | "model.diffusion_model.input_blocks.2.1.norm.weight": "down_blocks.0.attentions.1.norm.weight", 72 | "model.diffusion_model.input_blocks.2.1.norm.bias": "down_blocks.0.attentions.1.norm.bias", 73 | "model.diffusion_model.input_blocks.2.1.proj_in.weight": "down_blocks.0.attentions.1.proj_in.weight", 74 | "model.diffusion_model.input_blocks.2.1.proj_in.bias": "down_blocks.0.attentions.1.proj_in.bias", 75 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight", 76 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight", 77 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight", 78 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", 79 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", 80 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", 81 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", 82 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight", 83 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias", 84 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight", 85 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight", 86 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight", 87 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", 88 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", 89 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "down_blocks.0.attentions.1.transformer_blocks.0.norm1.weight", 90 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "down_blocks.0.attentions.1.transformer_blocks.0.norm1.bias", 91 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "down_blocks.0.attentions.1.transformer_blocks.0.norm2.weight", 92 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "down_blocks.0.attentions.1.transformer_blocks.0.norm2.bias", 93 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "down_blocks.0.attentions.1.transformer_blocks.0.norm3.weight", 94 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "down_blocks.0.attentions.1.transformer_blocks.0.norm3.bias", 95 | "model.diffusion_model.input_blocks.2.1.proj_out.weight": "down_blocks.0.attentions.1.proj_out.weight", 96 | "model.diffusion_model.input_blocks.2.1.proj_out.bias": "down_blocks.0.attentions.1.proj_out.bias", 97 | "model.diffusion_model.input_blocks.3.0.op.weight": "down_blocks.0.downsamplers.0.conv.weight", 98 | "model.diffusion_model.input_blocks.3.0.op.bias": "down_blocks.0.downsamplers.0.conv.bias", 99 | "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "down_blocks.1.resnets.0.norm1.weight", 100 | "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "down_blocks.1.resnets.0.norm1.bias", 101 | "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "down_blocks.1.resnets.0.conv1.weight", 102 | "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "down_blocks.1.resnets.0.conv1.bias", 103 | "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "down_blocks.1.resnets.0.time_emb_proj.weight", 104 | "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "down_blocks.1.resnets.0.time_emb_proj.bias", 105 | "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "down_blocks.1.resnets.0.norm2.weight", 106 | "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "down_blocks.1.resnets.0.norm2.bias", 107 | "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "down_blocks.1.resnets.0.conv2.weight", 108 | "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "down_blocks.1.resnets.0.conv2.bias", 109 | "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "down_blocks.1.resnets.0.conv_shortcut.weight", 110 | "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "down_blocks.1.resnets.0.conv_shortcut.bias", 111 | "model.diffusion_model.input_blocks.4.1.norm.weight": "down_blocks.1.attentions.0.norm.weight", 112 | "model.diffusion_model.input_blocks.4.1.norm.bias": "down_blocks.1.attentions.0.norm.bias", 113 | "model.diffusion_model.input_blocks.4.1.proj_in.weight": "down_blocks.1.attentions.0.proj_in.weight", 114 | "model.diffusion_model.input_blocks.4.1.proj_in.bias": "down_blocks.1.attentions.0.proj_in.bias", 115 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", 116 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", 117 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", 118 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", 119 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", 120 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", 121 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", 122 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", 123 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", 124 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", 125 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", 126 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", 127 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", 128 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", 129 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", 130 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", 131 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", 132 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", 133 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", 134 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", 135 | "model.diffusion_model.input_blocks.4.1.proj_out.weight": "down_blocks.1.attentions.0.proj_out.weight", 136 | "model.diffusion_model.input_blocks.4.1.proj_out.bias": "down_blocks.1.attentions.0.proj_out.bias", 137 | "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "down_blocks.1.resnets.1.norm1.weight", 138 | "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "down_blocks.1.resnets.1.norm1.bias", 139 | "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "down_blocks.1.resnets.1.conv1.weight", 140 | "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "down_blocks.1.resnets.1.conv1.bias", 141 | "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "down_blocks.1.resnets.1.time_emb_proj.weight", 142 | "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "down_blocks.1.resnets.1.time_emb_proj.bias", 143 | "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "down_blocks.1.resnets.1.norm2.weight", 144 | "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "down_blocks.1.resnets.1.norm2.bias", 145 | "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "down_blocks.1.resnets.1.conv2.weight", 146 | "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "down_blocks.1.resnets.1.conv2.bias", 147 | "model.diffusion_model.input_blocks.5.1.norm.weight": "down_blocks.1.attentions.1.norm.weight", 148 | "model.diffusion_model.input_blocks.5.1.norm.bias": "down_blocks.1.attentions.1.norm.bias", 149 | "model.diffusion_model.input_blocks.5.1.proj_in.weight": "down_blocks.1.attentions.1.proj_in.weight", 150 | "model.diffusion_model.input_blocks.5.1.proj_in.bias": "down_blocks.1.attentions.1.proj_in.bias", 151 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", 152 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", 153 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", 154 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", 155 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", 156 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", 157 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", 158 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", 159 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", 160 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", 161 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", 162 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", 163 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", 164 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", 165 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", 166 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", 167 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", 168 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", 169 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", 170 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", 171 | "model.diffusion_model.input_blocks.5.1.proj_out.weight": "down_blocks.1.attentions.1.proj_out.weight", 172 | "model.diffusion_model.input_blocks.5.1.proj_out.bias": "down_blocks.1.attentions.1.proj_out.bias", 173 | "model.diffusion_model.input_blocks.6.0.op.weight": "down_blocks.1.downsamplers.0.conv.weight", 174 | "model.diffusion_model.input_blocks.6.0.op.bias": "down_blocks.1.downsamplers.0.conv.bias", 175 | "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "down_blocks.2.resnets.0.norm1.weight", 176 | "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "down_blocks.2.resnets.0.norm1.bias", 177 | "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "down_blocks.2.resnets.0.conv1.weight", 178 | "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "down_blocks.2.resnets.0.conv1.bias", 179 | "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "down_blocks.2.resnets.0.time_emb_proj.weight", 180 | "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "down_blocks.2.resnets.0.time_emb_proj.bias", 181 | "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "down_blocks.2.resnets.0.norm2.weight", 182 | "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "down_blocks.2.resnets.0.norm2.bias", 183 | "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "down_blocks.2.resnets.0.conv2.weight", 184 | "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "down_blocks.2.resnets.0.conv2.bias", 185 | "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "down_blocks.2.resnets.0.conv_shortcut.weight", 186 | "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "down_blocks.2.resnets.0.conv_shortcut.bias", 187 | "model.diffusion_model.input_blocks.7.1.norm.weight": "down_blocks.2.attentions.0.norm.weight", 188 | "model.diffusion_model.input_blocks.7.1.norm.bias": "down_blocks.2.attentions.0.norm.bias", 189 | "model.diffusion_model.input_blocks.7.1.proj_in.weight": "down_blocks.2.attentions.0.proj_in.weight", 190 | "model.diffusion_model.input_blocks.7.1.proj_in.bias": "down_blocks.2.attentions.0.proj_in.bias", 191 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", 192 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", 193 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", 194 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", 195 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", 196 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", 197 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", 198 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", 199 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", 200 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", 201 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", 202 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", 203 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", 204 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", 205 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", 206 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", 207 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", 208 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", 209 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", 210 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", 211 | "model.diffusion_model.input_blocks.7.1.proj_out.weight": "down_blocks.2.attentions.0.proj_out.weight", 212 | "model.diffusion_model.input_blocks.7.1.proj_out.bias": "down_blocks.2.attentions.0.proj_out.bias", 213 | "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "down_blocks.2.resnets.1.norm1.weight", 214 | "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "down_blocks.2.resnets.1.norm1.bias", 215 | "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "down_blocks.2.resnets.1.conv1.weight", 216 | "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "down_blocks.2.resnets.1.conv1.bias", 217 | "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "down_blocks.2.resnets.1.time_emb_proj.weight", 218 | "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "down_blocks.2.resnets.1.time_emb_proj.bias", 219 | "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "down_blocks.2.resnets.1.norm2.weight", 220 | "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "down_blocks.2.resnets.1.norm2.bias", 221 | "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "down_blocks.2.resnets.1.conv2.weight", 222 | "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "down_blocks.2.resnets.1.conv2.bias", 223 | "model.diffusion_model.input_blocks.8.1.norm.weight": "down_blocks.2.attentions.1.norm.weight", 224 | "model.diffusion_model.input_blocks.8.1.norm.bias": "down_blocks.2.attentions.1.norm.bias", 225 | "model.diffusion_model.input_blocks.8.1.proj_in.weight": "down_blocks.2.attentions.1.proj_in.weight", 226 | "model.diffusion_model.input_blocks.8.1.proj_in.bias": "down_blocks.2.attentions.1.proj_in.bias", 227 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", 228 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", 229 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", 230 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", 231 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", 232 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", 233 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", 234 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", 235 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", 236 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", 237 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", 238 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", 239 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", 240 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", 241 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", 242 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", 243 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", 244 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", 245 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", 246 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", 247 | "model.diffusion_model.input_blocks.8.1.proj_out.weight": "down_blocks.2.attentions.1.proj_out.weight", 248 | "model.diffusion_model.input_blocks.8.1.proj_out.bias": "down_blocks.2.attentions.1.proj_out.bias", 249 | "model.diffusion_model.input_blocks.9.0.op.weight": "down_blocks.2.downsamplers.0.conv.weight", 250 | "model.diffusion_model.input_blocks.9.0.op.bias": "down_blocks.2.downsamplers.0.conv.bias", 251 | "model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "down_blocks.3.resnets.0.norm1.weight", 252 | "model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "down_blocks.3.resnets.0.norm1.bias", 253 | "model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "down_blocks.3.resnets.0.conv1.weight", 254 | "model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "down_blocks.3.resnets.0.conv1.bias", 255 | "model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "down_blocks.3.resnets.0.time_emb_proj.weight", 256 | "model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "down_blocks.3.resnets.0.time_emb_proj.bias", 257 | "model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "down_blocks.3.resnets.0.norm2.weight", 258 | "model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "down_blocks.3.resnets.0.norm2.bias", 259 | "model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "down_blocks.3.resnets.0.conv2.weight", 260 | "model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "down_blocks.3.resnets.0.conv2.bias", 261 | "model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "down_blocks.3.resnets.1.norm1.weight", 262 | "model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "down_blocks.3.resnets.1.norm1.bias", 263 | "model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "down_blocks.3.resnets.1.conv1.weight", 264 | "model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "down_blocks.3.resnets.1.conv1.bias", 265 | "model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "down_blocks.3.resnets.1.time_emb_proj.weight", 266 | "model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "down_blocks.3.resnets.1.time_emb_proj.bias", 267 | "model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "down_blocks.3.resnets.1.norm2.weight", 268 | "model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "down_blocks.3.resnets.1.norm2.bias", 269 | "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "down_blocks.3.resnets.1.conv2.weight", 270 | "model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "down_blocks.3.resnets.1.conv2.bias", 271 | "model.diffusion_model.middle_block.0.in_layers.0.weight": "mid_block.resnets.0.norm1.weight", 272 | "model.diffusion_model.middle_block.0.in_layers.0.bias": "mid_block.resnets.0.norm1.bias", 273 | "model.diffusion_model.middle_block.0.in_layers.2.weight": "mid_block.resnets.0.conv1.weight", 274 | "model.diffusion_model.middle_block.0.in_layers.2.bias": "mid_block.resnets.0.conv1.bias", 275 | "model.diffusion_model.middle_block.0.emb_layers.1.weight": "mid_block.resnets.0.time_emb_proj.weight", 276 | "model.diffusion_model.middle_block.0.emb_layers.1.bias": "mid_block.resnets.0.time_emb_proj.bias", 277 | "model.diffusion_model.middle_block.0.out_layers.0.weight": "mid_block.resnets.0.norm2.weight", 278 | "model.diffusion_model.middle_block.0.out_layers.0.bias": "mid_block.resnets.0.norm2.bias", 279 | "model.diffusion_model.middle_block.0.out_layers.3.weight": "mid_block.resnets.0.conv2.weight", 280 | "model.diffusion_model.middle_block.0.out_layers.3.bias": "mid_block.resnets.0.conv2.bias", 281 | "model.diffusion_model.middle_block.2.in_layers.0.weight": "mid_block.resnets.1.norm1.weight", 282 | "model.diffusion_model.middle_block.2.in_layers.0.bias": "mid_block.resnets.1.norm1.bias", 283 | "model.diffusion_model.middle_block.2.in_layers.2.weight": "mid_block.resnets.1.conv1.weight", 284 | "model.diffusion_model.middle_block.2.in_layers.2.bias": "mid_block.resnets.1.conv1.bias", 285 | "model.diffusion_model.middle_block.2.emb_layers.1.weight": "mid_block.resnets.1.time_emb_proj.weight", 286 | "model.diffusion_model.middle_block.2.emb_layers.1.bias": "mid_block.resnets.1.time_emb_proj.bias", 287 | "model.diffusion_model.middle_block.2.out_layers.0.weight": "mid_block.resnets.1.norm2.weight", 288 | "model.diffusion_model.middle_block.2.out_layers.0.bias": "mid_block.resnets.1.norm2.bias", 289 | "model.diffusion_model.middle_block.2.out_layers.3.weight": "mid_block.resnets.1.conv2.weight", 290 | "model.diffusion_model.middle_block.2.out_layers.3.bias": "mid_block.resnets.1.conv2.bias", 291 | "model.diffusion_model.middle_block.1.norm.weight": "mid_block.attentions.0.norm.weight", 292 | "model.diffusion_model.middle_block.1.norm.bias": "mid_block.attentions.0.norm.bias", 293 | "model.diffusion_model.middle_block.1.proj_in.weight": "mid_block.attentions.0.proj_in.weight", 294 | "model.diffusion_model.middle_block.1.proj_in.bias": "mid_block.attentions.0.proj_in.bias", 295 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight", 296 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "mid_block.attentions.0.transformer_blocks.0.attn1.to_k.weight", 297 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "mid_block.attentions.0.transformer_blocks.0.attn1.to_v.weight", 298 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", 299 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", 300 | "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", 301 | "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", 302 | "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "mid_block.attentions.0.transformer_blocks.0.ff.net.2.weight", 303 | "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "mid_block.attentions.0.transformer_blocks.0.ff.net.2.bias", 304 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "mid_block.attentions.0.transformer_blocks.0.attn2.to_q.weight", 305 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "mid_block.attentions.0.transformer_blocks.0.attn2.to_k.weight", 306 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight", 307 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", 308 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", 309 | "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "mid_block.attentions.0.transformer_blocks.0.norm1.weight", 310 | "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "mid_block.attentions.0.transformer_blocks.0.norm1.bias", 311 | "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "mid_block.attentions.0.transformer_blocks.0.norm2.weight", 312 | "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "mid_block.attentions.0.transformer_blocks.0.norm2.bias", 313 | "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "mid_block.attentions.0.transformer_blocks.0.norm3.weight", 314 | "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "mid_block.attentions.0.transformer_blocks.0.norm3.bias", 315 | "model.diffusion_model.middle_block.1.proj_out.weight": "mid_block.attentions.0.proj_out.weight", 316 | "model.diffusion_model.middle_block.1.proj_out.bias": "mid_block.attentions.0.proj_out.bias", 317 | "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "up_blocks.0.resnets.0.norm1.weight", 318 | "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "up_blocks.0.resnets.0.norm1.bias", 319 | "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "up_blocks.0.resnets.0.conv1.weight", 320 | "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "up_blocks.0.resnets.0.conv1.bias", 321 | "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "up_blocks.0.resnets.0.time_emb_proj.weight", 322 | "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "up_blocks.0.resnets.0.time_emb_proj.bias", 323 | "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "up_blocks.0.resnets.0.norm2.weight", 324 | "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "up_blocks.0.resnets.0.norm2.bias", 325 | "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "up_blocks.0.resnets.0.conv2.weight", 326 | "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "up_blocks.0.resnets.0.conv2.bias", 327 | "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "up_blocks.0.resnets.0.conv_shortcut.weight", 328 | "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "up_blocks.0.resnets.0.conv_shortcut.bias", 329 | "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "up_blocks.0.resnets.1.norm1.weight", 330 | "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "up_blocks.0.resnets.1.norm1.bias", 331 | "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "up_blocks.0.resnets.1.conv1.weight", 332 | "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "up_blocks.0.resnets.1.conv1.bias", 333 | "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "up_blocks.0.resnets.1.time_emb_proj.weight", 334 | "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "up_blocks.0.resnets.1.time_emb_proj.bias", 335 | "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "up_blocks.0.resnets.1.norm2.weight", 336 | "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "up_blocks.0.resnets.1.norm2.bias", 337 | "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "up_blocks.0.resnets.1.conv2.weight", 338 | "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "up_blocks.0.resnets.1.conv2.bias", 339 | "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "up_blocks.0.resnets.1.conv_shortcut.weight", 340 | "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "up_blocks.0.resnets.1.conv_shortcut.bias", 341 | "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "up_blocks.0.resnets.2.norm1.weight", 342 | "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "up_blocks.0.resnets.2.norm1.bias", 343 | "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "up_blocks.0.resnets.2.conv1.weight", 344 | "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "up_blocks.0.resnets.2.conv1.bias", 345 | "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "up_blocks.0.resnets.2.time_emb_proj.weight", 346 | "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "up_blocks.0.resnets.2.time_emb_proj.bias", 347 | "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "up_blocks.0.resnets.2.norm2.weight", 348 | "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "up_blocks.0.resnets.2.norm2.bias", 349 | "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "up_blocks.0.resnets.2.conv2.weight", 350 | "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "up_blocks.0.resnets.2.conv2.bias", 351 | "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "up_blocks.0.resnets.2.conv_shortcut.weight", 352 | "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "up_blocks.0.resnets.2.conv_shortcut.bias", 353 | "model.diffusion_model.output_blocks.2.1.conv.weight": "up_blocks.0.upsamplers.0.conv.weight", 354 | "model.diffusion_model.output_blocks.2.1.conv.bias": "up_blocks.0.upsamplers.0.conv.bias", 355 | "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "up_blocks.1.resnets.0.norm1.weight", 356 | "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "up_blocks.1.resnets.0.norm1.bias", 357 | "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "up_blocks.1.resnets.0.conv1.weight", 358 | "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "up_blocks.1.resnets.0.conv1.bias", 359 | "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "up_blocks.1.resnets.0.time_emb_proj.weight", 360 | "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "up_blocks.1.resnets.0.time_emb_proj.bias", 361 | "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "up_blocks.1.resnets.0.norm2.weight", 362 | "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "up_blocks.1.resnets.0.norm2.bias", 363 | "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "up_blocks.1.resnets.0.conv2.weight", 364 | "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "up_blocks.1.resnets.0.conv2.bias", 365 | "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "up_blocks.1.resnets.0.conv_shortcut.weight", 366 | "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "up_blocks.1.resnets.0.conv_shortcut.bias", 367 | "model.diffusion_model.output_blocks.3.1.norm.weight": "up_blocks.1.attentions.0.norm.weight", 368 | "model.diffusion_model.output_blocks.3.1.norm.bias": "up_blocks.1.attentions.0.norm.bias", 369 | "model.diffusion_model.output_blocks.3.1.proj_in.weight": "up_blocks.1.attentions.0.proj_in.weight", 370 | "model.diffusion_model.output_blocks.3.1.proj_in.bias": "up_blocks.1.attentions.0.proj_in.bias", 371 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight", 372 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", 373 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight", 374 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", 375 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", 376 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", 377 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", 378 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight", 379 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias", 380 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight", 381 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", 382 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight", 383 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", 384 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", 385 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight", 386 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias", 387 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight", 388 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias", 389 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight", 390 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias", 391 | "model.diffusion_model.output_blocks.3.1.proj_out.weight": "up_blocks.1.attentions.0.proj_out.weight", 392 | "model.diffusion_model.output_blocks.3.1.proj_out.bias": "up_blocks.1.attentions.0.proj_out.bias", 393 | "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "up_blocks.1.resnets.1.norm1.weight", 394 | "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "up_blocks.1.resnets.1.norm1.bias", 395 | "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "up_blocks.1.resnets.1.conv1.weight", 396 | "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "up_blocks.1.resnets.1.conv1.bias", 397 | "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "up_blocks.1.resnets.1.time_emb_proj.weight", 398 | "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "up_blocks.1.resnets.1.time_emb_proj.bias", 399 | "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "up_blocks.1.resnets.1.norm2.weight", 400 | "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "up_blocks.1.resnets.1.norm2.bias", 401 | "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "up_blocks.1.resnets.1.conv2.weight", 402 | "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "up_blocks.1.resnets.1.conv2.bias", 403 | "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "up_blocks.1.resnets.1.conv_shortcut.weight", 404 | "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "up_blocks.1.resnets.1.conv_shortcut.bias", 405 | "model.diffusion_model.output_blocks.4.1.norm.weight": "up_blocks.1.attentions.1.norm.weight", 406 | "model.diffusion_model.output_blocks.4.1.norm.bias": "up_blocks.1.attentions.1.norm.bias", 407 | "model.diffusion_model.output_blocks.4.1.proj_in.weight": "up_blocks.1.attentions.1.proj_in.weight", 408 | "model.diffusion_model.output_blocks.4.1.proj_in.bias": "up_blocks.1.attentions.1.proj_in.bias", 409 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight", 410 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight", 411 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight", 412 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", 413 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", 414 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", 415 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", 416 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight", 417 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias", 418 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight", 419 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight", 420 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight", 421 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", 422 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", 423 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight", 424 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias", 425 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight", 426 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias", 427 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight", 428 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias", 429 | "model.diffusion_model.output_blocks.4.1.proj_out.weight": "up_blocks.1.attentions.1.proj_out.weight", 430 | "model.diffusion_model.output_blocks.4.1.proj_out.bias": "up_blocks.1.attentions.1.proj_out.bias", 431 | "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "up_blocks.1.resnets.2.norm1.weight", 432 | "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "up_blocks.1.resnets.2.norm1.bias", 433 | "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "up_blocks.1.resnets.2.conv1.weight", 434 | "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "up_blocks.1.resnets.2.conv1.bias", 435 | "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "up_blocks.1.resnets.2.time_emb_proj.weight", 436 | "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "up_blocks.1.resnets.2.time_emb_proj.bias", 437 | "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "up_blocks.1.resnets.2.norm2.weight", 438 | "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "up_blocks.1.resnets.2.norm2.bias", 439 | "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "up_blocks.1.resnets.2.conv2.weight", 440 | "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "up_blocks.1.resnets.2.conv2.bias", 441 | "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "up_blocks.1.resnets.2.conv_shortcut.weight", 442 | "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "up_blocks.1.resnets.2.conv_shortcut.bias", 443 | "model.diffusion_model.output_blocks.5.2.conv.weight": "up_blocks.1.upsamplers.0.conv.weight", 444 | "model.diffusion_model.output_blocks.5.2.conv.bias": "up_blocks.1.upsamplers.0.conv.bias", 445 | "model.diffusion_model.output_blocks.5.1.norm.weight": "up_blocks.1.attentions.2.norm.weight", 446 | "model.diffusion_model.output_blocks.5.1.norm.bias": "up_blocks.1.attentions.2.norm.bias", 447 | "model.diffusion_model.output_blocks.5.1.proj_in.weight": "up_blocks.1.attentions.2.proj_in.weight", 448 | "model.diffusion_model.output_blocks.5.1.proj_in.bias": "up_blocks.1.attentions.2.proj_in.bias", 449 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight", 450 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight", 451 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight", 452 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", 453 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", 454 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", 455 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", 456 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight", 457 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias", 458 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight", 459 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight", 460 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight", 461 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", 462 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", 463 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight", 464 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias", 465 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight", 466 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias", 467 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight", 468 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias", 469 | "model.diffusion_model.output_blocks.5.1.proj_out.weight": "up_blocks.1.attentions.2.proj_out.weight", 470 | "model.diffusion_model.output_blocks.5.1.proj_out.bias": "up_blocks.1.attentions.2.proj_out.bias", 471 | "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "up_blocks.2.resnets.0.norm1.weight", 472 | "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "up_blocks.2.resnets.0.norm1.bias", 473 | "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "up_blocks.2.resnets.0.conv1.weight", 474 | "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "up_blocks.2.resnets.0.conv1.bias", 475 | "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "up_blocks.2.resnets.0.time_emb_proj.weight", 476 | "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "up_blocks.2.resnets.0.time_emb_proj.bias", 477 | "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "up_blocks.2.resnets.0.norm2.weight", 478 | "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "up_blocks.2.resnets.0.norm2.bias", 479 | "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "up_blocks.2.resnets.0.conv2.weight", 480 | "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "up_blocks.2.resnets.0.conv2.bias", 481 | "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "up_blocks.2.resnets.0.conv_shortcut.weight", 482 | "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "up_blocks.2.resnets.0.conv_shortcut.bias", 483 | "model.diffusion_model.output_blocks.6.1.norm.weight": "up_blocks.2.attentions.0.norm.weight", 484 | "model.diffusion_model.output_blocks.6.1.norm.bias": "up_blocks.2.attentions.0.norm.bias", 485 | "model.diffusion_model.output_blocks.6.1.proj_in.weight": "up_blocks.2.attentions.0.proj_in.weight", 486 | "model.diffusion_model.output_blocks.6.1.proj_in.bias": "up_blocks.2.attentions.0.proj_in.bias", 487 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight": "up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight", 488 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight": "up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight", 489 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight": "up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight", 490 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight": "up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", 491 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias": "up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", 492 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight": "up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", 493 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias": "up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", 494 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight": "up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight", 495 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias": "up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias", 496 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight": "up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight", 497 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight": "up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight", 498 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight": "up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight", 499 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight": "up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", 500 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias": "up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", 501 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight": "up_blocks.2.attentions.0.transformer_blocks.0.norm1.weight", 502 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias": "up_blocks.2.attentions.0.transformer_blocks.0.norm1.bias", 503 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "up_blocks.2.attentions.0.transformer_blocks.0.norm2.weight", 504 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias": "up_blocks.2.attentions.0.transformer_blocks.0.norm2.bias", 505 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "up_blocks.2.attentions.0.transformer_blocks.0.norm3.weight", 506 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "up_blocks.2.attentions.0.transformer_blocks.0.norm3.bias", 507 | "model.diffusion_model.output_blocks.6.1.proj_out.weight": "up_blocks.2.attentions.0.proj_out.weight", 508 | "model.diffusion_model.output_blocks.6.1.proj_out.bias": "up_blocks.2.attentions.0.proj_out.bias", 509 | "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "up_blocks.2.resnets.1.norm1.weight", 510 | "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "up_blocks.2.resnets.1.norm1.bias", 511 | "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "up_blocks.2.resnets.1.conv1.weight", 512 | "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "up_blocks.2.resnets.1.conv1.bias", 513 | "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "up_blocks.2.resnets.1.time_emb_proj.weight", 514 | "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "up_blocks.2.resnets.1.time_emb_proj.bias", 515 | "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "up_blocks.2.resnets.1.norm2.weight", 516 | "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "up_blocks.2.resnets.1.norm2.bias", 517 | "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "up_blocks.2.resnets.1.conv2.weight", 518 | "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "up_blocks.2.resnets.1.conv2.bias", 519 | "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "up_blocks.2.resnets.1.conv_shortcut.weight", 520 | "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "up_blocks.2.resnets.1.conv_shortcut.bias", 521 | "model.diffusion_model.output_blocks.7.1.norm.weight": "up_blocks.2.attentions.1.norm.weight", 522 | "model.diffusion_model.output_blocks.7.1.norm.bias": "up_blocks.2.attentions.1.norm.bias", 523 | "model.diffusion_model.output_blocks.7.1.proj_in.weight": "up_blocks.2.attentions.1.proj_in.weight", 524 | "model.diffusion_model.output_blocks.7.1.proj_in.bias": "up_blocks.2.attentions.1.proj_in.bias", 525 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight", 526 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight", 527 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight", 528 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", 529 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", 530 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", 531 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", 532 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight", 533 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias", 534 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight", 535 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight", 536 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight", 537 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", 538 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", 539 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight": "up_blocks.2.attentions.1.transformer_blocks.0.norm1.weight", 540 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias": "up_blocks.2.attentions.1.transformer_blocks.0.norm1.bias", 541 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "up_blocks.2.attentions.1.transformer_blocks.0.norm2.weight", 542 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias": "up_blocks.2.attentions.1.transformer_blocks.0.norm2.bias", 543 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "up_blocks.2.attentions.1.transformer_blocks.0.norm3.weight", 544 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "up_blocks.2.attentions.1.transformer_blocks.0.norm3.bias", 545 | "model.diffusion_model.output_blocks.7.1.proj_out.weight": "up_blocks.2.attentions.1.proj_out.weight", 546 | "model.diffusion_model.output_blocks.7.1.proj_out.bias": "up_blocks.2.attentions.1.proj_out.bias", 547 | "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "up_blocks.2.resnets.2.norm1.weight", 548 | "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "up_blocks.2.resnets.2.norm1.bias", 549 | "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "up_blocks.2.resnets.2.conv1.weight", 550 | "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "up_blocks.2.resnets.2.conv1.bias", 551 | "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "up_blocks.2.resnets.2.time_emb_proj.weight", 552 | "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "up_blocks.2.resnets.2.time_emb_proj.bias", 553 | "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "up_blocks.2.resnets.2.norm2.weight", 554 | "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "up_blocks.2.resnets.2.norm2.bias", 555 | "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "up_blocks.2.resnets.2.conv2.weight", 556 | "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "up_blocks.2.resnets.2.conv2.bias", 557 | "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "up_blocks.2.resnets.2.conv_shortcut.weight", 558 | "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "up_blocks.2.resnets.2.conv_shortcut.bias", 559 | "model.diffusion_model.output_blocks.8.2.conv.weight": "up_blocks.2.upsamplers.0.conv.weight", 560 | "model.diffusion_model.output_blocks.8.2.conv.bias": "up_blocks.2.upsamplers.0.conv.bias", 561 | "model.diffusion_model.output_blocks.8.1.norm.weight": "up_blocks.2.attentions.2.norm.weight", 562 | "model.diffusion_model.output_blocks.8.1.norm.bias": "up_blocks.2.attentions.2.norm.bias", 563 | "model.diffusion_model.output_blocks.8.1.proj_in.weight": "up_blocks.2.attentions.2.proj_in.weight", 564 | "model.diffusion_model.output_blocks.8.1.proj_in.bias": "up_blocks.2.attentions.2.proj_in.bias", 565 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_q.weight", 566 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_k.weight", 567 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_v.weight", 568 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", 569 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", 570 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", 571 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", 572 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.weight", 573 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.bias", 574 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_q.weight", 575 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_k.weight", 576 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_v.weight", 577 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", 578 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", 579 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight": "up_blocks.2.attentions.2.transformer_blocks.0.norm1.weight", 580 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias": "up_blocks.2.attentions.2.transformer_blocks.0.norm1.bias", 581 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight": "up_blocks.2.attentions.2.transformer_blocks.0.norm2.weight", 582 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias": "up_blocks.2.attentions.2.transformer_blocks.0.norm2.bias", 583 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "up_blocks.2.attentions.2.transformer_blocks.0.norm3.weight", 584 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias": "up_blocks.2.attentions.2.transformer_blocks.0.norm3.bias", 585 | "model.diffusion_model.output_blocks.8.1.proj_out.weight": "up_blocks.2.attentions.2.proj_out.weight", 586 | "model.diffusion_model.output_blocks.8.1.proj_out.bias": "up_blocks.2.attentions.2.proj_out.bias", 587 | "model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "up_blocks.3.resnets.0.norm1.weight", 588 | "model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "up_blocks.3.resnets.0.norm1.bias", 589 | "model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "up_blocks.3.resnets.0.conv1.weight", 590 | "model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "up_blocks.3.resnets.0.conv1.bias", 591 | "model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "up_blocks.3.resnets.0.time_emb_proj.weight", 592 | "model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "up_blocks.3.resnets.0.time_emb_proj.bias", 593 | "model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "up_blocks.3.resnets.0.norm2.weight", 594 | "model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "up_blocks.3.resnets.0.norm2.bias", 595 | "model.diffusion_model.output_blocks.9.0.out_layers.3.weight": "up_blocks.3.resnets.0.conv2.weight", 596 | "model.diffusion_model.output_blocks.9.0.out_layers.3.bias": "up_blocks.3.resnets.0.conv2.bias", 597 | "model.diffusion_model.output_blocks.9.0.skip_connection.weight": "up_blocks.3.resnets.0.conv_shortcut.weight", 598 | "model.diffusion_model.output_blocks.9.0.skip_connection.bias": "up_blocks.3.resnets.0.conv_shortcut.bias", 599 | "model.diffusion_model.output_blocks.9.1.norm.weight": "up_blocks.3.attentions.0.norm.weight", 600 | "model.diffusion_model.output_blocks.9.1.norm.bias": "up_blocks.3.attentions.0.norm.bias", 601 | "model.diffusion_model.output_blocks.9.1.proj_in.weight": "up_blocks.3.attentions.0.proj_in.weight", 602 | "model.diffusion_model.output_blocks.9.1.proj_in.bias": "up_blocks.3.attentions.0.proj_in.bias", 603 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q.weight": "up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_q.weight", 604 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k.weight": "up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_k.weight", 605 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v.weight": "up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_v.weight", 606 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.weight": "up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.weight", 607 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.bias": "up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.bias", 608 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.weight": "up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.weight", 609 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias": "up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.bias", 610 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight": "up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.weight", 611 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias": "up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.bias", 612 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight": "up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_q.weight", 613 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight": "up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_k.weight", 614 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight": "up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_v.weight", 615 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight": "up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.weight", 616 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias": "up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.bias", 617 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight": "up_blocks.3.attentions.0.transformer_blocks.0.norm1.weight", 618 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias": "up_blocks.3.attentions.0.transformer_blocks.0.norm1.bias", 619 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight": "up_blocks.3.attentions.0.transformer_blocks.0.norm2.weight", 620 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.bias": "up_blocks.3.attentions.0.transformer_blocks.0.norm2.bias", 621 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight": "up_blocks.3.attentions.0.transformer_blocks.0.norm3.weight", 622 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.bias": "up_blocks.3.attentions.0.transformer_blocks.0.norm3.bias", 623 | "model.diffusion_model.output_blocks.9.1.proj_out.weight": "up_blocks.3.attentions.0.proj_out.weight", 624 | "model.diffusion_model.output_blocks.9.1.proj_out.bias": "up_blocks.3.attentions.0.proj_out.bias", 625 | "model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "up_blocks.3.resnets.1.norm1.weight", 626 | "model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "up_blocks.3.resnets.1.norm1.bias", 627 | "model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "up_blocks.3.resnets.1.conv1.weight", 628 | "model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "up_blocks.3.resnets.1.conv1.bias", 629 | "model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "up_blocks.3.resnets.1.time_emb_proj.weight", 630 | "model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "up_blocks.3.resnets.1.time_emb_proj.bias", 631 | "model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "up_blocks.3.resnets.1.norm2.weight", 632 | "model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "up_blocks.3.resnets.1.norm2.bias", 633 | "model.diffusion_model.output_blocks.10.0.out_layers.3.weight": "up_blocks.3.resnets.1.conv2.weight", 634 | "model.diffusion_model.output_blocks.10.0.out_layers.3.bias": "up_blocks.3.resnets.1.conv2.bias", 635 | "model.diffusion_model.output_blocks.10.0.skip_connection.weight": "up_blocks.3.resnets.1.conv_shortcut.weight", 636 | "model.diffusion_model.output_blocks.10.0.skip_connection.bias": "up_blocks.3.resnets.1.conv_shortcut.bias", 637 | "model.diffusion_model.output_blocks.10.1.norm.weight": "up_blocks.3.attentions.1.norm.weight", 638 | "model.diffusion_model.output_blocks.10.1.norm.bias": "up_blocks.3.attentions.1.norm.bias", 639 | "model.diffusion_model.output_blocks.10.1.proj_in.weight": "up_blocks.3.attentions.1.proj_in.weight", 640 | "model.diffusion_model.output_blocks.10.1.proj_in.bias": "up_blocks.3.attentions.1.proj_in.bias", 641 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q.weight": "up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_q.weight", 642 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k.weight": "up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_k.weight", 643 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v.weight": "up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_v.weight", 644 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.weight": "up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.weight", 645 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.bias": "up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.bias", 646 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.weight": "up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.weight", 647 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias": "up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.bias", 648 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight": "up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.weight", 649 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias": "up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.bias", 650 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight": "up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q.weight", 651 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight": "up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_k.weight", 652 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight": "up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_v.weight", 653 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight": "up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.weight", 654 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias": "up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.bias", 655 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight": "up_blocks.3.attentions.1.transformer_blocks.0.norm1.weight", 656 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias": "up_blocks.3.attentions.1.transformer_blocks.0.norm1.bias", 657 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight": "up_blocks.3.attentions.1.transformer_blocks.0.norm2.weight", 658 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.bias": "up_blocks.3.attentions.1.transformer_blocks.0.norm2.bias", 659 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight": "up_blocks.3.attentions.1.transformer_blocks.0.norm3.weight", 660 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias": "up_blocks.3.attentions.1.transformer_blocks.0.norm3.bias", 661 | "model.diffusion_model.output_blocks.10.1.proj_out.weight": "up_blocks.3.attentions.1.proj_out.weight", 662 | "model.diffusion_model.output_blocks.10.1.proj_out.bias": "up_blocks.3.attentions.1.proj_out.bias", 663 | "model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "up_blocks.3.resnets.2.norm1.weight", 664 | "model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "up_blocks.3.resnets.2.norm1.bias", 665 | "model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "up_blocks.3.resnets.2.conv1.weight", 666 | "model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "up_blocks.3.resnets.2.conv1.bias", 667 | "model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "up_blocks.3.resnets.2.time_emb_proj.weight", 668 | "model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "up_blocks.3.resnets.2.time_emb_proj.bias", 669 | "model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "up_blocks.3.resnets.2.norm2.weight", 670 | "model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "up_blocks.3.resnets.2.norm2.bias", 671 | "model.diffusion_model.output_blocks.11.0.out_layers.3.weight": "up_blocks.3.resnets.2.conv2.weight", 672 | "model.diffusion_model.output_blocks.11.0.out_layers.3.bias": "up_blocks.3.resnets.2.conv2.bias", 673 | "model.diffusion_model.output_blocks.11.0.skip_connection.weight": "up_blocks.3.resnets.2.conv_shortcut.weight", 674 | "model.diffusion_model.output_blocks.11.0.skip_connection.bias": "up_blocks.3.resnets.2.conv_shortcut.bias", 675 | "model.diffusion_model.output_blocks.11.1.norm.weight": "up_blocks.3.attentions.2.norm.weight", 676 | "model.diffusion_model.output_blocks.11.1.norm.bias": "up_blocks.3.attentions.2.norm.bias", 677 | "model.diffusion_model.output_blocks.11.1.proj_in.weight": "up_blocks.3.attentions.2.proj_in.weight", 678 | "model.diffusion_model.output_blocks.11.1.proj_in.bias": "up_blocks.3.attentions.2.proj_in.bias", 679 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q.weight": "up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_q.weight", 680 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k.weight": "up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_k.weight", 681 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v.weight": "up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_v.weight", 682 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.weight": "up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.weight", 683 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.bias": "up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.bias", 684 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.weight": "up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.weight", 685 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias": "up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.bias", 686 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight": "up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.weight", 687 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias": "up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.bias", 688 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight": "up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_q.weight", 689 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight": "up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_k.weight", 690 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight": "up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_v.weight", 691 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight": "up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.weight", 692 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias": "up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.bias", 693 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight": "up_blocks.3.attentions.2.transformer_blocks.0.norm1.weight", 694 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias": "up_blocks.3.attentions.2.transformer_blocks.0.norm1.bias", 695 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight": "up_blocks.3.attentions.2.transformer_blocks.0.norm2.weight", 696 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.bias": "up_blocks.3.attentions.2.transformer_blocks.0.norm2.bias", 697 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight": "up_blocks.3.attentions.2.transformer_blocks.0.norm3.weight", 698 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias": "up_blocks.3.attentions.2.transformer_blocks.0.norm3.bias", 699 | "model.diffusion_model.output_blocks.11.1.proj_out.weight": "up_blocks.3.attentions.2.proj_out.weight", 700 | "model.diffusion_model.output_blocks.11.1.proj_out.bias": "up_blocks.3.attentions.2.proj_out.bias" 701 | } 702 | 703 | 704 | if __name__ == "__main__": 705 | parser = argparse.ArgumentParser() 706 | 707 | parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") 708 | parser.add_argument("--src_path", default=None, type=str, required=True, help="Path to the original model.") 709 | parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") 710 | 711 | args = parser.parse_args() 712 | 713 | assert args.model_path is not None, "Must provide a model path!" 714 | assert args.checkpoint_path is not None, "Must provide a checkpoint path!" 715 | assert args.src_path is not None, "Must provide a sourcecheckpoint path!" 716 | 717 | diff_pipe = StableDiffusionPipeline.from_pretrained(args.model_path) 718 | diff_pipe_unet_sd = diff_pipe.unet.state_dict() 719 | 720 | org_model = torch.load(args.src_path) 721 | org_sd = org_model["state_dict"] if "state_dict" in org_model else org_model 722 | 723 | for ckpt_key, diff_key in KeyMap.items(): 724 | org_sd[ckpt_key] = diff_pipe_unet_sd[diff_key] 725 | 726 | torch.save(org_model, args.checkpoint_path) -------------------------------------------------------------------------------- /train_dreambooth.py: -------------------------------------------------------------------------------- 1 | '''Simple script to finetune a stable-diffusion model''' 2 | 3 | import argparse 4 | import contextlib 5 | import copy 6 | import gc 7 | import hashlib 8 | import itertools 9 | import json 10 | import math 11 | import os 12 | import re 13 | import random 14 | import shutil 15 | import subprocess 16 | import time 17 | import atexit 18 | import zipfile 19 | import tempfile 20 | import multiprocessing 21 | from pathlib import Path 22 | from contextlib import nullcontext 23 | from urllib.parse import urlparse 24 | from typing import Iterable 25 | 26 | import numpy as np 27 | import torch 28 | import torch.nn.functional as F 29 | import torch.utils.checkpoint 30 | from torch.utils.data import Dataset 31 | from torch.hub import download_url_to_file, get_dir 32 | 33 | try: 34 | # pip install git+https://github.com/KichangKim/DeepDanbooru 35 | import tensorflow as tf 36 | import deepdanbooru as dd 37 | 38 | gpus = tf.config.experimental.list_physical_devices('GPU') 39 | for gpu in gpus: 40 | tf.config.experimental.set_memory_growth(gpu, True) 41 | except ImportError: 42 | pass 43 | 44 | from accelerate import Accelerator 45 | from accelerate.utils import set_seed 46 | from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel 47 | from diffusers.optimization import ( 48 | get_scheduler, 49 | get_cosine_with_hard_restarts_schedule_with_warmup, 50 | get_cosine_schedule_with_warmup 51 | ) 52 | from PIL import Image 53 | from torchvision import transforms 54 | from tqdm.auto import tqdm 55 | from transformers import CLIPTextModel, CLIPTokenizer 56 | 57 | torch.backends.cudnn.benchmark = True 58 | 59 | def parse_args(): 60 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 61 | parser.add_argument( 62 | "--pretrained_model_name_or_path", 63 | type=str, 64 | default=None, 65 | help="Path to pretrained model or model identifier from huggingface.co/models.", 66 | ) 67 | parser.add_argument( 68 | "--pretrained_vae_name_or_path", 69 | type=str, 70 | default=None, 71 | help="Path to pretrained vae or vae identifier from huggingface.co/models.", 72 | ) 73 | parser.add_argument( 74 | "--tokenizer_name", 75 | type=str, 76 | default=None, 77 | help="Pretrained tokenizer name or path if not the same as model_name", 78 | ) 79 | parser.add_argument( 80 | "--instance_data_dir", 81 | type=str, 82 | default=None, 83 | help="A folder containing the training data of instance images.", 84 | ) 85 | parser.add_argument( 86 | "--class_data_dir", 87 | type=str, 88 | default=None, 89 | help="A folder containing the training data of class images.", 90 | ) 91 | parser.add_argument( 92 | "--instance_prompt", 93 | type=str, 94 | default="", 95 | help="The prompt with identifier specifying the instance", 96 | ) 97 | parser.add_argument( 98 | "--class_prompt", 99 | type=str, 100 | default="", 101 | help="The prompt to specify images in the same class as provided instance images.", 102 | ) 103 | parser.add_argument( 104 | "--class_negative_prompt", 105 | type=str, 106 | default=None, 107 | help="The negative prompt to specify images in the same class as provided instance images.", 108 | ) 109 | parser.add_argument( 110 | "--save_sample_prompt", 111 | type=str, 112 | default=None, 113 | help="The prompt used to generate sample outputs to save.", 114 | ) 115 | parser.add_argument( 116 | "--save_sample_negative_prompt", 117 | type=str, 118 | default=None, 119 | help="The prompt used to generate sample outputs to save.", 120 | ) 121 | parser.add_argument( 122 | "--n_save_sample", 123 | type=int, 124 | default=4, 125 | help="The number of samples to save.", 126 | ) 127 | parser.add_argument( 128 | "--save_guidance_scale", 129 | type=float, 130 | default=7.5, 131 | help="CFG for save sample.", 132 | ) 133 | parser.add_argument( 134 | "--save_infer_steps", 135 | type=int, 136 | default=50, 137 | help="The number of inference steps for save sample.", 138 | ) 139 | parser.add_argument( 140 | "--with_prior_preservation", 141 | default=False, 142 | action="store_true", 143 | help="Flag to add prior preservation loss.", 144 | ) 145 | parser.add_argument( 146 | "--pad_tokens", 147 | default=False, 148 | action="store_true", 149 | help="Flag to pad tokens to length 77.", 150 | ) 151 | parser.add_argument( 152 | "--prior_loss_weight", 153 | type=float, 154 | default=1.0, 155 | help="The weight of prior preservation loss." 156 | ) 157 | parser.add_argument( 158 | "--num_class_images", 159 | type=int, 160 | default=100, 161 | help=( 162 | "Minimal class images for prior preservation loss. If not have enough images," 163 | "additional images will be sampled with class_prompt." 164 | ), 165 | ) 166 | parser.add_argument( 167 | "--output_dir", 168 | type=str, 169 | default="", 170 | help="The output directory where the model predictions and checkpoints will be written.", 171 | ) 172 | parser.add_argument( 173 | "--seed", 174 | type=int, 175 | default=None, 176 | help="A seed for reproducible training." 177 | ) 178 | parser.add_argument( 179 | "--resolution", 180 | type=int, 181 | default=512, 182 | help=( 183 | "The resolution for input images, all the images in the train/validation " 184 | "dataset will be resized to this resolution" 185 | ), 186 | ) 187 | parser.add_argument( 188 | "--center_crop", 189 | action="store_true", 190 | help="Whether to center crop images before resizing to resolution" 191 | ) 192 | parser.add_argument( 193 | "--train_text_encoder", 194 | action="store_true", 195 | help="Whether to train the text encoder" 196 | ) 197 | parser.add_argument( 198 | "--train_batch_size", 199 | type=int, 200 | default=4, 201 | help="Batch size (per device) for the training dataloader." 202 | ) 203 | parser.add_argument( 204 | "--sample_batch_size", 205 | type=int, 206 | default=4, 207 | help="Batch size (per device) for sampling images." 208 | ) 209 | parser.add_argument( 210 | "--num_train_epochs", 211 | type=int, 212 | default=1 213 | ) 214 | parser.add_argument( 215 | "--max_train_steps", 216 | type=int, 217 | default=None, 218 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 219 | ) 220 | parser.add_argument( 221 | "--gradient_accumulation_steps", 222 | type=int, 223 | default=1, 224 | help="Number of updates steps to accumulate before performing a backward/update pass.", 225 | ) 226 | parser.add_argument( 227 | "--gradient_checkpointing", 228 | action="store_true", 229 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 230 | ) 231 | parser.add_argument( 232 | "--learning_rate", 233 | type=float, 234 | default=5e-6, 235 | help="Initial learning rate (after the potential warmup period) to use.", 236 | ) 237 | parser.add_argument( 238 | "--scale_lr", 239 | action="store_true", 240 | default=False, 241 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 242 | ) 243 | parser.add_argument( 244 | "--scale_lr_sqrt", 245 | action="store_true", 246 | default=False, 247 | help="Scale the learning rate using sqrt instead of linear method.", 248 | ) 249 | parser.add_argument( 250 | "--lr_scheduler", 251 | type=str, 252 | default="constant", 253 | help=( 254 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 255 | ' "constant", "constant_with_warmup", "cosine_with_restarts_mod", "cosine_mod"]' 256 | ), 257 | ) 258 | parser.add_argument( 259 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 260 | ) 261 | parser.add_argument( 262 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 263 | ) 264 | parser.add_argument( 265 | "--use_deepspeed_adam", action="store_true", help="Whether or not to use deepspeed Adam." 266 | ) 267 | parser.add_argument( 268 | "--optimizer", 269 | type=str, 270 | default="adamw", 271 | choices=["adamw", "adamw_8bit", "adamw_ds", "sgdm", "sgdm_8bit"], 272 | help=( 273 | "The optimizer to use. _8bit optimizers require bitsandbytes, _ds optimizers require deepspeed." 274 | ) 275 | ) 276 | parser.add_argument( 277 | "--adam_beta1", 278 | type=float, 279 | default=0.9, 280 | help="The beta1 parameter for the Adam optimizer." 281 | ) 282 | parser.add_argument( 283 | "--adam_beta2", 284 | type=float, 285 | default=0.999, 286 | help="The beta2 parameter for the Adam optimizer." 287 | ) 288 | parser.add_argument( 289 | "--adam_epsilon", 290 | type=float, 291 | default=1e-08, 292 | help="Epsilon value for the Adam optimizer" 293 | ) 294 | parser.add_argument( 295 | "--sgd_momentum", 296 | type=float, 297 | default=0.9, 298 | help="Momentum value for the SGDM optimizer" 299 | ) 300 | parser.add_argument( 301 | "--sgd_dampening", 302 | type=float, 303 | default=0, 304 | help="Dampening value for the SGDM optimizer" 305 | ) 306 | parser.add_argument( 307 | "--max_grad_norm", 308 | default=1.0, 309 | type=float, 310 | help="Max gradient norm." 311 | ) 312 | parser.add_argument( 313 | "--weight_decay", 314 | type=float, 315 | default=1e-2, 316 | help="Weight decay to use." 317 | ) 318 | parser.add_argument( 319 | "--logging_dir", 320 | type=str, 321 | default="logs", 322 | help=( 323 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 324 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 325 | ), 326 | ) 327 | parser.add_argument( 328 | "--log_interval", 329 | type=int, 330 | default=10, 331 | help="Log every N steps." 332 | ) 333 | parser.add_argument( 334 | "--save_interval", 335 | type=int, 336 | default=10_000, 337 | help="Save weights every N steps." 338 | ) 339 | parser.add_argument( 340 | "--save_min_steps", 341 | type=int, 342 | default=10, 343 | help="Start saving weights after N steps." 344 | ) 345 | parser.add_argument( 346 | "--mixed_precision", 347 | type=str, 348 | default="no", 349 | choices=["no", "fp16", "bf16"], 350 | help=( 351 | "Whether to use mixed precision. Choose" 352 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 353 | "and an Nvidia Ampere GPU." 354 | ), 355 | ) 356 | parser.add_argument( 357 | "--not_cache_latents", 358 | action="store_true", 359 | help="Do not precompute and cache latents from VAE." 360 | ) 361 | parser.add_argument( 362 | "--local_rank", 363 | type=int, 364 | default=-1, 365 | help="For distributed training: local_rank" 366 | ) 367 | parser.add_argument( 368 | "--concepts_list", 369 | type=str, 370 | default=None, 371 | help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.", 372 | ) 373 | parser.add_argument( 374 | "--wandb", 375 | default=False, 376 | action="store_true", 377 | help="Use wandb to watch training process.", 378 | ) 379 | parser.add_argument( 380 | "--wandb_artifact", 381 | default=False, 382 | action="store_true", 383 | help="Upload saved weights to wandb.", 384 | ) 385 | parser.add_argument( 386 | "--rm_after_wandb_saved", 387 | default=False, 388 | action="store_true", 389 | help="Remove saved weights from local machine after uploaded to wandb. Useful in colab.", 390 | ) 391 | parser.add_argument( 392 | "--wandb_name", 393 | type=str, 394 | default="Stable-Diffusion-Dreambooth", 395 | help="Project name in your wandb.", 396 | ) 397 | parser.add_argument( 398 | "--read_prompt_filename", 399 | default=False, 400 | action="store_true", 401 | help="Append extra prompt from filename.", 402 | ) 403 | parser.add_argument( 404 | "--read_prompt_txt", 405 | default=False, 406 | action="store_true", 407 | help="Append extra prompt from txt.", 408 | ) 409 | parser.add_argument( 410 | "--append_prompt", 411 | type=str, 412 | default="instance", 413 | choices=["class", "instance", "both"], 414 | help="Append extra prompt to which part of input.", 415 | ) 416 | parser.add_argument( 417 | "--save_unet_half", 418 | default=False, 419 | action="store_true", 420 | help="Use half precision to save unet weights, saves storage.", 421 | ) 422 | parser.add_argument( 423 | "--unet_half", 424 | default=False, 425 | action="store_true", 426 | help="Use half precision to save unet weights, saves storage.", 427 | ) 428 | parser.add_argument( 429 | "--clip_skip", 430 | type=int, 431 | default=1, 432 | help="Stop At last [n] layers of CLIP model when training." 433 | ) 434 | parser.add_argument( 435 | "--num_cycles", 436 | type=int, 437 | default=1, 438 | help="The number of hard restarts to use. Only works with --lr_scheduler=[cosine_with_restarts_mod, cosine_mod]" 439 | ) 440 | parser.add_argument( 441 | "--last_epoch", 442 | type=int, 443 | default=-1, 444 | help="The index of the last epoch when resuming training. Only works with --lr_scheduler=[cosine_with_restarts_mod, cosine_mod]" 445 | ) 446 | parser.add_argument( 447 | "--use_aspect_ratio_bucket", 448 | default=False, 449 | action="store_true", 450 | help="Use aspect ratio bucketing as image processing strategy, which may improve the quality of outputs. Use it with --not_cache_latents" 451 | ) 452 | parser.add_argument( 453 | "--debug_arb", 454 | default=False, 455 | action="store_true", 456 | help="Enable debug logging on aspect ratio bucket." 457 | ) 458 | parser.add_argument( 459 | "--save_optimizer", 460 | default=True, 461 | action="store_true", 462 | help="Save optimizer and scheduler state dict when training. Deprecated: use --save_states" 463 | ) 464 | parser.add_argument( 465 | "--save_states", 466 | default=True, 467 | action="store_true", 468 | help="Save optimizer and scheduler state dict when training." 469 | ) 470 | parser.add_argument( 471 | "--resume", 472 | default=False, 473 | action="store_true", 474 | help="Load optimizer and scheduler state dict to continue training." 475 | ) 476 | parser.add_argument( 477 | "--resume_from", 478 | type=str, 479 | default="", 480 | help="Specify checkpoint to resume. Use wandb://[artifact-full-name] for wandb artifact." 481 | ) 482 | parser.add_argument( 483 | "--config", 484 | type=str, 485 | default=None, 486 | help="Read args from config file. Command line args have higher priority and will override it.", 487 | ) 488 | parser.add_argument( 489 | "--arb_dim_limit", 490 | type=int, 491 | default=1024, 492 | help="Aspect ratio bucketing arguments: dim_limit." 493 | ) 494 | parser.add_argument( 495 | "--arb_divisible", 496 | type=int, 497 | default=64, 498 | help="Aspect ratio bucketing arguments: divisbile." 499 | ) 500 | parser.add_argument( 501 | "--arb_max_ar_error", 502 | type=int, 503 | default=4, 504 | help="Aspect ratio bucketing arguments: max_ar_error." 505 | ) 506 | parser.add_argument( 507 | "--arb_max_size", 508 | type=int, 509 | nargs="+", 510 | default=(768, 512), 511 | help="Aspect ratio bucketing arguments: max_size. example: --arb_max_size 768 512" 512 | ) 513 | parser.add_argument( 514 | "--arb_min_dim", 515 | type=int, 516 | default=256, 517 | help="Aspect ratio bucketing arguments: min_dim." 518 | ) 519 | parser.add_argument( 520 | "--deepdanbooru", 521 | default=False, 522 | action="store_true", 523 | help="Use deepdanbooru to tag images when prompt txt is not available." 524 | ) 525 | parser.add_argument( 526 | "--dd_threshold", 527 | type=float, 528 | default=0.6, 529 | help="Threshold for Deepdanbooru tag estimation" 530 | ) 531 | parser.add_argument( 532 | "--dd_alpha_sort", 533 | default=False, 534 | action="store_true", 535 | help="Sort deepbooru tags alphabetically." 536 | ) 537 | parser.add_argument( 538 | "--dd_use_spaces", 539 | default=True, 540 | action="store_true", 541 | help="Use spaces for tags in deepbooru." 542 | ) 543 | parser.add_argument( 544 | "--dd_use_escape", 545 | default=True, 546 | action="store_true", 547 | help="Use escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)" 548 | ) 549 | parser.add_argument( 550 | "--enable_rotate", 551 | default=False, 552 | action="store_true", 553 | help="Enable experimental feature to rotate image when buckets is not fit." 554 | ) 555 | parser.add_argument( 556 | "--dd_include_ranks", 557 | default=False, 558 | action="store_true", 559 | help="Include rank tag in deepdanbooru." 560 | ) 561 | parser.add_argument( 562 | "--use_ema", 563 | action="store_true", 564 | help="Whether to use EMA model." 565 | ) 566 | parser.add_argument( 567 | "--ucg", 568 | type=float, 569 | default=0.0, 570 | help="Percentage chance of dropping out the text condition per batch. \ 571 | Ranges from 0.0 to 1.0 where 1.0 means 100% text condition dropout." 572 | ) 573 | parser.add_argument( 574 | "--debug_prompt", 575 | default=False, 576 | action="store_true", 577 | help="Print input prompt when training." 578 | ) 579 | parser.add_argument( 580 | "--xformers", 581 | default=False, 582 | action="store_true", 583 | help="Enable memory efficient attention when training." 584 | ) 585 | 586 | args = parser.parse_args() 587 | resume_from = args.resume_from 588 | 589 | if resume_from.startswith("wandb://"): 590 | import wandb 591 | run = wandb.init(project=args.wandb_name, reinit=False) 592 | artifact = run.use_artifact(resume_from.replace("wandb://", ""), type='model') 593 | resume_from = artifact.download() 594 | 595 | elif args.resume_from != "": 596 | fp = os.path.join(resume_from, "state.pt") 597 | if not Path(fp).is_file(): 598 | raise ValueError(f"State_dict file {fp} not found.") 599 | 600 | elif args.resume: 601 | rx = re.compile(r'checkpoint_(\d+)') 602 | ckpts = rx.findall(" ".join([x.name for x in Path(args.output_dir).iterdir() if x.is_dir() and rx.match(x.name)])) 603 | 604 | if not any(ckpts): 605 | raise ValueError("At least one model is needed to resume training.") 606 | 607 | ckpts.sort(key=lambda e: int(e), reverse=True) 608 | for k in ckpts: 609 | fp = os.path.join(args.output_dir, f"checkpoint_{k}", "state.pt") 610 | if Path(fp).is_file(): 611 | resume_from = os.path.join(args.output_dir, f"checkpoint_{k}") 612 | break 613 | 614 | 615 | print(f"[*] Selected {resume_from}. To specify other checkpoint, use --resume-from") 616 | 617 | if resume_from: 618 | args.config = os.path.join(resume_from, "args.json") 619 | 620 | if args.config: 621 | with open(args.config, 'r') as f: 622 | config = json.load(f) 623 | parser.set_defaults(**config) 624 | args = parser.parse_args() 625 | 626 | if args.resume: 627 | args.pretrained_model_name_or_path = resume_from 628 | 629 | if not args.pretrained_model_name_or_path or not Path(args.pretrained_model_name_or_path).is_dir(): 630 | raise ValueError("A model is needed.") 631 | 632 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 633 | if env_local_rank != -1 and env_local_rank != args.local_rank: 634 | args.local_rank = env_local_rank 635 | 636 | return args 637 | 638 | 639 | class DeepDanbooru: 640 | def __init__( 641 | self, 642 | dd_threshold=0.6, 643 | dd_alpha_sort=False, 644 | dd_use_spaces=True, 645 | dd_use_escape=True, 646 | dd_include_ranks=False, 647 | **kwargs 648 | ): 649 | 650 | self.threshold = dd_threshold 651 | self.alpha_sort = dd_alpha_sort 652 | self.use_spaces = dd_use_spaces 653 | self.use_escape = dd_use_escape 654 | self.include_ranks = dd_include_ranks 655 | self.re_special = re.compile(r"([\\()])") 656 | self.new_process() 657 | 658 | def get_tags_local(self,image): 659 | self.returns["value"] = -1 660 | self.queue.put(image) 661 | while self.returns["value"] == -1: 662 | time.sleep(0.1) 663 | 664 | return self.returns["value"] 665 | 666 | def deepbooru_process(self): 667 | import tensorflow, deepdanbooru 668 | print(f"Deepdanbooru initialized using threshold: {self.threshold}") 669 | self.load_model() 670 | while True: 671 | image = self.queue.get() 672 | if image == "QUIT": 673 | break 674 | else: 675 | self.returns["value"] = self.get_tags(image) 676 | 677 | def new_process(self): 678 | context = multiprocessing.get_context("spawn") 679 | manager = context.Manager() 680 | self.queue = manager.Queue() 681 | self.returns = manager.dict() 682 | self.returns["value"] = -1 683 | self.process = context.Process(target=self.deepbooru_process) 684 | self.process.start() 685 | 686 | def kill_process(self): 687 | self.queue.put("QUIT") 688 | self.process.join() 689 | self.queue = None 690 | self.returns = None 691 | self.process = None 692 | 693 | def load_model(self): 694 | model_path = Path(tempfile.gettempdir()) / "deepbooru" 695 | if not Path(model_path / "project.json").is_file(): 696 | self.load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", model_path) 697 | 698 | with zipfile.ZipFile(model_path / "deepdanbooru-v3-20211112-sgd-e28.zip", "r") as zip_ref: 699 | zip_ref.extractall(model_path) 700 | os.remove(model_path / "deepdanbooru-v3-20211112-sgd-e28.zip") 701 | 702 | self.tags = dd.project.load_tags_from_project(model_path) 703 | self.model = dd.project.load_model_from_project(model_path, compile_model=False) 704 | 705 | def unload_model(self): 706 | self.kill_process() 707 | 708 | from tensorflow.python.framework import ops 709 | ops.reset_default_graph() 710 | tf.keras.backend.clear_session() 711 | 712 | @staticmethod 713 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 714 | if model_dir is None: # use the pytorch hub_dir 715 | hub_dir = get_dir() 716 | model_dir = os.path.join(hub_dir, 'checkpoints') 717 | 718 | os.makedirs(model_dir, exist_ok=True) 719 | 720 | parts = urlparse(url) 721 | filename = os.path.basename(parts.path) 722 | if file_name is not None: 723 | filename = file_name 724 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 725 | if not os.path.exists(cached_file): 726 | print(f'Downloading: "{url}" to {cached_file}\n') 727 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 728 | return cached_file 729 | 730 | def process_img(self, image): 731 | width = self.model.input_shape[2] 732 | height = self.model.input_shape[1] 733 | image = np.array(image) 734 | image = tf.image.resize( 735 | image, 736 | size=(height, width), 737 | method=tf.image.ResizeMethod.BICUBIC, 738 | preserve_aspect_ratio=True, 739 | ) 740 | image = image.numpy() # EagerTensor to np.array 741 | image = dd.image.transform_and_pad_image(image, width, height) 742 | image = image / 255.0 743 | image_shape = image.shape 744 | image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2])) 745 | return image 746 | 747 | def process_tag(self, y): 748 | result_dict = {} 749 | 750 | for i, tag in enumerate(self.tags): 751 | result_dict[tag] = y[i] 752 | 753 | unsorted_tags_in_theshold = [] 754 | result_tags_print = [] 755 | for tag in self.tags: 756 | if result_dict[tag] >= self.threshold: 757 | if tag.startswith("rating:"): 758 | continue 759 | unsorted_tags_in_theshold.append((result_dict[tag], tag)) 760 | result_tags_print.append(f"{result_dict[tag]} {tag}") 761 | 762 | # sort tags 763 | result_tags_out = [] 764 | sort_ndx = 0 765 | if self.alpha_sort: 766 | sort_ndx = 1 767 | 768 | # sort by reverse by likelihood and normal for alpha, and format tag text as requested 769 | unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not self.alpha_sort)) 770 | for weight, tag in unsorted_tags_in_theshold: 771 | tag_outformat = tag 772 | if self.use_spaces: 773 | tag_outformat = tag_outformat.replace("_", " ") 774 | if self.use_escape: 775 | tag_outformat = re.sub(self.re_special, r"\\\1", tag_outformat) 776 | if self.include_ranks: 777 | tag_outformat = f"({tag_outformat}:{weight:.3f})" 778 | 779 | result_tags_out.append(tag_outformat) 780 | 781 | # print("\n".join(sorted(result_tags_print, reverse=True))) 782 | 783 | return ", ".join(result_tags_out) 784 | 785 | def get_tags(self, image): 786 | result = self.model.predict(self.process_img(image))[0] 787 | return self.process_tag(result) 788 | 789 | 790 | class AspectRatioBucket: 791 | ''' 792 | Code from https://github.com/NovelAI/novelai-aspect-ratio-bucketing/blob/main/bucketmanager.py 793 | 794 | BucketManager impls NovelAI Aspect Ratio Bucketing, which may greatly improve the quality of outputs according to Novelai's blog (https://blog.novelai.net/novelai-improvements-on-stable-diffusion-e10d38db82ac) 795 | Requires a pickle with mapping of dataset IDs to resolutions called resolutions.pkl to use this. 796 | ''' 797 | 798 | def __init__(self, 799 | id_size_map, 800 | max_size=(768, 512), 801 | divisible=64, 802 | step_size=8, 803 | min_dim=256, 804 | base_res=(512, 512), 805 | bsz=1, 806 | world_size=1, 807 | global_rank=0, 808 | max_ar_error=4, 809 | seed=42, 810 | dim_limit=1024, 811 | debug=True, 812 | ): 813 | if global_rank == -1: 814 | global_rank = 0 815 | 816 | self.res_map = id_size_map 817 | self.max_size = max_size 818 | self.f = 8 819 | self.max_tokens = (max_size[0]/self.f) * (max_size[1]/self.f) 820 | self.div = divisible 821 | self.min_dim = min_dim 822 | self.dim_limit = dim_limit 823 | self.base_res = base_res 824 | self.bsz = bsz 825 | self.world_size = world_size 826 | self.global_rank = global_rank 827 | self.max_ar_error = max_ar_error 828 | self.prng = self.get_prng(seed) 829 | epoch_seed = self.prng.tomaxint() % (2**32-1) 830 | 831 | # separate prng for sharding use for increased thread resilience 832 | self.epoch_prng = self.get_prng(epoch_seed) 833 | self.epoch = None 834 | self.left_over = None 835 | self.batch_total = None 836 | self.batch_delivered = None 837 | 838 | self.debug = debug 839 | 840 | self.gen_buckets() 841 | self.assign_buckets() 842 | self.start_epoch() 843 | 844 | @staticmethod 845 | def get_prng(seed): 846 | return np.random.RandomState(seed) 847 | 848 | def __len__(self): 849 | return len(self.res_map) // self.bsz 850 | 851 | def gen_buckets(self): 852 | if self.debug: 853 | timer = time.perf_counter() 854 | resolutions = [] 855 | aspects = [] 856 | w = self.min_dim 857 | while (w/self.f) * (self.min_dim/self.f) <= self.max_tokens and w <= self.dim_limit: 858 | h = self.min_dim 859 | got_base = False 860 | while (w/self.f) * ((h+self.div)/self.f) <= self.max_tokens and (h+self.div) <= self.dim_limit: 861 | if w == self.base_res[0] and h == self.base_res[1]: 862 | got_base = True 863 | h += self.div 864 | if (w != self.base_res[0] or h != self.base_res[1]) and got_base: 865 | resolutions.append(self.base_res) 866 | aspects.append(1) 867 | resolutions.append((w, h)) 868 | aspects.append(float(w)/float(h)) 869 | w += self.div 870 | h = self.min_dim 871 | while (h/self.f) * (self.min_dim/self.f) <= self.max_tokens and h <= self.dim_limit: 872 | w = self.min_dim 873 | got_base = False 874 | while (h/self.f) * ((w+self.div)/self.f) <= self.max_tokens and (w+self.div) <= self.dim_limit: 875 | if w == self.base_res[0] and h == self.base_res[1]: 876 | got_base = True 877 | w += self.div 878 | resolutions.append((w, h)) 879 | aspects.append(float(w)/float(h)) 880 | h += self.div 881 | res_map = {} 882 | for i, res in enumerate(resolutions): 883 | res_map[res] = aspects[i] 884 | self.resolutions = sorted( 885 | res_map.keys(), key=lambda x: x[0] * 4096 - x[1]) 886 | self.aspects = np.array( 887 | list(map(lambda x: res_map[x], self.resolutions))) 888 | self.resolutions = np.array(self.resolutions) 889 | if self.debug: 890 | timer = time.perf_counter() - timer 891 | print(f"resolutions:\n{self.resolutions}") 892 | print(f"aspects:\n{self.aspects}") 893 | print(f"gen_buckets: {timer:.5f}s") 894 | 895 | def assign_buckets(self): 896 | if self.debug: 897 | timer = time.perf_counter() 898 | self.buckets = {} 899 | self.aspect_errors = [] 900 | skipped = 0 901 | skip_list = [] 902 | for post_id in self.res_map.keys(): 903 | w, h = self.res_map[post_id] 904 | aspect = float(w)/float(h) 905 | bucket_id = np.abs(self.aspects - aspect).argmin() 906 | if bucket_id not in self.buckets: 907 | self.buckets[bucket_id] = [] 908 | error = abs(self.aspects[bucket_id] - aspect) 909 | if error < self.max_ar_error: 910 | self.buckets[bucket_id].append(post_id) 911 | if self.debug: 912 | self.aspect_errors.append(error) 913 | else: 914 | skipped += 1 915 | skip_list.append(post_id) 916 | for post_id in skip_list: 917 | del self.res_map[post_id] 918 | if self.debug: 919 | timer = time.perf_counter() - timer 920 | self.aspect_errors = np.array(self.aspect_errors) 921 | try: 922 | print(f"skipped images: {skipped}") 923 | print(f"aspect error: mean {self.aspect_errors.mean()}, median {np.median(self.aspect_errors)}, max {self.aspect_errors.max()}") 924 | for bucket_id in reversed(sorted(self.buckets.keys(), key=lambda b: len(self.buckets[b]))): 925 | print( 926 | f"bucket {bucket_id}: {self.resolutions[bucket_id]}, aspect {self.aspects[bucket_id]:.5f}, entries {len(self.buckets[bucket_id])}") 927 | print(f"assign_buckets: {timer:.5f}s") 928 | except Exception as e: 929 | pass 930 | 931 | def start_epoch(self, world_size=None, global_rank=None): 932 | if self.debug: 933 | timer = time.perf_counter() 934 | if world_size is not None: 935 | self.world_size = world_size 936 | if global_rank is not None: 937 | self.global_rank = global_rank 938 | 939 | # select ids for this epoch/rank 940 | index = sorted(list(self.res_map.keys())) 941 | index_len = len(index) 942 | 943 | index = self.epoch_prng.permutation(index) 944 | index = index[:index_len - (index_len % (self.bsz * self.world_size))] 945 | # if self.debug: 946 | # print("perm", self.global_rank, index[0:16]) 947 | 948 | index = index[self.global_rank::self.world_size] 949 | self.batch_total = len(index) // self.bsz 950 | assert (len(index) % self.bsz == 0) 951 | index = set(index) 952 | 953 | self.epoch = {} 954 | self.left_over = [] 955 | self.batch_delivered = 0 956 | for bucket_id in sorted(self.buckets.keys()): 957 | if len(self.buckets[bucket_id]) > 0: 958 | self.epoch[bucket_id] = [post_id for post_id in self.buckets[bucket_id] if post_id in index] 959 | self.prng.shuffle(self.epoch[bucket_id]) 960 | self.epoch[bucket_id] = list(self.epoch[bucket_id]) 961 | overhang = len(self.epoch[bucket_id]) % self.bsz 962 | if overhang != 0: 963 | self.left_over.extend(self.epoch[bucket_id][:overhang]) 964 | self.epoch[bucket_id] = self.epoch[bucket_id][overhang:] 965 | if len(self.epoch[bucket_id]) == 0: 966 | del self.epoch[bucket_id] 967 | 968 | if self.debug: 969 | timer = time.perf_counter() - timer 970 | count = 0 971 | for bucket_id in self.epoch.keys(): 972 | count += len(self.epoch[bucket_id]) 973 | print( 974 | f"correct item count: {count == len(index)} ({count} of {len(index)})") 975 | print(f"start_epoch: {timer:.5f}s") 976 | 977 | def get_batch(self): 978 | if self.debug: 979 | timer = time.perf_counter() 980 | # check if no data left or no epoch initialized 981 | if self.epoch is None or self.left_over is None or (len(self.left_over) == 0 and not bool(self.epoch)) or self.batch_total == self.batch_delivered: 982 | self.start_epoch() 983 | 984 | found_batch = False 985 | batch_data = None 986 | resolution = self.base_res 987 | while not found_batch: 988 | bucket_ids = list(self.epoch.keys()) 989 | if len(self.left_over) >= self.bsz: 990 | bucket_probs = [ 991 | len(self.left_over)] + [len(self.epoch[bucket_id]) for bucket_id in bucket_ids] 992 | bucket_ids = [-1] + bucket_ids 993 | else: 994 | bucket_probs = [len(self.epoch[bucket_id]) 995 | for bucket_id in bucket_ids] 996 | bucket_probs = np.array(bucket_probs, dtype=np.float32) 997 | bucket_lens = bucket_probs 998 | bucket_probs = bucket_probs / bucket_probs.sum() 999 | if bool(self.epoch): 1000 | chosen_id = int(self.prng.choice( 1001 | bucket_ids, 1, p=bucket_probs)[0]) 1002 | else: 1003 | chosen_id = -1 1004 | 1005 | if chosen_id == -1: 1006 | # using leftover images that couldn't make it into a bucketed batch and returning them for use with basic square image 1007 | self.prng.shuffle(self.left_over) 1008 | batch_data = self.left_over[:self.bsz] 1009 | self.left_over = self.left_over[self.bsz:] 1010 | found_batch = True 1011 | else: 1012 | if len(self.epoch[chosen_id]) >= self.bsz: 1013 | # return bucket batch and resolution 1014 | batch_data = self.epoch[chosen_id][:self.bsz] 1015 | self.epoch[chosen_id] = self.epoch[chosen_id][self.bsz:] 1016 | resolution = tuple(self.resolutions[chosen_id]) 1017 | found_batch = True 1018 | if len(self.epoch[chosen_id]) == 0: 1019 | del self.epoch[chosen_id] 1020 | else: 1021 | # can't make a batch from this, not enough images. move them to leftovers and try again 1022 | self.left_over.extend(self.epoch[chosen_id]) 1023 | del self.epoch[chosen_id] 1024 | 1025 | assert (found_batch or len(self.left_over) 1026 | >= self.bsz or bool(self.epoch)) 1027 | 1028 | if self.debug: 1029 | timer = time.perf_counter() - timer 1030 | print(f"bucket probs: " + 1031 | ", ".join(map(lambda x: f"{x:.2f}", list(bucket_probs*100)))) 1032 | print(f"chosen id: {chosen_id}") 1033 | print(f"batch data: {batch_data}") 1034 | print(f"resolution: {resolution}") 1035 | print(f"get_batch: {timer:.5f}s") 1036 | 1037 | self.batch_delivered += 1 1038 | return (batch_data, resolution) 1039 | 1040 | def generator(self): 1041 | if self.batch_delivered >= self.batch_total: 1042 | self.start_epoch() 1043 | while self.batch_delivered < self.batch_total: 1044 | yield self.get_batch() 1045 | 1046 | 1047 | class EMAModel: 1048 | """ 1049 | Maintains (exponential) moving average of a set of parameters. 1050 | Ref: https://github.com/harubaru/waifu-diffusion/diffusers_trainer.py#L478 1051 | 1052 | Args: 1053 | parameters: Iterable of `torch.nn.Parameter` (typically from model.parameters()`). 1054 | decay: The exponential decay. 1055 | """ 1056 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): 1057 | parameters = list(parameters) 1058 | self.shadow_params = [p.clone().detach() for p in parameters] 1059 | 1060 | self.decay = decay 1061 | self.optimization_step = 0 1062 | 1063 | def get_decay(self, optimization_step): 1064 | """ 1065 | Compute the decay factor for the exponential moving average. 1066 | """ 1067 | value = (1 + optimization_step) / (10 + optimization_step) 1068 | return 1 - min(self.decay, value) 1069 | 1070 | @torch.no_grad() 1071 | def step(self, parameters): 1072 | parameters = list(parameters) 1073 | 1074 | self.optimization_step += 1 1075 | self.decay = self.get_decay(self.optimization_step) 1076 | 1077 | for s_param, param in zip(self.shadow_params, parameters): 1078 | if param.requires_grad: 1079 | tmp = self.decay * (s_param - param) 1080 | s_param.sub_(tmp) 1081 | else: 1082 | s_param.copy_(param) 1083 | 1084 | torch.cuda.empty_cache() 1085 | 1086 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: 1087 | """ 1088 | Copy current averaged parameters into given collection of parameters. 1089 | Args: 1090 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 1091 | updated with the stored moving averages. If `None`, the 1092 | parameters with which this `ExponentialMovingAverage` was 1093 | initialized will be used. 1094 | """ 1095 | parameters = list(parameters) 1096 | for s_param, param in zip(self.shadow_params, parameters): 1097 | param.data.copy_(s_param.data) 1098 | 1099 | # From CompVis LitEMA implementation 1100 | def store(self, parameters): 1101 | """ 1102 | Save the current parameters for restoring later. 1103 | Args: 1104 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 1105 | temporarily stored. 1106 | """ 1107 | self.collected_params = [param.clone() for param in parameters] 1108 | 1109 | def restore(self, parameters): 1110 | """ 1111 | Restore the parameters stored with the `store` method. 1112 | Useful to validate the model with EMA parameters without affecting the 1113 | original optimization process. Store the parameters before the 1114 | `copy_to` method. After validation (or model saving), use this to 1115 | restore the former parameters. 1116 | Args: 1117 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 1118 | updated with the stored parameters. 1119 | """ 1120 | for c_param, param in zip(self.collected_params, parameters): 1121 | param.data.copy_(c_param.data) 1122 | 1123 | del self.collected_params 1124 | gc.collect() 1125 | 1126 | def to(self, device=None, dtype=None) -> None: 1127 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. 1128 | Args: 1129 | device: like `device` argument to `torch.Tensor.to` 1130 | """ 1131 | # .to() on the tensors handles None correctly 1132 | self.shadow_params = [ 1133 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) 1134 | for p in self.shadow_params 1135 | ] 1136 | 1137 | @contextlib.contextmanager 1138 | def average_parameters(self, parameters): 1139 | r""" 1140 | Context manager for validation/inference with averaged parameters. 1141 | """ 1142 | self.store(parameters) 1143 | self.copy_to(parameters) 1144 | try: 1145 | yield 1146 | finally: 1147 | self.restore(parameters) 1148 | 1149 | 1150 | class DreamBoothDataset(Dataset): 1151 | """ 1152 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 1153 | It pre-processes the images and the tokenizes prompts. 1154 | """ 1155 | 1156 | def __init__( 1157 | self, 1158 | concepts_list, 1159 | tokenizer, 1160 | with_prior_preservation=True, 1161 | size=512, 1162 | center_crop=False, 1163 | num_class_images=None, 1164 | read_prompt_filename=False, 1165 | read_prompt_txt=False, 1166 | append_pos="", 1167 | pad_tokens=False, 1168 | deepdanbooru=False, 1169 | ucg=0, 1170 | debug_prompt=False, 1171 | **kwargs 1172 | ): 1173 | self.size = size 1174 | self.center_crop = center_crop 1175 | self.tokenizer = tokenizer 1176 | self.with_prior_preservation = with_prior_preservation 1177 | self.pad_tokens = pad_tokens 1178 | self.deepdanbooru = deepdanbooru 1179 | self.ucg = ucg 1180 | self.debug_prompt = debug_prompt 1181 | 1182 | self.instance_entries = [] 1183 | self.class_entries = [] 1184 | 1185 | if deepdanbooru: 1186 | dd = DeepDanbooru(**kwargs) 1187 | 1188 | def prompt_resolver(x, default, typ): 1189 | img_item = (x, default) 1190 | 1191 | if append_pos != typ and append_pos != "both": 1192 | return img_item 1193 | 1194 | if read_prompt_filename: 1195 | filename = Path(x).stem 1196 | pt = ''.join([i for i in filename if not i.isdigit()]) 1197 | pt = pt.replace("_", " ") 1198 | pt = pt.replace("(", "") 1199 | pt = pt.replace(")", "") 1200 | pt = pt.replace("--", "") 1201 | new_prompt = default + " " + pt 1202 | img_item = (x, new_prompt) 1203 | 1204 | elif read_prompt_txt: 1205 | fp = os.path.splitext(x)[0] 1206 | if not Path(fp + ".txt").is_file() and deepdanbooru: 1207 | print(f"Deepdanbooru: Working on {x}") 1208 | return (x, default + dd.get_tags_local(self.read_img(x))) 1209 | 1210 | with open(fp + ".txt") as f: 1211 | content = f.read() 1212 | new_prompt = default + " " + content 1213 | img_item = (x, new_prompt) 1214 | 1215 | elif deepdanbooru: 1216 | print(f"Deepdanbooru: Working on {x}") 1217 | return (x, default + dd.get_tags_local(self.read_img(x))) 1218 | 1219 | return img_item 1220 | 1221 | for concept in concepts_list: 1222 | inst_img_path = [prompt_resolver(x, concept["instance_prompt"], "instance") for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file() and x.suffix != ".txt"] 1223 | self.instance_entries.extend(inst_img_path) 1224 | 1225 | if with_prior_preservation: 1226 | class_img_path = [prompt_resolver(x, concept["class_prompt"], "class") for x in Path(concept["class_data_dir"]).iterdir() if x.is_file() and x.suffix != ".txt"] 1227 | self.class_entries.extend(class_img_path[:num_class_images]) 1228 | 1229 | if deepdanbooru: 1230 | dd.unload_model() 1231 | 1232 | random.shuffle(self.instance_entries) 1233 | self.num_instance_images = len(self.instance_entries) 1234 | self.num_class_images = len(self.class_entries) 1235 | self._length = max(self.num_class_images, self.num_instance_images) 1236 | 1237 | self.image_transforms = transforms.Compose( 1238 | [ 1239 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), 1240 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 1241 | transforms.ToTensor(), 1242 | transforms.Normalize([0.5], [0.5]), 1243 | ] 1244 | ) 1245 | 1246 | def tokenize(self, prompt): 1247 | return self.tokenizer( 1248 | prompt, 1249 | padding="max_length" if self.pad_tokens else "do_not_pad", 1250 | truncation=True, 1251 | max_length=self.tokenizer.model_max_length, 1252 | ).input_ids 1253 | 1254 | @staticmethod 1255 | def read_img(filepath) -> Image: 1256 | img = Image.open(filepath) 1257 | 1258 | if not img.mode == "RGB": 1259 | img = img.convert("RGB") 1260 | return img 1261 | 1262 | @staticmethod 1263 | def process_tags(tags, min_tags=1, max_tags=32, type_dropout=0.75, keep_important=1.00, keep_jpeg_artifacts=True, sort_tags=False): 1264 | if isinstance(tags, str): 1265 | tags = tags.split(" ") 1266 | final_tags = {} 1267 | 1268 | tag_dict = {tag: True for tag in tags} 1269 | pure_tag_dict = {tag.split(":", 1)[-1]: tag for tag in tags} 1270 | for bad_tag in ["absurdres", "highres", "translation_request", "translated", "commentary", "commentary_request", "commentary_typo", "character_request", "bad_id", "bad_link", "bad_pixiv_id", "bad_twitter_id", "bad_tumblr_id", "bad_deviantart_id", "bad_nicoseiga_id", "md5_mismatch", "cosplay_request", "artist_request", "wide_image", "author_request", "artist_name"]: 1271 | if bad_tag in pure_tag_dict: 1272 | del tag_dict[pure_tag_dict[bad_tag]] 1273 | 1274 | if "rating:questionable" in tag_dict or "rating:explicit" in tag_dict: 1275 | final_tags["nsfw"] = True 1276 | 1277 | base_chosen = [] 1278 | for tag in tag_dict.keys(): 1279 | parts = tag.split(":", 1) 1280 | if parts[0] in ["artist", "copyright", "character"] and random.random() < keep_important: 1281 | base_chosen.append(tag) 1282 | if len(parts[-1]) > 1 and parts[-1][0] in ["1", "2", "3", "4", "5", "6"] and parts[-1][1:] in ["boy", "boys", "girl", "girls"]: 1283 | base_chosen.append(tag) 1284 | if parts[-1] in ["6+girls", "6+boys", "bad_anatomy", "bad_hands"]: 1285 | base_chosen.append(tag) 1286 | 1287 | tag_count = min(random.randint(min_tags, max_tags), len(tag_dict.keys())) 1288 | base_chosen_set = set(base_chosen) 1289 | chosen_tags = base_chosen + [tag for tag in random.sample(list(tag_dict.keys()), tag_count) if tag not in base_chosen_set] 1290 | if sort_tags: 1291 | chosen_tags = sorted(chosen_tags) 1292 | 1293 | for tag in chosen_tags: 1294 | tag = tag.replace(",", "").replace("_", " ") 1295 | if random.random() < type_dropout: 1296 | if tag.startswith("artist:"): 1297 | tag = tag[7:] 1298 | elif tag.startswith("copyright:"): 1299 | tag = tag[10:] 1300 | elif tag.startswith("character:"): 1301 | tag = tag[10:] 1302 | elif tag.startswith("general:"): 1303 | tag = tag[8:] 1304 | if tag.startswith("meta:"): 1305 | tag = tag[5:] 1306 | final_tags[tag] = True 1307 | 1308 | skip_image = False 1309 | for bad_tag in ["comic", "panels", "everyone", "sample_watermark", "text_focus", "tagme"]: 1310 | if bad_tag in pure_tag_dict: 1311 | skip_image = True 1312 | if not keep_jpeg_artifacts and "jpeg_artifacts" in tag_dict: 1313 | skip_image = True 1314 | 1315 | return ", ".join(list(final_tags.keys())) 1316 | 1317 | def __len__(self): 1318 | return self._length 1319 | 1320 | def __getitem__(self, index): 1321 | example = {} 1322 | instance_path, instance_prompt = self.instance_entries[index % self.num_instance_images] 1323 | 1324 | if random.random() <= self.ucg: 1325 | instance_prompt = '' 1326 | 1327 | instance_image = self.read_img(instance_path) 1328 | if self.debug_prompt: 1329 | print(f"instance prompt: {instance_prompt}") 1330 | example["instance_images"] = self.image_transforms(instance_image) 1331 | example["instance_prompt_ids"] = self.tokenize(instance_prompt) 1332 | 1333 | if self.with_prior_preservation: 1334 | class_path, class_prompt = self.class_entries[index % self.num_class_images] 1335 | class_image = self.read_img(class_path) 1336 | if self.debug_prompt: 1337 | print(f"class prompt: {class_prompt}") 1338 | example["class_images"] = self.image_transforms(class_image) 1339 | example["class_prompt_ids"] = self.tokenize(class_prompt) 1340 | 1341 | return example 1342 | 1343 | 1344 | class AspectRatioDataset(DreamBoothDataset): 1345 | def __init__(self, debug_arb=False, enable_rotate=False, **kwargs): 1346 | super().__init__(**kwargs) 1347 | self.debug = debug_arb 1348 | self.enable_rotate = enable_rotate 1349 | self.prompt_cache = {} 1350 | 1351 | # cache prompts for reading 1352 | for path, prompt in self.instance_entries + self.class_entries: 1353 | self.prompt_cache[path] = prompt 1354 | 1355 | def denormalize(self, img, mean=0.5, std=0.5): 1356 | res = transforms.Normalize((-1*mean/std), (1.0/std))(img.squeeze(0)) 1357 | res = torch.clamp(res, 0, 1) 1358 | return res 1359 | 1360 | def transformer(self, img, size, center_crop=False): 1361 | x, y = img.size 1362 | short, long = (x, y) if x <= y else (y, x) 1363 | 1364 | w, h = size 1365 | min_crop, max_crop = (w, h) if w <= h else (h, w) 1366 | ratio_src, ratio_dst = float(long / short), float(max_crop / min_crop) 1367 | 1368 | if (x>y and wh) and self.with_prior_preservation and self.enable_rotate: 1369 | # handle i/c mixed input 1370 | img = img.rotate(90, expand=True) 1371 | x, y = img.size 1372 | 1373 | if ratio_src > ratio_dst: 1374 | new_w, new_h = (min_crop, int(min_crop * ratio_src)) if xy else (int(max_crop / ratio_src), max_crop) 1377 | else: 1378 | new_w, new_h = w, h 1379 | 1380 | image_transforms = transforms.Compose([ 1381 | transforms.Resize((new_h, new_w), interpolation=transforms.InterpolationMode.BICUBIC), 1382 | transforms.CenterCrop((h, w)) if center_crop else transforms.RandomCrop((h, w)), 1383 | transforms.ToTensor(), 1384 | transforms.Normalize([0.5], [0.5]) 1385 | ]) 1386 | 1387 | new_img = image_transforms(img) 1388 | 1389 | if self.debug: 1390 | import uuid, torchvision 1391 | print(x, y, "->", new_w, new_h, "->", new_img.shape) 1392 | filename = str(uuid.uuid4()) 1393 | torchvision.utils.save_image(self.denormalize(new_img), f"/tmp/{filename}_1.jpg") 1394 | torchvision.utils.save_image(torchvision.transforms.ToTensor()(img), f"/tmp/{filename}_2.jpg") 1395 | print(f"saved: /tmp/{filename}") 1396 | 1397 | return new_img 1398 | 1399 | def build_dict(self, item_id, size, typ) -> dict: 1400 | if item_id == "": 1401 | return {} 1402 | prompt = self.prompt_cache[item_id] 1403 | image = self.read_img(item_id) 1404 | 1405 | if random.random() < self.ucg: 1406 | prompt = '' 1407 | 1408 | if self.debug_prompt: 1409 | print(f"{typ} prompt: {prompt}") 1410 | 1411 | example = { 1412 | f"{typ}_images": self.transformer(image, size), 1413 | f"{typ}_prompt_ids": self.tokenize(prompt) 1414 | } 1415 | return example 1416 | 1417 | def __getitem__(self, index): 1418 | result = [] 1419 | for item in index: 1420 | instance_dict = self.build_dict(item["instance"], item["size"], "instance") 1421 | class_dict = self.build_dict(item["class"], item["size"], "class") 1422 | result.append({**instance_dict, **class_dict}) 1423 | 1424 | return result 1425 | 1426 | 1427 | class AspectRatioSampler(torch.utils.data.Sampler): 1428 | def __init__( 1429 | self, 1430 | instance_buckets: AspectRatioBucket, 1431 | class_buckets: AspectRatioBucket, 1432 | num_replicas: int = 1, 1433 | with_prior_preservation: bool = False, 1434 | debug: bool = False, 1435 | ): 1436 | super().__init__(None) 1437 | self.instance_bucket_manager = instance_buckets 1438 | self.class_bucket_manager = class_buckets 1439 | self.num_replicas = num_replicas 1440 | self.debug = debug 1441 | self.with_prior_preservation = with_prior_preservation 1442 | self.iterator = instance_buckets if len(class_buckets) < len(instance_buckets) or \ 1443 | not with_prior_preservation else class_buckets 1444 | 1445 | def build_res_id_dict(self, iter): 1446 | base = {} 1447 | for item, res in iter.generator(): 1448 | base.setdefault(res,[]).extend([item[0]]) 1449 | return base 1450 | 1451 | def find_closest(self, size, size_id_dict, typ): 1452 | new_size = size 1453 | if size not in size_id_dict or not any(size_id_dict[size]): 1454 | kv = [(abs(s[0] / s[1] - size[0] / size[1]), s) for s in size_id_dict.keys() if any(size_id_dict[s])] 1455 | kv.sort(key=lambda e: e[0]) 1456 | 1457 | new_size = kv[0][1] 1458 | print(f"Warning: no {typ} image with {size} exists. Will use the closest ratio {new_size}.") 1459 | 1460 | return random.choice(size_id_dict[new_size]) 1461 | 1462 | def __iter__(self): 1463 | iter_is_instance = self.iterator == self.instance_bucket_manager 1464 | self.cached_ids = self.build_res_id_dict(self.class_bucket_manager if iter_is_instance else self.instance_bucket_manager) 1465 | 1466 | 1467 | for batch, size in self.iterator.generator(): 1468 | result = [] 1469 | 1470 | for item in batch: 1471 | sdict = {"size": size} 1472 | 1473 | if iter_is_instance: 1474 | rdict = {"instance": item, "class": self.find_closest(size, self.cached_ids, "class") if self.with_prior_preservation else ""} 1475 | else: 1476 | rdict = {"class": item, "instance": self.find_closest(size, self.cached_ids, "instance")} 1477 | 1478 | 1479 | result.append({**rdict, **sdict}) 1480 | 1481 | yield result 1482 | 1483 | def __len__(self): 1484 | return len(self.iterator) // self.num_replicas 1485 | 1486 | 1487 | class PromptDataset(Dataset): 1488 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 1489 | 1490 | def __init__(self, prompt, num_samples): 1491 | self.prompt = prompt 1492 | self.num_samples = num_samples 1493 | 1494 | def __len__(self): 1495 | return self.num_samples 1496 | 1497 | def __getitem__(self, index): 1498 | example = {} 1499 | example["prompt"] = self.prompt 1500 | example["index"] = index 1501 | return example 1502 | 1503 | 1504 | class LatentsDataset(Dataset): 1505 | def __init__(self, latents_cache, text_encoder_cache): 1506 | self.latents_cache = latents_cache 1507 | self.text_encoder_cache = text_encoder_cache 1508 | 1509 | def __len__(self): 1510 | return len(self.latents_cache) 1511 | 1512 | def __getitem__(self, index): 1513 | return self.latents_cache[index], self.text_encoder_cache[index] 1514 | 1515 | 1516 | class AverageMeter: 1517 | def __init__(self, name=None): 1518 | self.name = name 1519 | self.reset() 1520 | 1521 | def reset(self): 1522 | self.sum = self.count = self.avg = 0 1523 | 1524 | def update(self, val, n=1): 1525 | self.sum += val * n 1526 | self.count += n 1527 | self.avg = self.sum / self.count 1528 | 1529 | 1530 | def get_optimizer_class(optimizer_name: str): 1531 | def try_import_bnb(): 1532 | try: 1533 | import bitsandbytes as bnb 1534 | return bnb 1535 | except ImportError: 1536 | raise ImportError( 1537 | "To use 8-bit optimizers, please install the bitsandbytes library: `pip install bitsandbytes`." 1538 | ) 1539 | def try_import_ds(): 1540 | try: 1541 | import deepspeed 1542 | return deepspeed 1543 | except ImportError: 1544 | raise ImportError( 1545 | "Failed to import Deepspeed" 1546 | ) 1547 | 1548 | name = optimizer_name.lower() 1549 | 1550 | if name == "adamw": 1551 | return torch.optim.AdamW 1552 | elif name == "adamw_8bit": 1553 | return try_import_bnb().optim.AdamW8bit 1554 | elif name == "adamw_ds": 1555 | return try_import_ds().ops.adam.DeepSpeedCPUAdam 1556 | elif name == "sgdm": 1557 | return torch.optim.sgd 1558 | elif name == "sgdm_8bit": 1559 | return try_import_bnb().optim.SGD8bit 1560 | else: 1561 | raise ValueError("Unsupport optimizer") 1562 | 1563 | 1564 | def generate_class_images(args, accelerator): 1565 | pipeline = None 1566 | for concept in args.concepts_list: 1567 | class_images_dir = Path(concept["class_data_dir"]) 1568 | class_images_dir.mkdir(parents=True, exist_ok=True) 1569 | cur_class_images = len(list(class_images_dir.iterdir())) 1570 | 1571 | if cur_class_images < args.num_class_images: 1572 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 1573 | if pipeline is None: 1574 | pipeline = StableDiffusionPipeline.from_pretrained( 1575 | args.pretrained_model_name_or_path, 1576 | vae=AutoencoderKL.from_pretrained(args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path, subfolder=None if args.pretrained_vae_name_or_path else "vae"), 1577 | torch_dtype=torch_dtype, 1578 | safety_checker=None, 1579 | ) 1580 | pipeline.set_progress_bar_config(disable=True) 1581 | pipeline.to(accelerator.device) 1582 | 1583 | num_new_images = args.num_class_images - cur_class_images 1584 | print(f"Number of class images to sample: {num_new_images}.") 1585 | 1586 | sample_dataset = PromptDataset([concept["class_prompt"], concept["class_negative_prompt"]], num_new_images) 1587 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) 1588 | sample_dataloader = accelerator.prepare(sample_dataloader) 1589 | 1590 | with torch.autocast("cuda"), torch.inference_mode(): 1591 | for example in tqdm( 1592 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process 1593 | ): 1594 | images = pipeline(prompt=example["prompt"][0][0], 1595 | negative_prompt=example["prompt"][1][0], 1596 | guidance_scale=args.save_guidance_scale, 1597 | num_inference_steps=args.save_infer_steps, 1598 | num_images_per_prompt=len(example["prompt"][0])).images 1599 | 1600 | for i, image in enumerate(images): 1601 | hash_image = hashlib.sha1(image.tobytes()).hexdigest() 1602 | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" 1603 | image.save(image_filename) 1604 | 1605 | del pipeline 1606 | if torch.cuda.is_available(): 1607 | torch.cuda.empty_cache() 1608 | 1609 | 1610 | def sizeof_fmt(num, suffix="B"): 1611 | for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: 1612 | if abs(num) < 1024.0: 1613 | return f"{num:3.1f}{unit}{suffix}" 1614 | num /= 1024.0 1615 | return f"{num:.1f}Yi{suffix}" 1616 | 1617 | 1618 | def get_gpu_ram() -> str: 1619 | """ 1620 | Returns memory usage statistics for the CPU, GPU, and Torch. 1621 | :return: 1622 | """ 1623 | devid = torch.cuda.current_device() 1624 | return f"GPU.{devid} {torch.cuda.get_device_name(devid)}" 1625 | 1626 | 1627 | def init_arb_buckets(args, accelerator): 1628 | arg_config = { 1629 | "bsz": args.train_batch_size, 1630 | "seed": args.seed, 1631 | "debug": args.debug_arb, 1632 | "base_res": (args.resolution, args.resolution), 1633 | "max_size": args.arb_max_size, 1634 | "divisible": args.arb_divisible, 1635 | "max_ar_error": args.arb_max_ar_error, 1636 | "min_dim": args.arb_min_dim, 1637 | "dim_limit": args.arb_dim_limit, 1638 | "world_size": accelerator.num_processes, 1639 | "global_rank": args.local_rank, 1640 | } 1641 | 1642 | if args.debug_arb: 1643 | print("BucketManager initialized using config:") 1644 | print(json.dumps(arg_config, sort_keys=True, indent=4)) 1645 | else: 1646 | print(f"BucketManager initialized with base_res = {arg_config['base_res']}, max_size = {arg_config['max_size']}") 1647 | 1648 | def get_id_size_dict(entries, hint): 1649 | id_size_map = {} 1650 | 1651 | for entry in tqdm(entries, desc=f"Loading resolution from {hint} images", disable=args.local_rank not in [0, -1]): 1652 | with Image.open(entry) as img: 1653 | size = img.size 1654 | id_size_map[entry] = size 1655 | 1656 | return id_size_map 1657 | 1658 | instance_entries, class_entries = [], [] 1659 | for concept in args.concepts_list: 1660 | inst_img_path = [x for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file() and x.suffix != ".txt"] 1661 | instance_entries.extend(inst_img_path) 1662 | 1663 | if args.with_prior_preservation: 1664 | class_img_path = [x for x in Path(concept["class_data_dir"]).iterdir() if x.is_file() and x.suffix != ".txt"] 1665 | class_entries.extend(class_img_path[:args.num_class_images]) 1666 | 1667 | instance_id_size_map = get_id_size_dict(instance_entries, "instance") 1668 | class_id_size_map = get_id_size_dict(class_entries, "class") 1669 | 1670 | instance_bucket_manager = AspectRatioBucket(instance_id_size_map, **arg_config) 1671 | class_bucket_manager = AspectRatioBucket(class_id_size_map, **arg_config) 1672 | 1673 | return instance_bucket_manager, class_bucket_manager 1674 | 1675 | 1676 | def main(args): 1677 | logging_dir = Path(args.output_dir, args.logging_dir) 1678 | 1679 | metrics = ["tensorboard"] 1680 | if args.wandb: 1681 | import wandb 1682 | run = wandb.init(project=args.wandb_name, reinit=False) 1683 | metrics.append("wandb") 1684 | 1685 | accelerator = Accelerator( 1686 | gradient_accumulation_steps=args.gradient_accumulation_steps, 1687 | mixed_precision=args.mixed_precision, 1688 | log_with=metrics, 1689 | project_dir=logging_dir, 1690 | ) 1691 | 1692 | print(get_gpu_ram()) 1693 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 1694 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 1695 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 1696 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 1697 | raise ValueError( 1698 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 1699 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 1700 | ) 1701 | 1702 | if args.seed is not None: 1703 | set_seed(args.seed) 1704 | 1705 | if args.concepts_list is None: 1706 | args.concepts_list = [ 1707 | { 1708 | "instance_prompt": args.instance_prompt, 1709 | "class_prompt": args.class_prompt, 1710 | "class_negative_prompt": args.class_negative_prompt, 1711 | "instance_data_dir": args.instance_data_dir, 1712 | "class_data_dir": args.class_data_dir 1713 | } 1714 | ] 1715 | else: 1716 | if type(args.concepts_list) == str: 1717 | with open(args.concepts_list, "r") as f: 1718 | args.concepts_list = json.load(f) 1719 | 1720 | if args.with_prior_preservation and accelerator.is_local_main_process: 1721 | generate_class_images(args, accelerator) 1722 | 1723 | # Load the tokenizer 1724 | if args.tokenizer_name: 1725 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 1726 | elif args.pretrained_model_name_or_path: 1727 | tokenizer = CLIPTokenizer.from_pretrained( 1728 | args.pretrained_model_name_or_path, subfolder="tokenizer") 1729 | else: 1730 | raise ValueError(args.tokenizer_name) 1731 | 1732 | # Load models and create wrapper for stable diffusion 1733 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") 1734 | 1735 | def encode_tokens(tokens): 1736 | 1737 | if args.clip_skip > 1: 1738 | result = text_encoder(tokens, output_hidden_states=True, return_dict=True) 1739 | return text_encoder.text_model.final_layer_norm(result.hidden_states[-args.clip_skip]) 1740 | 1741 | return text_encoder(tokens)[0] 1742 | 1743 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") 1744 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") 1745 | 1746 | if args.xformers: 1747 | unet.set_use_memory_efficient_attention_xformers(True) 1748 | 1749 | vae.requires_grad_(False) 1750 | if not args.train_text_encoder: 1751 | text_encoder.requires_grad_(False) 1752 | 1753 | if args.gradient_checkpointing: 1754 | unet.enable_gradient_checkpointing() 1755 | if args.train_text_encoder: 1756 | text_encoder.gradient_checkpointing_enable() 1757 | 1758 | if args.scale_lr: 1759 | args.learning_rate = ( 1760 | args.learning_rate * args.gradient_accumulation_steps * 1761 | args.train_batch_size * accelerator.num_processes 1762 | ) 1763 | 1764 | elif args.scale_lr_sqrt: 1765 | args.learning_rate *= math.sqrt(args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes) 1766 | 1767 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 1768 | if args.use_8bit_adam: 1769 | try: 1770 | import bitsandbytes as bnb 1771 | except ImportError: 1772 | raise ImportError( 1773 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 1774 | ) 1775 | 1776 | optimizer_class = bnb.optim.AdamW8bit 1777 | elif args.use_deepspeed_adam: 1778 | try: 1779 | import deepspeed 1780 | except ImportError: 1781 | raise ImportError( 1782 | "Failed to import Deepspeed" 1783 | ) 1784 | optimizer_class = deepspeed.ops.adam.DeepSpeedCPUAdam 1785 | else: 1786 | optimizer_class = get_optimizer_class(args.optimizer) 1787 | 1788 | params_to_optimize = ( 1789 | itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() 1790 | ) 1791 | 1792 | if "adam" in args.optimizer.lower(): 1793 | optimizer = optimizer_class( 1794 | params_to_optimize, 1795 | lr=args.learning_rate, 1796 | betas=(args.adam_beta1, args.adam_beta2), 1797 | weight_decay=args.weight_decay, 1798 | eps=args.adam_epsilon, 1799 | ) 1800 | elif "sgd" in args.optimizer.lower(): 1801 | optimizer = optimizer_class( 1802 | params_to_optimize, 1803 | lr=args.learning_rate, 1804 | momentum=args.sgd_momentum, 1805 | dampening=args.sgd_dampening, 1806 | weight_decay=args.weight_decay 1807 | ) 1808 | else: 1809 | raise ValueError(args.optimizer) 1810 | 1811 | noise_scheduler = DDIMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") 1812 | dataset_class = AspectRatioDataset if args.use_aspect_ratio_bucket else DreamBoothDataset 1813 | train_dataset = dataset_class( 1814 | concepts_list=args.concepts_list, 1815 | tokenizer=tokenizer, 1816 | with_prior_preservation=args.with_prior_preservation, 1817 | size=args.resolution, 1818 | center_crop=args.center_crop, 1819 | num_class_images=args.num_class_images, 1820 | read_prompt_filename=args.read_prompt_filename, 1821 | read_prompt_txt=args.read_prompt_txt, 1822 | append_pos=args.append_prompt, 1823 | bsz=args.train_batch_size, 1824 | debug_arb=args.debug_arb, 1825 | seed=args.seed, 1826 | deepdanbooru=args.deepdanbooru, 1827 | dd_threshold=args.dd_threshold, 1828 | dd_alpha_sort=args.dd_alpha_sort, 1829 | dd_use_spaces=args.dd_use_spaces, 1830 | dd_use_escape=args.dd_use_escape, 1831 | dd_include_ranks=args.dd_include_ranks, 1832 | enable_rotate=args.enable_rotate, 1833 | ucg=args.ucg, 1834 | debug_prompt=args.debug_prompt, 1835 | ) 1836 | 1837 | def collate_fn_wrap(examples): 1838 | # workround for variable list 1839 | if len(examples) == 1: 1840 | examples = examples[0] 1841 | return collate_fn(examples) 1842 | 1843 | def collate_fn(examples): 1844 | input_ids = [example["instance_prompt_ids"] for example in examples] 1845 | pixel_values = [example["instance_images"] for example in examples] 1846 | 1847 | # Concat class and instance examples for prior preservation. 1848 | # We do this to avoid doing two forward passes. 1849 | if args.with_prior_preservation: 1850 | input_ids += [example["class_prompt_ids"] for example in examples] 1851 | pixel_values += [example["class_images"] for example in examples] 1852 | 1853 | pixel_values = torch.stack(pixel_values) 1854 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 1855 | 1856 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 1857 | 1858 | batch = { 1859 | "input_ids": input_ids, 1860 | "pixel_values": pixel_values, 1861 | } 1862 | return batch 1863 | 1864 | if args.ucg: 1865 | args.not_cache_latents = True 1866 | print("Latents cache disabled.") 1867 | 1868 | if args.use_aspect_ratio_bucket: 1869 | args.not_cache_latents = True 1870 | print("Latents cache disabled.") 1871 | instance_bucket_manager, class_bucket_manager = init_arb_buckets(args, accelerator) 1872 | sampler = AspectRatioSampler(instance_bucket_manager, class_bucket_manager, accelerator.num_processes, args.with_prior_preservation) 1873 | 1874 | train_dataloader = torch.utils.data.DataLoader( 1875 | train_dataset, collate_fn=collate_fn_wrap, num_workers=1, sampler=sampler, 1876 | ) 1877 | else: 1878 | train_dataloader = torch.utils.data.DataLoader( 1879 | train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=1 1880 | ) 1881 | 1882 | weight_dtype = torch.float32 1883 | if args.mixed_precision == "fp16": 1884 | weight_dtype = torch.float16 1885 | elif args.mixed_precision == "bf16": 1886 | weight_dtype = torch.bfloat16 1887 | 1888 | if args.use_ema: 1889 | ema_unet = EMAModel(unet.parameters()) 1890 | ema_unet.to(accelerator.device, dtype=weight_dtype) 1891 | 1892 | # Move text_encode and vae to gpu. 1893 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 1894 | # as these models are only used for inference, keeping weights in full precision is not required. 1895 | vae.to(accelerator.device, dtype=weight_dtype) 1896 | if not args.train_text_encoder: 1897 | text_encoder.to(accelerator.device, dtype=weight_dtype) 1898 | 1899 | if not args.not_cache_latents: 1900 | latents_cache = [] 1901 | text_encoder_cache = [] 1902 | for batch in tqdm(train_dataloader, desc="Caching latents", disable=not accelerator.is_local_main_process): 1903 | with torch.no_grad(): 1904 | batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype) 1905 | batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True) 1906 | latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) 1907 | if args.train_text_encoder: 1908 | text_encoder_cache.append(batch["input_ids"]) 1909 | else: 1910 | text_encoder_cache.append(encode_tokens(batch["input_ids"])) 1911 | 1912 | train_dataset = LatentsDataset(latents_cache, text_encoder_cache) 1913 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True) 1914 | 1915 | del vae 1916 | if not args.train_text_encoder: 1917 | del text_encoder 1918 | if torch.cuda.is_available(): 1919 | torch.cuda.empty_cache() 1920 | 1921 | # Scheduler and math around the number of training steps. 1922 | overrode_max_train_steps = False 1923 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 1924 | if args.max_train_steps is None: 1925 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 1926 | overrode_max_train_steps = True 1927 | 1928 | if args.lr_scheduler == "cosine_with_restarts_mod": 1929 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 1930 | optimizer=optimizer, 1931 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 1932 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 1933 | num_cycles=args.num_cycles, 1934 | last_epoch=args.last_epoch, 1935 | ) 1936 | elif args.lr_scheduler == "cosine_mod": 1937 | lr_scheduler = get_cosine_schedule_with_warmup( 1938 | optimizer=optimizer, 1939 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 1940 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 1941 | num_cycles=args.num_cycles, 1942 | last_epoch=args.last_epoch, 1943 | ) 1944 | else: 1945 | lr_scheduler = get_scheduler( 1946 | args.lr_scheduler, 1947 | optimizer=optimizer, 1948 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 1949 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 1950 | ) 1951 | 1952 | base_step = 0 1953 | base_epoch = 0 1954 | 1955 | if args.train_text_encoder: 1956 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 1957 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler 1958 | ) 1959 | else: 1960 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 1961 | unet, optimizer, train_dataloader, lr_scheduler 1962 | ) 1963 | 1964 | if args.resume: 1965 | state_dict = torch.load(os.path.join(args.pretrained_model_name_or_path, f"state.pt"), map_location="cuda") 1966 | if "optimizer" in state_dict: 1967 | optimizer.load_state_dict(state_dict["optimizer"]) 1968 | 1969 | if "scheduler" in state_dict: 1970 | lr_scheduler.load_state_dict(state_dict["scheduler"]) 1971 | 1972 | last_lr = state_dict["scheduler"]["_last_lr"] 1973 | print(f"Loaded state_dict from '{args.pretrained_model_name_or_path}': last_lr = {last_lr}") 1974 | 1975 | base_step = state_dict["total_steps"] 1976 | base_epoch = state_dict["total_epoch"] 1977 | del state_dict 1978 | 1979 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 1980 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 1981 | if overrode_max_train_steps: 1982 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 1983 | # Afterwards we recalculate our number of training epochs 1984 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 1985 | 1986 | # We need to initialize the trackers we use, and also store our configuration. 1987 | # The trackers initializes automatically on the main process. 1988 | if accelerator.is_main_process: 1989 | accelerator.init_trackers("dreambooth") 1990 | 1991 | # Train! 1992 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 1993 | 1994 | if accelerator.is_main_process: 1995 | print("***** Running training *****") 1996 | print(f" Num examples = {len(train_dataset)}") 1997 | print(f" Num batches each epoch = {len(train_dataloader)}") 1998 | print(f" Num Epochs = {args.num_train_epochs}") 1999 | print(f" Instantaneous batch size per device = {args.train_batch_size}") 2000 | print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 2001 | print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 2002 | print(f" Total optimization steps = {args.max_train_steps}") 2003 | 2004 | def save_weights(interrupt=False): 2005 | # Create the pipeline using using the trained modules and save it. 2006 | if accelerator.is_main_process: 2007 | 2008 | if args.train_text_encoder: 2009 | text_enc_model = accelerator.unwrap_model(text_encoder) 2010 | else: 2011 | text_enc_model = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") 2012 | 2013 | unet_unwrapped = accelerator.unwrap_model(unet) 2014 | 2015 | if args.save_unet_half or args.unet_half: 2016 | import copy 2017 | unet_unwrapped = copy.deepcopy(unet_unwrapped).half() 2018 | 2019 | scheduler = DDIMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") 2020 | pipeline = StableDiffusionPipeline.from_pretrained( 2021 | args.pretrained_model_name_or_path, 2022 | unet=unet_unwrapped, 2023 | text_encoder=text_enc_model, 2024 | vae=AutoencoderKL.from_pretrained(args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path, subfolder=None if args.pretrained_vae_name_or_path else "vae"), 2025 | safety_checker=None, 2026 | scheduler=scheduler, 2027 | torch_dtype=weight_dtype, 2028 | ) 2029 | 2030 | output_dir = Path(args.output_dir) 2031 | output_dir.mkdir(exist_ok=True) 2032 | 2033 | save_dir = output_dir / f"checkpoint_{global_step}" 2034 | if local_step >= args.max_train_steps: 2035 | save_dir = output_dir / f"checkpoint_last" 2036 | 2037 | save_dir.mkdir(exist_ok=True) 2038 | pipeline.save_pretrained(save_dir) 2039 | print(f"[*] Weights saved at {save_dir}") 2040 | 2041 | if args.use_ema: 2042 | ema_path = save_dir / "unet_ema" 2043 | with ema_unet.average_parameters(unet_unwrapped.parameters()): 2044 | unet_unwrapped.save_pretrained(ema_path) 2045 | 2046 | ema_unet.to("cpu", dtype=weight_dtype) 2047 | torch.cuda.empty_cache() 2048 | print(f"[*] EMA Weights saved at {ema_path}") 2049 | 2050 | if args.save_states: 2051 | accelerator.save({ 2052 | 'total_epoch': global_epoch, 2053 | 'total_steps': global_step, 2054 | 'optimizer': optimizer.state_dict(), 2055 | 'scheduler': lr_scheduler.state_dict(), 2056 | 'loss': loss, 2057 | }, os.path.join(save_dir, "state.pt")) 2058 | 2059 | with open(save_dir / "args.json", "w") as f: 2060 | args.resume_from = str(save_dir) 2061 | json.dump(args.__dict__, f, indent=2) 2062 | 2063 | if interrupt: 2064 | return 2065 | 2066 | if args.save_sample_prompt: 2067 | pipeline = pipeline.to(accelerator.device) 2068 | g_cuda = torch.Generator(device=accelerator.device).manual_seed(args.seed) 2069 | pipeline.set_progress_bar_config(disable=True) 2070 | sample_dir = save_dir / "samples" 2071 | sample_dir.mkdir(exist_ok=True) 2072 | with torch.autocast("cuda"), torch.inference_mode(): 2073 | for i in tqdm(range(args.n_save_sample), desc="Generating samples"): 2074 | images = pipeline( 2075 | args.save_sample_prompt, 2076 | negative_prompt=args.save_sample_negative_prompt, 2077 | guidance_scale=args.save_guidance_scale, 2078 | num_inference_steps=args.save_infer_steps, 2079 | generator=g_cuda 2080 | ).images 2081 | images[0].save(sample_dir / f"{i}.png") 2082 | 2083 | if args.wandb: 2084 | wandb.log({"samples": [wandb.Image(str(x)) for x in sample_dir.glob("*.png")]}, step=global_step) 2085 | 2086 | del pipeline 2087 | if torch.cuda.is_available(): 2088 | torch.cuda.empty_cache() 2089 | 2090 | if args.use_ema: 2091 | ema_unet.to(accelerator.device, dtype=weight_dtype) 2092 | 2093 | if args.wandb_artifact: 2094 | model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={ 2095 | 'epochs_trained': global_epoch + 1, 2096 | 'project': run.project 2097 | }) 2098 | model_artifact.add_dir(save_dir) 2099 | wandb.log_artifact(model_artifact, aliases=['latest', 'last', f'epoch {global_epoch + 1}']) 2100 | 2101 | if args.rm_after_wandb_saved: 2102 | shutil.rmtree(save_dir) 2103 | subprocess.run("wandb", "artifact", "cache", "cleanup", "1G") 2104 | 2105 | # Only show the progress bar once on each machine. 2106 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 2107 | progress_bar.set_description("Steps") 2108 | local_step = 0 2109 | loss_avg = AverageMeter() 2110 | text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad() 2111 | 2112 | @atexit.register 2113 | def on_exit(): 2114 | if 100 < local_step < args.max_train_steps and accelerator.is_local_main_process: 2115 | print("Saving model...") 2116 | save_weights(interrupt=True) 2117 | 2118 | for epoch in range(args.num_train_epochs): 2119 | unet.train() 2120 | if args.train_text_encoder: 2121 | text_encoder.train() 2122 | for _, batch in enumerate(train_dataloader): 2123 | with accelerator.accumulate(unet): 2124 | # Convert images to latent space 2125 | with torch.no_grad(): 2126 | if not args.not_cache_latents: 2127 | latent_dist = batch[0][0] 2128 | else: 2129 | latent_dist = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist 2130 | latents = latent_dist.sample() * 0.18215 2131 | 2132 | # Sample noise that we'll add to the latents 2133 | noise = torch.randn_like(latents) 2134 | bsz = latents.shape[0] 2135 | # Sample a random timestep for each image 2136 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 2137 | timesteps = timesteps.long() 2138 | 2139 | # Add noise to the latents according to the noise magnitude at each timestep 2140 | # (this is the forward diffusion process) 2141 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 2142 | 2143 | # Get the text embedding for conditioning 2144 | with text_enc_context: 2145 | if not args.not_cache_latents: 2146 | if args.train_text_encoder: 2147 | encoder_hidden_states = encode_tokens(batch[0][1]) 2148 | else: 2149 | encoder_hidden_states = batch[0][1] 2150 | else: 2151 | encoder_hidden_states = encode_tokens(batch["input_ids"]) 2152 | 2153 | # Predict the noise residual 2154 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 2155 | 2156 | if args.with_prior_preservation: 2157 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 2158 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 2159 | noise, noise_prior = torch.chunk(noise, 2, dim=0) 2160 | 2161 | # Compute instance loss 2162 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() 2163 | 2164 | # Compute prior loss 2165 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") 2166 | 2167 | # Add the prior loss to the instance loss. 2168 | loss = loss + args.prior_loss_weight * prior_loss 2169 | else: 2170 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 2171 | 2172 | accelerator.backward(loss) 2173 | if accelerator.sync_gradients: 2174 | params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()) 2175 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 2176 | optimizer.step() 2177 | lr_scheduler.step() 2178 | optimizer.zero_grad() 2179 | loss_avg.update(loss.detach_(), bsz) 2180 | 2181 | global_step = base_step + local_step 2182 | global_epoch = base_epoch + epoch 2183 | 2184 | if not local_step % args.log_interval: 2185 | logs = { 2186 | "epoch": global_epoch + 1, 2187 | "loss": loss_avg.avg.item(), 2188 | "lr": lr_scheduler.get_last_lr()[0] 2189 | } 2190 | progress_bar.set_postfix(**logs) 2191 | accelerator.log(logs, step=global_step) 2192 | 2193 | # Checks if the accelerator has performed an optimization step behind the scenes 2194 | # if accelerator.sync_gradients: 2195 | if accelerator.sync_gradients: 2196 | if args.use_ema: 2197 | ema_unet.step(unet.parameters()) 2198 | progress_bar.update(1) 2199 | local_step += 1 2200 | 2201 | if local_step > args.save_min_steps and not global_step % args.save_interval: 2202 | save_weights() 2203 | 2204 | if local_step >= args.max_train_steps: 2205 | break 2206 | 2207 | accelerator.wait_for_everyone() 2208 | 2209 | save_weights() 2210 | accelerator.end_training() 2211 | 2212 | 2213 | if __name__ == "__main__": 2214 | args = parse_args() 2215 | main(args) --------------------------------------------------------------------------------