├── .github ├── dependabot.yml └── workflows │ └── typos.yml ├── .gitignore ├── LICENSE.md ├── README-ja.md ├── README.md ├── XTI_hijack.py ├── _typos.toml ├── bitsandbytes_windows ├── cextension.py ├── libbitsandbytes_cpu.dll ├── libbitsandbytes_cuda116.dll ├── libbitsandbytes_cuda118.dll └── main.py ├── docs ├── config_README-ja.md ├── fine_tune_README_ja.md ├── gen_img_README-ja.md ├── train_README-ja.md ├── train_README-zh.md ├── train_db_README-ja.md ├── train_db_README-zh.md ├── train_lllite_README-ja.md ├── train_lllite_README.md ├── train_network_README-ja.md ├── train_network_README-zh.md └── train_ti_README-ja.md ├── fine_tune.py ├── finetune ├── blip │ ├── blip.py │ ├── med.py │ ├── med_config.json │ └── vit.py ├── clean_captions_and_tags.py ├── hypernetwork_nai.py ├── make_captions.py ├── make_captions_by_git.py ├── merge_captions_to_metadata.py ├── merge_dd_tags_to_metadata.py ├── prepare_buckets_latents.py ├── tag_images_by_wd14_tagger.py └── wd14_ordered_tagger.py ├── gen_img_diffusers.py ├── library ├── __init__.py ├── attention_processors.py ├── config_util.py ├── custom_train_functions.py ├── huggingface_util.py ├── hypernetwork.py ├── ipex │ ├── __init__.py │ ├── attention.py │ ├── diffusers.py │ ├── gradscaler.py │ └── hijacks.py ├── ipex_interop.py ├── lpw_stable_diffusion.py ├── model_util.py ├── original_unet.py ├── sai_model_spec.py ├── sdxl_lpw_stable_diffusion.py ├── sdxl_model_util.py ├── sdxl_original_unet.py ├── sdxl_train_util.py ├── slicing_vae.py ├── train_util.py └── utils.py ├── networks ├── check_lora_weights.py ├── control_net_lllite.py ├── control_net_lllite_for_train.py ├── dylora.py ├── extract_lora_from_dylora.py ├── extract_lora_from_models.py ├── lora.py ├── lora_diffusers.py ├── lora_fa.py ├── lora_interrogator.py ├── merge_lora.py ├── merge_lora_old.py ├── oft.py ├── resize_lora.py ├── sdxl_merge_lora.py └── svd_merge_lora.py ├── notebook ├── config_file.toml └── train.ipynb ├── requirements.txt ├── sdxl_gen_img.py ├── sdxl_minimal_inference.py ├── sdxl_train.py ├── sdxl_train_control_net_lllite.py ├── sdxl_train_control_net_lllite_old.py ├── sdxl_train_network.py ├── sdxl_train_textual_inversion.py ├── setup.py ├── tools ├── cache_latents.py ├── cache_text_encoder_outputs.py ├── canny.py ├── convert_diffusers20_original_sd.py ├── detect_face_rotate.py ├── latent_upscaler.py ├── merge_models.py ├── original_control_net.py ├── resize_images_to_resolution.py └── show_metadata.py ├── train_controlnet.py ├── train_db.py ├── train_network.py ├── train_textual_inversion.py └── train_textual_inversion_XTI.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | --- 2 | version: 2 3 | updates: 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "monthly" 8 | -------------------------------------------------------------------------------- /.github/workflows/typos.yml: -------------------------------------------------------------------------------- 1 | --- 2 | # yamllint disable rule:line-length 3 | name: Typos 4 | 5 | on: # yamllint disable-line rule:truthy 6 | push: 7 | pull_request: 8 | types: 9 | - opened 10 | - synchronize 11 | - reopened 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | 20 | - name: typos-action 21 | uses: crate-ci/typos@v1.16.26 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | logs 2 | __pycache__ 3 | wd14_tagger_model 4 | venv 5 | *.egg-info 6 | build 7 | .vscode 8 | wandb 9 | -------------------------------------------------------------------------------- /README-ja.md: -------------------------------------------------------------------------------- 1 | SDXLがサポートされました。sdxlブランチはmainブランチにマージされました。リポジトリを更新したときにはUpgradeの手順を実行してください。また accelerate のバージョンが上がっていますので、accelerate config を再度実行してください。 2 | 3 | SDXL学習については[こちら](./README.md#sdxl-training)をご覧ください(英語です)。 4 | 5 | ## リポジトリについて 6 | Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。 7 | 8 | [README in English](./README.md) ←更新情報はこちらにあります 9 | 10 | GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。 11 | 12 | 以下のスクリプトがあります。 13 | 14 | * DreamBooth、U-NetおよびText Encoderの学習をサポート 15 | * fine-tuning、同上 16 | * LoRAの学習をサポート 17 | * 画像生成 18 | * モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換) 19 | 20 | ## 使用法について 21 | 22 | * [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど 23 | * [データセット設定](./docs/config_README-ja.md) 24 | * [DreamBoothの学習について](./docs/train_db_README-ja.md) 25 | * [fine-tuningのガイド](./docs/fine_tune_README_ja.md): 26 | * [LoRAの学習について](./docs/train_network_README-ja.md) 27 | * [Textual Inversionの学習について](./docs/train_ti_README-ja.md) 28 | * [画像生成スクリプト](./docs/gen_img_README-ja.md) 29 | * note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad) 30 | 31 | ## Windowsでの動作に必要なプログラム 32 | 33 | Python 3.10.6およびGitが必要です。 34 | 35 | - Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe 36 | - git: https://git-scm.com/download/win 37 | 38 | PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。 39 | (venvに限らずスクリプトの実行が可能になりますので注意してください。) 40 | 41 | - PowerShellを管理者として開きます。 42 | - 「Set-ExecutionPolicy Unrestricted」と入力し、Yと答えます。 43 | - 管理者のPowerShellを閉じます。 44 | 45 | ## Windows環境でのインストール 46 | 47 | スクリプトはPyTorch 2.0.1でテストしています。PyTorch 1.12.1でも動作すると思われます。 48 | 49 | 以下の例ではPyTorchは2.0.1/CUDA 11.8版をインストールします。CUDA 11.6版やPyTorch 1.12.1を使う場合は適宜書き換えください。 50 | 51 | (なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。) 52 | 53 | PowerShellを使う場合、通常の(管理者ではない)PowerShellを開き以下を順に実行します。 54 | 55 | ```powershell 56 | git clone https://github.com/kohya-ss/sd-scripts.git 57 | cd sd-scripts 58 | 59 | python -m venv venv 60 | .\venv\Scripts\activate 61 | 62 | pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 63 | pip install --upgrade -r requirements.txt 64 | pip install xformers==0.0.20 65 | 66 | accelerate config 67 | ``` 68 | 69 | コマンドプロンプトでも同一です。 70 | 71 | (注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。) 72 | 73 | accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。) 74 | 75 | ※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます(……)。数字キーの0、1、2……で選択できますので、そちらを使ってください。 76 | 77 | ```txt 78 | - This machine 79 | - No distributed training 80 | - NO 81 | - NO 82 | - NO 83 | - all 84 | - fp16 85 | ``` 86 | 87 | ※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問( 88 | ``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。) 89 | 90 | ### オプション:`bitsandbytes`(8bit optimizer)を使う 91 | 92 | `bitsandbytes`はオプションになりました。Linuxでは通常通りpipでインストールできます(0.41.1または以降のバージョンを推奨)。 93 | 94 | Windowsでは0.35.0または0.41.1を推奨します。 95 | 96 | - `bitsandbytes` 0.35.0: 安定しているとみられるバージョンです。AdamW8bitは使用できますが、他のいくつかの8bit optimizer、学習時の`full_bf16`オプションは使用できません。 97 | - `bitsandbytes` 0.41.1: Lion8bit、PagedAdamW8bit、PagedLion8bitをサポートします。`full_bf16`が使用できます。 98 | 99 | 注:`bitsandbytes` 0.35.0から0.41.0までのバージョンには問題があるようです。 https://github.com/TimDettmers/bitsandbytes/issues/659 100 | 101 | 以下の手順に従い、`bitsandbytes`をインストールしてください。 102 | 103 | ### 0.35.0を使う場合 104 | 105 | PowerShellの例です。コマンドプロンプトではcpの代わりにcopyを使ってください。 106 | 107 | ```powershell 108 | cd sd-scripts 109 | .\venv\Scripts\activate 110 | pip install bitsandbytes==0.35.0 111 | 112 | cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ 113 | cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py 114 | cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py 115 | ``` 116 | 117 | ### 0.41.1を使う場合 118 | 119 | jllllll氏の配布されている[こちら](https://github.com/jllllll/bitsandbytes-windows-webui) または他の場所から、Windows用のwhlファイルをインストールしてください。 120 | 121 | ```powershell 122 | python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui 123 | ``` 124 | 125 | ## アップグレード 126 | 127 | 新しいリリースがあった場合、以下のコマンドで更新できます。 128 | 129 | ```powershell 130 | cd sd-scripts 131 | git pull 132 | .\venv\Scripts\activate 133 | pip install --use-pep517 --upgrade -r requirements.txt 134 | ``` 135 | 136 | コマンドが成功すれば新しいバージョンが使用できます。 137 | 138 | ## 謝意 139 | 140 | LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。 141 | 142 | Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。 143 | 144 | ## ライセンス 145 | 146 | スクリプトのライセンスはASL 2.0ですが(Diffusersおよびcloneofsimo氏のリポジトリ由来のものも同様)、一部他のライセンスのコードを含みます。 147 | 148 | [Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT 149 | 150 | [bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT 151 | 152 | [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause 153 | 154 | 155 | -------------------------------------------------------------------------------- /XTI_hijack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from library.ipex_interop import init_ipex 3 | 4 | init_ipex() 5 | from typing import Union, List, Optional, Dict, Any, Tuple 6 | from diffusers.models.unet_2d_condition import UNet2DConditionOutput 7 | 8 | from library.original_unet import SampleOutput 9 | 10 | 11 | def unet_forward_XTI( 12 | self, 13 | sample: torch.FloatTensor, 14 | timestep: Union[torch.Tensor, float, int], 15 | encoder_hidden_states: torch.Tensor, 16 | class_labels: Optional[torch.Tensor] = None, 17 | return_dict: bool = True, 18 | ) -> Union[Dict, Tuple]: 19 | r""" 20 | Args: 21 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 22 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 23 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 24 | return_dict (`bool`, *optional*, defaults to `True`): 25 | Whether or not to return a dict instead of a plain tuple. 26 | 27 | Returns: 28 | `SampleOutput` or `tuple`: 29 | `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. 30 | """ 31 | # By default samples have to be AT least a multiple of the overall upsampling factor. 32 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 33 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 34 | # on the fly if necessary. 35 | # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある 36 | # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する 37 | # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い 38 | default_overall_up_factor = 2**self.num_upsamplers 39 | 40 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 41 | # 64で割り切れないときはupsamplerにサイズを伝える 42 | forward_upsample_size = False 43 | upsample_size = None 44 | 45 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 46 | # logger.info("Forward upsample size to force interpolation output size.") 47 | forward_upsample_size = True 48 | 49 | # 1. time 50 | timesteps = timestep 51 | timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 52 | 53 | t_emb = self.time_proj(timesteps) 54 | 55 | # timesteps does not contain any weights and will always return f32 tensors 56 | # but time_embedding might actually be running in fp16. so we need to cast here. 57 | # there might be better ways to encapsulate this. 58 | # timestepsは重みを含まないので常にfloat32のテンソルを返す 59 | # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある 60 | # time_projでキャストしておけばいいんじゃね? 61 | t_emb = t_emb.to(dtype=self.dtype) 62 | emb = self.time_embedding(t_emb) 63 | 64 | # 2. pre-process 65 | sample = self.conv_in(sample) 66 | 67 | # 3. down 68 | down_block_res_samples = (sample,) 69 | down_i = 0 70 | for downsample_block in self.down_blocks: 71 | # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 72 | # まあこちらのほうがわかりやすいかもしれない 73 | if downsample_block.has_cross_attention: 74 | sample, res_samples = downsample_block( 75 | hidden_states=sample, 76 | temb=emb, 77 | encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2], 78 | ) 79 | down_i += 2 80 | else: 81 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 82 | 83 | down_block_res_samples += res_samples 84 | 85 | # 4. mid 86 | sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) 87 | 88 | # 5. up 89 | up_i = 7 90 | for i, upsample_block in enumerate(self.up_blocks): 91 | is_final_block = i == len(self.up_blocks) - 1 92 | 93 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 94 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection 95 | 96 | # if we have not reached the final block and need to forward the upsample size, we do it here 97 | # 前述のように最後のブロック以外ではupsample_sizeを伝える 98 | if not is_final_block and forward_upsample_size: 99 | upsample_size = down_block_res_samples[-1].shape[2:] 100 | 101 | if upsample_block.has_cross_attention: 102 | sample = upsample_block( 103 | hidden_states=sample, 104 | temb=emb, 105 | res_hidden_states_tuple=res_samples, 106 | encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3], 107 | upsample_size=upsample_size, 108 | ) 109 | up_i += 3 110 | else: 111 | sample = upsample_block( 112 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 113 | ) 114 | 115 | # 6. post-process 116 | sample = self.conv_norm_out(sample) 117 | sample = self.conv_act(sample) 118 | sample = self.conv_out(sample) 119 | 120 | if not return_dict: 121 | return (sample,) 122 | 123 | return SampleOutput(sample=sample) 124 | 125 | 126 | def downblock_forward_XTI( 127 | self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None 128 | ): 129 | output_states = () 130 | i = 0 131 | 132 | for resnet, attn in zip(self.resnets, self.attentions): 133 | if self.training and self.gradient_checkpointing: 134 | 135 | def create_custom_forward(module, return_dict=None): 136 | def custom_forward(*inputs): 137 | if return_dict is not None: 138 | return module(*inputs, return_dict=return_dict) 139 | else: 140 | return module(*inputs) 141 | 142 | return custom_forward 143 | 144 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 145 | hidden_states = torch.utils.checkpoint.checkpoint( 146 | create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] 147 | )[0] 148 | else: 149 | hidden_states = resnet(hidden_states, temb) 150 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample 151 | 152 | output_states += (hidden_states,) 153 | i += 1 154 | 155 | if self.downsamplers is not None: 156 | for downsampler in self.downsamplers: 157 | hidden_states = downsampler(hidden_states) 158 | 159 | output_states += (hidden_states,) 160 | 161 | return hidden_states, output_states 162 | 163 | 164 | def upblock_forward_XTI( 165 | self, 166 | hidden_states, 167 | res_hidden_states_tuple, 168 | temb=None, 169 | encoder_hidden_states=None, 170 | upsample_size=None, 171 | ): 172 | i = 0 173 | for resnet, attn in zip(self.resnets, self.attentions): 174 | # pop res hidden states 175 | res_hidden_states = res_hidden_states_tuple[-1] 176 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 177 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 178 | 179 | if self.training and self.gradient_checkpointing: 180 | 181 | def create_custom_forward(module, return_dict=None): 182 | def custom_forward(*inputs): 183 | if return_dict is not None: 184 | return module(*inputs, return_dict=return_dict) 185 | else: 186 | return module(*inputs) 187 | 188 | return custom_forward 189 | 190 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 191 | hidden_states = torch.utils.checkpoint.checkpoint( 192 | create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] 193 | )[0] 194 | else: 195 | hidden_states = resnet(hidden_states, temb) 196 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample 197 | 198 | i += 1 199 | 200 | if self.upsamplers is not None: 201 | for upsampler in self.upsamplers: 202 | hidden_states = upsampler(hidden_states, upsample_size) 203 | 204 | return hidden_states 205 | -------------------------------------------------------------------------------- /_typos.toml: -------------------------------------------------------------------------------- 1 | # Files for typos 2 | # Instruction: https://github.com/marketplace/actions/typos-action#getting-started 3 | 4 | [default.extend-identifiers] 5 | 6 | [default.extend-words] 7 | NIN="NIN" 8 | parms="parms" 9 | nin="nin" 10 | extention="extention" # Intentionally left 11 | nd="nd" 12 | shs="shs" 13 | sts="sts" 14 | scs="scs" 15 | cpc="cpc" 16 | coc="coc" 17 | cic="cic" 18 | msm="msm" 19 | usu="usu" 20 | ici="ici" 21 | lvl="lvl" 22 | dii="dii" 23 | muk="muk" 24 | ori="ori" 25 | hru="hru" 26 | rik="rik" 27 | koo="koo" 28 | yos="yos" 29 | wn="wn" 30 | 31 | 32 | [files] 33 | extend-exclude = ["_typos.toml", "venv"] 34 | -------------------------------------------------------------------------------- /bitsandbytes_windows/cextension.py: -------------------------------------------------------------------------------- 1 | import ctypes as ct 2 | from pathlib import Path 3 | from warnings import warn 4 | 5 | from .cuda_setup.main import evaluate_cuda_setup 6 | 7 | 8 | class CUDALibrary_Singleton(object): 9 | _instance = None 10 | 11 | def __init__(self): 12 | raise RuntimeError("Call get_instance() instead") 13 | 14 | def initialize(self): 15 | binary_name = evaluate_cuda_setup() 16 | package_dir = Path(__file__).parent 17 | binary_path = package_dir / binary_name 18 | 19 | if not binary_path.exists(): 20 | print(f"CUDA SETUP: TODO: compile library for specific version: {binary_name}") 21 | legacy_binary_name = "libbitsandbytes.so" 22 | print(f"CUDA SETUP: Defaulting to {legacy_binary_name}...") 23 | binary_path = package_dir / legacy_binary_name 24 | if not binary_path.exists(): 25 | print('CUDA SETUP: CUDA detection failed. Either CUDA driver not installed, CUDA not installed, or you have multiple conflicting CUDA libraries!') 26 | print('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.') 27 | raise Exception('CUDA SETUP: Setup Failed!') 28 | # self.lib = ct.cdll.LoadLibrary(binary_path) 29 | self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$ 30 | else: 31 | print(f"CUDA SETUP: Loading binary {binary_path}...") 32 | # self.lib = ct.cdll.LoadLibrary(binary_path) 33 | self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$ 34 | 35 | @classmethod 36 | def get_instance(cls): 37 | if cls._instance is None: 38 | cls._instance = cls.__new__(cls) 39 | cls._instance.initialize() 40 | return cls._instance 41 | 42 | 43 | lib = CUDALibrary_Singleton.get_instance().lib 44 | try: 45 | lib.cadam32bit_g32 46 | lib.get_context.restype = ct.c_void_p 47 | lib.get_cusparse.restype = ct.c_void_p 48 | COMPILED_WITH_CUDA = True 49 | except AttributeError: 50 | warn( 51 | "The installed version of bitsandbytes was compiled without GPU support. " 52 | "8-bit optimizers and GPU quantization are unavailable." 53 | ) 54 | COMPILED_WITH_CUDA = False 55 | -------------------------------------------------------------------------------- /bitsandbytes_windows/libbitsandbytes_cpu.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cagliostrolab/sd-scripts-ani3/a8cf51571ebbcb4b8d0d413b1e02d68d7d53bbbf/bitsandbytes_windows/libbitsandbytes_cpu.dll -------------------------------------------------------------------------------- /bitsandbytes_windows/libbitsandbytes_cuda116.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cagliostrolab/sd-scripts-ani3/a8cf51571ebbcb4b8d0d413b1e02d68d7d53bbbf/bitsandbytes_windows/libbitsandbytes_cuda116.dll -------------------------------------------------------------------------------- /bitsandbytes_windows/libbitsandbytes_cuda118.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cagliostrolab/sd-scripts-ani3/a8cf51571ebbcb4b8d0d413b1e02d68d7d53bbbf/bitsandbytes_windows/libbitsandbytes_cuda118.dll -------------------------------------------------------------------------------- /bitsandbytes_windows/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | extract factors the build is dependent on: 3 | [X] compute capability 4 | [ ] TODO: Q - What if we have multiple GPUs of different makes? 5 | - CUDA version 6 | - Software: 7 | - CPU-only: only CPU quantization functions (no optimizer, no matrix multiple) 8 | - CuBLAS-LT: full-build 8-bit optimizer 9 | - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) 10 | 11 | evaluation: 12 | - if paths faulty, return meaningful error 13 | - else: 14 | - determine CUDA version 15 | - determine capabilities 16 | - based on that set the default path 17 | """ 18 | 19 | import ctypes 20 | 21 | from .paths import determine_cuda_runtime_lib_path 22 | 23 | 24 | def check_cuda_result(cuda, result_val): 25 | # 3. Check for CUDA errors 26 | if result_val != 0: 27 | error_str = ctypes.c_char_p() 28 | cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) 29 | print(f"CUDA exception! Error code: {error_str.value.decode()}") 30 | 31 | def get_cuda_version(cuda, cudart_path): 32 | # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION 33 | try: 34 | cudart = ctypes.CDLL(cudart_path) 35 | except OSError: 36 | # TODO: shouldn't we error or at least warn here? 37 | print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!') 38 | return None 39 | 40 | version = ctypes.c_int() 41 | check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version))) 42 | version = int(version.value) 43 | major = version//1000 44 | minor = (version-(major*1000))//10 45 | 46 | if major < 11: 47 | print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') 48 | 49 | return f'{major}{minor}' 50 | 51 | 52 | def get_cuda_lib_handle(): 53 | # 1. find libcuda.so library (GPU driver) (/usr/lib) 54 | try: 55 | cuda = ctypes.CDLL("libcuda.so") 56 | except OSError: 57 | # TODO: shouldn't we error or at least warn here? 58 | print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!') 59 | return None 60 | check_cuda_result(cuda, cuda.cuInit(0)) 61 | 62 | return cuda 63 | 64 | 65 | def get_compute_capabilities(cuda): 66 | """ 67 | 1. find libcuda.so library (GPU driver) (/usr/lib) 68 | init_device -> init variables -> call function by reference 69 | 2. call extern C function to determine CC 70 | (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html) 71 | 3. Check for CUDA errors 72 | https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api 73 | # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 74 | """ 75 | 76 | 77 | nGpus = ctypes.c_int() 78 | cc_major = ctypes.c_int() 79 | cc_minor = ctypes.c_int() 80 | 81 | device = ctypes.c_int() 82 | 83 | check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus))) 84 | ccs = [] 85 | for i in range(nGpus.value): 86 | check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i)) 87 | ref_major = ctypes.byref(cc_major) 88 | ref_minor = ctypes.byref(cc_minor) 89 | # 2. call extern C function to determine CC 90 | check_cuda_result( 91 | cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device) 92 | ) 93 | ccs.append(f"{cc_major.value}.{cc_minor.value}") 94 | 95 | return ccs 96 | 97 | 98 | # def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error 99 | def get_compute_capability(cuda): 100 | """ 101 | Extracts the highest compute capbility from all available GPUs, as compute 102 | capabilities are downwards compatible. If no GPUs are detected, it returns 103 | None. 104 | """ 105 | ccs = get_compute_capabilities(cuda) 106 | if ccs is not None: 107 | # TODO: handle different compute capabilities; for now, take the max 108 | return ccs[-1] 109 | return None 110 | 111 | 112 | def evaluate_cuda_setup(): 113 | print('') 114 | print('='*35 + 'BUG REPORT' + '='*35) 115 | print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues') 116 | print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link') 117 | print('='*80) 118 | return "libbitsandbytes_cuda116.dll" # $$$ 119 | 120 | binary_name = "libbitsandbytes_cpu.so" 121 | #if not torch.cuda.is_available(): 122 | #print('No GPU detected. Loading CPU library...') 123 | #return binary_name 124 | 125 | cudart_path = determine_cuda_runtime_lib_path() 126 | if cudart_path is None: 127 | print( 128 | "WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!" 129 | ) 130 | return binary_name 131 | 132 | print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}") 133 | cuda = get_cuda_lib_handle() 134 | cc = get_compute_capability(cuda) 135 | print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}") 136 | cuda_version_string = get_cuda_version(cuda, cudart_path) 137 | 138 | 139 | if cc == '': 140 | print( 141 | "WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..." 142 | ) 143 | return binary_name 144 | 145 | # 7.5 is the minimum CC vor cublaslt 146 | has_cublaslt = cc in ["7.5", "8.0", "8.6"] 147 | 148 | # TODO: 149 | # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) 150 | # (2) Multiple CUDA versions installed 151 | 152 | # we use ls -l instead of nvcc to determine the cuda version 153 | # since most installations will have the libcudart.so installed, but not the compiler 154 | print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}') 155 | 156 | def get_binary_name(): 157 | "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so" 158 | bin_base_name = "libbitsandbytes_cuda" 159 | if has_cublaslt: 160 | return f"{bin_base_name}{cuda_version_string}.so" 161 | else: 162 | return f"{bin_base_name}{cuda_version_string}_nocublaslt.so" 163 | 164 | binary_name = get_binary_name() 165 | 166 | return binary_name 167 | -------------------------------------------------------------------------------- /docs/fine_tune_README_ja.md: -------------------------------------------------------------------------------- 1 | NovelAIの提案した学習手法、自動キャプションニング、タグ付け、Windows+VRAM 12GB(SD v1.xの場合)環境等に対応したfine tuningです。ここでfine tuningとは、モデルを画像とキャプションで学習することを指します(LoRAやTextual Inversion、Hypernetworksは含みません) 2 | 3 | [学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。 4 | 5 | # 概要 6 | 7 | Diffusersを用いてStable DiffusionのU-Netのfine tuningを行います。NovelAIの記事にある以下の改善に対応しています(Aspect Ratio BucketingについてはNovelAIのコードを参考にしましたが、最終的なコードはすべてオリジナルです)。 8 | 9 | * CLIP(Text Encoder)の最後の層ではなく最後から二番目の層の出力を用いる。 10 | * 正方形以外の解像度での学習(Aspect Ratio Bucketing) 。 11 | * トークン長を75から225に拡張する。 12 | * BLIPによるキャプショニング(キャプションの自動作成)、DeepDanbooruまたはWD14Taggerによる自動タグ付けを行う。 13 | * Hypernetworkの学習にも対応する。 14 | * Stable Diffusion v2.0(baseおよび768/v)に対応。 15 | * VAEの出力をあらかじめ取得しディスクに保存しておくことで、学習の省メモリ化、高速化を図る。 16 | 17 | デフォルトではText Encoderの学習は行いません。モデル全体のfine tuningではU-Netだけを学習するのが一般的なようです(NovelAIもそのようです)。オプション指定でText Encoderも学習対象とできます。 18 | 19 | # 追加機能について 20 | 21 | ## CLIPの出力の変更 22 | 23 | プロンプトを画像に反映するため、テキストの特徴量への変換を行うのがCLIP(Text Encoder)です。Stable DiffusionではCLIPの最後の層の出力を用いていますが、それを最後から二番目の層の出力を用いるよう変更できます。NovelAIによると、これによりより正確にプロンプトが反映されるようになるとのことです。 24 | 元のまま、最後の層の出力を用いることも可能です。 25 | 26 | ※Stable Diffusion 2.0では最後から二番目の層をデフォルトで使います。clip_skipオプションを指定しないでください。 27 | 28 | ## 正方形以外の解像度での学習 29 | 30 | Stable Diffusionは512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくプロンプトと画像の関係が学習されることが期待されます。 31 | 学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位で縦横に調整、作成されます。 32 | 33 | 機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。 34 | 35 | ## トークン長の75から225への拡張 36 | 37 | Stable Diffusionでは最大75トークン(開始・終了を含むと77トークン)ですが、それを225トークンまで拡張します。 38 | ただしCLIPが受け付ける最大長は75トークンですので、225トークンの場合、単純に三分割してCLIPを呼び出してから結果を連結しています。 39 | 40 | ※これが望ましい実装なのかどうかはいまひとつわかりません。とりあえず動いてはいるようです。特に2.0では何も参考になる実装がないので独自に実装してあります。 41 | 42 | ※Automatic1111氏のWeb UIではカンマを意識して分割、といったこともしているようですが、私の場合はそこまでしておらず単純な分割です。 43 | 44 | # 学習の手順 45 | 46 | あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。 47 | 48 | ## データの準備 49 | 50 | [学習データの準備について](./train_README-ja.md) を参照してください。fine tuningではメタデータを用いるfine tuning方式のみ対応しています。 51 | 52 | ## 学習の実行 53 | たとえば以下のように実行します。以下は省メモリ化のための設定です。それぞれの行を必要に応じて書き換えてください。 54 | 55 | ``` 56 | accelerate launch --num_cpu_threads_per_process 1 fine_tune.py 57 | --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ> 58 | --output_dir=<学習したモデルの出力先フォルダ> 59 | --output_name=<学習したモデル出力時のファイル名> 60 | --dataset_config=<データ準備で作成した.tomlファイル> 61 | --save_model_as=safetensors 62 | --learning_rate=5e-6 --max_train_steps=10000 63 | --use_8bit_adam --xformers --gradient_checkpointing 64 | --mixed_precision=fp16 65 | ``` 66 | 67 | `num_cpu_threads_per_process` には通常は1を指定するとよいようです。 68 | 69 | `pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。 70 | 71 | `output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。 72 | 73 | `dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。 74 | 75 | 学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。 76 | 77 | 省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。 78 | 79 | オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。 80 | 81 | `xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。 82 | 83 | ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。 84 | 85 | ### よく使われるオプションについて 86 | 87 | 以下の場合にはオプションに関するドキュメントを参照してください。 88 | 89 | - Stable Diffusion 2.xまたはそこからの派生モデルを学習する 90 | - clip skipを2以上を前提としたモデルを学習する 91 | - 75トークンを超えたキャプションで学習する 92 | 93 | ### バッチサイズについて 94 | 95 | モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(DreamBoothと同じ)。 96 | 97 | ### 学習率について 98 | 99 | 1e-6から5e-6程度が一般的なようです。他のfine tuningの例なども参照してみてください。 100 | 101 | ### 以前の形式のデータセット指定をした場合のコマンドライン 102 | 103 | 解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。 104 | 105 | ``` 106 | accelerate launch --num_cpu_threads_per_process 1 fine_tune.py 107 | --pretrained_model_name_or_path=model.ckpt 108 | --in_json meta_lat.json 109 | --train_data_dir=train_data 110 | --output_dir=fine_tuned 111 | --shuffle_caption 112 | --train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000 113 | --use_8bit_adam --xformers --gradient_checkpointing 114 | --mixed_precision=bf16 115 | --save_every_n_epochs=4 116 | ``` 117 | 118 | 129 | 130 | # fine tuning特有のその他の主なオプション 131 | 132 | すべてのオプションについては別文書を参照してください。 133 | 134 | ## `train_text_encoder` 135 | Text Encoderも学習対象とします。メモリ使用量が若干増加します。 136 | 137 | 通常のfine tuningではText Encoderは学習対象としませんが(恐らくText Encoderの出力に従うようにU-Netを学習するため)、学習データ数が少ない場合には、DreamBoothのようにText Encoder側に学習させるのも有効的なようです。 138 | 139 | ## `diffusers_xformers` 140 | スクリプト独自のxformers置換機能ではなくDiffusersのxformers機能を利用します。Hypernetworkの学習はできなくなります。 141 | -------------------------------------------------------------------------------- /docs/train_db_README-ja.md: -------------------------------------------------------------------------------- 1 | DreamBoothのガイドです。 2 | 3 | [学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。 4 | 5 | # 概要 6 | 7 | DreamBoothとは、画像生成モデルに特定の主題を追加学習し、それを特定の識別子で生成する技術です。[論文はこちら](https://arxiv.org/abs/2208.12242)。 8 | 9 | 具体的には、Stable Diffusionのモデルにキャラや画風などを学ばせ、それを `shs` のような特定の単語で呼び出せる(生成画像に出現させる)ことができます。 10 | 11 | スクリプトは[DiffusersのDreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)を元にしていますが、以下のような機能追加を行っています(いくつかの機能は元のスクリプト側もその後対応しています)。 12 | 13 | スクリプトの主な機能は以下の通りです。 14 | 15 | - 8bit Adam optimizerおよびlatentのキャッシュによる省メモリ化([Shivam Shrirao氏版](https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth)と同様)。 16 | - xformersによる省メモリ化。 17 | - 512x512だけではなく任意サイズでの学習。 18 | - augmentationによる品質の向上。 19 | - DreamBoothだけではなくText Encoder+U-Netのfine tuningに対応。 20 | - Stable Diffusion形式でのモデルの読み書き。 21 | - Aspect Ratio Bucketing。 22 | - Stable Diffusion v2.0対応。 23 | 24 | # 学習の手順 25 | 26 | あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。 27 | 28 | ## データの準備 29 | 30 | [学習データの準備について](./train_README-ja.md) を参照してください。 31 | 32 | ## 学習の実行 33 | 34 | スクリプトを実行します。最大限、メモリを節約したコマンドは以下のようになります(実際には1行で入力します)。それぞれの行を必要に応じて書き換えてください。12GB程度のVRAMで動作するようです。 35 | 36 | ``` 37 | accelerate launch --num_cpu_threads_per_process 1 train_db.py 38 | --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ> 39 | --dataset_config=<データ準備で作成した.tomlファイル> 40 | --output_dir=<学習したモデルの出力先フォルダ> 41 | --output_name=<学習したモデル出力時のファイル名> 42 | --save_model_as=safetensors 43 | --prior_loss_weight=1.0 44 | --max_train_steps=1600 45 | --learning_rate=1e-6 46 | --optimizer_type="AdamW8bit" 47 | --xformers 48 | --mixed_precision="fp16" 49 | --cache_latents 50 | --gradient_checkpointing 51 | ``` 52 | 53 | `num_cpu_threads_per_process` には通常は1を指定するとよいようです。 54 | 55 | `pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。 56 | 57 | `output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。 58 | 59 | `dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。 60 | 61 | `prior_loss_weight` は正則化画像のlossの重みです。通常は1.0を指定します。 62 | 63 | 学習させるステップ数 `max_train_steps` を1600とします。学習率 `learning_rate` はここでは1e-6を指定しています。 64 | 65 | 省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。 66 | 67 | オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。 68 | 69 | `xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。 70 | 71 | 省メモリ化のため `cache_latents` オプションを指定してVAEの出力をキャッシュします。 72 | 73 | ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。また `cache_latents` を外すことで augmentation が可能になります。 74 | 75 | ### よく使われるオプションについて 76 | 77 | 以下の場合には [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」を参照してください。 78 | 79 | - Stable Diffusion 2.xまたはそこからの派生モデルを学習する 80 | - clip skipを2以上を前提としたモデルを学習する 81 | - 75トークンを超えたキャプションで学習する 82 | 83 | ### DreamBoothでのステップ数について 84 | 85 | 当スクリプトでは省メモリ化のため、ステップ当たりの学習回数が元のスクリプトの半分になっています(対象の画像と正則化画像を同一のバッチではなく別のバッチに分割して学習するため)。 86 | 87 | 元のDiffusers版やXavierXiao氏のStable Diffusion版とほぼ同じ学習を行うには、ステップ数を倍にしてください。 88 | 89 | (学習画像と正則化画像をまとめてから shuffle するため厳密にはデータの順番が変わってしまいますが、学習には大きな影響はないと思います。) 90 | 91 | ### DreamBoothでのバッチサイズについて 92 | 93 | モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(fine tuningと同じ)。 94 | 95 | ### 学習率について 96 | 97 | Diffusers版では5e-6ですがStable Diffusion版は1e-6ですので、上のサンプルでは1e-6を指定しています。 98 | 99 | ### 以前の形式のデータセット指定をした場合のコマンドライン 100 | 101 | 解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。 102 | 103 | ``` 104 | accelerate launch --num_cpu_threads_per_process 1 train_db.py 105 | --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ> 106 | --train_data_dir=<学習用データのディレクトリ> 107 | --reg_data_dir=<正則化画像のディレクトリ> 108 | --output_dir=<学習したモデルの出力先ディレクトリ> 109 | --output_name=<学習したモデル出力時のファイル名> 110 | --prior_loss_weight=1.0 111 | --resolution=512 112 | --train_batch_size=1 113 | --learning_rate=1e-6 114 | --max_train_steps=1600 115 | --use_8bit_adam 116 | --xformers 117 | --mixed_precision="bf16" 118 | --cache_latents 119 | --gradient_checkpointing 120 | ``` 121 | 122 | ## 学習したモデルで画像生成する 123 | 124 | 学習が終わると指定したフォルダに指定した名前でsafetensorsファイルが出力されます。 125 | 126 | v1.4/1.5およびその他の派生モデルの場合、このモデルでAutomatic1111氏のWebUIなどで推論できます。models\Stable-diffusionフォルダに置いてください。 127 | 128 | v2.xモデルでWebUIで画像生成する場合、モデルの仕様が記述された.yamlファイルが別途必要になります。v2.x baseの場合はv2-inference.yamlを、768/vの場合はv2-inference-v.yamlを、同じフォルダに置き、拡張子の前の部分をモデルと同じ名前にしてください。 129 | 130 | ![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png) 131 | 132 | 各yamlファイルは[Stability AIのSD2.0のリポジトリ](https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion)にあります。 133 | 134 | # DreamBooth特有のその他の主なオプション 135 | 136 | すべてのオプションについては別文書を参照してください。 137 | 138 | ## Text Encoderの学習を途中から行わない --stop_text_encoder_training 139 | 140 | stop_text_encoder_trainingオプションに数値を指定すると、そのステップ数以降はText Encoderの学習を行わずU-Netだけ学習します。場合によっては精度の向上が期待できるかもしれません。 141 | 142 | (恐らくText Encoderだけ先に過学習することがあり、それを防げるのではないかと推測していますが、詳細な影響は不明です。) 143 | 144 | ## Tokenizerのパディングをしない --no_token_padding 145 | no_token_paddingオプションを指定するとTokenizerの出力をpaddingしません(Diffusers版の旧DreamBoothと同じ動きになります)。 146 | 147 | 148 | 168 | -------------------------------------------------------------------------------- /docs/train_db_README-zh.md: -------------------------------------------------------------------------------- 1 | 这是DreamBooth的指南。 2 | 3 | 请同时查看[关于学习的通用文档](./train_README-zh.md)。 4 | 5 | # 概要 6 | 7 | DreamBooth是一种将特定主题添加到图像生成模型中进行学习,并使用特定识别子生成它的技术。论文链接。 8 | 9 | 具体来说,它可以将角色和绘画风格等添加到Stable Diffusion模型中进行学习,并使用特定的单词(例如`shs`)来调用(呈现在生成的图像中)。 10 | 11 | 脚本基于Diffusers的DreamBooth,但添加了以下功能(一些功能已在原始脚本中得到支持)。 12 | 13 | 脚本的主要功能如下: 14 | 15 | - 使用8位Adam优化器和潜在变量的缓存来节省内存(与Shivam Shrirao版相似)。 16 | - 使用xformers来节省内存。 17 | - 不仅支持512x512,还支持任意尺寸的训练。 18 | - 通过数据增强来提高质量。 19 | - 支持DreamBooth和Text Encoder + U-Net的微调。 20 | - 支持以Stable Diffusion格式读写模型。 21 | - 支持Aspect Ratio Bucketing。 22 | - 支持Stable Diffusion v2.0。 23 | 24 | # 训练步骤 25 | 26 | 请先参阅此存储库的README以进行环境设置。 27 | 28 | ## 准备数据 29 | 30 | 请参阅[有关准备训练数据的说明](./train_README-zh.md)。 31 | 32 | ## 运行训练 33 | 34 | 运行脚本。以下是最大程度地节省内存的命令(实际上,这将在一行中输入)。请根据需要修改每行。它似乎需要约12GB的VRAM才能运行。 35 | ``` 36 | accelerate launch --num_cpu_threads_per_process 1 train_db.py 37 | --pretrained_model_name_or_path=<.ckpt或.safetensord或Diffusers版模型的目录> 38 | --dataset_config=<数据准备时创建的.toml文件> 39 | --output_dir=<训练模型的输出目录> 40 | --output_name=<训练模型输出时的文件名> 41 | --save_model_as=safetensors 42 | --prior_loss_weight=1.0 43 | --max_train_steps=1600 44 | --learning_rate=1e-6 45 | --optimizer_type="AdamW8bit" 46 | --xformers 47 | --mixed_precision="fp16" 48 | --cache_latents 49 | --gradient_checkpointing 50 | ``` 51 | `num_cpu_threads_per_process` 通常应该设置为1。 52 | 53 | `pretrained_model_name_or_path` 指定要进行追加训练的基础模型。可以指定 Stable Diffusion 的 checkpoint 文件(.ckpt 或 .safetensors)、Diffusers 的本地模型目录或模型 ID(如 "stabilityai/stable-diffusion-2")。 54 | 55 | `output_dir` 指定保存训练后模型的文件夹。在 `output_name` 中指定模型文件名,不包括扩展名。使用 `save_model_as` 指定以 safetensors 格式保存。 56 | 57 | 在 `dataset_config` 中指定 `.toml` 文件。初始批处理大小应为 `1`,以减少内存消耗。 58 | 59 | `prior_loss_weight` 是正则化图像损失的权重。通常设为1.0。 60 | 61 | 将要训练的步数 `max_train_steps` 设置为1600。在这里,学习率 `learning_rate` 被设置为1e-6。 62 | 63 | 为了节省内存,设置 `mixed_precision="fp16"`(在 RTX30 系列及更高版本中也可以设置为 `bf16`)。同时指定 `gradient_checkpointing`。 64 | 65 | 为了使用内存消耗较少的 8bit AdamW 优化器(将模型优化为适合于训练数据的状态),指定 `optimizer_type="AdamW8bit"`。 66 | 67 | 指定 `xformers` 选项,并使用 xformers 的 CrossAttention。如果未安装 xformers 或出现错误(具体情况取决于环境,例如使用 `mixed_precision="no"`),则可以指定 `mem_eff_attn` 选项以使用省内存版的 CrossAttention(速度会变慢)。 68 | 69 | 为了节省内存,指定 `cache_latents` 选项以缓存 VAE 的输出。 70 | 71 | 如果有足够的内存,请编辑 `.toml` 文件将批处理大小增加到大约 `4`(可能会提高速度和精度)。此外,取消 `cache_latents` 选项可以进行数据增强。 72 | 73 | ### 常用选项 74 | 75 | 对于以下情况,请参阅“常用选项”部分。 76 | 77 | - 学习 Stable Diffusion 2.x 或其衍生模型。 78 | - 学习基于 clip skip 大于等于2的模型。 79 | - 学习超过75个令牌的标题。 80 | 81 | ### 关于DreamBooth中的步数 82 | 83 | 为了实现省内存化,该脚本中每个步骤的学习次数减半(因为学习和正则化的图像在训练时被分为不同的批次)。 84 | 85 | 要进行与原始Diffusers版或XavierXiao的Stable Diffusion版几乎相同的学习,请将步骤数加倍。 86 | 87 | (虽然在将学习图像和正则化图像整合后再打乱顺序,但我认为对学习没有太大影响。) 88 | 89 | 关于DreamBooth的批量大小 90 | 91 | 与像LoRA这样的学习相比,为了训练整个模型,内存消耗量会更大(与微调相同)。 92 | 93 | 关于学习率 94 | 95 | 在Diffusers版中,学习率为5e-6,而在Stable Diffusion版中为1e-6,因此在上面的示例中指定了1e-6。 96 | 97 | 当使用旧格式的数据集指定命令行时 98 | 99 | 使用选项指定分辨率和批量大小。命令行示例如下。 100 | ``` 101 | accelerate launch --num_cpu_threads_per_process 1 train_db.py 102 | --pretrained_model_name_or_path=<.ckpt或.safetensord或Diffusers版模型的目录> 103 | --train_data_dir=<训练数据的目录> 104 | --reg_data_dir=<正则化图像的目录> 105 | --output_dir=<训练后模型的输出目录> 106 | --output_name=<训练后模型输出文件的名称> 107 | --prior_loss_weight=1.0 108 | --resolution=512 109 | --train_batch_size=1 110 | --learning_rate=1e-6 111 | --max_train_steps=1600 112 | --use_8bit_adam 113 | --xformers 114 | --mixed_precision="bf16" 115 | --cache_latents 116 | --gradient_checkpointing 117 | ``` 118 | 119 | ## 使用训练好的模型生成图像 120 | 121 | 训练完成后,将在指定的文件夹中以指定的名称输出safetensors文件。 122 | 123 | 对于v1.4/1.5和其他派生模型,可以在此模型中使用Automatic1111先生的WebUI进行推断。请将其放置在models\Stable-diffusion文件夹中。 124 | 125 | 对于使用v2.x模型在WebUI中生成图像的情况,需要单独的.yaml文件来描述模型的规格。对于v2.x base,需要v2-inference.yaml,对于768/v,则需要v2-inference-v.yaml。请将它们放置在相同的文件夹中,并将文件扩展名之前的部分命名为与模型相同的名称。 126 | ![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png) 127 | 128 | 每个yaml文件都在[Stability AI的SD2.0存储库](https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion)……之中。 129 | 130 | # DreamBooth的其他主要选项 131 | 132 | 有关所有选项的详细信息,请参阅另一份文档。 133 | 134 | ## 不在中途开始对文本编码器进行训练 --stop_text_encoder_training 135 | 136 | 如果在stop_text_encoder_training选项中指定一个数字,则在该步骤之后,将不再对文本编码器进行训练,只会对U-Net进行训练。在某些情况下,可能会期望提高精度。 137 | 138 | (我们推测可能会有时候仅仅文本编码器会过度学习,而这样做可以避免这种情况,但详细影响尚不清楚。) 139 | 140 | ## 不进行分词器的填充 --no_token_padding 141 | 142 | 如果指定no_token_padding选项,则不会对分词器的输出进行填充(与Diffusers版本的旧DreamBooth相同)。 143 | 144 | 163 | -------------------------------------------------------------------------------- /docs/train_lllite_README-ja.md: -------------------------------------------------------------------------------- 1 | # ControlNet-LLLite について 2 | 3 | __きわめて実験的な実装のため、将来的に大きく変更される可能性があります。__ 4 | 5 | ## 概要 6 | ControlNet-LLLite は、[ControlNet](https://github.com/lllyasviel/ControlNet) の軽量版です。LoRA Like Lite という意味で、LoRAからインスピレーションを得た構造を持つ、軽量なControlNetです。現在はSDXLにのみ対応しています。 7 | 8 | ## サンプルの重みファイルと推論 9 | 10 | こちらにあります: https://huggingface.co/kohya-ss/controlnet-lllite 11 | 12 | ComfyUIのカスタムノードを用意しています。: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI 13 | 14 | 生成サンプルはこのページの末尾にあります。 15 | 16 | ## モデル構造 17 | ひとつのLLLiteモジュールは、制御用画像(以下conditioning image)を潜在空間に写像するconditioning image embeddingと、LoRAにちょっと似た構造を持つ小型のネットワークからなります。LLLiteモジュールを、LoRAと同様にU-NetのLinearやConvに追加します。詳しくはソースコードを参照してください。 18 | 19 | 推論環境の制限で、現在はCrossAttentionのみ(attn1のq/k/v、attn2のq)に追加されます。 20 | 21 | ## モデルの学習 22 | 23 | ### データセットの準備 24 | 通常のdatasetに加え、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。 25 | 26 | たとえば DreamBooth 方式でキャプションファイルを用いる場合の設定ファイルは以下のようになります。 27 | 28 | ```toml 29 | [[datasets.subsets]] 30 | image_dir = "path/to/image/dir" 31 | caption_extension = ".txt" 32 | conditioning_data_dir = "path/to/conditioning/image/dir" 33 | ``` 34 | 35 | 現時点の制約として、random_cropは使用できません。 36 | 37 | 学習データとしては、元のモデルで生成した画像を学習用画像として、そこから加工した画像をconditioning imageとした、合成によるデータセットを用いるのがもっとも簡単です(データセットの品質的には問題があるかもしれません)。具体的なデータセットの合成方法については後述します。 38 | 39 | なお、元モデルと異なる画風の画像を学習用画像とすると、制御に加えて、その画風についても学ぶ必要が生じます。ControlNet-LLLiteは容量が少ないため、画風学習には不向きです。このような場合には、後述の次元数を多めにしてください。 40 | 41 | ### 学習 42 | スクリプトで生成する場合は、`sdxl_train_control_net_lllite.py` を実行してください。`--cond_emb_dim` でconditioning image embeddingの次元数を指定できます。`--network_dim` でLoRA的モジュールのrankを指定できます。その他のオプションは`sdxl_train_network.py`に準じますが、`--network_module`の指定は不要です。 43 | 44 | 学習時にはメモリを大量に使用しますので、キャッシュやgradient checkpointingなどの省メモリ化のオプションを有効にしてください。また`--full_bf16` オプションで、BFloat16を使用するのも有効です(RTX 30シリーズ以降のGPUが必要です)。24GB VRAMで動作確認しています。 45 | 46 | conditioning image embeddingの次元数は、サンプルのCannyでは32を指定しています。LoRA的モジュールのrankは同じく64です。対象とするconditioning imageの特徴に合わせて調整してください。 47 | 48 | (サンプルのCannyは恐らくかなり難しいと思われます。depthなどでは半分程度にしてもいいかもしれません。) 49 | 50 | 以下は .toml の設定例です。 51 | 52 | ```toml 53 | pretrained_model_name_or_path = "/path/to/model_trained_on.safetensors" 54 | max_train_epochs = 12 55 | max_data_loader_n_workers = 4 56 | persistent_data_loader_workers = true 57 | seed = 42 58 | gradient_checkpointing = true 59 | mixed_precision = "bf16" 60 | save_precision = "bf16" 61 | full_bf16 = true 62 | optimizer_type = "adamw8bit" 63 | learning_rate = 2e-4 64 | xformers = true 65 | output_dir = "/path/to/output/dir" 66 | output_name = "output_name" 67 | save_every_n_epochs = 1 68 | save_model_as = "safetensors" 69 | vae_batch_size = 4 70 | cache_latents = true 71 | cache_latents_to_disk = true 72 | cache_text_encoder_outputs = true 73 | cache_text_encoder_outputs_to_disk = true 74 | network_dim = 64 75 | cond_emb_dim = 32 76 | dataset_config = "/path/to/dataset.toml" 77 | ``` 78 | 79 | ### 推論 80 | 81 | スクリプトで生成する場合は、`sdxl_gen_img.py` を実行してください。`--control_net_lllite_models` でLLLiteのモデルファイルを指定できます。次元数はモデルファイルから自動取得します。 82 | 83 | `--guide_image_path`で推論に用いるconditioning imageを指定してください。なおpreprocessは行われないため、たとえばCannyならCanny処理を行った画像を指定してください(背景黒に白線)。`--control_net_preps`, `--control_net_weights`, `--control_net_ratios` には未対応です。 84 | 85 | ## データセットの合成方法 86 | 87 | ### 学習用画像の生成 88 | 89 | 学習のベースとなるモデルで画像生成を行います。Web UIやComfyUIなどで生成してください。画像サイズはモデルのデフォルトサイズで良いと思われます(1024x1024など)。bucketingを用いることもできます。その場合は適宜適切な解像度で生成してください。 90 | 91 | 生成時のキャプション等は、ControlNet-LLLiteの利用時に生成したい画像にあわせるのが良いと思われます。 92 | 93 | 生成した画像を任意のディレクトリに保存してください。このディレクトリをデータセットの設定ファイルで指定します。 94 | 95 | 当リポジトリ内の `sdxl_gen_img.py` でも生成できます。例えば以下のように実行します。 96 | 97 | ```dos 98 | python sdxl_gen_img.py --ckpt path/to/model.safetensors --n_iter 1 --scale 10 --steps 36 --outdir path/to/output/dir --xformers --W 1024 --H 1024 --original_width 2048 --original_height 2048 --bf16 --sampler ddim --batch_size 4 --vae_batch_size 2 --images_per_prompt 512 --max_embeddings_multiples 1 --prompt "{portrait|digital art|anime screen cap|detailed illustration} of 1girl, {standing|sitting|walking|running|dancing} on {classroom|street|town|beach|indoors|outdoors}, {looking at viewer|looking away|looking at another}, {in|wearing} {shirt and skirt|school uniform|casual wear} { |, dynamic pose}, (solo), teen age, {0-1$$smile,|blush,|kind smile,|expression less,|happy,|sadness,} {0-1$$upper body,|full body,|cowboy shot,|face focus,} trending on pixiv, {0-2$$depth of fields,|8k wallpaper,|highly detailed,|pov,} {0-1$$summer, |winter, |spring, |autumn, } beautiful face { |, from below|, from above|, from side|, from behind|, from back} --n nsfw, bad face, lowres, low quality, worst quality, low effort, watermark, signature, ugly, poorly drawn" 99 | ``` 100 | 101 | VRAM 24GBの設定です。VRAMサイズにより`--batch_size` `--vae_batch_size`を調整してください。 102 | 103 | `--prompt`でワイルドカードを利用してランダムに生成しています。適宜調整してください。 104 | 105 | ### 画像の加工 106 | 107 | 外部のプログラムを用いて、生成した画像を加工します。加工した画像を任意のディレクトリに保存してください。これらがconditioning imageになります。 108 | 109 | 加工にはたとえばCannyなら以下のようなスクリプトが使えます。 110 | 111 | ```python 112 | import glob 113 | import os 114 | import random 115 | import cv2 116 | import numpy as np 117 | 118 | IMAGES_DIR = "path/to/generated/images" 119 | CANNY_DIR = "path/to/canny/images" 120 | 121 | os.makedirs(CANNY_DIR, exist_ok=True) 122 | img_files = glob.glob(IMAGES_DIR + "/*.png") 123 | for img_file in img_files: 124 | can_file = CANNY_DIR + "/" + os.path.basename(img_file) 125 | if os.path.exists(can_file): 126 | print("Skip: " + img_file) 127 | continue 128 | 129 | print(img_file) 130 | 131 | img = cv2.imread(img_file) 132 | 133 | # random threshold 134 | # while True: 135 | # threshold1 = random.randint(0, 127) 136 | # threshold2 = random.randint(128, 255) 137 | # if threshold2 - threshold1 > 80: 138 | # break 139 | 140 | # fixed threshold 141 | threshold1 = 100 142 | threshold2 = 200 143 | 144 | img = cv2.Canny(img, threshold1, threshold2) 145 | 146 | cv2.imwrite(can_file, img) 147 | ``` 148 | 149 | ### キャプションファイルの作成 150 | 151 | 学習用画像のbasenameと同じ名前で、それぞれの画像に対応したキャプションファイルを作成してください。生成時のプロンプトをそのまま利用すれば良いと思われます。 152 | 153 | `sdxl_gen_img.py` で生成した場合は、画像内のメタデータに生成時のプロンプトが記録されていますので、以下のようなスクリプトで学習用画像と同じディレクトリにキャプションファイルを作成できます(拡張子 `.txt`)。 154 | 155 | ```python 156 | import glob 157 | import os 158 | from PIL import Image 159 | 160 | IMAGES_DIR = "path/to/generated/images" 161 | 162 | img_files = glob.glob(IMAGES_DIR + "/*.png") 163 | for img_file in img_files: 164 | cap_file = img_file.replace(".png", ".txt") 165 | if os.path.exists(cap_file): 166 | print(f"Skip: {img_file}") 167 | continue 168 | print(img_file) 169 | 170 | img = Image.open(img_file) 171 | prompt = img.text["prompt"] if "prompt" in img.text else "" 172 | if prompt == "": 173 | print(f"Prompt not found in {img_file}") 174 | 175 | with open(cap_file, "w") as f: 176 | f.write(prompt + "\n") 177 | ``` 178 | 179 | ### データセットの設定ファイルの作成 180 | 181 | コマンドラインオプションからの指定も可能ですが、`.toml`ファイルを作成する場合は `conditioning_data_dir` に加工した画像を保存したディレクトリを指定します。 182 | 183 | 以下は設定ファイルの例です。 184 | 185 | ```toml 186 | [general] 187 | flip_aug = false 188 | color_aug = false 189 | resolution = [1024,1024] 190 | 191 | [[datasets]] 192 | batch_size = 8 193 | enable_bucket = false 194 | 195 | [[datasets.subsets]] 196 | image_dir = "path/to/generated/image/dir" 197 | caption_extension = ".txt" 198 | conditioning_data_dir = "path/to/canny/image/dir" 199 | ``` 200 | 201 | ## 謝辞 202 | 203 | ControlNetの作者である lllyasviel 氏、実装上のアドバイスとトラブル解決へのご尽力をいただいた furusu 氏、ControlNetデータセットを実装していただいた ddPn08 氏に感謝いたします。 204 | 205 | ## サンプル 206 | Canny 207 | ![kohya_ss_girl_standing_at_classroom_smiling_to_the_viewer_class_78976b3e-0d4d-4ea0-b8e3-053ae493abbc](https://github.com/kohya-ss/sd-scripts/assets/52813779/37e9a736-649b-4c0f-ab26-880a1bf319b5) 208 | 209 | ![im_20230820104253_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/c8896900-ab86-4120-932f-6e2ae17b77c0) 210 | 211 | ![im_20230820104302_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/b12457a0-ee3c-450e-ba9a-b712d0fe86bb) 212 | 213 | ![im_20230820104310_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/8845b8d9-804a-44ac-9618-113a28eac8a1) 214 | 215 | -------------------------------------------------------------------------------- /docs/train_ti_README-ja.md: -------------------------------------------------------------------------------- 1 | [Textual Inversion](https://textual-inversion.github.io/) の学習についての説明です。 2 | 3 | [学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。 4 | 5 | 実装に当たっては https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を大いに参考にしました。 6 | 7 | 学習したモデルはWeb UIでもそのまま使えます。 8 | 9 | # 学習の手順 10 | 11 | あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。 12 | 13 | ## データの準備 14 | 15 | [学習データの準備について](./train_README-ja.md) を参照してください。 16 | 17 | ## 学習の実行 18 | 19 | ``train_textual_inversion.py`` を用います。以下はコマンドラインの例です(DreamBooth手法)。 20 | 21 | ``` 22 | accelerate launch --num_cpu_threads_per_process 1 train_textual_inversion.py 23 | --dataset_config=<データ準備で作成した.tomlファイル> 24 | --output_dir=<学習したモデルの出力先フォルダ> 25 | --output_name=<学習したモデル出力時のファイル名> 26 | --save_model_as=safetensors 27 | --prior_loss_weight=1.0 28 | --max_train_steps=1600 29 | --learning_rate=1e-6 30 | --optimizer_type="AdamW8bit" 31 | --xformers 32 | --mixed_precision="fp16" 33 | --cache_latents 34 | --gradient_checkpointing 35 | --token_string=mychar4 --init_word=cute --num_vectors_per_token=4 36 | ``` 37 | 38 | ``--token_string`` に学習時のトークン文字列を指定します。__学習時のプロンプトは、この文字列を含むようにしてください(token_stringがmychar4なら、``mychar4 1girl`` など)__。プロンプトのこの文字列の部分が、Textual Inversionの新しいtokenに置換されて学習されます。DreamBooth, class+identifier形式のデータセットとして、`token_string` をトークン文字列にするのが最も簡単で確実です。 39 | 40 | プロンプトにトークン文字列が含まれているかどうかは、``--debug_dataset`` で置換後のtoken idが表示されますので、以下のように ``49408`` 以降のtokenが存在するかどうかで確認できます。 41 | 42 | ``` 43 | input ids: tensor([[49406, 49408, 49409, 49410, 49411, 49412, 49413, 49414, 49415, 49407, 44 | 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 45 | 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 46 | 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 47 | 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 48 | 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49 | 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 50 | 49407, 49407, 49407, 49407, 49407, 49407, 49407]]) 51 | ``` 52 | 53 | tokenizerがすでに持っている単語(一般的な単語)は使用できません。 54 | 55 | ``--init_word`` にembeddingsを初期化するときのコピー元トークンの文字列を指定します。学ばせたい概念が近いものを選ぶとよいようです。二つ以上のトークンになる文字列は指定できません。 56 | 57 | ``--num_vectors_per_token`` にいくつのトークンをこの学習で使うかを指定します。多いほうが表現力が増しますが、その分多くのトークンを消費します。たとえばnum_vectors_per_token=8の場合、指定したトークン文字列は(一般的なプロンプトの77トークン制限のうち)8トークンを消費します。 58 | 59 | 以上がTextual Inversionのための主なオプションです。以降は他の学習スクリプトと同様です。 60 | 61 | `num_cpu_threads_per_process` には通常は1を指定するとよいようです。 62 | 63 | `pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。 64 | 65 | `output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。 66 | 67 | `dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。 68 | 69 | 学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。 70 | 71 | 省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。 72 | 73 | オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。 74 | 75 | `xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。 76 | 77 | ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `8` くらいに増やしてください(高速化と精度向上の可能性があります)。 78 | 79 | ### よく使われるオプションについて 80 | 81 | 以下の場合にはオプションに関するドキュメントを参照してください。 82 | 83 | - Stable Diffusion 2.xまたはそこからの派生モデルを学習する 84 | - clip skipを2以上を前提としたモデルを学習する 85 | - 75トークンを超えたキャプションで学習する 86 | 87 | ### Textual Inversionでのバッチサイズについて 88 | 89 | モデル全体を学習するDreamBoothやfine tuningに比べてメモリ使用量が少ないため、バッチサイズは大きめにできます。 90 | 91 | # Textual Inversionのその他の主なオプション 92 | 93 | すべてのオプションについては別文書を参照してください。 94 | 95 | * `--weights` 96 | * 学習前に学習済みのembeddingsを読み込み、そこから追加で学習します。 97 | * `--use_object_template` 98 | * キャプションではなく既定の物体用テンプレート文字列(``a photo of a {}``など)で学習します。公式実装と同じになります。キャプションは無視されます。 99 | * `--use_style_template` 100 | * キャプションではなく既定のスタイル用テンプレート文字列で学習します(``a painting in the style of {}``など)。公式実装と同じになります。キャプションは無視されます。 101 | 102 | ## 当リポジトリ内の画像生成スクリプトで生成する 103 | 104 | gen_img_diffusers.pyに、``--textual_inversion_embeddings`` オプションで学習したembeddingsファイルを指定してください(複数可)。プロンプトでembeddingsファイルのファイル名(拡張子を除く)を使うと、そのembeddingsが適用されます。 105 | 106 | -------------------------------------------------------------------------------- /finetune/blip/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /finetune/clean_captions_and_tags.py: -------------------------------------------------------------------------------- 1 | # このスクリプトのライセンスは、Apache License 2.0とします 2 | # (c) 2022 Kohya S. @kohya_ss 3 | 4 | import argparse 5 | import glob 6 | import os 7 | import json 8 | import re 9 | 10 | from tqdm import tqdm 11 | 12 | PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') 13 | PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') 14 | PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ') 15 | PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ') 16 | 17 | # 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する 18 | PATTERNS_REMOVE_IN_MULTI = [ 19 | PATTERN_HAIR_LENGTH, 20 | PATTERN_HAIR_CUT, 21 | re.compile(r', [\w\-]+ eyes, '), 22 | re.compile(r', ([\w\-]+ sleeves|sleeveless), '), 23 | # 複数の髪型定義がある場合は削除する 24 | re.compile( 25 | r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '), 26 | ] 27 | 28 | 29 | def clean_tags(image_key, tags): 30 | # replace '_' to ' ' 31 | tags = tags.replace('^_^', '^@@@^') 32 | tags = tags.replace('_', ' ') 33 | tags = tags.replace('^@@@^', '^_^') 34 | 35 | # remove rating: deepdanbooruのみ 36 | tokens = tags.split(", rating") 37 | if len(tokens) == 1: 38 | # WD14 taggerのときはこちらになるのでメッセージは出さない 39 | # print("no rating:") 40 | # print(f"{image_key} {tags}") 41 | pass 42 | else: 43 | if len(tokens) > 2: 44 | print("multiple ratings:") 45 | print(f"{image_key} {tags}") 46 | tags = tokens[0] 47 | 48 | tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策 49 | 50 | # 複数の人物がいる場合は髪色等のタグを削除する 51 | if 'girls' in tags or 'boys' in tags: 52 | for pat in PATTERNS_REMOVE_IN_MULTI: 53 | found = pat.findall(tags) 54 | if len(found) > 1: # 二つ以上、タグがある 55 | tags = pat.sub("", tags) 56 | 57 | # 髪の特殊対応 58 | srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合) 59 | if srch_hair_len: 60 | org = srch_hair_len.group() 61 | tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags) 62 | 63 | found = PATTERN_HAIR.findall(tags) 64 | if len(found) > 1: 65 | tags = PATTERN_HAIR.sub("", tags) 66 | 67 | if srch_hair_len: 68 | tags = tags.replace(", @@@, ", org) # 戻す 69 | 70 | # white shirtとshirtみたいな重複タグの削除 71 | found = PATTERN_WORD.findall(tags) 72 | for word in found: 73 | if re.search(f", ((\w+) )+{word}, ", tags): 74 | tags = tags.replace(f", {word}, ", "") 75 | 76 | tags = tags.replace(", , ", ", ") 77 | assert tags.startswith(", ") and tags.endswith(", ") 78 | tags = tags[2:-2] 79 | return tags 80 | 81 | 82 | # 上から順に検索、置換される 83 | # ('置換元文字列', '置換後文字列') 84 | CAPTION_REPLACEMENTS = [ 85 | ('anime anime', 'anime'), 86 | ('young ', ''), 87 | ('anime girl', 'girl'), 88 | ('cartoon female', 'girl'), 89 | ('cartoon lady', 'girl'), 90 | ('cartoon character', 'girl'), # a or ~s 91 | ('cartoon woman', 'girl'), 92 | ('cartoon women', 'girls'), 93 | ('cartoon girl', 'girl'), 94 | ('anime female', 'girl'), 95 | ('anime lady', 'girl'), 96 | ('anime character', 'girl'), # a or ~s 97 | ('anime woman', 'girl'), 98 | ('anime women', 'girls'), 99 | ('lady', 'girl'), 100 | ('female', 'girl'), 101 | ('woman', 'girl'), 102 | ('women', 'girls'), 103 | ('people', 'girls'), 104 | ('person', 'girl'), 105 | ('a cartoon figure', 'a figure'), 106 | ('a cartoon image', 'an image'), 107 | ('a cartoon picture', 'a picture'), 108 | ('an anime cartoon image', 'an image'), 109 | ('a cartoon anime drawing', 'a drawing'), 110 | ('a cartoon drawing', 'a drawing'), 111 | ('girl girl', 'girl'), 112 | ] 113 | 114 | 115 | def clean_caption(caption): 116 | for rf, rt in CAPTION_REPLACEMENTS: 117 | replaced = True 118 | while replaced: 119 | bef = caption 120 | caption = caption.replace(rf, rt) 121 | replaced = bef != caption 122 | return caption 123 | 124 | 125 | def main(args): 126 | if os.path.exists(args.in_json): 127 | print(f"loading existing metadata: {args.in_json}") 128 | with open(args.in_json, "rt", encoding='utf-8') as f: 129 | metadata = json.load(f) 130 | else: 131 | print("no metadata / メタデータファイルがありません") 132 | return 133 | 134 | print("cleaning captions and tags.") 135 | image_keys = list(metadata.keys()) 136 | for image_key in tqdm(image_keys): 137 | tags = metadata[image_key].get('tags') 138 | if tags is None: 139 | print(f"image does not have tags / メタデータにタグがありません: {image_key}") 140 | else: 141 | org = tags 142 | tags = clean_tags(image_key, tags) 143 | metadata[image_key]['tags'] = tags 144 | if args.debug and org != tags: 145 | print("FROM: " + org) 146 | print("TO: " + tags) 147 | 148 | caption = metadata[image_key].get('caption') 149 | if caption is None: 150 | print(f"image does not have caption / メタデータにキャプションがありません: {image_key}") 151 | else: 152 | org = caption 153 | caption = clean_caption(caption) 154 | metadata[image_key]['caption'] = caption 155 | if args.debug and org != caption: 156 | print("FROM: " + org) 157 | print("TO: " + caption) 158 | 159 | # metadataを書き出して終わり 160 | print(f"writing metadata: {args.out_json}") 161 | with open(args.out_json, "wt", encoding='utf-8') as f: 162 | json.dump(metadata, f, indent=2) 163 | print("done!") 164 | 165 | 166 | def setup_parser() -> argparse.ArgumentParser: 167 | parser = argparse.ArgumentParser() 168 | # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 169 | parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") 170 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 171 | parser.add_argument("--debug", action="store_true", help="debug mode") 172 | 173 | return parser 174 | 175 | 176 | if __name__ == '__main__': 177 | parser = setup_parser() 178 | 179 | args, unknown = parser.parse_known_args() 180 | if len(unknown) == 1: 181 | print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") 182 | print("All captions and tags in the metadata are processed.") 183 | print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") 184 | print("メタデータ内のすべてのキャプションとタグが処理されます。") 185 | args.in_json = args.out_json 186 | args.out_json = unknown[0] 187 | elif len(unknown) > 0: 188 | raise ValueError(f"error: unrecognized arguments: {unknown}") 189 | 190 | main(args) 191 | -------------------------------------------------------------------------------- /finetune/hypernetwork_nai.py: -------------------------------------------------------------------------------- 1 | # NAI compatible 2 | 3 | import torch 4 | 5 | 6 | class HypernetworkModule(torch.nn.Module): 7 | def __init__(self, dim, multiplier=1.0): 8 | super().__init__() 9 | 10 | linear1 = torch.nn.Linear(dim, dim * 2) 11 | linear2 = torch.nn.Linear(dim * 2, dim) 12 | linear1.weight.data.normal_(mean=0.0, std=0.01) 13 | linear1.bias.data.zero_() 14 | linear2.weight.data.normal_(mean=0.0, std=0.01) 15 | linear2.bias.data.zero_() 16 | linears = [linear1, linear2] 17 | 18 | self.linear = torch.nn.Sequential(*linears) 19 | self.multiplier = multiplier 20 | 21 | def forward(self, x): 22 | return x + self.linear(x) * self.multiplier 23 | 24 | 25 | class Hypernetwork(torch.nn.Module): 26 | enable_sizes = [320, 640, 768, 1280] 27 | # return self.modules[Hypernetwork.enable_sizes.index(size)] 28 | 29 | def __init__(self, multiplier=1.0) -> None: 30 | super().__init__() 31 | self.modules = [] 32 | for size in Hypernetwork.enable_sizes: 33 | self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier))) 34 | self.register_module(f"{size}_0", self.modules[-1][0]) 35 | self.register_module(f"{size}_1", self.modules[-1][1]) 36 | 37 | def apply_to_stable_diffusion(self, text_encoder, vae, unet): 38 | blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks 39 | for block in blocks: 40 | for subblk in block: 41 | if 'SpatialTransformer' in str(type(subblk)): 42 | for tf_block in subblk.transformer_blocks: 43 | for attn in [tf_block.attn1, tf_block.attn2]: 44 | size = attn.context_dim 45 | if size in Hypernetwork.enable_sizes: 46 | attn.hypernetwork = self 47 | else: 48 | attn.hypernetwork = None 49 | 50 | def apply_to_diffusers(self, text_encoder, vae, unet): 51 | blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks 52 | for block in blocks: 53 | if hasattr(block, 'attentions'): 54 | for subblk in block.attentions: 55 | if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~ 56 | for tf_block in subblk.transformer_blocks: 57 | for attn in [tf_block.attn1, tf_block.attn2]: 58 | size = attn.to_k.in_features 59 | if size in Hypernetwork.enable_sizes: 60 | attn.hypernetwork = self 61 | else: 62 | attn.hypernetwork = None 63 | return True # TODO error checking 64 | 65 | def forward(self, x, context): 66 | size = context.shape[-1] 67 | assert size in Hypernetwork.enable_sizes 68 | module = self.modules[Hypernetwork.enable_sizes.index(size)] 69 | return module[0].forward(context), module[1].forward(context) 70 | 71 | def load_from_state_dict(self, state_dict): 72 | # old ver to new ver 73 | changes = { 74 | 'linear1.bias': 'linear.0.bias', 75 | 'linear1.weight': 'linear.0.weight', 76 | 'linear2.bias': 'linear.1.bias', 77 | 'linear2.weight': 'linear.1.weight', 78 | } 79 | for key_from, key_to in changes.items(): 80 | if key_from in state_dict: 81 | state_dict[key_to] = state_dict[key_from] 82 | del state_dict[key_from] 83 | 84 | for size, sd in state_dict.items(): 85 | if type(size) == int: 86 | self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True) 87 | self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True) 88 | return True 89 | 90 | def get_state_dict(self): 91 | state_dict = {} 92 | for i, size in enumerate(Hypernetwork.enable_sizes): 93 | sd0 = self.modules[i][0].state_dict() 94 | sd1 = self.modules[i][1].state_dict() 95 | state_dict[size] = [sd0, sd1] 96 | return state_dict 97 | -------------------------------------------------------------------------------- /finetune/make_captions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import json 5 | import random 6 | import sys 7 | 8 | from pathlib import Path 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import numpy as np 12 | import torch 13 | from torchvision import transforms 14 | from torchvision.transforms.functional import InterpolationMode 15 | sys.path.append(os.path.dirname(__file__)) 16 | from blip.blip import blip_decoder, is_url 17 | import library.train_util as train_util 18 | 19 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | 22 | IMAGE_SIZE = 384 23 | 24 | # 正方形でいいのか? という気がするがソースがそうなので 25 | IMAGE_TRANSFORM = transforms.Compose( 26 | [ 27 | transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC), 28 | transforms.ToTensor(), 29 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 30 | ] 31 | ) 32 | 33 | 34 | # 共通化したいが微妙に処理が異なる…… 35 | class ImageLoadingTransformDataset(torch.utils.data.Dataset): 36 | def __init__(self, image_paths): 37 | self.images = image_paths 38 | 39 | def __len__(self): 40 | return len(self.images) 41 | 42 | def __getitem__(self, idx): 43 | img_path = self.images[idx] 44 | 45 | try: 46 | image = Image.open(img_path).convert("RGB") 47 | # convert to tensor temporarily so dataloader will accept it 48 | tensor = IMAGE_TRANSFORM(image) 49 | except Exception as e: 50 | print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") 51 | return None 52 | 53 | return (tensor, img_path) 54 | 55 | 56 | def collate_fn_remove_corrupted(batch): 57 | """Collate function that allows to remove corrupted examples in the 58 | dataloader. It expects that the dataloader returns 'None' when that occurs. 59 | The 'None's in the batch are removed. 60 | """ 61 | # Filter out all the Nones (corrupted examples) 62 | batch = list(filter(lambda x: x is not None, batch)) 63 | return batch 64 | 65 | 66 | def main(args): 67 | # fix the seed for reproducibility 68 | seed = args.seed # + utils.get_rank() 69 | torch.manual_seed(seed) 70 | np.random.seed(seed) 71 | random.seed(seed) 72 | 73 | if not os.path.exists("blip"): 74 | args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path 75 | 76 | cwd = os.getcwd() 77 | print("Current Working Directory is: ", cwd) 78 | os.chdir("finetune") 79 | if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights): 80 | args.caption_weights = os.path.join("..", args.caption_weights) 81 | 82 | print(f"load images from {args.train_data_dir}") 83 | train_data_dir_path = Path(args.train_data_dir) 84 | image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 85 | print(f"found {len(image_paths)} images.") 86 | 87 | print(f"loading BLIP caption: {args.caption_weights}") 88 | model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json") 89 | model.eval() 90 | model = model.to(DEVICE) 91 | print("BLIP loaded") 92 | 93 | # captioningする 94 | def run_batch(path_imgs): 95 | imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) 96 | 97 | with torch.no_grad(): 98 | if args.beam_search: 99 | captions = model.generate( 100 | imgs, sample=False, num_beams=args.num_beams, max_length=args.max_length, min_length=args.min_length 101 | ) 102 | else: 103 | captions = model.generate( 104 | imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length 105 | ) 106 | 107 | for (image_path, _), caption in zip(path_imgs, captions): 108 | with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: 109 | f.write(caption + "\n") 110 | if args.debug: 111 | print(image_path, caption) 112 | 113 | # 読み込みの高速化のためにDataLoaderを使うオプション 114 | if args.max_data_loader_n_workers is not None: 115 | dataset = ImageLoadingTransformDataset(image_paths) 116 | data = torch.utils.data.DataLoader( 117 | dataset, 118 | batch_size=args.batch_size, 119 | shuffle=False, 120 | num_workers=args.max_data_loader_n_workers, 121 | collate_fn=collate_fn_remove_corrupted, 122 | drop_last=False, 123 | ) 124 | else: 125 | data = [[(None, ip)] for ip in image_paths] 126 | 127 | b_imgs = [] 128 | for data_entry in tqdm(data, smoothing=0.0): 129 | for data in data_entry: 130 | if data is None: 131 | continue 132 | 133 | img_tensor, image_path = data 134 | if img_tensor is None: 135 | try: 136 | raw_image = Image.open(image_path) 137 | if raw_image.mode != "RGB": 138 | raw_image = raw_image.convert("RGB") 139 | img_tensor = IMAGE_TRANSFORM(raw_image) 140 | except Exception as e: 141 | print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") 142 | continue 143 | 144 | b_imgs.append((image_path, img_tensor)) 145 | if len(b_imgs) >= args.batch_size: 146 | run_batch(b_imgs) 147 | b_imgs.clear() 148 | if len(b_imgs) > 0: 149 | run_batch(b_imgs) 150 | 151 | print("done!") 152 | 153 | 154 | def setup_parser() -> argparse.ArgumentParser: 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 157 | parser.add_argument( 158 | "--caption_weights", 159 | type=str, 160 | default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", 161 | help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)", 162 | ) 163 | parser.add_argument( 164 | "--caption_extention", 165 | type=str, 166 | default=None, 167 | help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)", 168 | ) 169 | parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") 170 | parser.add_argument( 171 | "--beam_search", 172 | action="store_true", 173 | help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)", 174 | ) 175 | parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") 176 | parser.add_argument( 177 | "--max_data_loader_n_workers", 178 | type=int, 179 | default=None, 180 | help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", 181 | ) 182 | parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") 183 | parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") 184 | parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") 185 | parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") 186 | parser.add_argument("--seed", default=42, type=int, help="seed for reproducibility / 再現性を確保するための乱数seed") 187 | parser.add_argument("--debug", action="store_true", help="debug mode") 188 | parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") 189 | 190 | return parser 191 | 192 | 193 | if __name__ == "__main__": 194 | parser = setup_parser() 195 | 196 | args = parser.parse_args() 197 | 198 | # スペルミスしていたオプションを復元する 199 | if args.caption_extention is not None: 200 | args.caption_extension = args.caption_extention 201 | 202 | main(args) 203 | -------------------------------------------------------------------------------- /finetune/make_captions_by_git.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | 5 | from pathlib import Path 6 | from PIL import Image 7 | from tqdm import tqdm 8 | import torch 9 | from transformers import AutoProcessor, AutoModelForCausalLM 10 | from transformers.generation.utils import GenerationMixin 11 | 12 | import library.train_util as train_util 13 | 14 | 15 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | PATTERN_REPLACE = [ 18 | re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'), 19 | re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'), 20 | re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"), 21 | re.compile(r"with the number \d+ on (it|\w+ \w+)"), 22 | re.compile(r'with the words "'), 23 | re.compile(r"word \w+ on it"), 24 | re.compile(r"that says the word \w+ on it"), 25 | re.compile("that says'the word \"( on it)?"), 26 | ] 27 | 28 | # 誤検知しまくりの with the word xxxx を消す 29 | 30 | 31 | def remove_words(captions, debug): 32 | removed_caps = [] 33 | for caption in captions: 34 | cap = caption 35 | for pat in PATTERN_REPLACE: 36 | cap = pat.sub("", cap) 37 | if debug and cap != caption: 38 | print(caption) 39 | print(cap) 40 | removed_caps.append(cap) 41 | return removed_caps 42 | 43 | 44 | def collate_fn_remove_corrupted(batch): 45 | """Collate function that allows to remove corrupted examples in the 46 | dataloader. It expects that the dataloader returns 'None' when that occurs. 47 | The 'None's in the batch are removed. 48 | """ 49 | # Filter out all the Nones (corrupted examples) 50 | batch = list(filter(lambda x: x is not None, batch)) 51 | return batch 52 | 53 | 54 | def main(args): 55 | r""" 56 | transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト 57 | 58 | # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用 59 | org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation 60 | curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように 61 | 62 | # input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す 63 | # ここより上で置き換えようとするとすごく大変 64 | def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs): 65 | input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs) 66 | if input_ids.size()[0] != curr_batch_size[0]: 67 | input_ids = input_ids.repeat(curr_batch_size[0], 1) 68 | return input_ids 69 | 70 | GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch 71 | """ 72 | 73 | print(f"load images from {args.train_data_dir}") 74 | train_data_dir_path = Path(args.train_data_dir) 75 | image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 76 | print(f"found {len(image_paths)} images.") 77 | 78 | # できればcacheに依存せず明示的にダウンロードしたい 79 | print(f"loading GIT: {args.model_id}") 80 | git_processor = AutoProcessor.from_pretrained(args.model_id) 81 | git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) 82 | print("GIT loaded") 83 | 84 | # captioningする 85 | def run_batch(path_imgs): 86 | imgs = [im for _, im in path_imgs] 87 | 88 | # curr_batch_size[0] = len(path_imgs) 89 | inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式 90 | generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length) 91 | captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True) 92 | 93 | if args.remove_words: 94 | captions = remove_words(captions, args.debug) 95 | 96 | for (image_path, _), caption in zip(path_imgs, captions): 97 | with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: 98 | f.write(caption + "\n") 99 | if args.debug: 100 | print(image_path, caption) 101 | 102 | # 読み込みの高速化のためにDataLoaderを使うオプション 103 | if args.max_data_loader_n_workers is not None: 104 | dataset = train_util.ImageLoadingDataset(image_paths) 105 | data = torch.utils.data.DataLoader( 106 | dataset, 107 | batch_size=args.batch_size, 108 | shuffle=False, 109 | num_workers=args.max_data_loader_n_workers, 110 | collate_fn=collate_fn_remove_corrupted, 111 | drop_last=False, 112 | ) 113 | else: 114 | data = [[(None, ip)] for ip in image_paths] 115 | 116 | b_imgs = [] 117 | for data_entry in tqdm(data, smoothing=0.0): 118 | for data in data_entry: 119 | if data is None: 120 | continue 121 | 122 | image, image_path = data 123 | if image is None: 124 | try: 125 | image = Image.open(image_path) 126 | if image.mode != "RGB": 127 | image = image.convert("RGB") 128 | except Exception as e: 129 | print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") 130 | continue 131 | 132 | b_imgs.append((image_path, image)) 133 | if len(b_imgs) >= args.batch_size: 134 | run_batch(b_imgs) 135 | b_imgs.clear() 136 | 137 | if len(b_imgs) > 0: 138 | run_batch(b_imgs) 139 | 140 | print("done!") 141 | 142 | 143 | def setup_parser() -> argparse.ArgumentParser: 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 146 | parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") 147 | parser.add_argument( 148 | "--model_id", 149 | type=str, 150 | default="microsoft/git-large-textcaps", 151 | help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID", 152 | ) 153 | parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") 154 | parser.add_argument( 155 | "--max_data_loader_n_workers", 156 | type=int, 157 | default=None, 158 | help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", 159 | ) 160 | parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長") 161 | parser.add_argument( 162 | "--remove_words", 163 | action="store_true", 164 | help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する", 165 | ) 166 | parser.add_argument("--debug", action="store_true", help="debug mode") 167 | parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") 168 | 169 | return parser 170 | 171 | 172 | if __name__ == "__main__": 173 | parser = setup_parser() 174 | 175 | args = parser.parse_args() 176 | main(args) 177 | -------------------------------------------------------------------------------- /finetune/merge_captions_to_metadata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from typing import List 5 | from tqdm import tqdm 6 | import library.train_util as train_util 7 | import os 8 | 9 | def main(args): 10 | assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" 11 | 12 | train_data_dir_path = Path(args.train_data_dir) 13 | image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 14 | print(f"found {len(image_paths)} images.") 15 | 16 | if args.in_json is None and Path(args.out_json).is_file(): 17 | args.in_json = args.out_json 18 | 19 | if args.in_json is not None: 20 | print(f"loading existing metadata: {args.in_json}") 21 | metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) 22 | print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") 23 | else: 24 | print("new metadata will be created / 新しいメタデータファイルが作成されます") 25 | metadata = {} 26 | 27 | print("merge caption texts to metadata json.") 28 | for image_path in tqdm(image_paths): 29 | caption_path = image_path.with_suffix(args.caption_extension) 30 | caption = caption_path.read_text(encoding='utf-8').strip() 31 | 32 | if not os.path.exists(caption_path): 33 | caption_path = os.path.join(image_path, args.caption_extension) 34 | 35 | image_key = str(image_path) if args.full_path else image_path.stem 36 | if image_key not in metadata: 37 | metadata[image_key] = {} 38 | 39 | metadata[image_key]['caption'] = caption 40 | if args.debug: 41 | print(image_key, caption) 42 | 43 | # metadataを書き出して終わり 44 | print(f"writing metadata: {args.out_json}") 45 | Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') 46 | print("done!") 47 | 48 | 49 | def setup_parser() -> argparse.ArgumentParser: 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 52 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 53 | parser.add_argument("--in_json", type=str, 54 | help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") 55 | parser.add_argument("--caption_extention", type=str, default=None, 56 | help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)") 57 | parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子") 58 | parser.add_argument("--full_path", action="store_true", 59 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") 60 | parser.add_argument("--recursive", action="store_true", 61 | help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") 62 | parser.add_argument("--debug", action="store_true", help="debug mode") 63 | 64 | return parser 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = setup_parser() 69 | 70 | args = parser.parse_args() 71 | 72 | # スペルミスしていたオプションを復元する 73 | if args.caption_extention is not None: 74 | args.caption_extension = args.caption_extention 75 | 76 | main(args) 77 | -------------------------------------------------------------------------------- /finetune/merge_dd_tags_to_metadata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from typing import List 5 | from tqdm import tqdm 6 | import library.train_util as train_util 7 | import os 8 | 9 | def main(args): 10 | assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" 11 | 12 | train_data_dir_path = Path(args.train_data_dir) 13 | image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 14 | print(f"found {len(image_paths)} images.") 15 | 16 | if args.in_json is None and Path(args.out_json).is_file(): 17 | args.in_json = args.out_json 18 | 19 | if args.in_json is not None: 20 | print(f"loading existing metadata: {args.in_json}") 21 | metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) 22 | print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") 23 | else: 24 | print("new metadata will be created / 新しいメタデータファイルが作成されます") 25 | metadata = {} 26 | 27 | print("merge tags to metadata json.") 28 | for image_path in tqdm(image_paths): 29 | tags_path = image_path.with_suffix(args.caption_extension) 30 | tags = tags_path.read_text(encoding='utf-8').strip() 31 | 32 | if not os.path.exists(tags_path): 33 | tags_path = os.path.join(image_path, args.caption_extension) 34 | 35 | image_key = str(image_path) if args.full_path else image_path.stem 36 | if image_key not in metadata: 37 | metadata[image_key] = {} 38 | 39 | metadata[image_key]['tags'] = tags 40 | if args.debug: 41 | print(image_key, tags) 42 | 43 | # metadataを書き出して終わり 44 | print(f"writing metadata: {args.out_json}") 45 | Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') 46 | 47 | print("done!") 48 | 49 | 50 | def setup_parser() -> argparse.ArgumentParser: 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 53 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 54 | parser.add_argument("--in_json", type=str, 55 | help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)") 56 | parser.add_argument("--full_path", action="store_true", 57 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") 58 | parser.add_argument("--recursive", action="store_true", 59 | help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す") 60 | parser.add_argument("--caption_extension", type=str, default=".txt", 61 | help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子") 62 | parser.add_argument("--debug", action="store_true", help="debug mode, print tags") 63 | 64 | return parser 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = setup_parser() 69 | 70 | args = parser.parse_args() 71 | main(args) 72 | -------------------------------------------------------------------------------- /library/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cagliostrolab/sd-scripts-ani3/a8cf51571ebbcb4b8d0d413b1e02d68d7d53bbbf/library/__init__.py -------------------------------------------------------------------------------- /library/attention_processors.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any 3 | from einops import rearrange 4 | import torch 5 | from diffusers.models.attention_processor import Attention 6 | 7 | 8 | # flash attention forwards and backwards 9 | 10 | # https://arxiv.org/abs/2205.14135 11 | 12 | EPSILON = 1e-6 13 | 14 | 15 | class FlashAttentionFunction(torch.autograd.function.Function): 16 | @staticmethod 17 | @torch.no_grad() 18 | def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): 19 | """Algorithm 2 in the paper""" 20 | 21 | device = q.device 22 | dtype = q.dtype 23 | max_neg_value = -torch.finfo(q.dtype).max 24 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) 25 | 26 | o = torch.zeros_like(q) 27 | all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) 28 | all_row_maxes = torch.full( 29 | (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device 30 | ) 31 | 32 | scale = q.shape[-1] ** -0.5 33 | 34 | if mask is None: 35 | mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) 36 | else: 37 | mask = rearrange(mask, "b n -> b 1 1 n") 38 | mask = mask.split(q_bucket_size, dim=-1) 39 | 40 | row_splits = zip( 41 | q.split(q_bucket_size, dim=-2), 42 | o.split(q_bucket_size, dim=-2), 43 | mask, 44 | all_row_sums.split(q_bucket_size, dim=-2), 45 | all_row_maxes.split(q_bucket_size, dim=-2), 46 | ) 47 | 48 | for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): 49 | q_start_index = ind * q_bucket_size - qk_len_diff 50 | 51 | col_splits = zip( 52 | k.split(k_bucket_size, dim=-2), 53 | v.split(k_bucket_size, dim=-2), 54 | ) 55 | 56 | for k_ind, (kc, vc) in enumerate(col_splits): 57 | k_start_index = k_ind * k_bucket_size 58 | 59 | attn_weights = ( 60 | torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale 61 | ) 62 | 63 | if row_mask is not None: 64 | attn_weights.masked_fill_(~row_mask, max_neg_value) 65 | 66 | if causal and q_start_index < (k_start_index + k_bucket_size - 1): 67 | causal_mask = torch.ones( 68 | (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device 69 | ).triu(q_start_index - k_start_index + 1) 70 | attn_weights.masked_fill_(causal_mask, max_neg_value) 71 | 72 | block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) 73 | attn_weights -= block_row_maxes 74 | exp_weights = torch.exp(attn_weights) 75 | 76 | if row_mask is not None: 77 | exp_weights.masked_fill_(~row_mask, 0.0) 78 | 79 | block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( 80 | min=EPSILON 81 | ) 82 | 83 | new_row_maxes = torch.maximum(block_row_maxes, row_maxes) 84 | 85 | exp_values = torch.einsum( 86 | "... i j, ... j d -> ... i d", exp_weights, vc 87 | ) 88 | 89 | exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) 90 | exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) 91 | 92 | new_row_sums = ( 93 | exp_row_max_diff * row_sums 94 | + exp_block_row_max_diff * block_row_sums 95 | ) 96 | 97 | oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( 98 | (exp_block_row_max_diff / new_row_sums) * exp_values 99 | ) 100 | 101 | row_maxes.copy_(new_row_maxes) 102 | row_sums.copy_(new_row_sums) 103 | 104 | ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) 105 | ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) 106 | 107 | return o 108 | 109 | @staticmethod 110 | @torch.no_grad() 111 | def backward(ctx, do): 112 | """Algorithm 4 in the paper""" 113 | 114 | causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args 115 | q, k, v, o, l, m = ctx.saved_tensors 116 | 117 | device = q.device 118 | 119 | max_neg_value = -torch.finfo(q.dtype).max 120 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) 121 | 122 | dq = torch.zeros_like(q) 123 | dk = torch.zeros_like(k) 124 | dv = torch.zeros_like(v) 125 | 126 | row_splits = zip( 127 | q.split(q_bucket_size, dim=-2), 128 | o.split(q_bucket_size, dim=-2), 129 | do.split(q_bucket_size, dim=-2), 130 | mask, 131 | l.split(q_bucket_size, dim=-2), 132 | m.split(q_bucket_size, dim=-2), 133 | dq.split(q_bucket_size, dim=-2), 134 | ) 135 | 136 | for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): 137 | q_start_index = ind * q_bucket_size - qk_len_diff 138 | 139 | col_splits = zip( 140 | k.split(k_bucket_size, dim=-2), 141 | v.split(k_bucket_size, dim=-2), 142 | dk.split(k_bucket_size, dim=-2), 143 | dv.split(k_bucket_size, dim=-2), 144 | ) 145 | 146 | for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): 147 | k_start_index = k_ind * k_bucket_size 148 | 149 | attn_weights = ( 150 | torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale 151 | ) 152 | 153 | if causal and q_start_index < (k_start_index + k_bucket_size - 1): 154 | causal_mask = torch.ones( 155 | (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device 156 | ).triu(q_start_index - k_start_index + 1) 157 | attn_weights.masked_fill_(causal_mask, max_neg_value) 158 | 159 | exp_attn_weights = torch.exp(attn_weights - mc) 160 | 161 | if row_mask is not None: 162 | exp_attn_weights.masked_fill_(~row_mask, 0.0) 163 | 164 | p = exp_attn_weights / lc 165 | 166 | dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) 167 | dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) 168 | 169 | D = (doc * oc).sum(dim=-1, keepdims=True) 170 | ds = p * scale * (dp - D) 171 | 172 | dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) 173 | dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) 174 | 175 | dqc.add_(dq_chunk) 176 | dkc.add_(dk_chunk) 177 | dvc.add_(dv_chunk) 178 | 179 | return dq, dk, dv, None, None, None, None 180 | 181 | 182 | class FlashAttnProcessor: 183 | def __call__( 184 | self, 185 | attn: Attention, 186 | hidden_states, 187 | encoder_hidden_states=None, 188 | attention_mask=None, 189 | ) -> Any: 190 | q_bucket_size = 512 191 | k_bucket_size = 1024 192 | 193 | h = attn.heads 194 | q = attn.to_q(hidden_states) 195 | 196 | encoder_hidden_states = ( 197 | encoder_hidden_states 198 | if encoder_hidden_states is not None 199 | else hidden_states 200 | ) 201 | encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) 202 | 203 | if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None: 204 | context_k, context_v = attn.hypernetwork.forward( 205 | hidden_states, encoder_hidden_states 206 | ) 207 | context_k = context_k.to(hidden_states.dtype) 208 | context_v = context_v.to(hidden_states.dtype) 209 | else: 210 | context_k = encoder_hidden_states 211 | context_v = encoder_hidden_states 212 | 213 | k = attn.to_k(context_k) 214 | v = attn.to_v(context_v) 215 | del encoder_hidden_states, hidden_states 216 | 217 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) 218 | 219 | out = FlashAttentionFunction.apply( 220 | q, k, v, attention_mask, False, q_bucket_size, k_bucket_size 221 | ) 222 | 223 | out = rearrange(out, "b h n d -> b n (h d)") 224 | 225 | out = attn.to_out[0](out) 226 | out = attn.to_out[1](out) 227 | return out 228 | -------------------------------------------------------------------------------- /library/huggingface_util.py: -------------------------------------------------------------------------------- 1 | from typing import Union, BinaryIO 2 | from huggingface_hub import HfApi 3 | from pathlib import Path 4 | import argparse 5 | import os 6 | from library.utils import fire_in_thread 7 | 8 | 9 | def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): 10 | api = HfApi( 11 | token=token, 12 | ) 13 | try: 14 | api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) 15 | return True 16 | except: 17 | return False 18 | 19 | 20 | def upload( 21 | args: argparse.Namespace, 22 | src: Union[str, Path, bytes, BinaryIO], 23 | dest_suffix: str = "", 24 | force_sync_upload: bool = False, 25 | ): 26 | repo_id = args.huggingface_repo_id 27 | repo_type = args.huggingface_repo_type 28 | token = args.huggingface_token 29 | path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None 30 | private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public" 31 | api = HfApi(token=token) 32 | if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token): 33 | try: 34 | api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) 35 | except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので 36 | print("===========================================") 37 | print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") 38 | print("===========================================") 39 | 40 | is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir()) 41 | 42 | def uploader(): 43 | try: 44 | if is_folder: 45 | api.upload_folder( 46 | repo_id=repo_id, 47 | repo_type=repo_type, 48 | folder_path=src, 49 | path_in_repo=path_in_repo, 50 | ) 51 | else: 52 | api.upload_file( 53 | repo_id=repo_id, 54 | repo_type=repo_type, 55 | path_or_fileobj=src, 56 | path_in_repo=path_in_repo, 57 | ) 58 | except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので 59 | print("===========================================") 60 | print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") 61 | print("===========================================") 62 | 63 | if args.async_upload and not force_sync_upload: 64 | fire_in_thread(uploader) 65 | else: 66 | uploader() 67 | 68 | 69 | def list_dir( 70 | repo_id: str, 71 | subfolder: str, 72 | repo_type: str, 73 | revision: str = "main", 74 | token: str = None, 75 | ): 76 | api = HfApi( 77 | token=token, 78 | ) 79 | repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) 80 | file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)] 81 | return file_list 82 | -------------------------------------------------------------------------------- /library/hypernetwork.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from diffusers.models.attention_processor import ( 4 | Attention, 5 | AttnProcessor2_0, 6 | SlicedAttnProcessor, 7 | XFormersAttnProcessor 8 | ) 9 | 10 | try: 11 | import xformers.ops 12 | except: 13 | xformers = None 14 | 15 | 16 | loaded_networks = [] 17 | 18 | 19 | def apply_single_hypernetwork( 20 | hypernetwork, hidden_states, encoder_hidden_states 21 | ): 22 | context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states) 23 | return context_k, context_v 24 | 25 | 26 | def apply_hypernetworks(context_k, context_v, layer=None): 27 | if len(loaded_networks) == 0: 28 | return context_v, context_v 29 | for hypernetwork in loaded_networks: 30 | context_k, context_v = hypernetwork.forward(context_k, context_v) 31 | 32 | context_k = context_k.to(dtype=context_k.dtype) 33 | context_v = context_v.to(dtype=context_k.dtype) 34 | 35 | return context_k, context_v 36 | 37 | 38 | 39 | def xformers_forward( 40 | self: XFormersAttnProcessor, 41 | attn: Attention, 42 | hidden_states: torch.Tensor, 43 | encoder_hidden_states: torch.Tensor = None, 44 | attention_mask: torch.Tensor = None, 45 | ): 46 | batch_size, sequence_length, _ = ( 47 | hidden_states.shape 48 | if encoder_hidden_states is None 49 | else encoder_hidden_states.shape 50 | ) 51 | 52 | attention_mask = attn.prepare_attention_mask( 53 | attention_mask, sequence_length, batch_size 54 | ) 55 | 56 | query = attn.to_q(hidden_states) 57 | 58 | if encoder_hidden_states is None: 59 | encoder_hidden_states = hidden_states 60 | elif attn.norm_cross: 61 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 62 | 63 | context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) 64 | 65 | key = attn.to_k(context_k) 66 | value = attn.to_v(context_v) 67 | 68 | query = attn.head_to_batch_dim(query).contiguous() 69 | key = attn.head_to_batch_dim(key).contiguous() 70 | value = attn.head_to_batch_dim(value).contiguous() 71 | 72 | hidden_states = xformers.ops.memory_efficient_attention( 73 | query, 74 | key, 75 | value, 76 | attn_bias=attention_mask, 77 | op=self.attention_op, 78 | scale=attn.scale, 79 | ) 80 | hidden_states = hidden_states.to(query.dtype) 81 | hidden_states = attn.batch_to_head_dim(hidden_states) 82 | 83 | # linear proj 84 | hidden_states = attn.to_out[0](hidden_states) 85 | # dropout 86 | hidden_states = attn.to_out[1](hidden_states) 87 | return hidden_states 88 | 89 | 90 | def sliced_attn_forward( 91 | self: SlicedAttnProcessor, 92 | attn: Attention, 93 | hidden_states: torch.Tensor, 94 | encoder_hidden_states: torch.Tensor = None, 95 | attention_mask: torch.Tensor = None, 96 | ): 97 | batch_size, sequence_length, _ = ( 98 | hidden_states.shape 99 | if encoder_hidden_states is None 100 | else encoder_hidden_states.shape 101 | ) 102 | attention_mask = attn.prepare_attention_mask( 103 | attention_mask, sequence_length, batch_size 104 | ) 105 | 106 | query = attn.to_q(hidden_states) 107 | dim = query.shape[-1] 108 | query = attn.head_to_batch_dim(query) 109 | 110 | if encoder_hidden_states is None: 111 | encoder_hidden_states = hidden_states 112 | elif attn.norm_cross: 113 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 114 | 115 | context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) 116 | 117 | key = attn.to_k(context_k) 118 | value = attn.to_v(context_v) 119 | key = attn.head_to_batch_dim(key) 120 | value = attn.head_to_batch_dim(value) 121 | 122 | batch_size_attention, query_tokens, _ = query.shape 123 | hidden_states = torch.zeros( 124 | (batch_size_attention, query_tokens, dim // attn.heads), 125 | device=query.device, 126 | dtype=query.dtype, 127 | ) 128 | 129 | for i in range(batch_size_attention // self.slice_size): 130 | start_idx = i * self.slice_size 131 | end_idx = (i + 1) * self.slice_size 132 | 133 | query_slice = query[start_idx:end_idx] 134 | key_slice = key[start_idx:end_idx] 135 | attn_mask_slice = ( 136 | attention_mask[start_idx:end_idx] if attention_mask is not None else None 137 | ) 138 | 139 | attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) 140 | 141 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) 142 | 143 | hidden_states[start_idx:end_idx] = attn_slice 144 | 145 | hidden_states = attn.batch_to_head_dim(hidden_states) 146 | 147 | # linear proj 148 | hidden_states = attn.to_out[0](hidden_states) 149 | # dropout 150 | hidden_states = attn.to_out[1](hidden_states) 151 | 152 | return hidden_states 153 | 154 | 155 | def v2_0_forward( 156 | self: AttnProcessor2_0, 157 | attn: Attention, 158 | hidden_states, 159 | encoder_hidden_states=None, 160 | attention_mask=None, 161 | ): 162 | batch_size, sequence_length, _ = ( 163 | hidden_states.shape 164 | if encoder_hidden_states is None 165 | else encoder_hidden_states.shape 166 | ) 167 | inner_dim = hidden_states.shape[-1] 168 | 169 | if attention_mask is not None: 170 | attention_mask = attn.prepare_attention_mask( 171 | attention_mask, sequence_length, batch_size 172 | ) 173 | # scaled_dot_product_attention expects attention_mask shape to be 174 | # (batch, heads, source_length, target_length) 175 | attention_mask = attention_mask.view( 176 | batch_size, attn.heads, -1, attention_mask.shape[-1] 177 | ) 178 | 179 | query = attn.to_q(hidden_states) 180 | 181 | if encoder_hidden_states is None: 182 | encoder_hidden_states = hidden_states 183 | elif attn.norm_cross: 184 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 185 | 186 | context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) 187 | 188 | key = attn.to_k(context_k) 189 | value = attn.to_v(context_v) 190 | 191 | head_dim = inner_dim // attn.heads 192 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 193 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 194 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 195 | 196 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 197 | # TODO: add support for attn.scale when we move to Torch 2.1 198 | hidden_states = F.scaled_dot_product_attention( 199 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 200 | ) 201 | 202 | hidden_states = hidden_states.transpose(1, 2).reshape( 203 | batch_size, -1, attn.heads * head_dim 204 | ) 205 | hidden_states = hidden_states.to(query.dtype) 206 | 207 | # linear proj 208 | hidden_states = attn.to_out[0](hidden_states) 209 | # dropout 210 | hidden_states = attn.to_out[1](hidden_states) 211 | return hidden_states 212 | 213 | 214 | def replace_attentions_for_hypernetwork(): 215 | import diffusers.models.attention_processor 216 | 217 | diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = ( 218 | xformers_forward 219 | ) 220 | diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = ( 221 | sliced_attn_forward 222 | ) 223 | diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward 224 | -------------------------------------------------------------------------------- /library/ipex/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import contextlib 4 | import torch 5 | import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import 6 | from .hijacks import ipex_hijacks 7 | 8 | # pylint: disable=protected-access, missing-function-docstring, line-too-long 9 | 10 | def ipex_init(): # pylint: disable=too-many-statements 11 | try: 12 | # Replace cuda with xpu: 13 | torch.cuda.current_device = torch.xpu.current_device 14 | torch.cuda.current_stream = torch.xpu.current_stream 15 | torch.cuda.device = torch.xpu.device 16 | torch.cuda.device_count = torch.xpu.device_count 17 | torch.cuda.device_of = torch.xpu.device_of 18 | torch.cuda.get_device_name = torch.xpu.get_device_name 19 | torch.cuda.get_device_properties = torch.xpu.get_device_properties 20 | torch.cuda.init = torch.xpu.init 21 | torch.cuda.is_available = torch.xpu.is_available 22 | torch.cuda.is_initialized = torch.xpu.is_initialized 23 | torch.cuda.is_current_stream_capturing = lambda: False 24 | torch.cuda.set_device = torch.xpu.set_device 25 | torch.cuda.stream = torch.xpu.stream 26 | torch.cuda.synchronize = torch.xpu.synchronize 27 | torch.cuda.Event = torch.xpu.Event 28 | torch.cuda.Stream = torch.xpu.Stream 29 | torch.cuda.FloatTensor = torch.xpu.FloatTensor 30 | torch.Tensor.cuda = torch.Tensor.xpu 31 | torch.Tensor.is_cuda = torch.Tensor.is_xpu 32 | torch.UntypedStorage.cuda = torch.UntypedStorage.xpu 33 | torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock 34 | torch.cuda._initialized = torch.xpu.lazy_init._initialized 35 | torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker 36 | torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls 37 | torch.cuda._tls = torch.xpu.lazy_init._tls 38 | torch.cuda.threading = torch.xpu.lazy_init.threading 39 | torch.cuda.traceback = torch.xpu.lazy_init.traceback 40 | torch.cuda.Optional = torch.xpu.Optional 41 | torch.cuda.__cached__ = torch.xpu.__cached__ 42 | torch.cuda.__loader__ = torch.xpu.__loader__ 43 | torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage 44 | torch.cuda.Tuple = torch.xpu.Tuple 45 | torch.cuda.streams = torch.xpu.streams 46 | torch.cuda._lazy_new = torch.xpu._lazy_new 47 | torch.cuda.FloatStorage = torch.xpu.FloatStorage 48 | torch.cuda.Any = torch.xpu.Any 49 | torch.cuda.__doc__ = torch.xpu.__doc__ 50 | torch.cuda.default_generators = torch.xpu.default_generators 51 | torch.cuda.HalfTensor = torch.xpu.HalfTensor 52 | torch.cuda._get_device_index = torch.xpu._get_device_index 53 | torch.cuda.__path__ = torch.xpu.__path__ 54 | torch.cuda.Device = torch.xpu.Device 55 | torch.cuda.IntTensor = torch.xpu.IntTensor 56 | torch.cuda.ByteStorage = torch.xpu.ByteStorage 57 | torch.cuda.set_stream = torch.xpu.set_stream 58 | torch.cuda.BoolStorage = torch.xpu.BoolStorage 59 | torch.cuda.os = torch.xpu.os 60 | torch.cuda.torch = torch.xpu.torch 61 | torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage 62 | torch.cuda.Union = torch.xpu.Union 63 | torch.cuda.DoubleTensor = torch.xpu.DoubleTensor 64 | torch.cuda.ShortTensor = torch.xpu.ShortTensor 65 | torch.cuda.LongTensor = torch.xpu.LongTensor 66 | torch.cuda.IntStorage = torch.xpu.IntStorage 67 | torch.cuda.LongStorage = torch.xpu.LongStorage 68 | torch.cuda.__annotations__ = torch.xpu.__annotations__ 69 | torch.cuda.__package__ = torch.xpu.__package__ 70 | torch.cuda.__builtins__ = torch.xpu.__builtins__ 71 | torch.cuda.CharTensor = torch.xpu.CharTensor 72 | torch.cuda.List = torch.xpu.List 73 | torch.cuda._lazy_init = torch.xpu._lazy_init 74 | torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor 75 | torch.cuda.DoubleStorage = torch.xpu.DoubleStorage 76 | torch.cuda.ByteTensor = torch.xpu.ByteTensor 77 | torch.cuda.StreamContext = torch.xpu.StreamContext 78 | torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage 79 | torch.cuda.ShortStorage = torch.xpu.ShortStorage 80 | torch.cuda._lazy_call = torch.xpu._lazy_call 81 | torch.cuda.HalfStorage = torch.xpu.HalfStorage 82 | torch.cuda.random = torch.xpu.random 83 | torch.cuda._device = torch.xpu._device 84 | torch.cuda.classproperty = torch.xpu.classproperty 85 | torch.cuda.__name__ = torch.xpu.__name__ 86 | torch.cuda._device_t = torch.xpu._device_t 87 | torch.cuda.warnings = torch.xpu.warnings 88 | torch.cuda.__spec__ = torch.xpu.__spec__ 89 | torch.cuda.BoolTensor = torch.xpu.BoolTensor 90 | torch.cuda.CharStorage = torch.xpu.CharStorage 91 | torch.cuda.__file__ = torch.xpu.__file__ 92 | torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork 93 | # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing 94 | 95 | # Memory: 96 | torch.cuda.memory = torch.xpu.memory 97 | if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): 98 | torch.xpu.empty_cache = lambda: None 99 | torch.cuda.empty_cache = torch.xpu.empty_cache 100 | torch.cuda.memory_stats = torch.xpu.memory_stats 101 | torch.cuda.memory_summary = torch.xpu.memory_summary 102 | torch.cuda.memory_snapshot = torch.xpu.memory_snapshot 103 | torch.cuda.memory_allocated = torch.xpu.memory_allocated 104 | torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated 105 | torch.cuda.memory_reserved = torch.xpu.memory_reserved 106 | torch.cuda.memory_cached = torch.xpu.memory_reserved 107 | torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved 108 | torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved 109 | torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats 110 | torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats 111 | torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats 112 | torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict 113 | torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats 114 | 115 | # RNG: 116 | torch.cuda.get_rng_state = torch.xpu.get_rng_state 117 | torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all 118 | torch.cuda.set_rng_state = torch.xpu.set_rng_state 119 | torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all 120 | torch.cuda.manual_seed = torch.xpu.manual_seed 121 | torch.cuda.manual_seed_all = torch.xpu.manual_seed_all 122 | torch.cuda.seed = torch.xpu.seed 123 | torch.cuda.seed_all = torch.xpu.seed_all 124 | torch.cuda.initial_seed = torch.xpu.initial_seed 125 | 126 | # AMP: 127 | torch.cuda.amp = torch.xpu.amp 128 | if not hasattr(torch.cuda.amp, "common"): 129 | torch.cuda.amp.common = contextlib.nullcontext() 130 | torch.cuda.amp.common.amp_definitely_not_available = lambda: False 131 | try: 132 | torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler 133 | except Exception: # pylint: disable=broad-exception-caught 134 | try: 135 | from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error 136 | gradscaler_init() 137 | torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler 138 | except Exception: # pylint: disable=broad-exception-caught 139 | torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler 140 | 141 | # C 142 | torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream 143 | ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count 144 | ipex._C._DeviceProperties.major = 2023 145 | ipex._C._DeviceProperties.minor = 2 146 | 147 | # Fix functions with ipex: 148 | torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] 149 | torch._utils._get_available_device_type = lambda: "xpu" 150 | torch.has_cuda = True 151 | torch.cuda.has_half = True 152 | torch.cuda.is_bf16_supported = lambda *args, **kwargs: True 153 | torch.cuda.is_fp16_supported = lambda *args, **kwargs: True 154 | torch.version.cuda = "11.7" 155 | torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] 156 | torch.cuda.get_device_properties.major = 11 157 | torch.cuda.get_device_properties.minor = 7 158 | torch.cuda.ipc_collect = lambda *args, **kwargs: None 159 | torch.cuda.utilization = lambda *args, **kwargs: 0 160 | 161 | ipex_hijacks() 162 | if not torch.xpu.has_fp64_dtype(): 163 | try: 164 | from .diffusers import ipex_diffusers 165 | ipex_diffusers() 166 | except Exception: # pylint: disable=broad-exception-caught 167 | pass 168 | except Exception as e: 169 | return False, e 170 | return True, None 171 | -------------------------------------------------------------------------------- /library/ipex/attention.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import 4 | from functools import cache 5 | 6 | # pylint: disable=protected-access, missing-function-docstring, line-too-long 7 | 8 | # ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers 9 | 10 | sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) 11 | attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) 12 | 13 | # Find something divisible with the input_tokens 14 | @cache 15 | def find_slice_size(slice_size, slice_block_size): 16 | while (slice_size * slice_block_size) > attention_slice_rate: 17 | slice_size = slice_size // 2 18 | if slice_size <= 1: 19 | slice_size = 1 20 | break 21 | return slice_size 22 | 23 | # Find slice sizes for SDPA 24 | @cache 25 | def find_sdpa_slice_sizes(query_shape, query_element_size): 26 | if len(query_shape) == 3: 27 | batch_size_attention, query_tokens, shape_three = query_shape 28 | shape_four = 1 29 | else: 30 | batch_size_attention, query_tokens, shape_three, shape_four = query_shape 31 | 32 | slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size 33 | block_size = batch_size_attention * slice_block_size 34 | 35 | split_slice_size = batch_size_attention 36 | split_2_slice_size = query_tokens 37 | split_3_slice_size = shape_three 38 | 39 | do_split = False 40 | do_split_2 = False 41 | do_split_3 = False 42 | 43 | if block_size > sdpa_slice_trigger_rate: 44 | do_split = True 45 | split_slice_size = find_slice_size(split_slice_size, slice_block_size) 46 | if split_slice_size * slice_block_size > attention_slice_rate: 47 | slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size 48 | do_split_2 = True 49 | split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) 50 | if split_2_slice_size * slice_2_block_size > attention_slice_rate: 51 | slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size 52 | do_split_3 = True 53 | split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) 54 | 55 | return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size 56 | 57 | # Find slice sizes for BMM 58 | @cache 59 | def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape): 60 | batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2] 61 | slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size 62 | block_size = batch_size_attention * slice_block_size 63 | 64 | split_slice_size = batch_size_attention 65 | split_2_slice_size = input_tokens 66 | split_3_slice_size = mat2_atten_shape 67 | 68 | do_split = False 69 | do_split_2 = False 70 | do_split_3 = False 71 | 72 | if block_size > attention_slice_rate: 73 | do_split = True 74 | split_slice_size = find_slice_size(split_slice_size, slice_block_size) 75 | if split_slice_size * slice_block_size > attention_slice_rate: 76 | slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size 77 | do_split_2 = True 78 | split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) 79 | if split_2_slice_size * slice_2_block_size > attention_slice_rate: 80 | slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size 81 | do_split_3 = True 82 | split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) 83 | 84 | return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size 85 | 86 | 87 | original_torch_bmm = torch.bmm 88 | def torch_bmm_32_bit(input, mat2, *, out=None): 89 | if input.device.type != "xpu": 90 | return original_torch_bmm(input, mat2, out=out) 91 | do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape) 92 | 93 | # Slice BMM 94 | if do_split: 95 | batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2] 96 | hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) 97 | for i in range(batch_size_attention // split_slice_size): 98 | start_idx = i * split_slice_size 99 | end_idx = (i + 1) * split_slice_size 100 | if do_split_2: 101 | for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name 102 | start_idx_2 = i2 * split_2_slice_size 103 | end_idx_2 = (i2 + 1) * split_2_slice_size 104 | if do_split_3: 105 | for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name 106 | start_idx_3 = i3 * split_3_slice_size 107 | end_idx_3 = (i3 + 1) * split_3_slice_size 108 | hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm( 109 | input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], 110 | mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], 111 | out=out 112 | ) 113 | else: 114 | hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( 115 | input[start_idx:end_idx, start_idx_2:end_idx_2], 116 | mat2[start_idx:end_idx, start_idx_2:end_idx_2], 117 | out=out 118 | ) 119 | else: 120 | hidden_states[start_idx:end_idx] = original_torch_bmm( 121 | input[start_idx:end_idx], 122 | mat2[start_idx:end_idx], 123 | out=out 124 | ) 125 | else: 126 | return original_torch_bmm(input, mat2, out=out) 127 | return hidden_states 128 | 129 | original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention 130 | def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): 131 | if query.device.type != "xpu": 132 | return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) 133 | do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) 134 | 135 | # Slice SDPA 136 | if do_split: 137 | batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] 138 | hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) 139 | for i in range(batch_size_attention // split_slice_size): 140 | start_idx = i * split_slice_size 141 | end_idx = (i + 1) * split_slice_size 142 | if do_split_2: 143 | for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name 144 | start_idx_2 = i2 * split_2_slice_size 145 | end_idx_2 = (i2 + 1) * split_2_slice_size 146 | if do_split_3: 147 | for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name 148 | start_idx_3 = i3 * split_3_slice_size 149 | end_idx_3 = (i3 + 1) * split_3_slice_size 150 | hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention( 151 | query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], 152 | key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], 153 | value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], 154 | attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask, 155 | dropout_p=dropout_p, is_causal=is_causal 156 | ) 157 | else: 158 | hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( 159 | query[start_idx:end_idx, start_idx_2:end_idx_2], 160 | key[start_idx:end_idx, start_idx_2:end_idx_2], 161 | value[start_idx:end_idx, start_idx_2:end_idx_2], 162 | attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, 163 | dropout_p=dropout_p, is_causal=is_causal 164 | ) 165 | else: 166 | hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( 167 | query[start_idx:end_idx], 168 | key[start_idx:end_idx], 169 | value[start_idx:end_idx], 170 | attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, 171 | dropout_p=dropout_p, is_causal=is_causal 172 | ) 173 | else: 174 | return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) 175 | return hidden_states 176 | -------------------------------------------------------------------------------- /library/ipex/gradscaler.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch 3 | import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import 4 | import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import 5 | 6 | # pylint: disable=protected-access, missing-function-docstring, line-too-long 7 | 8 | device_supports_fp64 = torch.xpu.has_fp64_dtype() 9 | OptState = ipex.cpu.autocast._grad_scaler.OptState 10 | _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator 11 | _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state 12 | 13 | def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument 14 | per_device_inv_scale = _MultiDeviceReplicator(inv_scale) 15 | per_device_found_inf = _MultiDeviceReplicator(found_inf) 16 | 17 | # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. 18 | # There could be hundreds of grads, so we'd like to iterate through them just once. 19 | # However, we don't know their devices or dtypes in advance. 20 | 21 | # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict 22 | # Google says mypy struggles with defaultdicts type annotations. 23 | per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] 24 | # sync grad to master weight 25 | if hasattr(optimizer, "sync_grad"): 26 | optimizer.sync_grad() 27 | with torch.no_grad(): 28 | for group in optimizer.param_groups: 29 | for param in group["params"]: 30 | if param.grad is None: 31 | continue 32 | if (not allow_fp16) and param.grad.dtype == torch.float16: 33 | raise ValueError("Attempting to unscale FP16 gradients.") 34 | if param.grad.is_sparse: 35 | # is_coalesced() == False means the sparse grad has values with duplicate indices. 36 | # coalesce() deduplicates indices and adds all values that have the same index. 37 | # For scaled fp16 values, there's a good chance coalescing will cause overflow, 38 | # so we should check the coalesced _values(). 39 | if param.grad.dtype is torch.float16: 40 | param.grad = param.grad.coalesce() 41 | to_unscale = param.grad._values() 42 | else: 43 | to_unscale = param.grad 44 | 45 | # -: is there a way to split by device and dtype without appending in the inner loop? 46 | to_unscale = to_unscale.to("cpu") 47 | per_device_and_dtype_grads[to_unscale.device][ 48 | to_unscale.dtype 49 | ].append(to_unscale) 50 | 51 | for _, per_dtype_grads in per_device_and_dtype_grads.items(): 52 | for grads in per_dtype_grads.values(): 53 | core._amp_foreach_non_finite_check_and_unscale_( 54 | grads, 55 | per_device_found_inf.get("cpu"), 56 | per_device_inv_scale.get("cpu"), 57 | ) 58 | 59 | return per_device_found_inf._per_device_tensors 60 | 61 | def unscale_(self, optimizer): 62 | """ 63 | Divides ("unscales") the optimizer's gradient tensors by the scale factor. 64 | :meth:`unscale_` is optional, serving cases where you need to 65 | :ref:`modify or inspect gradients` 66 | between the backward pass(es) and :meth:`step`. 67 | If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. 68 | Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: 69 | ... 70 | scaler.scale(loss).backward() 71 | scaler.unscale_(optimizer) 72 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 73 | scaler.step(optimizer) 74 | scaler.update() 75 | Args: 76 | optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. 77 | .. warning:: 78 | :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, 79 | and only after all gradients for that optimizer's assigned parameters have been accumulated. 80 | Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. 81 | .. warning:: 82 | :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. 83 | """ 84 | if not self._enabled: 85 | return 86 | 87 | self._check_scale_growth_tracker("unscale_") 88 | 89 | optimizer_state = self._per_optimizer_states[id(optimizer)] 90 | 91 | if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise 92 | raise RuntimeError( 93 | "unscale_() has already been called on this optimizer since the last update()." 94 | ) 95 | elif optimizer_state["stage"] is OptState.STEPPED: 96 | raise RuntimeError("unscale_() is being called after step().") 97 | 98 | # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. 99 | assert self._scale is not None 100 | if device_supports_fp64: 101 | inv_scale = self._scale.double().reciprocal().float() 102 | else: 103 | inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) 104 | found_inf = torch.full( 105 | (1,), 0.0, dtype=torch.float32, device=self._scale.device 106 | ) 107 | 108 | optimizer_state["found_inf_per_device"] = self._unscale_grads_( 109 | optimizer, inv_scale, found_inf, False 110 | ) 111 | optimizer_state["stage"] = OptState.UNSCALED 112 | 113 | def update(self, new_scale=None): 114 | """ 115 | Updates the scale factor. 116 | If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` 117 | to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, 118 | the scale is multiplied by ``growth_factor`` to increase it. 119 | Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not 120 | used directly, it's used to fill GradScaler's internal scale tensor. So if 121 | ``new_scale`` was a tensor, later in-place changes to that tensor will not further 122 | affect the scale GradScaler uses internally.) 123 | Args: 124 | new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor. 125 | .. warning:: 126 | :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has 127 | been invoked for all optimizers used this iteration. 128 | """ 129 | if not self._enabled: 130 | return 131 | 132 | _scale, _growth_tracker = self._check_scale_growth_tracker("update") 133 | 134 | if new_scale is not None: 135 | # Accept a new user-defined scale. 136 | if isinstance(new_scale, float): 137 | self._scale.fill_(new_scale) # type: ignore[union-attr] 138 | else: 139 | reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False." 140 | assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined] 141 | assert new_scale.numel() == 1, reason 142 | assert new_scale.requires_grad is False, reason 143 | self._scale.copy_(new_scale) # type: ignore[union-attr] 144 | else: 145 | # Consume shared inf/nan data collected from optimizers to update the scale. 146 | # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. 147 | found_infs = [ 148 | found_inf.to(device="cpu", non_blocking=True) 149 | for state in self._per_optimizer_states.values() 150 | for found_inf in state["found_inf_per_device"].values() 151 | ] 152 | 153 | assert len(found_infs) > 0, "No inf checks were recorded prior to update." 154 | 155 | found_inf_combined = found_infs[0] 156 | if len(found_infs) > 1: 157 | for i in range(1, len(found_infs)): 158 | found_inf_combined += found_infs[i] 159 | 160 | to_device = _scale.device 161 | _scale = _scale.to("cpu") 162 | _growth_tracker = _growth_tracker.to("cpu") 163 | 164 | core._amp_update_scale_( 165 | _scale, 166 | _growth_tracker, 167 | found_inf_combined, 168 | self._growth_factor, 169 | self._backoff_factor, 170 | self._growth_interval, 171 | ) 172 | 173 | _scale = _scale.to(to_device) 174 | _growth_tracker = _growth_tracker.to(to_device) 175 | # To prepare for next iteration, clear the data collected from optimizers this iteration. 176 | self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) 177 | 178 | def gradscaler_init(): 179 | torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler 180 | torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_ 181 | torch.xpu.amp.GradScaler.unscale_ = unscale_ 182 | torch.xpu.amp.GradScaler.update = update 183 | return torch.xpu.amp.GradScaler 184 | -------------------------------------------------------------------------------- /library/ipex_interop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def init_ipex(): 5 | """ 6 | Try to import `intel_extension_for_pytorch`, and apply 7 | the hijacks using `library.ipex.ipex_init`. 8 | 9 | If IPEX is not installed, this function does nothing. 10 | """ 11 | try: 12 | import intel_extension_for_pytorch as ipex # noqa 13 | except ImportError: 14 | return 15 | 16 | try: 17 | from library.ipex import ipex_init 18 | 19 | if torch.xpu.is_available(): 20 | is_initialized, error_message = ipex_init() 21 | if not is_initialized: 22 | print("failed to initialize ipex:", error_message) 23 | except Exception as e: 24 | print("failed to initialize ipex:", e) 25 | -------------------------------------------------------------------------------- /library/utils.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from typing import * 3 | 4 | 5 | def fire_in_thread(f, *args, **kwargs): 6 | threading.Thread(target=f, args=args, kwargs=kwargs).start() -------------------------------------------------------------------------------- /networks/check_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from safetensors.torch import load_file 5 | 6 | 7 | def main(file): 8 | print(f"loading: {file}") 9 | if os.path.splitext(file)[1] == ".safetensors": 10 | sd = load_file(file) 11 | else: 12 | sd = torch.load(file, map_location="cpu") 13 | 14 | values = [] 15 | 16 | keys = list(sd.keys()) 17 | for key in keys: 18 | if "lora_up" in key or "lora_down" in key: 19 | values.append((key, sd[key])) 20 | print(f"number of LoRA modules: {len(values)}") 21 | 22 | if args.show_all_keys: 23 | for key in [k for k in keys if k not in values]: 24 | values.append((key, sd[key])) 25 | print(f"number of all modules: {len(values)}") 26 | 27 | for key, value in values: 28 | value = value.to(torch.float32) 29 | print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") 30 | 31 | 32 | def setup_parser() -> argparse.ArgumentParser: 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") 35 | parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する") 36 | 37 | return parser 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = setup_parser() 42 | 43 | args = parser.parse_args() 44 | 45 | main(args.file) 46 | -------------------------------------------------------------------------------- /networks/extract_lora_from_dylora.py: -------------------------------------------------------------------------------- 1 | # Convert LoRA to different rank approximation (should only be used to go to lower rank) 2 | # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py 3 | # Thanks to cloneofsimo 4 | 5 | import argparse 6 | import math 7 | import os 8 | import torch 9 | from safetensors.torch import load_file, save_file, safe_open 10 | from tqdm import tqdm 11 | from library import train_util, model_util 12 | import numpy as np 13 | 14 | 15 | def load_state_dict(file_name): 16 | if model_util.is_safetensors(file_name): 17 | sd = load_file(file_name) 18 | with safe_open(file_name, framework="pt") as f: 19 | metadata = f.metadata() 20 | else: 21 | sd = torch.load(file_name, map_location="cpu") 22 | metadata = None 23 | 24 | return sd, metadata 25 | 26 | 27 | def save_to_file(file_name, model, metadata): 28 | if model_util.is_safetensors(file_name): 29 | save_file(model, file_name, metadata) 30 | else: 31 | torch.save(model, file_name) 32 | 33 | 34 | def split_lora_model(lora_sd, unit): 35 | max_rank = 0 36 | 37 | # Extract loaded lora dim and alpha 38 | for key, value in lora_sd.items(): 39 | if "lora_down" in key: 40 | rank = value.size()[0] 41 | if rank > max_rank: 42 | max_rank = rank 43 | print(f"Max rank: {max_rank}") 44 | 45 | rank = unit 46 | split_models = [] 47 | new_alpha = None 48 | while rank < max_rank: 49 | print(f"Splitting rank {rank}") 50 | new_sd = {} 51 | for key, value in lora_sd.items(): 52 | if "lora_down" in key: 53 | new_sd[key] = value[:rank].contiguous() 54 | elif "lora_up" in key: 55 | new_sd[key] = value[:, :rank].contiguous() 56 | else: 57 | # なぜかscaleするとおかしくなる…… 58 | # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0] 59 | # scale = math.sqrt(this_rank / rank) # rank is > unit 60 | # print(key, value.size(), this_rank, rank, value, scale) 61 | # new_alpha = value * scale # always same 62 | # new_sd[key] = new_alpha 63 | new_sd[key] = value 64 | 65 | split_models.append((new_sd, rank, new_alpha)) 66 | rank += unit 67 | 68 | return max_rank, split_models 69 | 70 | 71 | def split(args): 72 | print("loading Model...") 73 | lora_sd, metadata = load_state_dict(args.model) 74 | 75 | print("Splitting Model...") 76 | original_rank, split_models = split_lora_model(lora_sd, args.unit) 77 | 78 | comment = metadata.get("ss_training_comment", "") 79 | for state_dict, new_rank, new_alpha in split_models: 80 | # update metadata 81 | if metadata is None: 82 | new_metadata = {} 83 | else: 84 | new_metadata = metadata.copy() 85 | 86 | new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}" 87 | new_metadata["ss_network_dim"] = str(new_rank) 88 | # new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy()) 89 | 90 | model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) 91 | metadata["sshs_model_hash"] = model_hash 92 | metadata["sshs_legacy_hash"] = legacy_hash 93 | 94 | filename, ext = os.path.splitext(args.save_to) 95 | model_file_name = filename + f"-{new_rank:04d}{ext}" 96 | 97 | print(f"saving model to: {model_file_name}") 98 | save_to_file(model_file_name, state_dict, new_metadata) 99 | 100 | 101 | def setup_parser() -> argparse.ArgumentParser: 102 | parser = argparse.ArgumentParser() 103 | 104 | parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ") 105 | parser.add_argument( 106 | "--save_to", 107 | type=str, 108 | default=None, 109 | help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors", 110 | ) 111 | parser.add_argument( 112 | "--model", 113 | type=str, 114 | default=None, 115 | help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors", 116 | ) 117 | 118 | return parser 119 | 120 | 121 | if __name__ == "__main__": 122 | parser = setup_parser() 123 | 124 | args = parser.parse_args() 125 | split(args) 126 | -------------------------------------------------------------------------------- /networks/lora_interrogator.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from tqdm import tqdm 4 | from library import model_util 5 | import library.train_util as train_util 6 | import argparse 7 | from transformers import CLIPTokenizer 8 | import torch 9 | 10 | import library.model_util as model_util 11 | import lora 12 | 13 | TOKENIZER_PATH = "openai/clip-vit-large-patch14" 14 | V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う 15 | 16 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | 19 | def interrogate(args): 20 | weights_dtype = torch.float16 21 | 22 | # いろいろ準備する 23 | print(f"loading SD model: {args.sd_model}") 24 | args.pretrained_model_name_or_path = args.sd_model 25 | args.vae = None 26 | text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE) 27 | 28 | print(f"loading LoRA: {args.model}") 29 | network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) 30 | 31 | # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい 32 | has_te_weight = False 33 | for key in weights_sd.keys(): 34 | if 'lora_te' in key: 35 | has_te_weight = True 36 | break 37 | if not has_te_weight: 38 | print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") 39 | return 40 | del vae 41 | 42 | print("loading tokenizer") 43 | if args.v2: 44 | tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") 45 | else: 46 | tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2) 47 | 48 | text_encoder.to(DEVICE, dtype=weights_dtype) 49 | text_encoder.eval() 50 | unet.to(DEVICE, dtype=weights_dtype) 51 | unet.eval() # U-Netは呼び出さないので不要だけど 52 | 53 | # トークンをひとつひとつ当たっていく 54 | token_id_start = 0 55 | token_id_end = max(tokenizer.all_special_ids) 56 | print(f"interrogate tokens are: {token_id_start} to {token_id_end}") 57 | 58 | def get_all_embeddings(text_encoder): 59 | embs = [] 60 | with torch.no_grad(): 61 | for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)): 62 | batch = [] 63 | for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)): 64 | tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id] 65 | # tokens = [tid] # こちらは結果がいまひとつ 66 | batch.append(tokens) 67 | 68 | # batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1] 69 | # clip skip対応 70 | batch = torch.tensor(batch).to(DEVICE) 71 | if args.clip_skip is None: 72 | encoder_hidden_states = text_encoder(batch)[0] 73 | else: 74 | enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True) 75 | encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] 76 | encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) 77 | encoder_hidden_states = encoder_hidden_states.to("cpu") 78 | 79 | embs.extend(encoder_hidden_states) 80 | return torch.stack(embs) 81 | 82 | print("get original text encoder embeddings.") 83 | orig_embs = get_all_embeddings(text_encoder) 84 | 85 | network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0) 86 | info = network.load_state_dict(weights_sd, strict=False) 87 | print(f"Loading LoRA weights: {info}") 88 | 89 | network.to(DEVICE, dtype=weights_dtype) 90 | network.eval() 91 | 92 | del unet 93 | 94 | print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") 95 | print("get text encoder embeddings with lora.") 96 | lora_embs = get_all_embeddings(text_encoder) 97 | 98 | # 比べる:とりあえず単純に差分の絶対値で 99 | print("comparing...") 100 | diffs = {} 101 | for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))): 102 | diff = torch.mean(torch.abs(orig_emb - lora_emb)) 103 | # diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない 104 | diff = float(diff.detach().to('cpu').numpy()) 105 | diffs[token_id_start + i] = diff 106 | 107 | diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1]) 108 | 109 | # 結果を表示する 110 | print("top 100:") 111 | for i, (token, diff) in enumerate(diffs_sorted[:100]): 112 | # if diff < 1e-6: 113 | # break 114 | string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token])) 115 | print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}") 116 | 117 | 118 | def setup_parser() -> argparse.ArgumentParser: 119 | parser = argparse.ArgumentParser() 120 | 121 | parser.add_argument("--v2", action='store_true', 122 | help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') 123 | parser.add_argument("--sd_model", type=str, default=None, 124 | help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors") 125 | parser.add_argument("--model", type=str, default=None, 126 | help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors") 127 | parser.add_argument("--batch_size", type=int, default=16, 128 | help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ") 129 | parser.add_argument("--clip_skip", type=int, default=None, 130 | help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") 131 | 132 | return parser 133 | 134 | 135 | if __name__ == '__main__': 136 | parser = setup_parser() 137 | 138 | args = parser.parse_args() 139 | interrogate(args) 140 | -------------------------------------------------------------------------------- /networks/merge_lora_old.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | import os 5 | import torch 6 | from safetensors.torch import load_file, save_file 7 | import library.model_util as model_util 8 | import lora 9 | 10 | 11 | def load_state_dict(file_name, dtype): 12 | if os.path.splitext(file_name)[1] == '.safetensors': 13 | sd = load_file(file_name) 14 | else: 15 | sd = torch.load(file_name, map_location='cpu') 16 | for key in list(sd.keys()): 17 | if type(sd[key]) == torch.Tensor: 18 | sd[key] = sd[key].to(dtype) 19 | return sd 20 | 21 | 22 | def save_to_file(file_name, model, state_dict, dtype): 23 | if dtype is not None: 24 | for key in list(state_dict.keys()): 25 | if type(state_dict[key]) == torch.Tensor: 26 | state_dict[key] = state_dict[key].to(dtype) 27 | 28 | if os.path.splitext(file_name)[1] == '.safetensors': 29 | save_file(model, file_name) 30 | else: 31 | torch.save(model, file_name) 32 | 33 | 34 | def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): 35 | text_encoder.to(merge_dtype) 36 | unet.to(merge_dtype) 37 | 38 | # create module map 39 | name_to_module = {} 40 | for i, root_module in enumerate([text_encoder, unet]): 41 | if i == 0: 42 | prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER 43 | target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE 44 | else: 45 | prefix = lora.LoRANetwork.LORA_PREFIX_UNET 46 | target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE 47 | 48 | for name, module in root_module.named_modules(): 49 | if module.__class__.__name__ in target_replace_modules: 50 | for child_name, child_module in module.named_modules(): 51 | if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): 52 | lora_name = prefix + '.' + name + '.' + child_name 53 | lora_name = lora_name.replace('.', '_') 54 | name_to_module[lora_name] = child_module 55 | 56 | for model, ratio in zip(models, ratios): 57 | print(f"loading: {model}") 58 | lora_sd = load_state_dict(model, merge_dtype) 59 | 60 | print(f"merging...") 61 | for key in lora_sd.keys(): 62 | if "lora_down" in key: 63 | up_key = key.replace("lora_down", "lora_up") 64 | alpha_key = key[:key.index("lora_down")] + 'alpha' 65 | 66 | # find original module for this lora 67 | module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" 68 | if module_name not in name_to_module: 69 | print(f"no module found for LoRA weight: {key}") 70 | continue 71 | module = name_to_module[module_name] 72 | # print(f"apply {key} to {module}") 73 | 74 | down_weight = lora_sd[key] 75 | up_weight = lora_sd[up_key] 76 | 77 | dim = down_weight.size()[0] 78 | alpha = lora_sd.get(alpha_key, dim) 79 | scale = alpha / dim 80 | 81 | # W <- W + U * D 82 | weight = module.weight 83 | if len(weight.size()) == 2: 84 | # linear 85 | weight = weight + ratio * (up_weight @ down_weight) * scale 86 | else: 87 | # conv2d 88 | weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale 89 | 90 | module.weight = torch.nn.Parameter(weight) 91 | 92 | 93 | def merge_lora_models(models, ratios, merge_dtype): 94 | merged_sd = {} 95 | 96 | alpha = None 97 | dim = None 98 | for model, ratio in zip(models, ratios): 99 | print(f"loading: {model}") 100 | lora_sd = load_state_dict(model, merge_dtype) 101 | 102 | print(f"merging...") 103 | for key in lora_sd.keys(): 104 | if 'alpha' in key: 105 | if key in merged_sd: 106 | assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません" 107 | else: 108 | alpha = lora_sd[key].detach().numpy() 109 | merged_sd[key] = lora_sd[key] 110 | else: 111 | if key in merged_sd: 112 | assert merged_sd[key].size() == lora_sd[key].size( 113 | ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" 114 | merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio 115 | else: 116 | if "lora_down" in key: 117 | dim = lora_sd[key].size()[0] 118 | merged_sd[key] = lora_sd[key] * ratio 119 | 120 | print(f"dim (rank): {dim}, alpha: {alpha}") 121 | if alpha is None: 122 | alpha = dim 123 | 124 | return merged_sd, dim, alpha 125 | 126 | 127 | def merge(args): 128 | assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" 129 | 130 | def str_to_dtype(p): 131 | if p == 'float': 132 | return torch.float 133 | if p == 'fp16': 134 | return torch.float16 135 | if p == 'bf16': 136 | return torch.bfloat16 137 | return None 138 | 139 | merge_dtype = str_to_dtype(args.precision) 140 | save_dtype = str_to_dtype(args.save_precision) 141 | if save_dtype is None: 142 | save_dtype = merge_dtype 143 | 144 | if args.sd_model is not None: 145 | print(f"loading SD model: {args.sd_model}") 146 | 147 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) 148 | 149 | merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) 150 | 151 | print(f"\nsaving SD model to: {args.save_to}") 152 | model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, 153 | args.sd_model, 0, 0, save_dtype, vae) 154 | else: 155 | state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) 156 | 157 | print(f"\nsaving model to: {args.save_to}") 158 | save_to_file(args.save_to, state_dict, state_dict, save_dtype) 159 | 160 | 161 | def setup_parser() -> argparse.ArgumentParser: 162 | parser = argparse.ArgumentParser() 163 | parser.add_argument("--v2", action='store_true', 164 | help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') 165 | parser.add_argument("--save_precision", type=str, default=None, 166 | choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") 167 | parser.add_argument("--precision", type=str, default="float", 168 | choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)") 169 | parser.add_argument("--sd_model", type=str, default=None, 170 | help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") 171 | parser.add_argument("--save_to", type=str, default=None, 172 | help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") 173 | parser.add_argument("--models", type=str, nargs='*', 174 | help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") 175 | parser.add_argument("--ratios", type=float, nargs='*', 176 | help="ratios for each model / それぞれのLoRAモデルの比率") 177 | 178 | return parser 179 | 180 | 181 | if __name__ == '__main__': 182 | parser = setup_parser() 183 | 184 | args = parser.parse_args() 185 | merge(args) 186 | -------------------------------------------------------------------------------- /notebook/config_file.toml: -------------------------------------------------------------------------------- 1 | [sdxl_arguments] 2 | cache_text_encoder_outputs = false 3 | no_half_vae = false 4 | min_timestep = 0 5 | max_timestep = 1000 6 | 7 | [model_arguments] 8 | pretrained_model_name_or_path = "/workspace/model/animagine-xl-3.0-base.safetensors" 9 | vae = "/workspace/vae/sdxl_vae.safetensors" 10 | 11 | [dataset_arguments] 12 | shuffle_caption = true 13 | debug_dataset = false 14 | in_json = "/workspace/fine_tune/animagine-xl-3.1_lat.json" 15 | train_data_dir = "/workspace/train_data/animagine-xl-3.1" 16 | dataset_repeats = 1 17 | keep_tokens_separator = "|||" 18 | resolution = "1024, 1024" 19 | caption_dropout_rate = 0 20 | caption_tag_dropout_rate = 0 21 | caption_dropout_every_n_epochs = 0 22 | token_warmup_min = 1 23 | token_warmup_step = 0 24 | 25 | [training_arguments] 26 | output_dir = "/workspace/fine_tune/outputs/animagine-xl-3.1" 27 | output_name = "animagine-xl-3.1" 28 | save_precision = "fp16" 29 | save_every_n_steps = 1000 30 | save_last_n_steps = true 31 | save_state = true 32 | save_last_n_steps_state = true 33 | train_batch_size = 16 34 | max_token_length = 225 35 | mem_eff_attn = false 36 | xformers = true 37 | sdpa = false 38 | max_train_epochs = 10 39 | max_data_loader_n_workers = 8 40 | persistent_data_loader_workers = true 41 | gradient_checkpointing = true 42 | gradient_accumulation_steps = 3 43 | mixed_precision = "fp16" 44 | ddp_gradient_as_bucket_view = true 45 | ddp_static_graph = true 46 | ddp_timeout = 100000 47 | 48 | [logging_arguments] 49 | log_with = "wandb" 50 | log_tracker_name = "animagine-xl-3.1" 51 | logging_dir = "/workspace/fine_tune/logs" 52 | 53 | [sample_prompt_arguments] 54 | sample_every_n_steps = 100 55 | sample_sampler = "euler_a" 56 | 57 | [saving_arguments] 58 | save_model_as = "safetensors" 59 | 60 | [optimizer_arguments] 61 | optimizer_type = "AdamW" 62 | learning_rate = 1e-5 63 | train_text_encoder = true 64 | optimizer_args = [ "weight_decay=0.1", "betas=0.9,0.99",] 65 | lr_scheduler = "cosine_with_restarts" 66 | lr_scheduler_num_cycles = 10 67 | lr_scheduler_type = "LoraEasyCustomOptimizer.CustomOptimizers.CosineAnnealingWarmupRestarts" 68 | lr_scheduler_args = [ "min_lr=1e-06", "gamma=0.9", "first_cycle_steps=9099",] 69 | max_grad_norm = 1.0 70 | 71 | [advanced_training_config] 72 | resume_from_huggingface = false 73 | 74 | [save_to_hub_config] 75 | huggingface_repo_type = "model" 76 | huggingface_path_in_repo = "model/animagine-xl-3.1_20240320_104513" 77 | huggingface_token = "" 78 | async_upload = true 79 | save_state_to_huggingface = true 80 | huggingface_repo_visibility = "private" 81 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.25.0 2 | transformers==4.36.2 3 | diffusers[torch]==0.25.0 4 | ftfy==6.1.1 5 | # albumentations==1.3.0 6 | opencv-python==4.7.0.68 7 | einops==0.6.1 8 | pytorch-lightning==1.9.0 9 | # bitsandbytes==0.39.1 10 | tensorboard==2.10.1 11 | safetensors==0.4.2 12 | # gradio==3.16.2 13 | altair==4.2.2 14 | easygui==0.98.3 15 | toml==0.10.2 16 | voluptuous==0.13.1 17 | huggingface-hub==0.20.1 18 | # for BLIP captioning 19 | # requests==2.28.2 20 | # timm==0.6.12 21 | # fairscale==0.4.13 22 | # for WD14 captioning (tensorflow) 23 | # tensorflow==2.10.1 24 | # for WD14 captioning (onnx) 25 | # onnx==1.14.1 26 | # onnxruntime-gpu==1.16.0 27 | # onnxruntime==1.16.0 28 | # this is for onnx: 29 | # protobuf==3.20.3 30 | # open clip for SDXL 31 | open-clip-torch==2.20.0 32 | # for kohya_ss library 33 | -e . 34 | -------------------------------------------------------------------------------- /sdxl_train_network.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from library.ipex_interop import init_ipex 5 | 6 | init_ipex() 7 | 8 | from library import sdxl_model_util, sdxl_train_util, train_util 9 | import train_network 10 | 11 | 12 | class SdxlNetworkTrainer(train_network.NetworkTrainer): 13 | def __init__(self): 14 | super().__init__() 15 | self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR 16 | self.is_sdxl = True 17 | 18 | def assert_extra_args(self, args, train_dataset_group): 19 | super().assert_extra_args(args, train_dataset_group) 20 | sdxl_train_util.verify_sdxl_training_args(args) 21 | 22 | if args.cache_text_encoder_outputs: 23 | assert ( 24 | train_dataset_group.is_text_encoder_output_cacheable() 25 | ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" 26 | 27 | assert ( 28 | args.network_train_unet_only or not args.cache_text_encoder_outputs 29 | ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" 30 | 31 | train_dataset_group.verify_bucket_reso_steps(32) 32 | 33 | def load_target_model(self, args, weight_dtype, accelerator): 34 | ( 35 | load_stable_diffusion_format, 36 | text_encoder1, 37 | text_encoder2, 38 | vae, 39 | unet, 40 | logit_scale, 41 | ckpt_info, 42 | ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) 43 | 44 | self.load_stable_diffusion_format = load_stable_diffusion_format 45 | self.logit_scale = logit_scale 46 | self.ckpt_info = ckpt_info 47 | 48 | return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet 49 | 50 | def load_tokenizer(self, args): 51 | tokenizer = sdxl_train_util.load_tokenizers(args) 52 | return tokenizer 53 | 54 | def is_text_encoder_outputs_cached(self, args): 55 | return args.cache_text_encoder_outputs 56 | 57 | def cache_text_encoder_outputs_if_needed( 58 | self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype 59 | ): 60 | if args.cache_text_encoder_outputs: 61 | if not args.lowram: 62 | # メモリ消費を減らす 63 | print("move vae and unet to cpu to save memory") 64 | org_vae_device = vae.device 65 | org_unet_device = unet.device 66 | vae.to("cpu") 67 | unet.to("cpu") 68 | if torch.cuda.is_available(): 69 | torch.cuda.empty_cache() 70 | 71 | # When TE is not be trained, it will not be prepared so we need to use explicit autocast 72 | with accelerator.autocast(): 73 | dataset.cache_text_encoder_outputs( 74 | tokenizers, 75 | text_encoders, 76 | accelerator.device, 77 | weight_dtype, 78 | args.cache_text_encoder_outputs_to_disk, 79 | accelerator.is_main_process, 80 | ) 81 | 82 | text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU 83 | text_encoders[1].to("cpu", dtype=torch.float32) 84 | if torch.cuda.is_available(): 85 | torch.cuda.empty_cache() 86 | 87 | if not args.lowram: 88 | print("move vae and unet back to original device") 89 | vae.to(org_vae_device) 90 | unet.to(org_unet_device) 91 | else: 92 | # Text Encoderから毎回出力を取得するので、GPUに乗せておく 93 | text_encoders[0].to(accelerator.device, dtype=weight_dtype) 94 | text_encoders[1].to(accelerator.device, dtype=weight_dtype) 95 | 96 | def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): 97 | if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: 98 | input_ids1 = batch["input_ids"] 99 | input_ids2 = batch["input_ids2"] 100 | with torch.enable_grad(): 101 | # Get the text embedding for conditioning 102 | # TODO support weighted captions 103 | # if args.weighted_captions: 104 | # encoder_hidden_states = get_weighted_text_embeddings( 105 | # tokenizer, 106 | # text_encoder, 107 | # batch["captions"], 108 | # accelerator.device, 109 | # args.max_token_length // 75 if args.max_token_length else 1, 110 | # clip_skip=args.clip_skip, 111 | # ) 112 | # else: 113 | input_ids1 = input_ids1.to(accelerator.device) 114 | input_ids2 = input_ids2.to(accelerator.device) 115 | encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( 116 | args.max_token_length, 117 | input_ids1, 118 | input_ids2, 119 | tokenizers[0], 120 | tokenizers[1], 121 | text_encoders[0], 122 | text_encoders[1], 123 | None if not args.full_fp16 else weight_dtype, 124 | accelerator=accelerator, 125 | ) 126 | else: 127 | encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) 128 | encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) 129 | pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) 130 | 131 | # # verify that the text encoder outputs are correct 132 | # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( 133 | # args.max_token_length, 134 | # batch["input_ids"].to(text_encoders[0].device), 135 | # batch["input_ids2"].to(text_encoders[0].device), 136 | # tokenizers[0], 137 | # tokenizers[1], 138 | # text_encoders[0], 139 | # text_encoders[1], 140 | # None if not args.full_fp16 else weight_dtype, 141 | # ) 142 | # b_size = encoder_hidden_states1.shape[0] 143 | # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 144 | # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 145 | # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 146 | # print("text encoder outputs verified") 147 | 148 | return encoder_hidden_states1, encoder_hidden_states2, pool2 149 | 150 | def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): 151 | noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype 152 | 153 | # get size embeddings 154 | orig_size = batch["original_sizes_hw"] 155 | crop_size = batch["crop_top_lefts"] 156 | target_size = batch["target_sizes_hw"] 157 | embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) 158 | 159 | # concat embeddings 160 | encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds 161 | vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) 162 | text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) 163 | 164 | noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) 165 | return noise_pred 166 | 167 | def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): 168 | sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) 169 | 170 | 171 | def setup_parser() -> argparse.ArgumentParser: 172 | parser = train_network.setup_parser() 173 | sdxl_train_util.add_sdxl_training_arguments(parser) 174 | return parser 175 | 176 | 177 | if __name__ == "__main__": 178 | parser = setup_parser() 179 | 180 | args = parser.parse_args() 181 | args = train_util.read_config_from_file(args, parser) 182 | 183 | trainer = SdxlNetworkTrainer() 184 | trainer.train(args) 185 | -------------------------------------------------------------------------------- /sdxl_train_textual_inversion.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import regex 5 | import torch 6 | from library.ipex_interop import init_ipex 7 | 8 | init_ipex() 9 | import open_clip 10 | from library import sdxl_model_util, sdxl_train_util, train_util 11 | 12 | import train_textual_inversion 13 | 14 | 15 | class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer): 16 | def __init__(self): 17 | super().__init__() 18 | self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR 19 | self.is_sdxl = True 20 | 21 | def assert_extra_args(self, args, train_dataset_group): 22 | super().assert_extra_args(args, train_dataset_group) 23 | sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) 24 | 25 | train_dataset_group.verify_bucket_reso_steps(32) 26 | 27 | def load_target_model(self, args, weight_dtype, accelerator): 28 | ( 29 | load_stable_diffusion_format, 30 | text_encoder1, 31 | text_encoder2, 32 | vae, 33 | unet, 34 | logit_scale, 35 | ckpt_info, 36 | ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) 37 | 38 | self.load_stable_diffusion_format = load_stable_diffusion_format 39 | self.logit_scale = logit_scale 40 | self.ckpt_info = ckpt_info 41 | 42 | return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet 43 | 44 | def load_tokenizer(self, args): 45 | tokenizer = sdxl_train_util.load_tokenizers(args) 46 | return tokenizer 47 | 48 | def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): 49 | input_ids1 = batch["input_ids"] 50 | input_ids2 = batch["input_ids2"] 51 | with torch.enable_grad(): 52 | input_ids1 = input_ids1.to(accelerator.device) 53 | input_ids2 = input_ids2.to(accelerator.device) 54 | encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( 55 | args.max_token_length, 56 | input_ids1, 57 | input_ids2, 58 | tokenizers[0], 59 | tokenizers[1], 60 | text_encoders[0], 61 | text_encoders[1], 62 | None if not args.full_fp16 else weight_dtype, 63 | accelerator=accelerator, 64 | ) 65 | return encoder_hidden_states1, encoder_hidden_states2, pool2 66 | 67 | def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): 68 | noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype 69 | 70 | # get size embeddings 71 | orig_size = batch["original_sizes_hw"] 72 | crop_size = batch["crop_top_lefts"] 73 | target_size = batch["target_sizes_hw"] 74 | embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) 75 | 76 | # concat embeddings 77 | encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds 78 | vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) 79 | text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) 80 | 81 | noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) 82 | return noise_pred 83 | 84 | def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): 85 | sdxl_train_util.sample_images( 86 | accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement 87 | ) 88 | 89 | def save_weights(self, file, updated_embs, save_dtype, metadata): 90 | state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]} 91 | 92 | if save_dtype is not None: 93 | for key in list(state_dict.keys()): 94 | v = state_dict[key] 95 | v = v.detach().clone().to("cpu").to(save_dtype) 96 | state_dict[key] = v 97 | 98 | if os.path.splitext(file)[1] == ".safetensors": 99 | from safetensors.torch import save_file 100 | 101 | save_file(state_dict, file, metadata) 102 | else: 103 | torch.save(state_dict, file) 104 | 105 | def load_weights(self, file): 106 | if os.path.splitext(file)[1] == ".safetensors": 107 | from safetensors.torch import load_file 108 | 109 | data = load_file(file) 110 | else: 111 | data = torch.load(file, map_location="cpu") 112 | 113 | emb_l = data.get("clip_l", None) # ViT-L text encoder 1 114 | emb_g = data.get("clip_g", None) # BiG-G text encoder 2 115 | 116 | assert ( 117 | emb_l is not None or emb_g is not None 118 | ), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}" 119 | 120 | return [emb_l, emb_g] 121 | 122 | 123 | def setup_parser() -> argparse.ArgumentParser: 124 | parser = train_textual_inversion.setup_parser() 125 | # don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching 126 | # sdxl_train_util.add_sdxl_training_arguments(parser) 127 | return parser 128 | 129 | 130 | if __name__ == "__main__": 131 | parser = setup_parser() 132 | 133 | args = parser.parse_args() 134 | args = train_util.read_config_from_file(args, parser) 135 | 136 | trainer = SdxlTextualInversionTrainer() 137 | trainer.train(args) 138 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name = "library", packages = find_packages()) -------------------------------------------------------------------------------- /tools/cache_latents.py: -------------------------------------------------------------------------------- 1 | # latentsのdiskへの事前キャッシュを行う / cache latents to disk 2 | 3 | import argparse 4 | import math 5 | from multiprocessing import Value 6 | import os 7 | 8 | from accelerate.utils import set_seed 9 | import torch 10 | from tqdm import tqdm 11 | 12 | from library import config_util 13 | from library import train_util 14 | from library import sdxl_train_util 15 | from library.config_util import ( 16 | ConfigSanitizer, 17 | BlueprintGenerator, 18 | ) 19 | 20 | 21 | def cache_to_disk(args: argparse.Namespace) -> None: 22 | train_util.prepare_dataset_args(args, True) 23 | 24 | # check cache latents arg 25 | assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" 26 | 27 | use_dreambooth_method = args.in_json is None 28 | 29 | if args.seed is not None: 30 | set_seed(args.seed) # 乱数系列を初期化する 31 | 32 | # tokenizerを準備する:datasetを動かすために必要 33 | if args.sdxl: 34 | tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) 35 | tokenizers = [tokenizer1, tokenizer2] 36 | else: 37 | tokenizer = train_util.load_tokenizer(args) 38 | tokenizers = [tokenizer] 39 | 40 | # データセットを準備する 41 | if args.dataset_class is None: 42 | blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) 43 | if args.dataset_config is not None: 44 | print(f"Load dataset config from {args.dataset_config}") 45 | user_config = config_util.load_user_config(args.dataset_config) 46 | ignored = ["train_data_dir", "in_json"] 47 | if any(getattr(args, attr) is not None for attr in ignored): 48 | print( 49 | "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( 50 | ", ".join(ignored) 51 | ) 52 | ) 53 | else: 54 | if use_dreambooth_method: 55 | print("Using DreamBooth method.") 56 | user_config = { 57 | "datasets": [ 58 | { 59 | "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( 60 | args.train_data_dir, args.reg_data_dir 61 | ) 62 | } 63 | ] 64 | } 65 | else: 66 | print("Training with captions.") 67 | user_config = { 68 | "datasets": [ 69 | { 70 | "subsets": [ 71 | { 72 | "image_dir": args.train_data_dir, 73 | "metadata_file": args.in_json, 74 | } 75 | ] 76 | } 77 | ] 78 | } 79 | 80 | blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) 81 | train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) 82 | else: 83 | train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) 84 | 85 | # datasetのcache_latentsを呼ばなければ、生の画像が返る 86 | 87 | current_epoch = Value("i", 0) 88 | current_step = Value("i", 0) 89 | ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None 90 | collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) 91 | 92 | # acceleratorを準備する 93 | print("prepare accelerator") 94 | accelerator = train_util.prepare_accelerator(args) 95 | 96 | # mixed precisionに対応した型を用意しておき適宜castする 97 | weight_dtype, _ = train_util.prepare_dtype(args) 98 | vae_dtype = torch.float32 if args.no_half_vae else weight_dtype 99 | 100 | # モデルを読み込む 101 | print("load model") 102 | if args.sdxl: 103 | (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) 104 | else: 105 | _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) 106 | 107 | if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える 108 | vae.set_use_memory_efficient_attention_xformers(args.xformers) 109 | vae.to(accelerator.device, dtype=vae_dtype) 110 | vae.requires_grad_(False) 111 | vae.eval() 112 | 113 | # dataloaderを準備する 114 | train_dataset_group.set_caching_mode("latents") 115 | 116 | # DataLoaderのプロセス数:0はメインプロセスになる 117 | n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで 118 | 119 | train_dataloader = torch.utils.data.DataLoader( 120 | train_dataset_group, 121 | batch_size=1, 122 | shuffle=True, 123 | collate_fn=collator, 124 | num_workers=n_workers, 125 | persistent_workers=args.persistent_data_loader_workers, 126 | ) 127 | 128 | # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず 129 | train_dataloader = accelerator.prepare(train_dataloader) 130 | 131 | # データ取得のためのループ 132 | for batch in tqdm(train_dataloader): 133 | b_size = len(batch["images"]) 134 | vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size 135 | flip_aug = batch["flip_aug"] 136 | random_crop = batch["random_crop"] 137 | bucket_reso = batch["bucket_reso"] 138 | 139 | # バッチを分割して処理する 140 | for i in range(0, b_size, vae_batch_size): 141 | images = batch["images"][i : i + vae_batch_size] 142 | absolute_paths = batch["absolute_paths"][i : i + vae_batch_size] 143 | resized_sizes = batch["resized_sizes"][i : i + vae_batch_size] 144 | 145 | image_infos = [] 146 | for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)): 147 | image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) 148 | image_info.image = image 149 | image_info.bucket_reso = bucket_reso 150 | image_info.resized_size = resized_size 151 | image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" 152 | 153 | if args.skip_existing: 154 | if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): 155 | print(f"Skipping {image_info.latents_npz} because it already exists.") 156 | continue 157 | 158 | image_infos.append(image_info) 159 | 160 | if len(image_infos) > 0: 161 | train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop) 162 | 163 | accelerator.wait_for_everyone() 164 | accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") 165 | 166 | 167 | def setup_parser() -> argparse.ArgumentParser: 168 | parser = argparse.ArgumentParser() 169 | 170 | train_util.add_sd_models_arguments(parser) 171 | train_util.add_training_arguments(parser, True) 172 | train_util.add_dataset_arguments(parser, True, True, True) 173 | config_util.add_config_arguments(parser) 174 | parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") 175 | parser.add_argument( 176 | "--no_half_vae", 177 | action="store_true", 178 | help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", 179 | ) 180 | parser.add_argument( 181 | "--skip_existing", 182 | action="store_true", 183 | help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", 184 | ) 185 | return parser 186 | 187 | 188 | if __name__ == "__main__": 189 | parser = setup_parser() 190 | 191 | args = parser.parse_args() 192 | args = train_util.read_config_from_file(args, parser) 193 | 194 | cache_to_disk(args) 195 | -------------------------------------------------------------------------------- /tools/cache_text_encoder_outputs.py: -------------------------------------------------------------------------------- 1 | # text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance 2 | 3 | import argparse 4 | import math 5 | from multiprocessing import Value 6 | import os 7 | 8 | from accelerate.utils import set_seed 9 | import torch 10 | from tqdm import tqdm 11 | 12 | from library import config_util 13 | from library import train_util 14 | from library import sdxl_train_util 15 | from library.config_util import ( 16 | ConfigSanitizer, 17 | BlueprintGenerator, 18 | ) 19 | 20 | 21 | def cache_to_disk(args: argparse.Namespace) -> None: 22 | train_util.prepare_dataset_args(args, True) 23 | 24 | # check cache arg 25 | assert ( 26 | args.cache_text_encoder_outputs_to_disk 27 | ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります" 28 | 29 | # できるだけ準備はしておくが今のところSDXLのみしか動かない 30 | assert ( 31 | args.sdxl 32 | ), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です" 33 | 34 | use_dreambooth_method = args.in_json is None 35 | 36 | if args.seed is not None: 37 | set_seed(args.seed) # 乱数系列を初期化する 38 | 39 | # tokenizerを準備する:datasetを動かすために必要 40 | if args.sdxl: 41 | tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) 42 | tokenizers = [tokenizer1, tokenizer2] 43 | else: 44 | tokenizer = train_util.load_tokenizer(args) 45 | tokenizers = [tokenizer] 46 | 47 | # データセットを準備する 48 | if args.dataset_class is None: 49 | blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) 50 | if args.dataset_config is not None: 51 | print(f"Load dataset config from {args.dataset_config}") 52 | user_config = config_util.load_user_config(args.dataset_config) 53 | ignored = ["train_data_dir", "in_json"] 54 | if any(getattr(args, attr) is not None for attr in ignored): 55 | print( 56 | "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( 57 | ", ".join(ignored) 58 | ) 59 | ) 60 | else: 61 | if use_dreambooth_method: 62 | print("Using DreamBooth method.") 63 | user_config = { 64 | "datasets": [ 65 | { 66 | "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( 67 | args.train_data_dir, args.reg_data_dir 68 | ) 69 | } 70 | ] 71 | } 72 | else: 73 | print("Training with captions.") 74 | user_config = { 75 | "datasets": [ 76 | { 77 | "subsets": [ 78 | { 79 | "image_dir": args.train_data_dir, 80 | "metadata_file": args.in_json, 81 | } 82 | ] 83 | } 84 | ] 85 | } 86 | 87 | blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) 88 | train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) 89 | else: 90 | train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) 91 | 92 | current_epoch = Value("i", 0) 93 | current_step = Value("i", 0) 94 | ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None 95 | collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) 96 | 97 | # acceleratorを準備する 98 | print("prepare accelerator") 99 | accelerator = train_util.prepare_accelerator(args) 100 | 101 | # mixed precisionに対応した型を用意しておき適宜castする 102 | weight_dtype, _ = train_util.prepare_dtype(args) 103 | 104 | # モデルを読み込む 105 | print("load model") 106 | if args.sdxl: 107 | (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) 108 | text_encoders = [text_encoder1, text_encoder2] 109 | else: 110 | text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) 111 | text_encoders = [text_encoder1] 112 | 113 | for text_encoder in text_encoders: 114 | text_encoder.to(accelerator.device, dtype=weight_dtype) 115 | text_encoder.requires_grad_(False) 116 | text_encoder.eval() 117 | 118 | # dataloaderを準備する 119 | train_dataset_group.set_caching_mode("text") 120 | 121 | # DataLoaderのプロセス数:0はメインプロセスになる 122 | n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで 123 | 124 | train_dataloader = torch.utils.data.DataLoader( 125 | train_dataset_group, 126 | batch_size=1, 127 | shuffle=True, 128 | collate_fn=collator, 129 | num_workers=n_workers, 130 | persistent_workers=args.persistent_data_loader_workers, 131 | ) 132 | 133 | # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず 134 | train_dataloader = accelerator.prepare(train_dataloader) 135 | 136 | # データ取得のためのループ 137 | for batch in tqdm(train_dataloader): 138 | absolute_paths = batch["absolute_paths"] 139 | input_ids1_list = batch["input_ids1_list"] 140 | input_ids2_list = batch["input_ids2_list"] 141 | 142 | image_infos = [] 143 | for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list): 144 | image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) 145 | image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX 146 | image_info 147 | 148 | if args.skip_existing: 149 | if os.path.exists(image_info.text_encoder_outputs_npz): 150 | print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") 151 | continue 152 | 153 | image_info.input_ids1 = input_ids1 154 | image_info.input_ids2 = input_ids2 155 | image_infos.append(image_info) 156 | 157 | if len(image_infos) > 0: 158 | b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos]) 159 | b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos]) 160 | train_util.cache_batch_text_encoder_outputs( 161 | image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype 162 | ) 163 | 164 | accelerator.wait_for_everyone() 165 | accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") 166 | 167 | 168 | def setup_parser() -> argparse.ArgumentParser: 169 | parser = argparse.ArgumentParser() 170 | 171 | train_util.add_sd_models_arguments(parser) 172 | train_util.add_training_arguments(parser, True) 173 | train_util.add_dataset_arguments(parser, True, True, True) 174 | config_util.add_config_arguments(parser) 175 | sdxl_train_util.add_sdxl_training_arguments(parser) 176 | parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") 177 | parser.add_argument( 178 | "--skip_existing", 179 | action="store_true", 180 | help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", 181 | ) 182 | return parser 183 | 184 | 185 | if __name__ == "__main__": 186 | parser = setup_parser() 187 | 188 | args = parser.parse_args() 189 | args = train_util.read_config_from_file(args, parser) 190 | 191 | cache_to_disk(args) 192 | -------------------------------------------------------------------------------- /tools/canny.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | 4 | 5 | def canny(args): 6 | img = cv2.imread(args.input) 7 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 8 | 9 | canny_img = cv2.Canny(img, args.thres1, args.thres2) 10 | # canny_img = 255 - canny_img 11 | 12 | cv2.imwrite(args.output, canny_img) 13 | print("done!") 14 | 15 | 16 | def setup_parser() -> argparse.ArgumentParser: 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--input", type=str, default=None, help="input path") 19 | parser.add_argument("--output", type=str, default=None, help="output path") 20 | parser.add_argument("--thres1", type=int, default=32, help="thres1") 21 | parser.add_argument("--thres2", type=int, default=224, help="thres2") 22 | 23 | return parser 24 | 25 | 26 | if __name__ == '__main__': 27 | parser = setup_parser() 28 | 29 | args = parser.parse_args() 30 | canny(args) 31 | -------------------------------------------------------------------------------- /tools/convert_diffusers20_original_sd.py: -------------------------------------------------------------------------------- 1 | # convert Diffusers v1.x/v2.0 model to original Stable Diffusion 2 | 3 | import argparse 4 | import os 5 | import torch 6 | from diffusers import StableDiffusionPipeline 7 | 8 | import library.model_util as model_util 9 | 10 | 11 | def convert(args): 12 | # 引数を確認する 13 | load_dtype = torch.float16 if args.fp16 else None 14 | 15 | save_dtype = None 16 | if args.fp16 or args.save_precision_as == "fp16": 17 | save_dtype = torch.float16 18 | elif args.bf16 or args.save_precision_as == "bf16": 19 | save_dtype = torch.bfloat16 20 | elif args.float or args.save_precision_as == "float": 21 | save_dtype = torch.float 22 | 23 | is_load_ckpt = os.path.isfile(args.model_to_load) 24 | is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 25 | 26 | assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" 27 | # assert ( 28 | # is_save_ckpt or args.reference_model is not None 29 | # ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" 30 | 31 | # モデルを読み込む 32 | msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) 33 | print(f"loading {msg}: {args.model_to_load}") 34 | 35 | if is_load_ckpt: 36 | v2_model = args.v2 37 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint( 38 | v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection 39 | ) 40 | else: 41 | pipe = StableDiffusionPipeline.from_pretrained( 42 | args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant 43 | ) 44 | text_encoder = pipe.text_encoder 45 | vae = pipe.vae 46 | unet = pipe.unet 47 | 48 | if args.v1 == args.v2: 49 | # 自動判定する 50 | v2_model = unet.config.cross_attention_dim == 1024 51 | print("checking model version: model is " + ("v2" if v2_model else "v1")) 52 | else: 53 | v2_model = not args.v1 54 | 55 | # 変換して保存する 56 | msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" 57 | print(f"converting and saving as {msg}: {args.model_to_save}") 58 | 59 | if is_save_ckpt: 60 | original_model = args.model_to_load if is_load_ckpt else None 61 | key_count = model_util.save_stable_diffusion_checkpoint( 62 | v2_model, 63 | args.model_to_save, 64 | text_encoder, 65 | unet, 66 | original_model, 67 | args.epoch, 68 | args.global_step, 69 | None if args.metadata is None else eval(args.metadata), 70 | save_dtype=save_dtype, 71 | vae=vae, 72 | ) 73 | print(f"model saved. total converted state_dict keys: {key_count}") 74 | else: 75 | print( 76 | f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}" 77 | ) 78 | model_util.save_diffusers_checkpoint( 79 | v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors 80 | ) 81 | print("model saved.") 82 | 83 | 84 | def setup_parser() -> argparse.ArgumentParser: 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument( 87 | "--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む" 88 | ) 89 | parser.add_argument( 90 | "--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む" 91 | ) 92 | parser.add_argument( 93 | "--unet_use_linear_projection", 94 | action="store_true", 95 | help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)", 96 | ) 97 | parser.add_argument( 98 | "--fp16", 99 | action="store_true", 100 | help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)", 101 | ) 102 | parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)") 103 | parser.add_argument( 104 | "--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)" 105 | ) 106 | parser.add_argument( 107 | "--save_precision_as", 108 | type=str, 109 | default="no", 110 | choices=["fp16", "bf16", "float"], 111 | help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください", 112 | ) 113 | parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値") 114 | parser.add_argument( 115 | "--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値" 116 | ) 117 | parser.add_argument( 118 | "--metadata", 119 | type=str, 120 | default=None, 121 | help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'', 122 | ) 123 | parser.add_argument( 124 | "--variant", 125 | type=str, 126 | default=None, 127 | help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16", 128 | ) 129 | parser.add_argument( 130 | "--reference_model", 131 | type=str, 132 | default=None, 133 | help="scheduler/tokenizerのコピー元Diffusersモデル、Diffusers形式で保存するときに使用される、省略時は`runwayml/stable-diffusion-v1-5` または `stabilityai/stable-diffusion-2-1` / reference Diffusers model to copy scheduler/tokenizer config from, used when saving as Diffusers format, default is `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1`", 134 | ) 135 | parser.add_argument( 136 | "--use_safetensors", 137 | action="store_true", 138 | help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)", 139 | ) 140 | 141 | parser.add_argument( 142 | "model_to_load", 143 | type=str, 144 | default=None, 145 | help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ", 146 | ) 147 | parser.add_argument( 148 | "model_to_save", 149 | type=str, 150 | default=None, 151 | help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存", 152 | ) 153 | return parser 154 | 155 | 156 | if __name__ == "__main__": 157 | parser = setup_parser() 158 | 159 | args = parser.parse_args() 160 | convert(args) 161 | -------------------------------------------------------------------------------- /tools/detect_face_rotate.py: -------------------------------------------------------------------------------- 1 | # このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします 2 | # (c) 2022 Kohya S. @kohya_ss 3 | 4 | # 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す 5 | 6 | # v2: extract max face if multiple faces are found 7 | # v3: add crop_ratio option 8 | # v4: add multiple faces extraction and min/max size 9 | 10 | import argparse 11 | import math 12 | import cv2 13 | import glob 14 | import os 15 | from anime_face_detector import create_detector 16 | from tqdm import tqdm 17 | import numpy as np 18 | 19 | KP_REYE = 11 20 | KP_LEYE = 19 21 | 22 | SCORE_THRES = 0.90 23 | 24 | 25 | def detect_faces(detector, image, min_size): 26 | preds = detector(image) # bgr 27 | # print(len(preds)) 28 | 29 | faces = [] 30 | for pred in preds: 31 | bb = pred['bbox'] 32 | score = bb[-1] 33 | if score < SCORE_THRES: 34 | continue 35 | 36 | left, top, right, bottom = bb[:4] 37 | cx = int((left + right) / 2) 38 | cy = int((top + bottom) / 2) 39 | fw = int(right - left) 40 | fh = int(bottom - top) 41 | 42 | lex, ley = pred['keypoints'][KP_LEYE, 0:2] 43 | rex, rey = pred['keypoints'][KP_REYE, 0:2] 44 | angle = math.atan2(ley - rey, lex - rex) 45 | angle = angle / math.pi * 180 46 | 47 | faces.append((cx, cy, fw, fh, angle)) 48 | 49 | faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順 50 | return faces 51 | 52 | 53 | def rotate_image(image, angle, cx, cy): 54 | h, w = image.shape[0:2] 55 | rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) 56 | 57 | # # 回転する分、すこし画像サイズを大きくする→とりあえず無効化 58 | # nh = max(h, int(w * math.sin(angle))) 59 | # nw = max(w, int(h * math.sin(angle))) 60 | # if nh > h or nw > w: 61 | # pad_y = nh - h 62 | # pad_t = pad_y // 2 63 | # pad_x = nw - w 64 | # pad_l = pad_x // 2 65 | # m = np.array([[0, 0, pad_l], 66 | # [0, 0, pad_t]]) 67 | # rot_mat = rot_mat + m 68 | # h, w = nh, nw 69 | # cx += pad_l 70 | # cy += pad_t 71 | 72 | result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) 73 | return result, cx, cy 74 | 75 | 76 | def process(args): 77 | assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません" 78 | assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません" 79 | 80 | # アニメ顔検出モデルを読み込む 81 | print("loading face detector.") 82 | detector = create_detector('yolov3') 83 | 84 | # cropの引数を解析する 85 | if args.crop_size is None: 86 | crop_width = crop_height = None 87 | else: 88 | tokens = args.crop_size.split(',') 89 | assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください" 90 | crop_width, crop_height = [int(t) for t in tokens] 91 | 92 | if args.crop_ratio is None: 93 | crop_h_ratio = crop_v_ratio = None 94 | else: 95 | tokens = args.crop_ratio.split(',') 96 | assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください" 97 | crop_h_ratio, crop_v_ratio = [float(t) for t in tokens] 98 | 99 | # 画像を処理する 100 | print("processing.") 101 | output_extension = ".png" 102 | 103 | os.makedirs(args.dst_dir, exist_ok=True) 104 | paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \ 105 | glob.glob(os.path.join(args.src_dir, "*.webp")) 106 | for path in tqdm(paths): 107 | basename = os.path.splitext(os.path.basename(path))[0] 108 | 109 | # image = cv2.imread(path) # 日本語ファイル名でエラーになる 110 | image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED) 111 | if len(image.shape) == 2: 112 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) 113 | if image.shape[2] == 4: 114 | print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") 115 | image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい 116 | 117 | h, w = image.shape[:2] 118 | 119 | faces = detect_faces(detector, image, args.multiple_faces) 120 | for i, face in enumerate(faces): 121 | cx, cy, fw, fh, angle = face 122 | face_size = max(fw, fh) 123 | if args.min_size is not None and face_size < args.min_size: 124 | continue 125 | if args.max_size is not None and face_size >= args.max_size: 126 | continue 127 | face_suffix = f"_{i+1:02d}" if args.multiple_faces else "" 128 | 129 | # オプション指定があれば回転する 130 | face_img = image 131 | if args.rotate: 132 | face_img, cx, cy = rotate_image(face_img, angle, cx, cy) 133 | 134 | # オプション指定があれば顔を中心に切り出す 135 | if crop_width is not None or crop_h_ratio is not None: 136 | cur_crop_width, cur_crop_height = crop_width, crop_height 137 | if crop_h_ratio is not None: 138 | cur_crop_width = int(face_size * crop_h_ratio + .5) 139 | cur_crop_height = int(face_size * crop_v_ratio + .5) 140 | 141 | # リサイズを必要なら行う 142 | scale = 1.0 143 | if args.resize_face_size is not None: 144 | # 顔サイズを基準にリサイズする 145 | scale = args.resize_face_size / face_size 146 | if scale < cur_crop_width / w: 147 | print( 148 | f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") 149 | scale = cur_crop_width / w 150 | if scale < cur_crop_height / h: 151 | print( 152 | f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") 153 | scale = cur_crop_height / h 154 | elif crop_h_ratio is not None: 155 | # 倍率指定の時にはリサイズしない 156 | pass 157 | else: 158 | # 切り出しサイズ指定あり 159 | if w < cur_crop_width: 160 | print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") 161 | scale = cur_crop_width / w 162 | if h < cur_crop_height: 163 | print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") 164 | scale = cur_crop_height / h 165 | if args.resize_fit: 166 | scale = max(cur_crop_width / w, cur_crop_height / h) 167 | 168 | if scale != 1.0: 169 | w = int(w * scale + .5) 170 | h = int(h * scale + .5) 171 | face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4) 172 | cx = int(cx * scale + .5) 173 | cy = int(cy * scale + .5) 174 | fw = int(fw * scale + .5) 175 | fh = int(fh * scale + .5) 176 | 177 | cur_crop_width = min(cur_crop_width, face_img.shape[1]) 178 | cur_crop_height = min(cur_crop_height, face_img.shape[0]) 179 | 180 | x = cx - cur_crop_width // 2 181 | cx = cur_crop_width // 2 182 | if x < 0: 183 | cx = cx + x 184 | x = 0 185 | elif x + cur_crop_width > w: 186 | cx = cx + (x + cur_crop_width - w) 187 | x = w - cur_crop_width 188 | face_img = face_img[:, x:x+cur_crop_width] 189 | 190 | y = cy - cur_crop_height // 2 191 | cy = cur_crop_height // 2 192 | if y < 0: 193 | cy = cy + y 194 | y = 0 195 | elif y + cur_crop_height > h: 196 | cy = cy + (y + cur_crop_height - h) 197 | y = h - cur_crop_height 198 | face_img = face_img[y:y + cur_crop_height] 199 | 200 | # # debug 201 | # print(path, cx, cy, angle) 202 | # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8)) 203 | # cv2.imshow("image", crp) 204 | # if cv2.waitKey() == 27: 205 | # break 206 | # cv2.destroyAllWindows() 207 | 208 | # debug 209 | if args.debug: 210 | cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20) 211 | 212 | _, buf = cv2.imencode(output_extension, face_img) 213 | with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f: 214 | buf.tofile(f) 215 | 216 | 217 | def setup_parser() -> argparse.ArgumentParser: 218 | parser = argparse.ArgumentParser() 219 | parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ") 220 | parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ") 221 | parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する") 222 | parser.add_argument("--resize_fit", action="store_true", 223 | help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする") 224 | parser.add_argument("--resize_face_size", type=int, default=None, 225 | help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする") 226 | parser.add_argument("--crop_size", type=str, default=None, 227 | help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す") 228 | parser.add_argument("--crop_ratio", type=str, default=None, 229 | help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す") 230 | parser.add_argument("--min_size", type=int, default=None, 231 | help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)") 232 | parser.add_argument("--max_size", type=int, default=None, 233 | help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)") 234 | parser.add_argument("--multiple_faces", action="store_true", 235 | help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す") 236 | parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します") 237 | 238 | return parser 239 | 240 | 241 | if __name__ == '__main__': 242 | parser = setup_parser() 243 | 244 | args = parser.parse_args() 245 | 246 | process(args) 247 | -------------------------------------------------------------------------------- /tools/merge_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from safetensors import safe_open 6 | from safetensors.torch import load_file, save_file 7 | from tqdm import tqdm 8 | 9 | 10 | def is_unet_key(key): 11 | # VAE or TextEncoder, the last one is for SDXL 12 | return not ("first_stage_model" in key or "cond_stage_model" in key or "conditioner." in key) 13 | 14 | 15 | TEXT_ENCODER_KEY_REPLACEMENTS = [ 16 | ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), 17 | ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), 18 | ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), 19 | ] 20 | 21 | 22 | # support for models with different text encoder keys 23 | def replace_text_encoder_key(key): 24 | for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: 25 | if key.startswith(rep_from): 26 | return True, rep_to + key[len(rep_from) :] 27 | return False, key 28 | 29 | 30 | def merge(args): 31 | if args.precision == "fp16": 32 | dtype = torch.float16 33 | elif args.precision == "bf16": 34 | dtype = torch.bfloat16 35 | else: 36 | dtype = torch.float 37 | 38 | if args.saving_precision == "fp16": 39 | save_dtype = torch.float16 40 | elif args.saving_precision == "bf16": 41 | save_dtype = torch.bfloat16 42 | else: 43 | save_dtype = torch.float 44 | 45 | # check if all models are safetensors 46 | for model in args.models: 47 | if not model.endswith("safetensors"): 48 | print(f"Model {model} is not a safetensors model") 49 | exit() 50 | if not os.path.isfile(model): 51 | print(f"Model {model} does not exist") 52 | exit() 53 | 54 | assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models" 55 | 56 | # load and merge 57 | ratio = 1.0 / len(args.models) # default 58 | supplementary_key_ratios = {} # [key] = ratio, for keys not in all models, add later 59 | 60 | merged_sd = None 61 | first_model_keys = set() # check missing keys in other models 62 | for i, model in enumerate(args.models): 63 | if args.ratios is not None: 64 | ratio = args.ratios[i] 65 | 66 | if merged_sd is None: 67 | # load first model 68 | print(f"Loading model {model}, ratio = {ratio}...") 69 | merged_sd = {} 70 | with safe_open(model, framework="pt", device=args.device) as f: 71 | for key in tqdm(f.keys()): 72 | value = f.get_tensor(key) 73 | _, key = replace_text_encoder_key(key) 74 | 75 | first_model_keys.add(key) 76 | 77 | if not is_unet_key(key) and args.unet_only: 78 | supplementary_key_ratios[key] = 1.0 # use first model's value for VAE or TextEncoder 79 | continue 80 | 81 | value = ratio * value.to(dtype) # first model's value * ratio 82 | merged_sd[key] = value 83 | 84 | print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) 85 | continue 86 | 87 | # load other models 88 | print(f"Loading model {model}, ratio = {ratio}...") 89 | 90 | with safe_open(model, framework="pt", device=args.device) as f: 91 | model_keys = f.keys() 92 | for key in tqdm(model_keys): 93 | _, new_key = replace_text_encoder_key(key) 94 | if new_key not in merged_sd: 95 | if args.show_skipped and new_key not in first_model_keys: 96 | print(f"Skip: {new_key}") 97 | continue 98 | 99 | value = f.get_tensor(key) 100 | merged_sd[new_key] = merged_sd[new_key] + ratio * value.to(dtype) 101 | 102 | # enumerate keys not in this model 103 | model_keys = set(model_keys) 104 | for key in merged_sd.keys(): 105 | if key in model_keys: 106 | continue 107 | print(f"Key {key} not in model {model}, use first model's value") 108 | if key in supplementary_key_ratios: 109 | supplementary_key_ratios[key] += ratio 110 | else: 111 | supplementary_key_ratios[key] = ratio 112 | 113 | # add supplementary keys' value (including VAE and TextEncoder) 114 | if len(supplementary_key_ratios) > 0: 115 | print("add first model's value") 116 | with safe_open(args.models[0], framework="pt", device=args.device) as f: 117 | for key in tqdm(f.keys()): 118 | _, new_key = replace_text_encoder_key(key) 119 | if new_key not in supplementary_key_ratios: 120 | continue 121 | 122 | if is_unet_key(new_key): # not VAE or TextEncoder 123 | print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") 124 | 125 | value = f.get_tensor(key) # original key 126 | 127 | if new_key not in merged_sd: 128 | merged_sd[new_key] = supplementary_key_ratios[new_key] * value.to(dtype) 129 | else: 130 | merged_sd[new_key] = merged_sd[new_key] + supplementary_key_ratios[new_key] * value.to(dtype) 131 | 132 | # save 133 | output_file = args.output 134 | if not output_file.endswith(".safetensors"): 135 | output_file = output_file + ".safetensors" 136 | 137 | print(f"Saving to {output_file}...") 138 | 139 | # convert to save_dtype 140 | for k in merged_sd.keys(): 141 | merged_sd[k] = merged_sd[k].to(save_dtype) 142 | 143 | save_file(merged_sd, output_file) 144 | 145 | print("Done!") 146 | 147 | 148 | if __name__ == "__main__": 149 | parser = argparse.ArgumentParser(description="Merge models") 150 | parser.add_argument("--models", nargs="+", type=str, help="Models to merge") 151 | parser.add_argument("--output", type=str, help="Output model") 152 | parser.add_argument("--ratios", nargs="+", type=float, help="Ratios of models, default is equal, total = 1.0") 153 | parser.add_argument("--unet_only", action="store_true", help="Only merge unet") 154 | parser.add_argument("--device", type=str, default="cpu", help="Device to use, default is cpu") 155 | parser.add_argument( 156 | "--precision", type=str, default="float", choices=["float", "fp16", "bf16"], help="Calculation precision, default is float" 157 | ) 158 | parser.add_argument( 159 | "--saving_precision", 160 | type=str, 161 | default="float", 162 | choices=["float", "fp16", "bf16"], 163 | help="Saving precision, default is float", 164 | ) 165 | parser.add_argument("--show_skipped", action="store_true", help="Show skipped keys (keys not in first model)") 166 | 167 | args = parser.parse_args() 168 | merge(args) 169 | -------------------------------------------------------------------------------- /tools/resize_images_to_resolution.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import cv2 4 | import argparse 5 | import shutil 6 | import math 7 | from PIL import Image 8 | import numpy as np 9 | 10 | 11 | def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): 12 | # Split the max_resolution string by "," and strip any whitespaces 13 | max_resolutions = [res.strip() for res in max_resolution.split(',')] 14 | 15 | # # Calculate max_pixels from max_resolution string 16 | # max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) 17 | 18 | # Create destination folder if it does not exist 19 | if not os.path.exists(dst_img_folder): 20 | os.makedirs(dst_img_folder) 21 | 22 | # Select interpolation method 23 | if interpolation == 'lanczos4': 24 | cv2_interpolation = cv2.INTER_LANCZOS4 25 | elif interpolation == 'cubic': 26 | cv2_interpolation = cv2.INTER_CUBIC 27 | else: 28 | cv2_interpolation = cv2.INTER_AREA 29 | 30 | # Iterate through all files in src_img_folder 31 | img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py 32 | for filename in os.listdir(src_img_folder): 33 | # Check if the image is png, jpg or webp etc... 34 | if not filename.endswith(img_exts): 35 | # Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.) 36 | shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename)) 37 | continue 38 | 39 | # Load image 40 | # img = cv2.imread(os.path.join(src_img_folder, filename)) 41 | image = Image.open(os.path.join(src_img_folder, filename)) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | img = np.array(image, np.uint8) 45 | 46 | base, _ = os.path.splitext(filename) 47 | for max_resolution in max_resolutions: 48 | # Calculate max_pixels from max_resolution string 49 | max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) 50 | 51 | # Calculate current number of pixels 52 | current_pixels = img.shape[0] * img.shape[1] 53 | 54 | # Check if the image needs resizing 55 | if current_pixels > max_pixels: 56 | # Calculate scaling factor 57 | scale_factor = max_pixels / current_pixels 58 | 59 | # Calculate new dimensions 60 | new_height = int(img.shape[0] * math.sqrt(scale_factor)) 61 | new_width = int(img.shape[1] * math.sqrt(scale_factor)) 62 | 63 | # Resize image 64 | img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) 65 | else: 66 | new_height, new_width = img.shape[0:2] 67 | 68 | # Calculate the new height and width that are divisible by divisible_by (with/without resizing) 69 | new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by 70 | new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by 71 | 72 | # Center crop the image to the calculated dimensions 73 | y = int((img.shape[0] - new_height) / 2) 74 | x = int((img.shape[1] - new_width) / 2) 75 | img = img[y:y + new_height, x:x + new_width] 76 | 77 | # Split filename into base and extension 78 | new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg') 79 | 80 | # Save resized image in dst_img_folder 81 | # cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) 82 | image = Image.fromarray(img) 83 | image.save(os.path.join(dst_img_folder, new_filename), quality=100) 84 | 85 | proc = "Resized" if current_pixels > max_pixels else "Saved" 86 | print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") 87 | 88 | # If other files with same basename, copy them with resolution suffix 89 | if copy_associated_files: 90 | asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*")) 91 | for asoc_file in asoc_files: 92 | ext = os.path.splitext(asoc_file)[1] 93 | if ext in img_exts: 94 | continue 95 | for max_resolution in max_resolutions: 96 | new_asoc_file = base + '+' + max_resolution + ext 97 | print(f"Copy {asoc_file} as {new_asoc_file}") 98 | shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) 99 | 100 | 101 | def setup_parser() -> argparse.ArgumentParser: 102 | parser = argparse.ArgumentParser( 103 | description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします') 104 | parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ') 105 | parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ') 106 | parser.add_argument('--max_resolution', type=str, 107 | help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128") 108 | parser.add_argument('--divisible_by', type=int, 109 | help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1) 110 | parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'], 111 | default='area', help='Interpolation method for resizing / リサイズ時の補完方法') 112 | parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存') 113 | parser.add_argument('--copy_associated_files', action='store_true', 114 | help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする') 115 | 116 | return parser 117 | 118 | 119 | def main(): 120 | parser = setup_parser() 121 | 122 | args = parser.parse_args() 123 | resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution, 124 | args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /tools/show_metadata.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from safetensors import safe_open 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--model", type=str, required=True) 7 | args = parser.parse_args() 8 | 9 | with safe_open(args.model, framework="pt") as f: 10 | metadata = f.metadata() 11 | 12 | if metadata is None: 13 | print("No metadata found") 14 | else: 15 | # metadata is json dict, but not pretty printed 16 | # sort by key and pretty print 17 | print(json.dumps(metadata, indent=4, sort_keys=True)) 18 | 19 | --------------------------------------------------------------------------------