├── .gitignore ├── README.md ├── back_train.sh ├── ckpt_models ├── model.yaml └── put_your_ckpt_models_here.txt ├── datasets └── put_datasets_here.txt ├── other └── something others.txt ├── test_model.py ├── test_prompts_object.txt ├── test_prompts_style.txt ├── tools ├── ckpt2diffusers.py ├── ckpt2diffusers_old.py ├── ckpt_merge.py ├── ckpt_prune.py ├── deepdanbooru-models │ └── put_deepdanbooru_model_here.txt ├── diagnose_tensorboard.py ├── diffusers2ckpt.py ├── handle_images.py ├── label_images.py ├── test_cuda.py ├── train_dreambooth.py ├── train_dreambooth_rect.py ├── train_textual_inversion.py └── upload_cos.py ├── train_object.sh ├── train_object_rect.sh ├── train_style.sh ├── train_style_rect.sh ├── train_textual_inversion.sh └── 运行.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .ipynb_checkpoints 3 | */.ipynb_checkpoints 4 | *.ckpt 5 | *.pt 6 | *.whl 7 | *.log 8 | *.png 9 | *.jpg 10 | nohup.out 11 | /datasets 12 | /model 13 | /new-* 14 | /log 15 | /output* 16 | /tools/deepdanbooru-models/* 17 | /tools/diffusers-models/* 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dreambooth Stable Diffusion 集成化环境训练 2 | 如果你是在autodl上的机器可以直接使用封装好的镜像创建实例,开箱即用 3 | 如果是本地或者其他服务器上也可以使用,需要手动安装一些pip包 4 | 5 | ## 如何运行 6 | 直接在autodl使用镜像运行:https://www.codewithgpu.com/i/CrazyBoyM/dreambooth-for-diffusion/dreambooth-for-diffusion 7 | 8 | 如果你不熟悉notebook代码的训练方式,也可以直接使用封装好的webui在线镜像(含稳定Dreambooth、dreamArtist训练插件,已fix): 9 | https://www.codewithgpu.com/i/CrazyBoyM/sd_dreambooth_extension_webui/dreambooth-dreamartist-for-webui 10 | 11 | ## 注意 12 | 本项目仅供用于学习、测试人工智能技术使用 13 | 请勿用于训练生成不良或侵权图片内容 14 | 15 | ## 关于项目 16 | 在autodl封装的镜像名称为:dreambooth-for-diffusion 17 | 可在创建实例时直接选择公开的算法镜像使用。 18 | 在autodl内蒙A区A5000的机器上封装,如遇到问题且无法自行解决的朋友请使用同一环境。 19 | 白菜写教程时做了尽可能多的测试,但仍然无法确保每一个环节都完全覆盖 20 | 如有小错误可尝试手动解决,或者访问git项目地址查看最新的README 21 | 项目地址:https://github.com/CrazyBoyM/dreambooth-for-diffusion 22 | 23 | 如果遇到问题可到b站主页找该教程对应训练演示的视频:https://space.bilibili.com/291593914 24 | (因为现在写时视频还没做) 25 | 26 | ## 强烈建议 27 | 1.用vscode的ssh功能远程连接到本服务器,训练体验更好,autodl自带的notebook也不错,有文件上传、下载功能。 28 | 2.(重要)先把/root/目录下dreambooth-for-diffusion文件夹整个移动到/root/autodl-tmp/路径下(数据盘),避免系统盘空间满 29 | 30 | ## 进入工作文件夹 31 | ``` 32 | cd /root/autodl-tmp/dreambooth-for-diffusion 33 | ``` 34 | 35 | ## 转换ckpt检查点文件为diffusers官方权重 36 | 已经内置了两个基础模型,可以根据自己数据集的特性选择。 37 | - sd_1-5.ckpt是偏真实风格 38 | - nd_lastest.ckpt是偏二次元风格 39 | 开始转换二次元模型: 40 | ``` 41 | # 该步需要运行大约一分钟 42 | !python tools/ckpt2diffusers.py \ 43 | --checkpoint_path=./ckpt_models/nd_lastest.ckpt \ 44 | --dump_path=./model \ 45 | --vae_path=./ckpt_models/animevae.pt \ 46 | --original_config_file=./ckpt_models/model.yaml \ 47 | --scheduler_type="ddim" 48 | ``` 49 | 转换写实风格模型: 50 | ``` 51 | # 该步需要运行大约一分钟 52 | !python tools/ckpt2diffusers.py \ 53 | --checkpoint_path=./ckpt_models/sd_1-5.ckpt \ 54 | --dump_path=./model \ 55 | --original_config_file=./ckpt_models/model.yaml \ 56 | --scheduler_type="ddim" 57 | ``` 58 | 这里后面跟的两个文件分别是你的ckpt文件和转换后的输出路径。 59 | 60 | ## 转换diffusers官方权重为ckpt检查点文件 61 | ``` 62 | python tools/diffusers2ckpt.py ./new_model ./ckpt_models/newModel_half.ckpt --half 63 | ``` 64 | 如需保存为float16版精度,添加--half参数,权重大小会减半。 65 | 66 | ## 准备数据集 67 | 请按照训练任务准备好对应的数据集。 68 | ### 图像裁剪为512*512 69 | 我在tools/handle_images.py中提供了一份批量处理的代码用于参考 70 | 自动center crop图像,并缩放尺寸 71 | ``` 72 | python tools/handle_images.py ./datasets/test ./datasets/test2 --width=512 --height=512 73 | ``` 74 | test为未处理的原始图像文件夹,test2为输出处理图像的路径 75 | 如需处理透明背景png图为黑色/白色底jpg,可以添加--png参数。 76 | 77 | ### 图像自动标注 78 | 使用deepdanbooru生成tags label. 79 | ``` 80 | !python tools/label_images.py --path=./datasets/test2 81 | ``` 82 | 第二个参数--path为你需要标注的图像文件夹路径 83 | 84 | 注:如提示deepdanbooru找不到,可自行参考以下仓库进行编译 85 | https://github.com/KichangKim/DeepDanbooru 86 | 87 | 我在other文件夹下也提供了一份编译好的版本: 88 | ``` 89 | pip install other/deepdanbooru-1.0.0-py3-none-any.whl 90 | ``` 91 | 92 | ## 训练以及常用命令总结 93 | ### 配置训练环境(可选) 94 | 如果你不是在封装好的镜像上直接使用,则需要做以下配置: 95 | ``` 96 | pip install accelerate 97 | ``` 98 | 运行以下命令,并选择本地运行、NO、NO 99 | ``` 100 | accelerate config 101 | ``` 102 | 103 | ### 开始训练 104 | 请打开train.sh文件,参考其中的具体参数说明。 105 | 如果需要训练特定人、事物: 106 | (推荐准备3~5张风格统一、特定对象的图片) 107 | 108 | ``` 109 | sh train_object.sh 110 | ``` 111 | 112 | 如果要Finetune训练自己的大模型: 113 | (推荐准备3000+张图片,包含尽可能的多样性,数据决定训练出的模型质量) 114 | ``` 115 | sh train_style.sh 116 | ``` 117 | A5000的训练速度大概8分钟/1000步 118 | 119 | ### 测试训练效果 120 | 打开train/test_model.py文件修改其中的model_path和prompt,然后执行: 121 | ``` 122 | python test_model.py 123 | ``` 124 | 125 | ### 其他常用命令 126 | 如需后台任务训练: 127 | ``` 128 | nohup sh train_style.sh & 129 | ``` 130 | 推荐晚上这样挂后台跑着,不需要担心连接中断导致的训练停止。 131 | 白菜个人推荐的省钱训练小妙招: 132 | ``` 133 | nohup sh back_train.sh & 134 | ``` 135 | (训练完直接自动关机) 136 | 137 | 训练日志会输出到nohup.out文件中,可以vscode直接打开或下载查看。 138 | 查看日志后十行: 139 | ``` 140 | tail -n 10 nohup.out 141 | ``` 142 | 143 | 查看当前磁盘占有率: 144 | (记得清理不要的文件,不然经常容易磁盘几十个g空间满导致模型保存失败!!) 145 | ``` 146 | df -h 147 | ``` 148 | 149 | ## 如果你是在其他服务器上执行,没有使用集成环境 150 | 提示缺少一些包可以自行安装: 151 | ``` 152 | pip install diffusers 153 | pip install ftfy 154 | pip install tensorflow-gpu 155 | pip install pytorch_lightning 156 | pip install OmegaConf 157 | ... 以及其他的一些 158 | ``` 159 | 160 | ## 学术加速(可选) 161 | 如果你需要拉取git上一些内容,发现速度很慢,以下内容或许对你有帮助。 162 | 请根据机器所在区域执行以下命令: 163 | ``` 164 | 北京A区的实例¶ 165 | export http_proxy=http://100.72.64.19:12798 && export https_proxy=http://100.72.64.19:12798 166 | 167 | 内蒙A区的实例¶ 168 | export http_proxy=http://192.168.1.174:12798 && export https_proxy=http://192.168.1.174:12798 169 | 170 | 泉州A区的实例¶ 171 | export http_proxy=http://10.55.146.88:12798 && export https_proxy=http://10.55.146.88:12798 172 | ``` 173 | 174 | ## xformers(可选) 175 | 由于A5000实测训练和推理的速度已经很快了,就没有安装。 176 | 如果你使用的是其他显卡或者实在有需要,可以参考下面的地址进行编译使用: 177 | https://github.com/facebookresearch/xformers 178 | (我猜到你可能想要尝试,已经在train/other目录下放了一个提前编译好的版本啦) 179 | 注:需要升级pytorch版本到1.12.x以上才能安装使用(好懒)(更新:我已经升级好并帮你装好啦~!) 180 | 181 | ## 升级pytorch版本到1.12.x 182 | ``` 183 | pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 184 | ``` 185 | 186 | # 关于autodl的使用心得 187 | 188 | ## 服务器的数据迁移 189 | 经常关机后再开机发现机器资源被占用了,这时候你只能另外开一台机器了 190 | 但是对于已经关机的机器在菜单上有个功能是“跨实例拷贝数据”, 191 | 可以很方便地同步/root/autodl-tmp文件夹下的内容到其他已开机的机器(所以推荐工作文件都放这) 192 | (注意,只适用于同一区域的机器之间) 193 | 数据迁移教程:https://www.autodl.com/docs/migrate_instance/ 194 | 195 | ## 传输文件的方式 196 | ### 方式一 使用VScode 197 | 直接从vscode拖动上传、下载文件,速度慢,也最简单。 198 | 199 | ### 方式二 使用autodl的用户网盘 200 | 在autodl的网盘界面初始化一个同区域的网盘,然后重启一下服务器实例 201 | 会发现多了一个文件夹/root/autodl-nas/, 你可以在网页界面进行权重和数据的上传 202 | 训练完,把生成的权重文件移动到该路径下,就可以去网页上进行下载了。 203 | (对应网页:https://www.autodl.com/console/netdisk) 204 | 注意:初始化的网盘一定要和服务器处于同一区域. 205 | 206 | ### 方式三 使用对象存储 207 | 有条件的朋友也可以尝试使用cos或oss进行文件中转,速度更快。 208 | 在train/tools文件夹中我也放置了一份上传到cos的代码供参考,请有经验的朋友自行使用。 209 | 210 | autodl官网也有一些推荐的方式可以参考,https://www.autodl.com/docs/scp/ 211 | 212 | # 其他内容 213 | 感谢diffusers、deepdanbooru等开源项目 214 | 风格训练代码来自nbardy的PR进行修改 215 | 打tags标签的部分代码来自crosstyan、Nyanko Lepsoni、AUTOMATC1111 216 | 如果感兴趣欢迎加QQ群探讨交流,455521885 217 | 封装整理by - 白菜 218 | -------------------------------------------------------------------------------- /back_train.sh: -------------------------------------------------------------------------------- 1 | # 省钱训练:训练正常完成后关机 2 | sh train_style.sh && shutdown -------------------------------------------------------------------------------- /ckpt_models/model.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | 19 | scheduler_config: # 10000 warmup steps 20 | target: ldm.lr_scheduler.LambdaLinearScheduler 21 | params: 22 | warm_up_steps: [ 10000 ] 23 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 24 | f_start: [ 1.e-6 ] 25 | f_max: [ 1. ] 26 | f_min: [ 1. ] 27 | 28 | unet_config: 29 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 30 | params: 31 | image_size: 32 # unused 32 | in_channels: 4 33 | out_channels: 4 34 | model_channels: 320 35 | attention_resolutions: [ 4, 2, 1 ] 36 | num_res_blocks: 2 37 | channel_mult: [ 1, 2, 4, 4 ] 38 | num_heads: 8 39 | use_spatial_transformer: True 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: True 43 | legacy: False 44 | 45 | first_stage_config: 46 | target: ldm.models.autoencoder.AutoencoderKL 47 | params: 48 | embed_dim: 4 49 | monitor: val/rec_loss 50 | ddconfig: 51 | double_z: true 52 | z_channels: 4 53 | resolution: 512 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 70 | -------------------------------------------------------------------------------- /ckpt_models/put_your_ckpt_models_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrazyBoyM/dreambooth-for-diffusion/311717098636b9af72c4907fcc9436df9eb2c352/ckpt_models/put_your_ckpt_models_here.txt -------------------------------------------------------------------------------- /datasets/put_datasets_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrazyBoyM/dreambooth-for-diffusion/311717098636b9af72c4907fcc9436df9eb2c352/datasets/put_datasets_here.txt -------------------------------------------------------------------------------- /other/something others.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrazyBoyM/dreambooth-for-diffusion/311717098636b9af72c4907fcc9436df9eb2c352/other/something others.txt -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline 2 | import torch 3 | from diffusers import DDIMScheduler 4 | 5 | model_path = "./new_model" 6 | prompt = "a cute girl, blue eyes, brown hair" 7 | torch.manual_seed(123123123) 8 | 9 | pipe = StableDiffusionPipeline.from_pretrained( 10 | model_path, 11 | torch_dtype=torch.float16, 12 | scheduler=DDIMScheduler( 13 | beta_start=0.00085, 14 | beta_end=0.012, 15 | beta_schedule="scaled_linear", 16 | clip_sample=False, 17 | set_alpha_to_one=True, 18 | ), 19 | safety_checker=None 20 | ) 21 | 22 | # def dummy(images, **kwargs): 23 | # return images, False 24 | # pipe.safety_checker = dummy 25 | pipe = pipe.to("cuda") 26 | images = pipe(prompt, width=512, height=512, num_inference_steps=30, num_images_per_prompt=3).images 27 | for i, image in enumerate(images): 28 | image.save(f"test-{i}.png") 29 | -------------------------------------------------------------------------------- /test_prompts_object.txt: -------------------------------------------------------------------------------- 1 | a photo of dog 2 | a photo of dog -------------------------------------------------------------------------------- /test_prompts_style.txt: -------------------------------------------------------------------------------- 1 | a cute girl, blue eyes, brown hair 2 | a cute girl, blue eyes, blue hair 3 | a cute boy, green eyes, brown hair -------------------------------------------------------------------------------- /tools/ckpt2diffusers_old.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Conversion script for the LDM checkpoints. """ 16 | 17 | import argparse, os 18 | import torch 19 | 20 | try: 21 | from omegaconf import OmegaConf 22 | except ImportError: 23 | raise ImportError("OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`.") 24 | 25 | from transformers import BertTokenizerFast, CLIPFeatureExtractor, CLIPTokenizer, CLIPTextModel 26 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler 27 | from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertModel, LDMBertConfig 28 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker 29 | 30 | def shave_segments(path, n_shave_prefix_segments=1): 31 | """ 32 | Removes segments. Positive values shave the first segments, negative shave the last segments. 33 | """ 34 | if n_shave_prefix_segments >= 0: 35 | return '.'.join(path.split('.')[n_shave_prefix_segments:]) 36 | else: 37 | return '.'.join(path.split('.')[:n_shave_prefix_segments]) 38 | 39 | 40 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): 41 | """ 42 | Updates paths inside resnets to the new naming scheme (local renaming) 43 | """ 44 | mapping = [] 45 | for old_item in old_list: 46 | new_item = old_item.replace('in_layers.0', 'norm1') 47 | new_item = new_item.replace('in_layers.2', 'conv1') 48 | 49 | new_item = new_item.replace('out_layers.0', 'norm2') 50 | new_item = new_item.replace('out_layers.3', 'conv2') 51 | 52 | new_item = new_item.replace('emb_layers.1', 'time_emb_proj') 53 | new_item = new_item.replace('skip_connection', 'conv_shortcut') 54 | 55 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 56 | 57 | mapping.append({'old': old_item, 'new': new_item}) 58 | 59 | return mapping 60 | 61 | 62 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): 63 | """ 64 | Updates paths inside resnets to the new naming scheme (local renaming) 65 | """ 66 | mapping = [] 67 | for old_item in old_list: 68 | new_item = old_item 69 | 70 | new_item = new_item.replace('nin_shortcut', 'conv_shortcut') 71 | 72 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 73 | 74 | mapping.append({'old': old_item, 'new': new_item}) 75 | 76 | return mapping 77 | 78 | 79 | def renew_attention_paths(old_list, n_shave_prefix_segments=0): 80 | """ 81 | Updates paths inside attentions to the new naming scheme (local renaming) 82 | """ 83 | mapping = [] 84 | for old_item in old_list: 85 | new_item = old_item 86 | 87 | # new_item = new_item.replace('norm.weight', 'group_norm.weight') 88 | # new_item = new_item.replace('norm.bias', 'group_norm.bias') 89 | 90 | # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') 91 | # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') 92 | 93 | # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 94 | 95 | mapping.append({'old': old_item, 'new': new_item}) 96 | 97 | return mapping 98 | 99 | 100 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): 101 | """ 102 | Updates paths inside attentions to the new naming scheme (local renaming) 103 | """ 104 | mapping = [] 105 | for old_item in old_list: 106 | new_item = old_item 107 | 108 | new_item = new_item.replace('norm.weight', 'group_norm.weight') 109 | new_item = new_item.replace('norm.bias', 'group_norm.bias') 110 | 111 | new_item = new_item.replace('q.weight', 'query.weight') 112 | new_item = new_item.replace('q.bias', 'query.bias') 113 | 114 | new_item = new_item.replace('k.weight', 'key.weight') 115 | new_item = new_item.replace('k.bias', 'key.bias') 116 | 117 | new_item = new_item.replace('v.weight', 'value.weight') 118 | new_item = new_item.replace('v.bias', 'value.bias') 119 | 120 | new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') 121 | new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') 122 | 123 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 124 | 125 | mapping.append({'old': old_item, 'new': new_item}) 126 | 127 | return mapping 128 | 129 | 130 | def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None): 131 | """ 132 | This does the final conversion step: take locally converted weights and apply a global renaming 133 | to them. It splits attention layers, and takes into account additional replacements 134 | that may arise. 135 | 136 | Assigns the weights to the new checkpoint. 137 | """ 138 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." 139 | 140 | # Splits the attention layers into three variables. 141 | if attention_paths_to_split is not None: 142 | for path, path_map in attention_paths_to_split.items(): 143 | old_tensor = old_checkpoint[path] 144 | channels = old_tensor.shape[0] // 3 145 | 146 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) 147 | 148 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 149 | 150 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) 151 | query, key, value = old_tensor.split(channels // num_heads, dim=1) 152 | 153 | checkpoint[path_map['query']] = query.reshape(target_shape) 154 | checkpoint[path_map['key']] = key.reshape(target_shape) 155 | checkpoint[path_map['value']] = value.reshape(target_shape) 156 | 157 | for path in paths: 158 | new_path = path['new'] 159 | 160 | # These have already been assigned 161 | if attention_paths_to_split is not None and new_path in attention_paths_to_split: 162 | continue 163 | 164 | # Global renaming happens here 165 | new_path = new_path.replace('middle_block.0', 'mid_block.resnets.0') 166 | new_path = new_path.replace('middle_block.1', 'mid_block.attentions.0') 167 | new_path = new_path.replace('middle_block.2', 'mid_block.resnets.1') 168 | 169 | if additional_replacements is not None: 170 | for replacement in additional_replacements: 171 | new_path = new_path.replace(replacement['old'], replacement['new']) 172 | 173 | # proj_attn.weight has to be converted from conv 1D to linear 174 | if "proj_attn.weight" in new_path: 175 | checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0] 176 | else: 177 | checkpoint[new_path] = old_checkpoint[path['old']] 178 | 179 | 180 | def conv_attn_to_linear(checkpoint): 181 | keys = list(checkpoint.keys()) 182 | attn_keys = ["query.weight", "key.weight", "value.weight"] 183 | for key in keys: 184 | if ".".join(key.split(".")[-2:]) in attn_keys: 185 | if checkpoint[key].ndim > 2: 186 | checkpoint[key] = checkpoint[key][:, :, 0, 0] 187 | elif "proj_attn.weight" in key: 188 | if checkpoint[key].ndim > 2: 189 | checkpoint[key] = checkpoint[key][:, :, 0] 190 | 191 | 192 | def create_unet_diffusers_config(original_config): 193 | """ 194 | Creates a config for the diffusers based on the config of the LDM model. 195 | """ 196 | unet_params = original_config.model.params.unet_config.params 197 | 198 | block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] 199 | 200 | down_block_types = [] 201 | resolution = 1 202 | for i in range(len(block_out_channels)): 203 | block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" 204 | down_block_types.append(block_type) 205 | if i != len(block_out_channels) - 1: 206 | resolution *= 2 207 | 208 | up_block_types = [] 209 | for i in range(len(block_out_channels)): 210 | block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" 211 | up_block_types.append(block_type) 212 | resolution //= 2 213 | 214 | config = dict( 215 | sample_size=unet_params.image_size, 216 | in_channels=unet_params.in_channels, 217 | out_channels=unet_params.out_channels, 218 | down_block_types=tuple(down_block_types), 219 | up_block_types=tuple(up_block_types), 220 | block_out_channels=tuple(block_out_channels), 221 | layers_per_block=unet_params.num_res_blocks, 222 | cross_attention_dim=unet_params.context_dim, 223 | attention_head_dim=unet_params.num_heads, 224 | ) 225 | 226 | return config 227 | 228 | 229 | def create_vae_diffusers_config(original_config): 230 | """ 231 | Creates a config for the diffusers based on the config of the LDM model. 232 | """ 233 | vae_params = original_config.model.params.first_stage_config.params.ddconfig 234 | latent_channles = original_config.model.params.first_stage_config.params.embed_dim 235 | 236 | block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] 237 | down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) 238 | up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) 239 | 240 | config = dict( 241 | sample_size=vae_params.resolution, 242 | in_channels=vae_params.in_channels, 243 | out_channels=vae_params.out_ch, 244 | down_block_types=tuple(down_block_types), 245 | up_block_types=tuple(up_block_types), 246 | block_out_channels=tuple(block_out_channels), 247 | latent_channels=vae_params.z_channels, 248 | layers_per_block=vae_params.num_res_blocks, 249 | ) 250 | return config 251 | 252 | 253 | def create_diffusers_schedular(original_config): 254 | schedular = DDIMScheduler( 255 | num_train_timesteps=original_config.model.params.timesteps, 256 | beta_start=original_config.model.params.linear_start, 257 | beta_end=original_config.model.params.linear_end, 258 | beta_schedule="scaled_linear", 259 | ) 260 | return schedular 261 | 262 | 263 | def create_ldm_bert_config(original_config): 264 | bert_params = original_config.model.parms.cond_stage_config.params 265 | config = LDMBertConfig( 266 | d_model=bert_params.n_embed, 267 | encoder_layers=bert_params.n_layer, 268 | encoder_ffn_dim=bert_params.n_embed * 4, 269 | ) 270 | return config 271 | 272 | 273 | def convert_ldm_unet_checkpoint(checkpoint, config): 274 | """ 275 | Takes a state dict and a config, and returns a converted checkpoint. 276 | """ 277 | 278 | # extract state_dict for UNet 279 | unet_state_dict = {} 280 | unet_key = "model.diffusion_model." 281 | keys = list(checkpoint.keys()) 282 | for key in keys: 283 | if key.startswith(unet_key): 284 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) 285 | 286 | new_checkpoint = {} 287 | 288 | new_checkpoint['time_embedding.linear_1.weight'] = unet_state_dict['time_embed.0.weight'] 289 | new_checkpoint['time_embedding.linear_1.bias'] = unet_state_dict['time_embed.0.bias'] 290 | new_checkpoint['time_embedding.linear_2.weight'] = unet_state_dict['time_embed.2.weight'] 291 | new_checkpoint['time_embedding.linear_2.bias'] = unet_state_dict['time_embed.2.bias'] 292 | 293 | new_checkpoint['conv_in.weight'] = unet_state_dict['input_blocks.0.0.weight'] 294 | new_checkpoint['conv_in.bias'] = unet_state_dict['input_blocks.0.0.bias'] 295 | 296 | new_checkpoint['conv_norm_out.weight'] = unet_state_dict['out.0.weight'] 297 | new_checkpoint['conv_norm_out.bias'] = unet_state_dict['out.0.bias'] 298 | new_checkpoint['conv_out.weight'] = unet_state_dict['out.2.weight'] 299 | new_checkpoint['conv_out.bias'] = unet_state_dict['out.2.bias'] 300 | 301 | # Retrieves the keys for the input blocks only 302 | num_input_blocks = len({'.'.join(layer.split('.')[:2]) for layer in unet_state_dict if 'input_blocks' in layer}) 303 | input_blocks = {layer_id: [key for key in unet_state_dict if f'input_blocks.{layer_id}' in key] for layer_id in range(num_input_blocks)} 304 | 305 | # Retrieves the keys for the middle blocks only 306 | num_middle_blocks = len({'.'.join(layer.split('.')[:2]) for layer in unet_state_dict if 'middle_block' in layer}) 307 | middle_blocks = {layer_id: [key for key in unet_state_dict if f'middle_block.{layer_id}' in key] for layer_id in range(num_middle_blocks)} 308 | 309 | # Retrieves the keys for the output blocks only 310 | num_output_blocks = len({'.'.join(layer.split('.')[:2]) for layer in unet_state_dict if 'output_blocks' in layer}) 311 | output_blocks = {layer_id: [key for key in unet_state_dict if f'output_blocks.{layer_id}' in key] for layer_id in range(num_output_blocks)} 312 | 313 | for i in range(1, num_input_blocks): 314 | block_id = (i - 1) // (config['layers_per_block'] + 1) 315 | layer_in_block_id = (i - 1) % (config['layers_per_block'] + 1) 316 | 317 | resnets = [key for key in input_blocks[i] if f'input_blocks.{i}.0' in key and f'input_blocks.{i}.0.op' not in key] 318 | attentions = [key for key in input_blocks[i] if f'input_blocks.{i}.1' in key] 319 | 320 | if f'input_blocks.{i}.0.op.weight' in unet_state_dict: 321 | new_checkpoint[f'down_blocks.{block_id}.downsamplers.0.conv.weight'] = unet_state_dict.pop(f'input_blocks.{i}.0.op.weight') 322 | new_checkpoint[f'down_blocks.{block_id}.downsamplers.0.conv.bias'] = unet_state_dict.pop(f'input_blocks.{i}.0.op.bias') 323 | 324 | paths = renew_resnet_paths(resnets) 325 | meta_path = {'old': f'input_blocks.{i}.0', 'new': f'down_blocks.{block_id}.resnets.{layer_in_block_id}'} 326 | assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) 327 | 328 | if len(attentions): 329 | paths = renew_attention_paths(attentions) 330 | meta_path = {'old': f'input_blocks.{i}.1', 'new': f'down_blocks.{block_id}.attentions.{layer_in_block_id}'} 331 | assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) 332 | 333 | 334 | resnet_0 = middle_blocks[0] 335 | attentions = middle_blocks[1] 336 | resnet_1 = middle_blocks[2] 337 | 338 | resnet_0_paths = renew_resnet_paths(resnet_0) 339 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) 340 | 341 | resnet_1_paths = renew_resnet_paths(resnet_1) 342 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) 343 | 344 | attentions_paths = renew_attention_paths(attentions) 345 | meta_path = {'old': 'middle_block.1', 'new': 'mid_block.attentions.0'} 346 | assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) 347 | 348 | for i in range(num_output_blocks): 349 | block_id = i // (config['layers_per_block'] + 1) 350 | layer_in_block_id = i % (config['layers_per_block'] + 1) 351 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] 352 | output_block_list = {} 353 | 354 | for layer in output_block_layers: 355 | layer_id, layer_name = layer.split('.')[0], shave_segments(layer, 1) 356 | if layer_id in output_block_list: 357 | output_block_list[layer_id].append(layer_name) 358 | else: 359 | output_block_list[layer_id] = [layer_name] 360 | 361 | if len(output_block_list) > 1: 362 | resnets = [key for key in output_blocks[i] if f'output_blocks.{i}.0' in key] 363 | attentions = [key for key in output_blocks[i] if f'output_blocks.{i}.1' in key] 364 | 365 | resnet_0_paths = renew_resnet_paths(resnets) 366 | paths = renew_resnet_paths(resnets) 367 | 368 | meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'} 369 | assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) 370 | 371 | if ['conv.weight', 'conv.bias'] in output_block_list.values(): 372 | index = list(output_block_list.values()).index(['conv.weight', 'conv.bias']) 373 | new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = unet_state_dict[f'output_blocks.{i}.{index}.conv.weight'] 374 | new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = unet_state_dict[f'output_blocks.{i}.{index}.conv.bias'] 375 | 376 | # Clear attentions as they have been attributed above. 377 | if len(attentions) == 2: 378 | attentions = [] 379 | 380 | if len(attentions): 381 | paths = renew_attention_paths(attentions) 382 | meta_path = { 383 | 'old': f'output_blocks.{i}.1', 384 | 'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}' 385 | } 386 | assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) 387 | else: 388 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) 389 | for path in resnet_0_paths: 390 | old_path = '.'.join(['output_blocks', str(i), path['old']]) 391 | new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']]) 392 | 393 | new_checkpoint[new_path] = unet_state_dict[old_path] 394 | 395 | return new_checkpoint 396 | 397 | 398 | def convert_ldm_vae_checkpoint(checkpoint, config): 399 | # extract state dict for VAE 400 | vae_state_dict = {} 401 | vae_key = "first_stage_model." 402 | keys = list(checkpoint.keys()) 403 | for key in keys: 404 | if key.startswith(vae_key): 405 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) 406 | 407 | new_checkpoint = {} 408 | 409 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] 410 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] 411 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] 412 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] 413 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] 414 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] 415 | 416 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] 417 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] 418 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] 419 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] 420 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] 421 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] 422 | 423 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] 424 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] 425 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] 426 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] 427 | 428 | 429 | # Retrieves the keys for the encoder down blocks only 430 | num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in vae_state_dict if 'encoder.down' in layer}) 431 | down_blocks = {layer_id: [key for key in vae_state_dict if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)} 432 | 433 | # Retrieves the keys for the decoder up blocks only 434 | num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in vae_state_dict if 'decoder.up' in layer}) 435 | up_blocks = {layer_id: [key for key in vae_state_dict if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)} 436 | 437 | 438 | for i in range(num_down_blocks): 439 | resnets = [key for key in down_blocks[i] if f'down.{i}' in key and f"down.{i}.downsample" not in key] 440 | 441 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: 442 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") 443 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") 444 | 445 | paths = renew_vae_resnet_paths(resnets) 446 | meta_path = {'old': f'down.{i}.block', 'new': f'down_blocks.{i}.resnets'} 447 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 448 | 449 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] 450 | num_mid_res_blocks = 2 451 | for i in range(1, num_mid_res_blocks + 1): 452 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] 453 | 454 | paths = renew_vae_resnet_paths(resnets) 455 | meta_path = {'old': f'mid.block_{i}', 'new': f'mid_block.resnets.{i - 1}'} 456 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 457 | 458 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] 459 | paths = renew_vae_attention_paths(mid_attentions) 460 | meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'} 461 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 462 | conv_attn_to_linear(new_checkpoint) 463 | 464 | for i in range(num_up_blocks): 465 | block_id = num_up_blocks - 1 - i 466 | resnets = [key for key in up_blocks[block_id] if f'up.{block_id}' in key and f"up.{block_id}.upsample" not in key] 467 | 468 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: 469 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] 470 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] 471 | 472 | paths = renew_vae_resnet_paths(resnets) 473 | meta_path = {'old': f'up.{block_id}.block', 'new': f'up_blocks.{i}.resnets'} 474 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 475 | 476 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] 477 | num_mid_res_blocks = 2 478 | for i in range(1, num_mid_res_blocks + 1): 479 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] 480 | 481 | paths = renew_vae_resnet_paths(resnets) 482 | meta_path = {'old': f'mid.block_{i}', 'new': f'mid_block.resnets.{i - 1}'} 483 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 484 | 485 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] 486 | paths = renew_vae_attention_paths(mid_attentions) 487 | meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'} 488 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 489 | conv_attn_to_linear(new_checkpoint) 490 | return new_checkpoint 491 | 492 | 493 | def convert_ldm_bert_checkpoint(checkpoint, config): 494 | def _copy_attn_layer(hf_attn_layer, pt_attn_layer): 495 | 496 | hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight 497 | hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight 498 | hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight 499 | 500 | hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight 501 | hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias 502 | 503 | 504 | def _copy_linear(hf_linear, pt_linear): 505 | hf_linear.weight = pt_linear.weight 506 | hf_linear.bias = pt_linear.bias 507 | 508 | 509 | def _copy_layer(hf_layer, pt_layer): 510 | # copy layer norms 511 | _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) 512 | _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) 513 | 514 | # copy attn 515 | _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) 516 | 517 | # copy MLP 518 | pt_mlp = pt_layer[1][1] 519 | _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) 520 | _copy_linear(hf_layer.fc2, pt_mlp.net[2]) 521 | 522 | 523 | def _copy_layers(hf_layers, pt_layers): 524 | for i, hf_layer in enumerate(hf_layers): 525 | if i != 0: i += i 526 | pt_layer = pt_layers[i:i+2] 527 | _copy_layer(hf_layer, pt_layer) 528 | 529 | hf_model = LDMBertModel(config).eval() 530 | 531 | # copy embeds 532 | hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight 533 | hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight 534 | 535 | # copy layer norm 536 | _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) 537 | 538 | # copy hidden layers 539 | _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) 540 | 541 | _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) 542 | 543 | return hf_model 544 | 545 | 546 | 547 | if __name__ == "__main__": 548 | parser = argparse.ArgumentParser() 549 | 550 | parser.add_argument( 551 | "checkpoint_path", default='./model.ckpt', type=str, help="Path to the checkpoint to convert." 552 | ) 553 | 554 | 555 | parser.add_argument( 556 | "dump_path", default='./model', type=str, help="Path to the output model." 557 | ) 558 | 559 | parser.add_argument( 560 | "--original_config_file", 561 | default='./ckpt_models/model.yaml', 562 | type=str, 563 | required=False, 564 | help="The YAML config file corresponding to the original architecture.", 565 | ) 566 | 567 | args = parser.parse_args() 568 | 569 | original_config = OmegaConf.load(args.original_config_file) 570 | 571 | checkpoint = torch.load(args.checkpoint_path)["state_dict"] 572 | 573 | # Convert the UNet2DConditionModel model. 574 | unet_config = create_unet_diffusers_config(original_config) 575 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config) 576 | 577 | unet = UNet2DConditionModel(**unet_config) 578 | unet.load_state_dict(converted_unet_checkpoint) 579 | 580 | # Convert the VAE model. 581 | vae_config = create_vae_diffusers_config(original_config) 582 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) 583 | 584 | vae = AutoencoderKL(**vae_config) 585 | vae.load_state_dict(converted_vae_checkpoint) 586 | 587 | 588 | 589 | # Convert the text model. 590 | text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] 591 | 592 | script_path = os.path.realpath(__file__) 593 | default_model_path = os.path.join(os.path.dirname(script_path), "diffusers-models") 594 | 595 | try: 596 | text_model = CLIPTextModel.from_pretrained(os.path.join(default_model_path, "clip-vit-large-patch14")) 597 | tokenizer = CLIPTokenizer.from_pretrained(os.path.join(default_model_path, "clip-vit-large-patch14")) 598 | safety_checker = StableDiffusionSafetyChecker.from_pretrained(os.path.join(default_model_path, "safety-checker")) 599 | 600 | except Exception as e: 601 | print(e) 602 | print("Could not load the default text model. Auto downloading...") 603 | if text_model_type == "FrozenCLIPEmbedder": 604 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") 605 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") 606 | else: 607 | # TODO: update the convert function to use the state_dict without the model instance. 608 | text_config = create_ldm_bert_config(original_config) 609 | text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) 610 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 611 | 612 | safety_checker = StableDiffusionSafetyChecker.from_pretrained('CompVis/stable-diffusion-safety-checker') 613 | 614 | scheduler = create_diffusers_schedular(original_config) 615 | 616 | scheduler = create_diffusers_schedular(original_config) 617 | feature_extractor = CLIPFeatureExtractor() 618 | pipe = StableDiffusionPipeline(vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor) 619 | pipe.save_pretrained(args.dump_path) 620 | -------------------------------------------------------------------------------- /tools/ckpt_merge.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from tqdm import tqdm 5 | 6 | parser = argparse.ArgumentParser(description="Merge two models") 7 | parser.add_argument("model_0", type=str, help="Path to model 0") 8 | parser.add_argument("model_1", type=str, help="Path to model 1") 9 | parser.add_argument("--alpha", type=float, help="Alpha value, optional, defaults to 0.5", default=0.5, required=False) 10 | parser.add_argument("--output", type=str, help="Output file name, without extension", default="merged", required=False) 11 | parser.add_argument("--device", type=str, help="Device to use, defaults to cpu", default="cpu", required=False) 12 | parser.add_argument("--without_vae", action="store_true", help="Do not merge VAE", required=False) 13 | 14 | args = parser.parse_args() 15 | 16 | device = args.device 17 | model_0 = torch.load(args.model_0, map_location=device) 18 | model_1 = torch.load(args.model_1, map_location=device) 19 | theta_0 = model_0["state_dict"] 20 | theta_1 = model_1["state_dict"] 21 | alpha = args.alpha 22 | 23 | output_file = f'{args.output}-{str(alpha)[2:] + "0"}.ckpt' 24 | 25 | # check if output file already exists, ask to overwrite 26 | if os.path.isfile(output_file): 27 | print("Output file already exists. Overwrite? (y/n)") 28 | while True: 29 | overwrite = input() 30 | if overwrite == "y": 31 | break 32 | elif overwrite == "n": 33 | print("Exiting...") 34 | exit() 35 | else: 36 | print("Please enter y or n") 37 | 38 | 39 | for key in tqdm(theta_0.keys(), desc="Stage 1/2"): 40 | # skip VAE model parameters to get better results(tested for anime models) 41 | # for anime model,with merging VAE model, the result will be worse (dark and blurry) 42 | if args.without_vae and "first_stage_model" in key: 43 | continue 44 | 45 | if "model" in key and key in theta_1: 46 | theta_0[key] = (1 - alpha) * theta_0[key] + alpha * theta_1[key] 47 | 48 | for key in tqdm(theta_1.keys(), desc="Stage 2/2"): 49 | if "model" in key and key not in theta_0: 50 | theta_0[key] = theta_1[key] 51 | 52 | print("Saving...") 53 | 54 | torch.save({"state_dict": theta_0}, output_file) 55 | 56 | print("Done!") 57 | -------------------------------------------------------------------------------- /tools/ckpt_prune.py: -------------------------------------------------------------------------------- 1 | sd = torch.load(model_path, map_location="cpu") 2 | if "state_dict" not in sd: 3 | pruned_sd = { 4 | "state_dict": dict(), 5 | } 6 | else: 7 | pruned_sd = dict() 8 | for k in sd.keys(): 9 | if k != "optimizer_states": 10 | if "state_dict" not in sd: 11 | pruned_sd["state_dict"][k] = sd[k] 12 | else: 13 | pruned_sd[k] = sd[k] 14 | torch.save(pruned_sd, "model-pruned.ckpt") -------------------------------------------------------------------------------- /tools/deepdanbooru-models/put_deepdanbooru_model_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrazyBoyM/dreambooth-for-diffusion/311717098636b9af72c4907fcc9436df9eb2c352/tools/deepdanbooru-models/put_deepdanbooru_model_here.txt -------------------------------------------------------------------------------- /tools/diagnose_tensorboard.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Self-diagnosis script for TensorBoard. 16 | 17 | Instructions: Save this script to your local machine, then execute it in 18 | the same environment (virtualenv, Conda, etc.) from which you normally 19 | run TensorBoard. Read the output and follow the directions. 20 | """ 21 | 22 | 23 | # This script may only depend on the Python standard library. It is not 24 | # built with Bazel and should not assume any third-party dependencies. 25 | import dataclasses 26 | import errno 27 | import functools 28 | import hashlib 29 | import inspect 30 | import logging 31 | import os 32 | import pipes 33 | import shlex 34 | import socket 35 | import subprocess 36 | import sys 37 | import tempfile 38 | import textwrap 39 | import traceback 40 | 41 | 42 | # A *check* is a function (of no arguments) that performs a diagnostic, 43 | # writes log messages, and optionally yields suggestions. Each check 44 | # runs in isolation; exceptions will be caught and reported. 45 | CHECKS = [] 46 | 47 | 48 | @dataclasses.dataclass(frozen=True) 49 | class Suggestion: 50 | """A suggestion to the end user. 51 | 52 | Attributes: 53 | headline: A short description, like "Turn it off and on again". Should be 54 | imperative with no trailing punctuation. May contain inline Markdown. 55 | description: A full enumeration of the steps that the user should take to 56 | accept the suggestion. Within this string, prose should be formatted 57 | with `reflow`. May contain Markdown. 58 | """ 59 | 60 | headline: str 61 | description: str 62 | 63 | 64 | def check(fn): 65 | """Decorator to register a function as a check. 66 | 67 | Checks are run in the order in which they are registered. 68 | 69 | Args: 70 | fn: A function that takes no arguments and either returns `None` or 71 | returns a generator of `Suggestion`s. (The ability to return 72 | `None` is to work around the awkwardness of defining empty 73 | generator functions in Python.) 74 | 75 | Returns: 76 | A wrapped version of `fn` that returns a generator of `Suggestion`s. 77 | """ 78 | 79 | @functools.wraps(fn) 80 | def wrapper(): 81 | result = fn() 82 | return iter(()) if result is None else result 83 | 84 | CHECKS.append(wrapper) 85 | return wrapper 86 | 87 | 88 | def reflow(paragraph): 89 | return textwrap.fill(textwrap.dedent(paragraph).strip()) 90 | 91 | 92 | def pip(args): 93 | """Invoke command-line Pip with the specified args. 94 | 95 | Returns: 96 | A bytestring containing the output of Pip. 97 | """ 98 | # Suppress the Python 2.7 deprecation warning. 99 | PYTHONWARNINGS_KEY = "PYTHONWARNINGS" 100 | old_pythonwarnings = os.environ.get(PYTHONWARNINGS_KEY) 101 | new_pythonwarnings = "%s%s" % ( 102 | "ignore:DEPRECATION", 103 | ",%s" % old_pythonwarnings if old_pythonwarnings else "", 104 | ) 105 | command = [sys.executable, "-m", "pip", "--disable-pip-version-check"] 106 | command.extend(args) 107 | try: 108 | os.environ[PYTHONWARNINGS_KEY] = new_pythonwarnings 109 | return subprocess.check_output(command) 110 | finally: 111 | if old_pythonwarnings is None: 112 | del os.environ[PYTHONWARNINGS_KEY] 113 | else: 114 | os.environ[PYTHONWARNINGS_KEY] = old_pythonwarnings 115 | 116 | 117 | def which(name): 118 | """Return the path to a binary, or `None` if it's not on the path. 119 | 120 | Returns: 121 | A bytestring. 122 | """ 123 | binary = "where" if os.name == "nt" else "which" 124 | try: 125 | return subprocess.check_output([binary, name]) 126 | except subprocess.CalledProcessError: 127 | return None 128 | 129 | 130 | def sgetattr(attr, default): 131 | """Get an attribute off the `socket` module, or use a default.""" 132 | sentinel = object() 133 | result = getattr(socket, attr, sentinel) 134 | if result is sentinel: 135 | print("socket.%s does not exist" % attr) 136 | return default 137 | else: 138 | print("socket.%s = %r" % (attr, result)) 139 | return result 140 | 141 | 142 | @check 143 | def autoidentify(): 144 | """Print the Git hash of this version of `diagnose_tensorboard.py`. 145 | 146 | Given this hash, use `git cat-file blob HASH` to recover the 147 | relevant version of the script. 148 | """ 149 | module = sys.modules[__name__] 150 | try: 151 | source = inspect.getsource(module).encode("utf-8") 152 | except TypeError: 153 | logging.info("diagnose_tensorboard.py source unavailable") 154 | else: 155 | # Git inserts a length-prefix before hashing; cf. `git-hash-object`. 156 | blob = b"blob %d\0%s" % (len(source), source) 157 | hash = hashlib.sha1(blob).hexdigest() 158 | logging.info("diagnose_tensorboard.py version %s", hash) 159 | 160 | 161 | @check 162 | def general(): 163 | logging.info("sys.version_info: %s", sys.version_info) 164 | logging.info("os.name: %s", os.name) 165 | na = type("N/A", (object,), {"__repr__": lambda self: "N/A"}) 166 | logging.info( 167 | "os.uname(): %r", 168 | getattr(os, "uname", na)(), 169 | ) 170 | logging.info( 171 | "sys.getwindowsversion(): %r", 172 | getattr(sys, "getwindowsversion", na)(), 173 | ) 174 | 175 | 176 | @check 177 | def package_management(): 178 | conda_meta = os.path.join(sys.prefix, "conda-meta") 179 | logging.info("has conda-meta: %s", os.path.exists(conda_meta)) 180 | logging.info("$VIRTUAL_ENV: %r", os.environ.get("VIRTUAL_ENV")) 181 | 182 | 183 | @check 184 | def installed_packages(): 185 | freeze = pip(["freeze", "--all"]).decode("utf-8").splitlines() 186 | packages = {line.split("==")[0]: line for line in freeze} 187 | packages_set = frozenset(packages) 188 | 189 | # For each of the following families, expect exactly one package to be 190 | # installed. 191 | expect_unique = [ 192 | frozenset( 193 | [ 194 | "tensorboard", 195 | "tb-nightly", 196 | "tensorflow-tensorboard", 197 | ] 198 | ), 199 | frozenset( 200 | [ 201 | "tensorflow", 202 | "tensorflow-gpu", 203 | "tf-nightly", 204 | "tf-nightly-2.0-preview", 205 | "tf-nightly-gpu", 206 | "tf-nightly-gpu-2.0-preview", 207 | ] 208 | ), 209 | frozenset( 210 | [ 211 | "tensorflow-estimator", 212 | "tensorflow-estimator-2.0-preview", 213 | "tf-estimator-nightly", 214 | ] 215 | ), 216 | ] 217 | salient_extras = frozenset(["tensorboard-data-server"]) 218 | 219 | found_conflict = False 220 | for family in expect_unique: 221 | actual = family & packages_set 222 | for package in actual: 223 | logging.info("installed: %s", packages[package]) 224 | if len(actual) == 0: 225 | logging.warning("no installation among: %s", sorted(family)) 226 | elif len(actual) > 1: 227 | logging.warning("conflicting installations: %s", sorted(actual)) 228 | found_conflict = True 229 | for package in sorted(salient_extras & packages_set): 230 | logging.info("installed: %s", packages[package]) 231 | 232 | if found_conflict: 233 | preamble = reflow( 234 | """ 235 | Conflicting package installations found. Depending on the order 236 | of installations and uninstallations, behavior may be undefined. 237 | Please uninstall ALL versions of TensorFlow and TensorBoard, 238 | then reinstall ONLY the desired version of TensorFlow, which 239 | will transitively pull in the proper version of TensorBoard. (If 240 | you use TensorBoard without TensorFlow, just reinstall the 241 | appropriate version of TensorBoard directly.) 242 | """ 243 | ) 244 | packages_to_uninstall = sorted( 245 | frozenset().union(*expect_unique) & packages_set 246 | ) 247 | commands = [ 248 | "pip uninstall %s" % " ".join(packages_to_uninstall), 249 | "pip install tensorflow # or `tensorflow-gpu`, or `tf-nightly`, ...", 250 | ] 251 | message = "%s\n\nNamely:\n\n%s" % ( 252 | preamble, 253 | "\n".join("\t%s" % c for c in commands), 254 | ) 255 | yield Suggestion("Fix conflicting installations", message) 256 | 257 | wit_version = packages.get("tensorboard-plugin-wit") 258 | if wit_version == "tensorboard-plugin-wit==1.6.0.post2": 259 | # This is only incompatible with TensorBoard prior to 2.2.0, but 260 | # we just issue a blanket warning so that we don't have to pull 261 | # in a `pkg_resources` dep to parse the version. 262 | preamble = reflow( 263 | """ 264 | Versions of the What-If Tool (`tensorboard-plugin-wit`) 265 | prior to 1.6.0.post3 are incompatible with some versions of 266 | TensorBoard. Please upgrade this package to the latest 267 | version to resolve any startup errors: 268 | """ 269 | ) 270 | command = "pip install -U tensorboard-plugin-wit" 271 | message = "%s\n\n\t%s" % (preamble, command) 272 | yield Suggestion("Upgrade `tensorboard-plugin-wit`", message) 273 | 274 | 275 | @check 276 | def tensorboard_python_version(): 277 | from tensorboard import version 278 | 279 | logging.info("tensorboard.version.VERSION: %r", version.VERSION) 280 | 281 | 282 | @check 283 | def tensorflow_python_version(): 284 | import tensorflow as tf 285 | 286 | logging.info("tensorflow.__version__: %r", tf.__version__) 287 | logging.info("tensorflow.__git_version__: %r", tf.__git_version__) 288 | 289 | 290 | @check 291 | def tensorboard_data_server_version(): 292 | try: 293 | import tensorboard_data_server 294 | except ImportError: 295 | logging.info("no data server installed") 296 | return 297 | 298 | path = tensorboard_data_server.server_binary() 299 | logging.info("data server binary: %r", path) 300 | if path is None: 301 | return 302 | 303 | try: 304 | subprocess_output = subprocess.run( 305 | [path, "--version"], 306 | capture_output=True, 307 | check=True, 308 | ) 309 | except subprocess.CalledProcessError as e: 310 | logging.info("failed to check binary version: %s", e) 311 | else: 312 | logging.info( 313 | "data server binary version: %s", subprocess_output.stdout.strip() 314 | ) 315 | 316 | 317 | @check 318 | def tensorboard_binary_path(): 319 | logging.info("which tensorboard: %r", which("tensorboard")) 320 | 321 | 322 | @check 323 | def addrinfos(): 324 | sgetattr("has_ipv6", None) 325 | family = sgetattr("AF_UNSPEC", 0) 326 | socktype = sgetattr("SOCK_STREAM", 0) 327 | protocol = 0 328 | flags_loopback = sgetattr("AI_ADDRCONFIG", 0) 329 | flags_wildcard = sgetattr("AI_PASSIVE", 0) 330 | 331 | hints_loopback = (family, socktype, protocol, flags_loopback) 332 | infos_loopback = socket.getaddrinfo(None, 0, *hints_loopback) 333 | print("Loopback flags: %r" % (flags_loopback,)) 334 | print("Loopback infos: %r" % (infos_loopback,)) 335 | 336 | hints_wildcard = (family, socktype, protocol, flags_wildcard) 337 | infos_wildcard = socket.getaddrinfo(None, 0, *hints_wildcard) 338 | print("Wildcard flags: %r" % (flags_wildcard,)) 339 | print("Wildcard infos: %r" % (infos_wildcard,)) 340 | 341 | 342 | @check 343 | def readable_fqdn(): 344 | # May raise `UnicodeDecodeError` for non-ASCII hostnames: 345 | # https://github.com/tensorflow/tensorboard/issues/682 346 | try: 347 | logging.info("socket.getfqdn(): %r", socket.getfqdn()) 348 | except UnicodeDecodeError as e: 349 | try: 350 | binary_hostname = subprocess.check_output(["hostname"]).strip() 351 | except subprocess.CalledProcessError: 352 | binary_hostname = b"" 353 | is_non_ascii = not all( 354 | 0x20 355 | <= (ord(c) if not isinstance(c, int) else c) 356 | <= 0x7E # Python 2 357 | for c in binary_hostname 358 | ) 359 | if is_non_ascii: 360 | message = reflow( 361 | """ 362 | Your computer's hostname, %r, contains bytes outside of the 363 | printable ASCII range. Some versions of Python have trouble 364 | working with such names (https://bugs.python.org/issue26227). 365 | Consider changing to a hostname that only contains printable 366 | ASCII bytes. 367 | """ 368 | % (binary_hostname,) 369 | ) 370 | yield Suggestion("Use an ASCII hostname", message) 371 | else: 372 | message = reflow( 373 | """ 374 | Python can't read your computer's hostname, %r. This can occur 375 | if the hostname contains non-ASCII bytes 376 | (https://bugs.python.org/issue26227). Consider changing your 377 | hostname, rebooting your machine, and rerunning this diagnosis 378 | script to see if the problem is resolved. 379 | """ 380 | % (binary_hostname,) 381 | ) 382 | yield Suggestion("Use a simpler hostname", message) 383 | raise e 384 | 385 | 386 | @check 387 | def stat_tensorboardinfo(): 388 | # We don't use `manager._get_info_dir`, because (a) that requires 389 | # TensorBoard, and (b) that creates the directory if it doesn't exist. 390 | path = os.path.join(tempfile.gettempdir(), ".tensorboard-info") 391 | logging.info("directory: %s", path) 392 | try: 393 | stat_result = os.stat(path) 394 | except OSError as e: 395 | if e.errno == errno.ENOENT: 396 | # No problem; this is just fine. 397 | logging.info(".tensorboard-info directory does not exist") 398 | return 399 | else: 400 | raise 401 | logging.info("os.stat(...): %r", stat_result) 402 | logging.info("mode: 0o%o", stat_result.st_mode) 403 | if stat_result.st_mode & 0o777 != 0o777: 404 | preamble = reflow( 405 | """ 406 | The ".tensorboard-info" directory was created by an old version 407 | of TensorBoard, and its permissions are not set correctly; see 408 | issue #2010. Change that directory to be world-accessible (may 409 | require superuser privilege): 410 | """ 411 | ) 412 | # This error should only appear on Unices, so it's okay to use 413 | # Unix-specific utilities and shell syntax. 414 | quote = getattr(shlex, "quote", None) or pipes.quote # Python <3.3 415 | command = "chmod 777 %s" % quote(path) 416 | message = "%s\n\n\t%s" % (preamble, command) 417 | yield Suggestion('Fix permissions on "%s"' % path, message) 418 | 419 | 420 | @check 421 | def source_trees_without_genfiles(): 422 | roots = list(sys.path) 423 | if "" not in roots: 424 | # Catch problems that would occur in a Python interactive shell 425 | # (where `""` is prepended to `sys.path`) but not when 426 | # `diagnose_tensorboard.py` is run as a standalone script. 427 | roots.insert(0, "") 428 | 429 | def has_tensorboard(root): 430 | return os.path.isfile(os.path.join(root, "tensorboard", "__init__.py")) 431 | 432 | def has_genfiles(root): 433 | sample_genfile = os.path.join("compat", "proto", "summary_pb2.py") 434 | return os.path.isfile(os.path.join(root, "tensorboard", sample_genfile)) 435 | 436 | def is_bad(root): 437 | return has_tensorboard(root) and not has_genfiles(root) 438 | 439 | tensorboard_roots = [root for root in roots if has_tensorboard(root)] 440 | bad_roots = [root for root in roots if is_bad(root)] 441 | 442 | logging.info( 443 | "tensorboard_roots (%d): %r; bad_roots (%d): %r", 444 | len(tensorboard_roots), 445 | tensorboard_roots, 446 | len(bad_roots), 447 | bad_roots, 448 | ) 449 | 450 | if bad_roots: 451 | if bad_roots == [""]: 452 | message = reflow( 453 | """ 454 | Your current directory contains a `tensorboard` Python package 455 | that does not include generated files. This can happen if your 456 | current directory includes the TensorBoard source tree (e.g., 457 | you are in the TensorBoard Git repository). Consider changing 458 | to a different directory. 459 | """ 460 | ) 461 | else: 462 | preamble = reflow( 463 | """ 464 | Your Python path contains a `tensorboard` package that does 465 | not include generated files. This can happen if your current 466 | directory includes the TensorBoard source tree (e.g., you are 467 | in the TensorBoard Git repository). The following directories 468 | from your Python path may be problematic: 469 | """ 470 | ) 471 | roots = [] 472 | realpaths_seen = set() 473 | for root in bad_roots: 474 | label = repr(root) if root else "current directory" 475 | realpath = os.path.realpath(root) 476 | if realpath in realpaths_seen: 477 | # virtualenvs on Ubuntu install to both `lib` and `local/lib`; 478 | # explicitly call out such duplicates to avoid confusion. 479 | label += " (duplicate underlying directory)" 480 | realpaths_seen.add(realpath) 481 | roots.append(label) 482 | message = "%s\n\n%s" % ( 483 | preamble, 484 | "\n".join(" - %s" % s for s in roots), 485 | ) 486 | yield Suggestion( 487 | "Avoid `tensorboard` packages without genfiles", message 488 | ) 489 | 490 | 491 | # Prefer to include this check last, as its output is long. 492 | @check 493 | def full_pip_freeze(): 494 | logging.info( 495 | "pip freeze --all:\n%s", pip(["freeze", "--all"]).decode("utf-8") 496 | ) 497 | 498 | 499 | def set_up_logging(): 500 | # Manually install handlers to prevent TensorFlow from stomping the 501 | # default configuration if it's imported: 502 | # https://github.com/tensorflow/tensorflow/issues/28147 503 | logger = logging.getLogger() 504 | logger.setLevel(logging.INFO) 505 | handler = logging.StreamHandler(sys.stdout) 506 | handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) 507 | logger.addHandler(handler) 508 | 509 | 510 | def main(): 511 | set_up_logging() 512 | 513 | print("### Diagnostics") 514 | print() 515 | 516 | print("
") 517 | print("Diagnostics output") 518 | print() 519 | 520 | markdown_code_fence = "``````" # seems likely to be sufficient 521 | print(markdown_code_fence) 522 | suggestions = [] 523 | for (i, check) in enumerate(CHECKS): 524 | if i > 0: 525 | print() 526 | print("--- check: %s" % check.__name__) 527 | try: 528 | suggestions.extend(check()) 529 | except Exception: 530 | traceback.print_exc(file=sys.stdout) 531 | pass 532 | print(markdown_code_fence) 533 | print() 534 | print("
") 535 | 536 | for suggestion in suggestions: 537 | print() 538 | print("### Suggestion: %s" % suggestion.headline) 539 | print() 540 | print(suggestion.description) 541 | 542 | print() 543 | print("### Next steps") 544 | print() 545 | if suggestions: 546 | print( 547 | reflow( 548 | """ 549 | Please try each suggestion enumerated above to determine whether 550 | it solves your problem. If none of these suggestions works, 551 | please copy ALL of the above output, including the lines 552 | containing only backticks, into your GitHub issue or comment. Be 553 | sure to redact any sensitive information. 554 | """ 555 | ) 556 | ) 557 | else: 558 | print( 559 | reflow( 560 | """ 561 | No action items identified. Please copy ALL of the above output, 562 | including the lines containing only backticks, into your GitHub 563 | issue or comment. Be sure to redact any sensitive information. 564 | """ 565 | ) 566 | ) 567 | 568 | 569 | if __name__ == "__main__": 570 | main() 571 | -------------------------------------------------------------------------------- /tools/diffusers2ckpt.py: -------------------------------------------------------------------------------- 1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. 2 | # *Only* converts the UNet, VAE, and Text Encoder. 3 | # Does not convert optimizer state or any other thing. 4 | 5 | import argparse 6 | import os.path as osp 7 | 8 | import torch 9 | 10 | 11 | # =================# 12 | # UNet Conversion # 13 | # =================# 14 | 15 | unet_conversion_map = [ 16 | # (stable-diffusion, HF Diffusers) 17 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), 18 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), 19 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), 20 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), 21 | ("input_blocks.0.0.weight", "conv_in.weight"), 22 | ("input_blocks.0.0.bias", "conv_in.bias"), 23 | ("out.0.weight", "conv_norm_out.weight"), 24 | ("out.0.bias", "conv_norm_out.bias"), 25 | ("out.2.weight", "conv_out.weight"), 26 | ("out.2.bias", "conv_out.bias"), 27 | ] 28 | 29 | unet_conversion_map_resnet = [ 30 | # (stable-diffusion, HF Diffusers) 31 | ("in_layers.0", "norm1"), 32 | ("in_layers.2", "conv1"), 33 | ("out_layers.0", "norm2"), 34 | ("out_layers.3", "conv2"), 35 | ("emb_layers.1", "time_emb_proj"), 36 | ("skip_connection", "conv_shortcut"), 37 | ] 38 | 39 | unet_conversion_map_layer = [] 40 | # hardcoded number of downblocks and resnets/attentions... 41 | # would need smarter logic for other networks. 42 | for i in range(4): 43 | # loop over downblocks/upblocks 44 | 45 | for j in range(2): 46 | # loop over resnets/attentions for downblocks 47 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 48 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 49 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 50 | 51 | if i < 3: 52 | # no attention layers in down_blocks.3 53 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 54 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 55 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 56 | 57 | for j in range(3): 58 | # loop over resnets/attentions for upblocks 59 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 60 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 61 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 62 | 63 | if i > 0: 64 | # no attention layers in up_blocks.0 65 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 66 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 67 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 68 | 69 | if i < 3: 70 | # no downsample in down_blocks.3 71 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 72 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 73 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) 74 | 75 | # no upsample in up_blocks.3 76 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 77 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." 78 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) 79 | 80 | hf_mid_atn_prefix = "mid_block.attentions.0." 81 | sd_mid_atn_prefix = "middle_block.1." 82 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 83 | 84 | for j in range(2): 85 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 86 | sd_mid_res_prefix = f"middle_block.{2*j}." 87 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 88 | 89 | 90 | def convert_unet_state_dict(unet_state_dict): 91 | # buyer beware: this is a *brittle* function, 92 | # and correct output requires that all of these pieces interact in 93 | # the exact order in which I have arranged them. 94 | mapping = {k: k for k in unet_state_dict.keys()} 95 | for sd_name, hf_name in unet_conversion_map: 96 | mapping[hf_name] = sd_name 97 | for k, v in mapping.items(): 98 | if "resnets" in k: 99 | for sd_part, hf_part in unet_conversion_map_resnet: 100 | v = v.replace(hf_part, sd_part) 101 | mapping[k] = v 102 | for k, v in mapping.items(): 103 | for sd_part, hf_part in unet_conversion_map_layer: 104 | v = v.replace(hf_part, sd_part) 105 | mapping[k] = v 106 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} 107 | return new_state_dict 108 | 109 | 110 | # ================# 111 | # VAE Conversion # 112 | # ================# 113 | 114 | vae_conversion_map = [ 115 | # (stable-diffusion, HF Diffusers) 116 | ("nin_shortcut", "conv_shortcut"), 117 | ("norm_out", "conv_norm_out"), 118 | ("mid.attn_1.", "mid_block.attentions.0."), 119 | ] 120 | 121 | for i in range(4): 122 | # down_blocks have two resnets 123 | for j in range(2): 124 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." 125 | sd_down_prefix = f"encoder.down.{i}.block.{j}." 126 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) 127 | 128 | if i < 3: 129 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." 130 | sd_downsample_prefix = f"down.{i}.downsample." 131 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) 132 | 133 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 134 | sd_upsample_prefix = f"up.{3-i}.upsample." 135 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) 136 | 137 | # up_blocks have three resnets 138 | # also, up blocks in hf are numbered in reverse from sd 139 | for j in range(3): 140 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." 141 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." 142 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) 143 | 144 | # this part accounts for mid blocks in both the encoder and the decoder 145 | for i in range(2): 146 | hf_mid_res_prefix = f"mid_block.resnets.{i}." 147 | sd_mid_res_prefix = f"mid.block_{i+1}." 148 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) 149 | 150 | 151 | vae_conversion_map_attn = [ 152 | # (stable-diffusion, HF Diffusers) 153 | ("norm.", "group_norm."), 154 | ("q.", "query."), 155 | ("k.", "key."), 156 | ("v.", "value."), 157 | ("proj_out.", "proj_attn."), 158 | ] 159 | 160 | 161 | def reshape_weight_for_sd(w): 162 | # convert HF linear weights to SD conv2d weights 163 | return w.reshape(*w.shape, 1, 1) 164 | 165 | 166 | def convert_vae_state_dict(vae_state_dict): 167 | mapping = {k: k for k in vae_state_dict.keys()} 168 | for k, v in mapping.items(): 169 | for sd_part, hf_part in vae_conversion_map: 170 | v = v.replace(hf_part, sd_part) 171 | mapping[k] = v 172 | for k, v in mapping.items(): 173 | if "attentions" in k: 174 | for sd_part, hf_part in vae_conversion_map_attn: 175 | v = v.replace(hf_part, sd_part) 176 | mapping[k] = v 177 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} 178 | weights_to_convert = ["q", "k", "v", "proj_out"] 179 | for k, v in new_state_dict.items(): 180 | for weight_name in weights_to_convert: 181 | if f"mid.attn_1.{weight_name}.weight" in k: 182 | print(f"Reshaping {k} for SD format") 183 | new_state_dict[k] = reshape_weight_for_sd(v) 184 | return new_state_dict 185 | 186 | 187 | # =========================# 188 | # Text Encoder Conversion # 189 | # =========================# 190 | # pretty much a no-op 191 | 192 | 193 | def convert_text_enc_state_dict(text_enc_dict): 194 | return text_enc_dict 195 | 196 | 197 | if __name__ == "__main__": 198 | parser = argparse.ArgumentParser() 199 | 200 | parser.add_argument("model_path", default=None, type=str, help="Path to the model to convert.") 201 | parser.add_argument("checkpoint_path", default=None, type=str, help="Path to the output model.") 202 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.") 203 | 204 | args = parser.parse_args() 205 | 206 | assert args.model_path is not None, "Must provide a model path!" 207 | 208 | assert args.checkpoint_path is not None, "Must provide a checkpoint path!" 209 | 210 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") 211 | vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") 212 | text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") 213 | 214 | # Convert the UNet model 215 | unet_state_dict = torch.load(unet_path, map_location="cpu") 216 | unet_state_dict = convert_unet_state_dict(unet_state_dict) 217 | unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} 218 | 219 | # Convert the VAE model 220 | vae_state_dict = torch.load(vae_path, map_location="cpu") 221 | vae_state_dict = convert_vae_state_dict(vae_state_dict) 222 | vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} 223 | 224 | # Convert the text encoder model 225 | text_enc_dict = torch.load(text_enc_path, map_location="cpu") 226 | text_enc_dict = convert_text_enc_state_dict(text_enc_dict) 227 | text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} 228 | 229 | # Put together new checkpoint 230 | state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} 231 | if args.half: 232 | state_dict = {k: v.half() for k, v in state_dict.items()} 233 | state_dict = {"state_dict": state_dict} 234 | torch.save(state_dict, args.checkpoint_path) -------------------------------------------------------------------------------- /tools/handle_images.py: -------------------------------------------------------------------------------- 1 | import os, cv2, argparse 2 | import numpy as np 3 | 4 | # 修改透明背景为白色 5 | def transparence2white(img): 6 | sp=img.shape 7 | width=sp[0] 8 | height=sp[1] 9 | for yh in range(height): 10 | for xw in range(width): 11 | color_d=img[xw,yh] 12 | if(color_d[3]==0): 13 | img[xw,yh]=[255,255,255,255] 14 | return img 15 | 16 | # 修改透明背景为黑色 17 | def transparence2black(img): 18 | sp = img.shape 19 | width = sp[0] 20 | height = sp[1] 21 | for yh in range(height): 22 | for xw in range(width): 23 | color_d = img[xw, yh] 24 | if (color_d[3] == 0): 25 | img[xw, yh] = [0, 0, 0, 255] 26 | return img 27 | 28 | # 中心裁剪 29 | def center_crop(img, crop_size): 30 | h, w = img.shape[:2] 31 | th, tw = crop_size 32 | i = int(round((h - th) / 2.)) 33 | j = int(round((w - tw) / 2.)) 34 | return img[i:i + th, j:j + tw] 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser() 38 | 39 | parser.add_argument("--origin_image_path", default=None, type=str, help="Path to the images to convert.") 40 | parser.add_argument("--output_image_path", default=None, type=str, help="Path to the 1:1 output images.") 41 | parser.add_argument("--output_image_path_0", default=None, type=str, help="Path to the 3:2 output images.") 42 | parser.add_argument("--output_image_path_1", default=None, type=str, help="Path to the 2:3 output images.") 43 | parser.add_argument("--width", default=512, type=int, help="Width of the output images.") 44 | parser.add_argument("--height", default=512, type=int, help="Height of the output images.") 45 | parser.add_argument("--png", action="store_true", help="convert the transparent background to white/black.") 46 | 47 | 48 | args = parser.parse_args() 49 | 50 | path = args.origin_image_path 51 | save_path = args.output_image_path 52 | save_path_0 = args.output_image_path_0 53 | save_path_1 = args.output_image_path_1 54 | if save_path!=None: 55 | if not os.path.exists(save_path): 56 | os.makedirs(save_path) 57 | if save_path_0!=None: 58 | if not os.path.exists(save_path_0): 59 | os.makedirs(save_path_0) 60 | if save_path_1!=None: 61 | if not os.path.exists(save_path_1): 62 | os.makedirs(save_path_1) 63 | else: 64 | print('The folder already exists, please check the path.') 65 | 66 | # 只读取png、jpg、jpeg、bmp、webp格式 67 | allow_suffix = ['png', 'jpg', 'jpeg', 'bmp', 'webp'] 68 | image_list = os.listdir(path) 69 | image_list = [os.path.join(path, image) for image in image_list if image.split('.')[-1] in allow_suffix] 70 | 71 | width = args.width 72 | height = args.height 73 | ratio = 3 / 2 74 | for file, i in zip(image_list, range(1, len(image_list)+1)): 75 | print('Processing image: %s' % file) 76 | try: 77 | img = cv2.imread(file) 78 | if width==height: 79 | # 对图像进行center crop, 保证图像的长宽比为1:1, crop_size为图像的较短边 80 | crop_size = min(img.shape[:2]) 81 | img = center_crop(img, (crop_size, crop_size)) 82 | # 缩放图像 83 | img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 84 | if args.png: 85 | img = transparence2black(img) 86 | cv2.imwrite(os.path.join(save_path, str(i).zfill(4) + ".jpg"), img) 87 | else: 88 | height_temp, width_temp, _ = img.shape 89 | # 如果宽度大于高度,则裁剪成3:2的宽高比 90 | if width_temp > height_temp: 91 | new_width = width_temp 92 | new_height = int(width_temp / ratio) 93 | left = 0 94 | top = 0 95 | img = img[top:top+new_height, left:left+new_width] 96 | img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 97 | if args.png: 98 | img = transparence2black(img) 99 | cv2.imwrite(os.path.join(save_path_0, str(i).zfill(4) + ".jpg"), img) 100 | else: 101 | # 反之,则裁剪成2:3的宽高比 102 | new_height = height_temp 103 | new_width = int(height_temp * ratio) 104 | left = 0 105 | top = 0 106 | img = img[top:top+new_height, left:left+new_width] 107 | img = cv2.resize(img, (height,width), interpolation=cv2.INTER_AREA) 108 | if args.png: 109 | img = transparence2black(img) 110 | cv2.imwrite(os.path.join(save_path_1, str(i).zfill(4) + ".jpg"), img) 111 | # img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 112 | 113 | # # 如果是透明图,将透明背景转换为白色或者黑色 114 | # if args.png: 115 | # img = transparence2black(img) 116 | 117 | # cv2.imwrite(os.path.join(save_path, str(i).zfill(4) + ".jpg"), img) 118 | except Exception as e: 119 | print(e) 120 | os.remove(path+file) # 删除无效图片 121 | print("删除无效图片: " + path+file) -------------------------------------------------------------------------------- /tools/label_images.py: -------------------------------------------------------------------------------- 1 | # from AUTOMATC1111 2 | # maybe modified by Nyanko Lepsoni 3 | # modified by crosstyan 4 | import os.path 5 | import re 6 | import tempfile 7 | import argparse 8 | import glob 9 | import zipfile 10 | import deepdanbooru as dd 11 | import tensorflow as tf 12 | import numpy as np 13 | 14 | from basicsr.utils.download_util import load_file_from_url 15 | from PIL import Image 16 | from tqdm import tqdm 17 | 18 | re_special = re.compile(r"([\\()])") 19 | 20 | def get_deepbooru_tags_model(model_path: str): 21 | if not os.path.exists(os.path.join(model_path, "project.json")): 22 | is_abs = os.path.isabs(model_path) 23 | if not is_abs: 24 | model_path = os.path.abspath(model_path) 25 | 26 | load_file_from_url( 27 | r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", 28 | model_path, 29 | ) 30 | with zipfile.ZipFile( 31 | os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r" 32 | ) as zip_ref: 33 | zip_ref.extractall(model_path) 34 | os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip")) 35 | 36 | tags = dd.project.load_tags_from_project(model_path) 37 | model = dd.project.load_model_from_project(model_path, compile_model=False) 38 | return model, tags 39 | 40 | 41 | def get_deepbooru_tags_from_model( 42 | model, 43 | tags, 44 | pil_image, 45 | threshold, 46 | alpha_sort=False, 47 | use_spaces=True, 48 | use_escape=True, 49 | include_ranks=False, 50 | ): 51 | width = model.input_shape[2] 52 | height = model.input_shape[1] 53 | image = np.array(pil_image) 54 | image = tf.image.resize( 55 | image, 56 | size=(height, width), 57 | method=tf.image.ResizeMethod.AREA, 58 | preserve_aspect_ratio=True, 59 | ) 60 | image = image.numpy() # EagerTensor to np.array 61 | image = dd.image.transform_and_pad_image(image, width, height) 62 | image = image / 255.0 63 | image_shape = image.shape 64 | image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2])) 65 | 66 | y = model.predict(image)[0] 67 | 68 | result_dict = {} 69 | 70 | for i, tag in enumerate(tags): 71 | result_dict[tag] = y[i] 72 | 73 | unsorted_tags_in_theshold = [] 74 | result_tags_print = [] 75 | for tag in tags: 76 | if result_dict[tag] >= threshold: 77 | if tag.startswith("rating:"): 78 | continue 79 | unsorted_tags_in_theshold.append((result_dict[tag], tag)) 80 | result_tags_print.append(f"{result_dict[tag]} {tag}") 81 | 82 | # sort tags 83 | result_tags_out = [] 84 | sort_ndx = 0 85 | if alpha_sort: 86 | sort_ndx = 1 87 | 88 | # sort by reverse by likelihood and normal for alpha, and format tag text as requested 89 | unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) 90 | for weight, tag in unsorted_tags_in_theshold: 91 | tag_outformat = tag 92 | if use_spaces: 93 | tag_outformat = tag_outformat.replace("_", " ") 94 | if use_escape: 95 | tag_outformat = re.sub(re_special, r"\\\1", tag_outformat) 96 | if include_ranks: 97 | tag_outformat = f"({tag_outformat}:{weight:.3f})" 98 | 99 | result_tags_out.append(tag_outformat) 100 | 101 | # print("\n".join(sorted(result_tags_print, reverse=True))) 102 | 103 | return ", ".join(result_tags_out) 104 | 105 | 106 | if __name__ == "__main__": 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument("--path", type=str, default=".") 109 | parser.add_argument("--threshold", type=int, default=0.75) 110 | parser.add_argument("--alpha_sort", type=bool, default=False) 111 | parser.add_argument("--use_spaces", type=bool, default=True) 112 | parser.add_argument("--use_escape", type=bool, default=True) 113 | parser.add_argument("--model_path", type=str, default="") 114 | parser.add_argument("--include_ranks", type=bool, default=False) 115 | 116 | args = parser.parse_args() 117 | 118 | global model_path 119 | model_path:str 120 | if args.model_path == "": 121 | script_path = os.path.realpath(__file__) 122 | default_model_path = os.path.join(os.path.dirname(script_path), "deepdanbooru-models") 123 | # print("No model path specified, using default model path: {}".format(default_model_path)) 124 | model_path = default_model_path 125 | else: 126 | model_path = args.model_path 127 | 128 | types = ('*.jpg', '*.png', '*.jpeg', '*.gif', '*.webp', '*.bmp') 129 | files_grabbed = [] 130 | for files in types: 131 | files_grabbed.extend(glob.glob(os.path.join(args.path, files))) 132 | # print(glob.glob(args.path + files)) 133 | 134 | model, tags = get_deepbooru_tags_model(model_path) 135 | for image_path in tqdm(files_grabbed, desc="Processing"): 136 | image = Image.open(image_path).convert("RGB") 137 | prompt = get_deepbooru_tags_from_model( 138 | model, 139 | tags, 140 | image, 141 | args.threshold, 142 | alpha_sort=args.alpha_sort, 143 | use_spaces=args.use_spaces, 144 | use_escape=args.use_escape, 145 | include_ranks=args.include_ranks, 146 | ) 147 | image_name = os.path.splitext(os.path.basename(image_path))[0] 148 | txt_filename = os.path.join(args.path, f"{image_name}.txt") 149 | # print(f"writing {txt_filename}: {prompt}") 150 | with open(txt_filename, 'w') as f: 151 | f.write(prompt) 152 | 153 | -------------------------------------------------------------------------------- /tools/test_cuda.py: -------------------------------------------------------------------------------- 1 | import torch 2 | print(torch.cuda.is_available()) 3 | -------------------------------------------------------------------------------- /tools/train_dreambooth.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import itertools 4 | import math 5 | import os 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint 12 | from torch.utils.data import Dataset 13 | 14 | from accelerate import Accelerator 15 | from accelerate.logging import get_logger 16 | from accelerate.utils import set_seed 17 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel 18 | from diffusers.optimization import get_scheduler 19 | from huggingface_hub import HfFolder, Repository, whoami 20 | from PIL import Image 21 | from torchvision import transforms 22 | from tqdm.auto import tqdm 23 | from transformers import CLIPTextModel, CLIPTokenizer 24 | 25 | 26 | logger = get_logger(__name__) 27 | 28 | 29 | def parse_args(input_args=None): 30 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 31 | parser.add_argument( 32 | "--pretrained_model_name_or_path", 33 | type=str, 34 | default=None, 35 | required=True, 36 | help="Path to pretrained model or model identifier from huggingface.co/models.", 37 | ) 38 | parser.add_argument( 39 | "--revision", 40 | type=str, 41 | default=None, 42 | required=False, 43 | help="Revision of pretrained model identifier from huggingface.co/models.", 44 | ) 45 | parser.add_argument( 46 | "--tokenizer_name", 47 | type=str, 48 | default=None, 49 | help="Pretrained tokenizer name or path if not the same as model_name", 50 | ) 51 | parser.add_argument( 52 | "--instance_data_dir", 53 | type=str, 54 | default=None, 55 | required=True, 56 | help="A folder containing the training data of instance images.", 57 | ) 58 | parser.add_argument( 59 | "--class_data_dir", 60 | type=str, 61 | default=None, 62 | required=False, 63 | help="A folder containing the training data of class images.", 64 | ) 65 | parser.add_argument( 66 | "--instance_prompt", 67 | type=str, 68 | default=None, 69 | help="The prompt with identifier specifying the instance", 70 | ) 71 | parser.add_argument( 72 | "--class_prompt", 73 | type=str, 74 | default=None, 75 | help="The prompt to specify images in the same class as provided instance images.", 76 | ) 77 | parser.add_argument( 78 | "--with_prior_preservation", 79 | default=False, 80 | action="store_true", 81 | help="Flag to add prior preservation loss.", 82 | ) 83 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") 84 | parser.add_argument( 85 | "--num_class_images", 86 | type=int, 87 | default=100, 88 | help=( 89 | "Minimal class images for prior preservation loss. If not have enough images, additional images will be" 90 | " sampled with class_prompt." 91 | ), 92 | ) 93 | parser.add_argument( 94 | "--output_dir", 95 | type=str, 96 | default="text-inversion-model", 97 | help="The output directory where the model predictions and checkpoints will be written.", 98 | ) 99 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 100 | parser.add_argument( 101 | "--resolution", 102 | type=int, 103 | default=512, 104 | help=( 105 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 106 | " resolution" 107 | ), 108 | ) 109 | parser.add_argument( 110 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" 111 | ) 112 | parser.add_argument( 113 | "--use_filename_as_label", action="store_true", help="Uses the filename as the image labels instead of the instance_prompt, useful for regularization when training for styles with wide image variance" 114 | ) 115 | parser.add_argument( 116 | "--use_txt_as_label", action="store_true", help="Uses the filename.txt file's content as the image labels instead of the instance_prompt, useful for regularization when training for styles with wide image variance" 117 | ) 118 | parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") 119 | parser.add_argument( 120 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 121 | ) 122 | parser.add_argument( 123 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 124 | ) 125 | parser.add_argument("--num_train_epochs", type=int, default=1) 126 | parser.add_argument( 127 | "--max_train_steps", 128 | type=int, 129 | default=None, 130 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 131 | ) 132 | parser.add_argument( 133 | "--gradient_accumulation_steps", 134 | type=int, 135 | default=1, 136 | help="Number of updates steps to accumulate before performing a backward/update pass.", 137 | ) 138 | parser.add_argument( 139 | "--gradient_checkpointing", 140 | action="store_true", 141 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 142 | ) 143 | parser.add_argument( 144 | "--learning_rate", 145 | type=float, 146 | default=5e-6, 147 | help="Initial learning rate (after the potential warmup period) to use.", 148 | ) 149 | parser.add_argument( 150 | "--scale_lr", 151 | action="store_true", 152 | default=False, 153 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 154 | ) 155 | parser.add_argument( 156 | "--lr_scheduler", 157 | type=str, 158 | default="constant", 159 | help=( 160 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 161 | ' "constant", "constant_with_warmup"]' 162 | ), 163 | ) 164 | parser.add_argument( 165 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 166 | ) 167 | parser.add_argument( 168 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 169 | ) 170 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 171 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 172 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 173 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 174 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 175 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 176 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 177 | parser.add_argument( 178 | "--hub_model_id", 179 | type=str, 180 | default=None, 181 | help="The name of the repository to keep in sync with the local `output_dir`.", 182 | ) 183 | parser.add_argument( 184 | "--logging_dir", 185 | type=str, 186 | default="logs", 187 | help=( 188 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 189 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 190 | ), 191 | ) 192 | parser.add_argument( 193 | "--log_with", 194 | type=str, 195 | default="tensorboard", 196 | choices=["tensorboard", "wandb"] 197 | ) 198 | parser.add_argument( 199 | "--mixed_precision", 200 | type=str, 201 | default="no", 202 | choices=["no", "fp16", "bf16"], 203 | help=( 204 | "Whether to use mixed precision. Choose" 205 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 206 | "and an Nvidia Ampere GPU." 207 | ), 208 | ) 209 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 210 | parser.add_argument("--save_model_every_n_steps", type=int) 211 | parser.add_argument("--auto_test_model", action="store_true", help="Whether or not to automatically test the model after saving it") 212 | parser.add_argument("--test_prompt", type=str, default="A photo of a cat", help="The prompt to use for testing the model.") 213 | parser.add_argument("--test_prompts_file", type=str, default=None, help="The file containing the prompts to use for testing the model.example: test_prompts.txt, each line is a prompt") 214 | parser.add_argument("--test_negative_prompt", type=str, default="", help="The negative prompt to use for testing the model.") 215 | parser.add_argument("--test_seed", type=int, default=42, help="The seed to use for testing the model.") 216 | parser.add_argument("--test_num_per_prompt", type=int, default=1, help="The number of images to generate per prompt.") 217 | 218 | if input_args is not None: 219 | args = parser.parse_args(input_args) 220 | else: 221 | args = parser.parse_args() 222 | 223 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 224 | if env_local_rank != -1 and env_local_rank != args.local_rank: 225 | args.local_rank = env_local_rank 226 | 227 | if args.instance_data_dir is None: 228 | raise ValueError("You must specify a train data directory.") 229 | 230 | if args.with_prior_preservation: 231 | if args.class_data_dir is None: 232 | raise ValueError("You must specify a data directory for class images.") 233 | if args.class_prompt is None: 234 | raise ValueError("You must specify prompt for class images.") 235 | 236 | return args 237 | 238 | # turns a path into a filename without the extension 239 | def get_filename(path): 240 | return path.stem 241 | 242 | def get_label_from_txt(path): 243 | txt_path = path.with_suffix(".txt") # get the path to the .txt file 244 | if txt_path.exists(): 245 | with open(txt_path, "r") as f: 246 | return f.read() 247 | else: 248 | return "" 249 | 250 | class DreamBoothDataset(Dataset): 251 | """ 252 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 253 | It pre-processes the images and the tokenizes prompts. 254 | """ 255 | 256 | def __init__( 257 | self, 258 | instance_data_root, 259 | instance_prompt, 260 | tokenizer, 261 | class_data_root=None, 262 | class_prompt=None, 263 | size=512, 264 | center_crop=False, 265 | use_filename_as_label=False, 266 | use_txt_as_label=False, 267 | ): 268 | self.size = size 269 | self.center_crop = center_crop 270 | self.tokenizer = tokenizer 271 | 272 | self.instance_data_root = Path(instance_data_root) 273 | if not self.instance_data_root.exists(): 274 | raise ValueError("Instance images root doesn't exists.") 275 | 276 | self.instance_images_path = list(self.instance_data_root.glob("*.jpg")) + list(self.instance_data_root.glob("*.png")) 277 | self.num_instance_images = len(self.instance_images_path) 278 | self.instance_prompt = instance_prompt 279 | self.use_filename_as_label = use_filename_as_label 280 | self.use_txt_as_label = use_txt_as_label 281 | self._length = self.num_instance_images 282 | 283 | if class_data_root is not None: 284 | self.class_data_root = Path(class_data_root) 285 | self.class_data_root.mkdir(parents=True, exist_ok=True) 286 | self.class_images_path = list(self.class_data_root.glob("*.jpg")) + list(self.class_data_root.glob("*.png")) 287 | self.num_class_images = len(self.class_images_path) 288 | self._length = max(self.num_class_images, self.num_instance_images) 289 | self.class_prompt = class_prompt 290 | else: 291 | self.class_data_root = None 292 | 293 | self.image_transforms = transforms.Compose( 294 | [ 295 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 296 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 297 | transforms.ToTensor(), 298 | transforms.Normalize([0.5], [0.5]), 299 | ] 300 | ) 301 | 302 | def __len__(self): 303 | return self._length 304 | 305 | def __getitem__(self, index): 306 | example = {} 307 | path = self.instance_images_path[index % self.num_instance_images] 308 | prompt = get_filename(path) if self.use_filename_as_label else self.instance_prompt 309 | prompt = get_label_from_txt(path) if self.use_txt_as_label else prompt 310 | 311 | print("prompt", prompt) 312 | 313 | instance_image = Image.open(path) 314 | if not instance_image.mode == "RGB": 315 | instance_image = instance_image.convert("RGB") 316 | example["instance_images"] = self.image_transforms(instance_image) 317 | example["instance_prompt_ids"] = self.tokenizer( 318 | prompt, 319 | padding="do_not_pad", 320 | truncation=True, 321 | max_length=self.tokenizer.model_max_length, 322 | ).input_ids 323 | 324 | if self.class_data_root: 325 | class_image = Image.open(self.class_images_path[index % self.num_class_images]) 326 | if not class_image.mode == "RGB": 327 | class_image = class_image.convert("RGB") 328 | example["class_images"] = self.image_transforms(class_image) 329 | example["class_prompt_ids"] = self.tokenizer( 330 | self.class_prompt, 331 | padding="do_not_pad", 332 | truncation=True, 333 | max_length=self.tokenizer.model_max_length, 334 | ).input_ids 335 | 336 | return example 337 | 338 | 339 | class PromptDataset(Dataset): 340 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 341 | 342 | def __init__(self, prompt, num_samples): 343 | self.prompt = prompt 344 | self.num_samples = num_samples 345 | 346 | def __len__(self): 347 | return self.num_samples 348 | 349 | def __getitem__(self, index): 350 | example = {} 351 | example["prompt"] = self.prompt 352 | example["index"] = index 353 | return example 354 | 355 | 356 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 357 | if token is None: 358 | token = HfFolder.get_token() 359 | if organization is None: 360 | username = whoami(token)["name"] 361 | return f"{username}/{model_id}" 362 | else: 363 | return f"{organization}/{model_id}" 364 | 365 | def test_model(folder, args): 366 | if args.test_prompts_file is not None: 367 | with open(args.test_prompts_file, "r") as f: 368 | prompts = f.read().splitlines() 369 | else: 370 | prompts = [args.test_prompt] 371 | 372 | test_path = os.path.join(folder, "test") 373 | if not os.path.exists(test_path): 374 | os.makedirs(test_path) 375 | 376 | print("Testing the model...") 377 | from diffusers import DDIMScheduler 378 | 379 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 380 | torch_dtype = torch.float16 if device.type == "cuda" else torch.float32 381 | pipeline = StableDiffusionPipeline.from_pretrained( 382 | folder, 383 | torch_dtype=torch_dtype, 384 | safety_checker=None, 385 | load_in_8bit=True, 386 | scheduler = DDIMScheduler( 387 | beta_start=0.00085, 388 | beta_end=0.012, 389 | beta_schedule="scaled_linear", 390 | clip_sample=False, 391 | set_alpha_to_one=False, 392 | ), 393 | ) 394 | pipeline.set_progress_bar_config(disable=True) 395 | pipeline.enable_attention_slicing() 396 | pipeline = pipeline.to(device) 397 | 398 | torch.manual_seed(args.test_seed) 399 | with torch.autocast('cuda'): 400 | for prompt in prompts: 401 | print(f"Generating test images for prompt: {prompt}") 402 | test_images = pipeline( 403 | prompt=prompt, 404 | width=512, 405 | height=512, 406 | negative_prompt=args.test_negative_prompt, 407 | num_inference_steps=30, 408 | num_images_per_prompt=args.test_num_per_prompt, 409 | ).images 410 | 411 | for index, image in enumerate(test_images): 412 | image.save(f"{test_path}/{prompt}_{index}.png") 413 | 414 | del pipeline 415 | if torch.cuda.is_available(): 416 | torch.cuda.empty_cache() 417 | 418 | print(f"Test completed.The examples are saved in {test_path}") 419 | 420 | 421 | def save_model(accelerator, unet, text_encoder, args, step=None): 422 | unet = accelerator.unwrap_model(unet) 423 | text_encoder = accelerator.unwrap_model(text_encoder) 424 | 425 | if step == None: 426 | folder = args.output_dir 427 | else: 428 | folder = args.output_dir + "-Step-" + str(step) 429 | 430 | print("Saving Model Checkpoint...") 431 | print("Directory: " + folder) 432 | 433 | # Create the pipeline using using the trained modules and save it. 434 | if accelerator.is_main_process: 435 | pipeline = StableDiffusionPipeline.from_pretrained( 436 | args.pretrained_model_name_or_path, 437 | unet=unet, 438 | text_encoder=text_encoder, 439 | revision=args.revision, 440 | ) 441 | pipeline.save_pretrained(folder) 442 | del pipeline 443 | if torch.cuda.is_available(): 444 | torch.cuda.empty_cache() 445 | 446 | if args.auto_test_model: 447 | print("Testing Model...") 448 | test_model(folder, args) 449 | 450 | if args.push_to_hub: 451 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) 452 | 453 | 454 | def main(args): 455 | logging_dir = Path(args.logging_dir) 456 | 457 | accelerator = Accelerator( 458 | gradient_accumulation_steps=args.gradient_accumulation_steps, 459 | mixed_precision=args.mixed_precision, 460 | log_with=args.log_with, 461 | logging_dir=logging_dir, 462 | ) 463 | 464 | 465 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 466 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 467 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 468 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 469 | raise ValueError( 470 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 471 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 472 | ) 473 | 474 | if args.seed is not None: 475 | set_seed(args.seed) 476 | 477 | if args.with_prior_preservation: 478 | class_images_dir = Path(args.class_data_dir) 479 | if not class_images_dir.exists(): 480 | class_images_dir.mkdir(parents=True) 481 | cur_class_images = len(list(class_images_dir.iterdir())) 482 | 483 | if cur_class_images < args.num_class_images: 484 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 485 | pipeline = StableDiffusionPipeline.from_pretrained( 486 | args.pretrained_model_name_or_path, 487 | torch_dtype=torch_dtype, 488 | safety_checker=None, 489 | revision=args.revision, 490 | ) 491 | pipeline.set_progress_bar_config(disable=True) 492 | 493 | num_new_images = args.num_class_images - cur_class_images 494 | logger.info(f"Number of class images to sample: {num_new_images}.") 495 | 496 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) 497 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) 498 | 499 | sample_dataloader = accelerator.prepare(sample_dataloader) 500 | pipeline.to(accelerator.device) 501 | 502 | for example in tqdm( 503 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process 504 | ): 505 | images = pipeline(example["prompt"]).images 506 | 507 | for i, image in enumerate(images): 508 | hash_image = hashlib.sha1(image.tobytes()).hexdigest() 509 | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" 510 | image.save(image_filename) 511 | 512 | del pipeline 513 | if torch.cuda.is_available(): 514 | torch.cuda.empty_cache() 515 | 516 | # Handle the repository creation 517 | if accelerator.is_main_process: 518 | if args.push_to_hub: 519 | if args.hub_model_id is None: 520 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 521 | else: 522 | repo_name = args.hub_model_id 523 | repo = Repository(args.output_dir, clone_from=repo_name) 524 | 525 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 526 | if "step_*" not in gitignore: 527 | gitignore.write("step_*\n") 528 | if "epoch_*" not in gitignore: 529 | gitignore.write("epoch_*\n") 530 | elif args.output_dir is not None: 531 | os.makedirs(args.output_dir, exist_ok=True) 532 | 533 | # Load the tokenizer 534 | if args.tokenizer_name: 535 | tokenizer = CLIPTokenizer.from_pretrained( 536 | args.tokenizer_name, 537 | revision=args.revision, 538 | ) 539 | elif args.pretrained_model_name_or_path: 540 | tokenizer = CLIPTokenizer.from_pretrained( 541 | args.pretrained_model_name_or_path, 542 | subfolder="tokenizer", 543 | revision=args.revision, 544 | ) 545 | 546 | # Load models and create wrapper for stable diffusion 547 | text_encoder = CLIPTextModel.from_pretrained( 548 | args.pretrained_model_name_or_path, 549 | subfolder="text_encoder", 550 | revision=args.revision, 551 | ) 552 | vae = AutoencoderKL.from_pretrained( 553 | args.pretrained_model_name_or_path, 554 | subfolder="vae", 555 | revision=args.revision, 556 | ) 557 | unet = UNet2DConditionModel.from_pretrained( 558 | args.pretrained_model_name_or_path, 559 | subfolder="unet", 560 | revision=args.revision, 561 | ) 562 | 563 | vae.requires_grad_(False) 564 | if not args.train_text_encoder: 565 | text_encoder.requires_grad_(False) 566 | 567 | if args.gradient_checkpointing: 568 | unet.enable_gradient_checkpointing() 569 | if args.train_text_encoder: 570 | text_encoder.gradient_checkpointing_enable() 571 | 572 | if args.scale_lr: 573 | args.learning_rate = ( 574 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 575 | ) 576 | 577 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 578 | if args.use_8bit_adam: 579 | try: 580 | import bitsandbytes as bnb 581 | except ImportError: 582 | raise ImportError( 583 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 584 | ) 585 | 586 | optimizer_class = bnb.optim.AdamW8bit 587 | else: 588 | optimizer_class = torch.optim.AdamW 589 | 590 | params_to_optimize = ( 591 | itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() 592 | ) 593 | optimizer = optimizer_class( 594 | params_to_optimize, 595 | lr=args.learning_rate, 596 | betas=(args.adam_beta1, args.adam_beta2), 597 | weight_decay=args.adam_weight_decay, 598 | eps=args.adam_epsilon, 599 | ) 600 | 601 | noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") 602 | 603 | train_dataset = DreamBoothDataset( 604 | instance_data_root=args.instance_data_dir, 605 | instance_prompt=args.instance_prompt, 606 | class_data_root=args.class_data_dir if args.with_prior_preservation else None, 607 | class_prompt=args.class_prompt, 608 | tokenizer=tokenizer, 609 | size=args.resolution, 610 | center_crop=args.center_crop, 611 | use_filename_as_label=args.use_filename_as_label, 612 | use_txt_as_label=args.use_txt_as_label, 613 | ) 614 | 615 | def collate_fn(examples): 616 | input_ids = [example["instance_prompt_ids"] for example in examples] 617 | pixel_values = [example["instance_images"] for example in examples] 618 | 619 | # Concat class and instance examples for prior preservation. 620 | # We do this to avoid doing two forward passes. 621 | if args.with_prior_preservation: 622 | input_ids += [example["class_prompt_ids"] for example in examples] 623 | pixel_values += [example["class_images"] for example in examples] 624 | 625 | pixel_values = torch.stack(pixel_values) 626 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 627 | 628 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 629 | 630 | batch = { 631 | "input_ids": input_ids, 632 | "pixel_values": pixel_values, 633 | } 634 | return batch 635 | 636 | train_dataloader = torch.utils.data.DataLoader( 637 | train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 638 | ) 639 | 640 | # Scheduler and math around the number of training steps. 641 | overrode_max_train_steps = False 642 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 643 | if args.max_train_steps is None: 644 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 645 | overrode_max_train_steps = True 646 | 647 | lr_scheduler = get_scheduler( 648 | args.lr_scheduler, 649 | optimizer=optimizer, 650 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 651 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 652 | ) 653 | 654 | if args.train_text_encoder: 655 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 656 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler 657 | ) 658 | else: 659 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 660 | unet, optimizer, train_dataloader, lr_scheduler 661 | ) 662 | 663 | weight_dtype = torch.float32 664 | if args.mixed_precision == "fp16": 665 | weight_dtype = torch.float16 666 | elif args.mixed_precision == "bf16": 667 | weight_dtype = torch.bfloat16 668 | 669 | # Move text_encode and vae to gpu. 670 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 671 | # as these models are only used for inference, keeping weights in full precision is not required. 672 | vae.to(accelerator.device, dtype=weight_dtype) 673 | if not args.train_text_encoder: 674 | text_encoder.to(accelerator.device, dtype=weight_dtype) 675 | 676 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 677 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 678 | if overrode_max_train_steps: 679 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 680 | # Afterwards we recalculate our number of training epochs 681 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 682 | 683 | # We need to initialize the trackers we use, and also store our configuration. 684 | # The trackers initializes automatically on the main process. 685 | if accelerator.is_main_process: 686 | accelerator.init_trackers("dreambooth", config=vars(args)) 687 | 688 | # Train! 689 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 690 | 691 | logger.info("***** Running training *****") 692 | logger.info(f" Num examples = {len(train_dataset)}") 693 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 694 | logger.info(f" Num Epochs = {args.num_train_epochs}") 695 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 696 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 697 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 698 | logger.info(f" Total optimization steps = {args.max_train_steps}") 699 | # Only show the progress bar once on each machine. 700 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 701 | progress_bar.set_description("Steps") 702 | global_step = 0 703 | 704 | for epoch in range(args.num_train_epochs): 705 | unet.train() 706 | if args.train_text_encoder: 707 | text_encoder.train() 708 | for step, batch in enumerate(train_dataloader): 709 | with accelerator.accumulate(unet): 710 | # Convert images to latent space 711 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 712 | latents = latents * 0.18215 713 | 714 | # Sample noise that we'll add to the latents 715 | noise = torch.randn_like(latents) 716 | bsz = latents.shape[0] 717 | # Sample a random timestep for each image 718 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 719 | timesteps = timesteps.long() 720 | 721 | # Add noise to the latents according to the noise magnitude at each timestep 722 | # (this is the forward diffusion process) 723 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 724 | 725 | # Get the text embedding for conditioning 726 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 727 | 728 | # Predict the noise residual 729 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 730 | 731 | if args.with_prior_preservation: 732 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 733 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 734 | noise, noise_prior = torch.chunk(noise, 2, dim=0) 735 | 736 | # Compute instance loss 737 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() 738 | 739 | # Compute prior loss 740 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") 741 | 742 | # Add the prior loss to the instance loss. 743 | loss = loss + args.prior_loss_weight * prior_loss 744 | else: 745 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 746 | 747 | accelerator.backward(loss) 748 | if accelerator.sync_gradients: 749 | params_to_clip = ( 750 | itertools.chain(unet.parameters(), text_encoder.parameters()) 751 | if args.train_text_encoder 752 | else unet.parameters() 753 | ) 754 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 755 | optimizer.step() 756 | lr_scheduler.step() 757 | optimizer.zero_grad() 758 | 759 | # Checks if the accelerator has performed an optimization step behind the scenes 760 | if accelerator.sync_gradients: 761 | progress_bar.update(1) 762 | global_step += 1 763 | 764 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 765 | progress_bar.set_postfix(**logs) 766 | accelerator.log(logs, step=global_step) 767 | 768 | if global_step >= args.max_train_steps: 769 | break 770 | 771 | 772 | if args.save_model_every_n_steps != None and (global_step % args.save_model_every_n_steps) == 0: 773 | save_model(accelerator, unet, text_encoder, args, global_step) 774 | 775 | accelerator.wait_for_everyone() 776 | 777 | save_model(accelerator, unet, text_encoder, args, step=None) 778 | 779 | accelerator.end_training() 780 | 781 | 782 | if __name__ == "__main__": 783 | args = parse_args() 784 | main(args) 785 | -------------------------------------------------------------------------------- /tools/train_dreambooth_rect.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import itertools 4 | import math 5 | import os 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import torch.utils.checkpoint 12 | from torch.utils.data import Dataset 13 | 14 | from accelerate import Accelerator 15 | from accelerate.logging import get_logger 16 | from accelerate.utils import set_seed 17 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel 18 | from diffusers.optimization import get_scheduler 19 | from huggingface_hub import HfFolder, Repository, whoami 20 | from PIL import Image 21 | from torchvision import transforms 22 | from tqdm.auto import tqdm 23 | from transformers import CLIPTextModel, CLIPTokenizer 24 | 25 | 26 | logger = get_logger(__name__) 27 | 28 | 29 | def parse_args(input_args=None): 30 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 31 | parser.add_argument( 32 | "--pretrained_model_name_or_path", 33 | type=str, 34 | default=None, 35 | required=True, 36 | help="Path to pretrained model or model identifier from huggingface.co/models.", 37 | ) 38 | parser.add_argument( 39 | "--revision", 40 | type=str, 41 | default=None, 42 | required=False, 43 | help="Revision of pretrained model identifier from huggingface.co/models.", 44 | ) 45 | parser.add_argument( 46 | "--tokenizer_name", 47 | type=str, 48 | default=None, 49 | help="Pretrained tokenizer name or path if not the same as model_name", 50 | ) 51 | parser.add_argument( 52 | "--instance_data_dir", 53 | type=str, 54 | default=None, 55 | required=True, 56 | help="A folder containing the training data of instance images.", 57 | ) 58 | parser.add_argument( 59 | "--class_data_dir", 60 | type=str, 61 | default=None, 62 | required=False, 63 | help="A folder containing the training data of class images.", 64 | ) 65 | parser.add_argument( 66 | "--instance_prompt", 67 | type=str, 68 | default=None, 69 | help="The prompt with identifier specifying the instance", 70 | ) 71 | parser.add_argument( 72 | "--class_prompt", 73 | type=str, 74 | default=None, 75 | help="The prompt to specify images in the same class as provided instance images.", 76 | ) 77 | parser.add_argument( 78 | "--with_prior_preservation", 79 | default=False, 80 | action="store_true", 81 | help="Flag to add prior preservation loss.", 82 | ) 83 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") 84 | parser.add_argument( 85 | "--num_class_images", 86 | type=int, 87 | default=100, 88 | help=( 89 | "Minimal class images for prior preservation loss. If not have enough images, additional images will be" 90 | " sampled with class_prompt." 91 | ), 92 | ) 93 | parser.add_argument( 94 | "--output_dir", 95 | type=str, 96 | default="text-inversion-model", 97 | help="The output directory where the model predictions and checkpoints will be written.", 98 | ) 99 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 100 | parser.add_argument( 101 | "--resolution", 102 | type=int, 103 | default=768, 104 | help=( 105 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 106 | " resolution" 107 | ), 108 | ) 109 | parser.add_argument( 110 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" 111 | ) 112 | 113 | parser.add_argument("--width", type=int, default=None, help="training width.") 114 | parser.add_argument("--height", type=int, default=None, help="training height.") 115 | 116 | parser.add_argument( 117 | "--use_filename_as_label", action="store_true", help="Uses the filename as the image labels instead of the instance_prompt, useful for regularization when training for styles with wide image variance" 118 | ) 119 | parser.add_argument( 120 | "--use_txt_as_label", action="store_true", help="Uses the filename.txt file's content as the image labels instead of the instance_prompt, useful for regularization when training for styles with wide image variance" 121 | ) 122 | parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") 123 | parser.add_argument( 124 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 125 | ) 126 | parser.add_argument( 127 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 128 | ) 129 | parser.add_argument("--num_train_epochs", type=int, default=1) 130 | parser.add_argument( 131 | "--max_train_steps", 132 | type=int, 133 | default=None, 134 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 135 | ) 136 | parser.add_argument( 137 | "--gradient_accumulation_steps", 138 | type=int, 139 | default=1, 140 | help="Number of updates steps to accumulate before performing a backward/update pass.", 141 | ) 142 | parser.add_argument( 143 | "--gradient_checkpointing", 144 | action="store_true", 145 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 146 | ) 147 | parser.add_argument( 148 | "--learning_rate", 149 | type=float, 150 | default=5e-6, 151 | help="Initial learning rate (after the potential warmup period) to use.", 152 | ) 153 | parser.add_argument( 154 | "--scale_lr", 155 | action="store_true", 156 | default=False, 157 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 158 | ) 159 | parser.add_argument( 160 | "--lr_scheduler", 161 | type=str, 162 | default="constant", 163 | help=( 164 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 165 | ' "constant", "constant_with_warmup"]' 166 | ), 167 | ) 168 | parser.add_argument( 169 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 170 | ) 171 | parser.add_argument( 172 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 173 | ) 174 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 175 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 176 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 177 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 178 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 179 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 180 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 181 | parser.add_argument( 182 | "--hub_model_id", 183 | type=str, 184 | default=None, 185 | help="The name of the repository to keep in sync with the local `output_dir`.", 186 | ) 187 | parser.add_argument( 188 | "--logging_dir", 189 | type=str, 190 | default="logs", 191 | help=( 192 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 193 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 194 | ), 195 | ) 196 | parser.add_argument( 197 | "--log_with", 198 | type=str, 199 | default="tensorboard", 200 | choices=["tensorboard", "wandb"] 201 | ) 202 | parser.add_argument( 203 | "--mixed_precision", 204 | type=str, 205 | default="no", 206 | choices=["no", "fp16", "bf16"], 207 | help=( 208 | "Whether to use mixed precision. Choose" 209 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 210 | "and an Nvidia Ampere GPU." 211 | ), 212 | ) 213 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 214 | parser.add_argument("--save_model_every_n_steps", type=int) 215 | parser.add_argument("--auto_test_model", action="store_true", help="Whether or not to automatically test the model after saving it") 216 | parser.add_argument("--test_prompt", type=str, default="A photo of a cat", help="The prompt to use for testing the model.") 217 | parser.add_argument("--test_prompts_file", type=str, default=None, help="The file containing the prompts to use for testing the model.example: test_prompts.txt, each line is a prompt") 218 | parser.add_argument("--test_negative_prompt", type=str, default="", help="The negative prompt to use for testing the model.") 219 | parser.add_argument("--test_seed", type=int, default=42, help="The seed to use for testing the model.") 220 | parser.add_argument("--test_num_per_prompt", type=int, default=1, help="The number of images to generate per prompt.") 221 | 222 | if input_args is not None: 223 | args = parser.parse_args(input_args) 224 | else: 225 | args = parser.parse_args() 226 | 227 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 228 | if env_local_rank != -1 and env_local_rank != args.local_rank: 229 | args.local_rank = env_local_rank 230 | 231 | if args.instance_data_dir is None: 232 | raise ValueError("You must specify a train data directory.") 233 | 234 | if args.with_prior_preservation: 235 | if args.class_data_dir is None: 236 | raise ValueError("You must specify a data directory for class images.") 237 | if args.class_prompt is None: 238 | raise ValueError("You must specify prompt for class images.") 239 | 240 | return args 241 | 242 | # turns a path into a filename without the extension 243 | def get_filename(path): 244 | return path.stem 245 | 246 | def get_label_from_txt(path): 247 | txt_path = path.with_suffix(".txt") # get the path to the .txt file 248 | if txt_path.exists(): 249 | with open(txt_path, "r") as f: 250 | return f.read() 251 | else: 252 | return "" 253 | 254 | class DreamBoothDataset(Dataset): 255 | """ 256 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 257 | It pre-processes the images and the tokenizes prompts. 258 | """ 259 | 260 | def __init__( 261 | self, 262 | instance_data_root, 263 | instance_prompt, 264 | tokenizer, 265 | class_data_root=None, 266 | class_prompt=None, 267 | size_width=768, 268 | size_height=512, 269 | center_crop=False, 270 | use_filename_as_label=False, 271 | use_txt_as_label=False, 272 | ): 273 | # self.size = size 274 | self.size_width=size_width 275 | self.size_height=size_height 276 | self.center_crop = center_crop 277 | self.tokenizer = tokenizer 278 | 279 | self.instance_data_root = Path(instance_data_root) 280 | if not self.instance_data_root.exists(): 281 | raise ValueError("Instance images root doesn't exists.") 282 | 283 | self.instance_images_path = list(self.instance_data_root.glob("*.jpg")) + list(self.instance_data_root.glob("*.png")) 284 | self.num_instance_images = len(self.instance_images_path) 285 | self.instance_prompt = instance_prompt 286 | self.use_filename_as_label = use_filename_as_label 287 | self.use_txt_as_label = use_txt_as_label 288 | self._length = self.num_instance_images 289 | 290 | if class_data_root is not None: 291 | self.class_data_root = Path(class_data_root) 292 | self.class_data_root.mkdir(parents=True, exist_ok=True) 293 | self.class_images_path = list(self.class_data_root.glob("*.jpg")) + list(self.class_data_root.glob("*.png")) 294 | self.num_class_images = len(self.class_images_path) 295 | self._length = max(self.num_class_images, self.num_instance_images) 296 | self.class_prompt = class_prompt 297 | else: 298 | self.class_data_root = None 299 | 300 | if size_width!=size_height: 301 | self.image_transforms = transforms.Compose( 302 | [ 303 | # transforms.Resize((self.size_height,self.size_width), interpolation=transforms.InterpolationMode.BILINEAR), 304 | transforms.ToTensor(), 305 | transforms.Normalize([0.5], [0.5]), 306 | ] 307 | ) 308 | elif size_width==size_height: 309 | self.image_transforms = transforms.Compose( 310 | [ 311 | transforms.Resize(self.size_height, interpolation=transforms.InterpolationMode.BILINEAR), 312 | transforms.CenterCrop(self.size_height) if center_crop else transforms.RandomCrop(self.size_height), 313 | transforms.ToTensor(), 314 | transforms.Normalize([0.5], [0.5]), 315 | ] 316 | ) 317 | 318 | def __len__(self): 319 | return self._length 320 | 321 | def __getitem__(self, index): 322 | example = {} 323 | path = self.instance_images_path[index % self.num_instance_images] 324 | prompt = get_filename(path) if self.use_filename_as_label else self.instance_prompt 325 | prompt = get_label_from_txt(path) if self.use_txt_as_label else prompt 326 | 327 | print("prompt", prompt) 328 | 329 | instance_image = Image.open(path) 330 | if not instance_image.mode == "RGB": 331 | instance_image = instance_image.convert("RGB") 332 | example["instance_images"] = self.image_transforms(instance_image) 333 | example["instance_prompt_ids"] = self.tokenizer( 334 | prompt, 335 | padding="do_not_pad", 336 | truncation=True, 337 | max_length=self.tokenizer.model_max_length, 338 | ).input_ids 339 | 340 | if self.class_data_root: 341 | class_image = Image.open(self.class_images_path[index % self.num_class_images]) 342 | if not class_image.mode == "RGB": 343 | class_image = class_image.convert("RGB") 344 | example["class_images"] = self.image_transforms(class_image) 345 | example["class_prompt_ids"] = self.tokenizer( 346 | self.class_prompt, 347 | padding="do_not_pad", 348 | truncation=True, 349 | max_length=self.tokenizer.model_max_length, 350 | ).input_ids 351 | 352 | return example 353 | 354 | 355 | class PromptDataset(Dataset): 356 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 357 | 358 | def __init__(self, prompt, num_samples): 359 | self.prompt = prompt 360 | self.num_samples = num_samples 361 | 362 | def __len__(self): 363 | return self.num_samples 364 | 365 | def __getitem__(self, index): 366 | example = {} 367 | example["prompt"] = self.prompt 368 | example["index"] = index 369 | return example 370 | 371 | 372 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 373 | if token is None: 374 | token = HfFolder.get_token() 375 | if organization is None: 376 | username = whoami(token)["name"] 377 | return f"{username}/{model_id}" 378 | else: 379 | return f"{organization}/{model_id}" 380 | 381 | def test_model(folder, args): 382 | if args.test_prompts_file is not None: 383 | with open(args.test_prompts_file, "r") as f: 384 | prompts = f.read().splitlines() 385 | else: 386 | prompts = [args.test_prompt] 387 | 388 | test_path = os.path.join(folder, "test") 389 | if not os.path.exists(test_path): 390 | os.makedirs(test_path) 391 | 392 | print("Testing the model...") 393 | from diffusers import DDIMScheduler 394 | 395 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 396 | torch_dtype = torch.float16 if device.type == "cuda" else torch.float32 397 | pipeline = StableDiffusionPipeline.from_pretrained( 398 | folder, 399 | torch_dtype=torch_dtype, 400 | safety_checker=None, 401 | load_in_8bit=True, 402 | scheduler = DDIMScheduler( 403 | beta_start=0.00085, 404 | beta_end=0.012, 405 | beta_schedule="scaled_linear", 406 | clip_sample=False, 407 | set_alpha_to_one=False, 408 | ), 409 | ) 410 | pipeline.set_progress_bar_config(disable=True) 411 | pipeline.enable_attention_slicing() 412 | pipeline = pipeline.to(device) 413 | 414 | torch.manual_seed(args.test_seed) 415 | with torch.autocast('cuda'): 416 | for prompt in prompts: 417 | print(f"Generating test images for prompt: {prompt}") 418 | test_images = pipeline( 419 | prompt=prompt, 420 | width=512, 421 | height=512, 422 | negative_prompt=args.test_negative_prompt, 423 | num_inference_steps=30, 424 | num_images_per_prompt=args.test_num_per_prompt, 425 | ).images 426 | 427 | for index, image in enumerate(test_images): 428 | image.save(f"{test_path}/{prompt}_{index}.png") 429 | 430 | del pipeline 431 | if torch.cuda.is_available(): 432 | torch.cuda.empty_cache() 433 | 434 | print(f"Test completed.The examples are saved in {test_path}") 435 | 436 | 437 | def save_model(accelerator, unet, text_encoder, args, step=None): 438 | unet = accelerator.unwrap_model(unet) 439 | text_encoder = accelerator.unwrap_model(text_encoder) 440 | 441 | if step == None: 442 | folder = args.output_dir 443 | else: 444 | folder = args.output_dir + "-Step-" + str(step) 445 | 446 | print("Saving Model Checkpoint...") 447 | print("Directory: " + folder) 448 | 449 | # Create the pipeline using using the trained modules and save it. 450 | if accelerator.is_main_process: 451 | pipeline = StableDiffusionPipeline.from_pretrained( 452 | args.pretrained_model_name_or_path, 453 | unet=unet, 454 | text_encoder=text_encoder, 455 | revision=args.revision, 456 | ) 457 | pipeline.save_pretrained(folder) 458 | del pipeline 459 | if torch.cuda.is_available(): 460 | torch.cuda.empty_cache() 461 | 462 | if args.auto_test_model: 463 | print("Testing Model...") 464 | test_model(folder, args) 465 | 466 | if args.push_to_hub: 467 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) 468 | 469 | 470 | def main(args): 471 | logging_dir = Path(args.logging_dir) 472 | 473 | accelerator = Accelerator( 474 | gradient_accumulation_steps=args.gradient_accumulation_steps, 475 | mixed_precision=args.mixed_precision, 476 | log_with=args.log_with, 477 | logging_dir=logging_dir, 478 | ) 479 | 480 | 481 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 482 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 483 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 484 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 485 | raise ValueError( 486 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 487 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 488 | ) 489 | 490 | if args.seed is not None: 491 | set_seed(args.seed) 492 | 493 | if args.with_prior_preservation: 494 | class_images_dir = Path(args.class_data_dir) 495 | if not class_images_dir.exists(): 496 | class_images_dir.mkdir(parents=True) 497 | cur_class_images = len(list(class_images_dir.iterdir())) 498 | 499 | if cur_class_images < args.num_class_images: 500 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 501 | pipeline = StableDiffusionPipeline.from_pretrained( 502 | args.pretrained_model_name_or_path, 503 | torch_dtype=torch_dtype, 504 | safety_checker=None, 505 | revision=args.revision, 506 | ) 507 | pipeline.set_progress_bar_config(disable=True) 508 | 509 | num_new_images = args.num_class_images - cur_class_images 510 | logger.info(f"Number of class images to sample: {num_new_images}.") 511 | 512 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) 513 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) 514 | 515 | sample_dataloader = accelerator.prepare(sample_dataloader) 516 | pipeline.to(accelerator.device) 517 | 518 | for example in tqdm( 519 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process 520 | ): 521 | images = pipeline(example["prompt"]).images 522 | 523 | for i, image in enumerate(images): 524 | hash_image = hashlib.sha1(image.tobytes()).hexdigest() 525 | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" 526 | image.save(image_filename) 527 | 528 | del pipeline 529 | if torch.cuda.is_available(): 530 | torch.cuda.empty_cache() 531 | 532 | # Handle the repository creation 533 | if accelerator.is_main_process: 534 | if args.push_to_hub: 535 | if args.hub_model_id is None: 536 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 537 | else: 538 | repo_name = args.hub_model_id 539 | repo = Repository(args.output_dir, clone_from=repo_name) 540 | 541 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 542 | if "step_*" not in gitignore: 543 | gitignore.write("step_*\n") 544 | if "epoch_*" not in gitignore: 545 | gitignore.write("epoch_*\n") 546 | elif args.output_dir is not None: 547 | os.makedirs(args.output_dir, exist_ok=True) 548 | 549 | # Load the tokenizer 550 | if args.tokenizer_name: 551 | tokenizer = CLIPTokenizer.from_pretrained( 552 | args.tokenizer_name, 553 | revision=args.revision, 554 | ) 555 | elif args.pretrained_model_name_or_path: 556 | tokenizer = CLIPTokenizer.from_pretrained( 557 | args.pretrained_model_name_or_path, 558 | subfolder="tokenizer", 559 | revision=args.revision, 560 | ) 561 | 562 | # Load models and create wrapper for stable diffusion 563 | text_encoder = CLIPTextModel.from_pretrained( 564 | args.pretrained_model_name_or_path, 565 | subfolder="text_encoder", 566 | revision=args.revision, 567 | ) 568 | vae = AutoencoderKL.from_pretrained( 569 | args.pretrained_model_name_or_path, 570 | subfolder="vae", 571 | revision=args.revision, 572 | ) 573 | unet = UNet2DConditionModel.from_pretrained( 574 | args.pretrained_model_name_or_path, 575 | subfolder="unet", 576 | revision=args.revision, 577 | ) 578 | 579 | vae.requires_grad_(False) 580 | if not args.train_text_encoder: 581 | text_encoder.requires_grad_(False) 582 | 583 | if args.gradient_checkpointing: 584 | unet.enable_gradient_checkpointing() 585 | if args.train_text_encoder: 586 | text_encoder.gradient_checkpointing_enable() 587 | 588 | if args.scale_lr: 589 | args.learning_rate = ( 590 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 591 | ) 592 | 593 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 594 | if args.use_8bit_adam: 595 | try: 596 | import bitsandbytes as bnb 597 | except ImportError: 598 | raise ImportError( 599 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 600 | ) 601 | 602 | optimizer_class = bnb.optim.AdamW8bit 603 | else: 604 | optimizer_class = torch.optim.AdamW 605 | 606 | params_to_optimize = ( 607 | itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() 608 | ) 609 | optimizer = optimizer_class( 610 | params_to_optimize, 611 | lr=args.learning_rate, 612 | betas=(args.adam_beta1, args.adam_beta2), 613 | weight_decay=args.adam_weight_decay, 614 | eps=args.adam_epsilon, 615 | ) 616 | 617 | noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") 618 | 619 | train_dataset = DreamBoothDataset( 620 | instance_data_root=args.instance_data_dir, 621 | instance_prompt=args.instance_prompt, 622 | class_data_root=args.class_data_dir if args.with_prior_preservation else None, 623 | class_prompt=args.class_prompt, 624 | tokenizer=tokenizer, 625 | size_width=args.width, 626 | size_height=args.height, 627 | center_crop=args.center_crop, 628 | use_filename_as_label=args.use_filename_as_label, 629 | use_txt_as_label=args.use_txt_as_label, 630 | ) 631 | 632 | def collate_fn(examples): 633 | input_ids = [example["instance_prompt_ids"] for example in examples] 634 | pixel_values = [example["instance_images"] for example in examples] 635 | 636 | # Concat class and instance examples for prior preservation. 637 | # We do this to avoid doing two forward passes. 638 | if args.with_prior_preservation: 639 | input_ids += [example["class_prompt_ids"] for example in examples] 640 | pixel_values += [example["class_images"] for example in examples] 641 | 642 | pixel_values = torch.stack(pixel_values) 643 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 644 | 645 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 646 | 647 | batch = { 648 | "input_ids": input_ids, 649 | "pixel_values": pixel_values, 650 | } 651 | return batch 652 | 653 | train_dataloader = torch.utils.data.DataLoader( 654 | train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 655 | ) 656 | 657 | # Scheduler and math around the number of training steps. 658 | overrode_max_train_steps = False 659 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 660 | if args.max_train_steps is None: 661 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 662 | overrode_max_train_steps = True 663 | 664 | lr_scheduler = get_scheduler( 665 | args.lr_scheduler, 666 | optimizer=optimizer, 667 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 668 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 669 | ) 670 | 671 | if args.train_text_encoder: 672 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 673 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler 674 | ) 675 | else: 676 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 677 | unet, optimizer, train_dataloader, lr_scheduler 678 | ) 679 | 680 | weight_dtype = torch.float32 681 | if args.mixed_precision == "fp16": 682 | weight_dtype = torch.float16 683 | elif args.mixed_precision == "bf16": 684 | weight_dtype = torch.bfloat16 685 | 686 | # Move text_encode and vae to gpu. 687 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 688 | # as these models are only used for inference, keeping weights in full precision is not required. 689 | vae.to(accelerator.device, dtype=weight_dtype) 690 | if not args.train_text_encoder: 691 | text_encoder.to(accelerator.device, dtype=weight_dtype) 692 | 693 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 694 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 695 | if overrode_max_train_steps: 696 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 697 | # Afterwards we recalculate our number of training epochs 698 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 699 | 700 | # We need to initialize the trackers we use, and also store our configuration. 701 | # The trackers initializes automatically on the main process. 702 | if accelerator.is_main_process: 703 | accelerator.init_trackers("dreambooth", config=vars(args)) 704 | 705 | # Train! 706 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 707 | 708 | logger.info("***** Running training *****") 709 | logger.info(f" Num examples = {len(train_dataset)}") 710 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 711 | logger.info(f" Num Epochs = {args.num_train_epochs}") 712 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 713 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 714 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 715 | logger.info(f" Total optimization steps = {args.max_train_steps}") 716 | # Only show the progress bar once on each machine. 717 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 718 | progress_bar.set_description("Steps") 719 | global_step = 0 720 | 721 | for epoch in range(args.num_train_epochs): 722 | unet.train() 723 | if args.train_text_encoder: 724 | text_encoder.train() 725 | for step, batch in enumerate(train_dataloader): 726 | with accelerator.accumulate(unet): 727 | # Convert images to latent space 728 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 729 | latents = latents * 0.18215 730 | 731 | # Sample noise that we'll add to the latents 732 | noise = torch.randn_like(latents) 733 | bsz = latents.shape[0] 734 | # Sample a random timestep for each image 735 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 736 | timesteps = timesteps.long() 737 | 738 | # Add noise to the latents according to the noise magnitude at each timestep 739 | # (this is the forward diffusion process) 740 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 741 | 742 | # Get the text embedding for conditioning 743 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 744 | 745 | # Predict the noise residual 746 | # print(noisy_latents.shape) 747 | # print(encoder_hidden_states.shape) 748 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 749 | 750 | if args.with_prior_preservation: 751 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 752 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 753 | noise, noise_prior = torch.chunk(noise, 2, dim=0) 754 | 755 | # Compute instance loss 756 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() 757 | 758 | # Compute prior loss 759 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") 760 | 761 | # Add the prior loss to the instance loss. 762 | loss = loss + args.prior_loss_weight * prior_loss 763 | else: 764 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 765 | 766 | accelerator.backward(loss) 767 | if accelerator.sync_gradients: 768 | params_to_clip = ( 769 | itertools.chain(unet.parameters(), text_encoder.parameters()) 770 | if args.train_text_encoder 771 | else unet.parameters() 772 | ) 773 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 774 | optimizer.step() 775 | lr_scheduler.step() 776 | optimizer.zero_grad() 777 | 778 | # Checks if the accelerator has performed an optimization step behind the scenes 779 | if accelerator.sync_gradients: 780 | progress_bar.update(1) 781 | global_step += 1 782 | 783 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 784 | progress_bar.set_postfix(**logs) 785 | accelerator.log(logs, step=global_step) 786 | 787 | if global_step >= args.max_train_steps: 788 | break 789 | 790 | 791 | if args.save_model_every_n_steps != None and (global_step % args.save_model_every_n_steps) == 0: 792 | save_model(accelerator, unet, text_encoder, args, global_step) 793 | 794 | accelerator.wait_for_everyone() 795 | 796 | save_model(accelerator, unet, text_encoder, args, step=None) 797 | 798 | accelerator.end_training() 799 | 800 | 801 | if __name__ == "__main__": 802 | args = parse_args() 803 | main(args) 804 | -------------------------------------------------------------------------------- /tools/train_textual_inversion.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import math 4 | import os 5 | import random 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | import numpy as np 10 | # import torch 11 | import oneflow as torch 12 | import torch.nn.functional as F 13 | import torch.utils.checkpoint 14 | from torch.utils.data import Dataset 15 | 16 | import PIL 17 | from accelerate import Accelerator 18 | from accelerate.logging import get_logger 19 | from accelerate.utils import set_seed 20 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel 21 | from diffusers.optimization import get_scheduler 22 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker 23 | from huggingface_hub import HfFolder, Repository, whoami 24 | from PIL import Image 25 | from torchvision import transforms 26 | from tqdm.auto import tqdm 27 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 28 | 29 | 30 | logger = get_logger(__name__) 31 | 32 | 33 | def save_progress(text_encoder, placeholder_token_id, accelerator, args): 34 | logger.info("Saving embeddings") 35 | learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] 36 | learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} 37 | torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin")) 38 | 39 | 40 | def parse_args(): 41 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 42 | parser.add_argument( 43 | "--save_steps", 44 | type=int, 45 | default=500, 46 | help="Save learned_embeds.bin every X updates steps.", 47 | ) 48 | parser.add_argument( 49 | "--pretrained_model_name_or_path", 50 | type=str, 51 | default=None, 52 | required=True, 53 | help="Path to pretrained model or model identifier from huggingface.co/models.", 54 | ) 55 | parser.add_argument( 56 | "--tokenizer_name", 57 | type=str, 58 | default=None, 59 | help="Pretrained tokenizer name or path if not the same as model_name", 60 | ) 61 | parser.add_argument( 62 | "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data." 63 | ) 64 | parser.add_argument( 65 | "--placeholder_token", 66 | type=str, 67 | default=None, 68 | required=True, 69 | help="A token to use as a placeholder for the concept.", 70 | ) 71 | parser.add_argument( 72 | "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word." 73 | ) 74 | parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'") 75 | parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.") 76 | parser.add_argument( 77 | "--output_dir", 78 | type=str, 79 | default="text-inversion-model", 80 | help="The output directory where the model predictions and checkpoints will be written.", 81 | ) 82 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 83 | parser.add_argument( 84 | "--resolution", 85 | type=int, 86 | default=512, 87 | help=( 88 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 89 | " resolution" 90 | ), 91 | ) 92 | parser.add_argument( 93 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" 94 | ) 95 | parser.add_argument( 96 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 97 | ) 98 | parser.add_argument("--num_train_epochs", type=int, default=100) 99 | parser.add_argument( 100 | "--max_train_steps", 101 | type=int, 102 | default=5000, 103 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 104 | ) 105 | parser.add_argument( 106 | "--gradient_accumulation_steps", 107 | type=int, 108 | default=1, 109 | help="Number of updates steps to accumulate before performing a backward/update pass.", 110 | ) 111 | parser.add_argument( 112 | "--learning_rate", 113 | type=float, 114 | default=1e-4, 115 | help="Initial learning rate (after the potential warmup period) to use.", 116 | ) 117 | parser.add_argument( 118 | "--scale_lr", 119 | action="store_true", 120 | default=True, 121 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 122 | ) 123 | parser.add_argument( 124 | "--lr_scheduler", 125 | type=str, 126 | default="constant", 127 | help=( 128 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 129 | ' "constant", "constant_with_warmup"]' 130 | ), 131 | ) 132 | parser.add_argument( 133 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 134 | ) 135 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 136 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 137 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 138 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 139 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 140 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 141 | parser.add_argument( 142 | "--hub_model_id", 143 | type=str, 144 | default=None, 145 | help="The name of the repository to keep in sync with the local `output_dir`.", 146 | ) 147 | parser.add_argument( 148 | "--logging_dir", 149 | type=str, 150 | default="logs", 151 | help=( 152 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 153 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 154 | ), 155 | ) 156 | parser.add_argument( 157 | "--mixed_precision", 158 | type=str, 159 | default="no", 160 | choices=["no", "fp16", "bf16"], 161 | help=( 162 | "Whether to use mixed precision. Choose" 163 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 164 | "and an Nvidia Ampere GPU." 165 | ), 166 | ) 167 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 168 | 169 | args = parser.parse_args() 170 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 171 | if env_local_rank != -1 and env_local_rank != args.local_rank: 172 | args.local_rank = env_local_rank 173 | 174 | if args.train_data_dir is None: 175 | raise ValueError("You must specify a train data directory.") 176 | 177 | return args 178 | 179 | 180 | imagenet_templates_small = [ 181 | "a photo of a {}", 182 | "a rendering of a {}", 183 | "a cropped photo of the {}", 184 | "the photo of a {}", 185 | "a photo of a clean {}", 186 | "a photo of a dirty {}", 187 | "a dark photo of the {}", 188 | "a photo of my {}", 189 | "a photo of the cool {}", 190 | "a close-up photo of a {}", 191 | "a bright photo of the {}", 192 | "a cropped photo of a {}", 193 | "a photo of the {}", 194 | "a good photo of the {}", 195 | "a photo of one {}", 196 | "a close-up photo of the {}", 197 | "a rendition of the {}", 198 | "a photo of the clean {}", 199 | "a rendition of a {}", 200 | "a photo of a nice {}", 201 | "a good photo of a {}", 202 | "a photo of the nice {}", 203 | "a photo of the small {}", 204 | "a photo of the weird {}", 205 | "a photo of the large {}", 206 | "a photo of a cool {}", 207 | "a photo of a small {}", 208 | ] 209 | 210 | imagenet_style_templates_small = [ 211 | "a painting in the style of {}", 212 | "a rendering in the style of {}", 213 | "a cropped painting in the style of {}", 214 | "the painting in the style of {}", 215 | "a clean painting in the style of {}", 216 | "a dirty painting in the style of {}", 217 | "a dark painting in the style of {}", 218 | "a picture in the style of {}", 219 | "a cool painting in the style of {}", 220 | "a close-up painting in the style of {}", 221 | "a bright painting in the style of {}", 222 | "a cropped painting in the style of {}", 223 | "a good painting in the style of {}", 224 | "a close-up painting in the style of {}", 225 | "a rendition in the style of {}", 226 | "a nice painting in the style of {}", 227 | "a small painting in the style of {}", 228 | "a weird painting in the style of {}", 229 | "a large painting in the style of {}", 230 | ] 231 | 232 | 233 | class TextualInversionDataset(Dataset): 234 | def __init__( 235 | self, 236 | data_root, 237 | tokenizer, 238 | learnable_property="object", # [object, style] 239 | size=512, 240 | repeats=100, 241 | interpolation="bicubic", 242 | flip_p=0.5, 243 | set="train", 244 | placeholder_token="*", 245 | center_crop=False, 246 | ): 247 | self.data_root = data_root 248 | self.tokenizer = tokenizer 249 | self.learnable_property = learnable_property 250 | self.size = size 251 | self.placeholder_token = placeholder_token 252 | self.center_crop = center_crop 253 | self.flip_p = flip_p 254 | 255 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] 256 | 257 | self.num_images = len(self.image_paths) 258 | self._length = self.num_images 259 | 260 | if set == "train": 261 | self._length = self.num_images * repeats 262 | 263 | self.interpolation = { 264 | "linear": PIL.Image.LINEAR, 265 | "bilinear": PIL.Image.BILINEAR, 266 | "bicubic": PIL.Image.BICUBIC, 267 | "lanczos": PIL.Image.LANCZOS, 268 | }[interpolation] 269 | 270 | self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small 271 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) 272 | 273 | def __len__(self): 274 | return self._length 275 | 276 | def __getitem__(self, i): 277 | example = {} 278 | image = Image.open(self.image_paths[i % self.num_images]) 279 | 280 | if not image.mode == "RGB": 281 | image = image.convert("RGB") 282 | 283 | placeholder_string = self.placeholder_token 284 | text = random.choice(self.templates).format(placeholder_string) 285 | 286 | example["input_ids"] = self.tokenizer( 287 | text, 288 | padding="max_length", 289 | truncation=True, 290 | max_length=self.tokenizer.model_max_length, 291 | return_tensors="pt", 292 | ).input_ids[0] 293 | 294 | # default to score-sde preprocessing 295 | img = np.array(image).astype(np.uint8) 296 | 297 | if self.center_crop: 298 | crop = min(img.shape[0], img.shape[1]) 299 | h, w, = ( 300 | img.shape[0], 301 | img.shape[1], 302 | ) 303 | img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] 304 | 305 | image = Image.fromarray(img) 306 | image = image.resize((self.size, self.size), resample=self.interpolation) 307 | 308 | image = self.flip_transform(image) 309 | image = np.array(image).astype(np.uint8) 310 | image = (image / 127.5 - 1.0).astype(np.float32) 311 | 312 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) 313 | return example 314 | 315 | 316 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 317 | if token is None: 318 | token = HfFolder.get_token() 319 | if organization is None: 320 | username = whoami(token)["name"] 321 | return f"{username}/{model_id}" 322 | else: 323 | return f"{organization}/{model_id}" 324 | 325 | 326 | def freeze_params(params): 327 | for param in params: 328 | param.requires_grad = False 329 | 330 | 331 | def main(): 332 | args = parse_args() 333 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 334 | 335 | accelerator = Accelerator( 336 | gradient_accumulation_steps=args.gradient_accumulation_steps, 337 | mixed_precision=args.mixed_precision, 338 | log_with="tensorboard", 339 | logging_dir=logging_dir, 340 | ) 341 | 342 | # If passed along, set the training seed now. 343 | if args.seed is not None: 344 | set_seed(args.seed) 345 | 346 | # Handle the repository creation 347 | if accelerator.is_main_process: 348 | if args.push_to_hub: 349 | if args.hub_model_id is None: 350 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 351 | else: 352 | repo_name = args.hub_model_id 353 | repo = Repository(args.output_dir, clone_from=repo_name) 354 | 355 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 356 | if "step_*" not in gitignore: 357 | gitignore.write("step_*\n") 358 | if "epoch_*" not in gitignore: 359 | gitignore.write("epoch_*\n") 360 | elif args.output_dir is not None: 361 | os.makedirs(args.output_dir, exist_ok=True) 362 | 363 | # Load the tokenizer and add the placeholder token as a additional special token 364 | if args.tokenizer_name: 365 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 366 | elif args.pretrained_model_name_or_path: 367 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 368 | 369 | # Add the placeholder token in tokenizer 370 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) 371 | if num_added_tokens == 0: 372 | raise ValueError( 373 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" 374 | " `placeholder_token` that is not already in the tokenizer." 375 | ) 376 | 377 | # Convert the initializer_token, placeholder_token to ids 378 | token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) 379 | # Check if initializer_token is a single token or a sequence of tokens 380 | if len(token_ids) > 1: 381 | raise ValueError("The initializer token must be a single token.") 382 | 383 | initializer_token_id = token_ids[0] 384 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 385 | 386 | # Load models and create wrapper for stable diffusion 387 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") 388 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") 389 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") 390 | 391 | # Resize the token embeddings as we are adding new special tokens to the tokenizer 392 | text_encoder.resize_token_embeddings(len(tokenizer)) 393 | 394 | # Initialise the newly added placeholder token with the embeddings of the initializer token 395 | token_embeds = text_encoder.get_input_embeddings().weight.data 396 | token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] 397 | 398 | # Freeze vae and unet 399 | freeze_params(vae.parameters()) 400 | freeze_params(unet.parameters()) 401 | # Freeze all parameters except for the token embeddings in text encoder 402 | params_to_freeze = itertools.chain( 403 | text_encoder.text_model.encoder.parameters(), 404 | text_encoder.text_model.final_layer_norm.parameters(), 405 | text_encoder.text_model.embeddings.position_embedding.parameters(), 406 | ) 407 | freeze_params(params_to_freeze) 408 | 409 | if args.scale_lr: 410 | args.learning_rate = ( 411 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 412 | ) 413 | 414 | # Initialize the optimizer 415 | optimizer = torch.optim.AdamW( 416 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings 417 | lr=args.learning_rate, 418 | betas=(args.adam_beta1, args.adam_beta2), 419 | weight_decay=args.adam_weight_decay, 420 | eps=args.adam_epsilon, 421 | ) 422 | 423 | noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") 424 | 425 | train_dataset = TextualInversionDataset( 426 | data_root=args.train_data_dir, 427 | tokenizer=tokenizer, 428 | size=args.resolution, 429 | placeholder_token=args.placeholder_token, 430 | repeats=args.repeats, 431 | learnable_property=args.learnable_property, 432 | center_crop=args.center_crop, 433 | set="train", 434 | ) 435 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) 436 | 437 | # Scheduler and math around the number of training steps. 438 | overrode_max_train_steps = False 439 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 440 | if args.max_train_steps is None: 441 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 442 | overrode_max_train_steps = True 443 | 444 | lr_scheduler = get_scheduler( 445 | args.lr_scheduler, 446 | optimizer=optimizer, 447 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 448 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 449 | ) 450 | 451 | text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 452 | text_encoder, optimizer, train_dataloader, lr_scheduler 453 | ) 454 | 455 | # Move vae and unet to device 456 | vae.to(accelerator.device) 457 | unet.to(accelerator.device) 458 | 459 | # Keep vae and unet in eval model as we don't train these 460 | vae.eval() 461 | unet.eval() 462 | 463 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 464 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 465 | if overrode_max_train_steps: 466 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 467 | # Afterwards we recalculate our number of training epochs 468 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 469 | 470 | # We need to initialize the trackers we use, and also store our configuration. 471 | # The trackers initializes automatically on the main process. 472 | if accelerator.is_main_process: 473 | accelerator.init_trackers("textual_inversion", config=vars(args)) 474 | 475 | # Train! 476 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 477 | 478 | logger.info("***** Running training *****") 479 | logger.info(f" Num examples = {len(train_dataset)}") 480 | logger.info(f" Num Epochs = {args.num_train_epochs}") 481 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 482 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 483 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 484 | logger.info(f" Total optimization steps = {args.max_train_steps}") 485 | # Only show the progress bar once on each machine. 486 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 487 | progress_bar.set_description("Steps") 488 | global_step = 0 489 | 490 | for epoch in range(args.num_train_epochs): 491 | text_encoder.train() 492 | for step, batch in enumerate(train_dataloader): 493 | with accelerator.accumulate(text_encoder): 494 | # Convert images to latent space 495 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() 496 | latents = latents * 0.18215 497 | 498 | # Sample noise that we'll add to the latents 499 | noise = torch.randn(latents.shape).to(latents.device) 500 | bsz = latents.shape[0] 501 | # Sample a random timestep for each image 502 | timesteps = torch.randint( 503 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device 504 | ).long() 505 | 506 | # Add noise to the latents according to the noise magnitude at each timestep 507 | # (this is the forward diffusion process) 508 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 509 | 510 | # Get the text embedding for conditioning 511 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 512 | 513 | # Predict the noise residual 514 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 515 | 516 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 517 | accelerator.backward(loss) 518 | 519 | # Zero out the gradients for all token embeddings except the newly added 520 | # embeddings for the concept, as we only want to optimize the concept embeddings 521 | # if accelerator.num_processes > 1: 522 | # grads = text_encoder.module.get_input_embeddings().weight.grad 523 | # else: 524 | # grads = text_encoder.get_input_embeddings().weight.grad 525 | grads = text_encoder.module.get_input_embeddings().weight.grad 526 | # Get the index for tokens that we want to zero the grads for 527 | index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id 528 | grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) 529 | 530 | optimizer.step() 531 | lr_scheduler.step() 532 | optimizer.zero_grad() 533 | 534 | # Checks if the accelerator has performed an optimization step behind the scenes 535 | if accelerator.sync_gradients: 536 | progress_bar.update(1) 537 | global_step += 1 538 | if global_step % args.save_steps == 0: 539 | save_progress(text_encoder, placeholder_token_id, accelerator, args) 540 | 541 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 542 | progress_bar.set_postfix(**logs) 543 | accelerator.log(logs, step=global_step) 544 | 545 | if global_step >= args.max_train_steps: 546 | break 547 | 548 | accelerator.wait_for_everyone() 549 | 550 | # Create the pipeline using using the trained modules and save it. 551 | if accelerator.is_main_process: 552 | pipeline = StableDiffusionPipeline( 553 | text_encoder=accelerator.unwrap_model(text_encoder), 554 | vae=vae, 555 | unet=unet, 556 | tokenizer=tokenizer, 557 | scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"), 558 | safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), 559 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), 560 | ) 561 | pipeline.save_pretrained(args.output_dir) 562 | # Also save the newly trained embeddings 563 | save_progress(text_encoder, placeholder_token_id, accelerator, args) 564 | 565 | if args.push_to_hub: 566 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) 567 | 568 | accelerator.end_training() 569 | 570 | 571 | if __name__ == "__main__": 572 | main() -------------------------------------------------------------------------------- /tools/upload_cos.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # by ruochen 3 | # 需要先执行 pip install -U cos-python-sdk-v5 4 | from qcloud_cos import CosConfig 5 | from qcloud_cos import CosS3Client 6 | 7 | secret_id = 'abc123' # 替换为用户的 secretId 8 | secret_key = 'abc123' # 替换为用户的 secretKey 9 | region = 'ap-guangzhou' # 替换为用户的 Region 10 | 11 | config = CosConfig(Region=region, SecretId=secret_id, SecretKey=secret_key) 12 | client = CosS3Client(config) 13 | 14 | response = client.upload_file( 15 | Bucket='xxx', # 替换为存储桶名称 16 | LocalFilePath='../ckpt_models/newModel.ckpt', # 本地文件的路径 17 | Key='newModel.ckpt', # 上传之后的文件名 18 | ) 19 | print(response['ETag']) -------------------------------------------------------------------------------- /train_object.sh: -------------------------------------------------------------------------------- 1 | # 用于训练特定物体/人物的方法(只需单一标签) 2 | export MODEL_NAME="./model" 3 | export INSTANCE_DIR="./datasets/test2" 4 | export OUTPUT_DIR="./new_model" 5 | export CLASS_DIR="./datasets/class" # 用于存放模型生成的先验知识的图片文件夹,请勿改动 6 | export LOG_DIR="/root/tf-logs" 7 | export TEST_PROMPTS_FILE="./test_prompts_object.txt" 8 | 9 | rm -rf $CLASS_DIR/* # 如果你要训练与上次不同的特定物体/人物,需要先清空该文件夹。其他时候可以注释掉这一行(前面加#) 10 | rm -rf $LOG_DIR/* 11 | 12 | accelerate launch tools/train_dreambooth.py \ 13 | --train_text_encoder \ 14 | --pretrained_model_name_or_path=$MODEL_NAME \ 15 | --mixed_precision="fp16" \ 16 | --instance_data_dir=$INSTANCE_DIR \ 17 | --instance_prompt="a photo of dog" \ 18 | --with_prior_preservation --prior_loss_weight=1.0 \ 19 | --class_prompt="a photo of dog" \ 20 | --class_data_dir=$CLASS_DIR \ 21 | --num_class_images=200 \ 22 | --output_dir=$OUTPUT_DIR \ 23 | --logging_dir=$LOG_DIR \ 24 | --center_crop \ 25 | --resolution=512 \ 26 | --train_batch_size=1 \ 27 | --gradient_accumulation_steps=1 --gradient_checkpointing \ 28 | --use_8bit_adam \ 29 | --learning_rate=2e-6 \ 30 | --lr_scheduler="constant" \ 31 | --lr_warmup_steps=0 \ 32 | --auto_test_model \ 33 | --test_prompts_file=$TEST_PROMPTS_FILE \ 34 | --test_seed=123 \ 35 | --test_num_per_prompt=3 \ 36 | --max_train_steps=1000 \ 37 | --save_model_every_n_steps=500 38 | 39 | # 如果max_train_steps改大了,请记得把save_model_every_n_steps也改大 40 | # 不然磁盘很容易中间就满了 41 | 42 | # 以下是核心参数介绍: 43 | # 主要的几个 44 | # --train_text_encoder 训练文本编码器 45 | # --mixed_precision="fp16" 混合精度训练 46 | # - center_crop 47 | # 是否裁剪图片,一般如果你的数据集不是正方形的话,需要裁剪 48 | # - resolution 49 | # 图片的分辨率,一般是512,使用该参数会自动缩放输入图像 50 | # 可以配合center_crop使用,达到裁剪成正方形并缩放到512*512的效果 51 | # - instance_prompt 52 | # 如果你希望训练的是特定的人物,使用该参数 53 | # 如 --instance_prompt="a photo of girl" 54 | # - class_prompt 55 | # 如果你希望训练的是某个特定的类别,使用该参数可能提升一定的训练效果 56 | # - use_txt_as_label 57 | # 是否读取与图片同名的txt文件作为label 58 | # 如果你要训练的是整个大模型的图像风格,那么可以使用该参数 59 | # 该选项会忽略instance_prompt参数传入的内容 60 | # - learning_rate 61 | # 学习率,一般是2e-6,是训练中需要调整的关键参数 62 | # 太大会导致模型不收敛,太小的话,训练速度会变慢 63 | # - lr_scheduler, 可选项有constant, linear, cosine, cosine_with_restarts, cosine_with_hard_restarts 64 | # 学习率调整策略,一般是constant,即不调整,如果你的数据集很大,可以尝试其他的,但是可能会导致模型不收敛,需要调整学习率 65 | # - lr_warmup_steps,如果你使用的是constant,那么这个参数可以忽略, 66 | # 如果使用其他的,那么这个参数可以设置为0,即不使用warmup 67 | # 也可以设置为其他的值,比如1000,即在前1000个step中,学习率从0慢慢增加到learning_rate的值 68 | # 一般不需要设置, 除非你的数据集很大,训练收敛很慢 69 | # - max_train_steps 70 | # 训练的最大步数,一般是1000,如果你的数据集比较大,那么可以适当增大该值 71 | # - save_model_every_n_steps 72 | # 每多少步保存一次模型,方便查看中间训练的结果找出最优的模型,也可以用于断点续训 73 | 74 | # --with_prior_preservation,--prior_loss_weight=1.0,分别是使用先验知识保留和先验损失权重 75 | # 如果你的数据样本比较少,那么可以使用这两个参数,可以提升训练效果,还可以防止过拟合(即生成的图片与训练的图片相似度过高) 76 | 77 | # --auto_test_model, --test_prompts_file, --test_seed, --test_num_per_prompt 78 | # 分别是自动测试模型(每save_model_every_n_steps步后)、测试的文本、随机种子、每个文本测试的次数 79 | # 测试的样本图片会保存在模型输出目录下的test文件夹中 -------------------------------------------------------------------------------- /train_object_rect.sh: -------------------------------------------------------------------------------- 1 | # 主要用于训练风格、作画能力(需要每张图片都有对应的标签描述) 2 | export MODEL_NAME="./model" 3 | export INSTANCE_DIR="./datasets/a1" 4 | export OUTPUT_DIR="./new_model" 5 | export CLASS_DIR="./datasets/class" # 用于存放模型生成的先验知识的图片文件夹,请勿改动 6 | export LOG_DIR="/root/tf-logs" 7 | export TEST_PROMPTS_FILE="./test_prompts_object.txt" 8 | 9 | rm -rf $LOG_DIR/* 10 | 11 | accelerate launch tools/train_dreambooth_rect.py \ 12 | --pretrained_model_name_or_path=$MODEL_NAME \ 13 | --mixed_precision="fp16" \ 14 | --instance_data_dir=$INSTANCE_DIR \ 15 | --use_txt_as_label \ 16 | --output_dir=$OUTPUT_DIR \ 17 | --logging_dir=$LOG_DIR \ 18 | --width=768 \ 19 | --height=512 \ 20 | --train_batch_size=1 \ 21 | --use_8bit_adam \ 22 | --gradient_accumulation_steps=1 --gradient_checkpointing \ 23 | --learning_rate=2e-6 \ 24 | --lr_scheduler="constant" \ 25 | --lr_warmup_steps=0 \ 26 | --max_train_steps=1000 \ 27 | --save_model_every_n_steps=500 \ 28 | --auto_test_model \ 29 | --test_prompts_file=$TEST_PROMPTS_FILE \ 30 | --test_seed=123 \ 31 | --test_num_per_prompt=3 32 | 33 | 34 | # 如果max_train_steps改大了,请记得把save_model_every_n_steps也改大 35 | # 不然磁盘很容易中间就满了 36 | 37 | # 以下是核心参数介绍: 38 | # 主要的几个 39 | # --train_text_encoder 训练文本编码器 40 | # --mixed_precision="fp16" 混合精度训练 41 | # - instance_prompt 42 | # 如果你希望训练的是特定的人物,使用该参数 43 | # 如 --instance_prompt="a photo of girl" 44 | # - use_txt_as_label 45 | # 是否读取与图片同名的txt文件作为label 46 | # 如果你要训练的是整个大模型的图像风格,那么可以使用该参数 47 | # 该选项会忽略instance_prompt参数传入的内容 48 | # - learning_rate 49 | # 学习率,一般是2e-6,是训练中需要调整的关键参数 50 | # 太大会导致模型不收敛,太小的话,训练速度会变慢 51 | # - max_train_steps 52 | # 训练的最大步数,一般是1000,如果你的数据集比较大,那么可以适当增大该值 53 | # - save_model_every_n_steps 54 | # 每多少步保存一次模型,方便查看中间训练的结果找出最优的模型,也可以用于断点续训 55 | 56 | # --train_text_encoder # 除了图像生成器,也训练文本编码器 57 | 58 | # --auto_test_model, --test_prompts_file, --test_seed, --test_num_per_prompt 59 | # 分别是自动测试模型(每save_model_every_n_steps步后)、测试的文本、随机种子、每个文本测试的次数 60 | # 测试的样本图片会保存在模型输出目录下的test文件夹中 -------------------------------------------------------------------------------- /train_style.sh: -------------------------------------------------------------------------------- 1 | # 主要用于训练风格、作画能力(需要每张图片都有对应的标签描述) 2 | export MODEL_NAME="./model" 3 | export INSTANCE_DIR="./datasets/test2" 4 | export OUTPUT_DIR="./new_model" 5 | export LOG_DIR="/root/tf-logs" 6 | export TEST_PROMPTS_FILE="./test_prompts_style.txt" 7 | 8 | rm -rf $LOG_DIR/* 9 | 10 | accelerate launch tools/train_dreambooth.py \ 11 | --pretrained_model_name_or_path=$MODEL_NAME \ 12 | --mixed_precision="fp16" \ 13 | --instance_data_dir=$INSTANCE_DIR \ 14 | --use_txt_as_label \ 15 | --output_dir=$OUTPUT_DIR \ 16 | --logging_dir=$LOG_DIR \ 17 | --center_crop \ 18 | --resolution=768 \ 19 | --train_batch_size=1 \ 20 | --use_8bit_adam \ 21 | --gradient_accumulation_steps=1 --gradient_checkpointing \ 22 | --learning_rate=2e-6 \ 23 | --lr_scheduler="constant" \ 24 | --lr_warmup_steps=0 \ 25 | --max_train_steps=1000 \ 26 | --save_model_every_n_steps=500 \ 27 | --auto_test_model \ 28 | --test_prompts_file=$TEST_PROMPTS_FILE \ 29 | --test_seed=123 \ 30 | --test_num_per_prompt=3 31 | 32 | # 如果max_train_steps改大了,请记得把save_model_every_n_steps也改大,不然磁盘容易中间就满了 33 | 34 | # 以下是核心参数介绍: 35 | # 主要的几个 36 | # --train_text_encoder 训练文本编码器 37 | # --mixed_precision="fp16" 混合精度训练 38 | # - center_crop 39 | # 是否裁剪图片,一般如果你的数据集不是正方形的话,需要裁剪 40 | # - resolution 41 | # 图片的分辨率,一般是512,使用该参数会自动缩放输入图像 42 | # 可以配合center_crop使用,达到裁剪成正方形并缩放到512*512的效果 43 | # - instance_prompt 44 | # 如果你希望训练的是特定的人物,使用该参数 45 | # 如 --instance_prompt="a photo of girl" 46 | # - use_txt_as_label 47 | # 是否读取与图片同名的txt文件作为label 48 | # 如果你要训练的是整个大模型的图像风格,那么可以使用该参数 49 | # 该选项会忽略instance_prompt参数传入的内容 50 | # - learning_rate 51 | # 学习率,一般是2e-6,是训练中需要调整的关键参数 52 | # 太大会导致模型不收敛,太小的话,训练速度会变慢 53 | # - max_train_steps 54 | # 训练的最大步数,一般是1000,如果你的数据集比较大,那么可以适当增大该值 55 | # - save_model_every_n_steps 56 | # 每多少步保存一次模型,方便查看中间训练的结果找出最优的模型,也可以用于断点续训 57 | 58 | # --train_text_encoder # 除了图像生成器,也训练文本编码器 59 | 60 | # --auto_test_model, --test_prompts_file, --test_seed, --test_num_per_prompt 61 | # 分别是自动测试模型(每save_model_every_n_steps步后)、测试的文本、随机种子、每个文本测试的次数 62 | # 测试的样本图片会保存在模型输出目录下的test文件夹中 -------------------------------------------------------------------------------- /train_style_rect.sh: -------------------------------------------------------------------------------- 1 | # 主要用于训练风格、作画能力(需要每张图片都有对应的标签描述)[矩形输入图片] 2 | export MODEL_NAME="./model" 3 | export INSTANCE_DIR="./datasets/a1" 4 | export OUTPUT_DIR="./new_model" 5 | export LOG_DIR="/root/tf-logs" 6 | export TEST_PROMPTS_FILE="./test_prompts_style.txt" 7 | 8 | rm -rf $LOG_DIR/* 9 | 10 | accelerate launch tools/train_dreambooth_rect.py \ 11 | --pretrained_model_name_or_path=$MODEL_NAME \ 12 | --mixed_precision="fp16" \ 13 | --instance_data_dir=$INSTANCE_DIR \ 14 | --use_txt_as_label \ 15 | --output_dir=$OUTPUT_DIR \ 16 | --logging_dir=$LOG_DIR \ 17 | --width=512 \ 18 | --height=768 \ 19 | --train_batch_size=1 \ 20 | --use_8bit_adam \ 21 | --gradient_accumulation_steps=1 --gradient_checkpointing \ 22 | --learning_rate=2e-6 \ 23 | --lr_scheduler="constant" \ 24 | --lr_warmup_steps=0 \ 25 | --max_train_steps=1000 \ 26 | --save_model_every_n_steps=500 \ 27 | --auto_test_model \ 28 | --test_prompts_file=$TEST_PROMPTS_FILE \ 29 | --test_seed=123 \ 30 | --test_num_per_prompt=3 31 | 32 | 33 | # 如果max_train_steps改大了,请记得把save_model_every_n_steps也改大 34 | # 不然磁盘很容易中间就满了 35 | 36 | # 以下是核心参数介绍: 37 | # 主要的几个 38 | # --train_text_encoder 训练文本编码器 39 | # --mixed_precision="fp16" 混合精度训练 40 | # - instance_prompt 41 | # 如果你希望训练的是特定的人物,使用该参数 42 | # 如 --instance_prompt="a photo of girl" 43 | # - use_txt_as_label 44 | # 是否读取与图片同名的txt文件作为label 45 | # 如果你要训练的是整个大模型的图像风格,那么可以使用该参数 46 | # 该选项会忽略instance_prompt参数传入的内容 47 | # - learning_rate 48 | # 学习率,一般是2e-6,是训练中需要调整的关键参数 49 | # 太大会导致模型不收敛,太小的话,训练速度会变慢 50 | # - max_train_steps 51 | # 训练的最大步数,一般是1000,如果你的数据集比较大,那么可以适当增大该值 52 | # - save_model_every_n_steps 53 | # 每多少步保存一次模型,方便查看中间训练的结果找出最优的模型,也可以用于断点续训 54 | 55 | # --train_text_encoder # 除了图像生成器,也训练文本编码器 56 | 57 | # --auto_test_model, --test_prompts_file, --test_seed, --test_num_per_prompt 58 | # 分别是自动测试模型(每save_model_every_n_steps步后)、测试的文本、随机种子、每个文本测试的次数 59 | # 测试的样本图片会保存在模型输出目录下的test文件夹中 -------------------------------------------------------------------------------- /train_textual_inversion.sh: -------------------------------------------------------------------------------- 1 | # 这是另一种finetune模型的方法,名为textual inversion,效果一般,仅内置一份供参考。 2 | # 提示:该方法训练出的概念编码只能在diffusers使用。暂时不支持在diffusers之外的推理框架使用。(如webui) 3 | #!/sbin/bash 4 | export LOG_DIR="/root/tf-logs" 5 | 6 | accelerate launch ./tools/train_textual_inversion.py \ 7 | --pretrained_model_name_or_path="./model/" \ 8 | --train_data_dir="./datasets/test" \ 9 | --learnable_property="style" \ 10 | --placeholder_token="" --initializer_token="girl" \ 11 | --resolution=512 \ 12 | --train_batch_size=1 \ 13 | --gradient_accumulation_steps=4 \ 14 | --learning_rate=5.0e-04 --scale_lr \ 15 | --lr_scheduler="constant" \ 16 | --lr_warmup_steps=0 \ 17 | --save_steps=200 \ 18 | --max_train_steps=3000 \ 19 | --mixed_precision="fp16" \ 20 | --logging_dir=$LOG_DIR \ 21 | --output_dir="output_model" 22 | 23 | # --learnable_property为style时训练特定风格,为object时训练特定物体/人物。 24 | # --placeholder_token为训练时的占位符,--initializer_token为训练时的初始化词。 25 | # --resolution为训练时的分辨率,--train_batch_size为训练时的batch size,--gradient_accumulation_steps为梯度累积步数。 26 | # --learning_rate为训练时的学习率,--scale_lr为是否对学习率进行缩放,--lr_scheduler为学习率调度器,--lr_warmup_steps为学习率预热步数。 27 | # --save_steps为保存模型的步数,--max_train_steps为最大训练步数,--mixed_precision为混合精度训练模式。 28 | # --logging_dir为日志保存路径,--output_dir为模型保存路径。 29 | # --pretrained_model_name_or_path为预训练模型路径,--train_data_dir为训练数据路径,必须为文件夹,文件夹内为处理后的图片。 -------------------------------------------------------------------------------- /运行.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "id": "a0b34c19-4215-46f9-9def-65e73629665c", 7 | "metadata": {}, 8 | "source": [ 9 | "# Dreambooth Stable Diffusion 集成化环境训练\n", 10 | "如果你是在autodl上的机器可以直接使用封装好的镜像创建实例,开箱即用 \n", 11 | "如果是本地或者其他服务器上也可以使用,需要手动安装一些pip包\n", 12 | "\n", 13 | "## 注意\n", 14 | "本项目仅供用于学习、测试人工智能技术使用 \n", 15 | "请勿用于训练生成不良或侵权图片内容\n", 16 | "\n", 17 | "## 关于项目\n", 18 | "在autodl封装的镜像名称为:dreambooth-for-diffusion \n", 19 | "可在创建实例时直接选择公开的算法镜像使用。 \n", 20 | "在autodl内蒙A区A5000的机器上封装,如遇到问题且无法自行解决的朋友请使用同一环境。 \n", 21 | "白菜写教程时做了尽可能多的测试,但仍然无法确保每一个环节都完全覆盖 \n", 22 | "如有小错误可尝试手动解决,或者访问git项目地址查看最新的README \n", 23 | "项目地址:https://github.com/CrazyBoyM/dreambooth-for-diffusion\n", 24 | "\n", 25 | "## #强烈建议\n", 26 | "1.用vscode的ssh功能远程连接到本服务器,训练体验更好,autodl自带的notebook也不错,有文件上传、下载功能。 \n", 27 | "(vscode连接autodl教程:https://www.autodl.com/docs/vscode/ ) \n", 28 | "### 2.(重要)把train文件夹整个移动到/root/autodl-tmp/路径下进行训练(数据盘),避免系统盘空间满\n", 29 | "有的机器数据盘也很小,需要自行关注开合适的机器或进行扩容\n", 30 | "\n", 31 | "如果遇到问题可到b站主页找该教程对应训练演示的视频:https://space.bilibili.com/291593914\n", 32 | "(因为现在写时视频还没做 \n", 33 | "\n", 34 | "## 服务器的数据迁移\n", 35 | "经常关机后再开机发现机器资源被占用了,这时候你只能另外开一台机器了 \n", 36 | "但是对于已经关机的机器在菜单上有个功能是“跨实例拷贝数据”, \n", 37 | "可以很方便地同步/root/autodl-tmp文件夹下的内容到其他已开机的机器(所以推荐工作文件都放这) \n", 38 | "(注意,只适用于同一区域的机器之间)\n", 39 | "数据迁移教程:https://www.autodl.com/docs/migrate_instance/" 40 | ] 41 | }, 42 | { 43 | "attachments": {}, 44 | "cell_type": "markdown", 45 | "id": "f091e609-bacc-469a-b6cf-bffe331a8944", 46 | "metadata": {}, 47 | "source": [ 48 | "### 本文件为notebook在线运行版\n", 49 | "具体详细的教程和参数说明请在根目录下教程.md 文件中查看。 \n", 50 | "在notebook中执行linux命令,需要前面加个!(感叹号) \n", 51 | "代码块前如果有个[*],表示正在运行该步骤,并不是卡住了\n" 52 | ] 53 | }, 54 | { 55 | "attachments": {}, 56 | "cell_type": "markdown", 57 | "id": "3555d8bd-fb3f-4303-8915-ec6fefcc780c", 58 | "metadata": {}, 59 | "source": [ 60 | "# 笔者前言\n", 61 | "\n", 62 | "linux压缩一个文件夹为单个文件包的命令:\n", 63 | "```\n", 64 | "!zip xx.zip -r ./xxx\n", 65 | "```\n", 66 | "解压一个包到文件夹:\n", 67 | "```\n", 68 | "!unzip xx.zip -d xxx\n", 69 | "```\n", 70 | "或许你在上传、下载数据集时会用到。\n", 71 | "\n", 72 | "其他linux基础命令:https://www.autodl.com/docs/linux/\n", 73 | "\n", 74 | "关于文件上传下载的提速可查看官网文档推荐的几种方式:https://www.autodl.com/docs/scp/" 75 | ] 76 | }, 77 | { 78 | "attachments": {}, 79 | "cell_type": "markdown", 80 | "id": "34cf6ed1-f2b1-4abd-baf6-565ac00567ab", 81 | "metadata": {}, 82 | "source": [ 83 | "### 首先,进入工作文件夹(记得先把dreambooth-for-diffusion文件夹移动到autodl-tmp目录下)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "id": "a1249a32-ce15-4b1b-8068-8149ad40588b", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "%cd /root/autodl-tmp/dreambooth-for-diffusion" 94 | ] 95 | }, 96 | { 97 | "attachments": {}, 98 | "cell_type": "markdown", 99 | "id": "ccba0e31-f01d-43e5-b474-7d88e0b09bd8", 100 | "metadata": {}, 101 | "source": [ 102 | "# 准备数据集\n", 103 | "该部分请参考教程.md文件中的详细内容自行上传并处理你的数据集 \n", 104 | "dreambooth-for-diffusion/datasets/test中为16张仅供于学习测试的样本数据,便于你了解以下代码的用处 \n" 105 | ] 106 | }, 107 | { 108 | "attachments": {}, 109 | "cell_type": "markdown", 110 | "id": "470113f6-795a-41f8-a6b3-09f854a4cbc3", 111 | "metadata": {}, 112 | "source": [ 113 | "## 一键裁剪\n", 114 | "### 图像批量center crop裁剪(正方形裁剪)\n", 115 | "./datasets/test是原始图片数据文件夹,请上传你的图片数据并进行更换 \n", 116 | "width和height请设置为8的整倍数,并记得修改训练脚本中的参数 \n", 117 | "(在显存低于20G的设备上请修改使用小于768的分辨率数据去训练,比如512) \n", 118 | "如果是对透明底的png图处理成纯色底可以加--png参数,具体可以看对应的代码文件" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 18, 124 | "id": "8b696a59", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "!python tools/handle_images.py --origin_image_path ./datasets/test \\\n", 129 | " --output_image_path ./datasets/test2 \\\n", 130 | " --width=768 --height=768" 131 | ] 132 | }, 133 | { 134 | "attachments": {}, 135 | "cell_type": "markdown", 136 | "id": "2633307a", 137 | "metadata": {}, 138 | "source": [ 139 | "[可选] 保留更高质量的裁剪(矩形裁剪) \n", 140 | "-- 不需要修改width和height,自适应" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 20, 146 | "id": "10d2bb3d-9002-4d3b-a4be-f5f74a008b9c", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "!python tools/handle_images.py --origin_image_path ./datasets/test \\\n", 151 | " --output_image_path_0 ./datasets/a1 --output_image_path_1 ./datasets/a2 \\\n", 152 | " --width=768 --height=512 " 153 | ] 154 | }, 155 | { 156 | "attachments": {}, 157 | "cell_type": "markdown", 158 | "id": "34efda73-9cb4-4a54-8aac-489ded452a50", 159 | "metadata": {}, 160 | "source": [ 161 | "## 一键打标签\n", 162 | "### 图像批量自动标注\n", 163 | "使用deepdanbooru生成tags标注文件。(仅针对纯二次元类图片效果较好,其他风格请手动标注) \n", 164 | "./datasets/test2中是需要打标注的图片数据,请按需更换为自己的路径 " 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 2, 170 | "id": "8863a53a-4650-4f27-863e-2a70e8b89e11", 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "# 该步根据需要标注文件数量不同,需要运行一段时间(测试6000张图片需要10分钟)\n", 175 | "!python tools/label_images.py --path=./datasets/test2 --model_path='./deepdanbooru/'" 176 | ] 177 | }, 178 | { 179 | "attachments": {}, 180 | "cell_type": "markdown", 181 | "id": "def72b19-9851-400f-8672-48023b3e95fb", 182 | "metadata": {}, 183 | "source": [ 184 | "## 转换ckpt检查点文件为diffusers官方权重\n", 185 | "输出的文件在dreambooth-for-diffusion/model下 \n", 186 | "./ckpt_models/sd_1-5.ckpt需要更换为你自己的权重文件路径 " 187 | ] 188 | }, 189 | { 190 | "attachments": {}, 191 | "cell_type": "markdown", 192 | "id": "0582e3c4-e899-4a3b-a468-d49e7775efc6", 193 | "metadata": {}, 194 | "source": [ 195 | "如需转换写实风格模型:" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "id": "05aaf7fd-315f-45b4-9b22-70a46a18424f", 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "# 该步需要运行大约一分钟 \n", 206 | "!python tools/ckpt2diffusers.py \\\n", 207 | " --checkpoint_path=./ckpt_models/sd_1-5.ckpt \\\n", 208 | " --dump_path=./model \\\n", 209 | " --original_config_file=./ckpt_models/model.yaml \\\n", 210 | " --scheduler_type=\"ddim\"" 211 | ] 212 | }, 213 | { 214 | "attachments": {}, 215 | "cell_type": "markdown", 216 | "id": "48c7893a-22db-4ea2-95dc-93fdbd6b5c4b", 217 | "metadata": {}, 218 | "source": [ 219 | "如需转换二次元风格模型:" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "id": "f7afb70d-7af4-4bd1-804e-40927f1257e2", 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "# 该步需要运行大约一分钟 \n", 230 | "!python tools/ckpt2diffusers.py \\\n", 231 | " --checkpoint_path=./ckpt_models/nd_lastest.ckpt \\\n", 232 | " --dump_path=./model \\\n", 233 | " --vae_path=./ckpt_models/animevae.pt \\\n", 234 | " --original_config_file=./ckpt_models/model.yaml \\\n", 235 | " --scheduler_type=\"ddim\"" 236 | ] 237 | }, 238 | { 239 | "attachments": {}, 240 | "cell_type": "markdown", 241 | "id": "a1edb9be-1de3-488e-baa3-8f3ab6b8f269", 242 | "metadata": {}, 243 | "source": [ 244 | "对于需要转换某个特殊模型(7g)并遇到报错的同学,ckpt_models里的nd_lastest.ckpt就是你需要的文件。 \n", 245 | "如果希望手动转换,我在./tools下放了一份ckpt_prune.py可以参考。" 246 | ] 247 | }, 248 | { 249 | "attachments": {}, 250 | "cell_type": "markdown", 251 | "id": "3a3470d3-1691-438c-b8d7-df2cbf885614", 252 | "metadata": {}, 253 | "source": [ 254 | "# 训练Unet和text encoder\n", 255 | "以下训练脚本会自动帮你启动tensorboard日志监控进程,入口可参考: https://www.autodl.com/docs/tensorboard/ \n", 256 | "使用tensorboard面板可以帮助分析loss在不同step的总体下降情况 \n", 257 | "如果你嫌输出太长,可以在以下命令每一行后加一句 &> log.txt, 会把输出都扔到这个文件中 \n", 258 | "```\n", 259 | "!sh train_style.sh &> log.txt\n", 260 | "```\n", 261 | "本代码包环境已在A5000、3090测试通过,如果你在某些机器上运行遇到问题可以尝试卸载编译的xformers\n", 262 | "```\n", 263 | "!pip uninstall xformers\n", 264 | "```" 265 | ] 266 | }, 267 | { 268 | "attachments": {}, 269 | "cell_type": "markdown", 270 | "id": "98645b45-4cf1-49f8-b2bb-42a5a8771164", 271 | "metadata": {}, 272 | "source": [ 273 | "### 如果需要训练特定人、事物: \n", 274 | "(推荐准备3~5张风格统一、特定对象的图片) \n", 275 | "请打开train_object.sh具体修改里面的参数" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "id": "8b6833e3-8d3f-438a-b45d-0711e9724496", 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "# 大约十分钟后才会在tensorboard有日志(因为前十分钟在生成同类别伪图)\n", 286 | "!sh train_object.sh " 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "id": "c7b8f2e4", 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "# 选择矩形图像数据集训练\n", 297 | "!sh train_object_rect.sh --width 512 --height 768 ## " 298 | ] 299 | }, 300 | { 301 | "attachments": {}, 302 | "cell_type": "markdown", 303 | "id": "594a0352-8bb5-45de-bb19-0028b671569b", 304 | "metadata": {}, 305 | "source": [ 306 | "### 如果要训练画风: \n", 307 | "(推荐准备3000+张图片,包含尽可能的多样性,数据决定训练出的模型质量) \n", 308 | "请打开train_object具体修改里面的参数 " 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 3, 314 | "id": "442cff33-d264-4096-97e2-0c578229c814", 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "# 正常训练立刻就可以在tensorboard看到日志\n", 319 | "# 如果输入图像是正方形\n", 320 | "!sh train_style.sh " 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "id": "f387b807", 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "# 如果输入图像是同等的矩形\n", 331 | "!sh train_object_rect.sh --width 768 --height 512 " 332 | ] 333 | }, 334 | { 335 | "attachments": {}, 336 | "cell_type": "markdown", 337 | "id": "3aa1d170-e2d1-4f72-8b0c-b6bfd5f0c318", 338 | "metadata": {}, 339 | "source": [ 340 | "后台训练法请参考教程.md中的内容" 341 | ] 342 | }, 343 | { 344 | "attachments": {}, 345 | "cell_type": "markdown", 346 | "id": "9aece8a8-c9ec-41eb-b6ad-c6c88b6203e1", 347 | "metadata": {}, 348 | "source": [ 349 | "省钱训练法(训练成功后自动关机,适合步数很大且夜晚训练的场景)" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 6, 355 | "id": "52fff58d-1a88-4a59-a961-b13b52812425", 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "!sh back_train.sh" 360 | ] 361 | }, 362 | { 363 | "attachments": {}, 364 | "cell_type": "markdown", 365 | "id": "17557280-3a5a-4bde-95c3-f20e1ccffa4d", 366 | "metadata": {}, 367 | "source": [ 368 | "## 拓展:训练Textual inversion\n" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "id": "36a543ee-56f8-405a-baaa-b784d96c7d40", 375 | "metadata": {}, 376 | "outputs": [], 377 | "source": [ 378 | "!sh train_textual_inversion.sh" 379 | ] 380 | }, 381 | { 382 | "attachments": {}, 383 | "cell_type": "markdown", 384 | "id": "f467b2e9-9170-4f19-aea9-7ce0b4e5444e", 385 | "metadata": {}, 386 | "source": [ 387 | "### 测试训练效果\n", 388 | "打开dreambooth-for-diffusion/test_model.py文件修改其中的model_path和prompt,然后执行以下测试 \n", 389 | "会生成一张图片 在左侧test-1、2、3.png" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 4, 395 | "id": "b462f33b-48e2-4092-b3de-463025e4ff9e", 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "# 大约5~10s \n", 400 | "!python test_model.py" 401 | ] 402 | }, 403 | { 404 | "attachments": {}, 405 | "cell_type": "markdown", 406 | "id": "47abb5fd-2f84-4344-a9cf-539b52515971", 407 | "metadata": {}, 408 | "source": [ 409 | "### 转换diffusers官方权重为ckpt检查点文件\n", 410 | "输出的文件在dreambooth-for-diffusion/ckpt_models/中,名为newModel.ckpt" 411 | ] 412 | }, 413 | { 414 | "attachments": {}, 415 | "cell_type": "markdown", 416 | "id": "5bfe9643-ef1d-42a3-a427-c4904f3a8631", 417 | "metadata": {}, 418 | "source": [ 419 | "原始保存:" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "id": "2ad27225-10ed-4b3c-9978-bd909404949c", 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "!python tools/diffusers2ckpt.py ./new_model ./ckpt_models/newModel.ckpt " 430 | ] 431 | }, 432 | { 433 | "attachments": {}, 434 | "cell_type": "markdown", 435 | "id": "b08a5e37-97d3-4c1e-9ba7-e331af23437f", 436 | "metadata": {}, 437 | "source": [ 438 | "以下代码添加--half 保存float16半精度,权重文件大小会减半(约2g),效果基本一致" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": null, 444 | "id": "cba99145-6aab-41b6-a5b7-6e0c4fd96641", 445 | "metadata": {}, 446 | "outputs": [], 447 | "source": [ 448 | "!python tools/diffusers2ckpt.py ./new_model ./ckpt_models/newModel_half.ckpt --half" 449 | ] 450 | }, 451 | { 452 | "attachments": {}, 453 | "cell_type": "markdown", 454 | "id": "d1f98d06-27f3-45b6-85df-c57cda5d6166", 455 | "metadata": {}, 456 | "source": [ 457 | "下载ckpt文件,去玩吧~" 458 | ] 459 | }, 460 | { 461 | "attachments": {}, 462 | "cell_type": "markdown", 463 | "id": "b13f0627-1d0a-4ae2-ab9c-90a605ee4a0e", 464 | "metadata": {}, 465 | "source": [ 466 | "有问题可以进XDiffusion QQ Group:455521885 " 467 | ] 468 | }, 469 | { 470 | "attachments": {}, 471 | "cell_type": "markdown", 472 | "id": "b939a03f-23c9-410d-89be-02e154eeb6b4", 473 | "metadata": {}, 474 | "source": [ 475 | "### 记得定期清理不需要的中间权重和文件,不然容易导致空间满\n", 476 | "大部分问题已在教程.md中详细记录,也包含其他非autodl机器手动部署该训练一体化封装代码包的步骤" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": null, 482 | "id": "3236d62e-fa3d-4826-874e-431f208cfb6d", 483 | "metadata": {}, 484 | "outputs": [], 485 | "source": [ 486 | "# 清理文件的示例\n", 487 | "!rm -rf ./model* # 删除当前目录model文件/文件夹\n", 488 | "!rm -rf ./new_* # 删除当前目录所有new_开头的模型文件夹\n", 489 | "# !rm -rf ./datasets/test2 #删除datasets中的test2数据集 " 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": null, 495 | "id": "224924ae-2d6d-47d0-aa36-0989a6572bd2", 496 | "metadata": {}, 497 | "outputs": [], 498 | "source": [] 499 | } 500 | ], 501 | "metadata": { 502 | "kernelspec": { 503 | "display_name": "Python 3 (ipykernel)", 504 | "language": "python", 505 | "name": "python3" 506 | }, 507 | "language_info": { 508 | "codemirror_mode": { 509 | "name": "ipython", 510 | "version": 3 511 | }, 512 | "file_extension": ".py", 513 | "mimetype": "text/x-python", 514 | "name": "python", 515 | "nbconvert_exporter": "python", 516 | "pygments_lexer": "ipython3", 517 | "version": "3.9.12" 518 | } 519 | }, 520 | "nbformat": 4, 521 | "nbformat_minor": 5 522 | } 523 | --------------------------------------------------------------------------------