├── environment.yaml ├── requirements.txt ├── README.md ├── merge_dd_tags_to_metadata.py ├── merge_captions_to_metadata.py ├── hypernetwork_nai.py ├── clean_captions_and_tags.py ├── make_captions.py ├── tag_images_by_wd14_tagger.py ├── prepare_buckets_latents.py ├── fine_tune.py └── model_util.py /environment.yaml: -------------------------------------------------------------------------------- 1 | name: wd14tagger 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - git 7 | - python=3.8.10 8 | - pip=20.3 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | transformers>=4.21.0 3 | ftfy 4 | albumentations 5 | opencv-python 6 | einops 7 | pytorch_lightning 8 | safetensors 9 | tensorflow 10 | diffusers[torch]==0.10.2 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # WD14Tagger 3 | 4 | Automatically tag images with booru tags. 5 | 6 | ## Requirements 7 | 8 | This project requires cudnn 8 to properly use the GPU. 9 | 10 | On Arch: 11 | ```bash 12 | sudo pacman -S cudnn 13 | ``` 14 | 15 | ## Installation 16 | 17 | ```bash 18 | git clone https://github.com/KutsuyaYuki/WD14Tagger 19 | cd WD14Tagger 20 | conda env create -f environment.yaml 21 | conda activate wd14tagger 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ## Usage/Examples 26 | 27 | ```bash 28 | python tag_images_by_wd14_tagger.py \ 29 | input \ 30 | --batch_size 4 \ 31 | --caption_extension .txt 32 | ``` 33 | 34 | Change input to the folder where your images are located. For example, if they are located in a folder called images on your desktop: 35 | 36 | ```bash 37 | python tag_images_by_wd14_tagger.py \ 38 | ~/Desktop/images \ 39 | --batch_size 4 \ 40 | --caption_extension .txt 41 | ``` 42 | -------------------------------------------------------------------------------- /merge_dd_tags_to_metadata.py: -------------------------------------------------------------------------------- 1 | # このスクリプトのライセンスは、Apache License 2.0とします 2 | # (c) 2022 Kohya S. @kohya_ss 3 | 4 | import argparse 5 | import glob 6 | import os 7 | import json 8 | 9 | from tqdm import tqdm 10 | 11 | 12 | def main(args): 13 | image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) 14 | print(f"found {len(image_paths)} images.") 15 | 16 | if args.in_json is None and os.path.isfile(args.out_json): 17 | args.in_json = args.out_json 18 | 19 | if args.in_json is not None: 20 | print(f"loading existing metadata: {args.in_json}") 21 | with open(args.in_json, "rt", encoding='utf-8') as f: 22 | metadata = json.load(f) 23 | print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") 24 | else: 25 | print("new metadata will be created / 新しいメタデータファイルが作成されます") 26 | metadata = {} 27 | 28 | print("merge tags to metadata json.") 29 | for image_path in tqdm(image_paths): 30 | tags_path = os.path.splitext(image_path)[0] + '.txt' 31 | with open(tags_path, "rt", encoding='utf-8') as f: 32 | tags = f.readlines()[0].strip() 33 | 34 | image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] 35 | if image_key not in metadata: 36 | metadata[image_key] = {} 37 | 38 | metadata[image_key]['tags'] = tags 39 | if args.debug: 40 | print(image_key, tags) 41 | 42 | # metadataを書き出して終わり 43 | print(f"writing metadata: {args.out_json}") 44 | with open(args.out_json, "wt", encoding='utf-8') as f: 45 | json.dump(metadata, f, indent=2) 46 | print("done!") 47 | 48 | 49 | if __name__ == '__main__': 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 52 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 53 | parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") 54 | parser.add_argument("--full_path", action="store_true", 55 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") 56 | parser.add_argument("--debug", action="store_true", help="debug mode, print tags") 57 | 58 | args = parser.parse_args() 59 | main(args) 60 | -------------------------------------------------------------------------------- /merge_captions_to_metadata.py: -------------------------------------------------------------------------------- 1 | # このスクリプトのライセンスは、Apache License 2.0とします 2 | # (c) 2022 Kohya S. @kohya_ss 3 | 4 | import argparse 5 | import glob 6 | import os 7 | import json 8 | 9 | from tqdm import tqdm 10 | 11 | 12 | def main(args): 13 | image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) 14 | print(f"found {len(image_paths)} images.") 15 | 16 | if args.in_json is None and os.path.isfile(args.out_json): 17 | args.in_json = args.out_json 18 | 19 | if args.in_json is not None: 20 | print(f"loading existing metadata: {args.in_json}") 21 | with open(args.in_json, "rt", encoding='utf-8') as f: 22 | metadata = json.load(f) 23 | print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") 24 | else: 25 | print("new metadata will be created / 新しいメタデータファイルが作成されます") 26 | metadata = {} 27 | 28 | print("merge caption texts to metadata json.") 29 | for image_path in tqdm(image_paths): 30 | caption_path = os.path.splitext(image_path)[0] + args.caption_extension 31 | with open(caption_path, "rt", encoding='utf-8') as f: 32 | caption = f.readlines()[0].strip() 33 | 34 | image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] 35 | if image_key not in metadata: 36 | metadata[image_key] = {} 37 | 38 | metadata[image_key]['caption'] = caption 39 | if args.debug: 40 | print(image_key, caption) 41 | 42 | # metadataを書き出して終わり 43 | print(f"writing metadata: {args.out_json}") 44 | with open(args.out_json, "wt", encoding='utf-8') as f: 45 | json.dump(metadata, f, indent=2) 46 | print("done!") 47 | 48 | 49 | if __name__ == '__main__': 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 52 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 53 | parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") 54 | parser.add_argument("--caption_extention", type=str, default=None, 55 | help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)") 56 | parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子") 57 | parser.add_argument("--full_path", action="store_true", 58 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") 59 | parser.add_argument("--debug", action="store_true", help="debug mode") 60 | 61 | args = parser.parse_args() 62 | 63 | # スペルミスしていたオプションを復元する 64 | if args.caption_extention is not None: 65 | args.caption_extension = args.caption_extention 66 | 67 | main(args) 68 | -------------------------------------------------------------------------------- /hypernetwork_nai.py: -------------------------------------------------------------------------------- 1 | # NAI compatible 2 | 3 | import torch 4 | 5 | 6 | class HypernetworkModule(torch.nn.Module): 7 | def __init__(self, dim, multiplier=1.0): 8 | super().__init__() 9 | 10 | linear1 = torch.nn.Linear(dim, dim * 2) 11 | linear2 = torch.nn.Linear(dim * 2, dim) 12 | linear1.weight.data.normal_(mean=0.0, std=0.01) 13 | linear1.bias.data.zero_() 14 | linear2.weight.data.normal_(mean=0.0, std=0.01) 15 | linear2.bias.data.zero_() 16 | linears = [linear1, linear2] 17 | 18 | self.linear = torch.nn.Sequential(*linears) 19 | self.multiplier = multiplier 20 | 21 | def forward(self, x): 22 | return x + self.linear(x) * self.multiplier 23 | 24 | 25 | class Hypernetwork(torch.nn.Module): 26 | enable_sizes = [320, 640, 768, 1280] 27 | # return self.modules[Hypernetwork.enable_sizes.index(size)] 28 | 29 | def __init__(self, multiplier=1.0) -> None: 30 | super().__init__() 31 | self.modules = [] 32 | for size in Hypernetwork.enable_sizes: 33 | self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier))) 34 | self.register_module(f"{size}_0", self.modules[-1][0]) 35 | self.register_module(f"{size}_1", self.modules[-1][1]) 36 | 37 | def apply_to_stable_diffusion(self, text_encoder, vae, unet): 38 | blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks 39 | for block in blocks: 40 | for subblk in block: 41 | if 'SpatialTransformer' in str(type(subblk)): 42 | for tf_block in subblk.transformer_blocks: 43 | for attn in [tf_block.attn1, tf_block.attn2]: 44 | size = attn.context_dim 45 | if size in Hypernetwork.enable_sizes: 46 | attn.hypernetwork = self 47 | else: 48 | attn.hypernetwork = None 49 | 50 | def apply_to_diffusers(self, text_encoder, vae, unet): 51 | blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks 52 | for block in blocks: 53 | if hasattr(block, 'attentions'): 54 | for subblk in block.attentions: 55 | if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~ 56 | for tf_block in subblk.transformer_blocks: 57 | for attn in [tf_block.attn1, tf_block.attn2]: 58 | size = attn.to_k.in_features 59 | if size in Hypernetwork.enable_sizes: 60 | attn.hypernetwork = self 61 | else: 62 | attn.hypernetwork = None 63 | return True # TODO error checking 64 | 65 | def forward(self, x, context): 66 | size = context.shape[-1] 67 | assert size in Hypernetwork.enable_sizes 68 | module = self.modules[Hypernetwork.enable_sizes.index(size)] 69 | return module[0].forward(context), module[1].forward(context) 70 | 71 | def load_from_state_dict(self, state_dict): 72 | # old ver to new ver 73 | changes = { 74 | 'linear1.bias': 'linear.0.bias', 75 | 'linear1.weight': 'linear.0.weight', 76 | 'linear2.bias': 'linear.1.bias', 77 | 'linear2.weight': 'linear.1.weight', 78 | } 79 | for key_from, key_to in changes.items(): 80 | if key_from in state_dict: 81 | state_dict[key_to] = state_dict[key_from] 82 | del state_dict[key_from] 83 | 84 | for size, sd in state_dict.items(): 85 | if type(size) == int: 86 | self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True) 87 | self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True) 88 | return True 89 | 90 | def get_state_dict(self): 91 | state_dict = {} 92 | for i, size in enumerate(Hypernetwork.enable_sizes): 93 | sd0 = self.modules[i][0].state_dict() 94 | sd1 = self.modules[i][1].state_dict() 95 | state_dict[size] = [sd0, sd1] 96 | return state_dict 97 | -------------------------------------------------------------------------------- /clean_captions_and_tags.py: -------------------------------------------------------------------------------- 1 | # このスクリプトのライセンスは、Apache License 2.0とします 2 | # (c) 2022 Kohya S. @kohya_ss 3 | 4 | import argparse 5 | import glob 6 | import os 7 | import json 8 | 9 | from tqdm import tqdm 10 | 11 | 12 | def clean_tags(image_key, tags): 13 | # replace '_' to ' ' 14 | tags = tags.replace('_', ' ') 15 | 16 | # remove rating: deepdanbooruのみ 17 | tokens = tags.split(", rating") 18 | if len(tokens) == 1: 19 | # WD14 taggerのときはこちらになるのでメッセージは出さない 20 | # print("no rating:") 21 | # print(f"{image_key} {tags}") 22 | pass 23 | else: 24 | if len(tokens) > 2: 25 | print("multiple ratings:") 26 | print(f"{image_key} {tags}") 27 | tags = tokens[0] 28 | 29 | return tags 30 | 31 | 32 | # 上から順に検索、置換される 33 | # ('置換元文字列', '置換後文字列') 34 | CAPTION_REPLACEMENTS = [ 35 | ('anime anime', 'anime'), 36 | ('young ', ''), 37 | ('anime girl', 'girl'), 38 | ('cartoon female', 'girl'), 39 | ('cartoon lady', 'girl'), 40 | ('cartoon character', 'girl'), # a or ~s 41 | ('cartoon woman', 'girl'), 42 | ('cartoon women', 'girls'), 43 | ('cartoon girl', 'girl'), 44 | ('anime female', 'girl'), 45 | ('anime lady', 'girl'), 46 | ('anime character', 'girl'), # a or ~s 47 | ('anime woman', 'girl'), 48 | ('anime women', 'girls'), 49 | ('lady', 'girl'), 50 | ('female', 'girl'), 51 | ('woman', 'girl'), 52 | ('women', 'girls'), 53 | ('people', 'girls'), 54 | ('person', 'girl'), 55 | ('a cartoon figure', 'a figure'), 56 | ('a cartoon image', 'an image'), 57 | ('a cartoon picture', 'a picture'), 58 | ('an anime cartoon image', 'an image'), 59 | ('a cartoon anime drawing', 'a drawing'), 60 | ('a cartoon drawing', 'a drawing'), 61 | ('girl girl', 'girl'), 62 | ] 63 | 64 | 65 | def clean_caption(caption): 66 | for rf, rt in CAPTION_REPLACEMENTS: 67 | replaced = True 68 | while replaced: 69 | bef = caption 70 | caption = caption.replace(rf, rt) 71 | replaced = bef != caption 72 | return caption 73 | 74 | 75 | def main(args): 76 | if os.path.exists(args.in_json): 77 | print(f"loading existing metadata: {args.in_json}") 78 | with open(args.in_json, "rt", encoding='utf-8') as f: 79 | metadata = json.load(f) 80 | else: 81 | print("no metadata / メタデータファイルがありません") 82 | return 83 | 84 | print("cleaning captions and tags.") 85 | image_keys = list(metadata.keys()) 86 | for image_key in tqdm(image_keys): 87 | tags = metadata[image_key].get('tags') 88 | if tags is None: 89 | print(f"image does not have tags / メタデータにタグがありません: {image_key}") 90 | else: 91 | metadata[image_key]['tags'] = clean_tags(image_key, tags) 92 | 93 | caption = metadata[image_key].get('caption') 94 | if caption is None: 95 | print(f"image does not have caption / メタデータにキャプションがありません: {image_key}") 96 | else: 97 | metadata[image_key]['caption'] = clean_caption(caption) 98 | 99 | # metadataを書き出して終わり 100 | print(f"writing metadata: {args.out_json}") 101 | with open(args.out_json, "wt", encoding='utf-8') as f: 102 | json.dump(metadata, f, indent=2) 103 | print("done!") 104 | 105 | 106 | if __name__ == '__main__': 107 | parser = argparse.ArgumentParser() 108 | # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 109 | parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") 110 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 111 | 112 | args, unknown = parser.parse_known_args() 113 | if len(unknown) == 1: 114 | print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") 115 | print("All captions and tags in the metadata are processed.") 116 | print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") 117 | print("メタデータ内のすべてのキャプションとタグが処理されます。") 118 | args.in_json = args.out_json 119 | args.out_json = unknown[0] 120 | elif len(unknown) > 0: 121 | raise ValueError(f"error: unrecognized arguments: {unknown}") 122 | 123 | main(args) 124 | -------------------------------------------------------------------------------- /make_captions.py: -------------------------------------------------------------------------------- 1 | # このスクリプトのライセンスは、Apache License 2.0とします 2 | # (c) 2022 Kohya S. @kohya_ss 3 | 4 | import argparse 5 | import glob 6 | import os 7 | import json 8 | 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import numpy as np 12 | import torch 13 | from torchvision import transforms 14 | from torchvision.transforms.functional import InterpolationMode 15 | from models.blip import blip_decoder 16 | # from Salesforce_BLIP.models.blip import blip_decoder 17 | 18 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | 20 | 21 | def main(args): 22 | image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) 23 | print(f"found {len(image_paths)} images.") 24 | 25 | print(f"loading BLIP caption: {args.caption_weights}") 26 | image_size = 384 27 | model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large') 28 | model.eval() 29 | model = model.to(DEVICE) 30 | print("BLIP loaded") 31 | 32 | # 正方形でいいのか? という気がするがソースがそうなので 33 | transform = transforms.Compose([ 34 | transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), 35 | transforms.ToTensor(), 36 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 37 | ]) 38 | 39 | # captioningする 40 | def run_batch(path_imgs): 41 | imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) 42 | 43 | with torch.no_grad(): 44 | if args.beam_search: 45 | captions = model.generate(imgs, sample=False, num_beams=args.num_beams, 46 | max_length=args.max_length, min_length=args.min_length) 47 | else: 48 | captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length) 49 | 50 | for (image_path, _), caption in zip(path_imgs, captions): 51 | with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: 52 | f.write(caption + "\n") 53 | if args.debug: 54 | print(image_path, caption) 55 | 56 | b_imgs = [] 57 | for image_path in tqdm(image_paths): 58 | raw_image = Image.open(image_path) 59 | if raw_image.mode != "RGB": 60 | print(f"convert image mode {raw_image.mode} to RGB: {image_path}") 61 | raw_image = raw_image.convert("RGB") 62 | 63 | image = transform(raw_image) 64 | b_imgs.append((image_path, image)) 65 | if len(b_imgs) >= args.batch_size: 66 | run_batch(b_imgs) 67 | b_imgs.clear() 68 | if len(b_imgs) > 0: 69 | run_batch(b_imgs) 70 | 71 | print("done!") 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 77 | parser.add_argument("caption_weights", type=str, 78 | help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)") 79 | parser.add_argument("--caption_extention", type=str, default=None, 80 | help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") 81 | parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") 82 | parser.add_argument("--beam_search", action="store_true", 83 | help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)") 84 | parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") 85 | parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") 86 | parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") 87 | parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") 88 | parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") 89 | parser.add_argument("--debug", action="store_true", help="debug mode") 90 | 91 | args = parser.parse_args() 92 | 93 | # スペルミスしていたオプションを復元する 94 | if args.caption_extention is not None: 95 | args.caption_extension = args.caption_extention 96 | 97 | main(args) 98 | -------------------------------------------------------------------------------- /tag_images_by_wd14_tagger.py: -------------------------------------------------------------------------------- 1 | # このスクリプトのライセンスは、Apache License 2.0とします 2 | # (c) 2022 Kohya S. @kohya_ss 3 | 4 | import argparse 5 | import csv 6 | import glob 7 | import os 8 | 9 | from PIL import Image 10 | import cv2 11 | from tqdm import tqdm 12 | import numpy as np 13 | from tensorflow.keras.models import load_model 14 | from huggingface_hub import hf_hub_download 15 | 16 | # from wd14 tagger 17 | IMAGE_SIZE = 448 18 | 19 | WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-vit-tagger' 20 | FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] 21 | SUB_DIR = "variables" 22 | SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] 23 | CSV_FILE = FILES[-1] 24 | 25 | 26 | def main(args): 27 | # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする 28 | # depreacatedの警告が出るけどなくなったらその時 29 | # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 30 | if not os.path.exists(args.model_dir) or args.force_download: 31 | print("downloading wd14 tagger model from hf_hub") 32 | for file in FILES: 33 | hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) 34 | for file in SUB_DIR_FILES: 35 | hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join( 36 | args.model_dir, SUB_DIR), force_download=True, force_filename=file) 37 | 38 | # 画像を読み込む 39 | image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ 40 | glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) 41 | print(f"found {len(image_paths)} images.") 42 | 43 | print("loading model and labels") 44 | model = load_model(args.model_dir) 45 | 46 | # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") 47 | # 依存ライブラリを増やしたくないので自力で読むよ 48 | with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f: 49 | reader = csv.reader(f) 50 | l = [row for row in reader] 51 | header = l[0] # tag_id,name,category,count 52 | rows = l[1:] 53 | assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}" 54 | 55 | tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ 56 | 57 | # 推論する 58 | def run_batch(path_imgs): 59 | imgs = np.array([im for _, im in path_imgs]) 60 | 61 | probs = model(imgs, training=False) 62 | probs = probs.numpy() 63 | 64 | for (image_path, _), prob in zip(path_imgs, probs): 65 | # 最初の4つはratingなので無視する 66 | # # First 4 labels are actually ratings: pick one with argmax 67 | # ratings_names = label_names[:4] 68 | # rating_index = ratings_names["probs"].argmax() 69 | # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] 70 | 71 | # それ以降はタグなのでconfidenceがthresholdより高いものを追加する 72 | # Everything else is tags: pick any where prediction confidence > threshold 73 | tag_text = "" 74 | for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで 75 | if p >= args.thresh: 76 | tag_text += ", " + tags[i] 77 | 78 | if len(tag_text) > 0: 79 | tag_text = tag_text[2:] # 最初の ", " を消す 80 | 81 | with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: 82 | f.write(tag_text + '\n') 83 | if args.debug: 84 | print(image_path, tag_text) 85 | 86 | b_imgs = [] 87 | for image_path in tqdm(image_paths): 88 | img = Image.open(image_path) # cv2は日本語ファイル名で死ぬのとモード変換したいのでpillowで開く 89 | if img.mode != 'RGB': 90 | img = img.convert("RGB") 91 | img = np.array(img) 92 | img = img[:, :, ::-1] # RGB->BGR 93 | 94 | # pad to square 95 | size = max(img.shape[0:2]) 96 | pad_x = size - img.shape[1] 97 | pad_y = size - img.shape[0] 98 | pad_l = pad_x // 2 99 | pad_t = pad_y // 2 100 | img = np.pad(img, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255) 101 | 102 | interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 103 | img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) 104 | # cv2.imshow("img", img) 105 | # cv2.waitKey() 106 | # cv2.destroyAllWindows() 107 | 108 | img = img.astype(np.float32) 109 | b_imgs.append((image_path, img)) 110 | 111 | if len(b_imgs) >= args.batch_size: 112 | run_batch(b_imgs) 113 | b_imgs.clear() 114 | 115 | if len(b_imgs) > 0: 116 | run_batch(b_imgs) 117 | 118 | print("done!") 119 | 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 124 | parser.add_argument("--repo_id", type=str, default=WD14_TAGGER_REPO, 125 | help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") 126 | parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", 127 | help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ") 128 | parser.add_argument("--force_download", action='store_true', 129 | help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします") 130 | parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") 131 | parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") 132 | parser.add_argument("--caption_extention", type=str, default=None, 133 | help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") 134 | parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") 135 | parser.add_argument("--debug", action="store_true", help="debug mode") 136 | 137 | args = parser.parse_args() 138 | 139 | # スペルミスしていたオプションを復元する 140 | if args.caption_extention is not None: 141 | args.caption_extension = args.caption_extention 142 | 143 | main(args) 144 | -------------------------------------------------------------------------------- /prepare_buckets_latents.py: -------------------------------------------------------------------------------- 1 | # このスクリプトのライセンスは、Apache License 2.0とします 2 | # (c) 2022 Kohya S. @kohya_ss 3 | 4 | import argparse 5 | import glob 6 | import os 7 | import json 8 | 9 | from tqdm import tqdm 10 | import numpy as np 11 | from diffusers import AutoencoderKL 12 | from PIL import Image 13 | import cv2 14 | import torch 15 | from torchvision import transforms 16 | 17 | import model_util 18 | 19 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | IMAGE_TRANSFORMS = transforms.Compose( 22 | [ 23 | transforms.ToTensor(), 24 | transforms.Normalize([0.5], [0.5]), 25 | ] 26 | ) 27 | 28 | 29 | def get_latents(vae, images, weight_dtype): 30 | img_tensors = [IMAGE_TRANSFORMS(image) for image in images] 31 | img_tensors = torch.stack(img_tensors) 32 | img_tensors = img_tensors.to(DEVICE, weight_dtype) 33 | with torch.no_grad(): 34 | latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy() 35 | return latents 36 | 37 | 38 | def main(args): 39 | image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) 40 | print(f"found {len(image_paths)} images.") 41 | 42 | if os.path.exists(args.in_json): 43 | print(f"loading existing metadata: {args.in_json}") 44 | with open(args.in_json, "rt", encoding='utf-8') as f: 45 | metadata = json.load(f) 46 | else: 47 | print(f"no metadata / メタデータファイルがありません: {args.in_json}") 48 | return 49 | 50 | weight_dtype = torch.float32 51 | if args.mixed_precision == "fp16": 52 | weight_dtype = torch.float16 53 | elif args.mixed_precision == "bf16": 54 | weight_dtype = torch.bfloat16 55 | 56 | vae = model_util.load_vae(args.model_name_or_path, weight_dtype) 57 | vae.eval() 58 | vae.to(DEVICE, dtype=weight_dtype) 59 | 60 | # bucketのサイズを計算する 61 | max_reso = tuple([int(t) for t in args.max_resolution.split(',')]) 62 | assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" 63 | 64 | bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions( 65 | max_reso, args.min_bucket_reso, args.max_bucket_reso) 66 | 67 | # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する 68 | bucket_aspect_ratios = np.array(bucket_aspect_ratios) 69 | buckets_imgs = [[] for _ in range(len(bucket_resos))] 70 | bucket_counts = [0 for _ in range(len(bucket_resos))] 71 | img_ar_errors = [] 72 | for i, image_path in enumerate(tqdm(image_paths)): 73 | image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] 74 | if image_key not in metadata: 75 | metadata[image_key] = {} 76 | 77 | image = Image.open(image_path) 78 | if image.mode != 'RGB': 79 | image = image.convert("RGB") 80 | 81 | aspect_ratio = image.width / image.height 82 | ar_errors = bucket_aspect_ratios - aspect_ratio 83 | bucket_id = np.abs(ar_errors).argmin() 84 | reso = bucket_resos[bucket_id] 85 | ar_error = ar_errors[bucket_id] 86 | img_ar_errors.append(abs(ar_error)) 87 | 88 | # どのサイズにリサイズするか→トリミングする方向で 89 | if ar_error <= 0: # 横が長い→縦を合わせる 90 | scale = reso[1] / image.height 91 | else: 92 | scale = reso[0] / image.width 93 | 94 | resized_size = (int(image.width * scale + .5), int(image.height * scale + .5)) 95 | 96 | # print(image.width, image.height, bucket_id, bucket_resos[bucket_id], ar_errors[bucket_id], resized_size, 97 | # bucket_resos[bucket_id][0] - resized_size[0], bucket_resos[bucket_id][1] - resized_size[1]) 98 | 99 | assert resized_size[0] == reso[0] or resized_size[1] == reso[ 100 | 1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}" 101 | assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ 102 | 1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}" 103 | 104 | # 画像をリサイズしてトリミングする 105 | # PILにinter_areaがないのでcv2で…… 106 | image = np.array(image) 107 | image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) 108 | if resized_size[0] > reso[0]: 109 | trim_size = resized_size[0] - reso[0] 110 | image = image[:, trim_size//2:trim_size//2 + reso[0]] 111 | elif resized_size[1] > reso[1]: 112 | trim_size = resized_size[1] - reso[1] 113 | image = image[trim_size//2:trim_size//2 + reso[1]] 114 | assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" 115 | 116 | # # debug 117 | # cv2.imwrite(f"r:\\test\\img_{i:05d}.jpg", image[:, :, ::-1]) 118 | 119 | # バッチへ追加 120 | buckets_imgs[bucket_id].append((image_key, reso, image)) 121 | bucket_counts[bucket_id] += 1 122 | metadata[image_key]['train_resolution'] = reso 123 | 124 | # バッチを推論するか判定して推論する 125 | is_last = i == len(image_paths) - 1 126 | for j in range(len(buckets_imgs)): 127 | bucket = buckets_imgs[j] 128 | if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: 129 | latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype) 130 | 131 | for (image_key, reso, _), latent in zip(bucket, latents): 132 | np.savez(os.path.join(args.train_data_dir, os.path.splitext(os.path.basename(image_key))[0]), latent) 133 | 134 | bucket.clear() 135 | 136 | for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)): 137 | print(f"bucket {i} {reso}: {count}") 138 | img_ar_errors = np.array(img_ar_errors) 139 | print(f"mean ar error: {np.mean(img_ar_errors)}") 140 | 141 | # metadataを書き出して終わり 142 | print(f"writing metadata: {args.out_json}") 143 | with open(args.out_json, "wt", encoding='utf-8') as f: 144 | json.dump(metadata, f, indent=2) 145 | print("done!") 146 | 147 | 148 | if __name__ == '__main__': 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 151 | parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") 152 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 153 | parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") 154 | parser.add_argument("--v2", action='store_true', 155 | help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') 156 | parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") 157 | parser.add_argument("--max_resolution", type=str, default="512,512", 158 | help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)") 159 | parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") 160 | parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度") 161 | parser.add_argument("--mixed_precision", type=str, default="no", 162 | choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") 163 | parser.add_argument("--full_path", action="store_true", 164 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") 165 | 166 | args = parser.parse_args() 167 | main(args) 168 | -------------------------------------------------------------------------------- /fine_tune.py: -------------------------------------------------------------------------------- 1 | # v2: select precision for saved checkpoint 2 | # v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset) 3 | # v4: support SD2.0, add lr scheduler options, supports save_every_n_epochs and save_state for DiffUsers model 4 | # v5: refactor to use model_util, support safetensors, add settings to use Diffusers' xformers, add log prefix 5 | # v6: model_util update 6 | # v7: support Diffusers 0.10.0 (v-parameterization training, safetensors in Diffusers) and accelerate 0.15.0, support full path in metadata 7 | # v8: experimental full fp16 training. 8 | 9 | # このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします 10 | # License: 11 | # Copyright 2022 Kohya S. @kohya_ss 12 | # 13 | # Licensed under the Apache License, Version 2.0 (the "License"); 14 | # you may not use this file except in compliance with the License. 15 | # You may obtain a copy of the License at 16 | # 17 | # http://www.apache.org/licenses/LICENSE-2.0 18 | # 19 | # Unless required by applicable law or agreed to in writing, software 20 | # distributed under the License is distributed on an "AS IS" BASIS, 21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22 | # See the License for the specific language governing permissions and 23 | # limitations under the License. 24 | 25 | # License of included scripts: 26 | 27 | # Diffusers: ASL 2.0 https://github.com/huggingface/diffusers/blob/main/LICENSE 28 | 29 | # Memory efficient attention: 30 | # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py 31 | # MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE 32 | 33 | import argparse 34 | import math 35 | import os 36 | import random 37 | import json 38 | import importlib 39 | import time 40 | 41 | from tqdm import tqdm 42 | import torch 43 | from accelerate import Accelerator 44 | from accelerate.utils import set_seed 45 | from transformers import CLIPTokenizer 46 | import diffusers 47 | from diffusers import DDPMScheduler, StableDiffusionPipeline 48 | import numpy as np 49 | from einops import rearrange 50 | from torch import einsum 51 | 52 | import model_util 53 | 54 | # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う 55 | TOKENIZER_PATH = "openai/clip-vit-large-patch14" 56 | V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ 57 | 58 | # checkpointファイル名 59 | EPOCH_STATE_NAME = "epoch-{:06d}-state" 60 | LAST_STATE_NAME = "last-state" 61 | 62 | LAST_DIFFUSERS_DIR_NAME = "last" 63 | EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}" 64 | 65 | 66 | def collate_fn(examples): 67 | return examples[0] 68 | 69 | 70 | class FineTuningDataset(torch.utils.data.Dataset): 71 | def __init__(self, metadata, train_data_dir, batch_size, tokenizer, max_token_length, shuffle_caption, dataset_repeats, debug) -> None: 72 | super().__init__() 73 | 74 | self.metadata = metadata 75 | self.train_data_dir = train_data_dir 76 | self.batch_size = batch_size 77 | self.tokenizer: CLIPTokenizer = tokenizer 78 | self.max_token_length = max_token_length 79 | self.shuffle_caption = shuffle_caption 80 | self.debug = debug 81 | 82 | self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 83 | 84 | print("make buckets") 85 | 86 | # 最初に数を数える 87 | self.bucket_resos = set() 88 | for img_md in metadata.values(): 89 | if 'train_resolution' in img_md: 90 | self.bucket_resos.add(tuple(img_md['train_resolution'])) 91 | self.bucket_resos = list(self.bucket_resos) 92 | self.bucket_resos.sort() 93 | print(f"number of buckets: {len(self.bucket_resos)}") 94 | 95 | reso_to_index = {} 96 | for i, reso in enumerate(self.bucket_resos): 97 | reso_to_index[reso] = i 98 | 99 | # bucketに割り当てていく 100 | self.buckets = [[] for _ in range(len(self.bucket_resos))] 101 | n = 1 if dataset_repeats is None else dataset_repeats 102 | images_count = 0 103 | for image_key, img_md in metadata.items(): 104 | if 'train_resolution' not in img_md: 105 | continue 106 | if not os.path.exists(self.image_key_to_npz_file(image_key)): 107 | continue 108 | 109 | reso = tuple(img_md['train_resolution']) 110 | for _ in range(n): 111 | self.buckets[reso_to_index[reso]].append(image_key) 112 | images_count += n 113 | 114 | # 参照用indexを作る 115 | self.buckets_indices = [] 116 | for bucket_index, bucket in enumerate(self.buckets): 117 | batch_count = int(math.ceil(len(bucket) / self.batch_size)) 118 | for batch_index in range(batch_count): 119 | self.buckets_indices.append((bucket_index, batch_index)) 120 | 121 | self.shuffle_buckets() 122 | self._length = len(self.buckets_indices) 123 | self.images_count = images_count 124 | 125 | def show_buckets(self): 126 | for i, (reso, bucket) in enumerate(zip(self.bucket_resos, self.buckets)): 127 | print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") 128 | 129 | def shuffle_buckets(self): 130 | random.shuffle(self.buckets_indices) 131 | for bucket in self.buckets: 132 | random.shuffle(bucket) 133 | 134 | def image_key_to_npz_file(self, image_key): 135 | npz_file = os.path.splitext(image_key)[0] + '.npz' 136 | if os.path.exists(npz_file): 137 | return npz_file 138 | return os.path.join(self.train_data_dir, image_key + '.npz') 139 | 140 | def load_latent(self, image_key): 141 | return np.load(self.image_key_to_npz_file(image_key))['arr_0'] 142 | 143 | def __len__(self): 144 | return self._length 145 | 146 | def __getitem__(self, index): 147 | if index == 0: 148 | self.shuffle_buckets() 149 | 150 | bucket = self.buckets[self.buckets_indices[index][0]] 151 | image_index = self.buckets_indices[index][1] * self.batch_size 152 | 153 | input_ids_list = [] 154 | latents_list = [] 155 | captions = [] 156 | for image_key in bucket[image_index:image_index + self.batch_size]: 157 | img_md = self.metadata[image_key] 158 | caption = img_md.get('caption') 159 | tags = img_md.get('tags') 160 | 161 | if caption is None: 162 | caption = tags 163 | elif tags is not None and len(tags) > 0: 164 | caption = caption + ', ' + tags 165 | assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{image_key}" 166 | 167 | latents = self.load_latent(image_key) 168 | 169 | if self.shuffle_caption: 170 | tokens = caption.strip().split(",") 171 | random.shuffle(tokens) 172 | caption = ",".join(tokens).strip() 173 | 174 | captions.append(caption) 175 | 176 | input_ids = self.tokenizer(caption, padding="max_length", truncation=True, 177 | max_length=self.tokenizer_max_length, return_tensors="pt").input_ids 178 | 179 | if self.tokenizer_max_length > self.tokenizer.model_max_length: 180 | input_ids = input_ids.squeeze(0) 181 | iids_list = [] 182 | if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: 183 | # v1 184 | # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する 185 | # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に 186 | for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75) 187 | ids_chunk = (input_ids[0].unsqueeze(0), 188 | input_ids[i:i + self.tokenizer.model_max_length - 2], 189 | input_ids[-1].unsqueeze(0)) 190 | ids_chunk = torch.cat(ids_chunk) 191 | iids_list.append(ids_chunk) 192 | else: 193 | # v2 194 | # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する 195 | for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): 196 | ids_chunk = (input_ids[0].unsqueeze(0), # BOS 197 | input_ids[i:i + self.tokenizer.model_max_length - 2], 198 | input_ids[-1].unsqueeze(0)) # PAD or EOS 199 | ids_chunk = torch.cat(ids_chunk) 200 | 201 | # 末尾が または の場合は、何もしなくてよい 202 | # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) 203 | if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: 204 | ids_chunk[-1] = self.tokenizer.eos_token_id 205 | # 先頭が ... の場合は ... に変える 206 | if ids_chunk[1] == self.tokenizer.pad_token_id: 207 | ids_chunk[1] = self.tokenizer.eos_token_id 208 | 209 | iids_list.append(ids_chunk) 210 | 211 | input_ids = torch.stack(iids_list) # 3,77 212 | 213 | input_ids_list.append(input_ids) 214 | latents_list.append(torch.FloatTensor(latents)) 215 | 216 | example = {} 217 | example['input_ids'] = torch.stack(input_ids_list) 218 | example['latents'] = torch.stack(latents_list) 219 | if self.debug: 220 | example['image_keys'] = bucket[image_index:image_index + self.batch_size] 221 | example['captions'] = captions 222 | return example 223 | 224 | 225 | def save_hypernetwork(output_file, hypernetwork): 226 | state_dict = hypernetwork.get_state_dict() 227 | torch.save(state_dict, output_file) 228 | 229 | 230 | def train(args): 231 | fine_tuning = args.hypernetwork_module is None # fine tuning or hypernetwork training 232 | 233 | # その他のオプション設定を確認する 234 | if args.v_parameterization and not args.v2: 235 | print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") 236 | if args.v2 and args.clip_skip is not None: 237 | print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") 238 | 239 | # モデル形式のオプション設定を確認する 240 | # v11からDiffUsersから直接落としてくるのもOK(ただし認証がいるやつは未対応)、またv11からDiffUsersも途中保存に対応した 241 | use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) 242 | 243 | # 乱数系列を初期化する 244 | if args.seed is not None: 245 | set_seed(args.seed) 246 | 247 | # メタデータを読み込む 248 | if os.path.exists(args.in_json): 249 | print(f"loading existing metadata: {args.in_json}") 250 | with open(args.in_json, "rt", encoding='utf-8') as f: 251 | metadata = json.load(f) 252 | else: 253 | print(f"no metadata / メタデータファイルがありません: {args.in_json}") 254 | return 255 | 256 | # tokenizerを読み込む 257 | print("prepare tokenizer") 258 | if args.v2: 259 | tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") 260 | else: 261 | tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) 262 | 263 | if args.max_token_length is not None: 264 | print(f"update token length: {args.max_token_length}") 265 | 266 | # datasetを用意する 267 | print("prepare dataset") 268 | train_dataset = FineTuningDataset(metadata, args.train_data_dir, args.train_batch_size, 269 | tokenizer, args.max_token_length, args.shuffle_caption, args.dataset_repeats, args.debug_dataset) 270 | 271 | print(f"Total dataset length / データセットの長さ: {len(train_dataset)}") 272 | print(f"Total images / 画像数: {train_dataset.images_count}") 273 | 274 | if len(train_dataset) == 0: 275 | print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。") 276 | return 277 | 278 | if args.debug_dataset: 279 | train_dataset.show_buckets() 280 | i = 0 281 | for example in train_dataset: 282 | print(f"image: {example['image_keys']}") 283 | print(f"captions: {example['captions']}") 284 | print(f"latents: {example['latents'].shape}") 285 | print(f"input_ids: {example['input_ids'].shape}") 286 | print(example['input_ids']) 287 | i += 1 288 | if i >= 8: 289 | break 290 | return 291 | 292 | # acceleratorを準備する 293 | print("prepare accelerator") 294 | if args.logging_dir is None: 295 | log_with = None 296 | logging_dir = None 297 | else: 298 | log_with = "tensorboard" 299 | log_prefix = "" if args.log_prefix is None else args.log_prefix 300 | logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) 301 | accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, 302 | mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir) 303 | 304 | # accelerateの互換性問題を解決する 305 | accelerator_0_15 = True 306 | try: 307 | accelerator.unwrap_model("dummy", True) 308 | print("Using accelerator 0.15.0 or above.") 309 | except TypeError: 310 | accelerator_0_15 = False 311 | 312 | def unwrap_model(model): 313 | if accelerator_0_15: 314 | return accelerator.unwrap_model(model, True) 315 | return accelerator.unwrap_model(model) 316 | 317 | # mixed precisionに対応した型を用意しておき適宜castする 318 | weight_dtype = torch.float32 319 | if args.mixed_precision == "fp16": 320 | weight_dtype = torch.float16 321 | elif args.mixed_precision == "bf16": 322 | weight_dtype = torch.bfloat16 323 | 324 | save_dtype = None 325 | if args.save_precision == "fp16": 326 | save_dtype = torch.float16 327 | elif args.save_precision == "bf16": 328 | save_dtype = torch.bfloat16 329 | elif args.save_precision == "float": 330 | save_dtype = torch.float32 331 | 332 | # モデルを読み込む 333 | if use_stable_diffusion_format: 334 | print("load StableDiffusion checkpoint") 335 | text_encoder, _, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) 336 | else: 337 | print("load Diffusers pretrained models") 338 | pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) 339 | # , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる 340 | text_encoder = pipe.text_encoder 341 | unet = pipe.unet 342 | del pipe 343 | 344 | # Diffusers版のxformers使用フラグを設定する関数 345 | def set_diffusers_xformers_flag(model, valid): 346 | # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう 347 | # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`) 348 | # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか 349 | 350 | # Recursively walk through all the children. 351 | # Any children which exposes the set_use_memory_efficient_attention_xformers method 352 | # gets the message 353 | def fn_recursive_set_mem_eff(module: torch.nn.Module): 354 | if hasattr(module, "set_use_memory_efficient_attention_xformers"): 355 | module.set_use_memory_efficient_attention_xformers(valid) 356 | 357 | for child in module.children(): 358 | fn_recursive_set_mem_eff(child) 359 | 360 | fn_recursive_set_mem_eff(model) 361 | 362 | # モデルに xformers とか memory efficient attention を組み込む 363 | if args.diffusers_xformers: 364 | print("Use xformers by Diffusers") 365 | set_diffusers_xformers_flag(unet, True) 366 | else: 367 | # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある 368 | print("Disable Diffusers' xformers") 369 | set_diffusers_xformers_flag(unet, False) 370 | replace_unet_modules(unet, args.mem_eff_attn, args.xformers) 371 | 372 | if not fine_tuning: 373 | # Hypernetwork 374 | print("import hypernetwork module:", args.hypernetwork_module) 375 | hyp_module = importlib.import_module(args.hypernetwork_module) 376 | 377 | hypernetwork = hyp_module.Hypernetwork() 378 | 379 | if args.hypernetwork_weights is not None: 380 | print("load hypernetwork weights from:", args.hypernetwork_weights) 381 | hyp_sd = torch.load(args.hypernetwork_weights, map_location='cpu') 382 | success = hypernetwork.load_from_state_dict(hyp_sd) 383 | assert success, "hypernetwork weights loading failed." 384 | 385 | print("apply hypernetwork") 386 | hypernetwork.apply_to_diffusers(None, text_encoder, unet) 387 | 388 | # 学習を準備する:モデルを適切な状態にする 389 | training_models = [] 390 | if fine_tuning: 391 | if args.gradient_checkpointing: 392 | unet.enable_gradient_checkpointing() 393 | training_models.append(unet) 394 | 395 | if args.train_text_encoder: 396 | print("enable text encoder training") 397 | if args.gradient_checkpointing: 398 | text_encoder.gradient_checkpointing_enable() 399 | training_models.append(text_encoder) 400 | else: 401 | text_encoder.to(accelerator.device, dtype=weight_dtype) 402 | text_encoder.requires_grad_(False) # text encoderは学習しない 403 | text_encoder.eval() 404 | else: 405 | unet.to(accelerator.device) # , dtype=weight_dtype) # dtypeを指定すると学習できない 406 | unet.requires_grad_(False) 407 | unet.eval() 408 | text_encoder.to(accelerator.device, dtype=weight_dtype) 409 | text_encoder.requires_grad_(False) 410 | text_encoder.eval() 411 | training_models.append(hypernetwork) 412 | 413 | for m in training_models: 414 | m.requires_grad_(True) 415 | params = [] 416 | for m in training_models: 417 | params.extend(m.parameters()) 418 | params_to_optimize = params 419 | 420 | # 学習に必要なクラスを準備する 421 | print("prepare optimizer, data loader etc.") 422 | 423 | # 8-bit Adamを使う 424 | if args.use_8bit_adam: 425 | try: 426 | import bitsandbytes as bnb 427 | except ImportError: 428 | raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") 429 | print("use 8-bit Adam optimizer") 430 | optimizer_class = bnb.optim.AdamW8bit 431 | else: 432 | optimizer_class = torch.optim.AdamW 433 | 434 | # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 435 | optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate) 436 | 437 | # dataloaderを準備する 438 | # DataLoaderのプロセス数:0はメインプロセスになる 439 | n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 440 | train_dataloader = torch.utils.data.DataLoader( 441 | train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) 442 | 443 | # lr schedulerを用意する 444 | lr_scheduler = diffusers.optimization.get_scheduler( 445 | args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) 446 | 447 | # acceleratorがなんかよろしくやってくれるらしい 448 | if args.full_fp16: 449 | assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" 450 | print("enable full fp16 training.") 451 | 452 | if fine_tuning: 453 | # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする 454 | if args.full_fp16: 455 | unet.to(weight_dtype) 456 | text_encoder.to(weight_dtype) 457 | 458 | if args.train_text_encoder: 459 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 460 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler) 461 | else: 462 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) 463 | else: 464 | if args.full_fp16: 465 | unet.to(weight_dtype) 466 | hypernetwork.to(weight_dtype) 467 | 468 | unet, hypernetwork, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 469 | unet, hypernetwork, optimizer, train_dataloader, lr_scheduler) 470 | 471 | # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする 472 | if args.full_fp16: 473 | org_unscale_grads = accelerator.scaler._unscale_grads_ 474 | 475 | def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): 476 | return org_unscale_grads(optimizer, inv_scale, found_inf, True) 477 | 478 | accelerator.scaler._unscale_grads_ = _unscale_grads_replacer 479 | 480 | # TODO accelerateのconfigに指定した型とオプション指定の型とをチェックして異なれば警告を出す 481 | 482 | # resumeする 483 | if args.resume is not None: 484 | print(f"resume training from state: {args.resume}") 485 | accelerator.load_state(args.resume) 486 | 487 | # epoch数を計算する 488 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 489 | num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 490 | 491 | # 学習する 492 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 493 | print("running training / 学習開始") 494 | print(f" num examples / サンプル数: {train_dataset.images_count}") 495 | print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") 496 | print(f" num epochs / epoch数: {num_train_epochs}") 497 | print(f" batch size per device / バッチサイズ: {args.train_batch_size}") 498 | print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") 499 | print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") 500 | print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") 501 | 502 | progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") 503 | global_step = 0 504 | 505 | # v4で更新:clip_sample=Falseに 506 | # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる 507 | # 既存の1.4/1.5/2.0/2.1はすべてschdulerのconfigは(クラス名を除いて)同じ 508 | # よくソースを見たら学習時はclip_sampleは関係ないや(;'∀') 509 | noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", 510 | num_train_timesteps=1000, clip_sample=False) 511 | 512 | if accelerator.is_main_process: 513 | accelerator.init_trackers("finetuning" if fine_tuning else "hypernetwork") 514 | 515 | # 以下 train_dreambooth.py からほぼコピペ 516 | for epoch in range(num_train_epochs): 517 | print(f"epoch {epoch+1}/{num_train_epochs}") 518 | for m in training_models: 519 | m.train() 520 | 521 | loss_total = 0 522 | for step, batch in enumerate(train_dataloader): 523 | with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく 524 | latents = batch["latents"].to(accelerator.device) 525 | latents = latents * 0.18215 526 | b_size = latents.shape[0] 527 | 528 | # with torch.no_grad(): 529 | with torch.set_grad_enabled(args.train_text_encoder): 530 | # Get the text embedding for conditioning 531 | input_ids = batch["input_ids"].to(accelerator.device) 532 | input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 533 | 534 | if args.clip_skip is None: 535 | encoder_hidden_states = text_encoder(input_ids)[0] 536 | else: 537 | enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) 538 | encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] 539 | encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) 540 | 541 | # bs*3, 77, 768 or 1024 542 | encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) 543 | 544 | if args.max_token_length is not None: 545 | if args.v2: 546 | # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん 547 | states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # 548 | for i in range(1, args.max_token_length, tokenizer.model_max_length): 549 | chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # の後から 最後の前まで 550 | if i > 0: 551 | for j in range(len(chunk)): 552 | if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン 553 | chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする 554 | states_list.append(chunk) # の後から の前まで 555 | states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか 556 | encoder_hidden_states = torch.cat(states_list, dim=1) 557 | else: 558 | # v1: ... の三連を ... へ戻す 559 | states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # 560 | for i in range(1, args.max_token_length, tokenizer.model_max_length): 561 | states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで 562 | states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # 563 | encoder_hidden_states = torch.cat(states_list, dim=1) 564 | 565 | # Sample noise that we'll add to the latents 566 | noise = torch.randn_like(latents, device=latents.device) 567 | 568 | # Sample a random timestep for each image 569 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) 570 | timesteps = timesteps.long() 571 | 572 | # Add noise to the latents according to the noise magnitude at each timestep 573 | # (this is the forward diffusion process) 574 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 575 | 576 | # Predict the noise residual 577 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 578 | 579 | if args.v_parameterization: 580 | # v-parameterization training 581 | # Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う 582 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 583 | else: 584 | target = noise 585 | 586 | loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") 587 | 588 | accelerator.backward(loss) 589 | if accelerator.sync_gradients: 590 | params_to_clip = [] 591 | for m in training_models: 592 | params_to_clip.extend(m.parameters()) 593 | accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) 594 | 595 | optimizer.step() 596 | lr_scheduler.step() 597 | optimizer.zero_grad(set_to_none=True) 598 | 599 | # Checks if the accelerator has performed an optimization step behind the scenes 600 | if accelerator.sync_gradients: 601 | progress_bar.update(1) 602 | global_step += 1 603 | 604 | current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず 605 | if args.logging_dir is not None: 606 | logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} 607 | accelerator.log(logs, step=global_step) 608 | 609 | loss_total += current_loss 610 | avr_loss = loss_total / (step+1) 611 | logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} 612 | progress_bar.set_postfix(**logs) 613 | 614 | if global_step >= args.max_train_steps: 615 | break 616 | 617 | if args.logging_dir is not None: 618 | logs = {"epoch_loss": loss_total / len(train_dataloader)} 619 | accelerator.log(logs, step=epoch+1) 620 | 621 | accelerator.wait_for_everyone() 622 | 623 | if args.save_every_n_epochs is not None: 624 | if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: 625 | print("saving checkpoint.") 626 | os.makedirs(args.output_dir, exist_ok=True) 627 | ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(args.use_safetensors, epoch + 1)) 628 | 629 | if fine_tuning: 630 | if use_stable_diffusion_format: 631 | model_util.save_stable_diffusion_checkpoint( 632 | args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), 633 | args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype) 634 | else: 635 | out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) 636 | os.makedirs(out_dir, exist_ok=True) 637 | model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), 638 | unwrap_model(unet), args.pretrained_model_name_or_path, use_safetensors=args.use_safetensors) 639 | else: 640 | save_hypernetwork(ckpt_file, unwrap_model(hypernetwork)) 641 | 642 | if args.save_state: 643 | print("saving state.") 644 | accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) 645 | 646 | is_main_process = accelerator.is_main_process 647 | if is_main_process: 648 | if fine_tuning: 649 | unet = unwrap_model(unet) 650 | text_encoder = unwrap_model(text_encoder) 651 | else: 652 | hypernetwork = unwrap_model(hypernetwork) 653 | 654 | accelerator.end_training() 655 | 656 | if args.save_state: 657 | print("saving last state.") 658 | accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) 659 | 660 | del accelerator # この後メモリを使うのでこれは消す 661 | 662 | if is_main_process: 663 | os.makedirs(args.output_dir, exist_ok=True) 664 | if fine_tuning: 665 | if use_stable_diffusion_format: 666 | ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(args.use_safetensors)) 667 | print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") 668 | model_util.save_stable_diffusion_checkpoint( 669 | args.v2, ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step, save_dtype) 670 | else: 671 | # Create the pipeline using using the trained modules and save it. 672 | print(f"save trained model as Diffusers to {args.output_dir}") 673 | out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME) 674 | os.makedirs(out_dir, exist_ok=True) 675 | model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, 676 | args.pretrained_model_name_or_path, use_safetensors=args.use_safetensors) 677 | else: 678 | ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(args.use_safetensors)) 679 | print(f"save trained model to {ckpt_file}") 680 | save_hypernetwork(ckpt_file, hypernetwork) 681 | 682 | print("model saved.") 683 | 684 | 685 | # region モジュール入れ替え部 686 | """ 687 | 高速化のためのモジュール入れ替え 688 | """ 689 | 690 | # FlashAttentionを使うCrossAttention 691 | # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py 692 | # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE 693 | 694 | # constants 695 | 696 | EPSILON = 1e-6 697 | 698 | # helper functions 699 | 700 | 701 | def exists(val): 702 | return val is not None 703 | 704 | 705 | def default(val, d): 706 | return val if exists(val) else d 707 | 708 | # flash attention forwards and backwards 709 | 710 | # https://arxiv.org/abs/2205.14135 711 | 712 | 713 | class FlashAttentionFunction(torch.autograd.function.Function): 714 | @ staticmethod 715 | @ torch.no_grad() 716 | def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): 717 | """ Algorithm 2 in the paper """ 718 | 719 | device = q.device 720 | dtype = q.dtype 721 | max_neg_value = -torch.finfo(q.dtype).max 722 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) 723 | 724 | o = torch.zeros_like(q) 725 | all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) 726 | all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) 727 | 728 | scale = (q.shape[-1] ** -0.5) 729 | 730 | if not exists(mask): 731 | mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) 732 | else: 733 | mask = rearrange(mask, 'b n -> b 1 1 n') 734 | mask = mask.split(q_bucket_size, dim=-1) 735 | 736 | row_splits = zip( 737 | q.split(q_bucket_size, dim=-2), 738 | o.split(q_bucket_size, dim=-2), 739 | mask, 740 | all_row_sums.split(q_bucket_size, dim=-2), 741 | all_row_maxes.split(q_bucket_size, dim=-2), 742 | ) 743 | 744 | for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): 745 | q_start_index = ind * q_bucket_size - qk_len_diff 746 | 747 | col_splits = zip( 748 | k.split(k_bucket_size, dim=-2), 749 | v.split(k_bucket_size, dim=-2), 750 | ) 751 | 752 | for k_ind, (kc, vc) in enumerate(col_splits): 753 | k_start_index = k_ind * k_bucket_size 754 | 755 | attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale 756 | 757 | if exists(row_mask): 758 | attn_weights.masked_fill_(~row_mask, max_neg_value) 759 | 760 | if causal and q_start_index < (k_start_index + k_bucket_size - 1): 761 | causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, 762 | device=device).triu(q_start_index - k_start_index + 1) 763 | attn_weights.masked_fill_(causal_mask, max_neg_value) 764 | 765 | block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) 766 | attn_weights -= block_row_maxes 767 | exp_weights = torch.exp(attn_weights) 768 | 769 | if exists(row_mask): 770 | exp_weights.masked_fill_(~row_mask, 0.) 771 | 772 | block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) 773 | 774 | new_row_maxes = torch.maximum(block_row_maxes, row_maxes) 775 | 776 | exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) 777 | 778 | exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) 779 | exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) 780 | 781 | new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums 782 | 783 | oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) 784 | 785 | row_maxes.copy_(new_row_maxes) 786 | row_sums.copy_(new_row_sums) 787 | 788 | ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) 789 | ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) 790 | 791 | return o 792 | 793 | @ staticmethod 794 | @ torch.no_grad() 795 | def backward(ctx, do): 796 | """ Algorithm 4 in the paper """ 797 | 798 | causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args 799 | q, k, v, o, l, m = ctx.saved_tensors 800 | 801 | device = q.device 802 | 803 | max_neg_value = -torch.finfo(q.dtype).max 804 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) 805 | 806 | dq = torch.zeros_like(q) 807 | dk = torch.zeros_like(k) 808 | dv = torch.zeros_like(v) 809 | 810 | row_splits = zip( 811 | q.split(q_bucket_size, dim=-2), 812 | o.split(q_bucket_size, dim=-2), 813 | do.split(q_bucket_size, dim=-2), 814 | mask, 815 | l.split(q_bucket_size, dim=-2), 816 | m.split(q_bucket_size, dim=-2), 817 | dq.split(q_bucket_size, dim=-2) 818 | ) 819 | 820 | for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): 821 | q_start_index = ind * q_bucket_size - qk_len_diff 822 | 823 | col_splits = zip( 824 | k.split(k_bucket_size, dim=-2), 825 | v.split(k_bucket_size, dim=-2), 826 | dk.split(k_bucket_size, dim=-2), 827 | dv.split(k_bucket_size, dim=-2), 828 | ) 829 | 830 | for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): 831 | k_start_index = k_ind * k_bucket_size 832 | 833 | attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale 834 | 835 | if causal and q_start_index < (k_start_index + k_bucket_size - 1): 836 | causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, 837 | device=device).triu(q_start_index - k_start_index + 1) 838 | attn_weights.masked_fill_(causal_mask, max_neg_value) 839 | 840 | exp_attn_weights = torch.exp(attn_weights - mc) 841 | 842 | if exists(row_mask): 843 | exp_attn_weights.masked_fill_(~row_mask, 0.) 844 | 845 | p = exp_attn_weights / lc 846 | 847 | dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) 848 | dp = einsum('... i d, ... j d -> ... i j', doc, vc) 849 | 850 | D = (doc * oc).sum(dim=-1, keepdims=True) 851 | ds = p * scale * (dp - D) 852 | 853 | dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) 854 | dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) 855 | 856 | dqc.add_(dq_chunk) 857 | dkc.add_(dk_chunk) 858 | dvc.add_(dv_chunk) 859 | 860 | return dq, dk, dv, None, None, None, None 861 | 862 | 863 | def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): 864 | if mem_eff_attn: 865 | replace_unet_cross_attn_to_memory_efficient() 866 | elif xformers: 867 | replace_unet_cross_attn_to_xformers() 868 | 869 | 870 | def replace_unet_cross_attn_to_memory_efficient(): 871 | print("Replace CrossAttention.forward to use FlashAttention (not xformers)") 872 | flash_func = FlashAttentionFunction 873 | 874 | def forward_flash_attn(self, x, context=None, mask=None): 875 | q_bucket_size = 512 876 | k_bucket_size = 1024 877 | 878 | h = self.heads 879 | q = self.to_q(x) 880 | 881 | context = context if context is not None else x 882 | context = context.to(x.dtype) 883 | 884 | if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: 885 | context_k, context_v = self.hypernetwork.forward(x, context) 886 | context_k = context_k.to(x.dtype) 887 | context_v = context_v.to(x.dtype) 888 | else: 889 | context_k = context 890 | context_v = context 891 | 892 | k = self.to_k(context_k) 893 | v = self.to_v(context_v) 894 | del context, x 895 | 896 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) 897 | 898 | out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) 899 | 900 | out = rearrange(out, 'b h n d -> b n (h d)') 901 | 902 | # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) 903 | out = self.to_out[0](out) 904 | out = self.to_out[1](out) 905 | return out 906 | 907 | diffusers.models.attention.CrossAttention.forward = forward_flash_attn 908 | 909 | 910 | def replace_unet_cross_attn_to_xformers(): 911 | print("Replace CrossAttention.forward to use xformers") 912 | try: 913 | import xformers.ops 914 | except ImportError: 915 | raise ImportError("No xformers / xformersがインストールされていないようです") 916 | 917 | def forward_xformers(self, x, context=None, mask=None): 918 | h = self.heads 919 | q_in = self.to_q(x) 920 | 921 | context = default(context, x) 922 | context = context.to(x.dtype) 923 | 924 | if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: 925 | context_k, context_v = self.hypernetwork.forward(x, context) 926 | context_k = context_k.to(x.dtype) 927 | context_v = context_v.to(x.dtype) 928 | else: 929 | context_k = context 930 | context_v = context 931 | 932 | k_in = self.to_k(context_k) 933 | v_in = self.to_v(context_v) 934 | 935 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) 936 | del q_in, k_in, v_in 937 | 938 | q = q.contiguous() 939 | k = k.contiguous() 940 | v = v.contiguous() 941 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる 942 | 943 | out = rearrange(out, 'b n h d -> b n (h d)', h=h) 944 | 945 | # diffusers 0.7.0~ 946 | out = self.to_out[0](out) 947 | out = self.to_out[1](out) 948 | return out 949 | 950 | diffusers.models.attention.CrossAttention.forward = forward_xformers 951 | # endregion 952 | 953 | 954 | if __name__ == '__main__': 955 | # torch.cuda.set_per_process_memory_fraction(0.48) 956 | parser = argparse.ArgumentParser() 957 | parser.add_argument("--v2", action='store_true', 958 | help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') 959 | parser.add_argument("--v_parameterization", action='store_true', 960 | help='enable v-parameterization training / v-parameterization学習を有効にする') 961 | parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, 962 | help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") 963 | parser.add_argument("--in_json", type=str, default=None, help="metadata file to input / 読みこむメタデータファイル") 964 | parser.add_argument("--shuffle_caption", action="store_true", 965 | help="shuffle comma-separated caption when fine tuning / fine tuning時にコンマで区切られたcaptionの各要素をshuffleする") 966 | parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") 967 | parser.add_argument("--dataset_repeats", type=int, default=None, help="num times to repeat dataset / 学習にデータセットを繰り返す回数") 968 | parser.add_argument("--output_dir", type=str, default=None, 969 | help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)") 970 | parser.add_argument("--use_safetensors", action='store_true', 971 | help="use safetensors format to save / checkpoint、モデルをsafetensors形式で保存する") 972 | parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") 973 | parser.add_argument("--hypernetwork_module", type=str, default=None, 974 | help='train hypernetwork instead of fine tuning, module to use / fine tuningの代わりにHypernetworkの学習をする場合、そのモジュール') 975 | parser.add_argument("--hypernetwork_weights", type=str, default=None, 976 | help='hypernetwork weights to initialize for additional training / Hypernetworkの学習時に読み込む重み(Hypernetworkの追加学習)') 977 | parser.add_argument("--save_every_n_epochs", type=int, default=None, 978 | help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") 979 | parser.add_argument("--save_state", action="store_true", 980 | help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") 981 | parser.add_argument("--resume", type=str, default=None, 982 | help="saved state to resume training / 学習再開するモデルのstate") 983 | parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], 984 | help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") 985 | parser.add_argument("--train_batch_size", type=int, default=1, 986 | help="batch size for training / 学習時のバッチサイズ") 987 | parser.add_argument("--use_8bit_adam", action="store_true", 988 | help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") 989 | parser.add_argument("--mem_eff_attn", action="store_true", 990 | help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") 991 | parser.add_argument("--xformers", action="store_true", 992 | help="use xformers for CrossAttention / CrossAttentionにxformersを使う") 993 | parser.add_argument("--diffusers_xformers", action='store_true', 994 | help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') 995 | parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") 996 | parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") 997 | parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") 998 | parser.add_argument("--gradient_checkpointing", action="store_true", 999 | help="enable gradient checkpointing / grandient checkpointingを有効にする") 1000 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1, 1001 | help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数") 1002 | parser.add_argument("--mixed_precision", type=str, default="no", 1003 | choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") 1004 | parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") 1005 | parser.add_argument("--save_precision", type=str, default=None, 1006 | choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)") 1007 | parser.add_argument("--clip_skip", type=int, default=None, 1008 | help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") 1009 | parser.add_argument("--debug_dataset", action="store_true", 1010 | help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") 1011 | parser.add_argument("--logging_dir", type=str, default=None, 1012 | help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") 1013 | parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") 1014 | parser.add_argument("--lr_scheduler", type=str, default="constant", 1015 | help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") 1016 | parser.add_argument("--lr_warmup_steps", type=int, default=0, 1017 | help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") 1018 | 1019 | args = parser.parse_args() 1020 | train(args) 1021 | -------------------------------------------------------------------------------- /model_util.py: -------------------------------------------------------------------------------- 1 | # v1: split from train_db_fixed.py. 2 | # v2: support safetensors 3 | 4 | import math 5 | import os 6 | import torch 7 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig 8 | from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel 9 | from safetensors.torch import load_file, save_file 10 | 11 | # DiffUsers版StableDiffusionのモデルパラメータ 12 | NUM_TRAIN_TIMESTEPS = 1000 13 | BETA_START = 0.00085 14 | BETA_END = 0.0120 15 | 16 | UNET_PARAMS_MODEL_CHANNELS = 320 17 | UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] 18 | UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] 19 | UNET_PARAMS_IMAGE_SIZE = 32 # unused 20 | UNET_PARAMS_IN_CHANNELS = 4 21 | UNET_PARAMS_OUT_CHANNELS = 4 22 | UNET_PARAMS_NUM_RES_BLOCKS = 2 23 | UNET_PARAMS_CONTEXT_DIM = 768 24 | UNET_PARAMS_NUM_HEADS = 8 25 | 26 | VAE_PARAMS_Z_CHANNELS = 4 27 | VAE_PARAMS_RESOLUTION = 256 28 | VAE_PARAMS_IN_CHANNELS = 3 29 | VAE_PARAMS_OUT_CH = 3 30 | VAE_PARAMS_CH = 128 31 | VAE_PARAMS_CH_MULT = [1, 2, 4, 4] 32 | VAE_PARAMS_NUM_RES_BLOCKS = 2 33 | 34 | # V2 35 | V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] 36 | V2_UNET_PARAMS_CONTEXT_DIM = 1024 37 | 38 | 39 | # region StableDiffusion->Diffusersの変換コード 40 | # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0) 41 | 42 | 43 | def shave_segments(path, n_shave_prefix_segments=1): 44 | """ 45 | Removes segments. Positive values shave the first segments, negative shave the last segments. 46 | """ 47 | if n_shave_prefix_segments >= 0: 48 | return ".".join(path.split(".")[n_shave_prefix_segments:]) 49 | else: 50 | return ".".join(path.split(".")[:n_shave_prefix_segments]) 51 | 52 | 53 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): 54 | """ 55 | Updates paths inside resnets to the new naming scheme (local renaming) 56 | """ 57 | mapping = [] 58 | for old_item in old_list: 59 | new_item = old_item.replace("in_layers.0", "norm1") 60 | new_item = new_item.replace("in_layers.2", "conv1") 61 | 62 | new_item = new_item.replace("out_layers.0", "norm2") 63 | new_item = new_item.replace("out_layers.3", "conv2") 64 | 65 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") 66 | new_item = new_item.replace("skip_connection", "conv_shortcut") 67 | 68 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 69 | 70 | mapping.append({"old": old_item, "new": new_item}) 71 | 72 | return mapping 73 | 74 | 75 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): 76 | """ 77 | Updates paths inside resnets to the new naming scheme (local renaming) 78 | """ 79 | mapping = [] 80 | for old_item in old_list: 81 | new_item = old_item 82 | 83 | new_item = new_item.replace("nin_shortcut", "conv_shortcut") 84 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 85 | 86 | mapping.append({"old": old_item, "new": new_item}) 87 | 88 | return mapping 89 | 90 | 91 | def renew_attention_paths(old_list, n_shave_prefix_segments=0): 92 | """ 93 | Updates paths inside attentions to the new naming scheme (local renaming) 94 | """ 95 | mapping = [] 96 | for old_item in old_list: 97 | new_item = old_item 98 | 99 | # new_item = new_item.replace('norm.weight', 'group_norm.weight') 100 | # new_item = new_item.replace('norm.bias', 'group_norm.bias') 101 | 102 | # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') 103 | # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') 104 | 105 | # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 106 | 107 | mapping.append({"old": old_item, "new": new_item}) 108 | 109 | return mapping 110 | 111 | 112 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): 113 | """ 114 | Updates paths inside attentions to the new naming scheme (local renaming) 115 | """ 116 | mapping = [] 117 | for old_item in old_list: 118 | new_item = old_item 119 | 120 | new_item = new_item.replace("norm.weight", "group_norm.weight") 121 | new_item = new_item.replace("norm.bias", "group_norm.bias") 122 | 123 | new_item = new_item.replace("q.weight", "query.weight") 124 | new_item = new_item.replace("q.bias", "query.bias") 125 | 126 | new_item = new_item.replace("k.weight", "key.weight") 127 | new_item = new_item.replace("k.bias", "key.bias") 128 | 129 | new_item = new_item.replace("v.weight", "value.weight") 130 | new_item = new_item.replace("v.bias", "value.bias") 131 | 132 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight") 133 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias") 134 | 135 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 136 | 137 | mapping.append({"old": old_item, "new": new_item}) 138 | 139 | return mapping 140 | 141 | 142 | def assign_to_checkpoint( 143 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None 144 | ): 145 | """ 146 | This does the final conversion step: take locally converted weights and apply a global renaming 147 | to them. It splits attention layers, and takes into account additional replacements 148 | that may arise. 149 | 150 | Assigns the weights to the new checkpoint. 151 | """ 152 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." 153 | 154 | # Splits the attention layers into three variables. 155 | if attention_paths_to_split is not None: 156 | for path, path_map in attention_paths_to_split.items(): 157 | old_tensor = old_checkpoint[path] 158 | channels = old_tensor.shape[0] // 3 159 | 160 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) 161 | 162 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 163 | 164 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) 165 | query, key, value = old_tensor.split(channels // num_heads, dim=1) 166 | 167 | checkpoint[path_map["query"]] = query.reshape(target_shape) 168 | checkpoint[path_map["key"]] = key.reshape(target_shape) 169 | checkpoint[path_map["value"]] = value.reshape(target_shape) 170 | 171 | for path in paths: 172 | new_path = path["new"] 173 | 174 | # These have already been assigned 175 | if attention_paths_to_split is not None and new_path in attention_paths_to_split: 176 | continue 177 | 178 | # Global renaming happens here 179 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") 180 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") 181 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") 182 | 183 | if additional_replacements is not None: 184 | for replacement in additional_replacements: 185 | new_path = new_path.replace(replacement["old"], replacement["new"]) 186 | 187 | # proj_attn.weight has to be converted from conv 1D to linear 188 | if "proj_attn.weight" in new_path: 189 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] 190 | else: 191 | checkpoint[new_path] = old_checkpoint[path["old"]] 192 | 193 | 194 | def conv_attn_to_linear(checkpoint): 195 | keys = list(checkpoint.keys()) 196 | attn_keys = ["query.weight", "key.weight", "value.weight"] 197 | for key in keys: 198 | if ".".join(key.split(".")[-2:]) in attn_keys: 199 | if checkpoint[key].ndim > 2: 200 | checkpoint[key] = checkpoint[key][:, :, 0, 0] 201 | elif "proj_attn.weight" in key: 202 | if checkpoint[key].ndim > 2: 203 | checkpoint[key] = checkpoint[key][:, :, 0] 204 | 205 | 206 | def linear_transformer_to_conv(checkpoint): 207 | keys = list(checkpoint.keys()) 208 | tf_keys = ["proj_in.weight", "proj_out.weight"] 209 | for key in keys: 210 | if ".".join(key.split(".")[-2:]) in tf_keys: 211 | if checkpoint[key].ndim == 2: 212 | checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) 213 | 214 | 215 | def convert_ldm_unet_checkpoint(v2, checkpoint, config): 216 | """ 217 | Takes a state dict and a config, and returns a converted checkpoint. 218 | """ 219 | 220 | # extract state_dict for UNet 221 | unet_state_dict = {} 222 | unet_key = "model.diffusion_model." 223 | keys = list(checkpoint.keys()) 224 | for key in keys: 225 | if key.startswith(unet_key): 226 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) 227 | 228 | new_checkpoint = {} 229 | 230 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] 231 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] 232 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] 233 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] 234 | 235 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] 236 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] 237 | 238 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] 239 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] 240 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] 241 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] 242 | 243 | # Retrieves the keys for the input blocks only 244 | num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) 245 | input_blocks = { 246 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] 247 | for layer_id in range(num_input_blocks) 248 | } 249 | 250 | # Retrieves the keys for the middle blocks only 251 | num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) 252 | middle_blocks = { 253 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] 254 | for layer_id in range(num_middle_blocks) 255 | } 256 | 257 | # Retrieves the keys for the output blocks only 258 | num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) 259 | output_blocks = { 260 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] 261 | for layer_id in range(num_output_blocks) 262 | } 263 | 264 | for i in range(1, num_input_blocks): 265 | block_id = (i - 1) // (config["layers_per_block"] + 1) 266 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) 267 | 268 | resnets = [ 269 | key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key 270 | ] 271 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] 272 | 273 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict: 274 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( 275 | f"input_blocks.{i}.0.op.weight" 276 | ) 277 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( 278 | f"input_blocks.{i}.0.op.bias" 279 | ) 280 | 281 | paths = renew_resnet_paths(resnets) 282 | meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} 283 | assign_to_checkpoint( 284 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 285 | ) 286 | 287 | if len(attentions): 288 | paths = renew_attention_paths(attentions) 289 | meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} 290 | assign_to_checkpoint( 291 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 292 | ) 293 | 294 | resnet_0 = middle_blocks[0] 295 | attentions = middle_blocks[1] 296 | resnet_1 = middle_blocks[2] 297 | 298 | resnet_0_paths = renew_resnet_paths(resnet_0) 299 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) 300 | 301 | resnet_1_paths = renew_resnet_paths(resnet_1) 302 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) 303 | 304 | attentions_paths = renew_attention_paths(attentions) 305 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} 306 | assign_to_checkpoint( 307 | attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 308 | ) 309 | 310 | for i in range(num_output_blocks): 311 | block_id = i // (config["layers_per_block"] + 1) 312 | layer_in_block_id = i % (config["layers_per_block"] + 1) 313 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] 314 | output_block_list = {} 315 | 316 | for layer in output_block_layers: 317 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) 318 | if layer_id in output_block_list: 319 | output_block_list[layer_id].append(layer_name) 320 | else: 321 | output_block_list[layer_id] = [layer_name] 322 | 323 | if len(output_block_list) > 1: 324 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] 325 | attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] 326 | 327 | resnet_0_paths = renew_resnet_paths(resnets) 328 | paths = renew_resnet_paths(resnets) 329 | 330 | meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} 331 | assign_to_checkpoint( 332 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 333 | ) 334 | 335 | # オリジナル: 336 | # if ["conv.weight", "conv.bias"] in output_block_list.values(): 337 | # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) 338 | 339 | # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが 340 | for l in output_block_list.values(): 341 | l.sort() 342 | 343 | if ["conv.bias", "conv.weight"] in output_block_list.values(): 344 | index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) 345 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ 346 | f"output_blocks.{i}.{index}.conv.bias" 347 | ] 348 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ 349 | f"output_blocks.{i}.{index}.conv.weight" 350 | ] 351 | 352 | # Clear attentions as they have been attributed above. 353 | if len(attentions) == 2: 354 | attentions = [] 355 | 356 | if len(attentions): 357 | paths = renew_attention_paths(attentions) 358 | meta_path = { 359 | "old": f"output_blocks.{i}.1", 360 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", 361 | } 362 | assign_to_checkpoint( 363 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 364 | ) 365 | else: 366 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) 367 | for path in resnet_0_paths: 368 | old_path = ".".join(["output_blocks", str(i), path["old"]]) 369 | new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) 370 | 371 | new_checkpoint[new_path] = unet_state_dict[old_path] 372 | 373 | # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する 374 | if v2: 375 | linear_transformer_to_conv(new_checkpoint) 376 | 377 | return new_checkpoint 378 | 379 | 380 | def convert_ldm_vae_checkpoint(checkpoint, config): 381 | # extract state dict for VAE 382 | vae_state_dict = {} 383 | vae_key = "first_stage_model." 384 | keys = list(checkpoint.keys()) 385 | for key in keys: 386 | if key.startswith(vae_key): 387 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) 388 | # if len(vae_state_dict) == 0: 389 | # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict 390 | # vae_state_dict = checkpoint 391 | 392 | new_checkpoint = {} 393 | 394 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] 395 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] 396 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] 397 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] 398 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] 399 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] 400 | 401 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] 402 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] 403 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] 404 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] 405 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] 406 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] 407 | 408 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] 409 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] 410 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] 411 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] 412 | 413 | # Retrieves the keys for the encoder down blocks only 414 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) 415 | down_blocks = { 416 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) 417 | } 418 | 419 | # Retrieves the keys for the decoder up blocks only 420 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) 421 | up_blocks = { 422 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) 423 | } 424 | 425 | for i in range(num_down_blocks): 426 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] 427 | 428 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: 429 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( 430 | f"encoder.down.{i}.downsample.conv.weight" 431 | ) 432 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( 433 | f"encoder.down.{i}.downsample.conv.bias" 434 | ) 435 | 436 | paths = renew_vae_resnet_paths(resnets) 437 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} 438 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 439 | 440 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] 441 | num_mid_res_blocks = 2 442 | for i in range(1, num_mid_res_blocks + 1): 443 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] 444 | 445 | paths = renew_vae_resnet_paths(resnets) 446 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 447 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 448 | 449 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] 450 | paths = renew_vae_attention_paths(mid_attentions) 451 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 452 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 453 | conv_attn_to_linear(new_checkpoint) 454 | 455 | for i in range(num_up_blocks): 456 | block_id = num_up_blocks - 1 - i 457 | resnets = [ 458 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key 459 | ] 460 | 461 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: 462 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ 463 | f"decoder.up.{block_id}.upsample.conv.weight" 464 | ] 465 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ 466 | f"decoder.up.{block_id}.upsample.conv.bias" 467 | ] 468 | 469 | paths = renew_vae_resnet_paths(resnets) 470 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} 471 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 472 | 473 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] 474 | num_mid_res_blocks = 2 475 | for i in range(1, num_mid_res_blocks + 1): 476 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] 477 | 478 | paths = renew_vae_resnet_paths(resnets) 479 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 480 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 481 | 482 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] 483 | paths = renew_vae_attention_paths(mid_attentions) 484 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 485 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 486 | conv_attn_to_linear(new_checkpoint) 487 | return new_checkpoint 488 | 489 | 490 | def create_unet_diffusers_config(v2): 491 | """ 492 | Creates a config for the diffusers based on the config of the LDM model. 493 | """ 494 | # unet_params = original_config.model.params.unet_config.params 495 | 496 | block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] 497 | 498 | down_block_types = [] 499 | resolution = 1 500 | for i in range(len(block_out_channels)): 501 | block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" 502 | down_block_types.append(block_type) 503 | if i != len(block_out_channels) - 1: 504 | resolution *= 2 505 | 506 | up_block_types = [] 507 | for i in range(len(block_out_channels)): 508 | block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" 509 | up_block_types.append(block_type) 510 | resolution //= 2 511 | 512 | config = dict( 513 | sample_size=UNET_PARAMS_IMAGE_SIZE, 514 | in_channels=UNET_PARAMS_IN_CHANNELS, 515 | out_channels=UNET_PARAMS_OUT_CHANNELS, 516 | down_block_types=tuple(down_block_types), 517 | up_block_types=tuple(up_block_types), 518 | block_out_channels=tuple(block_out_channels), 519 | layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, 520 | cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, 521 | attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, 522 | ) 523 | 524 | return config 525 | 526 | 527 | def create_vae_diffusers_config(): 528 | """ 529 | Creates a config for the diffusers based on the config of the LDM model. 530 | """ 531 | # vae_params = original_config.model.params.first_stage_config.params.ddconfig 532 | # _ = original_config.model.params.first_stage_config.params.embed_dim 533 | block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] 534 | down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) 535 | up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) 536 | 537 | config = dict( 538 | sample_size=VAE_PARAMS_RESOLUTION, 539 | in_channels=VAE_PARAMS_IN_CHANNELS, 540 | out_channels=VAE_PARAMS_OUT_CH, 541 | down_block_types=tuple(down_block_types), 542 | up_block_types=tuple(up_block_types), 543 | block_out_channels=tuple(block_out_channels), 544 | latent_channels=VAE_PARAMS_Z_CHANNELS, 545 | layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, 546 | ) 547 | return config 548 | 549 | 550 | def convert_ldm_clip_checkpoint_v1(checkpoint): 551 | keys = list(checkpoint.keys()) 552 | text_model_dict = {} 553 | for key in keys: 554 | if key.startswith("cond_stage_model.transformer"): 555 | text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] 556 | return text_model_dict 557 | 558 | 559 | def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): 560 | # 嫌になるくらい違うぞ! 561 | def convert_key(key): 562 | if not key.startswith("cond_stage_model"): 563 | return None 564 | 565 | # common conversion 566 | key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") 567 | key = key.replace("cond_stage_model.model.", "text_model.") 568 | 569 | if "resblocks" in key: 570 | # resblocks conversion 571 | key = key.replace(".resblocks.", ".layers.") 572 | if ".ln_" in key: 573 | key = key.replace(".ln_", ".layer_norm") 574 | elif ".mlp." in key: 575 | key = key.replace(".c_fc.", ".fc1.") 576 | key = key.replace(".c_proj.", ".fc2.") 577 | elif '.attn.out_proj' in key: 578 | key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") 579 | elif '.attn.in_proj' in key: 580 | key = None # 特殊なので後で処理する 581 | else: 582 | raise ValueError(f"unexpected key in SD: {key}") 583 | elif '.positional_embedding' in key: 584 | key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") 585 | elif '.text_projection' in key: 586 | key = None # 使われない??? 587 | elif '.logit_scale' in key: 588 | key = None # 使われない??? 589 | elif '.token_embedding' in key: 590 | key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") 591 | elif '.ln_final' in key: 592 | key = key.replace(".ln_final", ".final_layer_norm") 593 | return key 594 | 595 | keys = list(checkpoint.keys()) 596 | new_sd = {} 597 | for key in keys: 598 | # remove resblocks 23 599 | if '.resblocks.23.' in key: 600 | continue 601 | new_key = convert_key(key) 602 | if new_key is None: 603 | continue 604 | new_sd[new_key] = checkpoint[key] 605 | 606 | # attnの変換 607 | for key in keys: 608 | if '.resblocks.23.' in key: 609 | continue 610 | if '.resblocks' in key and '.attn.in_proj_' in key: 611 | # 三つに分割 612 | values = torch.chunk(checkpoint[key], 3) 613 | 614 | key_suffix = ".weight" if "weight" in key else ".bias" 615 | key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") 616 | key_pfx = key_pfx.replace("_weight", "") 617 | key_pfx = key_pfx.replace("_bias", "") 618 | key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") 619 | new_sd[key_pfx + "q_proj" + key_suffix] = values[0] 620 | new_sd[key_pfx + "k_proj" + key_suffix] = values[1] 621 | new_sd[key_pfx + "v_proj" + key_suffix] = values[2] 622 | 623 | # position_idsの追加 624 | new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64) 625 | return new_sd 626 | 627 | # endregion 628 | 629 | 630 | # region Diffusers->StableDiffusion の変換コード 631 | # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) 632 | 633 | def conv_transformer_to_linear(checkpoint): 634 | keys = list(checkpoint.keys()) 635 | tf_keys = ["proj_in.weight", "proj_out.weight"] 636 | for key in keys: 637 | if ".".join(key.split(".")[-2:]) in tf_keys: 638 | if checkpoint[key].ndim > 2: 639 | checkpoint[key] = checkpoint[key][:, :, 0, 0] 640 | 641 | 642 | def convert_unet_state_dict_to_sd(v2, unet_state_dict): 643 | unet_conversion_map = [ 644 | # (stable-diffusion, HF Diffusers) 645 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), 646 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), 647 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), 648 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), 649 | ("input_blocks.0.0.weight", "conv_in.weight"), 650 | ("input_blocks.0.0.bias", "conv_in.bias"), 651 | ("out.0.weight", "conv_norm_out.weight"), 652 | ("out.0.bias", "conv_norm_out.bias"), 653 | ("out.2.weight", "conv_out.weight"), 654 | ("out.2.bias", "conv_out.bias"), 655 | ] 656 | 657 | unet_conversion_map_resnet = [ 658 | # (stable-diffusion, HF Diffusers) 659 | ("in_layers.0", "norm1"), 660 | ("in_layers.2", "conv1"), 661 | ("out_layers.0", "norm2"), 662 | ("out_layers.3", "conv2"), 663 | ("emb_layers.1", "time_emb_proj"), 664 | ("skip_connection", "conv_shortcut"), 665 | ] 666 | 667 | unet_conversion_map_layer = [] 668 | for i in range(4): 669 | # loop over downblocks/upblocks 670 | 671 | for j in range(2): 672 | # loop over resnets/attentions for downblocks 673 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 674 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 675 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 676 | 677 | if i < 3: 678 | # no attention layers in down_blocks.3 679 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 680 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 681 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 682 | 683 | for j in range(3): 684 | # loop over resnets/attentions for upblocks 685 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 686 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 687 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 688 | 689 | if i > 0: 690 | # no attention layers in up_blocks.0 691 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 692 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 693 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 694 | 695 | if i < 3: 696 | # no downsample in down_blocks.3 697 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 698 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 699 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) 700 | 701 | # no upsample in up_blocks.3 702 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 703 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." 704 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) 705 | 706 | hf_mid_atn_prefix = "mid_block.attentions.0." 707 | sd_mid_atn_prefix = "middle_block.1." 708 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 709 | 710 | for j in range(2): 711 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 712 | sd_mid_res_prefix = f"middle_block.{2*j}." 713 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 714 | 715 | # buyer beware: this is a *brittle* function, 716 | # and correct output requires that all of these pieces interact in 717 | # the exact order in which I have arranged them. 718 | mapping = {k: k for k in unet_state_dict.keys()} 719 | for sd_name, hf_name in unet_conversion_map: 720 | mapping[hf_name] = sd_name 721 | for k, v in mapping.items(): 722 | if "resnets" in k: 723 | for sd_part, hf_part in unet_conversion_map_resnet: 724 | v = v.replace(hf_part, sd_part) 725 | mapping[k] = v 726 | for k, v in mapping.items(): 727 | for sd_part, hf_part in unet_conversion_map_layer: 728 | v = v.replace(hf_part, sd_part) 729 | mapping[k] = v 730 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} 731 | 732 | if v2: 733 | conv_transformer_to_linear(new_state_dict) 734 | 735 | return new_state_dict 736 | 737 | 738 | # ================# 739 | # VAE Conversion # 740 | # ================# 741 | 742 | def reshape_weight_for_sd(w): 743 | # convert HF linear weights to SD conv2d weights 744 | return w.reshape(*w.shape, 1, 1) 745 | 746 | 747 | def convert_vae_state_dict(vae_state_dict): 748 | vae_conversion_map = [ 749 | # (stable-diffusion, HF Diffusers) 750 | ("nin_shortcut", "conv_shortcut"), 751 | ("norm_out", "conv_norm_out"), 752 | ("mid.attn_1.", "mid_block.attentions.0."), 753 | ] 754 | 755 | for i in range(4): 756 | # down_blocks have two resnets 757 | for j in range(2): 758 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." 759 | sd_down_prefix = f"encoder.down.{i}.block.{j}." 760 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) 761 | 762 | if i < 3: 763 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." 764 | sd_downsample_prefix = f"down.{i}.downsample." 765 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) 766 | 767 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 768 | sd_upsample_prefix = f"up.{3-i}.upsample." 769 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) 770 | 771 | # up_blocks have three resnets 772 | # also, up blocks in hf are numbered in reverse from sd 773 | for j in range(3): 774 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." 775 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." 776 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) 777 | 778 | # this part accounts for mid blocks in both the encoder and the decoder 779 | for i in range(2): 780 | hf_mid_res_prefix = f"mid_block.resnets.{i}." 781 | sd_mid_res_prefix = f"mid.block_{i+1}." 782 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) 783 | 784 | vae_conversion_map_attn = [ 785 | # (stable-diffusion, HF Diffusers) 786 | ("norm.", "group_norm."), 787 | ("q.", "query."), 788 | ("k.", "key."), 789 | ("v.", "value."), 790 | ("proj_out.", "proj_attn."), 791 | ] 792 | 793 | mapping = {k: k for k in vae_state_dict.keys()} 794 | for k, v in mapping.items(): 795 | for sd_part, hf_part in vae_conversion_map: 796 | v = v.replace(hf_part, sd_part) 797 | mapping[k] = v 798 | for k, v in mapping.items(): 799 | if "attentions" in k: 800 | for sd_part, hf_part in vae_conversion_map_attn: 801 | v = v.replace(hf_part, sd_part) 802 | mapping[k] = v 803 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} 804 | weights_to_convert = ["q", "k", "v", "proj_out"] 805 | for k, v in new_state_dict.items(): 806 | for weight_name in weights_to_convert: 807 | if f"mid.attn_1.{weight_name}.weight" in k: 808 | # print(f"Reshaping {k} for SD format") 809 | new_state_dict[k] = reshape_weight_for_sd(v) 810 | 811 | return new_state_dict 812 | 813 | 814 | # endregion 815 | 816 | # region 自作のモデル読み書きなど 817 | 818 | def is_safetensors(path): 819 | return os.path.splitext(path)[1].lower() == '.safetensors' 820 | 821 | 822 | def load_checkpoint_with_text_encoder_conversion(ckpt_path): 823 | # text encoderの格納形式が違うモデルに対応する ('text_model'がない) 824 | TEXT_ENCODER_KEY_REPLACEMENTS = [ 825 | ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), 826 | ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'), 827 | ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') 828 | ] 829 | 830 | if is_safetensors(ckpt_path): 831 | checkpoint = None 832 | state_dict = load_file(ckpt_path, "cpu") 833 | else: 834 | checkpoint = torch.load(ckpt_path, map_location="cpu") 835 | if "state_dict" in checkpoint: 836 | state_dict = checkpoint["state_dict"] 837 | else: 838 | state_dict = checkpoint 839 | checkpoint = None 840 | 841 | key_reps = [] 842 | for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: 843 | for key in state_dict.keys(): 844 | if key.startswith(rep_from): 845 | new_key = rep_to + key[len(rep_from):] 846 | key_reps.append((key, new_key)) 847 | 848 | for key, new_key in key_reps: 849 | state_dict[new_key] = state_dict[key] 850 | del state_dict[key] 851 | 852 | return checkpoint, state_dict 853 | 854 | 855 | # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 856 | def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): 857 | _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) 858 | if dtype is not None: 859 | for k, v in state_dict.items(): 860 | if type(v) is torch.Tensor: 861 | state_dict[k] = v.to(dtype) 862 | 863 | # Convert the UNet2DConditionModel model. 864 | unet_config = create_unet_diffusers_config(v2) 865 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) 866 | 867 | unet = UNet2DConditionModel(**unet_config) 868 | info = unet.load_state_dict(converted_unet_checkpoint) 869 | print("loading u-net:", info) 870 | 871 | # Convert the VAE model. 872 | vae_config = create_vae_diffusers_config() 873 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) 874 | 875 | vae = AutoencoderKL(**vae_config) 876 | info = vae.load_state_dict(converted_vae_checkpoint) 877 | print("loadint vae:", info) 878 | 879 | # convert text_model 880 | if v2: 881 | converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) 882 | cfg = CLIPTextConfig( 883 | vocab_size=49408, 884 | hidden_size=1024, 885 | intermediate_size=4096, 886 | num_hidden_layers=23, 887 | num_attention_heads=16, 888 | max_position_embeddings=77, 889 | hidden_act="gelu", 890 | layer_norm_eps=1e-05, 891 | dropout=0.0, 892 | attention_dropout=0.0, 893 | initializer_range=0.02, 894 | initializer_factor=1.0, 895 | pad_token_id=1, 896 | bos_token_id=0, 897 | eos_token_id=2, 898 | model_type="clip_text_model", 899 | projection_dim=512, 900 | torch_dtype="float32", 901 | transformers_version="4.25.0.dev0", 902 | ) 903 | text_model = CLIPTextModel._from_config(cfg) 904 | info = text_model.load_state_dict(converted_text_encoder_checkpoint) 905 | else: 906 | converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) 907 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") 908 | info = text_model.load_state_dict(converted_text_encoder_checkpoint) 909 | print("loading text encoder:", info) 910 | 911 | return text_model, vae, unet 912 | 913 | 914 | def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): 915 | def convert_key(key): 916 | # position_idsの除去 917 | if ".position_ids" in key: 918 | return None 919 | 920 | # common 921 | key = key.replace("text_model.encoder.", "transformer.") 922 | key = key.replace("text_model.", "") 923 | if "layers" in key: 924 | # resblocks conversion 925 | key = key.replace(".layers.", ".resblocks.") 926 | if ".layer_norm" in key: 927 | key = key.replace(".layer_norm", ".ln_") 928 | elif ".mlp." in key: 929 | key = key.replace(".fc1.", ".c_fc.") 930 | key = key.replace(".fc2.", ".c_proj.") 931 | elif '.self_attn.out_proj' in key: 932 | key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") 933 | elif '.self_attn.' in key: 934 | key = None # 特殊なので後で処理する 935 | else: 936 | raise ValueError(f"unexpected key in DiffUsers model: {key}") 937 | elif '.position_embedding' in key: 938 | key = key.replace("embeddings.position_embedding.weight", "positional_embedding") 939 | elif '.token_embedding' in key: 940 | key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") 941 | elif 'final_layer_norm' in key: 942 | key = key.replace("final_layer_norm", "ln_final") 943 | return key 944 | 945 | keys = list(checkpoint.keys()) 946 | new_sd = {} 947 | for key in keys: 948 | new_key = convert_key(key) 949 | if new_key is None: 950 | continue 951 | new_sd[new_key] = checkpoint[key] 952 | 953 | # attnの変換 954 | for key in keys: 955 | if 'layers' in key and 'q_proj' in key: 956 | # 三つを結合 957 | key_q = key 958 | key_k = key.replace("q_proj", "k_proj") 959 | key_v = key.replace("q_proj", "v_proj") 960 | 961 | value_q = checkpoint[key_q] 962 | value_k = checkpoint[key_k] 963 | value_v = checkpoint[key_v] 964 | value = torch.cat([value_q, value_k, value_v]) 965 | 966 | new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") 967 | new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") 968 | new_sd[new_key] = value 969 | 970 | # 最後の層などを捏造するか 971 | if make_dummy_weights: 972 | print("make dummy weights for resblock.23, text_projection and logit scale.") 973 | keys = list(new_sd.keys()) 974 | for key in keys: 975 | if key.startswith("transformer.resblocks.22."): 976 | new_sd[key.replace(".22.", ".23.")] = new_sd[key] 977 | 978 | # Diffusersに含まれない重みを作っておく 979 | new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) 980 | new_sd['logit_scale'] = torch.tensor(1) 981 | 982 | return new_sd 983 | 984 | 985 | def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): 986 | if ckpt_path is not None: 987 | # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む 988 | checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) 989 | if checkpoint is None: # safetensors または state_dictのckpt 990 | checkpoint = {} 991 | strict = False 992 | else: 993 | strict = True 994 | if "state_dict" in state_dict: 995 | del state_dict["state_dict"] 996 | else: 997 | # 新しく作る 998 | checkpoint = {} 999 | state_dict = {} 1000 | strict = False 1001 | 1002 | def update_sd(prefix, sd): 1003 | for k, v in sd.items(): 1004 | key = prefix + k 1005 | assert not strict or key in state_dict, f"Illegal key in save SD: {key}" 1006 | if save_dtype is not None: 1007 | v = v.detach().clone().to("cpu").to(save_dtype) 1008 | state_dict[key] = v 1009 | 1010 | # Convert the UNet model 1011 | unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) 1012 | update_sd("model.diffusion_model.", unet_state_dict) 1013 | 1014 | # Convert the text encoder model 1015 | if v2: 1016 | make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる 1017 | text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) 1018 | update_sd("cond_stage_model.model.", text_enc_dict) 1019 | else: 1020 | text_enc_dict = text_encoder.state_dict() 1021 | update_sd("cond_stage_model.transformer.", text_enc_dict) 1022 | 1023 | # Convert the VAE 1024 | if vae is not None: 1025 | vae_dict = convert_vae_state_dict(vae.state_dict()) 1026 | update_sd("first_stage_model.", vae_dict) 1027 | 1028 | # Put together new checkpoint 1029 | key_count = len(state_dict.keys()) 1030 | new_ckpt = {'state_dict': state_dict} 1031 | 1032 | if 'epoch' in checkpoint: 1033 | epochs += checkpoint['epoch'] 1034 | if 'global_step' in checkpoint: 1035 | steps += checkpoint['global_step'] 1036 | 1037 | new_ckpt['epoch'] = epochs 1038 | new_ckpt['global_step'] = steps 1039 | 1040 | if is_safetensors(output_file): 1041 | # TODO Tensor以外のdictの値を削除したほうがいいか 1042 | save_file(state_dict, output_file) 1043 | else: 1044 | torch.save(new_ckpt, output_file) 1045 | 1046 | return key_count 1047 | 1048 | 1049 | def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): 1050 | if vae is None: 1051 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") 1052 | pipeline = StableDiffusionPipeline( 1053 | unet=unet, 1054 | text_encoder=text_encoder, 1055 | vae=vae, 1056 | scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"), 1057 | tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"), 1058 | safety_checker=None, 1059 | feature_extractor=None, 1060 | requires_safety_checker=None, 1061 | ) 1062 | pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) 1063 | 1064 | 1065 | VAE_PREFIX = "first_stage_model." 1066 | 1067 | 1068 | def load_vae(vae_id, dtype): 1069 | print(f"load VAE: {vae_id}") 1070 | if os.path.isdir(vae_id) or not os.path.isfile(vae_id): 1071 | # Diffusers local/remote 1072 | try: 1073 | vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) 1074 | except EnvironmentError as e: 1075 | print(f"exception occurs in loading vae: {e}") 1076 | print("retry with subfolder='vae'") 1077 | vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) 1078 | return vae 1079 | 1080 | # local 1081 | vae_config = create_vae_diffusers_config() 1082 | 1083 | if vae_id.endswith(".bin"): 1084 | # SD 1.5 VAE on Huggingface 1085 | vae_sd = torch.load(vae_id, map_location="cpu") 1086 | converted_vae_checkpoint = vae_sd 1087 | else: 1088 | # StableDiffusion 1089 | vae_model = torch.load(vae_id, map_location="cpu") 1090 | vae_sd = vae_model['state_dict'] 1091 | 1092 | # vae only or full model 1093 | full_model = False 1094 | for vae_key in vae_sd: 1095 | if vae_key.startswith(VAE_PREFIX): 1096 | full_model = True 1097 | break 1098 | if not full_model: 1099 | sd = {} 1100 | for key, value in vae_sd.items(): 1101 | sd[VAE_PREFIX + key] = value 1102 | vae_sd = sd 1103 | del sd 1104 | 1105 | # Convert the VAE model. 1106 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) 1107 | 1108 | vae = AutoencoderKL(**vae_config) 1109 | vae.load_state_dict(converted_vae_checkpoint) 1110 | return vae 1111 | 1112 | 1113 | def get_epoch_ckpt_name(use_safetensors, epoch): 1114 | return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt") 1115 | 1116 | 1117 | def get_last_ckpt_name(use_safetensors): 1118 | return f"last" + (".safetensors" if use_safetensors else ".ckpt") 1119 | 1120 | 1121 | # endregion 1122 | 1123 | 1124 | def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): 1125 | max_width, max_height = max_reso 1126 | max_area = (max_width // divisible) * (max_height // divisible) 1127 | 1128 | resos = set() 1129 | 1130 | size = int(math.sqrt(max_area)) * divisible 1131 | resos.add((size, size)) 1132 | 1133 | size = min_size 1134 | while size <= max_size: 1135 | width = size 1136 | height = min(max_size, (max_area // (width // divisible)) * divisible) 1137 | resos.add((width, height)) 1138 | resos.add((height, width)) 1139 | 1140 | # # make additional resos 1141 | # if width >= height and width - divisible >= min_size: 1142 | # resos.add((width - divisible, height)) 1143 | # resos.add((height, width - divisible)) 1144 | # if height >= width and height - divisible >= min_size: 1145 | # resos.add((width, height - divisible)) 1146 | # resos.add((height - divisible, width)) 1147 | 1148 | size += divisible 1149 | 1150 | resos = list(resos) 1151 | resos.sort() 1152 | 1153 | aspect_ratios = [w / h for w, h in resos] 1154 | return resos, aspect_ratios 1155 | 1156 | 1157 | if __name__ == '__main__': 1158 | resos, aspect_ratios = make_bucket_resolutions((512, 768)) 1159 | print(len(resos)) 1160 | print(resos) 1161 | print(aspect_ratios) 1162 | 1163 | ars = set() 1164 | for ar in aspect_ratios: 1165 | if ar in ars: 1166 | print("error! duplicate ar:", ar) 1167 | ars.add(ar) 1168 | --------------------------------------------------------------------------------