├── library ├── __init__.py ├── utils.py ├── huggingface_util.py ├── hypernetwork.py └── attention_processors.py ├── setup.py ├── finetune ├── blip │ ├── med_config.json │ └── blip.py ├── merge_dd_tags_to_metadata.py ├── merge_captions_to_metadata.py ├── hypernetwork_nai.py ├── clean_captions_and_tags.py ├── make_captions_by_git.py ├── make_captions.py ├── merge_all_to_metadata.py └── prepare_buckets_latents.py ├── requirements.txt ├── tools ├── canny.py ├── merge_block_weighted.py ├── merge_vae.py ├── resize_images_to_resolution.py ├── convert_diffusers20_original_sd.py ├── cache_text_encoder_outputs.py ├── cache_latents.py └── detect_face_rotate.py ├── networks ├── check_lora_weights.py ├── extract_lora_from_dylora.py ├── lora_interrogator.py ├── merge_lora_old.py ├── svd_merge_lora.py ├── merge_lora.py ├── sdxl_merge_lora.py └── extract_lora_from_models.py ├── sdxl_train_textual_inversion.py ├── XTI_hijack.py ├── sdxl_train_network.py └── LICENSE.md /library/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name = "library", packages = find_packages()) -------------------------------------------------------------------------------- /library/utils.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from typing import * 3 | 4 | 5 | def fire_in_thread(f, *args, **kwargs): 6 | threading.Thread(target=f, args=args, kwargs=kwargs).start() -------------------------------------------------------------------------------- /finetune/blip/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.19.0 2 | salesforce-lavis==1.0.0 3 | transformers==4.30.2 4 | diffusers[torch]==0.18.2 5 | ftfy==6.1.1 6 | # albumentations==1.3.0 7 | opencv-python==4.7.0.68 8 | einops==0.6.0 9 | pytorch-lightning==1.9.0 10 | bitsandbytes==0.39.1 11 | safetensors==0.3.1 12 | toml==0.10.2 13 | voluptuous==0.13.1 14 | huggingface-hub==0.15.1 15 | wandb==0.15.7 16 | # for loading Diffusers' SDXL 17 | invisible-watermark==0.2.0 18 | open-clip-torch==2.20.0 19 | # for kohya trainer 20 | gallery-dl==1.25.6 21 | gdown==4.7.1 22 | imjoy-elfinder==0.1.61 23 | dadaptation==3.1 24 | lion-pytorch==0.1.2 25 | lycoris-lora==1.8.1.dev3 26 | # for kohya_ss library 27 | -e . 28 | -------------------------------------------------------------------------------- /tools/canny.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | 4 | 5 | def canny(args): 6 | img = cv2.imread(args.input) 7 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 8 | 9 | canny_img = cv2.Canny(img, args.thres1, args.thres2) 10 | # canny_img = 255 - canny_img 11 | 12 | cv2.imwrite(args.output, canny_img) 13 | print("done!") 14 | 15 | 16 | def setup_parser() -> argparse.ArgumentParser: 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--input", type=str, default=None, help="input path") 19 | parser.add_argument("--output", type=str, default=None, help="output path") 20 | parser.add_argument("--thres1", type=int, default=32, help="thres1") 21 | parser.add_argument("--thres2", type=int, default=224, help="thres2") 22 | 23 | return parser 24 | 25 | 26 | if __name__ == '__main__': 27 | parser = setup_parser() 28 | 29 | args = parser.parse_args() 30 | canny(args) 31 | -------------------------------------------------------------------------------- /networks/check_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from safetensors.torch import load_file 5 | 6 | 7 | def main(file): 8 | print(f"loading: {file}") 9 | if os.path.splitext(file)[1] == '.safetensors': 10 | sd = load_file(file) 11 | else: 12 | sd = torch.load(file, map_location='cpu') 13 | 14 | values = [] 15 | 16 | keys = list(sd.keys()) 17 | for key in keys: 18 | if 'lora_up' in key or 'lora_down' in key: 19 | values.append((key, sd[key])) 20 | print(f"number of LoRA modules: {len(values)}") 21 | 22 | for key, value in values: 23 | value = value.to(torch.float32) 24 | print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") 25 | 26 | 27 | def setup_parser() -> argparse.ArgumentParser: 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") 30 | 31 | return parser 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = setup_parser() 36 | 37 | args = parser.parse_args() 38 | 39 | main(args.file) 40 | -------------------------------------------------------------------------------- /library/huggingface_util.py: -------------------------------------------------------------------------------- 1 | from typing import Union, BinaryIO 2 | from huggingface_hub import HfApi 3 | from pathlib import Path 4 | import argparse 5 | import os 6 | from library.utils import fire_in_thread 7 | 8 | 9 | def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): 10 | api = HfApi( 11 | token=token, 12 | ) 13 | try: 14 | api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) 15 | return True 16 | except: 17 | return False 18 | 19 | 20 | def upload( 21 | args: argparse.Namespace, 22 | src: Union[str, Path, bytes, BinaryIO], 23 | dest_suffix: str = "", 24 | force_sync_upload: bool = False, 25 | ): 26 | repo_id = args.huggingface_repo_id 27 | repo_type = args.huggingface_repo_type 28 | token = args.huggingface_token 29 | path_in_repo = args.huggingface_path_in_repo + dest_suffix 30 | private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public" 31 | api = HfApi(token=token) 32 | if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token): 33 | try: 34 | api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) 35 | except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので 36 | print("===========================================") 37 | print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") 38 | print("===========================================") 39 | 40 | is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir()) 41 | 42 | def uploader(): 43 | try: 44 | if is_folder: 45 | api.upload_folder( 46 | repo_id=repo_id, 47 | repo_type=repo_type, 48 | folder_path=src, 49 | path_in_repo=path_in_repo, 50 | ) 51 | else: 52 | api.upload_file( 53 | repo_id=repo_id, 54 | repo_type=repo_type, 55 | path_or_fileobj=src, 56 | path_in_repo=path_in_repo, 57 | ) 58 | except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので 59 | print("===========================================") 60 | print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") 61 | print("===========================================") 62 | 63 | if args.async_upload and not force_sync_upload: 64 | fire_in_thread(uploader) 65 | else: 66 | uploader() 67 | 68 | 69 | def list_dir( 70 | repo_id: str, 71 | subfolder: str, 72 | repo_type: str, 73 | revision: str = "main", 74 | token: str = None, 75 | ): 76 | api = HfApi( 77 | token=token, 78 | ) 79 | repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) 80 | file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)] 81 | return file_list 82 | -------------------------------------------------------------------------------- /finetune/merge_dd_tags_to_metadata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from typing import List 5 | from tqdm import tqdm 6 | import library.train_util as train_util 7 | import os 8 | 9 | def main(args): 10 | assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" 11 | 12 | train_data_dir_path = Path(args.train_data_dir) 13 | image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 14 | print(f"found {len(image_paths)} images.") 15 | 16 | if args.in_json is None and Path(args.out_json).is_file(): 17 | args.in_json = args.out_json 18 | 19 | if args.in_json is not None: 20 | print(f"loading existing metadata: {args.in_json}") 21 | metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) 22 | print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") 23 | else: 24 | print("new metadata will be created / 新しいメタデータファイルが作成されます") 25 | metadata = {} 26 | 27 | print("merge tags to metadata json.") 28 | for image_path in tqdm(image_paths): 29 | tags_path = image_path.with_suffix(args.caption_extension) 30 | tags = tags_path.read_text(encoding='utf-8').strip() 31 | 32 | if not os.path.exists(tags_path): 33 | tags_path = os.path.join(image_path, args.caption_extension) 34 | 35 | image_key = str(image_path) if args.full_path else image_path.stem 36 | if image_key not in metadata: 37 | metadata[image_key] = {} 38 | 39 | metadata[image_key]['tags'] = tags 40 | if args.debug: 41 | print(image_key, tags) 42 | 43 | # metadataを書き出して終わり 44 | print(f"writing metadata: {args.out_json}") 45 | Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') 46 | 47 | print("done!") 48 | 49 | 50 | def setup_parser() -> argparse.ArgumentParser: 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 53 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 54 | parser.add_argument("--in_json", type=str, 55 | help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") 56 | parser.add_argument("--full_path", action="store_true", 57 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") 58 | parser.add_argument("--recursive", action="store_true", 59 | help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") 60 | parser.add_argument("--caption_extension", type=str, default=".txt", 61 | help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子") 62 | parser.add_argument("--debug", action="store_true", help="debug mode, print tags") 63 | 64 | return parser 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = setup_parser() 69 | 70 | args = parser.parse_args() 71 | main(args) 72 | -------------------------------------------------------------------------------- /finetune/merge_captions_to_metadata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from typing import List 5 | from tqdm import tqdm 6 | import library.train_util as train_util 7 | import os 8 | 9 | def main(args): 10 | assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" 11 | 12 | train_data_dir_path = Path(args.train_data_dir) 13 | image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 14 | print(f"found {len(image_paths)} images.") 15 | 16 | if args.in_json is None and Path(args.out_json).is_file(): 17 | args.in_json = args.out_json 18 | 19 | if args.in_json is not None: 20 | print(f"loading existing metadata: {args.in_json}") 21 | metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) 22 | print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") 23 | else: 24 | print("new metadata will be created / 新しいメタデータファイルが作成されます") 25 | metadata = {} 26 | 27 | print("merge caption texts to metadata json.") 28 | for image_path in tqdm(image_paths): 29 | caption_path = image_path.with_suffix(args.caption_extension) 30 | caption = caption_path.read_text(encoding='utf-8').strip() 31 | 32 | if not os.path.exists(caption_path): 33 | caption_path = os.path.join(image_path, args.caption_extension) 34 | 35 | image_key = str(image_path) if args.full_path else image_path.stem 36 | if image_key not in metadata: 37 | metadata[image_key] = {} 38 | 39 | metadata[image_key]['caption'] = caption 40 | if args.debug: 41 | print(image_key, caption) 42 | 43 | # metadataを書き出して終わり 44 | print(f"writing metadata: {args.out_json}") 45 | Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') 46 | print("done!") 47 | 48 | 49 | def setup_parser() -> argparse.ArgumentParser: 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 52 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 53 | parser.add_argument("--in_json", type=str, 54 | help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") 55 | parser.add_argument("--caption_extention", type=str, default=None, 56 | help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)") 57 | parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子") 58 | parser.add_argument("--full_path", action="store_true", 59 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") 60 | parser.add_argument("--recursive", action="store_true", 61 | help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") 62 | parser.add_argument("--debug", action="store_true", help="debug mode") 63 | 64 | return parser 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = setup_parser() 69 | 70 | args = parser.parse_args() 71 | 72 | # スペルミスしていたオプションを復元する 73 | if args.caption_extention is not None: 74 | args.caption_extension = args.caption_extention 75 | 76 | main(args) 77 | -------------------------------------------------------------------------------- /finetune/hypernetwork_nai.py: -------------------------------------------------------------------------------- 1 | # NAI compatible 2 | 3 | import torch 4 | 5 | 6 | class HypernetworkModule(torch.nn.Module): 7 | def __init__(self, dim, multiplier=1.0): 8 | super().__init__() 9 | 10 | linear1 = torch.nn.Linear(dim, dim * 2) 11 | linear2 = torch.nn.Linear(dim * 2, dim) 12 | linear1.weight.data.normal_(mean=0.0, std=0.01) 13 | linear1.bias.data.zero_() 14 | linear2.weight.data.normal_(mean=0.0, std=0.01) 15 | linear2.bias.data.zero_() 16 | linears = [linear1, linear2] 17 | 18 | self.linear = torch.nn.Sequential(*linears) 19 | self.multiplier = multiplier 20 | 21 | def forward(self, x): 22 | return x + self.linear(x) * self.multiplier 23 | 24 | 25 | class Hypernetwork(torch.nn.Module): 26 | enable_sizes = [320, 640, 768, 1280] 27 | # return self.modules[Hypernetwork.enable_sizes.index(size)] 28 | 29 | def __init__(self, multiplier=1.0) -> None: 30 | super().__init__() 31 | self.modules = [] 32 | for size in Hypernetwork.enable_sizes: 33 | self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier))) 34 | self.register_module(f"{size}_0", self.modules[-1][0]) 35 | self.register_module(f"{size}_1", self.modules[-1][1]) 36 | 37 | def apply_to_stable_diffusion(self, text_encoder, vae, unet): 38 | blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks 39 | for block in blocks: 40 | for subblk in block: 41 | if 'SpatialTransformer' in str(type(subblk)): 42 | for tf_block in subblk.transformer_blocks: 43 | for attn in [tf_block.attn1, tf_block.attn2]: 44 | size = attn.context_dim 45 | if size in Hypernetwork.enable_sizes: 46 | attn.hypernetwork = self 47 | else: 48 | attn.hypernetwork = None 49 | 50 | def apply_to_diffusers(self, text_encoder, vae, unet): 51 | blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks 52 | for block in blocks: 53 | if hasattr(block, 'attentions'): 54 | for subblk in block.attentions: 55 | if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~ 56 | for tf_block in subblk.transformer_blocks: 57 | for attn in [tf_block.attn1, tf_block.attn2]: 58 | size = attn.to_k.in_features 59 | if size in Hypernetwork.enable_sizes: 60 | attn.hypernetwork = self 61 | else: 62 | attn.hypernetwork = None 63 | return True # TODO error checking 64 | 65 | def forward(self, x, context): 66 | size = context.shape[-1] 67 | assert size in Hypernetwork.enable_sizes 68 | module = self.modules[Hypernetwork.enable_sizes.index(size)] 69 | return module[0].forward(context), module[1].forward(context) 70 | 71 | def load_from_state_dict(self, state_dict): 72 | # old ver to new ver 73 | changes = { 74 | 'linear1.bias': 'linear.0.bias', 75 | 'linear1.weight': 'linear.0.weight', 76 | 'linear2.bias': 'linear.1.bias', 77 | 'linear2.weight': 'linear.1.weight', 78 | } 79 | for key_from, key_to in changes.items(): 80 | if key_from in state_dict: 81 | state_dict[key_to] = state_dict[key_from] 82 | del state_dict[key_from] 83 | 84 | for size, sd in state_dict.items(): 85 | if type(size) == int: 86 | self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True) 87 | self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True) 88 | return True 89 | 90 | def get_state_dict(self): 91 | state_dict = {} 92 | for i, size in enumerate(Hypernetwork.enable_sizes): 93 | sd0 = self.modules[i][0].state_dict() 94 | sd1 = self.modules[i][1].state_dict() 95 | state_dict[size] = [sd0, sd1] 96 | return state_dict 97 | -------------------------------------------------------------------------------- /tools/merge_block_weighted.py: -------------------------------------------------------------------------------- 1 | # original code: https://github.com/eyriewow/merge-models 2 | 3 | import os 4 | import argparse 5 | import re 6 | import torch 7 | from tqdm import tqdm 8 | 9 | 10 | NUM_INPUT_BLOCKS = 12 11 | NUM_MID_BLOCK = 1 12 | NUM_OUTPUT_BLOCKS = 12 13 | NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS 14 | 15 | 16 | def merge(args): 17 | if args.weights is None: 18 | weights = None 19 | else: 20 | weights = [float(w) for w in args.weights.split(',')] 21 | if len(weights) != NUM_TOTAL_BLOCKS: 22 | print(f"weights value must be {NUM_TOTAL_BLOCKS}.") 23 | return 24 | 25 | device = args.device 26 | print("loading", args.model_0) 27 | model_0 = torch.load(args.model_0, map_location=device) 28 | print("loading", args.model_1) 29 | model_1 = torch.load(args.model_1, map_location=device) 30 | theta_0 = model_0["state_dict"] 31 | theta_1 = model_1["state_dict"] 32 | alpha = args.base_alpha 33 | 34 | output_file = f'{args.output}-{str(alpha)[2:] + "0"}-bw.ckpt' 35 | 36 | # check if output file already exists, ask to overwrite 37 | if os.path.isfile(output_file): 38 | print("Output file already exists. Overwrite? (y/n)") 39 | while True: 40 | overwrite = input() 41 | if overwrite == "y": 42 | break 43 | elif overwrite == "n": 44 | print("Exiting...") 45 | return 46 | else: 47 | print("Please enter y or n") 48 | 49 | re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12 50 | re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1 51 | re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12 52 | 53 | for key in (tqdm(theta_0.keys(), desc="Stage 1/2") if not args.verbose else theta_0.keys()): 54 | if "model" in key and key in theta_1: 55 | current_alpha = alpha 56 | 57 | # check weighted and U-Net or not 58 | if weights is not None and 'model.diffusion_model.' in key: 59 | # check block index 60 | weight_index = -1 61 | 62 | if 'time_embed' in key: 63 | weight_index = 0 # before input blocks 64 | elif '.out.' in key: 65 | weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks 66 | else: 67 | m = re_inp.search(key) 68 | if m: 69 | inp_idx = int(m.groups()[0]) 70 | weight_index = inp_idx 71 | else: 72 | m = re_mid.search(key) 73 | if m: 74 | weight_index = NUM_INPUT_BLOCKS 75 | else: 76 | m = re_out.search(key) 77 | if m: 78 | out_idx = int(m.groups()[0]) 79 | weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx 80 | 81 | if weight_index >= NUM_TOTAL_BLOCKS: 82 | print(f"error. illegal block index: {key}") 83 | if weight_index >= 0: 84 | current_alpha = weights[weight_index] 85 | if args.verbose: 86 | print(f"weighted '{key}': {current_alpha}") 87 | 88 | theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key] 89 | 90 | for key in tqdm(theta_1.keys(), desc="Stage 2/2"): 91 | if "model" in key and key not in theta_0: 92 | theta_0[key] = theta_1[key] 93 | 94 | print("Saving...") 95 | 96 | torch.save({"state_dict": theta_0}, output_file) 97 | 98 | print("Done!") 99 | 100 | 101 | if __name__ == '__main__': 102 | parser = argparse.ArgumentParser(description="Merge two models with weights for each block") 103 | parser.add_argument("model_0", type=str, help="Path to model 0") 104 | parser.add_argument("model_1", type=str, help="Path to model 1") 105 | parser.add_argument("--base_alpha", type=float, 106 | help="Alpha value (for model 0) except U-Net, optional, defaults to 0.5", default=0.5, required=False) 107 | parser.add_argument("--output", type=str, help="Output file name, without extension", default="merged", required=False) 108 | parser.add_argument("--device", type=str, help="Device to use, defaults to cpu", default="cpu", required=False) 109 | parser.add_argument("--weights", type=str, 110 | help=f"comma separated {NUM_TOTAL_BLOCKS} weights value (for model 0) for each U-Net block", default=None, required=False) 111 | parser.add_argument("--verbose", action='store_true', help="show each block weight", required=False) 112 | 113 | args = parser.parse_args() 114 | merge(args) 115 | -------------------------------------------------------------------------------- /networks/extract_lora_from_dylora.py: -------------------------------------------------------------------------------- 1 | # Convert LoRA to different rank approximation (should only be used to go to lower rank) 2 | # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py 3 | # Thanks to cloneofsimo 4 | 5 | import argparse 6 | import math 7 | import os 8 | import torch 9 | from safetensors.torch import load_file, save_file, safe_open 10 | from tqdm import tqdm 11 | from library import train_util, model_util 12 | import numpy as np 13 | 14 | 15 | def load_state_dict(file_name): 16 | if model_util.is_safetensors(file_name): 17 | sd = load_file(file_name) 18 | with safe_open(file_name, framework="pt") as f: 19 | metadata = f.metadata() 20 | else: 21 | sd = torch.load(file_name, map_location="cpu") 22 | metadata = None 23 | 24 | return sd, metadata 25 | 26 | 27 | def save_to_file(file_name, model, metadata): 28 | if model_util.is_safetensors(file_name): 29 | save_file(model, file_name, metadata) 30 | else: 31 | torch.save(model, file_name) 32 | 33 | 34 | def split_lora_model(lora_sd, unit): 35 | max_rank = 0 36 | 37 | # Extract loaded lora dim and alpha 38 | for key, value in lora_sd.items(): 39 | if "lora_down" in key: 40 | rank = value.size()[0] 41 | if rank > max_rank: 42 | max_rank = rank 43 | print(f"Max rank: {max_rank}") 44 | 45 | rank = unit 46 | split_models = [] 47 | new_alpha = None 48 | while rank < max_rank: 49 | print(f"Splitting rank {rank}") 50 | new_sd = {} 51 | for key, value in lora_sd.items(): 52 | if "lora_down" in key: 53 | new_sd[key] = value[:rank].contiguous() 54 | elif "lora_up" in key: 55 | new_sd[key] = value[:, :rank].contiguous() 56 | else: 57 | # なぜかscaleするとおかしくなる…… 58 | # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0] 59 | # scale = math.sqrt(this_rank / rank) # rank is > unit 60 | # print(key, value.size(), this_rank, rank, value, scale) 61 | # new_alpha = value * scale # always same 62 | # new_sd[key] = new_alpha 63 | new_sd[key] = value 64 | 65 | split_models.append((new_sd, rank, new_alpha)) 66 | rank += unit 67 | 68 | return max_rank, split_models 69 | 70 | 71 | def split(args): 72 | print("loading Model...") 73 | lora_sd, metadata = load_state_dict(args.model) 74 | 75 | print("Splitting Model...") 76 | original_rank, split_models = split_lora_model(lora_sd, args.unit) 77 | 78 | comment = metadata.get("ss_training_comment", "") 79 | for state_dict, new_rank, new_alpha in split_models: 80 | # update metadata 81 | if metadata is None: 82 | new_metadata = {} 83 | else: 84 | new_metadata = metadata.copy() 85 | 86 | new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}" 87 | new_metadata["ss_network_dim"] = str(new_rank) 88 | # new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy()) 89 | 90 | model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) 91 | metadata["sshs_model_hash"] = model_hash 92 | metadata["sshs_legacy_hash"] = legacy_hash 93 | 94 | filename, ext = os.path.splitext(args.save_to) 95 | model_file_name = filename + f"-{new_rank:04d}{ext}" 96 | 97 | print(f"saving model to: {model_file_name}") 98 | save_to_file(model_file_name, state_dict, new_metadata) 99 | 100 | 101 | def setup_parser() -> argparse.ArgumentParser: 102 | parser = argparse.ArgumentParser() 103 | 104 | parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ") 105 | parser.add_argument( 106 | "--save_to", 107 | type=str, 108 | default=None, 109 | help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors", 110 | ) 111 | parser.add_argument( 112 | "--model", 113 | type=str, 114 | default=None, 115 | help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors", 116 | ) 117 | 118 | return parser 119 | 120 | 121 | if __name__ == "__main__": 122 | parser = setup_parser() 123 | 124 | args = parser.parse_args() 125 | split(args) 126 | -------------------------------------------------------------------------------- /tools/merge_vae.py: -------------------------------------------------------------------------------- 1 | # License of this file is ASL 2.0 2 | 3 | import argparse 4 | import torch 5 | 6 | 7 | VAE_PREFIX = "first_stage_model." 8 | 9 | # copy from convert_diffusers_to_original_stable_diffusion.py ASL 2.0 10 | 11 | # ================# 12 | # VAE Conversion # 13 | # ================# 14 | 15 | vae_conversion_map = [ 16 | # (stable-diffusion, HF Diffusers) 17 | ("nin_shortcut", "conv_shortcut"), 18 | ("norm_out", "conv_norm_out"), 19 | ("mid.attn_1.", "mid_block.attentions.0."), 20 | ] 21 | 22 | for i in range(4): 23 | # down_blocks have two resnets 24 | for j in range(2): 25 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." 26 | sd_down_prefix = f"encoder.down.{i}.block.{j}." 27 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) 28 | 29 | if i < 3: 30 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." 31 | sd_downsample_prefix = f"down.{i}.downsample." 32 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) 33 | 34 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 35 | sd_upsample_prefix = f"up.{3-i}.upsample." 36 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) 37 | 38 | # up_blocks have three resnets 39 | # also, up blocks in hf are numbered in reverse from sd 40 | for j in range(3): 41 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." 42 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." 43 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) 44 | 45 | # this part accounts for mid blocks in both the encoder and the decoder 46 | for i in range(2): 47 | hf_mid_res_prefix = f"mid_block.resnets.{i}." 48 | sd_mid_res_prefix = f"mid.block_{i+1}." 49 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) 50 | 51 | 52 | vae_conversion_map_attn = [ 53 | # (stable-diffusion, HF Diffusers) 54 | ("norm.", "group_norm."), 55 | ("q.", "query."), 56 | ("k.", "key."), 57 | ("v.", "value."), 58 | ("proj_out.", "proj_attn."), 59 | ] 60 | 61 | 62 | def reshape_weight_for_sd(w): 63 | # convert HF linear weights to SD conv2d weights 64 | return w.reshape(*w.shape, 1, 1) 65 | 66 | 67 | def convert_vae_state_dict(vae_state_dict): 68 | mapping = {k: k for k in vae_state_dict.keys()} 69 | for k, v in mapping.items(): 70 | for sd_part, hf_part in vae_conversion_map: 71 | v = v.replace(hf_part, sd_part) 72 | mapping[k] = v 73 | for k, v in mapping.items(): 74 | if "attentions" in k: 75 | for sd_part, hf_part in vae_conversion_map_attn: 76 | v = v.replace(hf_part, sd_part) 77 | mapping[k] = v 78 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} 79 | weights_to_convert = ["q", "k", "v", "proj_out"] 80 | for k, v in new_state_dict.items(): 81 | for weight_name in weights_to_convert: 82 | if f"mid.attn_1.{weight_name}.weight" in k: 83 | print(f"Reshaping {k} for SD format") 84 | new_state_dict[k] = reshape_weight_for_sd(v) 85 | return new_state_dict 86 | 87 | 88 | def convert_diffusers_vae(vae_path): 89 | vae_state_dict = torch.load(vae_path, map_location="cpu") 90 | vae_state_dict = convert_vae_state_dict(vae_state_dict) 91 | return vae_state_dict 92 | 93 | 94 | def merge_vae(ckpt, vae, output): 95 | print(f"load checkpoint: {ckpt}") 96 | model = torch.load(ckpt, map_location="cpu") 97 | sd = model['state_dict'] 98 | 99 | full_model = False 100 | 101 | print(f"load VAE: {vae}") 102 | if vae.endswith(".bin"): 103 | print("convert diffusers VAE to stablediffusion") 104 | vae_sd = convert_diffusers_vae(vae) 105 | else: 106 | vae_model = torch.load(vae, map_location="cpu") 107 | vae_sd = vae_model['state_dict'] 108 | 109 | # vae only or full model 110 | for vae_key in vae_sd: 111 | if vae_key.startswith(VAE_PREFIX): 112 | full_model = True 113 | break 114 | 115 | count = 0 116 | for vae_key in vae_sd: 117 | sd_key = vae_key 118 | if full_model: 119 | if not sd_key.startswith(VAE_PREFIX): 120 | continue 121 | else: 122 | if sd_key not in sd: 123 | sd_key = VAE_PREFIX + sd_key 124 | if sd_key not in sd: 125 | print(f"key not exists in model: {vae_key}") 126 | continue 127 | sd[sd_key] = vae_sd[vae_key] 128 | count += 1 129 | print(f"{count} weights are copied") 130 | 131 | print(f"saving checkpoint to: {output}") 132 | torch.save(model, output) 133 | 134 | 135 | if __name__ == "__main__": 136 | parser = argparse.ArgumentParser() 137 | parser.add_argument("ckpt", type=str, help="target checkpoint to replace VAE / マージ対象のモデルcheckpoint") 138 | parser.add_argument("vae", type=str, help="VAE/model checkpoint to merge / マージするVAEまたはモデルのcheckpoint") 139 | parser.add_argument("output", type=str, help="output checkoint / 出力先checkpoint") 140 | args = parser.parse_args() 141 | 142 | merge_vae(args.ckpt, args.vae, args.output) 143 | -------------------------------------------------------------------------------- /sdxl_train_textual_inversion.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import regex 5 | import torch 6 | import open_clip 7 | from library import sdxl_model_util, sdxl_train_util, train_util 8 | 9 | import train_textual_inversion 10 | 11 | 12 | class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer): 13 | def __init__(self): 14 | super().__init__() 15 | self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR 16 | 17 | def assert_extra_args(self, args, train_dataset_group): 18 | super().assert_extra_args(args, train_dataset_group) 19 | sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) 20 | 21 | def load_target_model(self, args, weight_dtype, accelerator): 22 | ( 23 | load_stable_diffusion_format, 24 | text_encoder1, 25 | text_encoder2, 26 | vae, 27 | unet, 28 | logit_scale, 29 | ckpt_info, 30 | ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype) 31 | 32 | self.load_stable_diffusion_format = load_stable_diffusion_format 33 | self.logit_scale = logit_scale 34 | self.ckpt_info = ckpt_info 35 | 36 | return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet 37 | 38 | def load_tokenizer(self, args): 39 | tokenizer = sdxl_train_util.load_tokenizers(args) 40 | return tokenizer 41 | 42 | def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): 43 | input_ids1 = batch["input_ids"] 44 | input_ids2 = batch["input_ids2"] 45 | with torch.enable_grad(): 46 | input_ids1 = input_ids1.to(accelerator.device) 47 | input_ids2 = input_ids2.to(accelerator.device) 48 | encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( 49 | args.max_token_length, 50 | input_ids1, 51 | input_ids2, 52 | tokenizers[0], 53 | tokenizers[1], 54 | text_encoders[0], 55 | text_encoders[1], 56 | None if not args.full_fp16 else weight_dtype, 57 | ) 58 | return encoder_hidden_states1, encoder_hidden_states2, pool2 59 | 60 | def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): 61 | noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype 62 | 63 | # get size embeddings 64 | orig_size = batch["original_sizes_hw"] 65 | crop_size = batch["crop_top_lefts"] 66 | target_size = batch["target_sizes_hw"] 67 | embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) 68 | 69 | # concat embeddings 70 | encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds 71 | vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) 72 | text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) 73 | 74 | noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) 75 | return noise_pred 76 | 77 | def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): 78 | sdxl_train_util.sample_images( 79 | accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement 80 | ) 81 | 82 | def save_weights(self, file, updated_embs, save_dtype): 83 | state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]} 84 | 85 | if save_dtype is not None: 86 | for key in list(state_dict.keys()): 87 | v = state_dict[key] 88 | v = v.detach().clone().to("cpu").to(save_dtype) 89 | state_dict[key] = v 90 | 91 | if os.path.splitext(file)[1] == ".safetensors": 92 | from safetensors.torch import save_file 93 | 94 | save_file(state_dict, file) 95 | else: 96 | torch.save(state_dict, file) 97 | 98 | def load_weights(self, file): 99 | if os.path.splitext(file)[1] == ".safetensors": 100 | from safetensors.torch import load_file 101 | 102 | data = load_file(file) 103 | else: 104 | data = torch.load(file, map_location="cpu") 105 | 106 | emb_l = data.get("clip_l", None) # ViT-L text encoder 1 107 | emb_g = data.get("clip_g", None) # BiG-G text encoder 2 108 | 109 | assert ( 110 | emb_l is not None or emb_g is not None 111 | ), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}" 112 | 113 | return [emb_l, emb_g] 114 | 115 | 116 | def setup_parser() -> argparse.ArgumentParser: 117 | parser = train_textual_inversion.setup_parser() 118 | # don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching 119 | # sdxl_train_util.add_sdxl_training_arguments(parser) 120 | return parser 121 | 122 | 123 | if __name__ == "__main__": 124 | parser = setup_parser() 125 | 126 | args = parser.parse_args() 127 | args = train_util.read_config_from_file(args, parser) 128 | 129 | trainer = SdxlTextualInversionTrainer() 130 | trainer.train(args) 131 | -------------------------------------------------------------------------------- /networks/lora_interrogator.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from tqdm import tqdm 4 | from library import model_util 5 | import library.train_util as train_util 6 | import argparse 7 | from transformers import CLIPTokenizer 8 | import torch 9 | 10 | import library.model_util as model_util 11 | import lora 12 | 13 | TOKENIZER_PATH = "openai/clip-vit-large-patch14" 14 | V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う 15 | 16 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | 19 | def interrogate(args): 20 | weights_dtype = torch.float16 21 | 22 | # いろいろ準備する 23 | print(f"loading SD model: {args.sd_model}") 24 | args.pretrained_model_name_or_path = args.sd_model 25 | args.vae = None 26 | text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE) 27 | 28 | print(f"loading LoRA: {args.model}") 29 | network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) 30 | 31 | # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい 32 | has_te_weight = False 33 | for key in weights_sd.keys(): 34 | if 'lora_te' in key: 35 | has_te_weight = True 36 | break 37 | if not has_te_weight: 38 | print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") 39 | return 40 | del vae 41 | 42 | print("loading tokenizer") 43 | if args.v2: 44 | tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") 45 | else: 46 | tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2) 47 | 48 | text_encoder.to(DEVICE, dtype=weights_dtype) 49 | text_encoder.eval() 50 | unet.to(DEVICE, dtype=weights_dtype) 51 | unet.eval() # U-Netは呼び出さないので不要だけど 52 | 53 | # トークンをひとつひとつ当たっていく 54 | token_id_start = 0 55 | token_id_end = max(tokenizer.all_special_ids) 56 | print(f"interrogate tokens are: {token_id_start} to {token_id_end}") 57 | 58 | def get_all_embeddings(text_encoder): 59 | embs = [] 60 | with torch.no_grad(): 61 | for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)): 62 | batch = [] 63 | for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)): 64 | tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id] 65 | # tokens = [tid] # こちらは結果がいまひとつ 66 | batch.append(tokens) 67 | 68 | # batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1] 69 | # clip skip対応 70 | batch = torch.tensor(batch).to(DEVICE) 71 | if args.clip_skip is None: 72 | encoder_hidden_states = text_encoder(batch)[0] 73 | else: 74 | enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True) 75 | encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] 76 | encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) 77 | encoder_hidden_states = encoder_hidden_states.to("cpu") 78 | 79 | embs.extend(encoder_hidden_states) 80 | return torch.stack(embs) 81 | 82 | print("get original text encoder embeddings.") 83 | orig_embs = get_all_embeddings(text_encoder) 84 | 85 | network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0) 86 | info = network.load_state_dict(weights_sd, strict=False) 87 | print(f"Loading LoRA weights: {info}") 88 | 89 | network.to(DEVICE, dtype=weights_dtype) 90 | network.eval() 91 | 92 | del unet 93 | 94 | print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") 95 | print("get text encoder embeddings with lora.") 96 | lora_embs = get_all_embeddings(text_encoder) 97 | 98 | # 比べる:とりあえず単純に差分の絶対値で 99 | print("comparing...") 100 | diffs = {} 101 | for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))): 102 | diff = torch.mean(torch.abs(orig_emb - lora_emb)) 103 | # diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない 104 | diff = float(diff.detach().to('cpu').numpy()) 105 | diffs[token_id_start + i] = diff 106 | 107 | diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1]) 108 | 109 | # 結果を表示する 110 | print("top 100:") 111 | for i, (token, diff) in enumerate(diffs_sorted[:100]): 112 | # if diff < 1e-6: 113 | # break 114 | string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token])) 115 | print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}") 116 | 117 | 118 | def setup_parser() -> argparse.ArgumentParser: 119 | parser = argparse.ArgumentParser() 120 | 121 | parser.add_argument("--v2", action='store_true', 122 | help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') 123 | parser.add_argument("--sd_model", type=str, default=None, 124 | help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors") 125 | parser.add_argument("--model", type=str, default=None, 126 | help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors") 127 | parser.add_argument("--batch_size", type=int, default=16, 128 | help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ") 129 | parser.add_argument("--clip_skip", type=int, default=None, 130 | help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") 131 | 132 | return parser 133 | 134 | 135 | if __name__ == '__main__': 136 | parser = setup_parser() 137 | 138 | args = parser.parse_args() 139 | interrogate(args) 140 | -------------------------------------------------------------------------------- /tools/resize_images_to_resolution.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import cv2 4 | import argparse 5 | import shutil 6 | import math 7 | from PIL import Image 8 | import numpy as np 9 | 10 | 11 | def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): 12 | # Split the max_resolution string by "," and strip any whitespaces 13 | max_resolutions = [res.strip() for res in max_resolution.split(',')] 14 | 15 | # # Calculate max_pixels from max_resolution string 16 | # max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) 17 | 18 | # Create destination folder if it does not exist 19 | if not os.path.exists(dst_img_folder): 20 | os.makedirs(dst_img_folder) 21 | 22 | # Select interpolation method 23 | if interpolation == 'lanczos4': 24 | cv2_interpolation = cv2.INTER_LANCZOS4 25 | elif interpolation == 'cubic': 26 | cv2_interpolation = cv2.INTER_CUBIC 27 | else: 28 | cv2_interpolation = cv2.INTER_AREA 29 | 30 | # Iterate through all files in src_img_folder 31 | img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py 32 | for filename in os.listdir(src_img_folder): 33 | # Check if the image is png, jpg or webp etc... 34 | if not filename.endswith(img_exts): 35 | # Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.) 36 | shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename)) 37 | continue 38 | 39 | # Load image 40 | # img = cv2.imread(os.path.join(src_img_folder, filename)) 41 | image = Image.open(os.path.join(src_img_folder, filename)) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | img = np.array(image, np.uint8) 45 | 46 | base, _ = os.path.splitext(filename) 47 | for max_resolution in max_resolutions: 48 | # Calculate max_pixels from max_resolution string 49 | max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) 50 | 51 | # Calculate current number of pixels 52 | current_pixels = img.shape[0] * img.shape[1] 53 | 54 | # Check if the image needs resizing 55 | if current_pixels > max_pixels: 56 | # Calculate scaling factor 57 | scale_factor = max_pixels / current_pixels 58 | 59 | # Calculate new dimensions 60 | new_height = int(img.shape[0] * math.sqrt(scale_factor)) 61 | new_width = int(img.shape[1] * math.sqrt(scale_factor)) 62 | 63 | # Resize image 64 | img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) 65 | else: 66 | new_height, new_width = img.shape[0:2] 67 | 68 | # Calculate the new height and width that are divisible by divisible_by (with/without resizing) 69 | new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by 70 | new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by 71 | 72 | # Center crop the image to the calculated dimensions 73 | y = int((img.shape[0] - new_height) / 2) 74 | x = int((img.shape[1] - new_width) / 2) 75 | img = img[y:y + new_height, x:x + new_width] 76 | 77 | # Split filename into base and extension 78 | new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg') 79 | 80 | # Save resized image in dst_img_folder 81 | # cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) 82 | image = Image.fromarray(img) 83 | image.save(os.path.join(dst_img_folder, new_filename), quality=100) 84 | 85 | proc = "Resized" if current_pixels > max_pixels else "Saved" 86 | print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") 87 | 88 | # If other files with same basename, copy them with resolution suffix 89 | if copy_associated_files: 90 | asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*")) 91 | for asoc_file in asoc_files: 92 | ext = os.path.splitext(asoc_file)[1] 93 | if ext in img_exts: 94 | continue 95 | for max_resolution in max_resolutions: 96 | new_asoc_file = base + '+' + max_resolution + ext 97 | print(f"Copy {asoc_file} as {new_asoc_file}") 98 | shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) 99 | 100 | 101 | def setup_parser() -> argparse.ArgumentParser: 102 | parser = argparse.ArgumentParser( 103 | description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします') 104 | parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ') 105 | parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ') 106 | parser.add_argument('--max_resolution', type=str, 107 | help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128") 108 | parser.add_argument('--divisible_by', type=int, 109 | help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1) 110 | parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'], 111 | default='area', help='Interpolation method for resizing / リサイズ時の補完方法') 112 | parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存') 113 | parser.add_argument('--copy_associated_files', action='store_true', 114 | help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする') 115 | 116 | return parser 117 | 118 | 119 | def main(): 120 | parser = setup_parser() 121 | 122 | args = parser.parse_args() 123 | resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution, 124 | args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /tools/convert_diffusers20_original_sd.py: -------------------------------------------------------------------------------- 1 | # convert Diffusers v1.x/v2.0 model to original Stable Diffusion 2 | 3 | import argparse 4 | import os 5 | import torch 6 | from diffusers import StableDiffusionPipeline 7 | 8 | import library.model_util as model_util 9 | 10 | 11 | def convert(args): 12 | # 引数を確認する 13 | load_dtype = torch.float16 if args.fp16 else None 14 | 15 | save_dtype = None 16 | if args.fp16 or args.save_precision_as == "fp16": 17 | save_dtype = torch.float16 18 | elif args.bf16 or args.save_precision_as == "bf16": 19 | save_dtype = torch.bfloat16 20 | elif args.float or args.save_precision_as == "float": 21 | save_dtype = torch.float 22 | 23 | is_load_ckpt = os.path.isfile(args.model_to_load) 24 | is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 25 | 26 | assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" 27 | # assert ( 28 | # is_save_ckpt or args.reference_model is not None 29 | # ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" 30 | 31 | # モデルを読み込む 32 | msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) 33 | print(f"loading {msg}: {args.model_to_load}") 34 | 35 | if is_load_ckpt: 36 | v2_model = args.v2 37 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection) 38 | else: 39 | pipe = StableDiffusionPipeline.from_pretrained( 40 | args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None 41 | ) 42 | text_encoder = pipe.text_encoder 43 | vae = pipe.vae 44 | unet = pipe.unet 45 | 46 | if args.v1 == args.v2: 47 | # 自動判定する 48 | v2_model = unet.config.cross_attention_dim == 1024 49 | print("checking model version: model is " + ("v2" if v2_model else "v1")) 50 | else: 51 | v2_model = not args.v1 52 | 53 | # 変換して保存する 54 | msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" 55 | print(f"converting and saving as {msg}: {args.model_to_save}") 56 | 57 | if is_save_ckpt: 58 | original_model = args.model_to_load if is_load_ckpt else None 59 | key_count = model_util.save_stable_diffusion_checkpoint( 60 | v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae 61 | ) 62 | print(f"model saved. total converted state_dict keys: {key_count}") 63 | else: 64 | print(f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}") 65 | model_util.save_diffusers_checkpoint( 66 | v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors 67 | ) 68 | print(f"model saved.") 69 | 70 | 71 | def setup_parser() -> argparse.ArgumentParser: 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument( 74 | "--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む" 75 | ) 76 | parser.add_argument( 77 | "--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む" 78 | ) 79 | parser.add_argument( 80 | "--unet_use_linear_projection", action="store_true", help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)" 81 | ) 82 | parser.add_argument( 83 | "--fp16", 84 | action="store_true", 85 | help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)", 86 | ) 87 | parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)") 88 | parser.add_argument( 89 | "--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)" 90 | ) 91 | parser.add_argument( 92 | "--save_precision_as", 93 | type=str, 94 | default="no", 95 | choices=["fp16", "bf16", "float"], 96 | help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください", 97 | ) 98 | parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値") 99 | parser.add_argument( 100 | "--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値" 101 | ) 102 | parser.add_argument( 103 | "--reference_model", 104 | type=str, 105 | default=None, 106 | help="scheduler/tokenizerのコピー元Diffusersモデル、Diffusers形式で保存するときに使用される、省略時は`runwayml/stable-diffusion-v1-5` または `stabilityai/stable-diffusion-2-1` / reference Diffusers model to copy scheduler/tokenizer config from, used when saving as Diffusers format, default is `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1`", 107 | ) 108 | parser.add_argument( 109 | "--use_safetensors", 110 | action="store_true", 111 | help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)", 112 | ) 113 | 114 | parser.add_argument( 115 | "model_to_load", 116 | type=str, 117 | default=None, 118 | help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ", 119 | ) 120 | parser.add_argument( 121 | "model_to_save", 122 | type=str, 123 | default=None, 124 | help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存", 125 | ) 126 | return parser 127 | 128 | 129 | if __name__ == "__main__": 130 | parser = setup_parser() 131 | 132 | args = parser.parse_args() 133 | convert(args) 134 | -------------------------------------------------------------------------------- /finetune/clean_captions_and_tags.py: -------------------------------------------------------------------------------- 1 | # このスクリプトのライセンスは、Apache License 2.0とします 2 | # (c) 2022 Kohya S. @kohya_ss 3 | 4 | import argparse 5 | import glob 6 | import os 7 | import json 8 | import re 9 | 10 | from tqdm import tqdm 11 | 12 | PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') 13 | PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') 14 | PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ') 15 | PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ') 16 | 17 | # 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する 18 | PATTERNS_REMOVE_IN_MULTI = [ 19 | PATTERN_HAIR_LENGTH, 20 | PATTERN_HAIR_CUT, 21 | re.compile(r', [\w\-]+ eyes, '), 22 | re.compile(r', ([\w\-]+ sleeves|sleeveless), '), 23 | # 複数の髪型定義がある場合は削除する 24 | re.compile( 25 | r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '), 26 | ] 27 | 28 | 29 | def clean_tags(image_key, tags): 30 | # replace '_' to ' ' 31 | tags = tags.replace('^_^', '^@@@^') 32 | tags = tags.replace('_', ' ') 33 | tags = tags.replace('^@@@^', '^_^') 34 | 35 | # remove rating: deepdanbooruのみ 36 | tokens = tags.split(", rating") 37 | if len(tokens) == 1: 38 | # WD14 taggerのときはこちらになるのでメッセージは出さない 39 | # print("no rating:") 40 | # print(f"{image_key} {tags}") 41 | pass 42 | else: 43 | if len(tokens) > 2: 44 | print("multiple ratings:") 45 | print(f"{image_key} {tags}") 46 | tags = tokens[0] 47 | 48 | tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策 49 | 50 | # 複数の人物がいる場合は髪色等のタグを削除する 51 | if 'girls' in tags or 'boys' in tags: 52 | for pat in PATTERNS_REMOVE_IN_MULTI: 53 | found = pat.findall(tags) 54 | if len(found) > 1: # 二つ以上、タグがある 55 | tags = pat.sub("", tags) 56 | 57 | # 髪の特殊対応 58 | srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合) 59 | if srch_hair_len: 60 | org = srch_hair_len.group() 61 | tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags) 62 | 63 | found = PATTERN_HAIR.findall(tags) 64 | if len(found) > 1: 65 | tags = PATTERN_HAIR.sub("", tags) 66 | 67 | if srch_hair_len: 68 | tags = tags.replace(", @@@, ", org) # 戻す 69 | 70 | # white shirtとshirtみたいな重複タグの削除 71 | found = PATTERN_WORD.findall(tags) 72 | for word in found: 73 | if re.search(f", ((\w+) )+{word}, ", tags): 74 | tags = tags.replace(f", {word}, ", "") 75 | 76 | tags = tags.replace(", , ", ", ") 77 | assert tags.startswith(", ") and tags.endswith(", ") 78 | tags = tags[2:-2] 79 | return tags 80 | 81 | 82 | # 上から順に検索、置換される 83 | # ('置換元文字列', '置換後文字列') 84 | CAPTION_REPLACEMENTS = [ 85 | ('anime anime', 'anime'), 86 | ('young ', ''), 87 | ('anime girl', 'girl'), 88 | ('cartoon female', 'girl'), 89 | ('cartoon lady', 'girl'), 90 | ('cartoon character', 'girl'), # a or ~s 91 | ('cartoon woman', 'girl'), 92 | ('cartoon women', 'girls'), 93 | ('cartoon girl', 'girl'), 94 | ('anime female', 'girl'), 95 | ('anime lady', 'girl'), 96 | ('anime character', 'girl'), # a or ~s 97 | ('anime woman', 'girl'), 98 | ('anime women', 'girls'), 99 | ('lady', 'girl'), 100 | ('female', 'girl'), 101 | ('woman', 'girl'), 102 | ('women', 'girls'), 103 | ('people', 'girls'), 104 | ('person', 'girl'), 105 | ('a cartoon figure', 'a figure'), 106 | ('a cartoon image', 'an image'), 107 | ('a cartoon picture', 'a picture'), 108 | ('an anime cartoon image', 'an image'), 109 | ('a cartoon anime drawing', 'a drawing'), 110 | ('a cartoon drawing', 'a drawing'), 111 | ('girl girl', 'girl'), 112 | ] 113 | 114 | 115 | def clean_caption(caption): 116 | for rf, rt in CAPTION_REPLACEMENTS: 117 | replaced = True 118 | while replaced: 119 | bef = caption 120 | caption = caption.replace(rf, rt) 121 | replaced = bef != caption 122 | return caption 123 | 124 | 125 | def main(args): 126 | if os.path.exists(args.in_json): 127 | print(f"loading existing metadata: {args.in_json}") 128 | with open(args.in_json, "rt", encoding='utf-8') as f: 129 | metadata = json.load(f) 130 | else: 131 | print("no metadata / メタデータファイルがありません") 132 | return 133 | 134 | print("cleaning captions and tags.") 135 | image_keys = list(metadata.keys()) 136 | for image_key in tqdm(image_keys): 137 | tags = metadata[image_key].get('tags') 138 | if tags is None: 139 | print(f"image does not have tags / メタデータにタグがありません: {image_key}") 140 | else: 141 | org = tags 142 | tags = clean_tags(image_key, tags) 143 | metadata[image_key]['tags'] = tags 144 | if args.debug and org != tags: 145 | print("FROM: " + org) 146 | print("TO: " + tags) 147 | 148 | caption = metadata[image_key].get('caption') 149 | if caption is None: 150 | print(f"image does not have caption / メタデータにキャプションがありません: {image_key}") 151 | else: 152 | org = caption 153 | caption = clean_caption(caption) 154 | metadata[image_key]['caption'] = caption 155 | if args.debug and org != caption: 156 | print("FROM: " + org) 157 | print("TO: " + caption) 158 | 159 | # metadataを書き出して終わり 160 | print(f"writing metadata: {args.out_json}") 161 | with open(args.out_json, "wt", encoding='utf-8') as f: 162 | json.dump(metadata, f, indent=2) 163 | print("done!") 164 | 165 | 166 | def setup_parser() -> argparse.ArgumentParser: 167 | parser = argparse.ArgumentParser() 168 | # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 169 | parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") 170 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 171 | parser.add_argument("--debug", action="store_true", help="debug mode") 172 | 173 | return parser 174 | 175 | 176 | if __name__ == '__main__': 177 | parser = setup_parser() 178 | 179 | args, unknown = parser.parse_known_args() 180 | if len(unknown) == 1: 181 | print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") 182 | print("All captions and tags in the metadata are processed.") 183 | print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") 184 | print("メタデータ内のすべてのキャプションとタグが処理されます。") 185 | args.in_json = args.out_json 186 | args.out_json = unknown[0] 187 | elif len(unknown) > 0: 188 | raise ValueError(f"error: unrecognized arguments: {unknown}") 189 | 190 | main(args) 191 | -------------------------------------------------------------------------------- /finetune/make_captions_by_git.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | 5 | from pathlib import Path 6 | from PIL import Image 7 | from tqdm import tqdm 8 | import torch 9 | from transformers import AutoProcessor, AutoModelForCausalLM 10 | from transformers.generation.utils import GenerationMixin 11 | 12 | import library.train_util as train_util 13 | 14 | 15 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | PATTERN_REPLACE = [ 18 | re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'), 19 | re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'), 20 | re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"), 21 | re.compile(r"with the number \d+ on (it|\w+ \w+)"), 22 | re.compile(r'with the words "'), 23 | re.compile(r"word \w+ on it"), 24 | re.compile(r"that says the word \w+ on it"), 25 | re.compile("that says'the word \"( on it)?"), 26 | ] 27 | 28 | # 誤検知しまくりの with the word xxxx を消す 29 | 30 | 31 | def remove_words(captions, debug): 32 | removed_caps = [] 33 | for caption in captions: 34 | cap = caption 35 | for pat in PATTERN_REPLACE: 36 | cap = pat.sub("", cap) 37 | if debug and cap != caption: 38 | print(caption) 39 | print(cap) 40 | removed_caps.append(cap) 41 | return removed_caps 42 | 43 | 44 | def collate_fn_remove_corrupted(batch): 45 | """Collate function that allows to remove corrupted examples in the 46 | dataloader. It expects that the dataloader returns 'None' when that occurs. 47 | The 'None's in the batch are removed. 48 | """ 49 | # Filter out all the Nones (corrupted examples) 50 | batch = list(filter(lambda x: x is not None, batch)) 51 | return batch 52 | 53 | 54 | def main(args): 55 | # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用 56 | org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation 57 | curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように 58 | 59 | # input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す 60 | # ここより上で置き換えようとするとすごく大変 61 | def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs): 62 | input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs) 63 | if input_ids.size()[0] != curr_batch_size[0]: 64 | input_ids = input_ids.repeat(curr_batch_size[0], 1) 65 | return input_ids 66 | 67 | GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch 68 | 69 | print(f"load images from {args.train_data_dir}") 70 | train_data_dir_path = Path(args.train_data_dir) 71 | image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 72 | print(f"found {len(image_paths)} images.") 73 | 74 | # できればcacheに依存せず明示的にダウンロードしたい 75 | print(f"loading GIT: {args.model_id}") 76 | git_processor = AutoProcessor.from_pretrained(args.model_id) 77 | git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) 78 | print("GIT loaded") 79 | 80 | # captioningする 81 | def run_batch(path_imgs): 82 | imgs = [im for _, im in path_imgs] 83 | 84 | curr_batch_size[0] = len(path_imgs) 85 | inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式 86 | generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length) 87 | captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True) 88 | 89 | if args.remove_words: 90 | captions = remove_words(captions, args.debug) 91 | 92 | for (image_path, _), caption in zip(path_imgs, captions): 93 | with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: 94 | f.write(caption + "\n") 95 | if args.debug: 96 | print(image_path, caption) 97 | 98 | # 読み込みの高速化のためにDataLoaderを使うオプション 99 | if args.max_data_loader_n_workers is not None: 100 | dataset = train_util.ImageLoadingDataset(image_paths) 101 | data = torch.utils.data.DataLoader( 102 | dataset, 103 | batch_size=args.batch_size, 104 | shuffle=False, 105 | num_workers=args.max_data_loader_n_workers, 106 | collate_fn=collate_fn_remove_corrupted, 107 | drop_last=False, 108 | ) 109 | else: 110 | data = [[(None, ip)] for ip in image_paths] 111 | 112 | b_imgs = [] 113 | for data_entry in tqdm(data, smoothing=0.0): 114 | for data in data_entry: 115 | if data is None: 116 | continue 117 | 118 | image, image_path = data 119 | if image is None: 120 | try: 121 | image = Image.open(image_path) 122 | if image.mode != "RGB": 123 | image = image.convert("RGB") 124 | except Exception as e: 125 | print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") 126 | continue 127 | 128 | b_imgs.append((image_path, image)) 129 | if len(b_imgs) >= args.batch_size: 130 | run_batch(b_imgs) 131 | b_imgs.clear() 132 | 133 | if len(b_imgs) > 0: 134 | run_batch(b_imgs) 135 | 136 | print("done!") 137 | 138 | 139 | def setup_parser() -> argparse.ArgumentParser: 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 142 | parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") 143 | parser.add_argument( 144 | "--model_id", 145 | type=str, 146 | default="microsoft/git-large-textcaps", 147 | help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID", 148 | ) 149 | parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") 150 | parser.add_argument( 151 | "--max_data_loader_n_workers", 152 | type=int, 153 | default=None, 154 | help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", 155 | ) 156 | parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長") 157 | parser.add_argument( 158 | "--remove_words", 159 | action="store_true", 160 | help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する", 161 | ) 162 | parser.add_argument("--debug", action="store_true", help="debug mode") 163 | parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") 164 | 165 | return parser 166 | 167 | 168 | if __name__ == "__main__": 169 | parser = setup_parser() 170 | 171 | args = parser.parse_args() 172 | main(args) 173 | -------------------------------------------------------------------------------- /networks/merge_lora_old.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | import os 5 | import torch 6 | from safetensors.torch import load_file, save_file 7 | import library.model_util as model_util 8 | import lora 9 | 10 | 11 | def load_state_dict(file_name, dtype): 12 | if os.path.splitext(file_name)[1] == '.safetensors': 13 | sd = load_file(file_name) 14 | else: 15 | sd = torch.load(file_name, map_location='cpu') 16 | for key in list(sd.keys()): 17 | if type(sd[key]) == torch.Tensor: 18 | sd[key] = sd[key].to(dtype) 19 | return sd 20 | 21 | 22 | def save_to_file(file_name, model, state_dict, dtype): 23 | if dtype is not None: 24 | for key in list(state_dict.keys()): 25 | if type(state_dict[key]) == torch.Tensor: 26 | state_dict[key] = state_dict[key].to(dtype) 27 | 28 | if os.path.splitext(file_name)[1] == '.safetensors': 29 | save_file(model, file_name) 30 | else: 31 | torch.save(model, file_name) 32 | 33 | 34 | def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): 35 | text_encoder.to(merge_dtype) 36 | unet.to(merge_dtype) 37 | 38 | # create module map 39 | name_to_module = {} 40 | for i, root_module in enumerate([text_encoder, unet]): 41 | if i == 0: 42 | prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER 43 | target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE 44 | else: 45 | prefix = lora.LoRANetwork.LORA_PREFIX_UNET 46 | target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE 47 | 48 | for name, module in root_module.named_modules(): 49 | if module.__class__.__name__ in target_replace_modules: 50 | for child_name, child_module in module.named_modules(): 51 | if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): 52 | lora_name = prefix + '.' + name + '.' + child_name 53 | lora_name = lora_name.replace('.', '_') 54 | name_to_module[lora_name] = child_module 55 | 56 | for model, ratio in zip(models, ratios): 57 | print(f"loading: {model}") 58 | lora_sd = load_state_dict(model, merge_dtype) 59 | 60 | print(f"merging...") 61 | for key in lora_sd.keys(): 62 | if "lora_down" in key: 63 | up_key = key.replace("lora_down", "lora_up") 64 | alpha_key = key[:key.index("lora_down")] + 'alpha' 65 | 66 | # find original module for this lora 67 | module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" 68 | if module_name not in name_to_module: 69 | print(f"no module found for LoRA weight: {key}") 70 | continue 71 | module = name_to_module[module_name] 72 | # print(f"apply {key} to {module}") 73 | 74 | down_weight = lora_sd[key] 75 | up_weight = lora_sd[up_key] 76 | 77 | dim = down_weight.size()[0] 78 | alpha = lora_sd.get(alpha_key, dim) 79 | scale = alpha / dim 80 | 81 | # W <- W + U * D 82 | weight = module.weight 83 | if len(weight.size()) == 2: 84 | # linear 85 | weight = weight + ratio * (up_weight @ down_weight) * scale 86 | else: 87 | # conv2d 88 | weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale 89 | 90 | module.weight = torch.nn.Parameter(weight) 91 | 92 | 93 | def merge_lora_models(models, ratios, merge_dtype): 94 | merged_sd = {} 95 | 96 | alpha = None 97 | dim = None 98 | for model, ratio in zip(models, ratios): 99 | print(f"loading: {model}") 100 | lora_sd = load_state_dict(model, merge_dtype) 101 | 102 | print(f"merging...") 103 | for key in lora_sd.keys(): 104 | if 'alpha' in key: 105 | if key in merged_sd: 106 | assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません" 107 | else: 108 | alpha = lora_sd[key].detach().numpy() 109 | merged_sd[key] = lora_sd[key] 110 | else: 111 | if key in merged_sd: 112 | assert merged_sd[key].size() == lora_sd[key].size( 113 | ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" 114 | merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio 115 | else: 116 | if "lora_down" in key: 117 | dim = lora_sd[key].size()[0] 118 | merged_sd[key] = lora_sd[key] * ratio 119 | 120 | print(f"dim (rank): {dim}, alpha: {alpha}") 121 | if alpha is None: 122 | alpha = dim 123 | 124 | return merged_sd, dim, alpha 125 | 126 | 127 | def merge(args): 128 | assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" 129 | 130 | def str_to_dtype(p): 131 | if p == 'float': 132 | return torch.float 133 | if p == 'fp16': 134 | return torch.float16 135 | if p == 'bf16': 136 | return torch.bfloat16 137 | return None 138 | 139 | merge_dtype = str_to_dtype(args.precision) 140 | save_dtype = str_to_dtype(args.save_precision) 141 | if save_dtype is None: 142 | save_dtype = merge_dtype 143 | 144 | if args.sd_model is not None: 145 | print(f"loading SD model: {args.sd_model}") 146 | 147 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) 148 | 149 | merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) 150 | 151 | print(f"\nsaving SD model to: {args.save_to}") 152 | model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, 153 | args.sd_model, 0, 0, save_dtype, vae) 154 | else: 155 | state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) 156 | 157 | print(f"\nsaving model to: {args.save_to}") 158 | save_to_file(args.save_to, state_dict, state_dict, save_dtype) 159 | 160 | 161 | def setup_parser() -> argparse.ArgumentParser: 162 | parser = argparse.ArgumentParser() 163 | parser.add_argument("--v2", action='store_true', 164 | help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') 165 | parser.add_argument("--save_precision", type=str, default=None, 166 | choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") 167 | parser.add_argument("--precision", type=str, default="float", 168 | choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)") 169 | parser.add_argument("--sd_model", type=str, default=None, 170 | help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") 171 | parser.add_argument("--save_to", type=str, default=None, 172 | help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") 173 | parser.add_argument("--models", type=str, nargs='*', 174 | help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") 175 | parser.add_argument("--ratios", type=float, nargs='*', 176 | help="ratios for each model / それぞれのLoRAモデルの比率") 177 | 178 | return parser 179 | 180 | 181 | if __name__ == '__main__': 182 | parser = setup_parser() 183 | 184 | args = parser.parse_args() 185 | merge(args) 186 | -------------------------------------------------------------------------------- /networks/svd_merge_lora.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import argparse 4 | import os 5 | import torch 6 | from safetensors.torch import load_file, save_file 7 | from tqdm import tqdm 8 | import library.model_util as model_util 9 | import lora 10 | 11 | 12 | CLAMP_QUANTILE = 0.99 13 | 14 | 15 | def load_state_dict(file_name, dtype): 16 | if os.path.splitext(file_name)[1] == '.safetensors': 17 | sd = load_file(file_name) 18 | else: 19 | sd = torch.load(file_name, map_location='cpu') 20 | for key in list(sd.keys()): 21 | if type(sd[key]) == torch.Tensor: 22 | sd[key] = sd[key].to(dtype) 23 | return sd 24 | 25 | 26 | def save_to_file(file_name, state_dict, dtype): 27 | if dtype is not None: 28 | for key in list(state_dict.keys()): 29 | if type(state_dict[key]) == torch.Tensor: 30 | state_dict[key] = state_dict[key].to(dtype) 31 | 32 | if os.path.splitext(file_name)[1] == '.safetensors': 33 | save_file(state_dict, file_name) 34 | else: 35 | torch.save(state_dict, file_name) 36 | 37 | 38 | def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): 39 | print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") 40 | merged_sd = {} 41 | for model, ratio in zip(models, ratios): 42 | print(f"loading: {model}") 43 | lora_sd = load_state_dict(model, merge_dtype) 44 | 45 | # merge 46 | print(f"merging...") 47 | for key in tqdm(list(lora_sd.keys())): 48 | if 'lora_down' not in key: 49 | continue 50 | 51 | lora_module_name = key[:key.rfind(".lora_down")] 52 | 53 | down_weight = lora_sd[key] 54 | network_dim = down_weight.size()[0] 55 | 56 | up_weight = lora_sd[lora_module_name + '.lora_up.weight'] 57 | alpha = lora_sd.get(lora_module_name + '.alpha', network_dim) 58 | 59 | in_dim = down_weight.size()[1] 60 | out_dim = up_weight.size()[0] 61 | conv2d = len(down_weight.size()) == 4 62 | kernel_size = None if not conv2d else down_weight.size()[2:4] 63 | # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) 64 | 65 | # make original weight if not exist 66 | if lora_module_name not in merged_sd: 67 | weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype) 68 | if device: 69 | weight = weight.to(device) 70 | else: 71 | weight = merged_sd[lora_module_name] 72 | 73 | # merge to weight 74 | if device: 75 | up_weight = up_weight.to(device) 76 | down_weight = down_weight.to(device) 77 | 78 | # W <- W + U * D 79 | scale = (alpha / network_dim) 80 | 81 | if device: # and isinstance(scale, torch.Tensor): 82 | scale = scale.to(device) 83 | 84 | if not conv2d: # linear 85 | weight = weight + ratio * (up_weight @ down_weight) * scale 86 | elif kernel_size == (1, 1): 87 | weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2) 88 | ).unsqueeze(2).unsqueeze(3) * scale 89 | else: 90 | conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) 91 | weight = weight + ratio * conved * scale 92 | 93 | merged_sd[lora_module_name] = weight 94 | 95 | # extract from merged weights 96 | print("extract new lora...") 97 | merged_lora_sd = {} 98 | with torch.no_grad(): 99 | for lora_module_name, mat in tqdm(list(merged_sd.items())): 100 | conv2d = (len(mat.size()) == 4) 101 | kernel_size = None if not conv2d else mat.size()[2:4] 102 | conv2d_3x3 = conv2d and kernel_size != (1, 1) 103 | out_dim, in_dim = mat.size()[0:2] 104 | 105 | if conv2d: 106 | if conv2d_3x3: 107 | mat = mat.flatten(start_dim=1) 108 | else: 109 | mat = mat.squeeze() 110 | 111 | module_new_rank = new_conv_rank if conv2d_3x3 else new_rank 112 | module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim 113 | 114 | U, S, Vh = torch.linalg.svd(mat) 115 | 116 | U = U[:, :module_new_rank] 117 | S = S[:module_new_rank] 118 | U = U @ torch.diag(S) 119 | 120 | Vh = Vh[:module_new_rank, :] 121 | 122 | dist = torch.cat([U.flatten(), Vh.flatten()]) 123 | hi_val = torch.quantile(dist, CLAMP_QUANTILE) 124 | low_val = -hi_val 125 | 126 | U = U.clamp(low_val, hi_val) 127 | Vh = Vh.clamp(low_val, hi_val) 128 | 129 | if conv2d: 130 | U = U.reshape(out_dim, module_new_rank, 1, 1) 131 | Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1]) 132 | 133 | up_weight = U 134 | down_weight = Vh 135 | 136 | merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous() 137 | merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous() 138 | merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank) 139 | 140 | return merged_lora_sd 141 | 142 | 143 | def merge(args): 144 | assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" 145 | 146 | def str_to_dtype(p): 147 | if p == 'float': 148 | return torch.float 149 | if p == 'fp16': 150 | return torch.float16 151 | if p == 'bf16': 152 | return torch.bfloat16 153 | return None 154 | 155 | merge_dtype = str_to_dtype(args.precision) 156 | save_dtype = str_to_dtype(args.save_precision) 157 | if save_dtype is None: 158 | save_dtype = merge_dtype 159 | 160 | new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank 161 | state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype) 162 | 163 | print(f"saving model to: {args.save_to}") 164 | save_to_file(args.save_to, state_dict, save_dtype) 165 | 166 | 167 | def setup_parser() -> argparse.ArgumentParser: 168 | parser = argparse.ArgumentParser() 169 | parser.add_argument("--save_precision", type=str, default=None, 170 | choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") 171 | parser.add_argument("--precision", type=str, default="float", 172 | choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)") 173 | parser.add_argument("--save_to", type=str, default=None, 174 | help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") 175 | parser.add_argument("--models", type=str, nargs='*', 176 | help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") 177 | parser.add_argument("--ratios", type=float, nargs='*', 178 | help="ratios for each model / それぞれのLoRAモデルの比率") 179 | parser.add_argument("--new_rank", type=int, default=4, 180 | help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") 181 | parser.add_argument("--new_conv_rank", type=int, default=None, 182 | help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ") 183 | parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") 184 | 185 | return parser 186 | 187 | 188 | if __name__ == '__main__': 189 | parser = setup_parser() 190 | 191 | args = parser.parse_args() 192 | merge(args) 193 | -------------------------------------------------------------------------------- /library/hypernetwork.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from diffusers.models.attention_processor import ( 4 | Attention, 5 | AttnProcessor2_0, 6 | SlicedAttnProcessor, 7 | XFormersAttnProcessor 8 | ) 9 | 10 | try: 11 | import xformers.ops 12 | except: 13 | xformers = None 14 | 15 | 16 | loaded_networks = [] 17 | 18 | 19 | def apply_single_hypernetwork( 20 | hypernetwork, hidden_states, encoder_hidden_states 21 | ): 22 | context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states) 23 | return context_k, context_v 24 | 25 | 26 | def apply_hypernetworks(context_k, context_v, layer=None): 27 | if len(loaded_networks) == 0: 28 | return context_v, context_v 29 | for hypernetwork in loaded_networks: 30 | context_k, context_v = hypernetwork.forward(context_k, context_v) 31 | 32 | context_k = context_k.to(dtype=context_k.dtype) 33 | context_v = context_v.to(dtype=context_k.dtype) 34 | 35 | return context_k, context_v 36 | 37 | 38 | 39 | def xformers_forward( 40 | self: XFormersAttnProcessor, 41 | attn: Attention, 42 | hidden_states: torch.Tensor, 43 | encoder_hidden_states: torch.Tensor = None, 44 | attention_mask: torch.Tensor = None, 45 | ): 46 | batch_size, sequence_length, _ = ( 47 | hidden_states.shape 48 | if encoder_hidden_states is None 49 | else encoder_hidden_states.shape 50 | ) 51 | 52 | attention_mask = attn.prepare_attention_mask( 53 | attention_mask, sequence_length, batch_size 54 | ) 55 | 56 | query = attn.to_q(hidden_states) 57 | 58 | if encoder_hidden_states is None: 59 | encoder_hidden_states = hidden_states 60 | elif attn.norm_cross: 61 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 62 | 63 | context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) 64 | 65 | key = attn.to_k(context_k) 66 | value = attn.to_v(context_v) 67 | 68 | query = attn.head_to_batch_dim(query).contiguous() 69 | key = attn.head_to_batch_dim(key).contiguous() 70 | value = attn.head_to_batch_dim(value).contiguous() 71 | 72 | hidden_states = xformers.ops.memory_efficient_attention( 73 | query, 74 | key, 75 | value, 76 | attn_bias=attention_mask, 77 | op=self.attention_op, 78 | scale=attn.scale, 79 | ) 80 | hidden_states = hidden_states.to(query.dtype) 81 | hidden_states = attn.batch_to_head_dim(hidden_states) 82 | 83 | # linear proj 84 | hidden_states = attn.to_out[0](hidden_states) 85 | # dropout 86 | hidden_states = attn.to_out[1](hidden_states) 87 | return hidden_states 88 | 89 | 90 | def sliced_attn_forward( 91 | self: SlicedAttnProcessor, 92 | attn: Attention, 93 | hidden_states: torch.Tensor, 94 | encoder_hidden_states: torch.Tensor = None, 95 | attention_mask: torch.Tensor = None, 96 | ): 97 | batch_size, sequence_length, _ = ( 98 | hidden_states.shape 99 | if encoder_hidden_states is None 100 | else encoder_hidden_states.shape 101 | ) 102 | attention_mask = attn.prepare_attention_mask( 103 | attention_mask, sequence_length, batch_size 104 | ) 105 | 106 | query = attn.to_q(hidden_states) 107 | dim = query.shape[-1] 108 | query = attn.head_to_batch_dim(query) 109 | 110 | if encoder_hidden_states is None: 111 | encoder_hidden_states = hidden_states 112 | elif attn.norm_cross: 113 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 114 | 115 | context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) 116 | 117 | key = attn.to_k(context_k) 118 | value = attn.to_v(context_v) 119 | key = attn.head_to_batch_dim(key) 120 | value = attn.head_to_batch_dim(value) 121 | 122 | batch_size_attention, query_tokens, _ = query.shape 123 | hidden_states = torch.zeros( 124 | (batch_size_attention, query_tokens, dim // attn.heads), 125 | device=query.device, 126 | dtype=query.dtype, 127 | ) 128 | 129 | for i in range(batch_size_attention // self.slice_size): 130 | start_idx = i * self.slice_size 131 | end_idx = (i + 1) * self.slice_size 132 | 133 | query_slice = query[start_idx:end_idx] 134 | key_slice = key[start_idx:end_idx] 135 | attn_mask_slice = ( 136 | attention_mask[start_idx:end_idx] if attention_mask is not None else None 137 | ) 138 | 139 | attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) 140 | 141 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) 142 | 143 | hidden_states[start_idx:end_idx] = attn_slice 144 | 145 | hidden_states = attn.batch_to_head_dim(hidden_states) 146 | 147 | # linear proj 148 | hidden_states = attn.to_out[0](hidden_states) 149 | # dropout 150 | hidden_states = attn.to_out[1](hidden_states) 151 | 152 | return hidden_states 153 | 154 | 155 | def v2_0_forward( 156 | self: AttnProcessor2_0, 157 | attn: Attention, 158 | hidden_states, 159 | encoder_hidden_states=None, 160 | attention_mask=None, 161 | ): 162 | batch_size, sequence_length, _ = ( 163 | hidden_states.shape 164 | if encoder_hidden_states is None 165 | else encoder_hidden_states.shape 166 | ) 167 | inner_dim = hidden_states.shape[-1] 168 | 169 | if attention_mask is not None: 170 | attention_mask = attn.prepare_attention_mask( 171 | attention_mask, sequence_length, batch_size 172 | ) 173 | # scaled_dot_product_attention expects attention_mask shape to be 174 | # (batch, heads, source_length, target_length) 175 | attention_mask = attention_mask.view( 176 | batch_size, attn.heads, -1, attention_mask.shape[-1] 177 | ) 178 | 179 | query = attn.to_q(hidden_states) 180 | 181 | if encoder_hidden_states is None: 182 | encoder_hidden_states = hidden_states 183 | elif attn.norm_cross: 184 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 185 | 186 | context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) 187 | 188 | key = attn.to_k(context_k) 189 | value = attn.to_v(context_v) 190 | 191 | head_dim = inner_dim // attn.heads 192 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 193 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 194 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 195 | 196 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 197 | # TODO: add support for attn.scale when we move to Torch 2.1 198 | hidden_states = F.scaled_dot_product_attention( 199 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 200 | ) 201 | 202 | hidden_states = hidden_states.transpose(1, 2).reshape( 203 | batch_size, -1, attn.heads * head_dim 204 | ) 205 | hidden_states = hidden_states.to(query.dtype) 206 | 207 | # linear proj 208 | hidden_states = attn.to_out[0](hidden_states) 209 | # dropout 210 | hidden_states = attn.to_out[1](hidden_states) 211 | return hidden_states 212 | 213 | 214 | def replace_attentions_for_hypernetwork(): 215 | import diffusers.models.attention_processor 216 | 217 | diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = ( 218 | xformers_forward 219 | ) 220 | diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = ( 221 | sliced_attn_forward 222 | ) 223 | diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward 224 | -------------------------------------------------------------------------------- /tools/cache_text_encoder_outputs.py: -------------------------------------------------------------------------------- 1 | # text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance 2 | 3 | import argparse 4 | import math 5 | from multiprocessing import Value 6 | import os 7 | 8 | from accelerate.utils import set_seed 9 | import torch 10 | from tqdm import tqdm 11 | 12 | from library import config_util 13 | from library import train_util 14 | from library import sdxl_train_util 15 | from library.config_util import ( 16 | ConfigSanitizer, 17 | BlueprintGenerator, 18 | ) 19 | 20 | 21 | def cache_to_disk(args: argparse.Namespace) -> None: 22 | train_util.prepare_dataset_args(args, True) 23 | 24 | # check cache arg 25 | assert ( 26 | args.cache_text_encoder_outputs_to_disk 27 | ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります" 28 | 29 | # できるだけ準備はしておくが今のところSDXLのみしか動かない 30 | assert ( 31 | args.sdxl 32 | ), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です" 33 | 34 | use_dreambooth_method = args.in_json is None 35 | 36 | if args.seed is not None: 37 | set_seed(args.seed) # 乱数系列を初期化する 38 | 39 | # tokenizerを準備する:datasetを動かすために必要 40 | if args.sdxl: 41 | tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) 42 | tokenizers = [tokenizer1, tokenizer2] 43 | else: 44 | tokenizer = train_util.load_tokenizer(args) 45 | tokenizers = [tokenizer] 46 | 47 | # データセットを準備する 48 | if args.dataset_class is None: 49 | blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) 50 | if args.dataset_config is not None: 51 | print(f"Load dataset config from {args.dataset_config}") 52 | user_config = config_util.load_user_config(args.dataset_config) 53 | ignored = ["train_data_dir", "in_json"] 54 | if any(getattr(args, attr) is not None for attr in ignored): 55 | print( 56 | "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( 57 | ", ".join(ignored) 58 | ) 59 | ) 60 | else: 61 | if use_dreambooth_method: 62 | print("Using DreamBooth method.") 63 | user_config = { 64 | "datasets": [ 65 | { 66 | "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( 67 | args.train_data_dir, args.reg_data_dir 68 | ) 69 | } 70 | ] 71 | } 72 | else: 73 | print("Training with captions.") 74 | user_config = { 75 | "datasets": [ 76 | { 77 | "subsets": [ 78 | { 79 | "image_dir": args.train_data_dir, 80 | "metadata_file": args.in_json, 81 | } 82 | ] 83 | } 84 | ] 85 | } 86 | 87 | blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) 88 | train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) 89 | else: 90 | train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) 91 | 92 | current_epoch = Value("i", 0) 93 | current_step = Value("i", 0) 94 | ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None 95 | collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) 96 | 97 | # acceleratorを準備する 98 | print("prepare accelerator") 99 | accelerator = train_util.prepare_accelerator(args) 100 | 101 | # mixed precisionに対応した型を用意しておき適宜castする 102 | weight_dtype, _ = train_util.prepare_dtype(args) 103 | 104 | # モデルを読み込む 105 | print("load model") 106 | if args.sdxl: 107 | (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) 108 | text_encoders = [text_encoder1, text_encoder2] 109 | else: 110 | text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) 111 | text_encoders = [text_encoder1] 112 | 113 | for text_encoder in text_encoders: 114 | text_encoder.to(accelerator.device, dtype=weight_dtype) 115 | text_encoder.requires_grad_(False) 116 | text_encoder.eval() 117 | 118 | # dataloaderを準備する 119 | train_dataset_group.set_caching_mode("text") 120 | 121 | # DataLoaderのプロセス数:0はメインプロセスになる 122 | n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで 123 | 124 | train_dataloader = torch.utils.data.DataLoader( 125 | train_dataset_group, 126 | batch_size=1, 127 | shuffle=True, 128 | collate_fn=collater, 129 | num_workers=n_workers, 130 | persistent_workers=args.persistent_data_loader_workers, 131 | ) 132 | 133 | # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず 134 | train_dataloader = accelerator.prepare(train_dataloader) 135 | 136 | # データ取得のためのループ 137 | for batch in tqdm(train_dataloader): 138 | absolute_paths = batch["absolute_paths"] 139 | input_ids1_list = batch["input_ids1_list"] 140 | input_ids2_list = batch["input_ids2_list"] 141 | 142 | image_infos = [] 143 | for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list): 144 | image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) 145 | image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX 146 | image_info 147 | 148 | if args.skip_existing: 149 | if os.path.exists(image_info.text_encoder_outputs_npz): 150 | print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") 151 | continue 152 | 153 | image_info.input_ids1 = input_ids1 154 | image_info.input_ids2 = input_ids2 155 | image_infos.append(image_info) 156 | 157 | if len(image_infos) > 0: 158 | b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos]) 159 | b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos]) 160 | train_util.cache_batch_text_encoder_outputs( 161 | image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype 162 | ) 163 | 164 | accelerator.wait_for_everyone() 165 | accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") 166 | 167 | 168 | def setup_parser() -> argparse.ArgumentParser: 169 | parser = argparse.ArgumentParser() 170 | 171 | train_util.add_sd_models_arguments(parser) 172 | train_util.add_training_arguments(parser, True) 173 | train_util.add_dataset_arguments(parser, True, True, True) 174 | config_util.add_config_arguments(parser) 175 | sdxl_train_util.add_sdxl_training_arguments(parser) 176 | parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") 177 | parser.add_argument( 178 | "--skip_existing", 179 | action="store_true", 180 | help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", 181 | ) 182 | return parser 183 | 184 | 185 | if __name__ == "__main__": 186 | parser = setup_parser() 187 | 188 | args = parser.parse_args() 189 | args = train_util.read_config_from_file(args, parser) 190 | 191 | cache_to_disk(args) 192 | -------------------------------------------------------------------------------- /tools/cache_latents.py: -------------------------------------------------------------------------------- 1 | # latentsのdiskへの事前キャッシュを行う / cache latents to disk 2 | 3 | import argparse 4 | import math 5 | from multiprocessing import Value 6 | import os 7 | 8 | from accelerate.utils import set_seed 9 | import torch 10 | from tqdm import tqdm 11 | 12 | from library import config_util 13 | from library import train_util 14 | from library import sdxl_train_util 15 | from library.config_util import ( 16 | ConfigSanitizer, 17 | BlueprintGenerator, 18 | ) 19 | 20 | 21 | def cache_to_disk(args: argparse.Namespace) -> None: 22 | train_util.prepare_dataset_args(args, True) 23 | 24 | # check cache latents arg 25 | assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" 26 | 27 | use_dreambooth_method = args.in_json is None 28 | 29 | if args.seed is not None: 30 | set_seed(args.seed) # 乱数系列を初期化する 31 | 32 | # tokenizerを準備する:datasetを動かすために必要 33 | if args.sdxl: 34 | tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) 35 | tokenizers = [tokenizer1, tokenizer2] 36 | else: 37 | tokenizer = train_util.load_tokenizer(args) 38 | tokenizers = [tokenizer] 39 | 40 | # データセットを準備する 41 | if args.dataset_class is None: 42 | blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) 43 | if args.dataset_config is not None: 44 | print(f"Load dataset config from {args.dataset_config}") 45 | user_config = config_util.load_user_config(args.dataset_config) 46 | ignored = ["train_data_dir", "in_json"] 47 | if any(getattr(args, attr) is not None for attr in ignored): 48 | print( 49 | "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( 50 | ", ".join(ignored) 51 | ) 52 | ) 53 | else: 54 | if use_dreambooth_method: 55 | print("Using DreamBooth method.") 56 | user_config = { 57 | "datasets": [ 58 | { 59 | "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( 60 | args.train_data_dir, args.reg_data_dir 61 | ) 62 | } 63 | ] 64 | } 65 | else: 66 | print("Training with captions.") 67 | user_config = { 68 | "datasets": [ 69 | { 70 | "subsets": [ 71 | { 72 | "image_dir": args.train_data_dir, 73 | "metadata_file": args.in_json, 74 | } 75 | ] 76 | } 77 | ] 78 | } 79 | 80 | blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) 81 | train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) 82 | else: 83 | train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) 84 | 85 | # datasetのcache_latentsを呼ばなければ、生の画像が返る 86 | 87 | current_epoch = Value("i", 0) 88 | current_step = Value("i", 0) 89 | ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None 90 | collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) 91 | 92 | # acceleratorを準備する 93 | print("prepare accelerator") 94 | accelerator = train_util.prepare_accelerator(args) 95 | 96 | # mixed precisionに対応した型を用意しておき適宜castする 97 | weight_dtype, _ = train_util.prepare_dtype(args) 98 | vae_dtype = torch.float32 if args.no_half_vae else weight_dtype 99 | 100 | # モデルを読み込む 101 | print("load model") 102 | if args.sdxl: 103 | (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) 104 | else: 105 | _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) 106 | 107 | if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える 108 | vae.set_use_memory_efficient_attention_xformers(args.xformers) 109 | vae.to(accelerator.device, dtype=vae_dtype) 110 | vae.requires_grad_(False) 111 | vae.eval() 112 | 113 | # dataloaderを準備する 114 | train_dataset_group.set_caching_mode("latents") 115 | 116 | # DataLoaderのプロセス数:0はメインプロセスになる 117 | n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで 118 | 119 | train_dataloader = torch.utils.data.DataLoader( 120 | train_dataset_group, 121 | batch_size=1, 122 | shuffle=True, 123 | collate_fn=collater, 124 | num_workers=n_workers, 125 | persistent_workers=args.persistent_data_loader_workers, 126 | ) 127 | 128 | # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず 129 | train_dataloader = accelerator.prepare(train_dataloader) 130 | 131 | # データ取得のためのループ 132 | for batch in tqdm(train_dataloader): 133 | b_size = len(batch["images"]) 134 | vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size 135 | flip_aug = batch["flip_aug"] 136 | random_crop = batch["random_crop"] 137 | bucket_reso = batch["bucket_reso"] 138 | 139 | # バッチを分割して処理する 140 | for i in range(0, b_size, vae_batch_size): 141 | images = batch["images"][i : i + vae_batch_size] 142 | absolute_paths = batch["absolute_paths"][i : i + vae_batch_size] 143 | resized_sizes = batch["resized_sizes"][i : i + vae_batch_size] 144 | 145 | image_infos = [] 146 | for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)): 147 | image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) 148 | image_info.image = image 149 | image_info.bucket_reso = bucket_reso 150 | image_info.resized_size = resized_size 151 | image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" 152 | 153 | if args.skip_existing: 154 | if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): 155 | print(f"Skipping {image_info.latents_npz} because it already exists.") 156 | continue 157 | 158 | image_infos.append(image_info) 159 | 160 | if len(image_infos) > 0: 161 | train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop) 162 | 163 | accelerator.wait_for_everyone() 164 | accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") 165 | 166 | 167 | def setup_parser() -> argparse.ArgumentParser: 168 | parser = argparse.ArgumentParser() 169 | 170 | train_util.add_sd_models_arguments(parser) 171 | train_util.add_training_arguments(parser, True) 172 | train_util.add_dataset_arguments(parser, True, True, True) 173 | config_util.add_config_arguments(parser) 174 | parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") 175 | parser.add_argument( 176 | "--no_half_vae", 177 | action="store_true", 178 | help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", 179 | ) 180 | parser.add_argument( 181 | "--skip_existing", 182 | action="store_true", 183 | help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", 184 | ) 185 | return parser 186 | 187 | 188 | if __name__ == "__main__": 189 | parser = setup_parser() 190 | 191 | args = parser.parse_args() 192 | args = train_util.read_config_from_file(args, parser) 193 | 194 | cache_to_disk(args) 195 | -------------------------------------------------------------------------------- /finetune/make_captions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import json 5 | import random 6 | import sys 7 | 8 | from pathlib import Path 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import numpy as np 12 | import torch 13 | from torchvision import transforms 14 | from torchvision.transforms.functional import InterpolationMode 15 | sys.path.append(os.path.dirname(__file__)) 16 | from blip.blip import blip_decoder 17 | import library.train_util as train_util 18 | 19 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | 22 | IMAGE_SIZE = 384 23 | 24 | # 正方形でいいのか? という気がするがソースがそうなので 25 | IMAGE_TRANSFORM = transforms.Compose( 26 | [ 27 | transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC), 28 | transforms.ToTensor(), 29 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 30 | ] 31 | ) 32 | 33 | 34 | # 共通化したいが微妙に処理が異なる…… 35 | class ImageLoadingTransformDataset(torch.utils.data.Dataset): 36 | def __init__(self, image_paths): 37 | self.images = image_paths 38 | 39 | def __len__(self): 40 | return len(self.images) 41 | 42 | def __getitem__(self, idx): 43 | img_path = self.images[idx] 44 | 45 | try: 46 | image = Image.open(img_path).convert("RGB") 47 | # convert to tensor temporarily so dataloader will accept it 48 | tensor = IMAGE_TRANSFORM(image) 49 | except Exception as e: 50 | print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") 51 | return None 52 | 53 | return (tensor, img_path) 54 | 55 | 56 | def collate_fn_remove_corrupted(batch): 57 | """Collate function that allows to remove corrupted examples in the 58 | dataloader. It expects that the dataloader returns 'None' when that occurs. 59 | The 'None's in the batch are removed. 60 | """ 61 | # Filter out all the Nones (corrupted examples) 62 | batch = list(filter(lambda x: x is not None, batch)) 63 | return batch 64 | 65 | 66 | def main(args): 67 | # fix the seed for reproducibility 68 | seed = args.seed # + utils.get_rank() 69 | torch.manual_seed(seed) 70 | np.random.seed(seed) 71 | random.seed(seed) 72 | 73 | if not os.path.exists("blip"): 74 | args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path 75 | 76 | cwd = os.getcwd() 77 | print("Current Working Directory is: ", cwd) 78 | os.chdir("finetune") 79 | 80 | print(f"load images from {args.train_data_dir}") 81 | train_data_dir_path = Path(args.train_data_dir) 82 | image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 83 | print(f"found {len(image_paths)} images.") 84 | 85 | print(f"loading BLIP caption: {args.caption_weights}") 86 | model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json") 87 | model.eval() 88 | model = model.to(DEVICE) 89 | print("BLIP loaded") 90 | 91 | # captioningする 92 | def run_batch(path_imgs): 93 | imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) 94 | 95 | with torch.no_grad(): 96 | if args.beam_search: 97 | captions = model.generate( 98 | imgs, sample=False, num_beams=args.num_beams, max_length=args.max_length, min_length=args.min_length 99 | ) 100 | else: 101 | captions = model.generate( 102 | imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length 103 | ) 104 | 105 | for (image_path, _), caption in zip(path_imgs, captions): 106 | with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: 107 | f.write(caption + "\n") 108 | if args.debug: 109 | print(image_path, caption) 110 | 111 | # 読み込みの高速化のためにDataLoaderを使うオプション 112 | if args.max_data_loader_n_workers is not None: 113 | dataset = ImageLoadingTransformDataset(image_paths) 114 | data = torch.utils.data.DataLoader( 115 | dataset, 116 | batch_size=args.batch_size, 117 | shuffle=False, 118 | num_workers=args.max_data_loader_n_workers, 119 | collate_fn=collate_fn_remove_corrupted, 120 | drop_last=False, 121 | ) 122 | else: 123 | data = [[(None, ip)] for ip in image_paths] 124 | 125 | b_imgs = [] 126 | for data_entry in tqdm(data, smoothing=0.0): 127 | for data in data_entry: 128 | if data is None: 129 | continue 130 | 131 | img_tensor, image_path = data 132 | if img_tensor is None: 133 | try: 134 | raw_image = Image.open(image_path) 135 | if raw_image.mode != "RGB": 136 | raw_image = raw_image.convert("RGB") 137 | img_tensor = IMAGE_TRANSFORM(raw_image) 138 | except Exception as e: 139 | print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") 140 | continue 141 | 142 | b_imgs.append((image_path, img_tensor)) 143 | if len(b_imgs) >= args.batch_size: 144 | run_batch(b_imgs) 145 | b_imgs.clear() 146 | if len(b_imgs) > 0: 147 | run_batch(b_imgs) 148 | 149 | print("done!") 150 | 151 | 152 | def setup_parser() -> argparse.ArgumentParser: 153 | parser = argparse.ArgumentParser() 154 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 155 | parser.add_argument( 156 | "--caption_weights", 157 | type=str, 158 | default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", 159 | help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)", 160 | ) 161 | parser.add_argument( 162 | "--caption_extention", 163 | type=str, 164 | default=None, 165 | help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)", 166 | ) 167 | parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") 168 | parser.add_argument( 169 | "--beam_search", 170 | action="store_true", 171 | help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)", 172 | ) 173 | parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") 174 | parser.add_argument( 175 | "--max_data_loader_n_workers", 176 | type=int, 177 | default=None, 178 | help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", 179 | ) 180 | parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") 181 | parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") 182 | parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") 183 | parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") 184 | parser.add_argument("--seed", default=42, type=int, help="seed for reproducibility / 再現性を確保するための乱数seed") 185 | parser.add_argument("--debug", action="store_true", help="debug mode") 186 | parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") 187 | 188 | return parser 189 | 190 | 191 | if __name__ == "__main__": 192 | parser = setup_parser() 193 | 194 | args = parser.parse_args() 195 | 196 | # スペルミスしていたオプションを復元する 197 | if args.caption_extention is not None: 198 | args.caption_extension = args.caption_extention 199 | 200 | main(args) 201 | -------------------------------------------------------------------------------- /XTI_hijack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, List, Optional, Dict, Any, Tuple 3 | from diffusers.models.unet_2d_condition import UNet2DConditionOutput 4 | 5 | from library.original_unet import SampleOutput 6 | 7 | 8 | def unet_forward_XTI( 9 | self, 10 | sample: torch.FloatTensor, 11 | timestep: Union[torch.Tensor, float, int], 12 | encoder_hidden_states: torch.Tensor, 13 | class_labels: Optional[torch.Tensor] = None, 14 | return_dict: bool = True, 15 | ) -> Union[Dict, Tuple]: 16 | r""" 17 | Args: 18 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 19 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 20 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 21 | return_dict (`bool`, *optional*, defaults to `True`): 22 | Whether or not to return a dict instead of a plain tuple. 23 | 24 | Returns: 25 | `SampleOutput` or `tuple`: 26 | `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. 27 | """ 28 | # By default samples have to be AT least a multiple of the overall upsampling factor. 29 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 30 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 31 | # on the fly if necessary. 32 | # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある 33 | # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する 34 | # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い 35 | default_overall_up_factor = 2**self.num_upsamplers 36 | 37 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 38 | # 64で割り切れないときはupsamplerにサイズを伝える 39 | forward_upsample_size = False 40 | upsample_size = None 41 | 42 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 43 | # logger.info("Forward upsample size to force interpolation output size.") 44 | forward_upsample_size = True 45 | 46 | # 1. time 47 | timesteps = timestep 48 | timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 49 | 50 | t_emb = self.time_proj(timesteps) 51 | 52 | # timesteps does not contain any weights and will always return f32 tensors 53 | # but time_embedding might actually be running in fp16. so we need to cast here. 54 | # there might be better ways to encapsulate this. 55 | # timestepsは重みを含まないので常にfloat32のテンソルを返す 56 | # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある 57 | # time_projでキャストしておけばいいんじゃね? 58 | t_emb = t_emb.to(dtype=self.dtype) 59 | emb = self.time_embedding(t_emb) 60 | 61 | # 2. pre-process 62 | sample = self.conv_in(sample) 63 | 64 | # 3. down 65 | down_block_res_samples = (sample,) 66 | down_i = 0 67 | for downsample_block in self.down_blocks: 68 | # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 69 | # まあこちらのほうがわかりやすいかもしれない 70 | if downsample_block.has_cross_attention: 71 | sample, res_samples = downsample_block( 72 | hidden_states=sample, 73 | temb=emb, 74 | encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2], 75 | ) 76 | down_i += 2 77 | else: 78 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 79 | 80 | down_block_res_samples += res_samples 81 | 82 | # 4. mid 83 | sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) 84 | 85 | # 5. up 86 | up_i = 7 87 | for i, upsample_block in enumerate(self.up_blocks): 88 | is_final_block = i == len(self.up_blocks) - 1 89 | 90 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 91 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection 92 | 93 | # if we have not reached the final block and need to forward the upsample size, we do it here 94 | # 前述のように最後のブロック以外ではupsample_sizeを伝える 95 | if not is_final_block and forward_upsample_size: 96 | upsample_size = down_block_res_samples[-1].shape[2:] 97 | 98 | if upsample_block.has_cross_attention: 99 | sample = upsample_block( 100 | hidden_states=sample, 101 | temb=emb, 102 | res_hidden_states_tuple=res_samples, 103 | encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3], 104 | upsample_size=upsample_size, 105 | ) 106 | up_i += 3 107 | else: 108 | sample = upsample_block( 109 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 110 | ) 111 | 112 | # 6. post-process 113 | sample = self.conv_norm_out(sample) 114 | sample = self.conv_act(sample) 115 | sample = self.conv_out(sample) 116 | 117 | if not return_dict: 118 | return (sample,) 119 | 120 | return SampleOutput(sample=sample) 121 | 122 | 123 | def downblock_forward_XTI( 124 | self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None 125 | ): 126 | output_states = () 127 | i = 0 128 | 129 | for resnet, attn in zip(self.resnets, self.attentions): 130 | if self.training and self.gradient_checkpointing: 131 | 132 | def create_custom_forward(module, return_dict=None): 133 | def custom_forward(*inputs): 134 | if return_dict is not None: 135 | return module(*inputs, return_dict=return_dict) 136 | else: 137 | return module(*inputs) 138 | 139 | return custom_forward 140 | 141 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 142 | hidden_states = torch.utils.checkpoint.checkpoint( 143 | create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] 144 | )[0] 145 | else: 146 | hidden_states = resnet(hidden_states, temb) 147 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample 148 | 149 | output_states += (hidden_states,) 150 | i += 1 151 | 152 | if self.downsamplers is not None: 153 | for downsampler in self.downsamplers: 154 | hidden_states = downsampler(hidden_states) 155 | 156 | output_states += (hidden_states,) 157 | 158 | return hidden_states, output_states 159 | 160 | 161 | def upblock_forward_XTI( 162 | self, 163 | hidden_states, 164 | res_hidden_states_tuple, 165 | temb=None, 166 | encoder_hidden_states=None, 167 | upsample_size=None, 168 | ): 169 | i = 0 170 | for resnet, attn in zip(self.resnets, self.attentions): 171 | # pop res hidden states 172 | res_hidden_states = res_hidden_states_tuple[-1] 173 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 174 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 175 | 176 | if self.training and self.gradient_checkpointing: 177 | 178 | def create_custom_forward(module, return_dict=None): 179 | def custom_forward(*inputs): 180 | if return_dict is not None: 181 | return module(*inputs, return_dict=return_dict) 182 | else: 183 | return module(*inputs) 184 | 185 | return custom_forward 186 | 187 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 188 | hidden_states = torch.utils.checkpoint.checkpoint( 189 | create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] 190 | )[0] 191 | else: 192 | hidden_states = resnet(hidden_states, temb) 193 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample 194 | 195 | i += 1 196 | 197 | if self.upsamplers is not None: 198 | for upsampler in self.upsamplers: 199 | hidden_states = upsampler(hidden_states, upsample_size) 200 | 201 | return hidden_states 202 | -------------------------------------------------------------------------------- /finetune/merge_all_to_metadata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | from pathlib import Path 6 | from typing import List 7 | from tqdm import tqdm 8 | from collections import Counter 9 | import library.train_util as train_util 10 | 11 | TAGS_EXT = ".txt" 12 | CAPTION_EXT = ".caption" 13 | 14 | PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') 15 | PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') 16 | PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ') 17 | PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ') 18 | 19 | PATTERNS_REMOVE_IN_MULTI = [ 20 | PATTERN_HAIR_LENGTH, 21 | PATTERN_HAIR_CUT, 22 | re.compile(r', [\w\-]+ eyes, '), 23 | re.compile(r', ([\w\-]+ sleeves|sleeveless), '), 24 | re.compile( 25 | r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '), 26 | ] 27 | 28 | CAPTION_REPLACEMENTS = [ 29 | ('anime anime', 'anime'), 30 | ('young ', ''), 31 | ('anime girl', 'girl'), 32 | ('cartoon female', 'girl'), 33 | ('cartoon lady', 'girl'), 34 | ('cartoon character', 'girl'), 35 | ('cartoon woman', 'girl'), 36 | ('cartoon women', 'girls'), 37 | ('cartoon girl', 'girl'), 38 | ('anime female', 'girl'), 39 | ('anime lady', 'girl'), 40 | ('anime character', 'girl'), 41 | ('anime woman', 'girl'), 42 | ('anime women', 'girls'), 43 | ('lady', 'girl'), 44 | ('female', 'girl'), 45 | ('woman', 'girl'), 46 | ('women', 'girls'), 47 | ('people', 'girls'), 48 | ('person', 'girl'), 49 | ('a cartoon figure', 'a figure'), 50 | ('a cartoon image', 'an image'), 51 | ('a cartoon picture', 'a picture'), 52 | ('an anime cartoon image', 'an image'), 53 | ('a cartoon anime drawing', 'a drawing'), 54 | ('a cartoon drawing', 'a drawing'), 55 | ('girl girl', 'girl'), 56 | ] 57 | 58 | def clean_tags(image_key, tags): 59 | tags = tags.replace('^_^', '^@@@^') 60 | tags = tags.replace('_', ' ') 61 | tags = tags.replace('^@@@^', '^_^') 62 | 63 | tokens = tags.split(", rating") 64 | if len(tokens) == 1: 65 | pass 66 | else: 67 | if len(tokens) > 2: 68 | print("multiple ratings:") 69 | print(f"{image_key} {tags}") 70 | tags = tokens[0] 71 | 72 | tags = ", " + tags.replace(", ", ", , ") + ", " 73 | 74 | if 'girls' in tags or 'boys' in tags: 75 | for pat in PATTERNS_REMOVE_IN_MULTI: 76 | found = pat.findall(tags) 77 | if len(found) > 1: 78 | tags = pat.sub("", tags) 79 | 80 | srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) 81 | if srch_hair_len: 82 | org = srch_hair_len.group() 83 | tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags) 84 | 85 | found = PATTERN_HAIR.findall(tags) 86 | if len(found) > 1: 87 | tags = PATTERN_HAIR.sub("", tags) 88 | 89 | if srch_hair_len: 90 | tags = tags.replace(", @@@, ", org) 91 | 92 | found = PATTERN_WORD.findall(tags) 93 | for word in found: 94 | if re.search(f", ((\w+) )+{word}, ", tags): 95 | tags = tags.replace(f", {word}, ", "") 96 | 97 | tags = tags.replace(", , ", ", ") 98 | assert tags.startswith(", ") and tags.endswith(", ") 99 | tags = tags[2:-2] 100 | return tags 101 | 102 | def clean_caption(caption): 103 | for rf, rt in CAPTION_REPLACEMENTS: 104 | replaced = True 105 | while replaced: 106 | bef = caption 107 | caption = caption.replace(rf, rt) 108 | replaced = bef != caption 109 | return caption 110 | 111 | def count_files(image_paths, metadata): 112 | counts = Counter({'_captions': 0, '_tags': 0}) 113 | 114 | for image_key in metadata: 115 | if 'tags' not in metadata[image_key]: 116 | counts['_tags'] += 1 117 | if 'caption' not in metadata[image_key]: 118 | counts['_captions'] += 1 119 | 120 | return counts 121 | 122 | def report_counts(counts, total_files): 123 | for key, value in counts.items(): 124 | if value == total_files: 125 | print(f"No {key.replace('_', '')} found for any of the {total_files} images") 126 | elif value == 0: 127 | print(f"All {total_files} images have {key.replace('_', '')}") 128 | else: 129 | print(f"{total_files - value}/{total_files} images have {key.replace('_', '')}") 130 | 131 | def merge_metadata(image_paths, metadata, full_path): 132 | for image_path in tqdm(image_paths): 133 | tags_path = image_path.with_suffix(TAGS_EXT) 134 | if not tags_path.exists(): 135 | tags_path = image_path.joinpath(TAGS_EXT) 136 | 137 | caption_path = image_path.with_suffix(CAPTION_EXT) 138 | if not caption_path.exists(): 139 | caption_path = image_path.joinpath(CAPTION_EXT) 140 | 141 | image_key = str(image_path) if full_path else image_path.stem 142 | if image_key not in metadata: 143 | metadata[image_key] = {} 144 | 145 | if tags_path.is_file(): 146 | tags = tags_path.read_text(encoding='utf-8').strip() 147 | metadata[image_key]['tags'] = tags 148 | 149 | if caption_path.is_file(): 150 | caption = caption_path.read_text(encoding='utf-8').strip() 151 | metadata[image_key]['caption'] = caption 152 | 153 | counts = count_files(image_paths, metadata) 154 | report_counts(counts, len(image_paths)) 155 | 156 | return metadata 157 | 158 | def clean_metadata(metadata): 159 | image_keys = list(metadata.keys()) 160 | for image_key in tqdm(image_keys): 161 | tags = metadata[image_key].get('tags') 162 | if tags is not None: 163 | org = tags 164 | tags = clean_tags(image_key, tags) 165 | metadata[image_key]['tags'] = tags 166 | 167 | caption = metadata[image_key].get('caption') 168 | if caption is not None: 169 | org = caption 170 | caption = clean_caption(caption) 171 | metadata[image_key]['caption'] = caption 172 | 173 | return metadata 174 | 175 | def main(args): 176 | assert not args.recursive or (args.recursive and args.full_path), "--recursive requires --full_path!" 177 | 178 | train_data_dir_path = Path(args.train_data_dir) 179 | image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 180 | print(f"Found {len(image_paths)} images.") 181 | 182 | if args.in_json is not None: 183 | print(f"Loading existing metadata: {args.in_json}") 184 | metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) 185 | print("Metadata for existing images will be overwritten") 186 | else: 187 | print("Creating a new metadata file") 188 | metadata = {} 189 | 190 | print("Merging tags and captions into metadata json.") 191 | metadata = merge_metadata(image_paths, metadata, args.full_path) 192 | 193 | if args.clean_caption: 194 | print("Cleaning captions and tags.") 195 | metadata = clean_metadata(metadata) 196 | 197 | if args.debug: 198 | print("Debug: image_key, tags, caption") 199 | for image_key, data in metadata.items(): 200 | print(image_key, data['tags'], data['caption']) 201 | 202 | print(f"Writing metadata: {args.out_json}") 203 | Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') 204 | print("Done!") 205 | 206 | def setup_parser() -> argparse.ArgumentParser: 207 | parser = argparse.ArgumentParser() 208 | parser.add_argument("train_data_dir", type=str, help="directory for train images") 209 | parser.add_argument("out_json", type=str, help="metadata file to output") 210 | parser.add_argument("--in_json", type=str, 211 | help="metadata file to input (if omitted and out_json exists, existing out_json is read)") 212 | parser.add_argument("--full_path", action="store_true", 213 | help="use full path as image-key in metadata (supports multiple directories)") 214 | parser.add_argument("--recursive", action="store_true", 215 | help="recursively search for training tags and captions in all child folders of train_data_dir") 216 | parser.add_argument("--debug", action="store_true", help="debug mode") 217 | parser.add_argument("--clean_caption", action="store_true", help="clean captions and tags in metadata") 218 | 219 | return parser 220 | 221 | if __name__ == '__main__': 222 | parser = setup_parser() 223 | 224 | args = parser.parse_args() 225 | 226 | main(args) 227 | -------------------------------------------------------------------------------- /sdxl_train_network.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from library import sdxl_model_util, sdxl_train_util, train_util 4 | import train_network 5 | 6 | 7 | class SdxlNetworkTrainer(train_network.NetworkTrainer): 8 | def __init__(self): 9 | super().__init__() 10 | self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR 11 | 12 | def assert_extra_args(self, args, train_dataset_group): 13 | super().assert_extra_args(args, train_dataset_group) 14 | sdxl_train_util.verify_sdxl_training_args(args) 15 | 16 | if args.cache_text_encoder_outputs: 17 | assert ( 18 | train_dataset_group.is_text_encoder_output_cacheable() 19 | ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" 20 | 21 | assert ( 22 | args.network_train_unet_only or not args.cache_text_encoder_outputs 23 | ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" 24 | 25 | def load_target_model(self, args, weight_dtype, accelerator): 26 | ( 27 | load_stable_diffusion_format, 28 | text_encoder1, 29 | text_encoder2, 30 | vae, 31 | unet, 32 | logit_scale, 33 | ckpt_info, 34 | ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype) 35 | 36 | self.load_stable_diffusion_format = load_stable_diffusion_format 37 | self.logit_scale = logit_scale 38 | self.ckpt_info = ckpt_info 39 | 40 | return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet 41 | 42 | def load_tokenizer(self, args): 43 | tokenizer = sdxl_train_util.load_tokenizers(args) 44 | return tokenizer 45 | 46 | def is_text_encoder_outputs_cached(self, args): 47 | return args.cache_text_encoder_outputs 48 | 49 | def cache_text_encoder_outputs_if_needed( 50 | self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype 51 | ): 52 | if args.cache_text_encoder_outputs: 53 | if not args.lowram: 54 | # メモリ消費を減らす 55 | print("move vae and unet to cpu to save memory") 56 | org_vae_device = vae.device 57 | org_unet_device = unet.device 58 | vae.to("cpu") 59 | unet.to("cpu") 60 | if torch.cuda.is_available(): 61 | torch.cuda.empty_cache() 62 | 63 | dataset.cache_text_encoder_outputs( 64 | tokenizers, 65 | text_encoders, 66 | accelerator.device, 67 | weight_dtype, 68 | args.cache_text_encoder_outputs_to_disk, 69 | accelerator.is_main_process, 70 | ) 71 | 72 | text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU 73 | text_encoders[1].to("cpu", dtype=torch.float32) 74 | if torch.cuda.is_available(): 75 | torch.cuda.empty_cache() 76 | 77 | if not args.lowram: 78 | print("move vae and unet back to original device") 79 | vae.to(org_vae_device) 80 | unet.to(org_unet_device) 81 | else: 82 | # Text Encoderから毎回出力を取得するので、GPUに乗せておく 83 | text_encoders[0].to(accelerator.device) 84 | text_encoders[1].to(accelerator.device) 85 | 86 | def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): 87 | if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: 88 | input_ids1 = batch["input_ids"] 89 | input_ids2 = batch["input_ids2"] 90 | with torch.enable_grad(): 91 | # Get the text embedding for conditioning 92 | # TODO support weighted captions 93 | # if args.weighted_captions: 94 | # encoder_hidden_states = get_weighted_text_embeddings( 95 | # tokenizer, 96 | # text_encoder, 97 | # batch["captions"], 98 | # accelerator.device, 99 | # args.max_token_length // 75 if args.max_token_length else 1, 100 | # clip_skip=args.clip_skip, 101 | # ) 102 | # else: 103 | input_ids1 = input_ids1.to(accelerator.device) 104 | input_ids2 = input_ids2.to(accelerator.device) 105 | encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( 106 | args.max_token_length, 107 | input_ids1, 108 | input_ids2, 109 | tokenizers[0], 110 | tokenizers[1], 111 | text_encoders[0], 112 | text_encoders[1], 113 | None if not args.full_fp16 else weight_dtype, 114 | ) 115 | else: 116 | encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) 117 | encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) 118 | pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) 119 | 120 | # # verify that the text encoder outputs are correct 121 | # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( 122 | # args.max_token_length, 123 | # batch["input_ids"].to(text_encoders[0].device), 124 | # batch["input_ids2"].to(text_encoders[0].device), 125 | # tokenizers[0], 126 | # tokenizers[1], 127 | # text_encoders[0], 128 | # text_encoders[1], 129 | # None if not args.full_fp16 else weight_dtype, 130 | # ) 131 | # b_size = encoder_hidden_states1.shape[0] 132 | # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 133 | # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 134 | # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 135 | # print("text encoder outputs verified") 136 | 137 | 138 | return encoder_hidden_states1, encoder_hidden_states2, pool2 139 | 140 | def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): 141 | noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype 142 | 143 | # get size embeddings 144 | orig_size = batch["original_sizes_hw"] 145 | crop_size = batch["crop_top_lefts"] 146 | target_size = batch["target_sizes_hw"] 147 | embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) 148 | 149 | # concat embeddings 150 | encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds 151 | vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) 152 | text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) 153 | 154 | noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) 155 | return noise_pred 156 | 157 | def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): 158 | sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) 159 | 160 | 161 | def setup_parser() -> argparse.ArgumentParser: 162 | parser = train_network.setup_parser() 163 | sdxl_train_util.add_sdxl_training_arguments(parser) 164 | return parser 165 | 166 | 167 | if __name__ == "__main__": 168 | parser = setup_parser() 169 | 170 | args = parser.parse_args() 171 | args = train_util.read_config_from_file(args, parser) 172 | 173 | trainer = SdxlNetworkTrainer() 174 | trainer.train(args) 175 | -------------------------------------------------------------------------------- /library/attention_processors.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any 3 | from einops import rearrange 4 | import torch 5 | from diffusers.models.attention_processor import Attention 6 | 7 | 8 | # flash attention forwards and backwards 9 | 10 | # https://arxiv.org/abs/2205.14135 11 | 12 | EPSILON = 1e-6 13 | 14 | 15 | class FlashAttentionFunction(torch.autograd.function.Function): 16 | @staticmethod 17 | @torch.no_grad() 18 | def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): 19 | """Algorithm 2 in the paper""" 20 | 21 | device = q.device 22 | dtype = q.dtype 23 | max_neg_value = -torch.finfo(q.dtype).max 24 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) 25 | 26 | o = torch.zeros_like(q) 27 | all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) 28 | all_row_maxes = torch.full( 29 | (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device 30 | ) 31 | 32 | scale = q.shape[-1] ** -0.5 33 | 34 | if mask is None: 35 | mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) 36 | else: 37 | mask = rearrange(mask, "b n -> b 1 1 n") 38 | mask = mask.split(q_bucket_size, dim=-1) 39 | 40 | row_splits = zip( 41 | q.split(q_bucket_size, dim=-2), 42 | o.split(q_bucket_size, dim=-2), 43 | mask, 44 | all_row_sums.split(q_bucket_size, dim=-2), 45 | all_row_maxes.split(q_bucket_size, dim=-2), 46 | ) 47 | 48 | for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): 49 | q_start_index = ind * q_bucket_size - qk_len_diff 50 | 51 | col_splits = zip( 52 | k.split(k_bucket_size, dim=-2), 53 | v.split(k_bucket_size, dim=-2), 54 | ) 55 | 56 | for k_ind, (kc, vc) in enumerate(col_splits): 57 | k_start_index = k_ind * k_bucket_size 58 | 59 | attn_weights = ( 60 | torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale 61 | ) 62 | 63 | if row_mask is not None: 64 | attn_weights.masked_fill_(~row_mask, max_neg_value) 65 | 66 | if causal and q_start_index < (k_start_index + k_bucket_size - 1): 67 | causal_mask = torch.ones( 68 | (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device 69 | ).triu(q_start_index - k_start_index + 1) 70 | attn_weights.masked_fill_(causal_mask, max_neg_value) 71 | 72 | block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) 73 | attn_weights -= block_row_maxes 74 | exp_weights = torch.exp(attn_weights) 75 | 76 | if row_mask is not None: 77 | exp_weights.masked_fill_(~row_mask, 0.0) 78 | 79 | block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( 80 | min=EPSILON 81 | ) 82 | 83 | new_row_maxes = torch.maximum(block_row_maxes, row_maxes) 84 | 85 | exp_values = torch.einsum( 86 | "... i j, ... j d -> ... i d", exp_weights, vc 87 | ) 88 | 89 | exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) 90 | exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) 91 | 92 | new_row_sums = ( 93 | exp_row_max_diff * row_sums 94 | + exp_block_row_max_diff * block_row_sums 95 | ) 96 | 97 | oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( 98 | (exp_block_row_max_diff / new_row_sums) * exp_values 99 | ) 100 | 101 | row_maxes.copy_(new_row_maxes) 102 | row_sums.copy_(new_row_sums) 103 | 104 | ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) 105 | ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) 106 | 107 | return o 108 | 109 | @staticmethod 110 | @torch.no_grad() 111 | def backward(ctx, do): 112 | """Algorithm 4 in the paper""" 113 | 114 | causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args 115 | q, k, v, o, l, m = ctx.saved_tensors 116 | 117 | device = q.device 118 | 119 | max_neg_value = -torch.finfo(q.dtype).max 120 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) 121 | 122 | dq = torch.zeros_like(q) 123 | dk = torch.zeros_like(k) 124 | dv = torch.zeros_like(v) 125 | 126 | row_splits = zip( 127 | q.split(q_bucket_size, dim=-2), 128 | o.split(q_bucket_size, dim=-2), 129 | do.split(q_bucket_size, dim=-2), 130 | mask, 131 | l.split(q_bucket_size, dim=-2), 132 | m.split(q_bucket_size, dim=-2), 133 | dq.split(q_bucket_size, dim=-2), 134 | ) 135 | 136 | for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): 137 | q_start_index = ind * q_bucket_size - qk_len_diff 138 | 139 | col_splits = zip( 140 | k.split(k_bucket_size, dim=-2), 141 | v.split(k_bucket_size, dim=-2), 142 | dk.split(k_bucket_size, dim=-2), 143 | dv.split(k_bucket_size, dim=-2), 144 | ) 145 | 146 | for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): 147 | k_start_index = k_ind * k_bucket_size 148 | 149 | attn_weights = ( 150 | torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale 151 | ) 152 | 153 | if causal and q_start_index < (k_start_index + k_bucket_size - 1): 154 | causal_mask = torch.ones( 155 | (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device 156 | ).triu(q_start_index - k_start_index + 1) 157 | attn_weights.masked_fill_(causal_mask, max_neg_value) 158 | 159 | exp_attn_weights = torch.exp(attn_weights - mc) 160 | 161 | if row_mask is not None: 162 | exp_attn_weights.masked_fill_(~row_mask, 0.0) 163 | 164 | p = exp_attn_weights / lc 165 | 166 | dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) 167 | dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) 168 | 169 | D = (doc * oc).sum(dim=-1, keepdims=True) 170 | ds = p * scale * (dp - D) 171 | 172 | dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) 173 | dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) 174 | 175 | dqc.add_(dq_chunk) 176 | dkc.add_(dk_chunk) 177 | dvc.add_(dv_chunk) 178 | 179 | return dq, dk, dv, None, None, None, None 180 | 181 | 182 | class FlashAttnProcessor: 183 | def __call__( 184 | self, 185 | attn: Attention, 186 | hidden_states, 187 | encoder_hidden_states=None, 188 | attention_mask=None, 189 | ) -> Any: 190 | q_bucket_size = 512 191 | k_bucket_size = 1024 192 | 193 | h = attn.heads 194 | q = attn.to_q(hidden_states) 195 | 196 | encoder_hidden_states = ( 197 | encoder_hidden_states 198 | if encoder_hidden_states is not None 199 | else hidden_states 200 | ) 201 | encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) 202 | 203 | if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None: 204 | context_k, context_v = attn.hypernetwork.forward( 205 | hidden_states, encoder_hidden_states 206 | ) 207 | context_k = context_k.to(hidden_states.dtype) 208 | context_v = context_v.to(hidden_states.dtype) 209 | else: 210 | context_k = encoder_hidden_states 211 | context_v = encoder_hidden_states 212 | 213 | k = attn.to_k(context_k) 214 | v = attn.to_v(context_v) 215 | del encoder_hidden_states, hidden_states 216 | 217 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) 218 | 219 | out = FlashAttentionFunction.apply( 220 | q, k, v, attention_mask, False, q_bucket_size, k_bucket_size 221 | ) 222 | 223 | out = rearrange(out, "b h n d -> b n (h d)") 224 | 225 | out = attn.to_out[0](out) 226 | out = attn.to_out[1](out) 227 | return out 228 | -------------------------------------------------------------------------------- /tools/detect_face_rotate.py: -------------------------------------------------------------------------------- 1 | # このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします 2 | # (c) 2022 Kohya S. @kohya_ss 3 | 4 | # 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す 5 | 6 | # v2: extract max face if multiple faces are found 7 | # v3: add crop_ratio option 8 | # v4: add multiple faces extraction and min/max size 9 | 10 | import argparse 11 | import math 12 | import cv2 13 | import glob 14 | import os 15 | from anime_face_detector import create_detector 16 | from tqdm import tqdm 17 | import numpy as np 18 | 19 | KP_REYE = 11 20 | KP_LEYE = 19 21 | 22 | SCORE_THRES = 0.90 23 | 24 | 25 | def detect_faces(detector, image, min_size): 26 | preds = detector(image) # bgr 27 | # print(len(preds)) 28 | 29 | faces = [] 30 | for pred in preds: 31 | bb = pred['bbox'] 32 | score = bb[-1] 33 | if score < SCORE_THRES: 34 | continue 35 | 36 | left, top, right, bottom = bb[:4] 37 | cx = int((left + right) / 2) 38 | cy = int((top + bottom) / 2) 39 | fw = int(right - left) 40 | fh = int(bottom - top) 41 | 42 | lex, ley = pred['keypoints'][KP_LEYE, 0:2] 43 | rex, rey = pred['keypoints'][KP_REYE, 0:2] 44 | angle = math.atan2(ley - rey, lex - rex) 45 | angle = angle / math.pi * 180 46 | 47 | faces.append((cx, cy, fw, fh, angle)) 48 | 49 | faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順 50 | return faces 51 | 52 | 53 | def rotate_image(image, angle, cx, cy): 54 | h, w = image.shape[0:2] 55 | rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) 56 | 57 | # # 回転する分、すこし画像サイズを大きくする→とりあえず無効化 58 | # nh = max(h, int(w * math.sin(angle))) 59 | # nw = max(w, int(h * math.sin(angle))) 60 | # if nh > h or nw > w: 61 | # pad_y = nh - h 62 | # pad_t = pad_y // 2 63 | # pad_x = nw - w 64 | # pad_l = pad_x // 2 65 | # m = np.array([[0, 0, pad_l], 66 | # [0, 0, pad_t]]) 67 | # rot_mat = rot_mat + m 68 | # h, w = nh, nw 69 | # cx += pad_l 70 | # cy += pad_t 71 | 72 | result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) 73 | return result, cx, cy 74 | 75 | 76 | def process(args): 77 | assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません" 78 | assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません" 79 | 80 | # アニメ顔検出モデルを読み込む 81 | print("loading face detector.") 82 | detector = create_detector('yolov3') 83 | 84 | # cropの引数を解析する 85 | if args.crop_size is None: 86 | crop_width = crop_height = None 87 | else: 88 | tokens = args.crop_size.split(',') 89 | assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください" 90 | crop_width, crop_height = [int(t) for t in tokens] 91 | 92 | if args.crop_ratio is None: 93 | crop_h_ratio = crop_v_ratio = None 94 | else: 95 | tokens = args.crop_ratio.split(',') 96 | assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください" 97 | crop_h_ratio, crop_v_ratio = [float(t) for t in tokens] 98 | 99 | # 画像を処理する 100 | print("processing.") 101 | output_extension = ".png" 102 | 103 | os.makedirs(args.dst_dir, exist_ok=True) 104 | paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \ 105 | glob.glob(os.path.join(args.src_dir, "*.webp")) 106 | for path in tqdm(paths): 107 | basename = os.path.splitext(os.path.basename(path))[0] 108 | 109 | # image = cv2.imread(path) # 日本語ファイル名でエラーになる 110 | image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED) 111 | if len(image.shape) == 2: 112 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) 113 | if image.shape[2] == 4: 114 | print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") 115 | image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい 116 | 117 | h, w = image.shape[:2] 118 | 119 | faces = detect_faces(detector, image, args.multiple_faces) 120 | for i, face in enumerate(faces): 121 | cx, cy, fw, fh, angle = face 122 | face_size = max(fw, fh) 123 | if args.min_size is not None and face_size < args.min_size: 124 | continue 125 | if args.max_size is not None and face_size >= args.max_size: 126 | continue 127 | face_suffix = f"_{i+1:02d}" if args.multiple_faces else "" 128 | 129 | # オプション指定があれば回転する 130 | face_img = image 131 | if args.rotate: 132 | face_img, cx, cy = rotate_image(face_img, angle, cx, cy) 133 | 134 | # オプション指定があれば顔を中心に切り出す 135 | if crop_width is not None or crop_h_ratio is not None: 136 | cur_crop_width, cur_crop_height = crop_width, crop_height 137 | if crop_h_ratio is not None: 138 | cur_crop_width = int(face_size * crop_h_ratio + .5) 139 | cur_crop_height = int(face_size * crop_v_ratio + .5) 140 | 141 | # リサイズを必要なら行う 142 | scale = 1.0 143 | if args.resize_face_size is not None: 144 | # 顔サイズを基準にリサイズする 145 | scale = args.resize_face_size / face_size 146 | if scale < cur_crop_width / w: 147 | print( 148 | f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") 149 | scale = cur_crop_width / w 150 | if scale < cur_crop_height / h: 151 | print( 152 | f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") 153 | scale = cur_crop_height / h 154 | elif crop_h_ratio is not None: 155 | # 倍率指定の時にはリサイズしない 156 | pass 157 | else: 158 | # 切り出しサイズ指定あり 159 | if w < cur_crop_width: 160 | print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") 161 | scale = cur_crop_width / w 162 | if h < cur_crop_height: 163 | print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") 164 | scale = cur_crop_height / h 165 | if args.resize_fit: 166 | scale = max(cur_crop_width / w, cur_crop_height / h) 167 | 168 | if scale != 1.0: 169 | w = int(w * scale + .5) 170 | h = int(h * scale + .5) 171 | face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4) 172 | cx = int(cx * scale + .5) 173 | cy = int(cy * scale + .5) 174 | fw = int(fw * scale + .5) 175 | fh = int(fh * scale + .5) 176 | 177 | cur_crop_width = min(cur_crop_width, face_img.shape[1]) 178 | cur_crop_height = min(cur_crop_height, face_img.shape[0]) 179 | 180 | x = cx - cur_crop_width // 2 181 | cx = cur_crop_width // 2 182 | if x < 0: 183 | cx = cx + x 184 | x = 0 185 | elif x + cur_crop_width > w: 186 | cx = cx + (x + cur_crop_width - w) 187 | x = w - cur_crop_width 188 | face_img = face_img[:, x:x+cur_crop_width] 189 | 190 | y = cy - cur_crop_height // 2 191 | cy = cur_crop_height // 2 192 | if y < 0: 193 | cy = cy + y 194 | y = 0 195 | elif y + cur_crop_height > h: 196 | cy = cy + (y + cur_crop_height - h) 197 | y = h - cur_crop_height 198 | face_img = face_img[y:y + cur_crop_height] 199 | 200 | # # debug 201 | # print(path, cx, cy, angle) 202 | # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8)) 203 | # cv2.imshow("image", crp) 204 | # if cv2.waitKey() == 27: 205 | # break 206 | # cv2.destroyAllWindows() 207 | 208 | # debug 209 | if args.debug: 210 | cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20) 211 | 212 | _, buf = cv2.imencode(output_extension, face_img) 213 | with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f: 214 | buf.tofile(f) 215 | 216 | 217 | def setup_parser() -> argparse.ArgumentParser: 218 | parser = argparse.ArgumentParser() 219 | parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ") 220 | parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ") 221 | parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する") 222 | parser.add_argument("--resize_fit", action="store_true", 223 | help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする") 224 | parser.add_argument("--resize_face_size", type=int, default=None, 225 | help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする") 226 | parser.add_argument("--crop_size", type=str, default=None, 227 | help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す") 228 | parser.add_argument("--crop_ratio", type=str, default=None, 229 | help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す") 230 | parser.add_argument("--min_size", type=int, default=None, 231 | help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)") 232 | parser.add_argument("--max_size", type=int, default=None, 233 | help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)") 234 | parser.add_argument("--multiple_faces", action="store_true", 235 | help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す") 236 | parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します") 237 | 238 | return parser 239 | 240 | 241 | if __name__ == '__main__': 242 | parser = setup_parser() 243 | 244 | args = parser.parse_args() 245 | 246 | process(args) 247 | -------------------------------------------------------------------------------- /networks/merge_lora.py: -------------------------------------------------------------------------------- 1 | import math 2 | import argparse 3 | import os 4 | import torch 5 | from safetensors.torch import load_file, save_file 6 | import library.model_util as model_util 7 | import lora 8 | 9 | 10 | def load_state_dict(file_name, dtype): 11 | if os.path.splitext(file_name)[1] == ".safetensors": 12 | sd = load_file(file_name) 13 | else: 14 | sd = torch.load(file_name, map_location="cpu") 15 | for key in list(sd.keys()): 16 | if type(sd[key]) == torch.Tensor: 17 | sd[key] = sd[key].to(dtype) 18 | return sd 19 | 20 | 21 | def save_to_file(file_name, model, state_dict, dtype): 22 | if dtype is not None: 23 | for key in list(state_dict.keys()): 24 | if type(state_dict[key]) == torch.Tensor: 25 | state_dict[key] = state_dict[key].to(dtype) 26 | 27 | if os.path.splitext(file_name)[1] == ".safetensors": 28 | save_file(model, file_name) 29 | else: 30 | torch.save(model, file_name) 31 | 32 | 33 | def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): 34 | text_encoder.to(merge_dtype) 35 | unet.to(merge_dtype) 36 | 37 | # create module map 38 | name_to_module = {} 39 | for i, root_module in enumerate([text_encoder, unet]): 40 | if i == 0: 41 | prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER 42 | target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE 43 | else: 44 | prefix = lora.LoRANetwork.LORA_PREFIX_UNET 45 | target_replace_modules = ( 46 | lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 47 | ) 48 | 49 | for name, module in root_module.named_modules(): 50 | if module.__class__.__name__ in target_replace_modules: 51 | for child_name, child_module in module.named_modules(): 52 | if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": 53 | lora_name = prefix + "." + name + "." + child_name 54 | lora_name = lora_name.replace(".", "_") 55 | name_to_module[lora_name] = child_module 56 | 57 | for model, ratio in zip(models, ratios): 58 | print(f"loading: {model}") 59 | lora_sd = load_state_dict(model, merge_dtype) 60 | 61 | print(f"merging...") 62 | for key in lora_sd.keys(): 63 | if "lora_down" in key: 64 | up_key = key.replace("lora_down", "lora_up") 65 | alpha_key = key[: key.index("lora_down")] + "alpha" 66 | 67 | # find original module for this lora 68 | module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" 69 | if module_name not in name_to_module: 70 | print(f"no module found for LoRA weight: {key}") 71 | continue 72 | module = name_to_module[module_name] 73 | # print(f"apply {key} to {module}") 74 | 75 | down_weight = lora_sd[key] 76 | up_weight = lora_sd[up_key] 77 | 78 | dim = down_weight.size()[0] 79 | alpha = lora_sd.get(alpha_key, dim) 80 | scale = alpha / dim 81 | 82 | # W <- W + U * D 83 | weight = module.weight 84 | # print(module_name, down_weight.size(), up_weight.size()) 85 | if len(weight.size()) == 2: 86 | # linear 87 | weight = weight + ratio * (up_weight @ down_weight) * scale 88 | elif down_weight.size()[2:4] == (1, 1): 89 | # conv2d 1x1 90 | weight = ( 91 | weight 92 | + ratio 93 | * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) 94 | * scale 95 | ) 96 | else: 97 | # conv2d 3x3 98 | conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) 99 | # print(conved.size(), weight.size(), module.stride, module.padding) 100 | weight = weight + ratio * conved * scale 101 | 102 | module.weight = torch.nn.Parameter(weight) 103 | 104 | 105 | def merge_lora_models(models, ratios, merge_dtype): 106 | base_alphas = {} # alpha for merged model 107 | base_dims = {} 108 | 109 | merged_sd = {} 110 | for model, ratio in zip(models, ratios): 111 | print(f"loading: {model}") 112 | lora_sd = load_state_dict(model, merge_dtype) 113 | 114 | # get alpha and dim 115 | alphas = {} # alpha for current model 116 | dims = {} # dims for current model 117 | for key in lora_sd.keys(): 118 | if "alpha" in key: 119 | lora_module_name = key[: key.rfind(".alpha")] 120 | alpha = float(lora_sd[key].detach().numpy()) 121 | alphas[lora_module_name] = alpha 122 | if lora_module_name not in base_alphas: 123 | base_alphas[lora_module_name] = alpha 124 | elif "lora_down" in key: 125 | lora_module_name = key[: key.rfind(".lora_down")] 126 | dim = lora_sd[key].size()[0] 127 | dims[lora_module_name] = dim 128 | if lora_module_name not in base_dims: 129 | base_dims[lora_module_name] = dim 130 | 131 | for lora_module_name in dims.keys(): 132 | if lora_module_name not in alphas: 133 | alpha = dims[lora_module_name] 134 | alphas[lora_module_name] = alpha 135 | if lora_module_name not in base_alphas: 136 | base_alphas[lora_module_name] = alpha 137 | 138 | print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") 139 | 140 | # merge 141 | print(f"merging...") 142 | for key in lora_sd.keys(): 143 | if "alpha" in key: 144 | continue 145 | 146 | lora_module_name = key[: key.rfind(".lora_")] 147 | 148 | base_alpha = base_alphas[lora_module_name] 149 | alpha = alphas[lora_module_name] 150 | 151 | scale = math.sqrt(alpha / base_alpha) * ratio 152 | 153 | if key in merged_sd: 154 | assert ( 155 | merged_sd[key].size() == lora_sd[key].size() 156 | ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" 157 | merged_sd[key] = merged_sd[key] + lora_sd[key] * scale 158 | else: 159 | merged_sd[key] = lora_sd[key] * scale 160 | 161 | # set alpha to sd 162 | for lora_module_name, alpha in base_alphas.items(): 163 | key = lora_module_name + ".alpha" 164 | merged_sd[key] = torch.tensor(alpha) 165 | 166 | print("merged model") 167 | print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") 168 | 169 | return merged_sd 170 | 171 | 172 | def merge(args): 173 | assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" 174 | 175 | def str_to_dtype(p): 176 | if p == "float": 177 | return torch.float 178 | if p == "fp16": 179 | return torch.float16 180 | if p == "bf16": 181 | return torch.bfloat16 182 | return None 183 | 184 | merge_dtype = str_to_dtype(args.precision) 185 | save_dtype = str_to_dtype(args.save_precision) 186 | if save_dtype is None: 187 | save_dtype = merge_dtype 188 | 189 | if args.sd_model is not None: 190 | print(f"loading SD model: {args.sd_model}") 191 | 192 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) 193 | 194 | merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) 195 | 196 | print(f"saving SD model to: {args.save_to}") 197 | model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) 198 | else: 199 | state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) 200 | 201 | print(f"saving model to: {args.save_to}") 202 | save_to_file(args.save_to, state_dict, state_dict, save_dtype) 203 | 204 | 205 | def setup_parser() -> argparse.ArgumentParser: 206 | parser = argparse.ArgumentParser() 207 | parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") 208 | parser.add_argument( 209 | "--save_precision", 210 | type=str, 211 | default=None, 212 | choices=[None, "float", "fp16", "bf16"], 213 | help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", 214 | ) 215 | parser.add_argument( 216 | "--precision", 217 | type=str, 218 | default="float", 219 | choices=["float", "fp16", "bf16"], 220 | help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", 221 | ) 222 | parser.add_argument( 223 | "--sd_model", 224 | type=str, 225 | default=None, 226 | help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", 227 | ) 228 | parser.add_argument( 229 | "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" 230 | ) 231 | parser.add_argument( 232 | "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" 233 | ) 234 | parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") 235 | 236 | return parser 237 | 238 | 239 | if __name__ == "__main__": 240 | parser = setup_parser() 241 | 242 | args = parser.parse_args() 243 | merge(args) 244 | -------------------------------------------------------------------------------- /networks/sdxl_merge_lora.py: -------------------------------------------------------------------------------- 1 | import math 2 | import argparse 3 | import os 4 | import torch 5 | from safetensors.torch import load_file, save_file 6 | from tqdm import tqdm 7 | from library import sdxl_model_util 8 | import library.model_util as model_util 9 | import lora 10 | 11 | 12 | def load_state_dict(file_name, dtype): 13 | if os.path.splitext(file_name)[1] == ".safetensors": 14 | sd = load_file(file_name) 15 | else: 16 | sd = torch.load(file_name, map_location="cpu") 17 | for key in list(sd.keys()): 18 | if type(sd[key]) == torch.Tensor: 19 | sd[key] = sd[key].to(dtype) 20 | return sd 21 | 22 | 23 | def save_to_file(file_name, model, state_dict, dtype): 24 | if dtype is not None: 25 | for key in list(state_dict.keys()): 26 | if type(state_dict[key]) == torch.Tensor: 27 | state_dict[key] = state_dict[key].to(dtype) 28 | 29 | if os.path.splitext(file_name)[1] == ".safetensors": 30 | save_file(model, file_name) 31 | else: 32 | torch.save(model, file_name) 33 | 34 | 35 | def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): 36 | text_encoder1.to(merge_dtype) 37 | text_encoder1.to(merge_dtype) 38 | unet.to(merge_dtype) 39 | 40 | # create module map 41 | name_to_module = {} 42 | for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): 43 | if i <= 1: 44 | if i == 0: 45 | prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 46 | else: 47 | prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 48 | target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE 49 | else: 50 | prefix = lora.LoRANetwork.LORA_PREFIX_UNET 51 | target_replace_modules = ( 52 | lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 53 | ) 54 | 55 | for name, module in root_module.named_modules(): 56 | if module.__class__.__name__ in target_replace_modules: 57 | for child_name, child_module in module.named_modules(): 58 | if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": 59 | lora_name = prefix + "." + name + "." + child_name 60 | lora_name = lora_name.replace(".", "_") 61 | name_to_module[lora_name] = child_module 62 | 63 | for model, ratio in zip(models, ratios): 64 | print(f"loading: {model}") 65 | lora_sd = load_state_dict(model, merge_dtype) 66 | 67 | print(f"merging...") 68 | for key in tqdm(lora_sd.keys()): 69 | if "lora_down" in key: 70 | up_key = key.replace("lora_down", "lora_up") 71 | alpha_key = key[: key.index("lora_down")] + "alpha" 72 | 73 | # find original module for this lora 74 | module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" 75 | if module_name not in name_to_module: 76 | print(f"no module found for LoRA weight: {key}") 77 | continue 78 | module = name_to_module[module_name] 79 | # print(f"apply {key} to {module}") 80 | 81 | down_weight = lora_sd[key] 82 | up_weight = lora_sd[up_key] 83 | 84 | dim = down_weight.size()[0] 85 | alpha = lora_sd.get(alpha_key, dim) 86 | scale = alpha / dim 87 | 88 | # W <- W + U * D 89 | weight = module.weight 90 | # print(module_name, down_weight.size(), up_weight.size()) 91 | if len(weight.size()) == 2: 92 | # linear 93 | weight = weight + ratio * (up_weight @ down_weight) * scale 94 | elif down_weight.size()[2:4] == (1, 1): 95 | # conv2d 1x1 96 | weight = ( 97 | weight 98 | + ratio 99 | * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) 100 | * scale 101 | ) 102 | else: 103 | # conv2d 3x3 104 | conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) 105 | # print(conved.size(), weight.size(), module.stride, module.padding) 106 | weight = weight + ratio * conved * scale 107 | 108 | module.weight = torch.nn.Parameter(weight) 109 | 110 | 111 | def merge_lora_models(models, ratios, merge_dtype): 112 | base_alphas = {} # alpha for merged model 113 | base_dims = {} 114 | 115 | merged_sd = {} 116 | for model, ratio in zip(models, ratios): 117 | print(f"loading: {model}") 118 | lora_sd = load_state_dict(model, merge_dtype) 119 | 120 | # get alpha and dim 121 | alphas = {} # alpha for current model 122 | dims = {} # dims for current model 123 | for key in lora_sd.keys(): 124 | if "alpha" in key: 125 | lora_module_name = key[: key.rfind(".alpha")] 126 | alpha = float(lora_sd[key].detach().numpy()) 127 | alphas[lora_module_name] = alpha 128 | if lora_module_name not in base_alphas: 129 | base_alphas[lora_module_name] = alpha 130 | elif "lora_down" in key: 131 | lora_module_name = key[: key.rfind(".lora_down")] 132 | dim = lora_sd[key].size()[0] 133 | dims[lora_module_name] = dim 134 | if lora_module_name not in base_dims: 135 | base_dims[lora_module_name] = dim 136 | 137 | for lora_module_name in dims.keys(): 138 | if lora_module_name not in alphas: 139 | alpha = dims[lora_module_name] 140 | alphas[lora_module_name] = alpha 141 | if lora_module_name not in base_alphas: 142 | base_alphas[lora_module_name] = alpha 143 | 144 | print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") 145 | 146 | # merge 147 | print(f"merging...") 148 | for key in tqdm(lora_sd.keys()): 149 | if "alpha" in key: 150 | continue 151 | 152 | lora_module_name = key[: key.rfind(".lora_")] 153 | 154 | base_alpha = base_alphas[lora_module_name] 155 | alpha = alphas[lora_module_name] 156 | 157 | scale = math.sqrt(alpha / base_alpha) * ratio 158 | 159 | if key in merged_sd: 160 | assert ( 161 | merged_sd[key].size() == lora_sd[key].size() 162 | ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" 163 | merged_sd[key] = merged_sd[key] + lora_sd[key] * scale 164 | else: 165 | merged_sd[key] = lora_sd[key] * scale 166 | 167 | # set alpha to sd 168 | for lora_module_name, alpha in base_alphas.items(): 169 | key = lora_module_name + ".alpha" 170 | merged_sd[key] = torch.tensor(alpha) 171 | 172 | print("merged model") 173 | print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") 174 | 175 | return merged_sd 176 | 177 | 178 | def merge(args): 179 | assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" 180 | 181 | def str_to_dtype(p): 182 | if p == "float": 183 | return torch.float 184 | if p == "fp16": 185 | return torch.float16 186 | if p == "bf16": 187 | return torch.bfloat16 188 | return None 189 | 190 | merge_dtype = str_to_dtype(args.precision) 191 | save_dtype = str_to_dtype(args.save_precision) 192 | if save_dtype is None: 193 | save_dtype = merge_dtype 194 | 195 | if args.sd_model is not None: 196 | print(f"loading SD model: {args.sd_model}") 197 | 198 | ( 199 | text_model1, 200 | text_model2, 201 | vae, 202 | unet, 203 | logit_scale, 204 | ckpt_info, 205 | ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.sd_model, "cpu") 206 | 207 | merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype) 208 | 209 | print(f"saving SD model to: {args.save_to}") 210 | sdxl_model_util.save_stable_diffusion_checkpoint( 211 | args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, save_dtype 212 | ) 213 | else: 214 | state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) 215 | 216 | print(f"saving model to: {args.save_to}") 217 | save_to_file(args.save_to, state_dict, state_dict, save_dtype) 218 | 219 | 220 | def setup_parser() -> argparse.ArgumentParser: 221 | parser = argparse.ArgumentParser() 222 | parser.add_argument( 223 | "--save_precision", 224 | type=str, 225 | default=None, 226 | choices=[None, "float", "fp16", "bf16"], 227 | help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", 228 | ) 229 | parser.add_argument( 230 | "--precision", 231 | type=str, 232 | default="float", 233 | choices=["float", "fp16", "bf16"], 234 | help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", 235 | ) 236 | parser.add_argument( 237 | "--sd_model", 238 | type=str, 239 | default=None, 240 | help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", 241 | ) 242 | parser.add_argument( 243 | "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" 244 | ) 245 | parser.add_argument( 246 | "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" 247 | ) 248 | parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") 249 | 250 | return parser 251 | 252 | 253 | if __name__ == "__main__": 254 | parser = setup_parser() 255 | 256 | args = parser.parse_args() 257 | merge(args) 258 | -------------------------------------------------------------------------------- /networks/extract_lora_from_models.py: -------------------------------------------------------------------------------- 1 | # extract approximating LoRA by svd from two SD models 2 | # The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py 3 | # Thanks to cloneofsimo! 4 | 5 | import argparse 6 | import json 7 | import os 8 | import torch 9 | from safetensors.torch import load_file, save_file 10 | from tqdm import tqdm 11 | import library.model_util as model_util 12 | import library.sdxl_model_util as sdxl_model_util 13 | import lora 14 | 15 | 16 | CLAMP_QUANTILE = 0.99 17 | MIN_DIFF = 1e-4 18 | 19 | 20 | def save_to_file(file_name, model, state_dict, dtype): 21 | if dtype is not None: 22 | for key in list(state_dict.keys()): 23 | if type(state_dict[key]) == torch.Tensor: 24 | state_dict[key] = state_dict[key].to(dtype) 25 | 26 | if os.path.splitext(file_name)[1] == ".safetensors": 27 | save_file(model, file_name) 28 | else: 29 | torch.save(model, file_name) 30 | 31 | 32 | def svd(args): 33 | def str_to_dtype(p): 34 | if p == "float": 35 | return torch.float 36 | if p == "fp16": 37 | return torch.float16 38 | if p == "bf16": 39 | return torch.bfloat16 40 | return None 41 | 42 | assert args.v2 != args.sdxl or ( 43 | not args.v2 and not args.sdxl 44 | ), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" 45 | if args.v_parameterization is None: 46 | args.v_parameterization = args.v2 47 | 48 | save_dtype = str_to_dtype(args.save_precision) 49 | 50 | # load models 51 | if not args.sdxl: 52 | print(f"loading original SD model : {args.model_org}") 53 | text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) 54 | text_encoders_o = [text_encoder_o] 55 | print(f"loading tuned SD model : {args.model_tuned}") 56 | text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) 57 | text_encoders_t = [text_encoder_t] 58 | model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization) 59 | else: 60 | print(f"loading original SDXL model : {args.model_org}") 61 | text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( 62 | sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.model_org, "cpu" 63 | ) 64 | text_encoders_o = [text_encoder_o1, text_encoder_o2] 65 | print(f"loading original SDXL model : {args.model_tuned}") 66 | text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( 67 | sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.model_tuned, "cpu" 68 | ) 69 | text_encoders_t = [text_encoder_t1, text_encoder_t2] 70 | model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9 71 | 72 | # create LoRA network to extract weights: Use dim (rank) as alpha 73 | if args.conv_dim is None: 74 | kwargs = {} 75 | else: 76 | kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} 77 | 78 | lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs) 79 | lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs) 80 | assert len(lora_network_o.text_encoder_loras) == len( 81 | lora_network_t.text_encoder_loras 82 | ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " 83 | 84 | # get diffs 85 | diffs = {} 86 | text_encoder_different = False 87 | for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): 88 | lora_name = lora_o.lora_name 89 | module_o = lora_o.org_module 90 | module_t = lora_t.org_module 91 | diff = module_t.weight - module_o.weight 92 | 93 | # Text Encoder might be same 94 | if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF: 95 | text_encoder_different = True 96 | print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}") 97 | 98 | diff = diff.float() 99 | diffs[lora_name] = diff 100 | 101 | if not text_encoder_different: 102 | print("Text encoder is same. Extract U-Net only.") 103 | lora_network_o.text_encoder_loras = [] 104 | diffs = {} 105 | 106 | for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): 107 | lora_name = lora_o.lora_name 108 | module_o = lora_o.org_module 109 | module_t = lora_t.org_module 110 | diff = module_t.weight - module_o.weight 111 | diff = diff.float() 112 | 113 | if args.device: 114 | diff = diff.to(args.device) 115 | 116 | diffs[lora_name] = diff 117 | 118 | # make LoRA with svd 119 | print("calculating by svd") 120 | lora_weights = {} 121 | with torch.no_grad(): 122 | for lora_name, mat in tqdm(list(diffs.items())): 123 | # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3 124 | conv2d = len(mat.size()) == 4 125 | kernel_size = None if not conv2d else mat.size()[2:4] 126 | conv2d_3x3 = conv2d and kernel_size != (1, 1) 127 | 128 | rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim 129 | out_dim, in_dim = mat.size()[0:2] 130 | 131 | if args.device: 132 | mat = mat.to(args.device) 133 | 134 | # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) 135 | rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim 136 | 137 | if conv2d: 138 | if conv2d_3x3: 139 | mat = mat.flatten(start_dim=1) 140 | else: 141 | mat = mat.squeeze() 142 | 143 | U, S, Vh = torch.linalg.svd(mat) 144 | 145 | U = U[:, :rank] 146 | S = S[:rank] 147 | U = U @ torch.diag(S) 148 | 149 | Vh = Vh[:rank, :] 150 | 151 | dist = torch.cat([U.flatten(), Vh.flatten()]) 152 | hi_val = torch.quantile(dist, CLAMP_QUANTILE) 153 | low_val = -hi_val 154 | 155 | U = U.clamp(low_val, hi_val) 156 | Vh = Vh.clamp(low_val, hi_val) 157 | 158 | if conv2d: 159 | U = U.reshape(out_dim, rank, 1, 1) 160 | Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) 161 | 162 | U = U.to("cpu").contiguous() 163 | Vh = Vh.to("cpu").contiguous() 164 | 165 | lora_weights[lora_name] = (U, Vh) 166 | 167 | # make state dict for LoRA 168 | lora_sd = {} 169 | for lora_name, (up_weight, down_weight) in lora_weights.items(): 170 | lora_sd[lora_name + ".lora_up.weight"] = up_weight 171 | lora_sd[lora_name + ".lora_down.weight"] = down_weight 172 | lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) 173 | 174 | # load state dict to LoRA and save it 175 | lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd) 176 | lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict 177 | 178 | info = lora_network_save.load_state_dict(lora_sd) 179 | print(f"Loading extracted LoRA weights: {info}") 180 | 181 | dir_name = os.path.dirname(args.save_to) 182 | if dir_name and not os.path.exists(dir_name): 183 | os.makedirs(dir_name, exist_ok=True) 184 | 185 | # minimum metadata 186 | net_kwargs = {} 187 | if args.conv_dim is not None: 188 | net_kwargs["conv_dim"] = args.conv_dim 189 | net_kwargs["conv_alpha"] = args.conv_dim 190 | 191 | metadata = { 192 | "ss_v2": str(args.v2), 193 | "ss_base_model_version": model_version, 194 | "ss_network_module": "networks.lora", 195 | "ss_network_dim": str(args.dim), 196 | "ss_network_alpha": str(args.dim), 197 | "ss_network_args": json.dumps(net_kwargs), 198 | } 199 | 200 | lora_network_save.save_weights(args.save_to, save_dtype, metadata) 201 | print(f"LoRA weights are saved to: {args.save_to}") 202 | 203 | 204 | def setup_parser() -> argparse.ArgumentParser: 205 | parser = argparse.ArgumentParser() 206 | parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") 207 | parser.add_argument( 208 | "--v_parameterization", 209 | type=bool, 210 | default=None, 211 | help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)", 212 | ) 213 | parser.add_argument( 214 | "--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む" 215 | ) 216 | parser.add_argument( 217 | "--save_precision", 218 | type=str, 219 | default=None, 220 | choices=[None, "float", "fp16", "bf16"], 221 | help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat", 222 | ) 223 | parser.add_argument( 224 | "--model_org", 225 | type=str, 226 | default=None, 227 | help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors", 228 | ) 229 | parser.add_argument( 230 | "--model_tuned", 231 | type=str, 232 | default=None, 233 | help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", 234 | ) 235 | parser.add_argument( 236 | "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" 237 | ) 238 | parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") 239 | parser.add_argument( 240 | "--conv_dim", 241 | type=int, 242 | default=None, 243 | help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)", 244 | ) 245 | parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") 246 | 247 | return parser 248 | 249 | 250 | if __name__ == "__main__": 251 | parser = setup_parser() 252 | 253 | args = parser.parse_args() 254 | svd(args) 255 | -------------------------------------------------------------------------------- /finetune/prepare_buckets_latents.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | 5 | from pathlib import Path 6 | from typing import List 7 | from tqdm import tqdm 8 | import numpy as np 9 | from PIL import Image 10 | import cv2 11 | import torch 12 | from torchvision import transforms 13 | 14 | import library.model_util as model_util 15 | import library.train_util as train_util 16 | 17 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | IMAGE_TRANSFORMS = transforms.Compose( 20 | [ 21 | transforms.ToTensor(), 22 | transforms.Normalize([0.5], [0.5]), 23 | ] 24 | ) 25 | 26 | 27 | def collate_fn_remove_corrupted(batch): 28 | """Collate function that allows to remove corrupted examples in the 29 | dataloader. It expects that the dataloader returns 'None' when that occurs. 30 | The 'None's in the batch are removed. 31 | """ 32 | # Filter out all the Nones (corrupted examples) 33 | batch = list(filter(lambda x: x is not None, batch)) 34 | return batch 35 | 36 | 37 | def get_npz_filename(data_dir, image_key, is_full_path, recursive): 38 | if is_full_path: 39 | base_name = os.path.splitext(os.path.basename(image_key))[0] 40 | relative_path = os.path.relpath(os.path.dirname(image_key), data_dir) 41 | else: 42 | base_name = image_key 43 | relative_path = "" 44 | 45 | if recursive and relative_path: 46 | return os.path.join(data_dir, relative_path, base_name) + ".npz" 47 | else: 48 | return os.path.join(data_dir, base_name) + ".npz" 49 | 50 | 51 | def main(args): 52 | # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" 53 | if args.bucket_reso_steps % 8 > 0: 54 | print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") 55 | 56 | train_data_dir_path = Path(args.train_data_dir) 57 | image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] 58 | print(f"found {len(image_paths)} images.") 59 | 60 | if os.path.exists(args.in_json): 61 | print(f"loading existing metadata: {args.in_json}") 62 | with open(args.in_json, "rt", encoding="utf-8") as f: 63 | metadata = json.load(f) 64 | else: 65 | print(f"no metadata / メタデータファイルがありません: {args.in_json}") 66 | return 67 | 68 | weight_dtype = torch.float32 69 | if args.mixed_precision == "fp16": 70 | weight_dtype = torch.float16 71 | elif args.mixed_precision == "bf16": 72 | weight_dtype = torch.bfloat16 73 | 74 | vae = model_util.load_vae(args.model_name_or_path, weight_dtype) 75 | vae.eval() 76 | vae.to(DEVICE, dtype=weight_dtype) 77 | 78 | # bucketのサイズを計算する 79 | max_reso = tuple([int(t) for t in args.max_resolution.split(",")]) 80 | assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" 81 | 82 | bucket_manager = train_util.BucketManager( 83 | args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps 84 | ) 85 | if not args.bucket_no_upscale: 86 | bucket_manager.make_buckets() 87 | else: 88 | print( 89 | "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" 90 | ) 91 | 92 | # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する 93 | img_ar_errors = [] 94 | 95 | def process_batch(is_last): 96 | for bucket in bucket_manager.buckets: 97 | if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: 98 | train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False) 99 | bucket.clear() 100 | 101 | # 読み込みの高速化のためにDataLoaderを使うオプション 102 | if args.max_data_loader_n_workers is not None: 103 | dataset = train_util.ImageLoadingDataset(image_paths) 104 | data = torch.utils.data.DataLoader( 105 | dataset, 106 | batch_size=1, 107 | shuffle=False, 108 | num_workers=args.max_data_loader_n_workers, 109 | collate_fn=collate_fn_remove_corrupted, 110 | drop_last=False, 111 | ) 112 | else: 113 | data = [[(None, ip)] for ip in image_paths] 114 | 115 | bucket_counts = {} 116 | for data_entry in tqdm(data, smoothing=0.0): 117 | if data_entry[0] is None: 118 | continue 119 | 120 | img_tensor, image_path = data_entry[0] 121 | if img_tensor is not None: 122 | image = transforms.functional.to_pil_image(img_tensor) 123 | else: 124 | try: 125 | image = Image.open(image_path) 126 | if image.mode != "RGB": 127 | image = image.convert("RGB") 128 | except Exception as e: 129 | print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") 130 | continue 131 | 132 | image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] 133 | if image_key not in metadata: 134 | metadata[image_key] = {} 135 | 136 | # 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変 137 | 138 | reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height) 139 | img_ar_errors.append(abs(ar_error)) 140 | bucket_counts[reso] = bucket_counts.get(reso, 0) + 1 141 | 142 | # メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て 143 | metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8) 144 | 145 | if not args.bucket_no_upscale: 146 | # upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する 147 | assert ( 148 | resized_size[0] == reso[0] or resized_size[1] == reso[1] 149 | ), f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}" 150 | assert ( 151 | resized_size[0] >= reso[0] and resized_size[1] >= reso[1] 152 | ), f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}" 153 | 154 | assert ( 155 | resized_size[0] >= reso[0] and resized_size[1] >= reso[1] 156 | ), f"internal error resized size is small: {resized_size}, {reso}" 157 | 158 | # 既に存在するファイルがあればshape等を確認して同じならskipする 159 | npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive) 160 | if args.skip_existing: 161 | if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug): 162 | continue 163 | 164 | # バッチへ追加 165 | image_info = train_util.ImageInfo(image_key, 1, "", False, image_path) 166 | image_info.latents_npz = npz_file_name 167 | image_info.bucket_reso = reso 168 | image_info.resized_size = resized_size 169 | image_info.image = image 170 | bucket_manager.add_image(reso, image_info) 171 | 172 | # バッチを推論するか判定して推論する 173 | process_batch(False) 174 | 175 | # 残りを処理する 176 | process_batch(True) 177 | 178 | bucket_manager.sort() 179 | for i, reso in enumerate(bucket_manager.resos): 180 | count = bucket_counts.get(reso, 0) 181 | if count > 0: 182 | print(f"bucket {i} {reso}: {count}") 183 | img_ar_errors = np.array(img_ar_errors) 184 | print(f"mean ar error: {np.mean(img_ar_errors)}") 185 | 186 | # metadataを書き出して終わり 187 | print(f"writing metadata: {args.out_json}") 188 | with open(args.out_json, "wt", encoding="utf-8") as f: 189 | json.dump(metadata, f, indent=2) 190 | print("done!") 191 | 192 | 193 | def setup_parser() -> argparse.ArgumentParser: 194 | parser = argparse.ArgumentParser() 195 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 196 | parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") 197 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 198 | parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") 199 | parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)") 200 | parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") 201 | parser.add_argument( 202 | "--max_data_loader_n_workers", 203 | type=int, 204 | default=None, 205 | help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", 206 | ) 207 | parser.add_argument( 208 | "--max_resolution", 209 | type=str, 210 | default="512,512", 211 | help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)", 212 | ) 213 | parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") 214 | parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度") 215 | parser.add_argument( 216 | "--bucket_reso_steps", 217 | type=int, 218 | default=64, 219 | help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", 220 | ) 221 | parser.add_argument( 222 | "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" 223 | ) 224 | parser.add_argument( 225 | "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" 226 | ) 227 | parser.add_argument( 228 | "--full_path", 229 | action="store_true", 230 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", 231 | ) 232 | parser.add_argument( 233 | "--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する" 234 | ) 235 | parser.add_argument( 236 | "--skip_existing", 237 | action="store_true", 238 | help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", 239 | ) 240 | parser.add_argument( 241 | "--recursive", 242 | action="store_true", 243 | help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す", 244 | ) 245 | 246 | return parser 247 | 248 | 249 | if __name__ == "__main__": 250 | parser = setup_parser() 251 | 252 | args = parser.parse_args() 253 | main(args) 254 | -------------------------------------------------------------------------------- /finetune/blip/blip.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | ''' 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | 11 | # from models.vit import VisionTransformer, interpolate_pos_embed 12 | # from models.med import BertConfig, BertModel, BertLMHeadModel 13 | from blip.vit import VisionTransformer, interpolate_pos_embed 14 | from blip.med import BertConfig, BertModel, BertLMHeadModel 15 | from transformers import BertTokenizer 16 | 17 | import torch 18 | from torch import nn 19 | import torch.nn.functional as F 20 | 21 | import os 22 | from urllib.parse import urlparse 23 | from timm.models.hub import download_cached_file 24 | 25 | class BLIP_Base(nn.Module): 26 | def __init__(self, 27 | med_config = 'configs/med_config.json', 28 | image_size = 224, 29 | vit = 'base', 30 | vit_grad_ckpt = False, 31 | vit_ckpt_layer = 0, 32 | ): 33 | """ 34 | Args: 35 | med_config (str): path for the mixture of encoder-decoder model's configuration file 36 | image_size (int): input image size 37 | vit (str): model size of vision transformer 38 | """ 39 | super().__init__() 40 | 41 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 42 | self.tokenizer = init_tokenizer() 43 | med_config = BertConfig.from_json_file(med_config) 44 | med_config.encoder_width = vision_width 45 | self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) 46 | 47 | 48 | def forward(self, image, caption, mode): 49 | 50 | assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal" 51 | text = self.tokenizer(caption, return_tensors="pt").to(image.device) 52 | 53 | if mode=='image': 54 | # return image features 55 | image_embeds = self.visual_encoder(image) 56 | return image_embeds 57 | 58 | elif mode=='text': 59 | # return text features 60 | text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, 61 | return_dict = True, mode = 'text') 62 | return text_output.last_hidden_state 63 | 64 | elif mode=='multimodal': 65 | # return multimodel features 66 | image_embeds = self.visual_encoder(image) 67 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 68 | 69 | text.input_ids[:,0] = self.tokenizer.enc_token_id 70 | output = self.text_encoder(text.input_ids, 71 | attention_mask = text.attention_mask, 72 | encoder_hidden_states = image_embeds, 73 | encoder_attention_mask = image_atts, 74 | return_dict = True, 75 | ) 76 | return output.last_hidden_state 77 | 78 | 79 | 80 | class BLIP_Decoder(nn.Module): 81 | def __init__(self, 82 | med_config = 'configs/med_config.json', 83 | image_size = 384, 84 | vit = 'base', 85 | vit_grad_ckpt = False, 86 | vit_ckpt_layer = 0, 87 | prompt = 'a picture of ', 88 | ): 89 | """ 90 | Args: 91 | med_config (str): path for the mixture of encoder-decoder model's configuration file 92 | image_size (int): input image size 93 | vit (str): model size of vision transformer 94 | """ 95 | super().__init__() 96 | 97 | self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) 98 | self.tokenizer = init_tokenizer() 99 | med_config = BertConfig.from_json_file(med_config) 100 | med_config.encoder_width = vision_width 101 | self.text_decoder = BertLMHeadModel(config=med_config) 102 | 103 | self.prompt = prompt 104 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1 105 | 106 | 107 | def forward(self, image, caption): 108 | 109 | image_embeds = self.visual_encoder(image) 110 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 111 | 112 | text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device) 113 | 114 | text.input_ids[:,0] = self.tokenizer.bos_token_id 115 | 116 | decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100) 117 | decoder_targets[:,:self.prompt_length] = -100 118 | 119 | decoder_output = self.text_decoder(text.input_ids, 120 | attention_mask = text.attention_mask, 121 | encoder_hidden_states = image_embeds, 122 | encoder_attention_mask = image_atts, 123 | labels = decoder_targets, 124 | return_dict = True, 125 | ) 126 | loss_lm = decoder_output.loss 127 | 128 | return loss_lm 129 | 130 | def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0): 131 | image_embeds = self.visual_encoder(image) 132 | 133 | if not sample: 134 | image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) 135 | 136 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) 137 | model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} 138 | 139 | prompt = [self.prompt] * image.size(0) 140 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) 141 | input_ids[:,0] = self.tokenizer.bos_token_id 142 | input_ids = input_ids[:, :-1] 143 | 144 | if sample: 145 | #nucleus sampling 146 | outputs = self.text_decoder.generate(input_ids=input_ids, 147 | max_length=max_length, 148 | min_length=min_length, 149 | do_sample=True, 150 | top_p=top_p, 151 | num_return_sequences=1, 152 | eos_token_id=self.tokenizer.sep_token_id, 153 | pad_token_id=self.tokenizer.pad_token_id, 154 | repetition_penalty=1.1, 155 | **model_kwargs) 156 | else: 157 | #beam search 158 | outputs = self.text_decoder.generate(input_ids=input_ids, 159 | max_length=max_length, 160 | min_length=min_length, 161 | num_beams=num_beams, 162 | eos_token_id=self.tokenizer.sep_token_id, 163 | pad_token_id=self.tokenizer.pad_token_id, 164 | repetition_penalty=repetition_penalty, 165 | **model_kwargs) 166 | 167 | captions = [] 168 | for output in outputs: 169 | caption = self.tokenizer.decode(output, skip_special_tokens=True) 170 | captions.append(caption[len(self.prompt):]) 171 | return captions 172 | 173 | 174 | def blip_decoder(pretrained='',**kwargs): 175 | model = BLIP_Decoder(**kwargs) 176 | if pretrained: 177 | model,msg = load_checkpoint(model,pretrained) 178 | assert(len(msg.missing_keys)==0) 179 | return model 180 | 181 | def blip_feature_extractor(pretrained='',**kwargs): 182 | model = BLIP_Base(**kwargs) 183 | if pretrained: 184 | model,msg = load_checkpoint(model,pretrained) 185 | assert(len(msg.missing_keys)==0) 186 | return model 187 | 188 | def init_tokenizer(): 189 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 190 | tokenizer.add_special_tokens({'bos_token':'[DEC]'}) 191 | tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) 192 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 193 | return tokenizer 194 | 195 | 196 | def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): 197 | 198 | assert vit in ['base', 'large'], "vit parameter must be base or large" 199 | if vit=='base': 200 | vision_width = 768 201 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, 202 | num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, 203 | drop_path_rate=0 or drop_path_rate 204 | ) 205 | elif vit=='large': 206 | vision_width = 1024 207 | visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, 208 | num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, 209 | drop_path_rate=0.1 or drop_path_rate 210 | ) 211 | return visual_encoder, vision_width 212 | 213 | def is_url(url_or_filename): 214 | parsed = urlparse(url_or_filename) 215 | return parsed.scheme in ("http", "https") 216 | 217 | def load_checkpoint(model,url_or_filename): 218 | if is_url(url_or_filename): 219 | cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) 220 | checkpoint = torch.load(cached_file, map_location='cpu') 221 | elif os.path.isfile(url_or_filename): 222 | checkpoint = torch.load(url_or_filename, map_location='cpu') 223 | else: 224 | raise RuntimeError('checkpoint url or path is invalid') 225 | 226 | state_dict = checkpoint['model'] 227 | 228 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) 229 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): 230 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], 231 | model.visual_encoder_m) 232 | for key in model.state_dict().keys(): 233 | if key in state_dict.keys(): 234 | if state_dict[key].shape!=model.state_dict()[key].shape: 235 | del state_dict[key] 236 | 237 | msg = model.load_state_dict(state_dict,strict=False) 238 | print('load checkpoint from %s'%url_or_filename) 239 | return model,msg 240 | 241 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------