├── .github ├── FUNDING.yml ├── dependabot.yml └── workflows │ └── typos.yml ├── .gitignore ├── LICENSE.md ├── README-ja.md ├── README.md ├── XTI_hijack.py ├── _typos.toml ├── bitsandbytes_windows ├── cextension.py ├── libbitsandbytes_cpu.dll ├── libbitsandbytes_cuda116.dll ├── libbitsandbytes_cuda118.dll └── main.py ├── docs ├── config_README-en.md ├── config_README-ja.md ├── fine_tune_README_ja.md ├── gen_img_README-ja.md ├── masked_loss_README-ja.md ├── masked_loss_README.md ├── train_README-ja.md ├── train_README-zh.md ├── train_SDXL-en.md ├── train_db_README-ja.md ├── train_db_README-zh.md ├── train_lllite_README-ja.md ├── train_lllite_README.md ├── train_network_README-ja.md ├── train_network_README-zh.md ├── train_ti_README-ja.md ├── wd14_tagger_README-en.md └── wd14_tagger_README-ja.md ├── fine_tune.py ├── finetune ├── blip │ ├── blip.py │ ├── med.py │ ├── med_config.json │ └── vit.py ├── clean_captions_and_tags.py ├── hypernetwork_nai.py ├── make_captions.py ├── make_captions_by_git.py ├── merge_captions_to_metadata.py ├── merge_dd_tags_to_metadata.py ├── prepare_buckets_latents.py └── tag_images_by_wd14_tagger.py ├── gen_img.py ├── gen_img_diffusers.py ├── library ├── __init__.py ├── adafactor_fused.py ├── attention_processors.py ├── config_util.py ├── custom_train_functions.py ├── deepspeed_utils.py ├── device_utils.py ├── huggingface_util.py ├── hypernetwork.py ├── ipex │ ├── __init__.py │ ├── attention.py │ ├── diffusers.py │ ├── gradscaler.py │ └── hijacks.py ├── lpw_stable_diffusion.py ├── model_util.py ├── original_unet.py ├── sai_model_spec.py ├── sdxl_lpw_stable_diffusion.py ├── sdxl_model_util.py ├── sdxl_original_unet.py ├── sdxl_train_util.py ├── slicing_vae.py ├── train_util.py └── utils.py ├── networks ├── check_lora_weights.py ├── control_net_lllite.py ├── control_net_lllite_for_train.py ├── dylora.py ├── extract_lora_from_dylora.py ├── extract_lora_from_models.py ├── lora.py ├── lora_diffusers.py ├── lora_fa.py ├── lora_interrogator.py ├── merge_lora.py ├── merge_lora_old.py ├── oft.py ├── resize_lora.py ├── sdxl_merge_lora.py └── svd_merge_lora.py ├── requirements.txt ├── sdxl_gen_img.py ├── sdxl_minimal_inference.py ├── sdxl_train.py ├── sdxl_train_control_net_lllite.py ├── sdxl_train_control_net_lllite_old.py ├── sdxl_train_network.py ├── sdxl_train_textual_inversion.py ├── setup.py ├── tools ├── cache_latents.py ├── cache_text_encoder_outputs.py ├── canny.py ├── convert_diffusers20_original_sd.py ├── detect_face_rotate.py ├── latent_upscaler.py ├── merge_models.py ├── original_control_net.py ├── resize_images_to_resolution.py └── show_metadata.py ├── train_controlnet.py ├── train_db.py ├── train_network.py ├── train_textual_inversion.py └── train_textual_inversion_XTI.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: kohya-ss 4 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | --- 2 | version: 2 3 | updates: 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "monthly" 8 | -------------------------------------------------------------------------------- /.github/workflows/typos.yml: -------------------------------------------------------------------------------- 1 | --- 2 | # yamllint disable rule:line-length 3 | name: Typos 4 | 5 | on: # yamllint disable-line rule:truthy 6 | push: 7 | pull_request: 8 | types: 9 | - opened 10 | - synchronize 11 | - reopened 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | 20 | - name: typos-action 21 | uses: crate-ci/typos@v1.24.3 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | logs 2 | __pycache__ 3 | wd14_tagger_model 4 | venv 5 | *.egg-info 6 | build 7 | .vscode 8 | wandb 9 | -------------------------------------------------------------------------------- /README-ja.md: -------------------------------------------------------------------------------- 1 | ## リポジトリについて 2 | Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。 3 | 4 | [README in English](./README.md) ←更新情報はこちらにあります 5 | 6 | 開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。 7 | 8 | FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。 9 | 10 | GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。 11 | 12 | 以下のスクリプトがあります。 13 | 14 | * DreamBooth、U-NetおよびText Encoderの学習をサポート 15 | * fine-tuning、同上 16 | * LoRAの学習をサポート 17 | * 画像生成 18 | * モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換) 19 | 20 | ## 使用法について 21 | 22 | * [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど 23 | * [データセット設定](./docs/config_README-ja.md) 24 | * [SDXL学習](./docs/train_SDXL-en.md) (英語版) 25 | * [DreamBoothの学習について](./docs/train_db_README-ja.md) 26 | * [fine-tuningのガイド](./docs/fine_tune_README_ja.md): 27 | * [LoRAの学習について](./docs/train_network_README-ja.md) 28 | * [Textual Inversionの学習について](./docs/train_ti_README-ja.md) 29 | * [画像生成スクリプト](./docs/gen_img_README-ja.md) 30 | * note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad) 31 | 32 | ## Windowsでの動作に必要なプログラム 33 | 34 | Python 3.10.6およびGitが必要です。 35 | 36 | - Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe 37 | - git: https://git-scm.com/download/win 38 | 39 | Python 3.10.x、3.11.x、3.12.xでも恐らく動作しますが、3.10.6でテストしています。 40 | 41 | PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。 42 | (venvに限らずスクリプトの実行が可能になりますので注意してください。) 43 | 44 | - PowerShellを管理者として開きます。 45 | - 「Set-ExecutionPolicy Unrestricted」と入力し、Yと答えます。 46 | - 管理者のPowerShellを閉じます。 47 | 48 | ## Windows環境でのインストール 49 | 50 | スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.2以降でも恐らく動作します。 51 | 52 | (なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。) 53 | 54 | PowerShellを使う場合、通常の(管理者ではない)PowerShellを開き以下を順に実行します。 55 | 56 | ```powershell 57 | git clone https://github.com/kohya-ss/sd-scripts.git 58 | cd sd-scripts 59 | 60 | python -m venv venv 61 | .\venv\Scripts\activate 62 | 63 | pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118 64 | pip install --upgrade -r requirements.txt 65 | pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118 66 | 67 | accelerate config 68 | ``` 69 | 70 | コマンドプロンプトでも同一です。 71 | 72 | 注:`bitsandbytes==0.44.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。 73 | 74 | この例では PyTorch および xfomers は2.1.2/CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。 75 | 76 | PyTorch 2.2以降を用いる場合は、`torch==2.1.2` と `torchvision==0.16.2` 、および `xformers==0.0.23.post1` を適宜変更してください。 77 | 78 | accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。) 79 | 80 | ```txt 81 | - This machine 82 | - No distributed training 83 | - NO 84 | - NO 85 | - NO 86 | - all 87 | - fp16 88 | ``` 89 | 90 | ※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問( 91 | ``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。) 92 | 93 | ## アップグレード 94 | 95 | 新しいリリースがあった場合、以下のコマンドで更新できます。 96 | 97 | ```powershell 98 | cd sd-scripts 99 | git pull 100 | .\venv\Scripts\activate 101 | pip install --use-pep517 --upgrade -r requirements.txt 102 | ``` 103 | 104 | コマンドが成功すれば新しいバージョンが使用できます。 105 | 106 | ## 謝意 107 | 108 | LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。 109 | 110 | Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。 111 | 112 | ## ライセンス 113 | 114 | スクリプトのライセンスはASL 2.0ですが(Diffusersおよびcloneofsimo氏のリポジトリ由来のものも同様)、一部他のライセンスのコードを含みます。 115 | 116 | [Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT 117 | 118 | [bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT 119 | 120 | [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause 121 | 122 | ## その他の情報 123 | 124 | ### LoRAの名称について 125 | 126 | `train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。 127 | 128 | 1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます) 129 | 130 | Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA 131 | 132 | 2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます) 133 | 134 | 1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA 135 | 136 | デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args` に `conv_dim` を指定してください。 137 | 138 | 143 | 144 | ### 学習中のサンプル画像生成 145 | 146 | プロンプトファイルは例えば以下のようになります。 147 | 148 | ``` 149 | # prompt 1 150 | masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28 151 | 152 | # prompt 2 153 | masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40 154 | ``` 155 | 156 | `#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。 157 | 158 | * `--n` Negative prompt up to the next option. 159 | * `--w` Specifies the width of the generated image. 160 | * `--h` Specifies the height of the generated image. 161 | * `--d` Specifies the seed of the generated image. 162 | * `--l` Specifies the CFG scale of the generated image. 163 | * `--s` Specifies the number of steps in the generation. 164 | 165 | `( )` や `[ ]` などの重みづけも動作します。 166 | -------------------------------------------------------------------------------- /XTI_hijack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from library.device_utils import init_ipex 3 | init_ipex() 4 | 5 | from typing import Union, List, Optional, Dict, Any, Tuple 6 | from diffusers.models.unet_2d_condition import UNet2DConditionOutput 7 | 8 | from library.original_unet import SampleOutput 9 | 10 | 11 | def unet_forward_XTI( 12 | self, 13 | sample: torch.FloatTensor, 14 | timestep: Union[torch.Tensor, float, int], 15 | encoder_hidden_states: torch.Tensor, 16 | class_labels: Optional[torch.Tensor] = None, 17 | return_dict: bool = True, 18 | ) -> Union[Dict, Tuple]: 19 | r""" 20 | Args: 21 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 22 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 23 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 24 | return_dict (`bool`, *optional*, defaults to `True`): 25 | Whether or not to return a dict instead of a plain tuple. 26 | 27 | Returns: 28 | `SampleOutput` or `tuple`: 29 | `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. 30 | """ 31 | # By default samples have to be AT least a multiple of the overall upsampling factor. 32 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 33 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 34 | # on the fly if necessary. 35 | # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある 36 | # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する 37 | # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い 38 | default_overall_up_factor = 2**self.num_upsamplers 39 | 40 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 41 | # 64で割り切れないときはupsamplerにサイズを伝える 42 | forward_upsample_size = False 43 | upsample_size = None 44 | 45 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 46 | # logger.info("Forward upsample size to force interpolation output size.") 47 | forward_upsample_size = True 48 | 49 | # 1. time 50 | timesteps = timestep 51 | timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 52 | 53 | t_emb = self.time_proj(timesteps) 54 | 55 | # timesteps does not contain any weights and will always return f32 tensors 56 | # but time_embedding might actually be running in fp16. so we need to cast here. 57 | # there might be better ways to encapsulate this. 58 | # timestepsは重みを含まないので常にfloat32のテンソルを返す 59 | # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある 60 | # time_projでキャストしておけばいいんじゃね? 61 | t_emb = t_emb.to(dtype=self.dtype) 62 | emb = self.time_embedding(t_emb) 63 | 64 | # 2. pre-process 65 | sample = self.conv_in(sample) 66 | 67 | # 3. down 68 | down_block_res_samples = (sample,) 69 | down_i = 0 70 | for downsample_block in self.down_blocks: 71 | # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 72 | # まあこちらのほうがわかりやすいかもしれない 73 | if downsample_block.has_cross_attention: 74 | sample, res_samples = downsample_block( 75 | hidden_states=sample, 76 | temb=emb, 77 | encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2], 78 | ) 79 | down_i += 2 80 | else: 81 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 82 | 83 | down_block_res_samples += res_samples 84 | 85 | # 4. mid 86 | sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) 87 | 88 | # 5. up 89 | up_i = 7 90 | for i, upsample_block in enumerate(self.up_blocks): 91 | is_final_block = i == len(self.up_blocks) - 1 92 | 93 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 94 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection 95 | 96 | # if we have not reached the final block and need to forward the upsample size, we do it here 97 | # 前述のように最後のブロック以外ではupsample_sizeを伝える 98 | if not is_final_block and forward_upsample_size: 99 | upsample_size = down_block_res_samples[-1].shape[2:] 100 | 101 | if upsample_block.has_cross_attention: 102 | sample = upsample_block( 103 | hidden_states=sample, 104 | temb=emb, 105 | res_hidden_states_tuple=res_samples, 106 | encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3], 107 | upsample_size=upsample_size, 108 | ) 109 | up_i += 3 110 | else: 111 | sample = upsample_block( 112 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 113 | ) 114 | 115 | # 6. post-process 116 | sample = self.conv_norm_out(sample) 117 | sample = self.conv_act(sample) 118 | sample = self.conv_out(sample) 119 | 120 | if not return_dict: 121 | return (sample,) 122 | 123 | return SampleOutput(sample=sample) 124 | 125 | 126 | def downblock_forward_XTI( 127 | self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None 128 | ): 129 | output_states = () 130 | i = 0 131 | 132 | for resnet, attn in zip(self.resnets, self.attentions): 133 | if self.training and self.gradient_checkpointing: 134 | 135 | def create_custom_forward(module, return_dict=None): 136 | def custom_forward(*inputs): 137 | if return_dict is not None: 138 | return module(*inputs, return_dict=return_dict) 139 | else: 140 | return module(*inputs) 141 | 142 | return custom_forward 143 | 144 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 145 | hidden_states = torch.utils.checkpoint.checkpoint( 146 | create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] 147 | )[0] 148 | else: 149 | hidden_states = resnet(hidden_states, temb) 150 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample 151 | 152 | output_states += (hidden_states,) 153 | i += 1 154 | 155 | if self.downsamplers is not None: 156 | for downsampler in self.downsamplers: 157 | hidden_states = downsampler(hidden_states) 158 | 159 | output_states += (hidden_states,) 160 | 161 | return hidden_states, output_states 162 | 163 | 164 | def upblock_forward_XTI( 165 | self, 166 | hidden_states, 167 | res_hidden_states_tuple, 168 | temb=None, 169 | encoder_hidden_states=None, 170 | upsample_size=None, 171 | ): 172 | i = 0 173 | for resnet, attn in zip(self.resnets, self.attentions): 174 | # pop res hidden states 175 | res_hidden_states = res_hidden_states_tuple[-1] 176 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 177 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 178 | 179 | if self.training and self.gradient_checkpointing: 180 | 181 | def create_custom_forward(module, return_dict=None): 182 | def custom_forward(*inputs): 183 | if return_dict is not None: 184 | return module(*inputs, return_dict=return_dict) 185 | else: 186 | return module(*inputs) 187 | 188 | return custom_forward 189 | 190 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 191 | hidden_states = torch.utils.checkpoint.checkpoint( 192 | create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] 193 | )[0] 194 | else: 195 | hidden_states = resnet(hidden_states, temb) 196 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample 197 | 198 | i += 1 199 | 200 | if self.upsamplers is not None: 201 | for upsampler in self.upsamplers: 202 | hidden_states = upsampler(hidden_states, upsample_size) 203 | 204 | return hidden_states 205 | -------------------------------------------------------------------------------- /_typos.toml: -------------------------------------------------------------------------------- 1 | # Files for typos 2 | # Instruction: https://github.com/marketplace/actions/typos-action#getting-started 3 | 4 | [default.extend-identifiers] 5 | ddPn08="ddPn08" 6 | 7 | [default.extend-words] 8 | NIN="NIN" 9 | parms="parms" 10 | nin="nin" 11 | extention="extention" # Intentionally left 12 | nd="nd" 13 | shs="shs" 14 | sts="sts" 15 | scs="scs" 16 | cpc="cpc" 17 | coc="coc" 18 | cic="cic" 19 | msm="msm" 20 | usu="usu" 21 | ici="ici" 22 | lvl="lvl" 23 | dii="dii" 24 | muk="muk" 25 | ori="ori" 26 | hru="hru" 27 | rik="rik" 28 | koo="koo" 29 | yos="yos" 30 | wn="wn" 31 | hime="hime" 32 | 33 | 34 | [files] 35 | extend-exclude = ["_typos.toml", "venv"] 36 | -------------------------------------------------------------------------------- /bitsandbytes_windows/cextension.py: -------------------------------------------------------------------------------- 1 | import ctypes as ct 2 | from pathlib import Path 3 | from warnings import warn 4 | 5 | from .cuda_setup.main import evaluate_cuda_setup 6 | 7 | 8 | class CUDALibrary_Singleton(object): 9 | _instance = None 10 | 11 | def __init__(self): 12 | raise RuntimeError("Call get_instance() instead") 13 | 14 | def initialize(self): 15 | binary_name = evaluate_cuda_setup() 16 | package_dir = Path(__file__).parent 17 | binary_path = package_dir / binary_name 18 | 19 | if not binary_path.exists(): 20 | print(f"CUDA SETUP: TODO: compile library for specific version: {binary_name}") 21 | legacy_binary_name = "libbitsandbytes.so" 22 | print(f"CUDA SETUP: Defaulting to {legacy_binary_name}...") 23 | binary_path = package_dir / legacy_binary_name 24 | if not binary_path.exists(): 25 | print('CUDA SETUP: CUDA detection failed. Either CUDA driver not installed, CUDA not installed, or you have multiple conflicting CUDA libraries!') 26 | print('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.') 27 | raise Exception('CUDA SETUP: Setup Failed!') 28 | # self.lib = ct.cdll.LoadLibrary(binary_path) 29 | self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$ 30 | else: 31 | print(f"CUDA SETUP: Loading binary {binary_path}...") 32 | # self.lib = ct.cdll.LoadLibrary(binary_path) 33 | self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$ 34 | 35 | @classmethod 36 | def get_instance(cls): 37 | if cls._instance is None: 38 | cls._instance = cls.__new__(cls) 39 | cls._instance.initialize() 40 | return cls._instance 41 | 42 | 43 | lib = CUDALibrary_Singleton.get_instance().lib 44 | try: 45 | lib.cadam32bit_g32 46 | lib.get_context.restype = ct.c_void_p 47 | lib.get_cusparse.restype = ct.c_void_p 48 | COMPILED_WITH_CUDA = True 49 | except AttributeError: 50 | warn( 51 | "The installed version of bitsandbytes was compiled without GPU support. " 52 | "8-bit optimizers and GPU quantization are unavailable." 53 | ) 54 | COMPILED_WITH_CUDA = False 55 | -------------------------------------------------------------------------------- /bitsandbytes_windows/libbitsandbytes_cpu.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sd-scripts/a21b6a917e8ca2d0392f5861da2dddb510e389ad/bitsandbytes_windows/libbitsandbytes_cpu.dll -------------------------------------------------------------------------------- /bitsandbytes_windows/libbitsandbytes_cuda116.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sd-scripts/a21b6a917e8ca2d0392f5861da2dddb510e389ad/bitsandbytes_windows/libbitsandbytes_cuda116.dll -------------------------------------------------------------------------------- /bitsandbytes_windows/libbitsandbytes_cuda118.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sd-scripts/a21b6a917e8ca2d0392f5861da2dddb510e389ad/bitsandbytes_windows/libbitsandbytes_cuda118.dll -------------------------------------------------------------------------------- /bitsandbytes_windows/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | extract factors the build is dependent on: 3 | [X] compute capability 4 | [ ] TODO: Q - What if we have multiple GPUs of different makes? 5 | - CUDA version 6 | - Software: 7 | - CPU-only: only CPU quantization functions (no optimizer, no matrix multiple) 8 | - CuBLAS-LT: full-build 8-bit optimizer 9 | - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) 10 | 11 | evaluation: 12 | - if paths faulty, return meaningful error 13 | - else: 14 | - determine CUDA version 15 | - determine capabilities 16 | - based on that set the default path 17 | """ 18 | 19 | import ctypes 20 | 21 | from .paths import determine_cuda_runtime_lib_path 22 | 23 | 24 | def check_cuda_result(cuda, result_val): 25 | # 3. Check for CUDA errors 26 | if result_val != 0: 27 | error_str = ctypes.c_char_p() 28 | cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) 29 | print(f"CUDA exception! Error code: {error_str.value.decode()}") 30 | 31 | def get_cuda_version(cuda, cudart_path): 32 | # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION 33 | try: 34 | cudart = ctypes.CDLL(cudart_path) 35 | except OSError: 36 | # TODO: shouldn't we error or at least warn here? 37 | print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!') 38 | return None 39 | 40 | version = ctypes.c_int() 41 | check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version))) 42 | version = int(version.value) 43 | major = version//1000 44 | minor = (version-(major*1000))//10 45 | 46 | if major < 11: 47 | print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') 48 | 49 | return f'{major}{minor}' 50 | 51 | 52 | def get_cuda_lib_handle(): 53 | # 1. find libcuda.so library (GPU driver) (/usr/lib) 54 | try: 55 | cuda = ctypes.CDLL("libcuda.so") 56 | except OSError: 57 | # TODO: shouldn't we error or at least warn here? 58 | print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!') 59 | return None 60 | check_cuda_result(cuda, cuda.cuInit(0)) 61 | 62 | return cuda 63 | 64 | 65 | def get_compute_capabilities(cuda): 66 | """ 67 | 1. find libcuda.so library (GPU driver) (/usr/lib) 68 | init_device -> init variables -> call function by reference 69 | 2. call extern C function to determine CC 70 | (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html) 71 | 3. Check for CUDA errors 72 | https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api 73 | # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 74 | """ 75 | 76 | 77 | nGpus = ctypes.c_int() 78 | cc_major = ctypes.c_int() 79 | cc_minor = ctypes.c_int() 80 | 81 | device = ctypes.c_int() 82 | 83 | check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus))) 84 | ccs = [] 85 | for i in range(nGpus.value): 86 | check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i)) 87 | ref_major = ctypes.byref(cc_major) 88 | ref_minor = ctypes.byref(cc_minor) 89 | # 2. call extern C function to determine CC 90 | check_cuda_result( 91 | cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device) 92 | ) 93 | ccs.append(f"{cc_major.value}.{cc_minor.value}") 94 | 95 | return ccs 96 | 97 | 98 | # def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error 99 | def get_compute_capability(cuda): 100 | """ 101 | Extracts the highest compute capbility from all available GPUs, as compute 102 | capabilities are downwards compatible. If no GPUs are detected, it returns 103 | None. 104 | """ 105 | ccs = get_compute_capabilities(cuda) 106 | if ccs is not None: 107 | # TODO: handle different compute capabilities; for now, take the max 108 | return ccs[-1] 109 | return None 110 | 111 | 112 | def evaluate_cuda_setup(): 113 | print('') 114 | print('='*35 + 'BUG REPORT' + '='*35) 115 | print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues') 116 | print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link') 117 | print('='*80) 118 | return "libbitsandbytes_cuda116.dll" # $$$ 119 | 120 | binary_name = "libbitsandbytes_cpu.so" 121 | #if not torch.cuda.is_available(): 122 | #print('No GPU detected. Loading CPU library...') 123 | #return binary_name 124 | 125 | cudart_path = determine_cuda_runtime_lib_path() 126 | if cudart_path is None: 127 | print( 128 | "WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!" 129 | ) 130 | return binary_name 131 | 132 | print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}") 133 | cuda = get_cuda_lib_handle() 134 | cc = get_compute_capability(cuda) 135 | print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}") 136 | cuda_version_string = get_cuda_version(cuda, cudart_path) 137 | 138 | 139 | if cc == '': 140 | print( 141 | "WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..." 142 | ) 143 | return binary_name 144 | 145 | # 7.5 is the minimum CC vor cublaslt 146 | has_cublaslt = cc in ["7.5", "8.0", "8.6"] 147 | 148 | # TODO: 149 | # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) 150 | # (2) Multiple CUDA versions installed 151 | 152 | # we use ls -l instead of nvcc to determine the cuda version 153 | # since most installations will have the libcudart.so installed, but not the compiler 154 | print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}') 155 | 156 | def get_binary_name(): 157 | "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so" 158 | bin_base_name = "libbitsandbytes_cuda" 159 | if has_cublaslt: 160 | return f"{bin_base_name}{cuda_version_string}.so" 161 | else: 162 | return f"{bin_base_name}{cuda_version_string}_nocublaslt.so" 163 | 164 | binary_name = get_binary_name() 165 | 166 | return binary_name 167 | -------------------------------------------------------------------------------- /docs/fine_tune_README_ja.md: -------------------------------------------------------------------------------- 1 | NovelAIの提案した学習手法、自動キャプションニング、タグ付け、Windows+VRAM 12GB(SD v1.xの場合)環境等に対応したfine tuningです。ここでfine tuningとは、モデルを画像とキャプションで学習することを指します(LoRAやTextual Inversion、Hypernetworksは含みません) 2 | 3 | [学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。 4 | 5 | # 概要 6 | 7 | Diffusersを用いてStable DiffusionのU-Netのfine tuningを行います。NovelAIの記事にある以下の改善に対応しています(Aspect Ratio BucketingについてはNovelAIのコードを参考にしましたが、最終的なコードはすべてオリジナルです)。 8 | 9 | * CLIP(Text Encoder)の最後の層ではなく最後から二番目の層の出力を用いる。 10 | * 正方形以外の解像度での学習(Aspect Ratio Bucketing) 。 11 | * トークン長を75から225に拡張する。 12 | * BLIPによるキャプショニング(キャプションの自動作成)、DeepDanbooruまたはWD14Taggerによる自動タグ付けを行う。 13 | * Hypernetworkの学習にも対応する。 14 | * Stable Diffusion v2.0(baseおよび768/v)に対応。 15 | * VAEの出力をあらかじめ取得しディスクに保存しておくことで、学習の省メモリ化、高速化を図る。 16 | 17 | デフォルトではText Encoderの学習は行いません。モデル全体のfine tuningではU-Netだけを学習するのが一般的なようです(NovelAIもそのようです)。オプション指定でText Encoderも学習対象とできます。 18 | 19 | # 追加機能について 20 | 21 | ## CLIPの出力の変更 22 | 23 | プロンプトを画像に反映するため、テキストの特徴量への変換を行うのがCLIP(Text Encoder)です。Stable DiffusionではCLIPの最後の層の出力を用いていますが、それを最後から二番目の層の出力を用いるよう変更できます。NovelAIによると、これによりより正確にプロンプトが反映されるようになるとのことです。 24 | 元のまま、最後の層の出力を用いることも可能です。 25 | 26 | ※Stable Diffusion 2.0では最後から二番目の層をデフォルトで使います。clip_skipオプションを指定しないでください。 27 | 28 | ## 正方形以外の解像度での学習 29 | 30 | Stable Diffusionは512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくプロンプトと画像の関係が学習されることが期待されます。 31 | 学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位で縦横に調整、作成されます。 32 | 33 | 機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。 34 | 35 | ## トークン長の75から225への拡張 36 | 37 | Stable Diffusionでは最大75トークン(開始・終了を含むと77トークン)ですが、それを225トークンまで拡張します。 38 | ただしCLIPが受け付ける最大長は75トークンですので、225トークンの場合、単純に三分割してCLIPを呼び出してから結果を連結しています。 39 | 40 | ※これが望ましい実装なのかどうかはいまひとつわかりません。とりあえず動いてはいるようです。特に2.0では何も参考になる実装がないので独自に実装してあります。 41 | 42 | ※Automatic1111氏のWeb UIではカンマを意識して分割、といったこともしているようですが、私の場合はそこまでしておらず単純な分割です。 43 | 44 | # 学習の手順 45 | 46 | あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。 47 | 48 | ## データの準備 49 | 50 | [学習データの準備について](./train_README-ja.md) を参照してください。fine tuningではメタデータを用いるfine tuning方式のみ対応しています。 51 | 52 | ## 学習の実行 53 | たとえば以下のように実行します。以下は省メモリ化のための設定です。それぞれの行を必要に応じて書き換えてください。 54 | 55 | ``` 56 | accelerate launch --num_cpu_threads_per_process 1 fine_tune.py 57 | --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ> 58 | --output_dir=<学習したモデルの出力先フォルダ> 59 | --output_name=<学習したモデル出力時のファイル名> 60 | --dataset_config=<データ準備で作成した.tomlファイル> 61 | --save_model_as=safetensors 62 | --learning_rate=5e-6 --max_train_steps=10000 63 | --use_8bit_adam --xformers --gradient_checkpointing 64 | --mixed_precision=fp16 65 | ``` 66 | 67 | `num_cpu_threads_per_process` には通常は1を指定するとよいようです。 68 | 69 | `pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。 70 | 71 | `output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。 72 | 73 | `dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。 74 | 75 | 学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。 76 | 77 | 省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。 78 | 79 | オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。 80 | 81 | `xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。 82 | 83 | ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。 84 | 85 | ### よく使われるオプションについて 86 | 87 | 以下の場合にはオプションに関するドキュメントを参照してください。 88 | 89 | - Stable Diffusion 2.xまたはそこからの派生モデルを学習する 90 | - clip skipを2以上を前提としたモデルを学習する 91 | - 75トークンを超えたキャプションで学習する 92 | 93 | ### バッチサイズについて 94 | 95 | モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(DreamBoothと同じ)。 96 | 97 | ### 学習率について 98 | 99 | 1e-6から5e-6程度が一般的なようです。他のfine tuningの例なども参照してみてください。 100 | 101 | ### 以前の形式のデータセット指定をした場合のコマンドライン 102 | 103 | 解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。 104 | 105 | ``` 106 | accelerate launch --num_cpu_threads_per_process 1 fine_tune.py 107 | --pretrained_model_name_or_path=model.ckpt 108 | --in_json meta_lat.json 109 | --train_data_dir=train_data 110 | --output_dir=fine_tuned 111 | --shuffle_caption 112 | --train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000 113 | --use_8bit_adam --xformers --gradient_checkpointing 114 | --mixed_precision=bf16 115 | --save_every_n_epochs=4 116 | ``` 117 | 118 | 129 | 130 | # fine tuning特有のその他の主なオプション 131 | 132 | すべてのオプションについては別文書を参照してください。 133 | 134 | ## `train_text_encoder` 135 | Text Encoderも学習対象とします。メモリ使用量が若干増加します。 136 | 137 | 通常のfine tuningではText Encoderは学習対象としませんが(恐らくText Encoderの出力に従うようにU-Netを学習するため)、学習データ数が少ない場合には、DreamBoothのようにText Encoder側に学習させるのも有効的なようです。 138 | 139 | ## `diffusers_xformers` 140 | スクリプト独自のxformers置換機能ではなくDiffusersのxformers機能を利用します。Hypernetworkの学習はできなくなります。 141 | -------------------------------------------------------------------------------- /docs/masked_loss_README-ja.md: -------------------------------------------------------------------------------- 1 | ## マスクロスについて 2 | 3 | マスクロスは、入力画像のマスクで指定された部分だけ損失計算することで、画像の一部分だけを学習することができる機能です。 4 | たとえばキャラクタを学習したい場合、キャラクタ部分だけをマスクして学習することで、背景を無視して学習することができます。 5 | 6 | マスクロスのマスクには、二種類の指定方法があります。 7 | 8 | - マスク画像を用いる方法 9 | - 透明度(アルファチャネル)を使用する方法 10 | 11 | なお、サンプルは [ずんずんPJイラスト/3Dデータ](https://zunko.jp/con_illust.html) の「AI画像モデル用学習データ」を使用しています。 12 | 13 | ### マスク画像を用いる方法 14 | 15 | 学習画像それぞれに対応するマスク画像を用意する方法です。学習画像と同じファイル名のマスク画像を用意し、それを学習画像と別のディレクトリに保存します。 16 | 17 | - 学習画像 18 | ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441) 19 | - マスク画像 20 | ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450) 21 | 22 | ```.toml 23 | [[datasets.subsets]] 24 | image_dir = "/path/to/a_zundamon" 25 | caption_extension = ".txt" 26 | conditioning_data_dir = "/path/to/a_zundamon_mask" 27 | num_repeats = 8 28 | ``` 29 | 30 | マスク画像は、学習画像と同じサイズで、学習する部分を白、無視する部分を黒で描画します。グレースケールにも対応しています(127 ならロス重みが 0.5 になります)。なお、正確にはマスク画像の R チャネルが用いられます。 31 | 32 | DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存してください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。 33 | 34 | ### 透明度(アルファチャネル)を使用する方法 35 | 36 | 学習画像の透明度(アルファチャネル)がマスクとして使用されます。透明度が 0 の部分は無視され、255 の部分は学習されます。半透明の場合は、その透明度に応じてロス重みが変化します(127 ならおおむね 0.5)。 37 | 38 | ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e) 39 | 40 | ※それぞれの画像は透過PNG 41 | 42 | 学習時のスクリプトのオプションに `--alpha_mask` を指定するか、dataset の設定ファイルの subset で、`alpha_mask` を指定してください。たとえば、以下のようになります。 43 | 44 | ```toml 45 | [[datasets.subsets]] 46 | image_dir = "/path/to/image/dir" 47 | caption_extension = ".txt" 48 | num_repeats = 8 49 | alpha_mask = true 50 | ``` 51 | 52 | ## 学習時の注意事項 53 | 54 | - 現時点では DreamBooth 方式の dataset のみ対応しています。 55 | - マスクは latents のサイズ、つまり 1/8 に縮小されてから適用されます。そのため、細かい部分(たとえばアホ毛やイヤリングなど)はうまく学習できない可能性があります。マスクをわずかに拡張するなどの工夫が必要かもしれません。 56 | - マスクロスを用いる場合、学習対象外の部分をキャプションに含める必要はないかもしれません。(要検証) 57 | - `alpha_mask` の場合、マスクの有無を切り替えると latents キャッシュが自動的に再生成されます。 58 | -------------------------------------------------------------------------------- /docs/masked_loss_README.md: -------------------------------------------------------------------------------- 1 | ## Masked Loss 2 | 3 | Masked loss is a feature that allows you to train only part of an image by calculating the loss only for the part specified by the mask of the input image. For example, if you want to train a character, you can train only the character part by masking it, ignoring the background. 4 | 5 | There are two ways to specify the mask for masked loss. 6 | 7 | - Using a mask image 8 | - Using transparency (alpha channel) of the image 9 | 10 | The sample uses the "AI image model training data" from [ZunZunPJ Illustration/3D Data](https://zunko.jp/con_illust.html). 11 | 12 | ### Using a mask image 13 | 14 | This is a method of preparing a mask image corresponding to each training image. Prepare a mask image with the same file name as the training image and save it in a different directory from the training image. 15 | 16 | - Training image 17 | ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441) 18 | - Mask image 19 | ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450) 20 | 21 | ```.toml 22 | [[datasets.subsets]] 23 | image_dir = "/path/to/a_zundamon" 24 | caption_extension = ".txt" 25 | conditioning_data_dir = "/path/to/a_zundamon_mask" 26 | num_repeats = 8 27 | ``` 28 | 29 | The mask image is the same size as the training image, with the part to be trained drawn in white and the part to be ignored in black. It also supports grayscale (127 gives a loss weight of 0.5). The R channel of the mask image is used currently. 30 | 31 | Use the dataset in the DreamBooth method, and save the mask image in the directory specified by `conditioning_data_dir`. It is the same as the ControlNet dataset, so please refer to [ControlNet-LLLite](train_lllite_README.md#Preparing-the-dataset) for details. 32 | 33 | ### Using transparency (alpha channel) of the image 34 | 35 | The transparency (alpha channel) of the training image is used as a mask. The part with transparency 0 is ignored, the part with transparency 255 is trained. For semi-transparent parts, the loss weight changes according to the transparency (127 gives a weight of about 0.5). 36 | 37 | ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e) 38 | 39 | ※Each image is a transparent PNG 40 | 41 | Specify `--alpha_mask` in the training script options or specify `alpha_mask` in the subset of the dataset configuration file. For example, it will look like this. 42 | 43 | ```toml 44 | [[datasets.subsets]] 45 | image_dir = "/path/to/image/dir" 46 | caption_extension = ".txt" 47 | num_repeats = 8 48 | alpha_mask = true 49 | ``` 50 | 51 | ## Notes on training 52 | 53 | - At the moment, only the dataset in the DreamBooth method is supported. 54 | - The mask is applied after the size is reduced to 1/8, which is the size of the latents. Therefore, fine details (such as ahoge or earrings) may not be learned well. Some dilations of the mask may be necessary. 55 | - If using masked loss, it may not be necessary to include parts that are not to be trained in the caption. (To be verified) 56 | - In the case of `alpha_mask`, the latents cache is automatically regenerated when the enable/disable state of the mask is switched. 57 | -------------------------------------------------------------------------------- /docs/train_SDXL-en.md: -------------------------------------------------------------------------------- 1 | ## SDXL training 2 | 3 | The documentation will be moved to the training documentation in the future. The following is a brief explanation of the training scripts for SDXL. 4 | 5 | ### Training scripts for SDXL 6 | 7 | - `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset. 8 | - `--full_bf16` option is added. Thanks to KohakuBlueleaf! 9 | - This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage. 10 | - The full bfloat16 training might be unstable. Please use it at your own risk. 11 | - The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`. 12 | - 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`. 13 | - `prepare_buckets_latents.py` now supports SDXL fine-tuning. 14 | 15 | - `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`. 16 | 17 | - Both scripts has following additional options: 18 | - `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions. 19 | - `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs. 20 | 21 | - `--weighted_captions` option is not supported yet for both scripts. 22 | 23 | - `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`. 24 | - `--cache_text_encoder_outputs` is not supported. 25 | - There are two options for captions: 26 | 1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens. 27 | 2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored. 28 | - See below for the format of the embeddings. 29 | 30 | - `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000. 31 | 32 | ### Utility scripts for SDXL 33 | 34 | - `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance. 35 | - The options are almost the same as `sdxl_train.py'. See the help message for the usage. 36 | - Please launch the script as follows: 37 | `accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...` 38 | - This script should work with multi-GPU, but it is not tested in my environment. 39 | 40 | - `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance. 41 | - The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage. 42 | 43 | - `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage. 44 | 45 | ### Tips for SDXL training 46 | 47 | - The default resolution of SDXL is 1024x1024. 48 | - The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__: 49 | - Train U-Net only. 50 | - Use gradient checkpointing. 51 | - Use `--cache_text_encoder_outputs` option and caching latents. 52 | - Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work. 53 | - The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended: 54 | - Train U-Net only. 55 | - Use gradient checkpointing. 56 | - Use `--cache_text_encoder_outputs` option and caching latents. 57 | - Use one of 8bit optimizers or Adafactor optimizer. 58 | - Use lower dim (4 to 8 for 8GB GPU). 59 | - `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected. 60 | - PyTorch 2 seems to use slightly less GPU memory than PyTorch 1. 61 | - `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training. 62 | 63 | Example of the optimizer settings for Adafactor with the fixed learning rate: 64 | ```toml 65 | optimizer_type = "adafactor" 66 | optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] 67 | lr_scheduler = "constant_with_warmup" 68 | lr_warmup_steps = 100 69 | learning_rate = 4e-7 # SDXL original learning rate 70 | ``` 71 | 72 | ### Format of Textual Inversion embeddings for SDXL 73 | 74 | ```python 75 | from safetensors.torch import save_file 76 | 77 | state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768} 78 | save_file(state_dict, file) 79 | ``` 80 | 81 | ### ControlNet-LLLite 82 | 83 | ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details. 84 | 85 | -------------------------------------------------------------------------------- /docs/train_db_README-ja.md: -------------------------------------------------------------------------------- 1 | DreamBoothのガイドです。 2 | 3 | [学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。 4 | 5 | # 概要 6 | 7 | DreamBoothとは、画像生成モデルに特定の主題を追加学習し、それを特定の識別子で生成する技術です。[論文はこちら](https://arxiv.org/abs/2208.12242)。 8 | 9 | 具体的には、Stable Diffusionのモデルにキャラや画風などを学ばせ、それを `shs` のような特定の単語で呼び出せる(生成画像に出現させる)ことができます。 10 | 11 | スクリプトは[DiffusersのDreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)を元にしていますが、以下のような機能追加を行っています(いくつかの機能は元のスクリプト側もその後対応しています)。 12 | 13 | スクリプトの主な機能は以下の通りです。 14 | 15 | - 8bit Adam optimizerおよびlatentのキャッシュによる省メモリ化([Shivam Shrirao氏版](https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth)と同様)。 16 | - xformersによる省メモリ化。 17 | - 512x512だけではなく任意サイズでの学習。 18 | - augmentationによる品質の向上。 19 | - DreamBoothだけではなくText Encoder+U-Netのfine tuningに対応。 20 | - Stable Diffusion形式でのモデルの読み書き。 21 | - Aspect Ratio Bucketing。 22 | - Stable Diffusion v2.0対応。 23 | 24 | # 学習の手順 25 | 26 | あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。 27 | 28 | ## データの準備 29 | 30 | [学習データの準備について](./train_README-ja.md) を参照してください。 31 | 32 | ## 学習の実行 33 | 34 | スクリプトを実行します。最大限、メモリを節約したコマンドは以下のようになります(実際には1行で入力します)。それぞれの行を必要に応じて書き換えてください。12GB程度のVRAMで動作するようです。 35 | 36 | ``` 37 | accelerate launch --num_cpu_threads_per_process 1 train_db.py 38 | --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ> 39 | --dataset_config=<データ準備で作成した.tomlファイル> 40 | --output_dir=<学習したモデルの出力先フォルダ> 41 | --output_name=<学習したモデル出力時のファイル名> 42 | --save_model_as=safetensors 43 | --prior_loss_weight=1.0 44 | --max_train_steps=1600 45 | --learning_rate=1e-6 46 | --optimizer_type="AdamW8bit" 47 | --xformers 48 | --mixed_precision="fp16" 49 | --cache_latents 50 | --gradient_checkpointing 51 | ``` 52 | 53 | `num_cpu_threads_per_process` には通常は1を指定するとよいようです。 54 | 55 | `pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。 56 | 57 | `output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。 58 | 59 | `dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。 60 | 61 | `prior_loss_weight` は正則化画像のlossの重みです。通常は1.0を指定します。 62 | 63 | 学習させるステップ数 `max_train_steps` を1600とします。学習率 `learning_rate` はここでは1e-6を指定しています。 64 | 65 | 省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。 66 | 67 | オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。 68 | 69 | `xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。 70 | 71 | 省メモリ化のため `cache_latents` オプションを指定してVAEの出力をキャッシュします。 72 | 73 | ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。また `cache_latents` を外すことで augmentation が可能になります。 74 | 75 | ### よく使われるオプションについて 76 | 77 | 以下の場合には [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」を参照してください。 78 | 79 | - Stable Diffusion 2.xまたはそこからの派生モデルを学習する 80 | - clip skipを2以上を前提としたモデルを学習する 81 | - 75トークンを超えたキャプションで学習する 82 | 83 | ### DreamBoothでのステップ数について 84 | 85 | 当スクリプトでは省メモリ化のため、ステップ当たりの学習回数が元のスクリプトの半分になっています(対象の画像と正則化画像を同一のバッチではなく別のバッチに分割して学習するため)。 86 | 87 | 元のDiffusers版やXavierXiao氏のStable Diffusion版とほぼ同じ学習を行うには、ステップ数を倍にしてください。 88 | 89 | (学習画像と正則化画像をまとめてから shuffle するため厳密にはデータの順番が変わってしまいますが、学習には大きな影響はないと思います。) 90 | 91 | ### DreamBoothでのバッチサイズについて 92 | 93 | モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(fine tuningと同じ)。 94 | 95 | ### 学習率について 96 | 97 | Diffusers版では5e-6ですがStable Diffusion版は1e-6ですので、上のサンプルでは1e-6を指定しています。 98 | 99 | ### 以前の形式のデータセット指定をした場合のコマンドライン 100 | 101 | 解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。 102 | 103 | ``` 104 | accelerate launch --num_cpu_threads_per_process 1 train_db.py 105 | --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ> 106 | --train_data_dir=<学習用データのディレクトリ> 107 | --reg_data_dir=<正則化画像のディレクトリ> 108 | --output_dir=<学習したモデルの出力先ディレクトリ> 109 | --output_name=<学習したモデル出力時のファイル名> 110 | --prior_loss_weight=1.0 111 | --resolution=512 112 | --train_batch_size=1 113 | --learning_rate=1e-6 114 | --max_train_steps=1600 115 | --use_8bit_adam 116 | --xformers 117 | --mixed_precision="bf16" 118 | --cache_latents 119 | --gradient_checkpointing 120 | ``` 121 | 122 | ## 学習したモデルで画像生成する 123 | 124 | 学習が終わると指定したフォルダに指定した名前でsafetensorsファイルが出力されます。 125 | 126 | v1.4/1.5およびその他の派生モデルの場合、このモデルでAutomatic1111氏のWebUIなどで推論できます。models\Stable-diffusionフォルダに置いてください。 127 | 128 | v2.xモデルでWebUIで画像生成する場合、モデルの仕様が記述された.yamlファイルが別途必要になります。v2.x baseの場合はv2-inference.yamlを、768/vの場合はv2-inference-v.yamlを、同じフォルダに置き、拡張子の前の部分をモデルと同じ名前にしてください。 129 | 130 | ![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png) 131 | 132 | 各yamlファイルは[Stability AIのSD2.0のリポジトリ](https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion)にあります。 133 | 134 | # DreamBooth特有のその他の主なオプション 135 | 136 | すべてのオプションについては別文書を参照してください。 137 | 138 | ## Text Encoderの学習を途中から行わない --stop_text_encoder_training 139 | 140 | stop_text_encoder_trainingオプションに数値を指定すると、そのステップ数以降はText Encoderの学習を行わずU-Netだけ学習します。場合によっては精度の向上が期待できるかもしれません。 141 | 142 | (恐らくText Encoderだけ先に過学習することがあり、それを防げるのではないかと推測していますが、詳細な影響は不明です。) 143 | 144 | ## Tokenizerのパディングをしない --no_token_padding 145 | no_token_paddingオプションを指定するとTokenizerの出力をpaddingしません(Diffusers版の旧DreamBoothと同じ動きになります)。 146 | 147 | 148 | 168 | -------------------------------------------------------------------------------- /docs/train_db_README-zh.md: -------------------------------------------------------------------------------- 1 | 这是DreamBooth的指南。 2 | 3 | 请同时查看[关于学习的通用文档](./train_README-zh.md)。 4 | 5 | # 概要 6 | 7 | DreamBooth是一种将特定主题添加到图像生成模型中进行学习,并使用特定识别子生成它的技术。论文链接。 8 | 9 | 具体来说,它可以将角色和绘画风格等添加到Stable Diffusion模型中进行学习,并使用特定的单词(例如`shs`)来调用(呈现在生成的图像中)。 10 | 11 | 脚本基于Diffusers的DreamBooth,但添加了以下功能(一些功能已在原始脚本中得到支持)。 12 | 13 | 脚本的主要功能如下: 14 | 15 | - 使用8位Adam优化器和潜在变量的缓存来节省内存(与Shivam Shrirao版相似)。 16 | - 使用xformers来节省内存。 17 | - 不仅支持512x512,还支持任意尺寸的训练。 18 | - 通过数据增强来提高质量。 19 | - 支持DreamBooth和Text Encoder + U-Net的微调。 20 | - 支持以Stable Diffusion格式读写模型。 21 | - 支持Aspect Ratio Bucketing。 22 | - 支持Stable Diffusion v2.0。 23 | 24 | # 训练步骤 25 | 26 | 请先参阅此存储库的README以进行环境设置。 27 | 28 | ## 准备数据 29 | 30 | 请参阅[有关准备训练数据的说明](./train_README-zh.md)。 31 | 32 | ## 运行训练 33 | 34 | 运行脚本。以下是最大程度地节省内存的命令(实际上,这将在一行中输入)。请根据需要修改每行。它似乎需要约12GB的VRAM才能运行。 35 | ``` 36 | accelerate launch --num_cpu_threads_per_process 1 train_db.py 37 | --pretrained_model_name_or_path=<.ckpt或.safetensord或Diffusers版模型的目录> 38 | --dataset_config=<数据准备时创建的.toml文件> 39 | --output_dir=<训练模型的输出目录> 40 | --output_name=<训练模型输出时的文件名> 41 | --save_model_as=safetensors 42 | --prior_loss_weight=1.0 43 | --max_train_steps=1600 44 | --learning_rate=1e-6 45 | --optimizer_type="AdamW8bit" 46 | --xformers 47 | --mixed_precision="fp16" 48 | --cache_latents 49 | --gradient_checkpointing 50 | ``` 51 | `num_cpu_threads_per_process` 通常应该设置为1。 52 | 53 | `pretrained_model_name_or_path` 指定要进行追加训练的基础模型。可以指定 Stable Diffusion 的 checkpoint 文件(.ckpt 或 .safetensors)、Diffusers 的本地模型目录或模型 ID(如 "stabilityai/stable-diffusion-2")。 54 | 55 | `output_dir` 指定保存训练后模型的文件夹。在 `output_name` 中指定模型文件名,不包括扩展名。使用 `save_model_as` 指定以 safetensors 格式保存。 56 | 57 | 在 `dataset_config` 中指定 `.toml` 文件。初始批处理大小应为 `1`,以减少内存消耗。 58 | 59 | `prior_loss_weight` 是正则化图像损失的权重。通常设为1.0。 60 | 61 | 将要训练的步数 `max_train_steps` 设置为1600。在这里,学习率 `learning_rate` 被设置为1e-6。 62 | 63 | 为了节省内存,设置 `mixed_precision="fp16"`(在 RTX30 系列及更高版本中也可以设置为 `bf16`)。同时指定 `gradient_checkpointing`。 64 | 65 | 为了使用内存消耗较少的 8bit AdamW 优化器(将模型优化为适合于训练数据的状态),指定 `optimizer_type="AdamW8bit"`。 66 | 67 | 指定 `xformers` 选项,并使用 xformers 的 CrossAttention。如果未安装 xformers 或出现错误(具体情况取决于环境,例如使用 `mixed_precision="no"`),则可以指定 `mem_eff_attn` 选项以使用省内存版的 CrossAttention(速度会变慢)。 68 | 69 | 为了节省内存,指定 `cache_latents` 选项以缓存 VAE 的输出。 70 | 71 | 如果有足够的内存,请编辑 `.toml` 文件将批处理大小增加到大约 `4`(可能会提高速度和精度)。此外,取消 `cache_latents` 选项可以进行数据增强。 72 | 73 | ### 常用选项 74 | 75 | 对于以下情况,请参阅“常用选项”部分。 76 | 77 | - 学习 Stable Diffusion 2.x 或其衍生模型。 78 | - 学习基于 clip skip 大于等于2的模型。 79 | - 学习超过75个令牌的标题。 80 | 81 | ### 关于DreamBooth中的步数 82 | 83 | 为了实现省内存化,该脚本中每个步骤的学习次数减半(因为学习和正则化的图像在训练时被分为不同的批次)。 84 | 85 | 要进行与原始Diffusers版或XavierXiao的Stable Diffusion版几乎相同的学习,请将步骤数加倍。 86 | 87 | (虽然在将学习图像和正则化图像整合后再打乱顺序,但我认为对学习没有太大影响。) 88 | 89 | 关于DreamBooth的批量大小 90 | 91 | 与像LoRA这样的学习相比,为了训练整个模型,内存消耗量会更大(与微调相同)。 92 | 93 | 关于学习率 94 | 95 | 在Diffusers版中,学习率为5e-6,而在Stable Diffusion版中为1e-6,因此在上面的示例中指定了1e-6。 96 | 97 | 当使用旧格式的数据集指定命令行时 98 | 99 | 使用选项指定分辨率和批量大小。命令行示例如下。 100 | ``` 101 | accelerate launch --num_cpu_threads_per_process 1 train_db.py 102 | --pretrained_model_name_or_path=<.ckpt或.safetensord或Diffusers版模型的目录> 103 | --train_data_dir=<训练数据的目录> 104 | --reg_data_dir=<正则化图像的目录> 105 | --output_dir=<训练后模型的输出目录> 106 | --output_name=<训练后模型输出文件的名称> 107 | --prior_loss_weight=1.0 108 | --resolution=512 109 | --train_batch_size=1 110 | --learning_rate=1e-6 111 | --max_train_steps=1600 112 | --use_8bit_adam 113 | --xformers 114 | --mixed_precision="bf16" 115 | --cache_latents 116 | --gradient_checkpointing 117 | ``` 118 | 119 | ## 使用训练好的模型生成图像 120 | 121 | 训练完成后,将在指定的文件夹中以指定的名称输出safetensors文件。 122 | 123 | 对于v1.4/1.5和其他派生模型,可以在此模型中使用Automatic1111先生的WebUI进行推断。请将其放置在models\Stable-diffusion文件夹中。 124 | 125 | 对于使用v2.x模型在WebUI中生成图像的情况,需要单独的.yaml文件来描述模型的规格。对于v2.x base,需要v2-inference.yaml,对于768/v,则需要v2-inference-v.yaml。请将它们放置在相同的文件夹中,并将文件扩展名之前的部分命名为与模型相同的名称。 126 | ![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png) 127 | 128 | 每个yaml文件都在[Stability AI的SD2.0存储库](https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion)……之中。 129 | 130 | # DreamBooth的其他主要选项 131 | 132 | 有关所有选项的详细信息,请参阅另一份文档。 133 | 134 | ## 不在中途开始对文本编码器进行训练 --stop_text_encoder_training 135 | 136 | 如果在stop_text_encoder_training选项中指定一个数字,则在该步骤之后,将不再对文本编码器进行训练,只会对U-Net进行训练。在某些情况下,可能会期望提高精度。 137 | 138 | (我们推测可能会有时候仅仅文本编码器会过度学习,而这样做可以避免这种情况,但详细影响尚不清楚。) 139 | 140 | ## 不进行分词器的填充 --no_token_padding 141 | 142 | 如果指定no_token_padding选项,则不会对分词器的输出进行填充(与Diffusers版本的旧DreamBooth相同)。 143 | 144 | 163 | -------------------------------------------------------------------------------- /docs/train_lllite_README-ja.md: -------------------------------------------------------------------------------- 1 | # ControlNet-LLLite について 2 | 3 | __きわめて実験的な実装のため、将来的に大きく変更される可能性があります。__ 4 | 5 | ## 概要 6 | ControlNet-LLLite は、[ControlNet](https://github.com/lllyasviel/ControlNet) の軽量版です。LoRA Like Lite という意味で、LoRAからインスピレーションを得た構造を持つ、軽量なControlNetです。現在はSDXLにのみ対応しています。 7 | 8 | ## サンプルの重みファイルと推論 9 | 10 | こちらにあります: https://huggingface.co/kohya-ss/controlnet-lllite 11 | 12 | ComfyUIのカスタムノードを用意しています。: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI 13 | 14 | 生成サンプルはこのページの末尾にあります。 15 | 16 | ## モデル構造 17 | ひとつのLLLiteモジュールは、制御用画像(以下conditioning image)を潜在空間に写像するconditioning image embeddingと、LoRAにちょっと似た構造を持つ小型のネットワークからなります。LLLiteモジュールを、LoRAと同様にU-NetのLinearやConvに追加します。詳しくはソースコードを参照してください。 18 | 19 | 推論環境の制限で、現在はCrossAttentionのみ(attn1のq/k/v、attn2のq)に追加されます。 20 | 21 | ## モデルの学習 22 | 23 | ### データセットの準備 24 | DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。 25 | 26 | (finetuning 方式の dataset はサポートしていません。) 27 | 28 | conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。 29 | 30 | たとえば、キャプションにフォルダ名ではなくキャプションファイルを用いる場合の設定ファイルは以下のようになります。 31 | 32 | ```toml 33 | [[datasets.subsets]] 34 | image_dir = "path/to/image/dir" 35 | caption_extension = ".txt" 36 | conditioning_data_dir = "path/to/conditioning/image/dir" 37 | ``` 38 | 39 | 現時点の制約として、random_cropは使用できません。 40 | 41 | 学習データとしては、元のモデルで生成した画像を学習用画像として、そこから加工した画像をconditioning imageとした、合成によるデータセットを用いるのがもっとも簡単です(データセットの品質的には問題があるかもしれません)。具体的なデータセットの合成方法については後述します。 42 | 43 | なお、元モデルと異なる画風の画像を学習用画像とすると、制御に加えて、その画風についても学ぶ必要が生じます。ControlNet-LLLiteは容量が少ないため、画風学習には不向きです。このような場合には、後述の次元数を多めにしてください。 44 | 45 | ### 学習 46 | スクリプトで生成する場合は、`sdxl_train_control_net_lllite.py` を実行してください。`--cond_emb_dim` でconditioning image embeddingの次元数を指定できます。`--network_dim` でLoRA的モジュールのrankを指定できます。その他のオプションは`sdxl_train_network.py`に準じますが、`--network_module`の指定は不要です。 47 | 48 | 学習時にはメモリを大量に使用しますので、キャッシュやgradient checkpointingなどの省メモリ化のオプションを有効にしてください。また`--full_bf16` オプションで、BFloat16を使用するのも有効です(RTX 30シリーズ以降のGPUが必要です)。24GB VRAMで動作確認しています。 49 | 50 | conditioning image embeddingの次元数は、サンプルのCannyでは32を指定しています。LoRA的モジュールのrankは同じく64です。対象とするconditioning imageの特徴に合わせて調整してください。 51 | 52 | (サンプルのCannyは恐らくかなり難しいと思われます。depthなどでは半分程度にしてもいいかもしれません。) 53 | 54 | 以下は .toml の設定例です。 55 | 56 | ```toml 57 | pretrained_model_name_or_path = "/path/to/model_trained_on.safetensors" 58 | max_train_epochs = 12 59 | max_data_loader_n_workers = 4 60 | persistent_data_loader_workers = true 61 | seed = 42 62 | gradient_checkpointing = true 63 | mixed_precision = "bf16" 64 | save_precision = "bf16" 65 | full_bf16 = true 66 | optimizer_type = "adamw8bit" 67 | learning_rate = 2e-4 68 | xformers = true 69 | output_dir = "/path/to/output/dir" 70 | output_name = "output_name" 71 | save_every_n_epochs = 1 72 | save_model_as = "safetensors" 73 | vae_batch_size = 4 74 | cache_latents = true 75 | cache_latents_to_disk = true 76 | cache_text_encoder_outputs = true 77 | cache_text_encoder_outputs_to_disk = true 78 | network_dim = 64 79 | cond_emb_dim = 32 80 | dataset_config = "/path/to/dataset.toml" 81 | ``` 82 | 83 | ### 推論 84 | 85 | スクリプトで生成する場合は、`sdxl_gen_img.py` を実行してください。`--control_net_lllite_models` でLLLiteのモデルファイルを指定できます。次元数はモデルファイルから自動取得します。 86 | 87 | `--guide_image_path`で推論に用いるconditioning imageを指定してください。なおpreprocessは行われないため、たとえばCannyならCanny処理を行った画像を指定してください(背景黒に白線)。`--control_net_preps`, `--control_net_weights`, `--control_net_ratios` には未対応です。 88 | 89 | ## データセットの合成方法 90 | 91 | ### 学習用画像の生成 92 | 93 | 学習のベースとなるモデルで画像生成を行います。Web UIやComfyUIなどで生成してください。画像サイズはモデルのデフォルトサイズで良いと思われます(1024x1024など)。bucketingを用いることもできます。その場合は適宜適切な解像度で生成してください。 94 | 95 | 生成時のキャプション等は、ControlNet-LLLiteの利用時に生成したい画像にあわせるのが良いと思われます。 96 | 97 | 生成した画像を任意のディレクトリに保存してください。このディレクトリをデータセットの設定ファイルで指定します。 98 | 99 | 当リポジトリ内の `sdxl_gen_img.py` でも生成できます。例えば以下のように実行します。 100 | 101 | ```dos 102 | python sdxl_gen_img.py --ckpt path/to/model.safetensors --n_iter 1 --scale 10 --steps 36 --outdir path/to/output/dir --xformers --W 1024 --H 1024 --original_width 2048 --original_height 2048 --bf16 --sampler ddim --batch_size 4 --vae_batch_size 2 --images_per_prompt 512 --max_embeddings_multiples 1 --prompt "{portrait|digital art|anime screen cap|detailed illustration} of 1girl, {standing|sitting|walking|running|dancing} on {classroom|street|town|beach|indoors|outdoors}, {looking at viewer|looking away|looking at another}, {in|wearing} {shirt and skirt|school uniform|casual wear} { |, dynamic pose}, (solo), teen age, {0-1$$smile,|blush,|kind smile,|expression less,|happy,|sadness,} {0-1$$upper body,|full body,|cowboy shot,|face focus,} trending on pixiv, {0-2$$depth of fields,|8k wallpaper,|highly detailed,|pov,} {0-1$$summer, |winter, |spring, |autumn, } beautiful face { |, from below|, from above|, from side|, from behind|, from back} --n nsfw, bad face, lowres, low quality, worst quality, low effort, watermark, signature, ugly, poorly drawn" 103 | ``` 104 | 105 | VRAM 24GBの設定です。VRAMサイズにより`--batch_size` `--vae_batch_size`を調整してください。 106 | 107 | `--prompt`でワイルドカードを利用してランダムに生成しています。適宜調整してください。 108 | 109 | ### 画像の加工 110 | 111 | 外部のプログラムを用いて、生成した画像を加工します。加工した画像を任意のディレクトリに保存してください。これらがconditioning imageになります。 112 | 113 | 加工にはたとえばCannyなら以下のようなスクリプトが使えます。 114 | 115 | ```python 116 | import glob 117 | import os 118 | import random 119 | import cv2 120 | import numpy as np 121 | 122 | IMAGES_DIR = "path/to/generated/images" 123 | CANNY_DIR = "path/to/canny/images" 124 | 125 | os.makedirs(CANNY_DIR, exist_ok=True) 126 | img_files = glob.glob(IMAGES_DIR + "/*.png") 127 | for img_file in img_files: 128 | can_file = CANNY_DIR + "/" + os.path.basename(img_file) 129 | if os.path.exists(can_file): 130 | print("Skip: " + img_file) 131 | continue 132 | 133 | print(img_file) 134 | 135 | img = cv2.imread(img_file) 136 | 137 | # random threshold 138 | # while True: 139 | # threshold1 = random.randint(0, 127) 140 | # threshold2 = random.randint(128, 255) 141 | # if threshold2 - threshold1 > 80: 142 | # break 143 | 144 | # fixed threshold 145 | threshold1 = 100 146 | threshold2 = 200 147 | 148 | img = cv2.Canny(img, threshold1, threshold2) 149 | 150 | cv2.imwrite(can_file, img) 151 | ``` 152 | 153 | ### キャプションファイルの作成 154 | 155 | 学習用画像のbasenameと同じ名前で、それぞれの画像に対応したキャプションファイルを作成してください。生成時のプロンプトをそのまま利用すれば良いと思われます。 156 | 157 | `sdxl_gen_img.py` で生成した場合は、画像内のメタデータに生成時のプロンプトが記録されていますので、以下のようなスクリプトで学習用画像と同じディレクトリにキャプションファイルを作成できます(拡張子 `.txt`)。 158 | 159 | ```python 160 | import glob 161 | import os 162 | from PIL import Image 163 | 164 | IMAGES_DIR = "path/to/generated/images" 165 | 166 | img_files = glob.glob(IMAGES_DIR + "/*.png") 167 | for img_file in img_files: 168 | cap_file = img_file.replace(".png", ".txt") 169 | if os.path.exists(cap_file): 170 | print(f"Skip: {img_file}") 171 | continue 172 | print(img_file) 173 | 174 | img = Image.open(img_file) 175 | prompt = img.text["prompt"] if "prompt" in img.text else "" 176 | if prompt == "": 177 | print(f"Prompt not found in {img_file}") 178 | 179 | with open(cap_file, "w") as f: 180 | f.write(prompt + "\n") 181 | ``` 182 | 183 | ### データセットの設定ファイルの作成 184 | 185 | コマンドラインオプションからの指定も可能ですが、`.toml`ファイルを作成する場合は `conditioning_data_dir` に加工した画像を保存したディレクトリを指定します。 186 | 187 | 以下は設定ファイルの例です。 188 | 189 | ```toml 190 | [general] 191 | flip_aug = false 192 | color_aug = false 193 | resolution = [1024,1024] 194 | 195 | [[datasets]] 196 | batch_size = 8 197 | enable_bucket = false 198 | 199 | [[datasets.subsets]] 200 | image_dir = "path/to/generated/image/dir" 201 | caption_extension = ".txt" 202 | conditioning_data_dir = "path/to/canny/image/dir" 203 | ``` 204 | 205 | ## 謝辞 206 | 207 | ControlNetの作者である lllyasviel 氏、実装上のアドバイスとトラブル解決へのご尽力をいただいた furusu 氏、ControlNetデータセットを実装していただいた ddPn08 氏に感謝いたします。 208 | 209 | ## サンプル 210 | Canny 211 | ![kohya_ss_girl_standing_at_classroom_smiling_to_the_viewer_class_78976b3e-0d4d-4ea0-b8e3-053ae493abbc](https://github.com/kohya-ss/sd-scripts/assets/52813779/37e9a736-649b-4c0f-ab26-880a1bf319b5) 212 | 213 | ![im_20230820104253_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/c8896900-ab86-4120-932f-6e2ae17b77c0) 214 | 215 | ![im_20230820104302_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/b12457a0-ee3c-450e-ba9a-b712d0fe86bb) 216 | 217 | ![im_20230820104310_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/8845b8d9-804a-44ac-9618-113a28eac8a1) 218 | 219 | -------------------------------------------------------------------------------- /docs/train_ti_README-ja.md: -------------------------------------------------------------------------------- 1 | [Textual Inversion](https://textual-inversion.github.io/) の学習についての説明です。 2 | 3 | [学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。 4 | 5 | 実装に当たっては https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を大いに参考にしました。 6 | 7 | 学習したモデルはWeb UIでもそのまま使えます。 8 | 9 | # 学習の手順 10 | 11 | あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。 12 | 13 | ## データの準備 14 | 15 | [学習データの準備について](./train_README-ja.md) を参照してください。 16 | 17 | ## 学習の実行 18 | 19 | ``train_textual_inversion.py`` を用います。以下はコマンドラインの例です(DreamBooth手法)。 20 | 21 | ``` 22 | accelerate launch --num_cpu_threads_per_process 1 train_textual_inversion.py 23 | --dataset_config=<データ準備で作成した.tomlファイル> 24 | --output_dir=<学習したモデルの出力先フォルダ> 25 | --output_name=<学習したモデル出力時のファイル名> 26 | --save_model_as=safetensors 27 | --prior_loss_weight=1.0 28 | --max_train_steps=1600 29 | --learning_rate=1e-6 30 | --optimizer_type="AdamW8bit" 31 | --xformers 32 | --mixed_precision="fp16" 33 | --cache_latents 34 | --gradient_checkpointing 35 | --token_string=mychar4 --init_word=cute --num_vectors_per_token=4 36 | ``` 37 | 38 | ``--token_string`` に学習時のトークン文字列を指定します。__学習時のプロンプトは、この文字列を含むようにしてください(token_stringがmychar4なら、``mychar4 1girl`` など)__。プロンプトのこの文字列の部分が、Textual Inversionの新しいtokenに置換されて学習されます。DreamBooth, class+identifier形式のデータセットとして、`token_string` をトークン文字列にするのが最も簡単で確実です。 39 | 40 | プロンプトにトークン文字列が含まれているかどうかは、``--debug_dataset`` で置換後のtoken idが表示されますので、以下のように ``49408`` 以降のtokenが存在するかどうかで確認できます。 41 | 42 | ``` 43 | input ids: tensor([[49406, 49408, 49409, 49410, 49411, 49412, 49413, 49414, 49415, 49407, 44 | 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 45 | 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 46 | 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 47 | 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 48 | 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49 | 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 50 | 49407, 49407, 49407, 49407, 49407, 49407, 49407]]) 51 | ``` 52 | 53 | tokenizerがすでに持っている単語(一般的な単語)は使用できません。 54 | 55 | ``--init_word`` にembeddingsを初期化するときのコピー元トークンの文字列を指定します。学ばせたい概念が近いものを選ぶとよいようです。二つ以上のトークンになる文字列は指定できません。 56 | 57 | ``--num_vectors_per_token`` にいくつのトークンをこの学習で使うかを指定します。多いほうが表現力が増しますが、その分多くのトークンを消費します。たとえばnum_vectors_per_token=8の場合、指定したトークン文字列は(一般的なプロンプトの77トークン制限のうち)8トークンを消費します。 58 | 59 | 以上がTextual Inversionのための主なオプションです。以降は他の学習スクリプトと同様です。 60 | 61 | `num_cpu_threads_per_process` には通常は1を指定するとよいようです。 62 | 63 | `pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。 64 | 65 | `output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。 66 | 67 | `dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。 68 | 69 | 学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。 70 | 71 | 省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。 72 | 73 | オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。 74 | 75 | `xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。 76 | 77 | ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `8` くらいに増やしてください(高速化と精度向上の可能性があります)。 78 | 79 | ### よく使われるオプションについて 80 | 81 | 以下の場合にはオプションに関するドキュメントを参照してください。 82 | 83 | - Stable Diffusion 2.xまたはそこからの派生モデルを学習する 84 | - clip skipを2以上を前提としたモデルを学習する 85 | - 75トークンを超えたキャプションで学習する 86 | 87 | ### Textual Inversionでのバッチサイズについて 88 | 89 | モデル全体を学習するDreamBoothやfine tuningに比べてメモリ使用量が少ないため、バッチサイズは大きめにできます。 90 | 91 | # Textual Inversionのその他の主なオプション 92 | 93 | すべてのオプションについては別文書を参照してください。 94 | 95 | * `--weights` 96 | * 学習前に学習済みのembeddingsを読み込み、そこから追加で学習します。 97 | * `--use_object_template` 98 | * キャプションではなく既定の物体用テンプレート文字列(``a photo of a {}``など)で学習します。公式実装と同じになります。キャプションは無視されます。 99 | * `--use_style_template` 100 | * キャプションではなく既定のスタイル用テンプレート文字列で学習します(``a painting in the style of {}``など)。公式実装と同じになります。キャプションは無視されます。 101 | 102 | ## 当リポジトリ内の画像生成スクリプトで生成する 103 | 104 | gen_img_diffusers.pyに、``--textual_inversion_embeddings`` オプションで学習したembeddingsファイルを指定してください(複数可)。プロンプトでembeddingsファイルのファイル名(拡張子を除く)を使うと、そのembeddingsが適用されます。 105 | 106 | -------------------------------------------------------------------------------- /docs/wd14_tagger_README-en.md: -------------------------------------------------------------------------------- 1 | # Image Tagging using WD14Tagger 2 | 3 | This document is based on the information from this github page (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger). 4 | 5 | Using onnx for inference is recommended. Please install onnx with the following command: 6 | 7 | ```powershell 8 | pip install onnx==1.15.0 onnxruntime-gpu==1.17.1 9 | ``` 10 | 11 | The model weights will be automatically downloaded from Hugging Face. 12 | 13 | # Usage 14 | 15 | Run the script to perform tagging. 16 | 17 | ```powershell 18 | python finetune/tag_images_by_wd14_tagger.py --onnx --repo_id --batch_size 19 | ``` 20 | 21 | For example, if using the repository `SmilingWolf/wd-swinv2-tagger-v3` with a batch size of 4, and the training data is located in the parent folder `train_data`, it would be: 22 | 23 | ```powershell 24 | python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 --batch_size 4 ..\train_data 25 | ``` 26 | 27 | On the first run, the model files will be automatically downloaded to the `wd14_tagger_model` folder (the folder can be changed with an option). 28 | 29 | Tag files will be created in the same directory as the training data images, with the same filename and a `.txt` extension. 30 | 31 | ![Generated tag files](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png) 32 | 33 | ![Tags and image](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png) 34 | 35 | ## Example 36 | 37 | To output in the Animagine XL 3.1 format, it would be as follows (enter on a single line in practice): 38 | 39 | ``` 40 | python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 41 | --batch_size 4 --remove_underscore --undesired_tags "PUT,YOUR,UNDESIRED,TAGS" --recursive 42 | --use_rating_tags_as_last_tag --character_tags_first --character_tag_expand 43 | --always_first_tags "1girl,1boy" ..\train_data 44 | ``` 45 | 46 | ## Available Repository IDs 47 | 48 | [SmilingWolf's V2 and V3 models](https://huggingface.co/SmilingWolf) are available for use. Specify them in the format like `SmilingWolf/wd-vit-tagger-v3`. The default when omitted is `SmilingWolf/wd-v1-4-convnext-tagger-v2`. 49 | 50 | # Options 51 | 52 | ## General Options 53 | 54 | - `--onnx`: Use ONNX for inference. If not specified, TensorFlow will be used. If using TensorFlow, please install TensorFlow separately. 55 | - `--batch_size`: Number of images to process at once. Default is 1. Adjust according to VRAM capacity. 56 | - `--caption_extension`: File extension for caption files. Default is `.txt`. 57 | - `--max_data_loader_n_workers`: Maximum number of workers for DataLoader. Specifying a value of 1 or more will use DataLoader to speed up image loading. If unspecified, DataLoader will not be used. 58 | - `--thresh`: Confidence threshold for outputting tags. Default is 0.35. Lowering the value will assign more tags but accuracy will decrease. 59 | - `--general_threshold`: Confidence threshold for general tags. If omitted, same as `--thresh`. 60 | - `--character_threshold`: Confidence threshold for character tags. If omitted, same as `--thresh`. 61 | - `--recursive`: If specified, subfolders within the specified folder will also be processed recursively. 62 | - `--append_tags`: Append tags to existing tag files. 63 | - `--frequency_tags`: Output tag frequencies. 64 | - `--debug`: Debug mode. Outputs debug information if specified. 65 | 66 | ## Model Download 67 | 68 | - `--model_dir`: Folder to save model files. Default is `wd14_tagger_model`. 69 | - `--force_download`: Re-download model files if specified. 70 | 71 | ## Tag Editing 72 | 73 | - `--remove_underscore`: Remove underscores from output tags. 74 | - `--undesired_tags`: Specify tags not to output. Multiple tags can be specified, separated by commas. For example, `black eyes,black hair`. 75 | - `--use_rating_tags`: Output rating tags at the beginning of the tags. 76 | - `--use_rating_tags_as_last_tag`: Add rating tags at the end of the tags. 77 | - `--character_tags_first`: Output character tags first. 78 | - `--character_tag_expand`: Expand character tag series names. For example, split the tag `chara_name_(series)` into `chara_name, series`. 79 | - `--always_first_tags`: Specify tags to always output first when a certain tag appears in an image. Multiple tags can be specified, separated by commas. For example, `1girl,1boy`. 80 | - `--caption_separator`: Separate tags with this string in the output file. Default is `, `. 81 | - `--tag_replacement`: Perform tag replacement. Specify in the format `tag1,tag2;tag3,tag4`. If using `,` and `;`, escape them with `\`. \ 82 | For example, specify `aira tsubase,aira tsubase (uniform)` (when you want to train a specific costume), `aira tsubase,aira tsubase\, heir of shadows` (when the series name is not included in the tag). 83 | 84 | When using `tag_replacement`, it is applied after `character_tag_expand`. 85 | 86 | When specifying `remove_underscore`, specify `undesired_tags`, `always_first_tags`, and `tag_replacement` without including underscores. 87 | 88 | When specifying `caption_separator`, separate `undesired_tags` and `always_first_tags` with `caption_separator`. Always separate `tag_replacement` with `,`. 89 | -------------------------------------------------------------------------------- /docs/wd14_tagger_README-ja.md: -------------------------------------------------------------------------------- 1 | # WD14Taggerによるタグ付け 2 | 3 | こちらのgithubページ(https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger )の情報を参考にさせていただきました。 4 | 5 | onnx を用いた推論を推奨します。以下のコマンドで onnx をインストールしてください。 6 | 7 | ```powershell 8 | pip install onnx==1.15.0 onnxruntime-gpu==1.17.1 9 | ``` 10 | 11 | モデルの重みはHugging Faceから自動的にダウンロードしてきます。 12 | 13 | # 使い方 14 | 15 | スクリプトを実行してタグ付けを行います。 16 | ``` 17 | python fintune/tag_images_by_wd14_tagger.py --onnx --repo_id <モデルのrepo id> --batch_size <バッチサイズ> <教師データフォルダ> 18 | ``` 19 | 20 | レポジトリに `SmilingWolf/wd-swinv2-tagger-v3` を使用し、バッチサイズを4にして、教師データを親フォルダの `train_data`に置いた場合、以下のようになります。 21 | 22 | ``` 23 | python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 --batch_size 4 ..\train_data 24 | ``` 25 | 26 | 初回起動時にはモデルファイルが `wd14_tagger_model` フォルダに自動的にダウンロードされます(フォルダはオプションで変えられます)。 27 | 28 | タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。 29 | 30 | ![生成されたタグファイル](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png) 31 | 32 | ![タグと画像](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png) 33 | 34 | ## 記述例 35 | 36 | Animagine XL 3.1 方式で出力する場合、以下のようになります(実際には 1 行で入力してください)。 37 | 38 | ``` 39 | python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 40 | --batch_size 4 --remove_underscore --undesired_tags "PUT,YOUR,UNDESIRED,TAGS" --recursive 41 | --use_rating_tags_as_last_tag --character_tags_first --character_tag_expand 42 | --always_first_tags "1girl,1boy" ..\train_data 43 | ``` 44 | 45 | ## 使用可能なリポジトリID 46 | 47 | [SmilingWolf 氏の V2、V3 のモデル](https://huggingface.co/SmilingWolf)が使用可能です。`SmilingWolf/wd-vit-tagger-v3` のように指定してください。省略時のデフォルトは `SmilingWolf/wd-v1-4-convnext-tagger-v2` です。 48 | 49 | # オプション 50 | 51 | ## 一般オプション 52 | 53 | - `--onnx` : ONNX を使用して推論します。指定しない場合は TensorFlow を使用します。TensorFlow 使用時は別途 TensorFlow をインストールしてください。 54 | - `--batch_size` : 一度に処理する画像の数。デフォルトは1です。VRAMの容量に応じて増減してください。 55 | - `--caption_extension` : キャプションファイルの拡張子。デフォルトは `.txt` です。 56 | - `--max_data_loader_n_workers` : DataLoader の最大ワーカー数です。このオプションに 1 以上の数値を指定すると、DataLoader を用いて画像読み込みを高速化します。未指定時は DataLoader を用いません。 57 | - `--thresh` : 出力するタグの信頼度の閾値。デフォルトは0.35です。値を下げるとより多くのタグが付与されますが、精度は下がります。 58 | - `--general_threshold` : 一般タグの信頼度の閾値。省略時は `--thresh` と同じです。 59 | - `--character_threshold` : キャラクタータグの信頼度の閾値。省略時は `--thresh` と同じです。 60 | - `--recursive` : 指定すると、指定したフォルダ内のサブフォルダも再帰的に処理します。 61 | - `--append_tags` : 既存のタグファイルにタグを追加します。 62 | - `--frequency_tags` : タグの頻度を出力します。 63 | - `--debug` : デバッグモード。指定するとデバッグ情報を出力します。 64 | 65 | ## モデルのダウンロード 66 | 67 | - `--model_dir` : モデルファイルの保存先フォルダ。デフォルトは `wd14_tagger_model` です。 68 | - `--force_download` : 指定するとモデルファイルを再ダウンロードします。 69 | 70 | ## タグ編集関連 71 | 72 | - `--remove_underscore` : 出力するタグからアンダースコアを削除します。 73 | - `--undesired_tags` : 出力しないタグを指定します。カンマ区切りで複数指定できます。たとえば `black eyes,black hair` のように指定します。 74 | - `--use_rating_tags` : タグの最初にレーティングタグを出力します。 75 | - `--use_rating_tags_as_last_tag` : タグの最後にレーティングタグを追加します。 76 | - `--character_tags_first` : キャラクタータグを最初に出力します。 77 | - `--character_tag_expand` : キャラクタータグのシリーズ名を展開します。たとえば `chara_name_(series)` のタグを `chara_name, series` に分割します。 78 | - `--always_first_tags` : あるタグが画像に出力されたとき、そのタグを最初に出力するタグを指定します。カンマ区切りで複数指定できます。たとえば `1girl,1boy` のように指定します。 79 | - `--caption_separator` : 出力するファイルでタグをこの文字列で区切ります。デフォルトは `, ` です。 80 | - `--tag_replacement` : タグの置換を行います。`tag1,tag2;tag3,tag4` のように指定します。`,` および `;` を使う場合は `\` でエスケープしてください。\ 81 | たとえば `aira tsubase,aira tsubase (uniform)` (特定の衣装を学習させたいとき)、`aira tsubase,aira tsubase\, heir of shadows` (シリーズ名がタグに含まれないとき)のように指定します。 82 | 83 | `tag_replacement` は `character_tag_expand` の後に適用されます。 84 | 85 | `remove_underscore` 指定時は、`undesired_tags`、`always_first_tags`、`tag_replacement` はアンダースコアを含めずに指定してください。 86 | 87 | `caption_separator` 指定時は、`undesired_tags`、`always_first_tags` は `caption_separator` で区切ってください。`tag_replacement` は必ず `,` で区切ってください。 88 | 89 | -------------------------------------------------------------------------------- /finetune/blip/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } 22 | -------------------------------------------------------------------------------- /finetune/clean_captions_and_tags.py: -------------------------------------------------------------------------------- 1 | # このスクリプトのライセンスは、Apache License 2.0とします 2 | # (c) 2022 Kohya S. @kohya_ss 3 | 4 | import argparse 5 | import glob 6 | import os 7 | import json 8 | import re 9 | 10 | from tqdm import tqdm 11 | from library.utils import setup_logging 12 | setup_logging() 13 | import logging 14 | logger = logging.getLogger(__name__) 15 | 16 | PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') 17 | PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') 18 | PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ') 19 | PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ') 20 | 21 | # 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する 22 | PATTERNS_REMOVE_IN_MULTI = [ 23 | PATTERN_HAIR_LENGTH, 24 | PATTERN_HAIR_CUT, 25 | re.compile(r', [\w\-]+ eyes, '), 26 | re.compile(r', ([\w\-]+ sleeves|sleeveless), '), 27 | # 複数の髪型定義がある場合は削除する 28 | re.compile( 29 | r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '), 30 | ] 31 | 32 | 33 | def clean_tags(image_key, tags): 34 | # replace '_' to ' ' 35 | tags = tags.replace('^_^', '^@@@^') 36 | tags = tags.replace('_', ' ') 37 | tags = tags.replace('^@@@^', '^_^') 38 | 39 | # remove rating: deepdanbooruのみ 40 | tokens = tags.split(", rating") 41 | if len(tokens) == 1: 42 | # WD14 taggerのときはこちらになるのでメッセージは出さない 43 | # logger.info("no rating:") 44 | # logger.info(f"{image_key} {tags}") 45 | pass 46 | else: 47 | if len(tokens) > 2: 48 | logger.info("multiple ratings:") 49 | logger.info(f"{image_key} {tags}") 50 | tags = tokens[0] 51 | 52 | tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策 53 | 54 | # 複数の人物がいる場合は髪色等のタグを削除する 55 | if 'girls' in tags or 'boys' in tags: 56 | for pat in PATTERNS_REMOVE_IN_MULTI: 57 | found = pat.findall(tags) 58 | if len(found) > 1: # 二つ以上、タグがある 59 | tags = pat.sub("", tags) 60 | 61 | # 髪の特殊対応 62 | srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合) 63 | if srch_hair_len: 64 | org = srch_hair_len.group() 65 | tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags) 66 | 67 | found = PATTERN_HAIR.findall(tags) 68 | if len(found) > 1: 69 | tags = PATTERN_HAIR.sub("", tags) 70 | 71 | if srch_hair_len: 72 | tags = tags.replace(", @@@, ", org) # 戻す 73 | 74 | # white shirtとshirtみたいな重複タグの削除 75 | found = PATTERN_WORD.findall(tags) 76 | for word in found: 77 | if re.search(f", ((\w+) )+{word}, ", tags): 78 | tags = tags.replace(f", {word}, ", "") 79 | 80 | tags = tags.replace(", , ", ", ") 81 | assert tags.startswith(", ") and tags.endswith(", ") 82 | tags = tags[2:-2] 83 | return tags 84 | 85 | 86 | # 上から順に検索、置換される 87 | # ('置換元文字列', '置換後文字列') 88 | CAPTION_REPLACEMENTS = [ 89 | ('anime anime', 'anime'), 90 | ('young ', ''), 91 | ('anime girl', 'girl'), 92 | ('cartoon female', 'girl'), 93 | ('cartoon lady', 'girl'), 94 | ('cartoon character', 'girl'), # a or ~s 95 | ('cartoon woman', 'girl'), 96 | ('cartoon women', 'girls'), 97 | ('cartoon girl', 'girl'), 98 | ('anime female', 'girl'), 99 | ('anime lady', 'girl'), 100 | ('anime character', 'girl'), # a or ~s 101 | ('anime woman', 'girl'), 102 | ('anime women', 'girls'), 103 | ('lady', 'girl'), 104 | ('female', 'girl'), 105 | ('woman', 'girl'), 106 | ('women', 'girls'), 107 | ('people', 'girls'), 108 | ('person', 'girl'), 109 | ('a cartoon figure', 'a figure'), 110 | ('a cartoon image', 'an image'), 111 | ('a cartoon picture', 'a picture'), 112 | ('an anime cartoon image', 'an image'), 113 | ('a cartoon anime drawing', 'a drawing'), 114 | ('a cartoon drawing', 'a drawing'), 115 | ('girl girl', 'girl'), 116 | ] 117 | 118 | 119 | def clean_caption(caption): 120 | for rf, rt in CAPTION_REPLACEMENTS: 121 | replaced = True 122 | while replaced: 123 | bef = caption 124 | caption = caption.replace(rf, rt) 125 | replaced = bef != caption 126 | return caption 127 | 128 | 129 | def main(args): 130 | if os.path.exists(args.in_json): 131 | logger.info(f"loading existing metadata: {args.in_json}") 132 | with open(args.in_json, "rt", encoding='utf-8') as f: 133 | metadata = json.load(f) 134 | else: 135 | logger.error("no metadata / メタデータファイルがありません") 136 | return 137 | 138 | logger.info("cleaning captions and tags.") 139 | image_keys = list(metadata.keys()) 140 | for image_key in tqdm(image_keys): 141 | tags = metadata[image_key].get('tags') 142 | if tags is None: 143 | logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}") 144 | else: 145 | org = tags 146 | tags = clean_tags(image_key, tags) 147 | metadata[image_key]['tags'] = tags 148 | if args.debug and org != tags: 149 | logger.info("FROM: " + org) 150 | logger.info("TO: " + tags) 151 | 152 | caption = metadata[image_key].get('caption') 153 | if caption is None: 154 | logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}") 155 | else: 156 | org = caption 157 | caption = clean_caption(caption) 158 | metadata[image_key]['caption'] = caption 159 | if args.debug and org != caption: 160 | logger.info("FROM: " + org) 161 | logger.info("TO: " + caption) 162 | 163 | # metadataを書き出して終わり 164 | logger.info(f"writing metadata: {args.out_json}") 165 | with open(args.out_json, "wt", encoding='utf-8') as f: 166 | json.dump(metadata, f, indent=2) 167 | logger.info("done!") 168 | 169 | 170 | def setup_parser() -> argparse.ArgumentParser: 171 | parser = argparse.ArgumentParser() 172 | # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 173 | parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") 174 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 175 | parser.add_argument("--debug", action="store_true", help="debug mode") 176 | 177 | return parser 178 | 179 | 180 | if __name__ == '__main__': 181 | parser = setup_parser() 182 | 183 | args, unknown = parser.parse_known_args() 184 | if len(unknown) == 1: 185 | logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") 186 | logger.warning("All captions and tags in the metadata are processed.") 187 | logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") 188 | logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。") 189 | args.in_json = args.out_json 190 | args.out_json = unknown[0] 191 | elif len(unknown) > 0: 192 | raise ValueError(f"error: unrecognized arguments: {unknown}") 193 | 194 | main(args) 195 | -------------------------------------------------------------------------------- /finetune/hypernetwork_nai.py: -------------------------------------------------------------------------------- 1 | # NAI compatible 2 | 3 | import torch 4 | 5 | 6 | class HypernetworkModule(torch.nn.Module): 7 | def __init__(self, dim, multiplier=1.0): 8 | super().__init__() 9 | 10 | linear1 = torch.nn.Linear(dim, dim * 2) 11 | linear2 = torch.nn.Linear(dim * 2, dim) 12 | linear1.weight.data.normal_(mean=0.0, std=0.01) 13 | linear1.bias.data.zero_() 14 | linear2.weight.data.normal_(mean=0.0, std=0.01) 15 | linear2.bias.data.zero_() 16 | linears = [linear1, linear2] 17 | 18 | self.linear = torch.nn.Sequential(*linears) 19 | self.multiplier = multiplier 20 | 21 | def forward(self, x): 22 | return x + self.linear(x) * self.multiplier 23 | 24 | 25 | class Hypernetwork(torch.nn.Module): 26 | enable_sizes = [320, 640, 768, 1280] 27 | # return self.modules[Hypernetwork.enable_sizes.index(size)] 28 | 29 | def __init__(self, multiplier=1.0) -> None: 30 | super().__init__() 31 | self.modules = [] 32 | for size in Hypernetwork.enable_sizes: 33 | self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier))) 34 | self.register_module(f"{size}_0", self.modules[-1][0]) 35 | self.register_module(f"{size}_1", self.modules[-1][1]) 36 | 37 | def apply_to_stable_diffusion(self, text_encoder, vae, unet): 38 | blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks 39 | for block in blocks: 40 | for subblk in block: 41 | if 'SpatialTransformer' in str(type(subblk)): 42 | for tf_block in subblk.transformer_blocks: 43 | for attn in [tf_block.attn1, tf_block.attn2]: 44 | size = attn.context_dim 45 | if size in Hypernetwork.enable_sizes: 46 | attn.hypernetwork = self 47 | else: 48 | attn.hypernetwork = None 49 | 50 | def apply_to_diffusers(self, text_encoder, vae, unet): 51 | blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks 52 | for block in blocks: 53 | if hasattr(block, 'attentions'): 54 | for subblk in block.attentions: 55 | if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~ 56 | for tf_block in subblk.transformer_blocks: 57 | for attn in [tf_block.attn1, tf_block.attn2]: 58 | size = attn.to_k.in_features 59 | if size in Hypernetwork.enable_sizes: 60 | attn.hypernetwork = self 61 | else: 62 | attn.hypernetwork = None 63 | return True # TODO error checking 64 | 65 | def forward(self, x, context): 66 | size = context.shape[-1] 67 | assert size in Hypernetwork.enable_sizes 68 | module = self.modules[Hypernetwork.enable_sizes.index(size)] 69 | return module[0].forward(context), module[1].forward(context) 70 | 71 | def load_from_state_dict(self, state_dict): 72 | # old ver to new ver 73 | changes = { 74 | 'linear1.bias': 'linear.0.bias', 75 | 'linear1.weight': 'linear.0.weight', 76 | 'linear2.bias': 'linear.1.bias', 77 | 'linear2.weight': 'linear.1.weight', 78 | } 79 | for key_from, key_to in changes.items(): 80 | if key_from in state_dict: 81 | state_dict[key_to] = state_dict[key_from] 82 | del state_dict[key_from] 83 | 84 | for size, sd in state_dict.items(): 85 | if type(size) == int: 86 | self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True) 87 | self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True) 88 | return True 89 | 90 | def get_state_dict(self): 91 | state_dict = {} 92 | for i, size in enumerate(Hypernetwork.enable_sizes): 93 | sd0 = self.modules[i][0].state_dict() 94 | sd1 = self.modules[i][1].state_dict() 95 | state_dict[size] = [sd0, sd1] 96 | return state_dict 97 | -------------------------------------------------------------------------------- /finetune/make_captions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import json 5 | import random 6 | import sys 7 | 8 | from pathlib import Path 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import numpy as np 12 | 13 | import torch 14 | from library.device_utils import init_ipex, get_preferred_device 15 | init_ipex() 16 | 17 | from torchvision import transforms 18 | from torchvision.transforms.functional import InterpolationMode 19 | sys.path.append(os.path.dirname(__file__)) 20 | from blip.blip import blip_decoder, is_url 21 | import library.train_util as train_util 22 | from library.utils import setup_logging 23 | setup_logging() 24 | import logging 25 | logger = logging.getLogger(__name__) 26 | 27 | DEVICE = get_preferred_device() 28 | 29 | 30 | IMAGE_SIZE = 384 31 | 32 | # 正方形でいいのか? という気がするがソースがそうなので 33 | IMAGE_TRANSFORM = transforms.Compose( 34 | [ 35 | transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC), 36 | transforms.ToTensor(), 37 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 38 | ] 39 | ) 40 | 41 | 42 | # 共通化したいが微妙に処理が異なる…… 43 | class ImageLoadingTransformDataset(torch.utils.data.Dataset): 44 | def __init__(self, image_paths): 45 | self.images = image_paths 46 | 47 | def __len__(self): 48 | return len(self.images) 49 | 50 | def __getitem__(self, idx): 51 | img_path = self.images[idx] 52 | 53 | try: 54 | image = Image.open(img_path).convert("RGB") 55 | # convert to tensor temporarily so dataloader will accept it 56 | tensor = IMAGE_TRANSFORM(image) 57 | except Exception as e: 58 | logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") 59 | return None 60 | 61 | return (tensor, img_path) 62 | 63 | 64 | def collate_fn_remove_corrupted(batch): 65 | """Collate function that allows to remove corrupted examples in the 66 | dataloader. It expects that the dataloader returns 'None' when that occurs. 67 | The 'None's in the batch are removed. 68 | """ 69 | # Filter out all the Nones (corrupted examples) 70 | batch = list(filter(lambda x: x is not None, batch)) 71 | return batch 72 | 73 | 74 | def main(args): 75 | # fix the seed for reproducibility 76 | seed = args.seed # + utils.get_rank() 77 | torch.manual_seed(seed) 78 | np.random.seed(seed) 79 | random.seed(seed) 80 | 81 | if not os.path.exists("blip"): 82 | args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path 83 | 84 | cwd = os.getcwd() 85 | logger.info(f"Current Working Directory is: {cwd}") 86 | os.chdir("finetune") 87 | if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights): 88 | args.caption_weights = os.path.join("..", args.caption_weights) 89 | 90 | logger.info(f"load images from {args.train_data_dir}") 91 | train_data_dir_path = Path(args.train_data_dir) 92 | image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 93 | logger.info(f"found {len(image_paths)} images.") 94 | 95 | logger.info(f"loading BLIP caption: {args.caption_weights}") 96 | model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json") 97 | model.eval() 98 | model = model.to(DEVICE) 99 | logger.info("BLIP loaded") 100 | 101 | # captioningする 102 | def run_batch(path_imgs): 103 | imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE) 104 | 105 | with torch.no_grad(): 106 | if args.beam_search: 107 | captions = model.generate( 108 | imgs, sample=False, num_beams=args.num_beams, max_length=args.max_length, min_length=args.min_length 109 | ) 110 | else: 111 | captions = model.generate( 112 | imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length 113 | ) 114 | 115 | for (image_path, _), caption in zip(path_imgs, captions): 116 | with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: 117 | f.write(caption + "\n") 118 | if args.debug: 119 | logger.info(f'{image_path} {caption}') 120 | 121 | # 読み込みの高速化のためにDataLoaderを使うオプション 122 | if args.max_data_loader_n_workers is not None: 123 | dataset = ImageLoadingTransformDataset(image_paths) 124 | data = torch.utils.data.DataLoader( 125 | dataset, 126 | batch_size=args.batch_size, 127 | shuffle=False, 128 | num_workers=args.max_data_loader_n_workers, 129 | collate_fn=collate_fn_remove_corrupted, 130 | drop_last=False, 131 | ) 132 | else: 133 | data = [[(None, ip)] for ip in image_paths] 134 | 135 | b_imgs = [] 136 | for data_entry in tqdm(data, smoothing=0.0): 137 | for data in data_entry: 138 | if data is None: 139 | continue 140 | 141 | img_tensor, image_path = data 142 | if img_tensor is None: 143 | try: 144 | raw_image = Image.open(image_path) 145 | if raw_image.mode != "RGB": 146 | raw_image = raw_image.convert("RGB") 147 | img_tensor = IMAGE_TRANSFORM(raw_image) 148 | except Exception as e: 149 | logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") 150 | continue 151 | 152 | b_imgs.append((image_path, img_tensor)) 153 | if len(b_imgs) >= args.batch_size: 154 | run_batch(b_imgs) 155 | b_imgs.clear() 156 | if len(b_imgs) > 0: 157 | run_batch(b_imgs) 158 | 159 | logger.info("done!") 160 | 161 | 162 | def setup_parser() -> argparse.ArgumentParser: 163 | parser = argparse.ArgumentParser() 164 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 165 | parser.add_argument( 166 | "--caption_weights", 167 | type=str, 168 | default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", 169 | help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)", 170 | ) 171 | parser.add_argument( 172 | "--caption_extention", 173 | type=str, 174 | default=None, 175 | help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)", 176 | ) 177 | parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") 178 | parser.add_argument( 179 | "--beam_search", 180 | action="store_true", 181 | help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)", 182 | ) 183 | parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") 184 | parser.add_argument( 185 | "--max_data_loader_n_workers", 186 | type=int, 187 | default=None, 188 | help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", 189 | ) 190 | parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") 191 | parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") 192 | parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") 193 | parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") 194 | parser.add_argument("--seed", default=42, type=int, help="seed for reproducibility / 再現性を確保するための乱数seed") 195 | parser.add_argument("--debug", action="store_true", help="debug mode") 196 | parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") 197 | 198 | return parser 199 | 200 | 201 | if __name__ == "__main__": 202 | parser = setup_parser() 203 | 204 | args = parser.parse_args() 205 | 206 | # スペルミスしていたオプションを復元する 207 | if args.caption_extention is not None: 208 | args.caption_extension = args.caption_extention 209 | 210 | main(args) 211 | -------------------------------------------------------------------------------- /finetune/make_captions_by_git.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | 5 | from pathlib import Path 6 | from PIL import Image 7 | from tqdm import tqdm 8 | 9 | import torch 10 | from library.device_utils import init_ipex, get_preferred_device 11 | init_ipex() 12 | 13 | from transformers import AutoProcessor, AutoModelForCausalLM 14 | from transformers.generation.utils import GenerationMixin 15 | 16 | import library.train_util as train_util 17 | from library.utils import setup_logging 18 | setup_logging() 19 | import logging 20 | logger = logging.getLogger(__name__) 21 | 22 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | 24 | PATTERN_REPLACE = [ 25 | re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'), 26 | re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'), 27 | re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"), 28 | re.compile(r"with the number \d+ on (it|\w+ \w+)"), 29 | re.compile(r'with the words "'), 30 | re.compile(r"word \w+ on it"), 31 | re.compile(r"that says the word \w+ on it"), 32 | re.compile("that says'the word \"( on it)?"), 33 | ] 34 | 35 | # 誤検知しまくりの with the word xxxx を消す 36 | 37 | 38 | def remove_words(captions, debug): 39 | removed_caps = [] 40 | for caption in captions: 41 | cap = caption 42 | for pat in PATTERN_REPLACE: 43 | cap = pat.sub("", cap) 44 | if debug and cap != caption: 45 | logger.info(caption) 46 | logger.info(cap) 47 | removed_caps.append(cap) 48 | return removed_caps 49 | 50 | 51 | def collate_fn_remove_corrupted(batch): 52 | """Collate function that allows to remove corrupted examples in the 53 | dataloader. It expects that the dataloader returns 'None' when that occurs. 54 | The 'None's in the batch are removed. 55 | """ 56 | # Filter out all the Nones (corrupted examples) 57 | batch = list(filter(lambda x: x is not None, batch)) 58 | return batch 59 | 60 | 61 | def main(args): 62 | r""" 63 | transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト 64 | 65 | # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用 66 | org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation 67 | curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように 68 | 69 | # input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す 70 | # ここより上で置き換えようとするとすごく大変 71 | def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs): 72 | input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs) 73 | if input_ids.size()[0] != curr_batch_size[0]: 74 | input_ids = input_ids.repeat(curr_batch_size[0], 1) 75 | return input_ids 76 | 77 | GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch 78 | """ 79 | 80 | logger.info(f"load images from {args.train_data_dir}") 81 | train_data_dir_path = Path(args.train_data_dir) 82 | image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 83 | logger.info(f"found {len(image_paths)} images.") 84 | 85 | # できればcacheに依存せず明示的にダウンロードしたい 86 | logger.info(f"loading GIT: {args.model_id}") 87 | git_processor = AutoProcessor.from_pretrained(args.model_id) 88 | git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) 89 | logger.info("GIT loaded") 90 | 91 | # captioningする 92 | def run_batch(path_imgs): 93 | imgs = [im for _, im in path_imgs] 94 | 95 | # curr_batch_size[0] = len(path_imgs) 96 | inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式 97 | generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length) 98 | captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True) 99 | 100 | if args.remove_words: 101 | captions = remove_words(captions, args.debug) 102 | 103 | for (image_path, _), caption in zip(path_imgs, captions): 104 | with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: 105 | f.write(caption + "\n") 106 | if args.debug: 107 | logger.info(f"{image_path} {caption}") 108 | 109 | # 読み込みの高速化のためにDataLoaderを使うオプション 110 | if args.max_data_loader_n_workers is not None: 111 | dataset = train_util.ImageLoadingDataset(image_paths) 112 | data = torch.utils.data.DataLoader( 113 | dataset, 114 | batch_size=args.batch_size, 115 | shuffle=False, 116 | num_workers=args.max_data_loader_n_workers, 117 | collate_fn=collate_fn_remove_corrupted, 118 | drop_last=False, 119 | ) 120 | else: 121 | data = [[(None, ip)] for ip in image_paths] 122 | 123 | b_imgs = [] 124 | for data_entry in tqdm(data, smoothing=0.0): 125 | for data in data_entry: 126 | if data is None: 127 | continue 128 | 129 | image, image_path = data 130 | if image is None: 131 | try: 132 | image = Image.open(image_path) 133 | if image.mode != "RGB": 134 | image = image.convert("RGB") 135 | except Exception as e: 136 | logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") 137 | continue 138 | 139 | b_imgs.append((image_path, image)) 140 | if len(b_imgs) >= args.batch_size: 141 | run_batch(b_imgs) 142 | b_imgs.clear() 143 | 144 | if len(b_imgs) > 0: 145 | run_batch(b_imgs) 146 | 147 | logger.info("done!") 148 | 149 | 150 | def setup_parser() -> argparse.ArgumentParser: 151 | parser = argparse.ArgumentParser() 152 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 153 | parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子") 154 | parser.add_argument( 155 | "--model_id", 156 | type=str, 157 | default="microsoft/git-large-textcaps", 158 | help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID", 159 | ) 160 | parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") 161 | parser.add_argument( 162 | "--max_data_loader_n_workers", 163 | type=int, 164 | default=None, 165 | help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", 166 | ) 167 | parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長") 168 | parser.add_argument( 169 | "--remove_words", 170 | action="store_true", 171 | help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する", 172 | ) 173 | parser.add_argument("--debug", action="store_true", help="debug mode") 174 | parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する") 175 | 176 | return parser 177 | 178 | 179 | if __name__ == "__main__": 180 | parser = setup_parser() 181 | 182 | args = parser.parse_args() 183 | main(args) 184 | -------------------------------------------------------------------------------- /finetune/merge_captions_to_metadata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from typing import List 5 | from tqdm import tqdm 6 | import library.train_util as train_util 7 | import os 8 | from library.utils import setup_logging 9 | 10 | setup_logging() 11 | import logging 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def main(args): 17 | assert not args.recursive or ( 18 | args.recursive and args.full_path 19 | ), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" 20 | 21 | train_data_dir_path = Path(args.train_data_dir) 22 | image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 23 | logger.info(f"found {len(image_paths)} images.") 24 | 25 | if args.in_json is None and Path(args.out_json).is_file(): 26 | args.in_json = args.out_json 27 | 28 | if args.in_json is not None: 29 | logger.info(f"loading existing metadata: {args.in_json}") 30 | metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8")) 31 | logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") 32 | else: 33 | logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") 34 | metadata = {} 35 | 36 | logger.info("merge caption texts to metadata json.") 37 | for image_path in tqdm(image_paths): 38 | caption_path = image_path.with_suffix(args.caption_extension) 39 | caption = caption_path.read_text(encoding="utf-8").strip() 40 | 41 | if not os.path.exists(caption_path): 42 | caption_path = os.path.join(image_path, args.caption_extension) 43 | 44 | image_key = str(image_path) if args.full_path else image_path.stem 45 | if image_key not in metadata: 46 | metadata[image_key] = {} 47 | 48 | metadata[image_key]["caption"] = caption 49 | if args.debug: 50 | logger.info(f"{image_key} {caption}") 51 | 52 | # metadataを書き出して終わり 53 | logger.info(f"writing metadata: {args.out_json}") 54 | Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8") 55 | logger.info("done!") 56 | 57 | 58 | def setup_parser() -> argparse.ArgumentParser: 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 61 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 62 | parser.add_argument( 63 | "--in_json", 64 | type=str, 65 | help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)", 66 | ) 67 | parser.add_argument( 68 | "--caption_extention", 69 | type=str, 70 | default=None, 71 | help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)", 72 | ) 73 | parser.add_argument( 74 | "--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子" 75 | ) 76 | parser.add_argument( 77 | "--full_path", 78 | action="store_true", 79 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", 80 | ) 81 | parser.add_argument( 82 | "--recursive", 83 | action="store_true", 84 | help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す", 85 | ) 86 | parser.add_argument("--debug", action="store_true", help="debug mode") 87 | 88 | return parser 89 | 90 | 91 | if __name__ == "__main__": 92 | parser = setup_parser() 93 | 94 | args = parser.parse_args() 95 | 96 | # スペルミスしていたオプションを復元する 97 | if args.caption_extention is not None: 98 | args.caption_extension = args.caption_extention 99 | 100 | main(args) 101 | -------------------------------------------------------------------------------- /finetune/merge_dd_tags_to_metadata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from typing import List 5 | from tqdm import tqdm 6 | import library.train_util as train_util 7 | import os 8 | from library.utils import setup_logging 9 | 10 | setup_logging() 11 | import logging 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def main(args): 17 | assert not args.recursive or ( 18 | args.recursive and args.full_path 19 | ), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" 20 | 21 | train_data_dir_path = Path(args.train_data_dir) 22 | image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) 23 | logger.info(f"found {len(image_paths)} images.") 24 | 25 | if args.in_json is None and Path(args.out_json).is_file(): 26 | args.in_json = args.out_json 27 | 28 | if args.in_json is not None: 29 | logger.info(f"loading existing metadata: {args.in_json}") 30 | metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8")) 31 | logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") 32 | else: 33 | logger.info("new metadata will be created / 新しいメタデータファイルが作成されます") 34 | metadata = {} 35 | 36 | logger.info("merge tags to metadata json.") 37 | for image_path in tqdm(image_paths): 38 | tags_path = image_path.with_suffix(args.caption_extension) 39 | tags = tags_path.read_text(encoding="utf-8").strip() 40 | 41 | if not os.path.exists(tags_path): 42 | tags_path = os.path.join(image_path, args.caption_extension) 43 | 44 | image_key = str(image_path) if args.full_path else image_path.stem 45 | if image_key not in metadata: 46 | metadata[image_key] = {} 47 | 48 | metadata[image_key]["tags"] = tags 49 | if args.debug: 50 | logger.info(f"{image_key} {tags}") 51 | 52 | # metadataを書き出して終わり 53 | logger.info(f"writing metadata: {args.out_json}") 54 | Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8") 55 | 56 | logger.info("done!") 57 | 58 | 59 | def setup_parser() -> argparse.ArgumentParser: 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") 62 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") 63 | parser.add_argument( 64 | "--in_json", 65 | type=str, 66 | help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)", 67 | ) 68 | parser.add_argument( 69 | "--full_path", 70 | action="store_true", 71 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", 72 | ) 73 | parser.add_argument( 74 | "--recursive", 75 | action="store_true", 76 | help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す", 77 | ) 78 | parser.add_argument( 79 | "--caption_extension", 80 | type=str, 81 | default=".txt", 82 | help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子", 83 | ) 84 | parser.add_argument("--debug", action="store_true", help="debug mode, print tags") 85 | 86 | return parser 87 | 88 | 89 | if __name__ == "__main__": 90 | parser = setup_parser() 91 | 92 | args = parser.parse_args() 93 | main(args) 94 | -------------------------------------------------------------------------------- /library/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdbds/sd-scripts/a21b6a917e8ca2d0392f5861da2dddb510e389ad/library/__init__.py -------------------------------------------------------------------------------- /library/adafactor_fused.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from transformers import Adafactor 4 | 5 | @torch.no_grad() 6 | def adafactor_step_param(self, p, group): 7 | if p.grad is None: 8 | return 9 | grad = p.grad 10 | if grad.dtype in {torch.float16, torch.bfloat16}: 11 | grad = grad.float() 12 | if grad.is_sparse: 13 | raise RuntimeError("Adafactor does not support sparse gradients.") 14 | 15 | state = self.state[p] 16 | grad_shape = grad.shape 17 | 18 | factored, use_first_moment = Adafactor._get_options(group, grad_shape) 19 | # State Initialization 20 | if len(state) == 0: 21 | state["step"] = 0 22 | 23 | if use_first_moment: 24 | # Exponential moving average of gradient values 25 | state["exp_avg"] = torch.zeros_like(grad) 26 | if factored: 27 | state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) 28 | state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) 29 | else: 30 | state["exp_avg_sq"] = torch.zeros_like(grad) 31 | 32 | state["RMS"] = 0 33 | else: 34 | if use_first_moment: 35 | state["exp_avg"] = state["exp_avg"].to(grad) 36 | if factored: 37 | state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) 38 | state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) 39 | else: 40 | state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) 41 | 42 | p_data_fp32 = p 43 | if p.dtype in {torch.float16, torch.bfloat16}: 44 | p_data_fp32 = p_data_fp32.float() 45 | 46 | state["step"] += 1 47 | state["RMS"] = Adafactor._rms(p_data_fp32) 48 | lr = Adafactor._get_lr(group, state) 49 | 50 | beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) 51 | update = (grad ** 2) + group["eps"][0] 52 | if factored: 53 | exp_avg_sq_row = state["exp_avg_sq_row"] 54 | exp_avg_sq_col = state["exp_avg_sq_col"] 55 | 56 | exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) 57 | exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) 58 | 59 | # Approximation of exponential moving average of square of gradient 60 | update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) 61 | update.mul_(grad) 62 | else: 63 | exp_avg_sq = state["exp_avg_sq"] 64 | 65 | exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) 66 | update = exp_avg_sq.rsqrt().mul_(grad) 67 | 68 | update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) 69 | update.mul_(lr) 70 | 71 | if use_first_moment: 72 | exp_avg = state["exp_avg"] 73 | exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) 74 | update = exp_avg 75 | 76 | if group["weight_decay"] != 0: 77 | p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) 78 | 79 | p_data_fp32.add_(-update) 80 | 81 | if p.dtype in {torch.float16, torch.bfloat16}: 82 | p.copy_(p_data_fp32) 83 | 84 | 85 | @torch.no_grad() 86 | def adafactor_step(self, closure=None): 87 | """ 88 | Performs a single optimization step 89 | 90 | Arguments: 91 | closure (callable, optional): A closure that reevaluates the model 92 | and returns the loss. 93 | """ 94 | loss = None 95 | if closure is not None: 96 | loss = closure() 97 | 98 | for group in self.param_groups: 99 | for p in group["params"]: 100 | adafactor_step_param(self, p, group) 101 | 102 | return loss 103 | 104 | def patch_adafactor_fused(optimizer: Adafactor): 105 | optimizer.step_param = adafactor_step_param.__get__(optimizer) 106 | optimizer.step = adafactor_step.__get__(optimizer) 107 | -------------------------------------------------------------------------------- /library/attention_processors.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any 3 | from einops import rearrange 4 | import torch 5 | from diffusers.models.attention_processor import Attention 6 | 7 | 8 | # flash attention forwards and backwards 9 | 10 | # https://arxiv.org/abs/2205.14135 11 | 12 | EPSILON = 1e-6 13 | 14 | 15 | class FlashAttentionFunction(torch.autograd.function.Function): 16 | @staticmethod 17 | @torch.no_grad() 18 | def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): 19 | """Algorithm 2 in the paper""" 20 | 21 | device = q.device 22 | dtype = q.dtype 23 | max_neg_value = -torch.finfo(q.dtype).max 24 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) 25 | 26 | o = torch.zeros_like(q) 27 | all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) 28 | all_row_maxes = torch.full( 29 | (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device 30 | ) 31 | 32 | scale = q.shape[-1] ** -0.5 33 | 34 | if mask is None: 35 | mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) 36 | else: 37 | mask = rearrange(mask, "b n -> b 1 1 n") 38 | mask = mask.split(q_bucket_size, dim=-1) 39 | 40 | row_splits = zip( 41 | q.split(q_bucket_size, dim=-2), 42 | o.split(q_bucket_size, dim=-2), 43 | mask, 44 | all_row_sums.split(q_bucket_size, dim=-2), 45 | all_row_maxes.split(q_bucket_size, dim=-2), 46 | ) 47 | 48 | for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): 49 | q_start_index = ind * q_bucket_size - qk_len_diff 50 | 51 | col_splits = zip( 52 | k.split(k_bucket_size, dim=-2), 53 | v.split(k_bucket_size, dim=-2), 54 | ) 55 | 56 | for k_ind, (kc, vc) in enumerate(col_splits): 57 | k_start_index = k_ind * k_bucket_size 58 | 59 | attn_weights = ( 60 | torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale 61 | ) 62 | 63 | if row_mask is not None: 64 | attn_weights.masked_fill_(~row_mask, max_neg_value) 65 | 66 | if causal and q_start_index < (k_start_index + k_bucket_size - 1): 67 | causal_mask = torch.ones( 68 | (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device 69 | ).triu(q_start_index - k_start_index + 1) 70 | attn_weights.masked_fill_(causal_mask, max_neg_value) 71 | 72 | block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) 73 | attn_weights -= block_row_maxes 74 | exp_weights = torch.exp(attn_weights) 75 | 76 | if row_mask is not None: 77 | exp_weights.masked_fill_(~row_mask, 0.0) 78 | 79 | block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( 80 | min=EPSILON 81 | ) 82 | 83 | new_row_maxes = torch.maximum(block_row_maxes, row_maxes) 84 | 85 | exp_values = torch.einsum( 86 | "... i j, ... j d -> ... i d", exp_weights, vc 87 | ) 88 | 89 | exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) 90 | exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) 91 | 92 | new_row_sums = ( 93 | exp_row_max_diff * row_sums 94 | + exp_block_row_max_diff * block_row_sums 95 | ) 96 | 97 | oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( 98 | (exp_block_row_max_diff / new_row_sums) * exp_values 99 | ) 100 | 101 | row_maxes.copy_(new_row_maxes) 102 | row_sums.copy_(new_row_sums) 103 | 104 | ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) 105 | ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) 106 | 107 | return o 108 | 109 | @staticmethod 110 | @torch.no_grad() 111 | def backward(ctx, do): 112 | """Algorithm 4 in the paper""" 113 | 114 | causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args 115 | q, k, v, o, l, m = ctx.saved_tensors 116 | 117 | device = q.device 118 | 119 | max_neg_value = -torch.finfo(q.dtype).max 120 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) 121 | 122 | dq = torch.zeros_like(q) 123 | dk = torch.zeros_like(k) 124 | dv = torch.zeros_like(v) 125 | 126 | row_splits = zip( 127 | q.split(q_bucket_size, dim=-2), 128 | o.split(q_bucket_size, dim=-2), 129 | do.split(q_bucket_size, dim=-2), 130 | mask, 131 | l.split(q_bucket_size, dim=-2), 132 | m.split(q_bucket_size, dim=-2), 133 | dq.split(q_bucket_size, dim=-2), 134 | ) 135 | 136 | for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): 137 | q_start_index = ind * q_bucket_size - qk_len_diff 138 | 139 | col_splits = zip( 140 | k.split(k_bucket_size, dim=-2), 141 | v.split(k_bucket_size, dim=-2), 142 | dk.split(k_bucket_size, dim=-2), 143 | dv.split(k_bucket_size, dim=-2), 144 | ) 145 | 146 | for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): 147 | k_start_index = k_ind * k_bucket_size 148 | 149 | attn_weights = ( 150 | torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale 151 | ) 152 | 153 | if causal and q_start_index < (k_start_index + k_bucket_size - 1): 154 | causal_mask = torch.ones( 155 | (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device 156 | ).triu(q_start_index - k_start_index + 1) 157 | attn_weights.masked_fill_(causal_mask, max_neg_value) 158 | 159 | exp_attn_weights = torch.exp(attn_weights - mc) 160 | 161 | if row_mask is not None: 162 | exp_attn_weights.masked_fill_(~row_mask, 0.0) 163 | 164 | p = exp_attn_weights / lc 165 | 166 | dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) 167 | dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) 168 | 169 | D = (doc * oc).sum(dim=-1, keepdims=True) 170 | ds = p * scale * (dp - D) 171 | 172 | dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) 173 | dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) 174 | 175 | dqc.add_(dq_chunk) 176 | dkc.add_(dk_chunk) 177 | dvc.add_(dv_chunk) 178 | 179 | return dq, dk, dv, None, None, None, None 180 | 181 | 182 | class FlashAttnProcessor: 183 | def __call__( 184 | self, 185 | attn: Attention, 186 | hidden_states, 187 | encoder_hidden_states=None, 188 | attention_mask=None, 189 | ) -> Any: 190 | q_bucket_size = 512 191 | k_bucket_size = 1024 192 | 193 | h = attn.heads 194 | q = attn.to_q(hidden_states) 195 | 196 | encoder_hidden_states = ( 197 | encoder_hidden_states 198 | if encoder_hidden_states is not None 199 | else hidden_states 200 | ) 201 | encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) 202 | 203 | if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None: 204 | context_k, context_v = attn.hypernetwork.forward( 205 | hidden_states, encoder_hidden_states 206 | ) 207 | context_k = context_k.to(hidden_states.dtype) 208 | context_v = context_v.to(hidden_states.dtype) 209 | else: 210 | context_k = encoder_hidden_states 211 | context_v = encoder_hidden_states 212 | 213 | k = attn.to_k(context_k) 214 | v = attn.to_v(context_v) 215 | del encoder_hidden_states, hidden_states 216 | 217 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) 218 | 219 | out = FlashAttentionFunction.apply( 220 | q, k, v, attention_mask, False, q_bucket_size, k_bucket_size 221 | ) 222 | 223 | out = rearrange(out, "b h n d -> b n (h d)") 224 | 225 | out = attn.to_out[0](out) 226 | out = attn.to_out[1](out) 227 | return out 228 | -------------------------------------------------------------------------------- /library/deepspeed_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from accelerate import DeepSpeedPlugin, Accelerator 5 | 6 | from .utils import setup_logging 7 | 8 | setup_logging() 9 | import logging 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def add_deepspeed_arguments(parser: argparse.ArgumentParser): 15 | # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed 16 | parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training") 17 | parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.") 18 | parser.add_argument( 19 | "--offload_optimizer_device", 20 | type=str, 21 | default=None, 22 | choices=[None, "cpu", "nvme"], 23 | help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.", 24 | ) 25 | parser.add_argument( 26 | "--offload_optimizer_nvme_path", 27 | type=str, 28 | default=None, 29 | help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.", 30 | ) 31 | parser.add_argument( 32 | "--offload_param_device", 33 | type=str, 34 | default=None, 35 | choices=[None, "cpu", "nvme"], 36 | help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.", 37 | ) 38 | parser.add_argument( 39 | "--offload_param_nvme_path", 40 | type=str, 41 | default=None, 42 | help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.", 43 | ) 44 | parser.add_argument( 45 | "--zero3_init_flag", 46 | action="store_true", 47 | help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models." 48 | "Only applicable with ZeRO Stage-3.", 49 | ) 50 | parser.add_argument( 51 | "--zero3_save_16bit_model", 52 | action="store_true", 53 | help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.", 54 | ) 55 | parser.add_argument( 56 | "--fp16_master_weights_and_gradients", 57 | action="store_true", 58 | help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.", 59 | ) 60 | 61 | 62 | def prepare_deepspeed_args(args: argparse.Namespace): 63 | if not args.deepspeed: 64 | return 65 | 66 | # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1. 67 | args.max_data_loader_n_workers = 1 68 | 69 | 70 | def prepare_deepspeed_plugin(args: argparse.Namespace): 71 | if not args.deepspeed: 72 | return None 73 | 74 | try: 75 | import deepspeed 76 | except ImportError as e: 77 | logger.error( 78 | "deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed" 79 | ) 80 | exit(1) 81 | 82 | deepspeed_plugin = DeepSpeedPlugin( 83 | zero_stage=args.zero_stage, 84 | gradient_accumulation_steps=args.gradient_accumulation_steps, 85 | gradient_clipping=args.max_grad_norm, 86 | offload_optimizer_device=args.offload_optimizer_device, 87 | offload_optimizer_nvme_path=args.offload_optimizer_nvme_path, 88 | offload_param_device=args.offload_param_device, 89 | offload_param_nvme_path=args.offload_param_nvme_path, 90 | zero3_init_flag=args.zero3_init_flag, 91 | zero3_save_16bit_model=args.zero3_save_16bit_model, 92 | ) 93 | deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size 94 | deepspeed_plugin.deepspeed_config["train_batch_size"] = ( 95 | args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"]) 96 | ) 97 | deepspeed_plugin.set_mixed_precision(args.mixed_precision) 98 | if args.mixed_precision.lower() == "fp16": 99 | deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow. 100 | if args.full_fp16 or args.fp16_master_weights_and_gradients: 101 | if args.offload_optimizer_device == "cpu" and args.zero_stage == 2: 102 | deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True 103 | logger.info("[DeepSpeed] full fp16 enable.") 104 | else: 105 | logger.info( 106 | "[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage." 107 | ) 108 | 109 | if args.offload_optimizer_device is not None: 110 | logger.info("[DeepSpeed] start to manually build cpu_adam.") 111 | deepspeed.ops.op_builder.CPUAdamBuilder().load() 112 | logger.info("[DeepSpeed] building cpu_adam done.") 113 | 114 | return deepspeed_plugin 115 | 116 | 117 | # Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model. 118 | def prepare_deepspeed_model(args: argparse.Namespace, **models): 119 | # remove None from models 120 | models = {k: v for k, v in models.items() if v is not None} 121 | 122 | class DeepSpeedWrapper(torch.nn.Module): 123 | def __init__(self, **kw_models) -> None: 124 | super().__init__() 125 | self.models = torch.nn.ModuleDict() 126 | 127 | for key, model in kw_models.items(): 128 | if isinstance(model, list): 129 | model = torch.nn.ModuleList(model) 130 | assert isinstance( 131 | model, torch.nn.Module 132 | ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" 133 | self.models.update(torch.nn.ModuleDict({key: model})) 134 | 135 | def get_models(self): 136 | return self.models 137 | 138 | ds_model = DeepSpeedWrapper(**models) 139 | return ds_model 140 | -------------------------------------------------------------------------------- /library/device_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import gc 3 | 4 | import torch 5 | try: 6 | # intel gpu support for pytorch older than 2.5 7 | # ipex is not needed after pytorch 2.5 8 | import intel_extension_for_pytorch as ipex # noqa 9 | except Exception: 10 | pass 11 | 12 | 13 | try: 14 | HAS_CUDA = torch.cuda.is_available() 15 | except Exception: 16 | HAS_CUDA = False 17 | 18 | try: 19 | HAS_MPS = torch.backends.mps.is_available() 20 | except Exception: 21 | HAS_MPS = False 22 | 23 | try: 24 | HAS_XPU = torch.xpu.is_available() 25 | except Exception: 26 | HAS_XPU = False 27 | 28 | 29 | def clean_memory(): 30 | gc.collect() 31 | if HAS_CUDA: 32 | torch.cuda.empty_cache() 33 | if HAS_XPU: 34 | torch.xpu.empty_cache() 35 | if HAS_MPS: 36 | torch.mps.empty_cache() 37 | 38 | 39 | def clean_memory_on_device(device: torch.device): 40 | r""" 41 | Clean memory on the specified device, will be called from training scripts. 42 | """ 43 | gc.collect() 44 | 45 | # device may "cuda" or "cuda:0", so we need to check the type of device 46 | if device.type == "cuda": 47 | torch.cuda.empty_cache() 48 | if device.type == "xpu": 49 | torch.xpu.empty_cache() 50 | if device.type == "mps": 51 | torch.mps.empty_cache() 52 | 53 | 54 | @functools.lru_cache(maxsize=None) 55 | def get_preferred_device() -> torch.device: 56 | r""" 57 | Do not call this function from training scripts. Use accelerator.device instead. 58 | """ 59 | if HAS_CUDA: 60 | device = torch.device("cuda") 61 | elif HAS_XPU: 62 | device = torch.device("xpu") 63 | elif HAS_MPS: 64 | device = torch.device("mps") 65 | else: 66 | device = torch.device("cpu") 67 | print(f"get_preferred_device() -> {device}") 68 | return device 69 | 70 | 71 | def init_ipex(): 72 | """ 73 | Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. 74 | 75 | This function should run right after importing torch and before doing anything else. 76 | 77 | If xpu is not available, this function does nothing. 78 | """ 79 | try: 80 | if HAS_XPU: 81 | from library.ipex import ipex_init 82 | 83 | is_initialized, error_message = ipex_init() 84 | if not is_initialized: 85 | print("failed to initialize ipex:", error_message) 86 | else: 87 | return 88 | except Exception as e: 89 | print("failed to initialize ipex:", e) 90 | -------------------------------------------------------------------------------- /library/huggingface_util.py: -------------------------------------------------------------------------------- 1 | from typing import Union, BinaryIO 2 | from huggingface_hub import HfApi 3 | from pathlib import Path 4 | import argparse 5 | import os 6 | from library.utils import fire_in_thread 7 | from library.utils import setup_logging 8 | setup_logging() 9 | import logging 10 | logger = logging.getLogger(__name__) 11 | 12 | def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): 13 | api = HfApi( 14 | token=token, 15 | ) 16 | try: 17 | api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) 18 | return True 19 | except: 20 | return False 21 | 22 | 23 | def upload( 24 | args: argparse.Namespace, 25 | src: Union[str, Path, bytes, BinaryIO], 26 | dest_suffix: str = "", 27 | force_sync_upload: bool = False, 28 | ): 29 | repo_id = args.huggingface_repo_id 30 | repo_type = args.huggingface_repo_type 31 | token = args.huggingface_token 32 | path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None 33 | private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public" 34 | api = HfApi(token=token) 35 | if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token): 36 | try: 37 | api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) 38 | except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので 39 | logger.error("===========================================") 40 | logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") 41 | logger.error("===========================================") 42 | 43 | is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir()) 44 | 45 | def uploader(): 46 | try: 47 | if is_folder: 48 | api.upload_folder( 49 | repo_id=repo_id, 50 | repo_type=repo_type, 51 | folder_path=src, 52 | path_in_repo=path_in_repo, 53 | ) 54 | else: 55 | api.upload_file( 56 | repo_id=repo_id, 57 | repo_type=repo_type, 58 | path_or_fileobj=src, 59 | path_in_repo=path_in_repo, 60 | ) 61 | except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので 62 | logger.error("===========================================") 63 | logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") 64 | logger.error("===========================================") 65 | 66 | if args.async_upload and not force_sync_upload: 67 | fire_in_thread(uploader) 68 | else: 69 | uploader() 70 | 71 | 72 | def list_dir( 73 | repo_id: str, 74 | subfolder: str, 75 | repo_type: str, 76 | revision: str = "main", 77 | token: str = None, 78 | ): 79 | api = HfApi( 80 | token=token, 81 | ) 82 | repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) 83 | file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)] 84 | return file_list 85 | -------------------------------------------------------------------------------- /library/hypernetwork.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from diffusers.models.attention_processor import ( 4 | Attention, 5 | AttnProcessor2_0, 6 | SlicedAttnProcessor, 7 | XFormersAttnProcessor 8 | ) 9 | 10 | try: 11 | import xformers.ops 12 | except: 13 | xformers = None 14 | 15 | 16 | loaded_networks = [] 17 | 18 | 19 | def apply_single_hypernetwork( 20 | hypernetwork, hidden_states, encoder_hidden_states 21 | ): 22 | context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states) 23 | return context_k, context_v 24 | 25 | 26 | def apply_hypernetworks(context_k, context_v, layer=None): 27 | if len(loaded_networks) == 0: 28 | return context_v, context_v 29 | for hypernetwork in loaded_networks: 30 | context_k, context_v = hypernetwork.forward(context_k, context_v) 31 | 32 | context_k = context_k.to(dtype=context_k.dtype) 33 | context_v = context_v.to(dtype=context_k.dtype) 34 | 35 | return context_k, context_v 36 | 37 | 38 | 39 | def xformers_forward( 40 | self: XFormersAttnProcessor, 41 | attn: Attention, 42 | hidden_states: torch.Tensor, 43 | encoder_hidden_states: torch.Tensor = None, 44 | attention_mask: torch.Tensor = None, 45 | ): 46 | batch_size, sequence_length, _ = ( 47 | hidden_states.shape 48 | if encoder_hidden_states is None 49 | else encoder_hidden_states.shape 50 | ) 51 | 52 | attention_mask = attn.prepare_attention_mask( 53 | attention_mask, sequence_length, batch_size 54 | ) 55 | 56 | query = attn.to_q(hidden_states) 57 | 58 | if encoder_hidden_states is None: 59 | encoder_hidden_states = hidden_states 60 | elif attn.norm_cross: 61 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 62 | 63 | context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) 64 | 65 | key = attn.to_k(context_k) 66 | value = attn.to_v(context_v) 67 | 68 | query = attn.head_to_batch_dim(query).contiguous() 69 | key = attn.head_to_batch_dim(key).contiguous() 70 | value = attn.head_to_batch_dim(value).contiguous() 71 | 72 | hidden_states = xformers.ops.memory_efficient_attention( 73 | query, 74 | key, 75 | value, 76 | attn_bias=attention_mask, 77 | op=self.attention_op, 78 | scale=attn.scale, 79 | ) 80 | hidden_states = hidden_states.to(query.dtype) 81 | hidden_states = attn.batch_to_head_dim(hidden_states) 82 | 83 | # linear proj 84 | hidden_states = attn.to_out[0](hidden_states) 85 | # dropout 86 | hidden_states = attn.to_out[1](hidden_states) 87 | return hidden_states 88 | 89 | 90 | def sliced_attn_forward( 91 | self: SlicedAttnProcessor, 92 | attn: Attention, 93 | hidden_states: torch.Tensor, 94 | encoder_hidden_states: torch.Tensor = None, 95 | attention_mask: torch.Tensor = None, 96 | ): 97 | batch_size, sequence_length, _ = ( 98 | hidden_states.shape 99 | if encoder_hidden_states is None 100 | else encoder_hidden_states.shape 101 | ) 102 | attention_mask = attn.prepare_attention_mask( 103 | attention_mask, sequence_length, batch_size 104 | ) 105 | 106 | query = attn.to_q(hidden_states) 107 | dim = query.shape[-1] 108 | query = attn.head_to_batch_dim(query) 109 | 110 | if encoder_hidden_states is None: 111 | encoder_hidden_states = hidden_states 112 | elif attn.norm_cross: 113 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 114 | 115 | context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) 116 | 117 | key = attn.to_k(context_k) 118 | value = attn.to_v(context_v) 119 | key = attn.head_to_batch_dim(key) 120 | value = attn.head_to_batch_dim(value) 121 | 122 | batch_size_attention, query_tokens, _ = query.shape 123 | hidden_states = torch.zeros( 124 | (batch_size_attention, query_tokens, dim // attn.heads), 125 | device=query.device, 126 | dtype=query.dtype, 127 | ) 128 | 129 | for i in range(batch_size_attention // self.slice_size): 130 | start_idx = i * self.slice_size 131 | end_idx = (i + 1) * self.slice_size 132 | 133 | query_slice = query[start_idx:end_idx] 134 | key_slice = key[start_idx:end_idx] 135 | attn_mask_slice = ( 136 | attention_mask[start_idx:end_idx] if attention_mask is not None else None 137 | ) 138 | 139 | attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) 140 | 141 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) 142 | 143 | hidden_states[start_idx:end_idx] = attn_slice 144 | 145 | hidden_states = attn.batch_to_head_dim(hidden_states) 146 | 147 | # linear proj 148 | hidden_states = attn.to_out[0](hidden_states) 149 | # dropout 150 | hidden_states = attn.to_out[1](hidden_states) 151 | 152 | return hidden_states 153 | 154 | 155 | def v2_0_forward( 156 | self: AttnProcessor2_0, 157 | attn: Attention, 158 | hidden_states, 159 | encoder_hidden_states=None, 160 | attention_mask=None, 161 | ): 162 | batch_size, sequence_length, _ = ( 163 | hidden_states.shape 164 | if encoder_hidden_states is None 165 | else encoder_hidden_states.shape 166 | ) 167 | inner_dim = hidden_states.shape[-1] 168 | 169 | if attention_mask is not None: 170 | attention_mask = attn.prepare_attention_mask( 171 | attention_mask, sequence_length, batch_size 172 | ) 173 | # scaled_dot_product_attention expects attention_mask shape to be 174 | # (batch, heads, source_length, target_length) 175 | attention_mask = attention_mask.view( 176 | batch_size, attn.heads, -1, attention_mask.shape[-1] 177 | ) 178 | 179 | query = attn.to_q(hidden_states) 180 | 181 | if encoder_hidden_states is None: 182 | encoder_hidden_states = hidden_states 183 | elif attn.norm_cross: 184 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 185 | 186 | context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) 187 | 188 | key = attn.to_k(context_k) 189 | value = attn.to_v(context_v) 190 | 191 | head_dim = inner_dim // attn.heads 192 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 193 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 194 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 195 | 196 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 197 | # TODO: add support for attn.scale when we move to Torch 2.1 198 | hidden_states = F.scaled_dot_product_attention( 199 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 200 | ) 201 | 202 | hidden_states = hidden_states.transpose(1, 2).reshape( 203 | batch_size, -1, attn.heads * head_dim 204 | ) 205 | hidden_states = hidden_states.to(query.dtype) 206 | 207 | # linear proj 208 | hidden_states = attn.to_out[0](hidden_states) 209 | # dropout 210 | hidden_states = attn.to_out[1](hidden_states) 211 | return hidden_states 212 | 213 | 214 | def replace_attentions_for_hypernetwork(): 215 | import diffusers.models.attention_processor 216 | 217 | diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = ( 218 | xformers_forward 219 | ) 220 | diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = ( 221 | sliced_attn_forward 222 | ) 223 | diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward 224 | -------------------------------------------------------------------------------- /library/ipex/attention.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from functools import cache, wraps 4 | 5 | # pylint: disable=protected-access, missing-function-docstring, line-too-long 6 | 7 | # ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers 8 | 9 | sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 1)) 10 | attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 0.5)) 11 | 12 | # Find something divisible with the input_tokens 13 | @cache 14 | def find_split_size(original_size, slice_block_size, slice_rate=2): 15 | split_size = original_size 16 | while True: 17 | if (split_size * slice_block_size) <= slice_rate and original_size % split_size == 0: 18 | return split_size 19 | split_size = split_size - 1 20 | if split_size <= 1: 21 | return 1 22 | return split_size 23 | 24 | 25 | # Find slice sizes for SDPA 26 | @cache 27 | def find_sdpa_slice_sizes(query_shape, key_shape, query_element_size, slice_rate=2, trigger_rate=3): 28 | batch_size, attn_heads, query_len, _ = query_shape 29 | _, _, key_len, _ = key_shape 30 | 31 | slice_batch_size = attn_heads * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024 32 | 33 | split_batch_size = batch_size 34 | split_head_size = attn_heads 35 | split_query_size = query_len 36 | 37 | do_batch_split = False 38 | do_head_split = False 39 | do_query_split = False 40 | 41 | if batch_size * slice_batch_size >= trigger_rate: 42 | do_batch_split = True 43 | split_batch_size = find_split_size(batch_size, slice_batch_size, slice_rate=slice_rate) 44 | 45 | if split_batch_size * slice_batch_size > slice_rate: 46 | slice_head_size = split_batch_size * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024 47 | do_head_split = True 48 | split_head_size = find_split_size(attn_heads, slice_head_size, slice_rate=slice_rate) 49 | 50 | if split_head_size * slice_head_size > slice_rate: 51 | slice_query_size = split_batch_size * split_head_size * (key_len) * query_element_size / 1024 / 1024 / 1024 52 | do_query_split = True 53 | split_query_size = find_split_size(query_len, slice_query_size, slice_rate=slice_rate) 54 | 55 | return do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size 56 | 57 | 58 | original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention 59 | @wraps(torch.nn.functional.scaled_dot_product_attention) 60 | def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): 61 | if query.device.type != "xpu": 62 | return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) 63 | is_unsqueezed = False 64 | if len(query.shape) == 3: 65 | query = query.unsqueeze(0) 66 | is_unsqueezed = True 67 | if len(key.shape) == 3: 68 | key = key.unsqueeze(0) 69 | if len(value.shape) == 3: 70 | value = value.unsqueeze(0) 71 | do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate) 72 | 73 | # Slice SDPA 74 | if do_batch_split: 75 | batch_size, attn_heads, query_len, _ = query.shape 76 | _, _, _, head_dim = value.shape 77 | hidden_states = torch.zeros((batch_size, attn_heads, query_len, head_dim), device=query.device, dtype=query.dtype) 78 | if attn_mask is not None: 79 | attn_mask = attn_mask.expand((query.shape[0], query.shape[1], query.shape[2], key.shape[-2])) 80 | for ib in range(batch_size // split_batch_size): 81 | start_idx = ib * split_batch_size 82 | end_idx = (ib + 1) * split_batch_size 83 | if do_head_split: 84 | for ih in range(attn_heads // split_head_size): # pylint: disable=invalid-name 85 | start_idx_h = ih * split_head_size 86 | end_idx_h = (ih + 1) * split_head_size 87 | if do_query_split: 88 | for iq in range(query_len // split_query_size): # pylint: disable=invalid-name 89 | start_idx_q = iq * split_query_size 90 | end_idx_q = (iq + 1) * split_query_size 91 | hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] = original_scaled_dot_product_attention( 92 | query[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :], 93 | key[start_idx:end_idx, start_idx_h:end_idx_h, :, :], 94 | value[start_idx:end_idx, start_idx_h:end_idx_h, :, :], 95 | attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] if attn_mask is not None else attn_mask, 96 | dropout_p=dropout_p, is_causal=is_causal, **kwargs 97 | ) 98 | else: 99 | hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, :, :] = original_scaled_dot_product_attention( 100 | query[start_idx:end_idx, start_idx_h:end_idx_h, :, :], 101 | key[start_idx:end_idx, start_idx_h:end_idx_h, :, :], 102 | value[start_idx:end_idx, start_idx_h:end_idx_h, :, :], 103 | attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, :, :] if attn_mask is not None else attn_mask, 104 | dropout_p=dropout_p, is_causal=is_causal, **kwargs 105 | ) 106 | else: 107 | hidden_states[start_idx:end_idx, :, :, :] = original_scaled_dot_product_attention( 108 | query[start_idx:end_idx, :, :, :], 109 | key[start_idx:end_idx, :, :, :], 110 | value[start_idx:end_idx, :, :, :], 111 | attn_mask=attn_mask[start_idx:end_idx, :, :, :] if attn_mask is not None else attn_mask, 112 | dropout_p=dropout_p, is_causal=is_causal, **kwargs 113 | ) 114 | torch.xpu.synchronize(query.device) 115 | else: 116 | hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) 117 | if is_unsqueezed: 118 | hidden_states.squeeze(0) 119 | return hidden_states 120 | -------------------------------------------------------------------------------- /library/ipex/diffusers.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | import torch 3 | import diffusers # pylint: disable=import-error 4 | 5 | # pylint: disable=protected-access, missing-function-docstring, line-too-long 6 | 7 | 8 | # Diffusers FreeU 9 | original_fourier_filter = diffusers.utils.torch_utils.fourier_filter 10 | @wraps(diffusers.utils.torch_utils.fourier_filter) 11 | def fourier_filter(x_in, threshold, scale): 12 | return_dtype = x_in.dtype 13 | return original_fourier_filter(x_in.to(dtype=torch.float32), threshold, scale).to(dtype=return_dtype) 14 | 15 | 16 | # fp64 error 17 | class FluxPosEmbed(torch.nn.Module): 18 | def __init__(self, theta: int, axes_dim): 19 | super().__init__() 20 | self.theta = theta 21 | self.axes_dim = axes_dim 22 | 23 | def forward(self, ids: torch.Tensor) -> torch.Tensor: 24 | n_axes = ids.shape[-1] 25 | cos_out = [] 26 | sin_out = [] 27 | pos = ids.float() 28 | for i in range(n_axes): 29 | cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed( 30 | self.axes_dim[i], 31 | pos[:, i], 32 | theta=self.theta, 33 | repeat_interleave_real=True, 34 | use_real=True, 35 | freqs_dtype=torch.float32, 36 | ) 37 | cos_out.append(cos) 38 | sin_out.append(sin) 39 | freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) 40 | freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) 41 | return freqs_cos, freqs_sin 42 | 43 | 44 | def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False): 45 | diffusers.utils.torch_utils.fourier_filter = fourier_filter 46 | if not device_supports_fp64: 47 | diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed 48 | -------------------------------------------------------------------------------- /networks/check_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from safetensors.torch import load_file 5 | from library.utils import setup_logging 6 | setup_logging() 7 | import logging 8 | logger = logging.getLogger(__name__) 9 | 10 | def main(file): 11 | logger.info(f"loading: {file}") 12 | if os.path.splitext(file)[1] == ".safetensors": 13 | sd = load_file(file) 14 | else: 15 | sd = torch.load(file, map_location="cpu") 16 | 17 | values = [] 18 | 19 | keys = list(sd.keys()) 20 | for key in keys: 21 | if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key: 22 | values.append((key, sd[key])) 23 | print(f"number of LoRA modules: {len(values)}") 24 | 25 | if args.show_all_keys: 26 | for key in [k for k in keys if k not in values]: 27 | values.append((key, sd[key])) 28 | print(f"number of all modules: {len(values)}") 29 | 30 | for key, value in values: 31 | value = value.to(torch.float32) 32 | print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") 33 | 34 | 35 | def setup_parser() -> argparse.ArgumentParser: 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") 38 | parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する") 39 | 40 | return parser 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = setup_parser() 45 | 46 | args = parser.parse_args() 47 | 48 | main(args.file) 49 | -------------------------------------------------------------------------------- /networks/extract_lora_from_dylora.py: -------------------------------------------------------------------------------- 1 | # Convert LoRA to different rank approximation (should only be used to go to lower rank) 2 | # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py 3 | # Thanks to cloneofsimo 4 | 5 | import argparse 6 | import math 7 | import os 8 | import torch 9 | from safetensors.torch import load_file, save_file, safe_open 10 | from tqdm import tqdm 11 | from library import train_util, model_util 12 | import numpy as np 13 | from library.utils import setup_logging 14 | setup_logging() 15 | import logging 16 | logger = logging.getLogger(__name__) 17 | 18 | def load_state_dict(file_name): 19 | if model_util.is_safetensors(file_name): 20 | sd = load_file(file_name) 21 | with safe_open(file_name, framework="pt") as f: 22 | metadata = f.metadata() 23 | else: 24 | sd = torch.load(file_name, map_location="cpu") 25 | metadata = None 26 | 27 | return sd, metadata 28 | 29 | 30 | def save_to_file(file_name, model, metadata): 31 | if model_util.is_safetensors(file_name): 32 | save_file(model, file_name, metadata) 33 | else: 34 | torch.save(model, file_name) 35 | 36 | 37 | def split_lora_model(lora_sd, unit): 38 | max_rank = 0 39 | 40 | # Extract loaded lora dim and alpha 41 | for key, value in lora_sd.items(): 42 | if "lora_down" in key: 43 | rank = value.size()[0] 44 | if rank > max_rank: 45 | max_rank = rank 46 | logger.info(f"Max rank: {max_rank}") 47 | 48 | rank = unit 49 | split_models = [] 50 | new_alpha = None 51 | while rank < max_rank: 52 | logger.info(f"Splitting rank {rank}") 53 | new_sd = {} 54 | for key, value in lora_sd.items(): 55 | if "lora_down" in key: 56 | new_sd[key] = value[:rank].contiguous() 57 | elif "lora_up" in key: 58 | new_sd[key] = value[:, :rank].contiguous() 59 | else: 60 | # なぜかscaleするとおかしくなる…… 61 | # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0] 62 | # scale = math.sqrt(this_rank / rank) # rank is > unit 63 | # logger.info(key, value.size(), this_rank, rank, value, scale) 64 | # new_alpha = value * scale # always same 65 | # new_sd[key] = new_alpha 66 | new_sd[key] = value 67 | 68 | split_models.append((new_sd, rank, new_alpha)) 69 | rank += unit 70 | 71 | return max_rank, split_models 72 | 73 | 74 | def split(args): 75 | logger.info("loading Model...") 76 | lora_sd, metadata = load_state_dict(args.model) 77 | 78 | logger.info("Splitting Model...") 79 | original_rank, split_models = split_lora_model(lora_sd, args.unit) 80 | 81 | comment = metadata.get("ss_training_comment", "") 82 | for state_dict, new_rank, new_alpha in split_models: 83 | # update metadata 84 | if metadata is None: 85 | new_metadata = {} 86 | else: 87 | new_metadata = metadata.copy() 88 | 89 | new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}" 90 | new_metadata["ss_network_dim"] = str(new_rank) 91 | # new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy()) 92 | 93 | model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) 94 | metadata["sshs_model_hash"] = model_hash 95 | metadata["sshs_legacy_hash"] = legacy_hash 96 | 97 | filename, ext = os.path.splitext(args.save_to) 98 | model_file_name = filename + f"-{new_rank:04d}{ext}" 99 | 100 | logger.info(f"saving model to: {model_file_name}") 101 | save_to_file(model_file_name, state_dict, new_metadata) 102 | 103 | 104 | def setup_parser() -> argparse.ArgumentParser: 105 | parser = argparse.ArgumentParser() 106 | 107 | parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ") 108 | parser.add_argument( 109 | "--save_to", 110 | type=str, 111 | default=None, 112 | help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors", 113 | ) 114 | parser.add_argument( 115 | "--model", 116 | type=str, 117 | default=None, 118 | help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors", 119 | ) 120 | 121 | return parser 122 | 123 | 124 | if __name__ == "__main__": 125 | parser = setup_parser() 126 | 127 | args = parser.parse_args() 128 | split(args) 129 | -------------------------------------------------------------------------------- /networks/lora_interrogator.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from tqdm import tqdm 4 | from library import model_util 5 | import library.train_util as train_util 6 | import argparse 7 | from transformers import CLIPTokenizer 8 | 9 | import torch 10 | from library.device_utils import init_ipex, get_preferred_device 11 | init_ipex() 12 | 13 | import library.model_util as model_util 14 | import lora 15 | from library.utils import setup_logging 16 | setup_logging() 17 | import logging 18 | logger = logging.getLogger(__name__) 19 | 20 | TOKENIZER_PATH = "openai/clip-vit-large-patch14" 21 | V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う 22 | 23 | DEVICE = get_preferred_device() 24 | 25 | 26 | def interrogate(args): 27 | weights_dtype = torch.float16 28 | 29 | # いろいろ準備する 30 | logger.info(f"loading SD model: {args.sd_model}") 31 | args.pretrained_model_name_or_path = args.sd_model 32 | args.vae = None 33 | text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE) 34 | 35 | logger.info(f"loading LoRA: {args.model}") 36 | network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) 37 | 38 | # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい 39 | has_te_weight = False 40 | for key in weights_sd.keys(): 41 | if 'lora_te' in key: 42 | has_te_weight = True 43 | break 44 | if not has_te_weight: 45 | logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") 46 | return 47 | del vae 48 | 49 | logger.info("loading tokenizer") 50 | if args.v2: 51 | tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") 52 | else: 53 | tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2) 54 | 55 | text_encoder.to(DEVICE, dtype=weights_dtype) 56 | text_encoder.eval() 57 | unet.to(DEVICE, dtype=weights_dtype) 58 | unet.eval() # U-Netは呼び出さないので不要だけど 59 | 60 | # トークンをひとつひとつ当たっていく 61 | token_id_start = 0 62 | token_id_end = max(tokenizer.all_special_ids) 63 | logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}") 64 | 65 | def get_all_embeddings(text_encoder): 66 | embs = [] 67 | with torch.no_grad(): 68 | for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)): 69 | batch = [] 70 | for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)): 71 | tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id] 72 | # tokens = [tid] # こちらは結果がいまひとつ 73 | batch.append(tokens) 74 | 75 | # batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1] 76 | # clip skip対応 77 | batch = torch.tensor(batch).to(DEVICE) 78 | if args.clip_skip is None: 79 | encoder_hidden_states = text_encoder(batch)[0] 80 | else: 81 | enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True) 82 | encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] 83 | encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) 84 | encoder_hidden_states = encoder_hidden_states.to("cpu") 85 | 86 | embs.extend(encoder_hidden_states) 87 | return torch.stack(embs) 88 | 89 | logger.info("get original text encoder embeddings.") 90 | orig_embs = get_all_embeddings(text_encoder) 91 | 92 | network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0) 93 | info = network.load_state_dict(weights_sd, strict=False) 94 | logger.info(f"Loading LoRA weights: {info}") 95 | 96 | network.to(DEVICE, dtype=weights_dtype) 97 | network.eval() 98 | 99 | del unet 100 | 101 | logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") 102 | logger.info("get text encoder embeddings with lora.") 103 | lora_embs = get_all_embeddings(text_encoder) 104 | 105 | # 比べる:とりあえず単純に差分の絶対値で 106 | logger.info("comparing...") 107 | diffs = {} 108 | for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))): 109 | diff = torch.mean(torch.abs(orig_emb - lora_emb)) 110 | # diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない 111 | diff = float(diff.detach().to('cpu').numpy()) 112 | diffs[token_id_start + i] = diff 113 | 114 | diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1]) 115 | 116 | # 結果を表示する 117 | print("top 100:") 118 | for i, (token, diff) in enumerate(diffs_sorted[:100]): 119 | # if diff < 1e-6: 120 | # break 121 | string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token])) 122 | print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}") 123 | 124 | 125 | def setup_parser() -> argparse.ArgumentParser: 126 | parser = argparse.ArgumentParser() 127 | 128 | parser.add_argument("--v2", action='store_true', 129 | help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') 130 | parser.add_argument("--sd_model", type=str, default=None, 131 | help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors") 132 | parser.add_argument("--model", type=str, default=None, 133 | help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors") 134 | parser.add_argument("--batch_size", type=int, default=16, 135 | help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ") 136 | parser.add_argument("--clip_skip", type=int, default=None, 137 | help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") 138 | 139 | return parser 140 | 141 | 142 | if __name__ == '__main__': 143 | parser = setup_parser() 144 | 145 | args = parser.parse_args() 146 | interrogate(args) 147 | -------------------------------------------------------------------------------- /networks/merge_lora_old.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | import os 5 | import torch 6 | from safetensors.torch import load_file, save_file 7 | import library.model_util as model_util 8 | import lora 9 | from library.utils import setup_logging 10 | setup_logging() 11 | import logging 12 | logger = logging.getLogger(__name__) 13 | 14 | def load_state_dict(file_name, dtype): 15 | if os.path.splitext(file_name)[1] == '.safetensors': 16 | sd = load_file(file_name) 17 | else: 18 | sd = torch.load(file_name, map_location='cpu') 19 | for key in list(sd.keys()): 20 | if type(sd[key]) == torch.Tensor: 21 | sd[key] = sd[key].to(dtype) 22 | return sd 23 | 24 | 25 | def save_to_file(file_name, model, state_dict, dtype): 26 | if dtype is not None: 27 | for key in list(state_dict.keys()): 28 | if type(state_dict[key]) == torch.Tensor: 29 | state_dict[key] = state_dict[key].to(dtype) 30 | 31 | if os.path.splitext(file_name)[1] == '.safetensors': 32 | save_file(model, file_name) 33 | else: 34 | torch.save(model, file_name) 35 | 36 | 37 | def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): 38 | text_encoder.to(merge_dtype) 39 | unet.to(merge_dtype) 40 | 41 | # create module map 42 | name_to_module = {} 43 | for i, root_module in enumerate([text_encoder, unet]): 44 | if i == 0: 45 | prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER 46 | target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE 47 | else: 48 | prefix = lora.LoRANetwork.LORA_PREFIX_UNET 49 | target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE 50 | 51 | for name, module in root_module.named_modules(): 52 | if module.__class__.__name__ in target_replace_modules: 53 | for child_name, child_module in module.named_modules(): 54 | if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): 55 | lora_name = prefix + '.' + name + '.' + child_name 56 | lora_name = lora_name.replace('.', '_') 57 | name_to_module[lora_name] = child_module 58 | 59 | for model, ratio in zip(models, ratios): 60 | logger.info(f"loading: {model}") 61 | lora_sd = load_state_dict(model, merge_dtype) 62 | 63 | logger.info(f"merging...") 64 | for key in lora_sd.keys(): 65 | if "lora_down" in key: 66 | up_key = key.replace("lora_down", "lora_up") 67 | alpha_key = key[:key.index("lora_down")] + 'alpha' 68 | 69 | # find original module for this lora 70 | module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" 71 | if module_name not in name_to_module: 72 | logger.info(f"no module found for LoRA weight: {key}") 73 | continue 74 | module = name_to_module[module_name] 75 | # logger.info(f"apply {key} to {module}") 76 | 77 | down_weight = lora_sd[key] 78 | up_weight = lora_sd[up_key] 79 | 80 | dim = down_weight.size()[0] 81 | alpha = lora_sd.get(alpha_key, dim) 82 | scale = alpha / dim 83 | 84 | # W <- W + U * D 85 | weight = module.weight 86 | if len(weight.size()) == 2: 87 | # linear 88 | weight = weight + ratio * (up_weight @ down_weight) * scale 89 | else: 90 | # conv2d 91 | weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale 92 | 93 | module.weight = torch.nn.Parameter(weight) 94 | 95 | 96 | def merge_lora_models(models, ratios, merge_dtype): 97 | merged_sd = {} 98 | 99 | alpha = None 100 | dim = None 101 | for model, ratio in zip(models, ratios): 102 | logger.info(f"loading: {model}") 103 | lora_sd = load_state_dict(model, merge_dtype) 104 | 105 | logger.info(f"merging...") 106 | for key in lora_sd.keys(): 107 | if 'alpha' in key: 108 | if key in merged_sd: 109 | assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません" 110 | else: 111 | alpha = lora_sd[key].detach().numpy() 112 | merged_sd[key] = lora_sd[key] 113 | else: 114 | if key in merged_sd: 115 | assert merged_sd[key].size() == lora_sd[key].size( 116 | ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" 117 | merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio 118 | else: 119 | if "lora_down" in key: 120 | dim = lora_sd[key].size()[0] 121 | merged_sd[key] = lora_sd[key] * ratio 122 | 123 | logger.info(f"dim (rank): {dim}, alpha: {alpha}") 124 | if alpha is None: 125 | alpha = dim 126 | 127 | return merged_sd, dim, alpha 128 | 129 | 130 | def merge(args): 131 | assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" 132 | 133 | def str_to_dtype(p): 134 | if p == 'float': 135 | return torch.float 136 | if p == 'fp16': 137 | return torch.float16 138 | if p == 'bf16': 139 | return torch.bfloat16 140 | return None 141 | 142 | merge_dtype = str_to_dtype(args.precision) 143 | save_dtype = str_to_dtype(args.save_precision) 144 | if save_dtype is None: 145 | save_dtype = merge_dtype 146 | 147 | if args.sd_model is not None: 148 | logger.info(f"loading SD model: {args.sd_model}") 149 | 150 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) 151 | 152 | merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) 153 | 154 | logger.info("") 155 | logger.info(f"saving SD model to: {args.save_to}") 156 | model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, 157 | args.sd_model, 0, 0, save_dtype, vae) 158 | else: 159 | state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) 160 | 161 | logger.info(f"") 162 | logger.info(f"saving model to: {args.save_to}") 163 | save_to_file(args.save_to, state_dict, state_dict, save_dtype) 164 | 165 | 166 | def setup_parser() -> argparse.ArgumentParser: 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument("--v2", action='store_true', 169 | help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') 170 | parser.add_argument("--save_precision", type=str, default=None, 171 | choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") 172 | parser.add_argument("--precision", type=str, default="float", 173 | choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)") 174 | parser.add_argument("--sd_model", type=str, default=None, 175 | help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") 176 | parser.add_argument("--save_to", type=str, default=None, 177 | help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") 178 | parser.add_argument("--models", type=str, nargs='*', 179 | help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") 180 | parser.add_argument("--ratios", type=float, nargs='*', 181 | help="ratios for each model / それぞれのLoRAモデルの比率") 182 | 183 | return parser 184 | 185 | 186 | if __name__ == '__main__': 187 | parser = setup_parser() 188 | 189 | args = parser.parse_args() 190 | merge(args) 191 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.30.0 2 | transformers==4.44.0 3 | diffusers[torch]==0.25.0 4 | ftfy==6.1.1 5 | # albumentations==1.3.0 6 | opencv-python==4.8.1.78 7 | einops==0.7.0 8 | pytorch-lightning==1.9.0 9 | bitsandbytes==0.44.0 10 | prodigyopt==1.0 11 | lion-pytorch==0.0.6 12 | tensorboard 13 | safetensors==0.4.2 14 | # gradio==3.16.2 15 | altair==4.2.2 16 | easygui==0.98.3 17 | toml==0.10.2 18 | voluptuous==0.13.1 19 | huggingface-hub==0.24.5 20 | # for Image utils 21 | imagesize==1.4.1 22 | # for BLIP captioning 23 | # requests==2.28.2 24 | # timm==0.6.12 25 | # fairscale==0.4.13 26 | # for WD14 captioning (tensorflow) 27 | # tensorflow==2.10.1 28 | # for WD14 captioning (onnx) 29 | # onnx==1.15.0 30 | # onnxruntime-gpu==1.17.1 31 | # onnxruntime==1.17.1 32 | # for cuda 12.1(default 11.8) 33 | # onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ 34 | 35 | # this is for onnx: 36 | # protobuf==3.20.3 37 | # open clip for SDXL 38 | # open-clip-torch==2.20.0 39 | # For logging 40 | rich==13.7.0 41 | # for kohya_ss library 42 | -e . 43 | -------------------------------------------------------------------------------- /sdxl_train_network.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from library.device_utils import init_ipex, clean_memory_on_device 5 | init_ipex() 6 | 7 | from library import sdxl_model_util, sdxl_train_util, train_util 8 | import train_network 9 | from library.utils import setup_logging 10 | setup_logging() 11 | import logging 12 | logger = logging.getLogger(__name__) 13 | 14 | class SdxlNetworkTrainer(train_network.NetworkTrainer): 15 | def __init__(self): 16 | super().__init__() 17 | self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR 18 | self.is_sdxl = True 19 | 20 | def assert_extra_args(self, args, train_dataset_group): 21 | sdxl_train_util.verify_sdxl_training_args(args) 22 | 23 | if args.cache_text_encoder_outputs: 24 | assert ( 25 | train_dataset_group.is_text_encoder_output_cacheable() 26 | ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" 27 | 28 | assert ( 29 | args.network_train_unet_only or not args.cache_text_encoder_outputs 30 | ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" 31 | 32 | train_dataset_group.verify_bucket_reso_steps(32) 33 | 34 | def load_target_model(self, args, weight_dtype, accelerator): 35 | ( 36 | load_stable_diffusion_format, 37 | text_encoder1, 38 | text_encoder2, 39 | vae, 40 | unet, 41 | logit_scale, 42 | ckpt_info, 43 | ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) 44 | 45 | self.load_stable_diffusion_format = load_stable_diffusion_format 46 | self.logit_scale = logit_scale 47 | self.ckpt_info = ckpt_info 48 | 49 | return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet 50 | 51 | def load_tokenizer(self, args): 52 | tokenizer = sdxl_train_util.load_tokenizers(args) 53 | return tokenizer 54 | 55 | def is_text_encoder_outputs_cached(self, args): 56 | return args.cache_text_encoder_outputs 57 | 58 | def cache_text_encoder_outputs_if_needed( 59 | self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype 60 | ): 61 | if args.cache_text_encoder_outputs: 62 | if not args.lowram: 63 | # メモリ消費を減らす 64 | logger.info("move vae and unet to cpu to save memory") 65 | org_vae_device = vae.device 66 | org_unet_device = unet.device 67 | vae.to("cpu") 68 | unet.to("cpu") 69 | clean_memory_on_device(accelerator.device) 70 | 71 | # When TE is not be trained, it will not be prepared so we need to use explicit autocast 72 | with accelerator.autocast(): 73 | dataset.cache_text_encoder_outputs( 74 | tokenizers, 75 | text_encoders, 76 | accelerator.device, 77 | weight_dtype, 78 | args.cache_text_encoder_outputs_to_disk, 79 | accelerator.is_main_process, 80 | ) 81 | 82 | text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU 83 | text_encoders[1].to("cpu", dtype=torch.float32) 84 | clean_memory_on_device(accelerator.device) 85 | 86 | if not args.lowram: 87 | logger.info("move vae and unet back to original device") 88 | vae.to(org_vae_device) 89 | unet.to(org_unet_device) 90 | else: 91 | # Text Encoderから毎回出力を取得するので、GPUに乗せておく 92 | text_encoders[0].to(accelerator.device, dtype=weight_dtype) 93 | text_encoders[1].to(accelerator.device, dtype=weight_dtype) 94 | 95 | def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): 96 | if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: 97 | input_ids1 = batch["input_ids"] 98 | input_ids2 = batch["input_ids2"] 99 | with torch.enable_grad(): 100 | # Get the text embedding for conditioning 101 | # TODO support weighted captions 102 | # if args.weighted_captions: 103 | # encoder_hidden_states = get_weighted_text_embeddings( 104 | # tokenizer, 105 | # text_encoder, 106 | # batch["captions"], 107 | # accelerator.device, 108 | # args.max_token_length // 75 if args.max_token_length else 1, 109 | # clip_skip=args.clip_skip, 110 | # ) 111 | # else: 112 | input_ids1 = input_ids1.to(accelerator.device) 113 | input_ids2 = input_ids2.to(accelerator.device) 114 | encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( 115 | args.max_token_length, 116 | input_ids1, 117 | input_ids2, 118 | tokenizers[0], 119 | tokenizers[1], 120 | text_encoders[0], 121 | text_encoders[1], 122 | None if not args.full_fp16 else weight_dtype, 123 | accelerator=accelerator, 124 | ) 125 | else: 126 | encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) 127 | encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) 128 | pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) 129 | 130 | # # verify that the text encoder outputs are correct 131 | # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( 132 | # args.max_token_length, 133 | # batch["input_ids"].to(text_encoders[0].device), 134 | # batch["input_ids2"].to(text_encoders[0].device), 135 | # tokenizers[0], 136 | # tokenizers[1], 137 | # text_encoders[0], 138 | # text_encoders[1], 139 | # None if not args.full_fp16 else weight_dtype, 140 | # ) 141 | # b_size = encoder_hidden_states1.shape[0] 142 | # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 143 | # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 144 | # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 145 | # logger.info("text encoder outputs verified") 146 | 147 | return encoder_hidden_states1, encoder_hidden_states2, pool2 148 | 149 | def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): 150 | noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype 151 | 152 | # get size embeddings 153 | orig_size = batch["original_sizes_hw"] 154 | crop_size = batch["crop_top_lefts"] 155 | target_size = batch["target_sizes_hw"] 156 | embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) 157 | 158 | # concat embeddings 159 | encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds 160 | vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) 161 | text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) 162 | 163 | noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) 164 | return noise_pred 165 | 166 | def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): 167 | sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) 168 | 169 | 170 | def setup_parser() -> argparse.ArgumentParser: 171 | parser = train_network.setup_parser() 172 | sdxl_train_util.add_sdxl_training_arguments(parser) 173 | return parser 174 | 175 | 176 | if __name__ == "__main__": 177 | parser = setup_parser() 178 | 179 | args = parser.parse_args() 180 | train_util.verify_command_line_training_args(args) 181 | args = train_util.read_config_from_file(args, parser) 182 | 183 | trainer = SdxlNetworkTrainer() 184 | trainer.train(args) 185 | -------------------------------------------------------------------------------- /sdxl_train_textual_inversion.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import regex 5 | 6 | import torch 7 | from library.device_utils import init_ipex 8 | init_ipex() 9 | 10 | from library import sdxl_model_util, sdxl_train_util, train_util 11 | 12 | import train_textual_inversion 13 | 14 | 15 | class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer): 16 | def __init__(self): 17 | super().__init__() 18 | self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR 19 | self.is_sdxl = True 20 | 21 | def assert_extra_args(self, args, train_dataset_group): 22 | super().assert_extra_args(args, train_dataset_group) 23 | sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) 24 | 25 | train_dataset_group.verify_bucket_reso_steps(32) 26 | 27 | def load_target_model(self, args, weight_dtype, accelerator): 28 | ( 29 | load_stable_diffusion_format, 30 | text_encoder1, 31 | text_encoder2, 32 | vae, 33 | unet, 34 | logit_scale, 35 | ckpt_info, 36 | ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) 37 | 38 | self.load_stable_diffusion_format = load_stable_diffusion_format 39 | self.logit_scale = logit_scale 40 | self.ckpt_info = ckpt_info 41 | 42 | return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet 43 | 44 | def load_tokenizer(self, args): 45 | tokenizer = sdxl_train_util.load_tokenizers(args) 46 | return tokenizer 47 | 48 | def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): 49 | input_ids1 = batch["input_ids"] 50 | input_ids2 = batch["input_ids2"] 51 | with torch.enable_grad(): 52 | input_ids1 = input_ids1.to(accelerator.device) 53 | input_ids2 = input_ids2.to(accelerator.device) 54 | encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( 55 | args.max_token_length, 56 | input_ids1, 57 | input_ids2, 58 | tokenizers[0], 59 | tokenizers[1], 60 | text_encoders[0], 61 | text_encoders[1], 62 | None if not args.full_fp16 else weight_dtype, 63 | accelerator=accelerator, 64 | ) 65 | return encoder_hidden_states1, encoder_hidden_states2, pool2 66 | 67 | def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): 68 | noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype 69 | 70 | # get size embeddings 71 | orig_size = batch["original_sizes_hw"] 72 | crop_size = batch["crop_top_lefts"] 73 | target_size = batch["target_sizes_hw"] 74 | embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) 75 | 76 | # concat embeddings 77 | encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds 78 | vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) 79 | text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) 80 | 81 | noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) 82 | return noise_pred 83 | 84 | def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): 85 | sdxl_train_util.sample_images( 86 | accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement 87 | ) 88 | 89 | def save_weights(self, file, updated_embs, save_dtype, metadata): 90 | state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]} 91 | 92 | if save_dtype is not None: 93 | for key in list(state_dict.keys()): 94 | v = state_dict[key] 95 | v = v.detach().clone().to("cpu").to(save_dtype) 96 | state_dict[key] = v 97 | 98 | if os.path.splitext(file)[1] == ".safetensors": 99 | from safetensors.torch import save_file 100 | 101 | save_file(state_dict, file, metadata) 102 | else: 103 | torch.save(state_dict, file) 104 | 105 | def load_weights(self, file): 106 | if os.path.splitext(file)[1] == ".safetensors": 107 | from safetensors.torch import load_file 108 | 109 | data = load_file(file) 110 | else: 111 | data = torch.load(file, map_location="cpu") 112 | 113 | emb_l = data.get("clip_l", None) # ViT-L text encoder 1 114 | emb_g = data.get("clip_g", None) # BiG-G text encoder 2 115 | 116 | assert ( 117 | emb_l is not None or emb_g is not None 118 | ), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}" 119 | 120 | return [emb_l, emb_g] 121 | 122 | 123 | def setup_parser() -> argparse.ArgumentParser: 124 | parser = train_textual_inversion.setup_parser() 125 | # don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching 126 | # sdxl_train_util.add_sdxl_training_arguments(parser) 127 | return parser 128 | 129 | 130 | if __name__ == "__main__": 131 | parser = setup_parser() 132 | 133 | args = parser.parse_args() 134 | train_util.verify_command_line_training_args(args) 135 | args = train_util.read_config_from_file(args, parser) 136 | 137 | trainer = SdxlTextualInversionTrainer() 138 | trainer.train(args) 139 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name = "library", packages = find_packages()) -------------------------------------------------------------------------------- /tools/cache_latents.py: -------------------------------------------------------------------------------- 1 | # latentsのdiskへの事前キャッシュを行う / cache latents to disk 2 | 3 | import argparse 4 | import math 5 | from multiprocessing import Value 6 | import os 7 | 8 | from accelerate.utils import set_seed 9 | import torch 10 | from tqdm import tqdm 11 | 12 | from library import config_util 13 | from library import train_util 14 | from library import sdxl_train_util 15 | from library.config_util import ( 16 | ConfigSanitizer, 17 | BlueprintGenerator, 18 | ) 19 | from library.utils import setup_logging, add_logging_arguments 20 | setup_logging() 21 | import logging 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def cache_to_disk(args: argparse.Namespace) -> None: 27 | setup_logging(args, reset=True) 28 | train_util.prepare_dataset_args(args, True) 29 | 30 | # check cache latents arg 31 | assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" 32 | 33 | use_dreambooth_method = args.in_json is None 34 | 35 | if args.seed is not None: 36 | set_seed(args.seed) # 乱数系列を初期化する 37 | 38 | # tokenizerを準備する:datasetを動かすために必要 39 | if args.sdxl: 40 | tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) 41 | tokenizers = [tokenizer1, tokenizer2] 42 | else: 43 | tokenizer = train_util.load_tokenizer(args) 44 | tokenizers = [tokenizer] 45 | 46 | # データセットを準備する 47 | if args.dataset_class is None: 48 | blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) 49 | if args.dataset_config is not None: 50 | logger.info(f"Load dataset config from {args.dataset_config}") 51 | user_config = config_util.load_user_config(args.dataset_config) 52 | ignored = ["train_data_dir", "in_json"] 53 | if any(getattr(args, attr) is not None for attr in ignored): 54 | logger.warning( 55 | "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( 56 | ", ".join(ignored) 57 | ) 58 | ) 59 | else: 60 | if use_dreambooth_method: 61 | logger.info("Using DreamBooth method.") 62 | user_config = { 63 | "datasets": [ 64 | { 65 | "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( 66 | args.train_data_dir, args.reg_data_dir 67 | ) 68 | } 69 | ] 70 | } 71 | else: 72 | logger.info("Training with captions.") 73 | user_config = { 74 | "datasets": [ 75 | { 76 | "subsets": [ 77 | { 78 | "image_dir": args.train_data_dir, 79 | "metadata_file": args.in_json, 80 | } 81 | ] 82 | } 83 | ] 84 | } 85 | 86 | blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) 87 | train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) 88 | else: 89 | train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) 90 | 91 | # datasetのcache_latentsを呼ばなければ、生の画像が返る 92 | 93 | current_epoch = Value("i", 0) 94 | current_step = Value("i", 0) 95 | ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None 96 | collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) 97 | 98 | # acceleratorを準備する 99 | logger.info("prepare accelerator") 100 | args.deepspeed = False 101 | accelerator = train_util.prepare_accelerator(args) 102 | 103 | # mixed precisionに対応した型を用意しておき適宜castする 104 | weight_dtype, _ = train_util.prepare_dtype(args) 105 | vae_dtype = torch.float32 if args.no_half_vae else weight_dtype 106 | 107 | # モデルを読み込む 108 | logger.info("load model") 109 | if args.sdxl: 110 | (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) 111 | else: 112 | _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) 113 | 114 | if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える 115 | vae.set_use_memory_efficient_attention_xformers(args.xformers) 116 | vae.to(accelerator.device, dtype=vae_dtype) 117 | vae.requires_grad_(False) 118 | vae.eval() 119 | 120 | # dataloaderを準備する 121 | train_dataset_group.set_caching_mode("latents") 122 | 123 | # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 124 | n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers 125 | 126 | train_dataloader = torch.utils.data.DataLoader( 127 | train_dataset_group, 128 | batch_size=1, 129 | shuffle=True, 130 | collate_fn=collator, 131 | num_workers=n_workers, 132 | persistent_workers=args.persistent_data_loader_workers, 133 | ) 134 | 135 | # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず 136 | train_dataloader = accelerator.prepare(train_dataloader) 137 | 138 | # データ取得のためのループ 139 | for batch in tqdm(train_dataloader): 140 | b_size = len(batch["images"]) 141 | vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size 142 | flip_aug = batch["flip_aug"] 143 | alpha_mask = batch["alpha_mask"] 144 | random_crop = batch["random_crop"] 145 | bucket_reso = batch["bucket_reso"] 146 | 147 | # バッチを分割して処理する 148 | for i in range(0, b_size, vae_batch_size): 149 | images = batch["images"][i : i + vae_batch_size] 150 | absolute_paths = batch["absolute_paths"][i : i + vae_batch_size] 151 | resized_sizes = batch["resized_sizes"][i : i + vae_batch_size] 152 | 153 | image_infos = [] 154 | for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)): 155 | image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) 156 | image_info.image = image 157 | image_info.bucket_reso = bucket_reso 158 | image_info.resized_size = resized_size 159 | image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" 160 | 161 | if args.skip_existing: 162 | if train_util.is_disk_cached_latents_is_expected( 163 | image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask 164 | ): 165 | logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") 166 | continue 167 | 168 | image_infos.append(image_info) 169 | 170 | if len(image_infos) > 0: 171 | train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop) 172 | 173 | accelerator.wait_for_everyone() 174 | accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") 175 | 176 | 177 | def setup_parser() -> argparse.ArgumentParser: 178 | parser = argparse.ArgumentParser() 179 | 180 | add_logging_arguments(parser) 181 | train_util.add_sd_models_arguments(parser) 182 | train_util.add_training_arguments(parser, True) 183 | train_util.add_dataset_arguments(parser, True, True, True) 184 | config_util.add_config_arguments(parser) 185 | parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") 186 | parser.add_argument( 187 | "--no_half_vae", 188 | action="store_true", 189 | help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", 190 | ) 191 | parser.add_argument( 192 | "--skip_existing", 193 | action="store_true", 194 | help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", 195 | ) 196 | return parser 197 | 198 | 199 | if __name__ == "__main__": 200 | parser = setup_parser() 201 | 202 | args = parser.parse_args() 203 | args = train_util.read_config_from_file(args, parser) 204 | 205 | cache_to_disk(args) 206 | -------------------------------------------------------------------------------- /tools/cache_text_encoder_outputs.py: -------------------------------------------------------------------------------- 1 | # text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance 2 | 3 | import argparse 4 | import math 5 | from multiprocessing import Value 6 | import os 7 | 8 | from accelerate.utils import set_seed 9 | import torch 10 | from tqdm import tqdm 11 | 12 | from library import config_util 13 | from library import train_util 14 | from library import sdxl_train_util 15 | from library.config_util import ( 16 | ConfigSanitizer, 17 | BlueprintGenerator, 18 | ) 19 | from library.utils import setup_logging, add_logging_arguments 20 | setup_logging() 21 | import logging 22 | logger = logging.getLogger(__name__) 23 | 24 | def cache_to_disk(args: argparse.Namespace) -> None: 25 | setup_logging(args, reset=True) 26 | train_util.prepare_dataset_args(args, True) 27 | 28 | # check cache arg 29 | assert ( 30 | args.cache_text_encoder_outputs_to_disk 31 | ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります" 32 | 33 | # できるだけ準備はしておくが今のところSDXLのみしか動かない 34 | assert ( 35 | args.sdxl 36 | ), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です" 37 | 38 | use_dreambooth_method = args.in_json is None 39 | 40 | if args.seed is not None: 41 | set_seed(args.seed) # 乱数系列を初期化する 42 | 43 | # tokenizerを準備する:datasetを動かすために必要 44 | if args.sdxl: 45 | tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) 46 | tokenizers = [tokenizer1, tokenizer2] 47 | else: 48 | tokenizer = train_util.load_tokenizer(args) 49 | tokenizers = [tokenizer] 50 | 51 | # データセットを準備する 52 | if args.dataset_class is None: 53 | blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) 54 | if args.dataset_config is not None: 55 | logger.info(f"Load dataset config from {args.dataset_config}") 56 | user_config = config_util.load_user_config(args.dataset_config) 57 | ignored = ["train_data_dir", "in_json"] 58 | if any(getattr(args, attr) is not None for attr in ignored): 59 | logger.warning( 60 | "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( 61 | ", ".join(ignored) 62 | ) 63 | ) 64 | else: 65 | if use_dreambooth_method: 66 | logger.info("Using DreamBooth method.") 67 | user_config = { 68 | "datasets": [ 69 | { 70 | "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( 71 | args.train_data_dir, args.reg_data_dir 72 | ) 73 | } 74 | ] 75 | } 76 | else: 77 | logger.info("Training with captions.") 78 | user_config = { 79 | "datasets": [ 80 | { 81 | "subsets": [ 82 | { 83 | "image_dir": args.train_data_dir, 84 | "metadata_file": args.in_json, 85 | } 86 | ] 87 | } 88 | ] 89 | } 90 | 91 | blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) 92 | train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) 93 | else: 94 | train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) 95 | 96 | current_epoch = Value("i", 0) 97 | current_step = Value("i", 0) 98 | ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None 99 | collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) 100 | 101 | # acceleratorを準備する 102 | logger.info("prepare accelerator") 103 | args.deepspeed = False 104 | accelerator = train_util.prepare_accelerator(args) 105 | 106 | # mixed precisionに対応した型を用意しておき適宜castする 107 | weight_dtype, _ = train_util.prepare_dtype(args) 108 | 109 | # モデルを読み込む 110 | logger.info("load model") 111 | if args.sdxl: 112 | (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) 113 | text_encoders = [text_encoder1, text_encoder2] 114 | else: 115 | text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) 116 | text_encoders = [text_encoder1] 117 | 118 | for text_encoder in text_encoders: 119 | text_encoder.to(accelerator.device, dtype=weight_dtype) 120 | text_encoder.requires_grad_(False) 121 | text_encoder.eval() 122 | 123 | # dataloaderを準備する 124 | train_dataset_group.set_caching_mode("text") 125 | 126 | # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 127 | n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers 128 | 129 | train_dataloader = torch.utils.data.DataLoader( 130 | train_dataset_group, 131 | batch_size=1, 132 | shuffle=True, 133 | collate_fn=collator, 134 | num_workers=n_workers, 135 | persistent_workers=args.persistent_data_loader_workers, 136 | ) 137 | 138 | # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず 139 | train_dataloader = accelerator.prepare(train_dataloader) 140 | 141 | # データ取得のためのループ 142 | for batch in tqdm(train_dataloader): 143 | absolute_paths = batch["absolute_paths"] 144 | input_ids1_list = batch["input_ids1_list"] 145 | input_ids2_list = batch["input_ids2_list"] 146 | 147 | image_infos = [] 148 | for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list): 149 | image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) 150 | image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX 151 | image_info 152 | 153 | if args.skip_existing: 154 | if os.path.exists(image_info.text_encoder_outputs_npz): 155 | logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") 156 | continue 157 | 158 | image_info.input_ids1 = input_ids1 159 | image_info.input_ids2 = input_ids2 160 | image_infos.append(image_info) 161 | 162 | if len(image_infos) > 0: 163 | b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos]) 164 | b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos]) 165 | train_util.cache_batch_text_encoder_outputs( 166 | image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype 167 | ) 168 | 169 | accelerator.wait_for_everyone() 170 | accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") 171 | 172 | 173 | def setup_parser() -> argparse.ArgumentParser: 174 | parser = argparse.ArgumentParser() 175 | 176 | add_logging_arguments(parser) 177 | train_util.add_sd_models_arguments(parser) 178 | train_util.add_training_arguments(parser, True) 179 | train_util.add_dataset_arguments(parser, True, True, True) 180 | config_util.add_config_arguments(parser) 181 | sdxl_train_util.add_sdxl_training_arguments(parser) 182 | parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する") 183 | parser.add_argument( 184 | "--skip_existing", 185 | action="store_true", 186 | help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", 187 | ) 188 | return parser 189 | 190 | 191 | if __name__ == "__main__": 192 | parser = setup_parser() 193 | 194 | args = parser.parse_args() 195 | args = train_util.read_config_from_file(args, parser) 196 | 197 | cache_to_disk(args) 198 | -------------------------------------------------------------------------------- /tools/canny.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | 4 | import logging 5 | from library.utils import setup_logging 6 | setup_logging() 7 | logger = logging.getLogger(__name__) 8 | 9 | def canny(args): 10 | img = cv2.imread(args.input) 11 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 12 | 13 | canny_img = cv2.Canny(img, args.thres1, args.thres2) 14 | # canny_img = 255 - canny_img 15 | 16 | cv2.imwrite(args.output, canny_img) 17 | logger.info("done!") 18 | 19 | 20 | def setup_parser() -> argparse.ArgumentParser: 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--input", type=str, default=None, help="input path") 23 | parser.add_argument("--output", type=str, default=None, help="output path") 24 | parser.add_argument("--thres1", type=int, default=32, help="thres1") 25 | parser.add_argument("--thres2", type=int, default=224, help="thres2") 26 | 27 | return parser 28 | 29 | 30 | if __name__ == '__main__': 31 | parser = setup_parser() 32 | 33 | args = parser.parse_args() 34 | canny(args) 35 | -------------------------------------------------------------------------------- /tools/convert_diffusers20_original_sd.py: -------------------------------------------------------------------------------- 1 | # convert Diffusers v1.x/v2.0 model to original Stable Diffusion 2 | 3 | import argparse 4 | import os 5 | import torch 6 | from diffusers import StableDiffusionPipeline 7 | 8 | import library.model_util as model_util 9 | from library.utils import setup_logging 10 | setup_logging() 11 | import logging 12 | logger = logging.getLogger(__name__) 13 | 14 | def convert(args): 15 | # 引数を確認する 16 | load_dtype = torch.float16 if args.fp16 else None 17 | 18 | save_dtype = None 19 | if args.fp16 or args.save_precision_as == "fp16": 20 | save_dtype = torch.float16 21 | elif args.bf16 or args.save_precision_as == "bf16": 22 | save_dtype = torch.bfloat16 23 | elif args.float or args.save_precision_as == "float": 24 | save_dtype = torch.float 25 | 26 | is_load_ckpt = os.path.isfile(args.model_to_load) 27 | is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 28 | 29 | assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" 30 | # assert ( 31 | # is_save_ckpt or args.reference_model is not None 32 | # ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" 33 | 34 | # モデルを読み込む 35 | msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) 36 | logger.info(f"loading {msg}: {args.model_to_load}") 37 | 38 | if is_load_ckpt: 39 | v2_model = args.v2 40 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint( 41 | v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection 42 | ) 43 | else: 44 | pipe = StableDiffusionPipeline.from_pretrained( 45 | args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant 46 | ) 47 | text_encoder = pipe.text_encoder 48 | vae = pipe.vae 49 | unet = pipe.unet 50 | 51 | if args.v1 == args.v2: 52 | # 自動判定する 53 | v2_model = unet.config.cross_attention_dim == 1024 54 | logger.info("checking model version: model is " + ("v2" if v2_model else "v1")) 55 | else: 56 | v2_model = not args.v1 57 | 58 | # 変換して保存する 59 | msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" 60 | logger.info(f"converting and saving as {msg}: {args.model_to_save}") 61 | 62 | if is_save_ckpt: 63 | original_model = args.model_to_load if is_load_ckpt else None 64 | key_count = model_util.save_stable_diffusion_checkpoint( 65 | v2_model, 66 | args.model_to_save, 67 | text_encoder, 68 | unet, 69 | original_model, 70 | args.epoch, 71 | args.global_step, 72 | None if args.metadata is None else eval(args.metadata), 73 | save_dtype=save_dtype, 74 | vae=vae, 75 | ) 76 | logger.info(f"model saved. total converted state_dict keys: {key_count}") 77 | else: 78 | logger.info( 79 | f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}" 80 | ) 81 | model_util.save_diffusers_checkpoint( 82 | v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors 83 | ) 84 | logger.info("model saved.") 85 | 86 | 87 | def setup_parser() -> argparse.ArgumentParser: 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument( 90 | "--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む" 91 | ) 92 | parser.add_argument( 93 | "--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む" 94 | ) 95 | parser.add_argument( 96 | "--unet_use_linear_projection", 97 | action="store_true", 98 | help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)", 99 | ) 100 | parser.add_argument( 101 | "--fp16", 102 | action="store_true", 103 | help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)", 104 | ) 105 | parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)") 106 | parser.add_argument( 107 | "--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)" 108 | ) 109 | parser.add_argument( 110 | "--save_precision_as", 111 | type=str, 112 | default="no", 113 | choices=["fp16", "bf16", "float"], 114 | help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください", 115 | ) 116 | parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値") 117 | parser.add_argument( 118 | "--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値" 119 | ) 120 | parser.add_argument( 121 | "--metadata", 122 | type=str, 123 | default=None, 124 | help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'', 125 | ) 126 | parser.add_argument( 127 | "--variant", 128 | type=str, 129 | default=None, 130 | help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16", 131 | ) 132 | parser.add_argument( 133 | "--reference_model", 134 | type=str, 135 | default=None, 136 | help="scheduler/tokenizerのコピー元Diffusersモデル、Diffusers形式で保存するときに使用される、省略時は`runwayml/stable-diffusion-v1-5` または `stabilityai/stable-diffusion-2-1` / reference Diffusers model to copy scheduler/tokenizer config from, used when saving as Diffusers format, default is `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1`", 137 | ) 138 | parser.add_argument( 139 | "--use_safetensors", 140 | action="store_true", 141 | help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)", 142 | ) 143 | 144 | parser.add_argument( 145 | "model_to_load", 146 | type=str, 147 | default=None, 148 | help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ", 149 | ) 150 | parser.add_argument( 151 | "model_to_save", 152 | type=str, 153 | default=None, 154 | help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存", 155 | ) 156 | return parser 157 | 158 | 159 | if __name__ == "__main__": 160 | parser = setup_parser() 161 | 162 | args = parser.parse_args() 163 | convert(args) 164 | -------------------------------------------------------------------------------- /tools/merge_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from safetensors import safe_open 6 | from safetensors.torch import load_file, save_file 7 | from tqdm import tqdm 8 | from library.utils import setup_logging 9 | setup_logging() 10 | import logging 11 | logger = logging.getLogger(__name__) 12 | 13 | def is_unet_key(key): 14 | # VAE or TextEncoder, the last one is for SDXL 15 | return not ("first_stage_model" in key or "cond_stage_model" in key or "conditioner." in key) 16 | 17 | 18 | TEXT_ENCODER_KEY_REPLACEMENTS = [ 19 | ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), 20 | ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), 21 | ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), 22 | ] 23 | 24 | 25 | # support for models with different text encoder keys 26 | def replace_text_encoder_key(key): 27 | for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: 28 | if key.startswith(rep_from): 29 | return True, rep_to + key[len(rep_from) :] 30 | return False, key 31 | 32 | 33 | def merge(args): 34 | if args.precision == "fp16": 35 | dtype = torch.float16 36 | elif args.precision == "bf16": 37 | dtype = torch.bfloat16 38 | else: 39 | dtype = torch.float 40 | 41 | if args.saving_precision == "fp16": 42 | save_dtype = torch.float16 43 | elif args.saving_precision == "bf16": 44 | save_dtype = torch.bfloat16 45 | else: 46 | save_dtype = torch.float 47 | 48 | # check if all models are safetensors 49 | for model in args.models: 50 | if not model.endswith("safetensors"): 51 | logger.info(f"Model {model} is not a safetensors model") 52 | exit() 53 | if not os.path.isfile(model): 54 | logger.info(f"Model {model} does not exist") 55 | exit() 56 | 57 | assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models" 58 | 59 | # load and merge 60 | ratio = 1.0 / len(args.models) # default 61 | supplementary_key_ratios = {} # [key] = ratio, for keys not in all models, add later 62 | 63 | merged_sd = None 64 | first_model_keys = set() # check missing keys in other models 65 | for i, model in enumerate(args.models): 66 | if args.ratios is not None: 67 | ratio = args.ratios[i] 68 | 69 | if merged_sd is None: 70 | # load first model 71 | logger.info(f"Loading model {model}, ratio = {ratio}...") 72 | merged_sd = {} 73 | with safe_open(model, framework="pt", device=args.device) as f: 74 | for key in tqdm(f.keys()): 75 | value = f.get_tensor(key) 76 | _, key = replace_text_encoder_key(key) 77 | 78 | first_model_keys.add(key) 79 | 80 | if not is_unet_key(key) and args.unet_only: 81 | supplementary_key_ratios[key] = 1.0 # use first model's value for VAE or TextEncoder 82 | continue 83 | 84 | value = ratio * value.to(dtype) # first model's value * ratio 85 | merged_sd[key] = value 86 | 87 | logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) 88 | continue 89 | 90 | # load other models 91 | logger.info(f"Loading model {model}, ratio = {ratio}...") 92 | 93 | with safe_open(model, framework="pt", device=args.device) as f: 94 | model_keys = f.keys() 95 | for key in tqdm(model_keys): 96 | _, new_key = replace_text_encoder_key(key) 97 | if new_key not in merged_sd: 98 | if args.show_skipped and new_key not in first_model_keys: 99 | logger.info(f"Skip: {new_key}") 100 | continue 101 | 102 | value = f.get_tensor(key) 103 | merged_sd[new_key] = merged_sd[new_key] + ratio * value.to(dtype) 104 | 105 | # enumerate keys not in this model 106 | model_keys = set(model_keys) 107 | for key in merged_sd.keys(): 108 | if key in model_keys: 109 | continue 110 | logger.warning(f"Key {key} not in model {model}, use first model's value") 111 | if key in supplementary_key_ratios: 112 | supplementary_key_ratios[key] += ratio 113 | else: 114 | supplementary_key_ratios[key] = ratio 115 | 116 | # add supplementary keys' value (including VAE and TextEncoder) 117 | if len(supplementary_key_ratios) > 0: 118 | logger.info("add first model's value") 119 | with safe_open(args.models[0], framework="pt", device=args.device) as f: 120 | for key in tqdm(f.keys()): 121 | _, new_key = replace_text_encoder_key(key) 122 | if new_key not in supplementary_key_ratios: 123 | continue 124 | 125 | if is_unet_key(new_key): # not VAE or TextEncoder 126 | logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") 127 | 128 | value = f.get_tensor(key) # original key 129 | 130 | if new_key not in merged_sd: 131 | merged_sd[new_key] = supplementary_key_ratios[new_key] * value.to(dtype) 132 | else: 133 | merged_sd[new_key] = merged_sd[new_key] + supplementary_key_ratios[new_key] * value.to(dtype) 134 | 135 | # save 136 | output_file = args.output 137 | if not output_file.endswith(".safetensors"): 138 | output_file = output_file + ".safetensors" 139 | 140 | logger.info(f"Saving to {output_file}...") 141 | 142 | # convert to save_dtype 143 | for k in merged_sd.keys(): 144 | merged_sd[k] = merged_sd[k].to(save_dtype) 145 | 146 | save_file(merged_sd, output_file) 147 | 148 | logger.info("Done!") 149 | 150 | 151 | if __name__ == "__main__": 152 | parser = argparse.ArgumentParser(description="Merge models") 153 | parser.add_argument("--models", nargs="+", type=str, help="Models to merge") 154 | parser.add_argument("--output", type=str, help="Output model") 155 | parser.add_argument("--ratios", nargs="+", type=float, help="Ratios of models, default is equal, total = 1.0") 156 | parser.add_argument("--unet_only", action="store_true", help="Only merge unet") 157 | parser.add_argument("--device", type=str, default="cpu", help="Device to use, default is cpu") 158 | parser.add_argument( 159 | "--precision", type=str, default="float", choices=["float", "fp16", "bf16"], help="Calculation precision, default is float" 160 | ) 161 | parser.add_argument( 162 | "--saving_precision", 163 | type=str, 164 | default="float", 165 | choices=["float", "fp16", "bf16"], 166 | help="Saving precision, default is float", 167 | ) 168 | parser.add_argument("--show_skipped", action="store_true", help="Show skipped keys (keys not in first model)") 169 | 170 | args = parser.parse_args() 171 | merge(args) 172 | -------------------------------------------------------------------------------- /tools/resize_images_to_resolution.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import cv2 4 | import argparse 5 | import shutil 6 | import math 7 | from PIL import Image 8 | import numpy as np 9 | from library.utils import setup_logging, pil_resize 10 | setup_logging() 11 | import logging 12 | logger = logging.getLogger(__name__) 13 | 14 | def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): 15 | # Split the max_resolution string by "," and strip any whitespaces 16 | max_resolutions = [res.strip() for res in max_resolution.split(',')] 17 | 18 | # # Calculate max_pixels from max_resolution string 19 | # max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) 20 | 21 | # Create destination folder if it does not exist 22 | if not os.path.exists(dst_img_folder): 23 | os.makedirs(dst_img_folder) 24 | 25 | # Select interpolation method 26 | if interpolation == 'lanczos4': 27 | pil_interpolation = Image.LANCZOS 28 | elif interpolation == 'cubic': 29 | pil_interpolation = Image.BICUBIC 30 | else: 31 | cv2_interpolation = cv2.INTER_AREA 32 | 33 | # Iterate through all files in src_img_folder 34 | img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py 35 | for filename in os.listdir(src_img_folder): 36 | # Check if the image is png, jpg or webp etc... 37 | if not filename.endswith(img_exts): 38 | # Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.) 39 | shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename)) 40 | continue 41 | 42 | # Load image 43 | # img = cv2.imread(os.path.join(src_img_folder, filename)) 44 | image = Image.open(os.path.join(src_img_folder, filename)) 45 | if not image.mode == "RGB": 46 | image = image.convert("RGB") 47 | img = np.array(image, np.uint8) 48 | 49 | base, _ = os.path.splitext(filename) 50 | for max_resolution in max_resolutions: 51 | # Calculate max_pixels from max_resolution string 52 | max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1]) 53 | 54 | # Calculate current number of pixels 55 | current_pixels = img.shape[0] * img.shape[1] 56 | 57 | # Check if the image needs resizing 58 | if current_pixels > max_pixels: 59 | # Calculate scaling factor 60 | scale_factor = max_pixels / current_pixels 61 | 62 | # Calculate new dimensions 63 | new_height = int(img.shape[0] * math.sqrt(scale_factor)) 64 | new_width = int(img.shape[1] * math.sqrt(scale_factor)) 65 | 66 | # Resize image 67 | if cv2_interpolation: 68 | img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) 69 | else: 70 | img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation) 71 | else: 72 | new_height, new_width = img.shape[0:2] 73 | 74 | # Calculate the new height and width that are divisible by divisible_by (with/without resizing) 75 | new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by 76 | new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by 77 | 78 | # Center crop the image to the calculated dimensions 79 | y = int((img.shape[0] - new_height) / 2) 80 | x = int((img.shape[1] - new_width) / 2) 81 | img = img[y:y + new_height, x:x + new_width] 82 | 83 | # Split filename into base and extension 84 | new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg') 85 | 86 | # Save resized image in dst_img_folder 87 | # cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) 88 | image = Image.fromarray(img) 89 | image.save(os.path.join(dst_img_folder, new_filename), quality=100) 90 | 91 | proc = "Resized" if current_pixels > max_pixels else "Saved" 92 | logger.info(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") 93 | 94 | # If other files with same basename, copy them with resolution suffix 95 | if copy_associated_files: 96 | asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*")) 97 | for asoc_file in asoc_files: 98 | ext = os.path.splitext(asoc_file)[1] 99 | if ext in img_exts: 100 | continue 101 | for max_resolution in max_resolutions: 102 | new_asoc_file = base + '+' + max_resolution + ext 103 | logger.info(f"Copy {asoc_file} as {new_asoc_file}") 104 | shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) 105 | 106 | 107 | def setup_parser() -> argparse.ArgumentParser: 108 | parser = argparse.ArgumentParser( 109 | description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします') 110 | parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ') 111 | parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ') 112 | parser.add_argument('--max_resolution', type=str, 113 | help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128") 114 | parser.add_argument('--divisible_by', type=int, 115 | help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1) 116 | parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'], 117 | default='area', help='Interpolation method for resizing / リサイズ時の補完方法') 118 | parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存') 119 | parser.add_argument('--copy_associated_files', action='store_true', 120 | help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする') 121 | 122 | return parser 123 | 124 | 125 | def main(): 126 | parser = setup_parser() 127 | 128 | args = parser.parse_args() 129 | resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution, 130 | args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files) 131 | 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /tools/show_metadata.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from safetensors import safe_open 4 | from library.utils import setup_logging 5 | setup_logging() 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--model", type=str, required=True) 11 | args = parser.parse_args() 12 | 13 | with safe_open(args.model, framework="pt") as f: 14 | metadata = f.metadata() 15 | 16 | if metadata is None: 17 | logger.error("No metadata found") 18 | else: 19 | # metadata is json dict, but not pretty printed 20 | # sort by key and pretty print 21 | print(json.dumps(metadata, indent=4, sort_keys=True)) 22 | 23 | 24 | --------------------------------------------------------------------------------