├── .github └── FUNDING.yml ├── .gitignore ├── LICENSE ├── README.md ├── README_ja.md ├── calcmode_en.md ├── calcmode_ja.md ├── changelog.md ├── elemental_en.md ├── elemental_ja.md ├── install.py ├── sample.txt ├── scripts ├── A1111 │ ├── lora_patches.py │ ├── lyco_helpers.py │ ├── network.py │ ├── network_full.py │ ├── network_glora.py │ ├── network_hada.py │ ├── network_ia3.py │ ├── network_lokr.py │ ├── network_lora.py │ ├── network_norm.py │ ├── network_oft.py │ └── networks.py ├── GenParamGetter.py ├── Roboto-Regular.ttf ├── kohyas │ ├── extract_lora_from_models.py │ ├── ipex │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── diffusers.py │ │ ├── gradscaler.py │ │ └── hijacks.py │ ├── lora.py │ ├── merge_lora.py │ ├── model_util.py │ ├── original_unet.py │ ├── sai_model_spec.py │ ├── sdxl_merge_lora.py │ ├── sdxl_model_util.py │ ├── sdxl_original_unet.py │ ├── svd_merge_lora.py │ └── train_util.py ├── mbwpresets_master.txt ├── mergers │ ├── bcolors.py │ ├── components.py │ ├── mergers.py │ ├── model_util.py │ ├── pluslora.py │ └── xyplot.py └── supermerger.py └── style.css /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github:[hako-mikan] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | changelog.m 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /README_ja.md: -------------------------------------------------------------------------------- 1 | # SuperMerger 2 | - [AUTOMATIC1111's stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 用のモデルマージ拡張 3 | - マージしたモデルを保存せず直接生成に使用できます 4 | 5 | [](README.md) 6 | [](#overview) 7 | [](https://github.com/sponsors/hako-mikan) 8 | 9 | 10 | # Overview 11 | このextentionではモデルをマージした際、保存せずに画像生成用のモデルとして読み込むことができます。 12 | これまでマージしたモデルはいったん保存して気に入らなければ削除するということが必要でしたが、このextentionを使うことでHDDやSSDの消耗を防ぐことができます。 13 | モデルの保存とロードの時間を節約できるため、比率を変更しながら連続生成することによりモデルマージの効率を大幅に向上させます。 14 | 15 | # もくじ 16 | - [Merge Models](#merge-models) 17 | - [Merge Block Weight](#merge-block-weight) 18 | - [XYZ Plot](#xyz-plot) 19 | - [Adjust](#adjust) 20 | - [Let the Dice Roll](#let-the-dice-roll) 21 | - [Elemental Merge](#elemental-merge) 22 | - [Generation Parameters](#generation-parameters) 23 | - [LoRA](#lora) 24 | - [Merge LoRAs](#merge-loras) 25 | - [Merge to Checipoint](#merge-to-checkpoint) 26 | - [Extract from Checkpoints](#extract-from-checkpoints) 27 | - [Other Tabs](#other-tabs) 28 | 29 | - [Calcomode](calcmode_ja.md) 30 | - [Elemental Merge](elemental_ja.md) 31 | 32 | 33 | # Recent Update 34 | 2023.11.10 35 | - LoRA抽出スクリプトの変更: SDXL,2.Xをサポート 36 | - SD2.XでのモデルへのLoRAマージをサポート 37 | - バグフィックス 38 | - 新しい計算方式extract merge(checkpoint/LoRA)を追加。作成した[subaqua](https://github.com/sbq0)氏に感謝します。 39 | 40 | ## 計算関数の変更 41 | いくつかのモードでは、計算に使用する機能が変更され、マージ計算の速度が向上しています。計算結果は同じですが、同じ結果が得られない場合は、オプションの「use old calc method」をチェックしてください。影響を受ける方式は以下の通りです: 42 | Weight Sum: normal, cosineA, cosineB 43 | Sum Twice:normal 44 | 提案した[wkpark](https://github.com/wkpark)氏に感謝します。 45 | 46 | **注意!** 47 | XLモデルのマージには最低64GBのCPUメモリが必要です。64Gのメモリであっても併用しているソフトによってはシステムが不安定になる恐れがあるのでシステムが落ちてもいい状態で作業して下さい。私は久しぶりにブルースクリーンに遭遇しました。 48 | 49 | ## 既知の問題 50 | 他の拡張機能(sd-webui-prompt-all-in-oneなど)を同時にインストールしている場合、起動時にブラウザを自動的に開くオプションを有効にすると、動作が不安定になることがあります。Gradioの問題である可能性が高いので、修正は難しいです。そのオプションを無効にしてお使いください。 51 | 52 | すべての更新は[ここ](changelog.md) で確認できます。 53 | 54 | # つかいかた 55 | ## Merge Models 56 | ここでマージされたモデルは、Web-UIの生成モデルとしてロードされます。左上のモデル表示は変わりませんが、マージされたモデルは実際にロードされています。別のモデルが左上のモデル選択から選択されるまで、マージされたモデルはロードされたままになります。 57 | ### Basic Usage 58 | Select models A/B/(C), the merge mode, and alpha (beta), then press Merge/Merge and Gen to start the merging process. In the case of Merge and Gen, generation is carried out using the prompt and other settings specified in txt2img.The Gen button only generates images, and the Stop button interrupts the merging process. 59 | モデルA/B/(C)、merge mode、alpha (beta)を選択し、Merge/Merge and Genを押すとマージ処理が始まります。Merge and Genの場合は、txt2imgで指定されたプロンプトやその他の設定を使用して生成が行われます。Genボタンは画像のみを生成し、Stopボタンはマージを中断します。 60 | 61 | ### Load Settings From: 62 | マージログから設定を読み込みます。マージが行われるたびにログが更新され、1から始まる連続IDが割り当てられます。"-1"は最後のマージからの設定に対応し、"-2"は最後から二番目のものに対応します。マージログはextension/sd-webui-supermerger/mergehistory.csvに保存されます。Historyタブで閲覧や検索ができます。半角スペースで区切ってand/orで検索できます。 63 | ### Clear Cache 64 | Web-UIのモデルキャッシュ機能が有効になっている場合、SuperMergerは連続マージを高速化するためにモデルキャッシュを作成します。モデルはWeb-UIのキャッシュ機能とは別にキャッシュされます。使用後にキャッシュを削除するにはこのボタンを使用してください。キャッシュはVRAMではなくRAMに作成されます。 65 | 66 | # 各種設定 67 | ## マージモード 68 | ### Weight sum 69 | 通常のマージです。alphaが使用されます。α=0の場合Model A, α=1 の時model Bになります。 70 | ### Add difference 71 | 差分マージです。 72 | ### Triple sum 73 | マージを3モデル同時に行います。alpha,betaが使用されます。モデル選択窓が3つあったので追加した機能ですが、ちゃんと動くようです。MBWでも使えます。それぞれMBWのalpha,betaを入力してください。 74 | ### sum Twice 75 | Weight sumを2回行います。alpha,betaが使用されます。MBWモードでも使えます。それぞれMBWのalpha,betaを入力してください。 76 | 77 | ### use MBW 78 | チェックするとブロックごとのマージ(階層マージ)が有効になります。各ブロックごとの比率は下部のスライダーかプリセットで設定してください。 79 | 80 | 81 | ### Merge mode 82 | #### Weight sum $(1-\alpha) A + \alpha B$ 83 | 通常のマージです。alphaが使用されます。$\alpha$=0の場合Model A, $\alpha$=1 の時model Bになります。 84 | #### Add difference $A + \alpha (B-C)$ 85 | 差分を加算します。MBWが有効な場合、$\alpha$としてMBWベースが使用されます。 86 | #### Triple sum $(1-\alpha - \beta) A + \alpha B + \beta C$ 87 | 同時に3つのモデルをマージします。$\alpha$と$\beta$が使用されます。3つのモデル選択ウィンドウがあったためこの機能を追加しましたが、実際に効果的に動作するかはわかりません。 88 | #### sum Twice $(1-\beta)((1-\alpha)A+\alpha B)+\beta C$ 89 | Weight sumを2回行います。$\alpha$と$\beta$が使用されます。 90 | 91 | ### calcmode 92 | 各計算方法の詳細については[リンク先](calcmode_ja.md)を参照してください。計算方法とマージモードの対応表は以下の通りです。 93 | | Calcmode | Description | Merge Mode | 94 | |----|----|----| 95 | |normal | 通常の計算方法 | ALL | 96 | |cosineA | モデルAを基準にマージ中の損失を最小限にする計算を行います。 | Weight sum | 97 | |cosineB | モデルBを基準にマージ中の損失を最小限にする計算を行います。 | Weight sum | 98 | |trainDifference |モデルAに対してファインチューニングするかのように差分を'トレーニング'します。 | Add difference | 99 | |smoothAdd | 中央値フィルタとガウスフィルタの利点を混合した差分の追加 | Add difference | 100 | |smoothAdd MT| マルチスレッドを使用して計算を高速化します。 | Add difference | 101 | |extract | モデルB/Cの共通点・非共通点を抽出して追加します. | Add difference | 102 | |tensor| 加算の代わりにテンソルを比率で入れ替えます | Weight sum | 103 | |tensor2 |テンソルの次元が大きい場合、2次元目を基準にして入れ替えます | Weight sum | 104 | |self | モデル自身に$\alpha$を掛け合わせます | Weight sum | 105 | 106 | ### use MBW 107 | 階層マーを有効にします。Merge Block Weightで重みを設定してください。これを有効にすると、アルファとベータが無効になります。 108 | 109 | ### Options 110 | | 設定 | 説明 | 111 | |-----------------|---------------------------------------------------| 112 | | save model | マージ後のモデルを保存します。 | 113 | | overwrite | モデルの上書きを有効にします。 | 114 | | safetensors | safetensors形式で保存します。 | 115 | | fp16 | 半精度で保存します。 | 116 | | save metadata | 保存時にメタデータにマージ情報を保存します。(safetensorsのみ) | 117 | | prune | 保存時にモデルから不要なデータを削除します。 | 118 | | Reset CLIP ids | CLIP idsをリセットします。 | 119 | | use old calc method | 古い計算関数を使用します| 120 | | debug | デバッグ情報をCMDに出力します。 | 121 | 122 | ### save merged model ID to 123 | 生成された画像またはPNG情報にマージIDを保存するかどうかを選択できます。 124 | ### Additional Value 125 | 現在、calcmodeで'extract'が選択されている場合にのみ有効です。 126 | ### Custom Name (Optional) 127 | モデル名を設定できます。設定されていない場合は、自動的に決定されます。 128 | ### Bake in VAE 129 | モデルを保存するとき、選択されたVAEがモデルに組み込まれます。 130 | 131 | ### Save current Merge 132 | 現在ロードされているモデルを保存します。PCのスペックやその他の問題により、マージに時間がかかる場合に効果的です。 133 | 134 | ## Merge Block Weight 135 | これは階層ごとに割合を設定できるマージ技法です。各階層ごは背景の描写、キャラクター、絵柄などに対応している可能性があるため、階層ごごとの割合を変更することで、様々なマージモデルを作成できます。 136 | 137 | 階層はSDのバージョンによって異なり、以下のような階層があります。 138 | 139 | BASEはテキストエンコーダを指し、プロンプトへの反応などに影響します。IN-OUTはU-Netを指し、画像生成を担当します。 140 | 141 | Stable diffusion 1.X, 2.X 142 | |1|2|3|4|5|6|7|8|9|10|11|12|13|14|15|16|17|18|19|20|21|22|23|24|25|26| 143 | |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-| 144 | |BASE|IN00|IN01|IN02|IN03|IN04|IN05|IN06|IN07|IN08|IN09|IN10|IN11|MID|OUT00|OUT01|OUT02|OUT03|OUT04|OUT05|OUT06|OUT07|OUT08|OUT09|OUT10|OUT11| 145 | 146 | Stable diffusion XL 147 | |1|2|3|4|5|6|7|8|9|10|11|12|13|14|15|16|17|18|19|20| 148 | |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-| 149 | |BASE|IN00|IN01|IN02|IN03|IN04|IN05|IN06|IN07|IN08|MID|OUT00|OUT01|OUT02|OUT03|OUT04|OUT05|OUT06|OUT07|OUT08| 150 | 151 | ## XYZ Plot 152 | ## XYZプロット 153 | 連続的にマージ画像を生成します。すべてのマージモードで効果的です。 154 | ### alpha, beta 155 | alphaとbetaの値を変更します。 156 | ### alpha と beta 157 | alphaとbetaを同時に変更します。alphaとbetaはスペース1つで区切り、各要素はカンマで区切ってください。1つの数字のみ入力された場合、alphaとbetaに同じ値が入力されます。 158 | 例: 0, 0.5 0.1, 0.3 0.4, 0.5 159 | ### MBW 160 | 階層ごとにマージを実行します。改行で区切った割合を入力してください。プリセットは使用できますが、必ず**改行で区切る**ようにしてください。TripleやTwiceの場合は、2行を1セットで入力してください。奇数行の入力ではエラーになります。 161 | ### seed 162 | シード値を変更します。-1を入力すると反対軸方向の固定シードになります。 163 | ### model_A, B, C 164 | モデルを変更します。モデル選択ウィンドウで選択したモデルは無視されます。 165 | ### pinpoint block 166 | MBWで特定の階層のみを変更します。反対軸にはalphaかbetaを選択してください。階層Dを入力すると、その階層のalpha (beta)のみが変更されます。他のタイプと同様にカンマで区切って入力してください。スペースやハイフンで区切ることで、同時に複数の階層を変更できます。効果を得るには、必ず先頭にNOTを入力してください。 167 | #### 入力例 168 | IN01, OUT10 OUT11, OUT03-OUT06, OUT07-OUT11, NOT M00 OUT03-OUT06 169 | この場合 170 | - 1:IN01のみ変化 171 | - 2:OUT10およびOUT11が変化 172 | - 3:OUT03からOUT06が変化 173 | - 4:OUT07からOUT11が変化 174 | - 5:M00およびOUT03からOUT06以外が変化 175 | 176 | となります。0の打ち忘れに注意してください。 177 | ![xy_grid-0006-2934360860 0](https://user-images.githubusercontent.com/122196982/214343111-e82bb20a-799b-4026-8e3c-dd36e26841e3.jpg) 178 | 179 | 階層ID (大文字のみ有効) 180 | BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,IN09,IN10,IN11,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11 181 | 182 | XL model 183 | BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08 184 | 185 | ### calcmode 186 | 計算方式を変更します。適用できるマージモードとの対応に注意して下さい。カンマで区切ります 187 | 188 | ### prompt 189 | プロンプトを変更できます。txt2imgのプロンプトは無視されます。ネガティブプロンプトは有効です。 190 | **改行で区切る**ことに注意をして下さい。 191 | 192 | ### XYプロットの予約 193 | Reserve XY plotボタンはすぐさまプロットを実行せず、ボタンを押したときの設定のXYプロットの実行を予約します。予約したXYプロットは通常のXYプロットが終了した後か、ReservationタブのStart XY plotボタンを押すと実行が開始されます。予約はXYプロット実行時・未実行時いつでも可能です。予約一覧は自動更新されないのでリロードボタンを使用してください。エラー発生時はそのプロットを破棄して次の予約を実行します。すべての予約が終了するまで画像は表示されませんが、Finishedになったものについてはグリッドの生成は終わっているので、Image Browser等で見ることが可能です。 194 | 「|」を使用することで任意の場所で予約へ移動することも可能です。 195 | 0.1,0.2,0.3,0.4,0.5|0.6,0.7,0.8,0.9,1.0とすると 196 | 197 | 0.1,0.2,0.3,0.4,0.5 198 | 0.6,0.7,0.8,0.9,1.0 199 | というふたつの予約に分割され実行されます。これは要素が多すぎてグリッドが大きくなってしまう場合などに有効でしょう。 200 | 201 | ## Adjust 202 | これは、モデルの細部と色調を補正します。LoRAとは異なるメカニズムを採用しています。U-Netの入力と出力のポイントを調整することで、画像の細部と色調を調整できます。 203 | ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/fsample0.jpg) 204 | ## 使い方 205 | 206 | テキストボックスに直接入力するか、スライダーで値を決めてから上ボタンを押してテキストボックスに反映できます。空白のままにしておくと無視されます。 207 | カンマで区切った7つの数字を入力してください。 208 | 209 | ``` 210 | 0,0,0,0,0,0,0,0 211 | ``` 212 | これがデフォルトで、これらの値をずらすことで効果が現れます。 213 | 214 | ### Each setting value 215 | ### 各設定値の意味 216 | 8つの数字は以下に対応しています。 217 | 1. 描き込み/ノイズ 218 | 2. 描き込み/ノイズ 219 | 3. 描き込み/ノイズ 220 | 4. コントラスト/描き込み 221 | 5. 明るさ 222 | 6. 色調1 (シアン-赤) 223 | 7. 色調2 (マゼンタ-緑) 224 | 8. 色調3 (黄-青) 225 | 226 | 描き込みが増えるほどノイズも必然的に増えることに注意してください。また、Hires.fixを使用する場合、出力が異なる可能性があるので、使用される設定でテストすることをおすすめします。 227 | 228 | 値は+/-5程度までは問題ないと思われますが、モデルによって異なります。正の値を入力すると描き込みが増えます。色調には3種類あり、おおむねカラーバランスに対応しているようです。 229 | ### 1,2,3 Detail/Noise 230 | 1.はU-Netの入り口に相当する部分です。ここを調節すると画像の描き込み量が調節できます。ここはOUTに比べて構図が変わりやすいです。マイナスにするとフラットに、そして少しぼけた感じに。プラスにすると描き込みが増えノイジーになります。通常の生成でノイジーでもhires.fixできれいになることがあるので注意してください。2,3はOUTに相当する部分です。 231 | ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/fsample1.jpg) 232 | 233 | ### 4. Contrast/Detail 234 | ここを調節するとコントラストや明るさがかわり、同時に描き込み量も変わります。サンプルを見てもらった方が早いですね。 235 | ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/fsample3.jpg) 236 | 237 | ### 5,6,7,8 Brightness, Color Tone 238 | 明るさと色調を補正できます。概ねカラーバランスに対応するようです。 239 | ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/asample1.jpg) 240 | 241 | ## Let the Dice roll 242 | ランダムにマージ比率を決定します。一度の複数のランダムマージが行えます。比率は各ブロック、各エレメントごとにランダムにすることが可能です。 243 | ## 使い方 244 |  Let the Dice rollで使用できます。`Random Mode`を選択し`Run Rand`を押すと`Num of challenge`の回数分ランダムにウェイトが設定されて画像が生成されます。生成はXYZモードで動作するので`STOP`ボタンが有効です。`Seed for Random Ratio`は`-1`に設定して下さい。Num of challengeの回数が2回以上の場合、自動的に-1に設定されます。同じseedを使うと再現性があります。生成数が10を超える場合グリッドは自動的に2次元になります。`Settings`の`alpha`、`beta`はチェックするとランダム化されます。Elementalの場合`beta`は無効化されます。 245 | 246 | ## 各モード 247 | ### R,U,X 248 | 26ブロックすべてに対してランダムなウェイトが設定されます。`R`、`U`、`X`の違いは乱数の値の範囲です。Xは各層に対して`lower limit` ~ `upper limit`で指定します。 249 | R : 0 ~ 1 250 | U : -0.5 ~ 1.5 251 | X : lower limit ~ upper limit 252 | ### ER,EU,EX 253 | Elementすべてに対してランダムなウェイトが設定されます。`ER`、`EU`、`EX`の違いは乱数の値の範囲です。Xは各層に対して`lower limit` ~ `upper limit`で指定します。 254 | 255 | ### custom 256 | ランダム化される階層を指定します。`costom`で指定します。 257 | `R`、`U`、`X`、`ER`、`EU`、`EX`が使用できます。 258 | 例: 259 | ``` 260 | U,0,0,0,0,0,0,0,0,0,0,0,0,R,R,R,R,R,R,R,R,R,R,R,R,R 261 | U,0,0,0,0,0,0,0,0,0,0,0,0,ER,0,0,0,0,0,0,X,0,0,0,0,0 262 | ``` 263 | ### XYZモード 264 | typeに`random`を設定することで使用できます。ランダム化する回数を入力すると、回数分軸の要素が設定されます。 265 | ``` 266 | X type : seed, -1,-1,-1 267 | Y type : random, 5 268 | ``` 269 | とすると、3×5のgridができ、5回分ランダムにウェイトが設定されたモデルで生成されます。ランダムかの設定はランダムのパネルで設定して下さい。ここが`off`では正常に動作しません。 270 | 271 | ### Settings 272 | - `round` は丸める小数点以下の桁数を設定します。初期値は3で、0.123のようになります。 273 | - `save E-list` はElementalのキーと割合をcsv形式で`script/data/`に保存します。 274 | 275 | ## Elemental Merge 276 | [こちら](elemental_ja.md)を参照して下さい。 277 | 278 | 279 | 280 | ## Generation Parameters 281 | 282 | ここでは画像生成の条件も設定できます。ここで値を設定すると優先されます。 283 | ## Include/Exclude 284 | マージする(しない)階層を設定できます。 Includeの場合ここで設定した階層のみマージされます。Excludeは逆でここで設定した階層のみマージされません。「print」にチェックを入れると、コマンドプロンプト画面で階層が除外されたか確認できます。「Adjust」にチェックを入れると、Adjustで使用する要素がのみマージ/除外されます。`attn`などの文字列も指定でき、この場合`attn`を含む要素のみマージ/除外されます。文字列はカンマで区切ってください。 285 | ## unload button 286 | 現在ロードされているモデルを削除します。kohya-ss GUI使用時にGPUメモリを解放するために使用します。モデルが削除されると画像生成ができなくなります。画像生成を行いたい場合は、再度モデルを選択してください。 287 | 288 | 289 | ## LoRA 290 | LoRA関連の機能です。基本的にはkohya-ssのスクリプトと同じですが、階層マージに対応します。現時点ではV2.X系のマージには対応していません。 291 | 292 | 注意:LyCORISは構造が特殊なため単独マージのみに対応しています。単独マージの比率は1,0のみ使用可能です。他の値を用いるとsame to Strengthでも階層LoRAの結果と一致しません。 293 | LoConは整数以外でもそれなりに一致します。 294 | 295 | LoCon/LyCoris のモデルへのマージにはweb-ui1.5以上が必要です。 296 | | 1.X,2.X | LoRA | LoCon | LyCORIS | 297 | |----------|-------|-------|---------| 298 | | Merge to Model | Yes | Yes | Yes | 299 | | Merge LoRAs | Yes | Yes | No | 300 | | Apply Block Weight(single)|Yes|Yes|Yes| 301 | | Extract From Models | Yes | No | No | 302 | 303 | | XL | LoRA | LoCon | LyCORIS | 304 | |----------|-------|-------|---------| 305 | | Merge to Model | Yes | Yes | Yes | 306 | | Merge LoRAs | Yes | Yes | No | 307 | | Extract From Models | Yes | No | No | 308 | 309 | 310 | ### Merge LoRAs 311 | ひとつまたは複数のLoRA同士をマージします。kohya-ss氏の最新のスクリプトを使用しているので、dimensionの異なるLoRA同氏もマージ可能ですが、dimensionの変換の際はLoRAの再計算を行うため、生成される画像が大きく異なる可能性があることに注意してください。 312 | 313 | calculate dimentionボタンで各LoRAの次元を計算して表示・ソート機能が有効化します。計算にはわりと時間がかかって、50程度のLoRAでも数十秒かかります。新しくマージされたLoRAはリストに表示されないのでリロードボタンを押してください。次元の再計算は追加されたLoRAだけを計算します。 314 | 315 | ### Merge to Checkpoint 316 | Merge LoRAs into a model. Multiple LoRAs can be merged at the same time. 317 | Enter LoRA name1:ratio1:block1,LoRA name2:ratio2:block2,... 318 | LoRA can also be used alone. The ":block" part can be omitted. The ratio can be any value, including negative values. There is no restriction that the total must sum to 1 (of course, if it greatly exceeds 1, it will break down). 319 | To use ":block", use a block preset name from the bottom list of presets, or create your own. Ex: 320 | ``` 321 | LoRAname1:ratio1 322 | LoRAname1:ratio1:ALL 323 | LoRAname1:ratio1:1,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0 324 | ``` 325 | 326 | ### Extract from checkpoints 327 | ふたつのモデルの差分からLoRAを生成します。 328 | demensionを指定すると指定されたdimensionで作製されます。無指定の場合は128で作製します。 329 | alphaとbetaによって配合比率を調整することができます。$(\alpha A - \beta B)$ alpha, beta = 1が通常のLoRA作成となります。 330 | 331 | ### Extract from tow LoRAs 332 | [こちら](calcmode_ja.md#extractlora)を参照して下さい。 333 | 334 | ### Metadata 335 | #### create new 336 | 新しく最小限のMetadataを作製します。dim,alpha,basemodelのversion,filename,networktypeのみが作製されます。 337 | #### merge 338 | 各LoRAの情報が保存され、タグがマージされます。 339 | (各LoRAの情報はWeb-uiの簡易Metadata読み込み機能では見えません) 340 | #### save all 341 | 各LoRAの情報が保存されます。 342 | (各LoRAの情報はWeb-uiの簡易Metadata読み込み機能では見えません) 343 | #### use first LoRA 344 | 最初のLoRAの情報をそのままコピーします 345 | 346 | ### Get Ratios from Prompt 347 | プロンプト欄からLoRAの比率設定を読み込みます。これはLoRA Block Weightの設定も含まれ、そのままマージが行えます。 348 | 349 | ### Difference between Normal Merge and SAME TO STRENGTH 350 | same to Strengthオプションを使用しない場合は、kohya-ss氏の作製したスクリプトのマージと同じ結果になります。この場合、下図のようにWeb-ui上でLoRAを適用した場合と異なる結果になります。これはLoRAをU-netに組み込む際の数式が関係しています。kohya-ss氏のスクリプトでは比率をそのまま掛けていますが、適用時の数式では比率が2乗されてしまうため、比率を1以外の数値に設定すると、あるいはマイナスに設定するとStrength(適用時の強度)と異なる結果となります。same to Strengthオプションを使用すると、マージ時には比率の平方根を駆けることで、適用時にはStrengthと比率が同じ意味を持つように計算しています。また、マイナスが効果が出るようにも計算しています。追加学習をしない場合などはsame to Strengthオプションを使用しても問題ないと思いますが、マージしたLoRAに対して追加学習をする場合はだれも使用しない方がいいかもしれません。 351 | 352 | 下図は通常適用/same to Strengthオプション/通常マージの各場合の生成画像です。figma化とukiyoE LoRAのマージを使用しています。通常マージの場合はマイナス方向でも2乗されてプラスになっていることが分かります。 353 | ![xyz_grid-0014-1534704891](https://user-images.githubusercontent.com/122196982/218322034-b7171298-5159-4619-be1d-ac684da92ed9.jpg) 354 | 355 | ## Other tabs 356 | ## Analysis 357 | 2つのモデルの違いを分析してください。比較したいモデルを選んでください、モデルAとモデルBを。 358 | ### Mode 359 | 360 | ASimilalityモードは、qkvから計算されたテンソルを比較します。他のモードは各要素のコサイン類似度から計算します。ASimilalityモード以外では計算された差が小さくなるようです。ASimilalityモードは出力画像の違いに近い結果を与えるため、一般的にはこのモードを使用すべきです。 361 | このAsimilality分析は、[Asimilality script](https://huggingface.co/JosephusCheung/ASimilarityCalculatior)を拡張して作成されました。 362 | 363 | ### Block Method 364 | ASimilalityモード以外のモードで各階層の比率を計算する方法です。Meanは平均を表し、minは最小値を表し、attn2は階層の計算結果としてattn2の値を出力します。 365 | 366 | ## History 367 | マージ履歴を検索することができます。検索機能は「and」と「or」の両方の検索に対応しています。 368 | 369 | ## Elements 370 | モデルに含まれるElementのリスト、階層の割り当て、およびテンソルのサイズを取得できます。 371 | 372 | ## 謝辞 373 | このスクリプトは[kohya](https://github.com/kohya-ss)氏,[bbc-mc](https://github.com/bbc-mc)氏のスクリプト一部使用しています。また、拡張の開発に貢献した全ての方々にも感謝します。 374 | -------------------------------------------------------------------------------- /changelog.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | ### 2023.10.15 3 | Adjust機能が改良されました。CD TunerのようにBrightness, Cyan-Red, Magenta-Gree, Yellow-Blueのような色指定に変わります。 4 | その他バグfix 5 | 自動でClip Idのリセットを行うオプションを追加 6 | Adjust feature was improved, changed to "Brightness, Cyan-Red, Magenta-Gree, Yellow-Blue" like CD Tuner 7 | other bug fixes were added. 8 | Added option to reset CLIP Ids. 9 | 10 | ### update 2023.10.03 11 | [Merge function of metadata for LoRAs is changed.](#Metadata) 12 | 13 | ### update 2023.09.02.1900(JST) 14 | モデルキャッシュに関する仕様が変わりました。 15 | モデルキャッシュを設定している場合、これまではweb-ui側のモデルキャッシュを使用していましたが、web-ui側の仕様変更により使えなくなりました。 16 | そこで、web-ui側のモデルキャッシュを無効にしてSuperMerger側でモデルをキャッシュするように変更しました。よって、モデルキャッシュを使用する設定にしている場合、SuperMerger使用後にモデルのキャッシュが残ることになります。SuperMeger使用後はClear cacheボタンでメモリの解放を行って下さい。 17 | 18 | The specifications regarding model caching have changed. If you have set up model caching, we used to utilize the model cache on the web-ui side. However, due to changes in the web-ui specifications, this is no longer possible. Therefore, I have disabled the model cache on the web-ui side and have made changes to cache the model on the SuperMerger side instead. As a result, if you have set it to use model caching, the model cache will remain after using SuperMerger. Please clear the cache using the "Clear cache" button to free up memory after using SuperMerger. 19 | 20 | 21 | bug fix/以下のバグを修正しました 22 | - XYZ plot でseedを選択すると発生するバグ 23 | - not work when selecting Seed in XYZ plot 24 | - LoRAマージができないバグ 25 | - Merging LoRA to checkpoint is not work 26 | - Hires fix を使用していないときでもDenoising Strength がPNGinfoに設定される 27 | - Denoising Strength is set to PNG info when Hires fix is not enabled 28 | - LOWVRAM/MEDVRAM使用時に正常に動作しない 29 | - bug when LOWVRAM/MEDVRAM 30 | 31 | ### update 2023.08.31 32 | 33 | - ほぼすべてのtxt2imgタブの設定を使えるようになりました 34 | - Almost all txt2img tab settings are now available in generation 35 | Thanks! [Filexor](https://github.com/Filexor) 36 | 37 | - support XL 38 | - XLモデル対応 39 | 40 | XL capabilities at the moment:XLでいまできること 41 | Merge/Block merge/マージ/階層マージ 42 | Merging LoRA into the model (supported within a few days)/モデルへのLoRAのマージ 43 | 44 | Cannot be done:できないこと 45 | Creating LoRA from model differences (TBD)/モデル差分からLoRA作成(未定) 46 | 47 | ### update 2023.07.07.2000(JST) 48 | - add new feature:[Random merge](#random-merge) 49 | - add new feature:[Adjust detail/colors](https://github.com/hako-mikan/sd-webui-supermerger/blob/main/elemental_en.md#adjust) 50 | 51 | ### update 2023.06.28.2000(JST) 52 | - add Image Generation Parameters(prompt,seed,etc.) 53 | for Vlad fork users, use this panel 54 | 55 | ### update 2023.06.27.2030 56 | - Add new calcmode "trainDifference"[detail here](https://github.com/hako-mikan/sd-webui-supermerger/blob/main/calcmode_en.md#trainDifference) (Thanks [SwiftIllusion](https://github.com/SwiftIllusion)) 57 | - Add Z axis for XY plot 58 | - Add Analysis tab for caclrating the difference of models (thanks [Mohamed-Desouki](https://github.com/Mohamed-Desouki)) 59 | 60 | ### update 2023.06.24.0300(JST) 61 | - VAE bake feature added 62 | - support inpainting/pix2pix 63 | Thanks [wkpark](https://github.com/wkpark) 64 | 65 | ### update 2023.05.02.1900(JST) 66 | - bug fix : Resolved conflict with wildcard in dynamic prompt 67 | - new feature : restore face and tile option added 68 | 69 | ### update 2023.04.19.2030(JST) 70 | - New feature, optimization using cosine similarity method updated [detail here](https://github.com/hako-mikan/sd-webui-supermerger/blob/main/calcmode_en.md#cosine) 71 | - New feature, tensor merge added [detail here](https://github.com/hako-mikan/sd-webui-supermerger/blob/main/calcmode_en.md#tensor) 72 | - New XY plot type : calcmode,prompt 73 | 74 | ### bug fix 2023.02.19.2330(JST) 75 | いくつかのバグが修正されました 76 | - LOWRAMオプション有効時にエラーになる問題 77 | - Linuxでエラーになる問題 78 | - XY plotが正常に終了しない問題 79 | - 未ロードのモデルを設定時にエラーになる問題 80 | 81 | ### update to version 3 2023.02.17.2020(JST) 82 | - LoRA関係の機能を追加しました 83 | - Logを保存し、設定を呼び出せるようになりました 84 | - safetensors,fp16形式での保存に対応しました 85 | - weightのプリセットに対応しました 86 | - XYプロットの予約が可能になりました 87 | 88 | ### bug fix 2023.02.19.2330(JST) 89 | Several bugs have been fixed 90 | - Error when LOWRAM option is enabled 91 | - Error on Linux 92 | - XY plot did not finish properly 93 | - Error when setting unused models 94 | 95 | ### update to version 3 2023.02.17.2020(JST) 96 | - Added LoRA related functions 97 | - Logs can now be saved and settings can be recalled. 98 | - Save in safetensors and fp16 format is now supported. 99 | - Weight presets are now supported. 100 | - Reservation of XY plots is now possible. 101 | 102 | ### bug fix 2023.01.29.0000(JST) 103 | pinpoint blocksがX方向で使用できない問題を修正しました。 104 | pinpoint blocks選択時Triple,Twiceを使用できない問題を解決しました 105 | XY plot 使用時に一部軸タイプでMBWを使用できない問題を解決しました 106 | Fixed a problem where pinpoint blocks could not be used in the X axis. 107 | Fixed a problem in which Triple,Twice could not be used when selecting pinpoint blocks. 108 | Problem solved where MBW could not be used with some axis types when using XY plot. 109 | 110 | ### bug fixed 2023.01.28.0100(JST) 111 | MBWモードでSave current modelボタンが正常に動作しない問題を解決しました 112 | ファイル名が長すぎて保存時にエラーが出る問題を解決しました 113 | Problem solved where the "Save current model" button would not work properly in MBW mode 114 | Problem solved where an error would occur when saving a file with too long a file name 115 | 116 | ### bug fixed 2023.01.26.2100(JST) 117 | XY plotにおいてタイプMBWが使用できない不具合を修正しました 118 | Fixed a bug that type of MBW could work in XY plot 119 | 120 | ### updated 2023.01.25.0000(JST) 121 | added several features 122 | - added new merge mode "Triple sum","sum Twice" 123 | - added XY plot 124 | - 新しいマージモードを追加しました "Triple sum","sum Twice" 125 | - XY plot機能を追加しました 126 | 127 | ### bug fixed 2023.01.20.2350(JST) 128 | png infoがうまく保存できない問題を解決しました。 129 | Problem solved where png info could not be saved properly. 130 | -------------------------------------------------------------------------------- /elemental_en.md: -------------------------------------------------------------------------------- 1 | # Elemental Merge 2 | - This is a block-by-block merge that goes beyond block-by-block merge. 3 | 4 | In a block-by-block merge, the merge ratio can be changed for each of the 25 blocks, but a blocks also consists of multiple elements, and in principle it is possible to change the ratio for each element. It is possible, but the number of elements is more than 600, and it was doubtful whether it could be handled by human hands, but we tried to implement it. I do not recommend merging elements by element out of the blue. It is recommended to use it as a final adjustment when a problem that cannot be solved by block-by-block merging. 5 | The following images show the result of changing the elements in the OUT05 layer. The leftmost one is without merging, the second one is all the OUT05 layers (i.e., normal block-by-block merging), and the rest are element merging. As shown in the table below, there are several more elements in attn2, etc. 6 | ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/sample1.jpg) 7 | ## Usage 8 | Note that elemental merging is effective for both normal and block-by-block merging, and is computed last, so it will overwrite values specified for block-by-block merging. 9 | 10 | Set in Elemental Merge. Note that if text is set here, it will be automatically adapted. Each element is listed in the table below, but it is not necessary to enter the full name of each element. 11 | You can check to see if the effect is properly applied by activating "print change" check. If this check is enabled, the applied elements will be displayed on the command prompt screen during the merge. 12 | 13 | ### Format 14 | Bloks:Element:Ratio, Bloks:Element:Ratio,... 15 | or 16 | Bloks:Element:Ratio 17 | Bloks:Element:Ratio 18 | Bloks:Element:Ratio 19 | 20 | Multiple specifications can be specified by separating them with commas or newlines. Commas and newlines may be mixed. 21 | Bloks can be specified in uppercase from BASE,IN00-M00-OUT11. If left blank, all Bloks will be applied. Multiple Bloks can be specified by separating them with a space. 22 | Similarly, multiple elements can be specified by separating them with a space. 23 | Partial matching is used, so for example, typing "attn" will change both attn1 and attn2, and typing "attn2" will change only attn2. If you want to specify more details, enter "attn2.to_out" and so on. 24 | 25 | OUT03 OUT04 OUT05:attn2 attn1.to_out:0.5 26 | 27 | the ratio of elements containing attn2 and attn1.to_out in the OUT03, OUT04 and OUT05 layers will be 0.5. 28 | If the element column is left blank, all elements in the specified Blocks will change, and the effect will be the same as a block-by-block merge. 29 | If there are duplicate specifications, the one entered later takes precedence. 30 | 31 | OUT06:attn:0.5,OUT06:attn2.to_k:0.2 32 | 33 | is entered, attn other than attn2.to_k in the OUT06 layer will be 0.5, and only attn2.to_k will be 0.2. 34 | 35 | You can invert the effect by first entering NOT. 36 | This can be set by Blocks and Element. 37 | 38 | NOT OUT04:attn:1 39 | 40 | will set the ratio 1 to the attn of all Blocks except the OUT04 layer. 41 | 42 | OUT05:NOT attn proj:0.2 43 | 44 | will set all Blocks except attn and proj in the OUT05 layer to 0.2. 45 | 46 | ## XY plot 47 | Several XY plots for elemental merge are available. Input examples can be found in sample.txt. 48 | #### elemental 49 | Creates XY plots for multiple elemental merges. Elements should be separated from each other by blank lines. 50 | The following image is the result of executing sample1 of sample.txt. 51 | 52 | #### pinpoint element 53 | Creates an XY plot with different values for a specific element. Do the same with elements as with Pinpoint Blocks, but specify alpha for the opposite axis. Separate elements with a new line or comma. 54 | The following image shows the result of running sample 3 of sample.txt. 55 | ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/sample3.jpg) 56 | 57 | #### effective elenemtal checker 58 | Outputs the difference of each element's effective elenemtal checker. The gif.csv file will be created in the output folder under the ModelA and ModelB folders in the diff folder. If there are duplicate file names, rename and save the files, but it is recommended to rename the diff folder to an appropriate name because it is complicated when the number of files increases. 59 | Separate the files with a new line or comma. Use alpha for the opposite axis and enter a single value. This is useful to see the effect of an element, but it is also possible to see the effect of a hierarchy by not specifying an element, so you may use it that way more often. 60 | The following image shows the result of running sample5 of sample.txt. 61 | ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/sample5-1.jpg) 62 | ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/sample5-2.jpg) 63 | 64 | ### List of elements 65 | Basically, it seems that attn is responsible for the face and clothing information. The IN07, OUT03, OUT04, and OUT05 layers seem to have a particularly strong influence. It does not seem to make sense to change the same element in multiple Blocks at the same time, since the degree of influence often differs depending on the Blocks. 66 | No element exists where it is marked null. 67 | 68 | ||IN00|IN01|IN02|IN03|IN04|IN05|IN06|IN07|IN08|IN09|IN10|IN11|M00|M00|OUT00|OUT01|OUT02|OUT03|OUT04|OUT05|OUT06|OUT07|OUT08|OUT09|OUT10|OUT11 69 | |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-| 70 | op.bias|null|null|null||null|null||null|null||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null 71 | op.weight|null|null|null||null|null||null|null||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null 72 | emb_layers.1.bias|null|||null|||null|||null|null||||||||||||||| 73 | emb_layers.1.weight|null|||null|||null|||null|null||||||||||||||| 74 | in_layers.0.bias|null|||null|||null|||null|null||||||||||||||| 75 | in_layers.0.weight|null|||null|||null|||null|null||||||||||||||| 76 | in_layers.2.bias|null|||null|||null|||null|null||||||||||||||| 77 | in_layers.2.weight|null|||null|||null|||null|null||||||||||||||| 78 | out_layers.0.bias|null|||null|||null|||null|null||||||||||||||| 79 | out_layers.0.weight|null|||null|||null|||null|null||||||||||||||| 80 | out_layers.3.bias|null|||null|||null|||null|null||||||||||||||| 81 | out_layers.3.weight|null|||null|||null|||null|null||||||||||||||| 82 | skip_connection.bias|null|null|null|null||null|null||null|null|null|null|null|null|||||||||||| 83 | skip_connection.weight|null|null|null|null||null|null||null|null|null|null|null|null|||||||||||| 84 | norm.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 85 | norm.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 86 | proj_in.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 87 | proj_in.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 88 | proj_out.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 89 | proj_out.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 90 | transformer_blocks.0.attn1.to_k.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 91 | transformer_blocks.0.attn1.to_out.0.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 92 | transformer_blocks.0.attn1.to_out.0.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 93 | transformer_blocks.0.attn1.to_q.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 94 | transformer_blocks.0.attn1.to_v.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 95 | transformer_blocks.0.attn2.to_k.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 96 | transformer_blocks.0.attn2.to_out.0.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 97 | transformer_blocks.0.attn2.to_out.0.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 98 | transformer_blocks.0.attn2.to_q.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 99 | transformer_blocks.0.attn2.to_v.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 100 | transformer_blocks.0.ff.net.0.proj.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 101 | transformer_blocks.0.ff.net.0.proj.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 102 | transformer_blocks.0.ff.net.2.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 103 | transformer_blocks.0.ff.net.2.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 104 | transformer_blocks.0.norm1.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 105 | transformer_blocks.0.norm1.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 106 | transformer_blocks.0.norm2.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 107 | transformer_blocks.0.norm2.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 108 | transformer_blocks.0.norm3.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 109 | transformer_blocks.0.norm3.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 110 | conv.bias|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null||null|null||null|null||null|null|null 111 | conv.weight|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null||null|null||null|null||null|null|null 112 | 0.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 113 | 0.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 114 | 2.bias|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 115 | 2.weight|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 116 | time_embed.0.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 117 | time_embed.0.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 118 | time_embed.2.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 119 | time_embed.2.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 120 | -------------------------------------------------------------------------------- /elemental_ja.md: -------------------------------------------------------------------------------- 1 | # Elemental Merge 2 | - 階層マージを越えた階層マージです 3 | 4 | 階層マージでは25の階層ごとにマージ比率を変えることができますが、階層もまた複数の要素で構成されており、要素ごとに比率を変えることも原理的には可能です。可能ですが、要素の数は600以上にもなり人の手で扱えるのかは疑問でしたが実装してみました。いきなり要素ごとのマージは推奨されません。階層マージにおいて解決不可能な問題が生じたときに最終調節手段として使うことをおすすめします。 5 | 次の画像はOUT05層の要素を変えた結果です。左端はマージ無し。2番目はOUT05層すべて(つまりは普通の階層マージ),以降が要素マージです。下表のとおり、attn2などの中にはさらに複数の要素が含まれます。 6 | ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/sample1.jpg) 7 | 8 | ## 使い方 9 | 要素マージは通常マージ、階層マージ時どちらの場合でも有効で、最後に計算されるために、階層マージで指定した値は上書きされることに注意してください。 10 | 11 | Elemental Mergeで設定します。ここにテキストが設定されていると自動的に適応されるので注意して下さい。各要素は下表のとおりですが、各要素のフルネームを入力する必要はありません。 12 | ちゃんと効果が現れるかどうかはprint changeチェックを有効にすることで確認できます。このチェックを有効にするとマージ時にコマンドプロンプト画面に適用された要素が表示されます。 13 | 部分一致で指定が可能です。 14 | ### 書式 15 | 階層:要素:比率,階層:要素:比率,... 16 | または 17 | 階層:要素:比率 18 | 階層:要素:比率 19 | 階層:要素:比率 20 | 21 | カンマまたは改行で区切ることで複数の指定が可能です。カンマと改行は混在しても問題ありません。 22 | 階層は大文字でBASE,IN00-M00-OUT11まで指定でます。空欄にするとすべての階層に適用されます。スペースで区切ることで複数の階層を指定できます。 23 | 要素も同様でスペースで区切ることで複数の要素を指定できます。 24 | 部分一致で判断するので、例えば「attn」と入力するとattn1,attn2両方が変化します。「attn2」の場合はattn2のみ。さらに細かく指定したい場合は「attn2.to_out」などと入力します。 25 | 26 | OUT03 OUT04 OUT05:attn2 attn1.to_out:0.5 27 | 28 | と入力すると、OUT03,OUT04,OUT05層のattn2が含まれる要素及びattn1.to_outの比率が0.5になります。 29 | 要素の欄を空欄にすると指定階層のすべての要素が変わり、階層マージと同じ効果になります。 30 | 指定が重複する場合、後に入力された方が優先されます。 31 | 32 | OUT06:attn:0.5,OUT06:attn2.to_k:0.2 33 | 34 | と入力した場合、OUT06層のattn2.to_k以外のattnは0.5,attn2.to_kのみ0.2となります。 35 | 36 | 最初にNOTと入力することで効果範囲を反転させることができます。 37 | これは階層・要素別に設定できます。 38 | 39 | NOT OUT04:attn:1 40 | 41 | と入力するとOUT04層以外の層のattnに比率1が設定されます。 42 | 43 | OUT05:NOT attn proj:0.2 44 | 45 | とすると、OUT05層のattnとproj以外の層が0.2になります。 46 | 47 | ## XY plot 48 | elemental用のXY plotを複数用意しています。入力例はsample.txtにあります。 49 | #### elemental 50 | 複数の要素マージについてXY plotを作成します。要素同士は空行で区切ってください。 51 | トップ画像はsample.txtのsample1を実行した結果です。 52 | 53 | #### pinpoint element 54 | 特定の要素について値を変えてXY plotを作成します。pinpoint Blocksと同じことを要素で行います。反対の軸にはalphaを指定してください。要素同士は改行またはカンマで区切ります。 55 | 以下の画像はsample.txtのsample3を実行した結果です。 56 | ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/sample3.jpg) 57 | 58 | #### effective elenemtal checker 59 | 各要素の影響度を差分として出力します。オプションでanime gif、csvファイルを出力できます。gif.csvファイルはoutputフォルダにModelAとModelBから作られるフォルダ下に作成されるdiffフォルダに作成されます。ファイル名が重複する場合名前を変えて保存しますが、増えてくるとややこしいのでdiffフォルダを適当な名前に変えることをおすすめします。 60 | 改行またはカンマで区切ります。反対の軸はalphaを使用し、単一の値を入力してください。これは要素の効果を見るのにも有効ですが、要素を指定しないことで階層の効果を見ることも可能なので、そちらの使い方をする場合が多いかもしれません。 61 | 以下の画像はsample.txtのsample5を実行した結果です。 62 | ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/sample5-1.jpg) 63 | ![](https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/images/sample5-2.jpg) 64 | 65 | 66 | ### 要素一覧 67 | 基本的にはattnが顔や服装の情報を担っているようです。特にIN07,OUT03,OUT04,OUT05層の影響度が強いようです。階層によって影響度が異なることが多いので複数の層の同じ要素を同時に変化させることは意味が無いように思えます。 68 | nullと書かれた場所には要素が存在しません。 69 | 70 | ||IN00|IN01|IN02|IN03|IN04|IN05|IN06|IN07|IN08|IN09|IN10|IN11|M00|M00|OUT00|OUT01|OUT02|OUT03|OUT04|OUT05|OUT06|OUT07|OUT08|OUT09|OUT10|OUT11 71 | |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-| 72 | op.bias|null|null|null||null|null||null|null||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null 73 | op.weight|null|null|null||null|null||null|null||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null 74 | emb_layers.1.bias|null|||null|||null|||null|null||||||||||||||| 75 | emb_layers.1.weight|null|||null|||null|||null|null||||||||||||||| 76 | in_layers.0.bias|null|||null|||null|||null|null||||||||||||||| 77 | in_layers.0.weight|null|||null|||null|||null|null||||||||||||||| 78 | in_layers.2.bias|null|||null|||null|||null|null||||||||||||||| 79 | in_layers.2.weight|null|||null|||null|||null|null||||||||||||||| 80 | out_layers.0.bias|null|||null|||null|||null|null||||||||||||||| 81 | out_layers.0.weight|null|||null|||null|||null|null||||||||||||||| 82 | out_layers.3.bias|null|||null|||null|||null|null||||||||||||||| 83 | out_layers.3.weight|null|||null|||null|||null|null||||||||||||||| 84 | skip_connection.bias|null|null|null|null||null|null||null|null|null|null|null|null|||||||||||| 85 | skip_connection.weight|null|null|null|null||null|null||null|null|null|null|null|null|||||||||||| 86 | norm.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 87 | norm.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 88 | proj_in.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 89 | proj_in.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 90 | proj_out.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 91 | proj_out.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 92 | transformer_blocks.0.attn1.to_k.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 93 | transformer_blocks.0.attn1.to_out.0.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 94 | transformer_blocks.0.attn1.to_out.0.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 95 | transformer_blocks.0.attn1.to_q.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 96 | transformer_blocks.0.attn1.to_v.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 97 | transformer_blocks.0.attn2.to_k.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 98 | transformer_blocks.0.attn2.to_out.0.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 99 | transformer_blocks.0.attn2.to_out.0.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 100 | transformer_blocks.0.attn2.to_q.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 101 | transformer_blocks.0.attn2.to_v.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 102 | transformer_blocks.0.ff.net.0.proj.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 103 | transformer_blocks.0.ff.net.0.proj.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 104 | transformer_blocks.0.ff.net.2.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 105 | transformer_blocks.0.ff.net.2.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 106 | transformer_blocks.0.norm1.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 107 | transformer_blocks.0.norm1.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 108 | transformer_blocks.0.norm2.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 109 | transformer_blocks.0.norm2.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 110 | transformer_blocks.0.norm3.bias|null|||null|||null|||null|null|null||null|null|null|null||||||||| 111 | transformer_blocks.0.norm3.weight|null|||null|||null|||null|null|null||null|null|null|null||||||||| 112 | conv.bias|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null||null|null||null|null||null|null|null 113 | conv.weight|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null||null|null||null|null||null|null|null 114 | 0.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 115 | 0.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 116 | 2.bias|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 117 | 2.weight|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 118 | time_embed.0.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 119 | time_embed.0.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 120 | time_embed.2.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 121 | time_embed.2.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null| 122 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import launch 2 | import importlib 3 | from packaging.version import Version 4 | from packaging.requirements import Requirement 5 | 6 | def is_installed(pip_package): 7 | """ 8 | Check if a package is installed and meets version requirements specified in pip-style format. 9 | 10 | Args: 11 | pip_package (str): Package name in pip-style format (e.g., "numpy>=1.22.0"). 12 | 13 | Returns: 14 | bool: True if the package is installed and meets the version requirement, False otherwise. 15 | """ 16 | try: 17 | # Parse the pip-style package name and version constraints 18 | requirement = Requirement(pip_package) 19 | package_name = requirement.name 20 | specifier = requirement.specifier # e.g., >=1.22.0 21 | 22 | # Check if the package is installed 23 | dist = importlib.metadata.distribution(package_name) 24 | installed_version = Version(dist.version) 25 | 26 | # Check version constraints 27 | if specifier.contains(installed_version): 28 | return True 29 | else: 30 | print(f"Installed version of {package_name} ({installed_version}) does not satisfy the requirement ({specifier}).") 31 | return False 32 | except importlib.metadata.PackageNotFoundError: 33 | print(f"Package {pip_package} is not installed.") 34 | return False 35 | 36 | requirements = [ 37 | "diffusers==0.31.0", 38 | "scikit-learn", 39 | ] 40 | 41 | for module in requirements: 42 | if not is_installed(module): 43 | launch.run_pip(f"install {module}", module) -------------------------------------------------------------------------------- /scripts/A1111/lora_patches.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import networks 4 | from modules import patches 5 | 6 | 7 | class LoraPatches: 8 | def __init__(self): 9 | self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward) 10 | self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict) 11 | self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward) 12 | self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict) 13 | self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward) 14 | self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict) 15 | self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward) 16 | self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict) 17 | self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward) 18 | self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict) 19 | 20 | def undo(self): 21 | self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward') 22 | self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict') 23 | self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward') 24 | self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict') 25 | self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward') 26 | self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict') 27 | self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward') 28 | self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict') 29 | self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward') 30 | self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict') 31 | 32 | -------------------------------------------------------------------------------- /scripts/A1111/lyco_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def make_weight_cp(t, wa, wb): 5 | temp = torch.einsum('i j k l, j r -> i r k l', t, wb) 6 | return torch.einsum('i j k l, i r -> r j k l', temp, wa) 7 | 8 | 9 | def rebuild_conventional(up, down, shape, dyn_dim=None): 10 | up = cpufloat(up.reshape(up.size(0), -1)) 11 | down = cpufloat(down.reshape(down.size(0), -1)) 12 | if dyn_dim is not None: 13 | up = up[:, :dyn_dim] 14 | down = down[:dyn_dim, :] 15 | return (up @ down).reshape(shape) 16 | 17 | 18 | def rebuild_cp_decomposition(up, down, mid): 19 | up = cpufloat(up.reshape(up.size(0), -1)) 20 | down = cpufloat(down.reshape(down.size(0), -1)) 21 | mid = cpufloat(mid) 22 | return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) 23 | 24 | 25 | # copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py 26 | def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: 27 | ''' 28 | return a tuple of two value of input dimension decomposed by the number closest to factor 29 | second value is higher or equal than first value. 30 | 31 | In LoRA with Kroneckor Product, first value is a value for weight scale. 32 | secon value is a value for weight. 33 | 34 | Because of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. 35 | 36 | examples) 37 | factor 38 | -1 2 4 8 16 ... 39 | 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 40 | 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 41 | 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 42 | 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 43 | 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 44 | 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 45 | ''' 46 | 47 | if factor > 0 and (dimension % factor) == 0: 48 | m = factor 49 | n = dimension // factor 50 | if m > n: 51 | n, m = m, n 52 | return m, n 53 | if factor < 0: 54 | factor = dimension 55 | m, n = 1, dimension 56 | length = m + n 57 | while m length or new_m>factor: 63 | break 64 | else: 65 | m, n = new_m, new_n 66 | if m > n: 67 | n, m = m, n 68 | return m, n 69 | 70 | def cpufloat(module): 71 | if module is None: return module #None対策 72 | return module.to(torch.float) if module.device.type == "cpu" else module -------------------------------------------------------------------------------- /scripts/A1111/network.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import os 3 | from collections import namedtuple 4 | import enum 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | try: 11 | from modules import sd_models, cache, errors, hashes, shared 12 | except: 13 | pass 14 | 15 | class QkvLinear(torch.nn.Linear): 16 | pass 17 | 18 | NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) 19 | 20 | metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} 21 | 22 | 23 | class SdVersion(enum.Enum): 24 | Unknown = 1 25 | SD1 = 2 26 | SD2 = 3 27 | SDXL = 4 28 | 29 | 30 | class NetworkOnDisk: 31 | def __init__(self, name, filename): 32 | self.name = name 33 | self.filename = filename 34 | self.metadata = {} 35 | self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" 36 | 37 | def read_metadata(): 38 | metadata = sd_models.read_metadata_from_safetensors(filename) 39 | 40 | return metadata 41 | 42 | if self.is_safetensors: 43 | try: 44 | self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata) 45 | except Exception as e: 46 | errors.display(e, f"reading lora {filename}") 47 | 48 | if self.metadata: 49 | m = {} 50 | for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): 51 | m[k] = v 52 | 53 | self.metadata = m 54 | 55 | self.alias = self.metadata.get('ss_output_name', self.name) 56 | 57 | self.hash = None 58 | self.shorthash = None 59 | self.set_hash( 60 | self.metadata.get('sshs_model_hash') or 61 | hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or 62 | '' 63 | ) 64 | 65 | self.sd_version = self.detect_version() 66 | 67 | def detect_version(self): 68 | if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"): 69 | return SdVersion.SDXL 70 | elif str(self.metadata.get('ss_v2', "")) == "True": 71 | return SdVersion.SD2 72 | elif len(self.metadata): 73 | return SdVersion.SD1 74 | 75 | return SdVersion.Unknown 76 | 77 | def set_hash(self, v): 78 | self.hash = v 79 | self.shorthash = self.hash[0:12] 80 | 81 | if self.shorthash: 82 | import networks 83 | networks.available_network_hash_lookup[self.shorthash] = self 84 | 85 | def read_hash(self): 86 | if not self.hash: 87 | self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '') 88 | 89 | def get_alias(self): 90 | import networks 91 | if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases: 92 | return self.name 93 | else: 94 | return self.alias 95 | 96 | 97 | class Network: # LoraModule 98 | def __init__(self, name, network_on_disk: NetworkOnDisk): 99 | self.name = name 100 | self.network_on_disk = network_on_disk 101 | self.te_multiplier = 1.0 102 | self.unet_multiplier = 1.0 103 | self.dyn_dim = None 104 | self.modules = {} 105 | self.bundle_embeddings = {} 106 | self.mtime = None 107 | 108 | self.mentioned_name = None 109 | """the text that was used to add the network to prompt - can be either name or an alias""" 110 | 111 | 112 | class ModuleType: 113 | def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: 114 | return None 115 | 116 | 117 | class NetworkModule: 118 | def __init__(self, net: Network, weights: NetworkWeights): 119 | self.network = net 120 | self.network_key = weights.network_key 121 | self.sd_key = weights.sd_key 122 | self.sd_module = weights.sd_module 123 | 124 | if isinstance(self.sd_module, QkvLinear): 125 | s = self.sd_module.weight.shape 126 | self.shape = (s[0] // 3, s[1]) 127 | elif hasattr(self.sd_module, 'weight'): 128 | self.shape = self.sd_module.weight.shape 129 | elif isinstance(self.sd_module, nn.MultiheadAttention): 130 | # For now, only self-attn use Pytorch's MHA 131 | # So assume all qkvo proj have same shape 132 | self.shape = self.sd_module.out_proj.weight.shape 133 | else: 134 | self.shape = None 135 | 136 | self.ops = None 137 | self.extra_kwargs = {} 138 | if isinstance(self.sd_module, nn.Conv2d): 139 | self.ops = F.conv2d 140 | self.extra_kwargs = { 141 | 'stride': self.sd_module.stride, 142 | 'padding': self.sd_module.padding 143 | } 144 | elif isinstance(self.sd_module, nn.Linear): 145 | self.ops = F.linear 146 | elif isinstance(self.sd_module, nn.LayerNorm): 147 | self.ops = F.layer_norm 148 | self.extra_kwargs = { 149 | 'normalized_shape': self.sd_module.normalized_shape, 150 | 'eps': self.sd_module.eps 151 | } 152 | elif isinstance(self.sd_module, nn.GroupNorm): 153 | self.ops = F.group_norm 154 | self.extra_kwargs = { 155 | 'num_groups': self.sd_module.num_groups, 156 | 'eps': self.sd_module.eps 157 | } 158 | 159 | self.dim = None 160 | self.bias = weights.w.get("bias") 161 | self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None 162 | self.scale = weights.w["scale"].item() if "scale" in weights.w else None 163 | 164 | self.dora_scale = weights.w.get("dora_scale", None) 165 | self.dora_norm_dims = len(self.shape) - 1 166 | 167 | def multiplier(self): 168 | if 'transformer' in self.sd_key[:20]: 169 | return self.network.te_multiplier 170 | else: 171 | return self.network.unet_multiplier 172 | 173 | def calc_scale(self): 174 | if self.scale is not None: 175 | return self.scale 176 | if self.dim is not None and self.alpha is not None: 177 | return self.alpha / self.dim 178 | 179 | return 1.0 180 | 181 | def apply_weight_decompose(self, updown, orig_weight): 182 | # Match the device/dtype 183 | orig_weight = orig_weight.to(updown.dtype) 184 | dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype) 185 | updown = updown.to(orig_weight.device) 186 | 187 | merged_scale1 = updown + orig_weight 188 | merged_scale1_norm = ( 189 | merged_scale1.transpose(0, 1) 190 | .reshape(merged_scale1.shape[1], -1) 191 | .norm(dim=1, keepdim=True) 192 | .reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims) 193 | .transpose(0, 1) 194 | ) 195 | 196 | dora_merged = ( 197 | merged_scale1 * (dora_scale / merged_scale1_norm) 198 | ) 199 | final_updown = dora_merged - orig_weight 200 | return final_updown 201 | 202 | def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): 203 | if self.bias is not None: 204 | updown = updown.reshape(self.bias.shape) 205 | updown += self.bias.to(orig_weight.device, dtype=updown.dtype) 206 | updown = updown.reshape(output_shape) 207 | 208 | if len(output_shape) == 4: 209 | updown = updown.reshape(output_shape) 210 | 211 | if orig_weight.size().numel() == updown.size().numel(): 212 | updown = updown.reshape(orig_weight.shape) 213 | 214 | if ex_bias is not None: 215 | ex_bias = ex_bias * self.multiplier() 216 | 217 | updown = updown * self.calc_scale() 218 | 219 | if self.dora_scale is not None: 220 | updown = self.apply_weight_decompose(updown, orig_weight) 221 | 222 | return updown * self.multiplier(), ex_bias 223 | 224 | def calc_updown(self, target): 225 | raise NotImplementedError() 226 | 227 | def forward(self, x, y): 228 | """A general forward implementation for all modules""" 229 | if self.ops is None: 230 | raise NotImplementedError() 231 | else: 232 | updown, ex_bias = self.calc_updown(self.sd_module.weight) 233 | return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs) 234 | 235 | -------------------------------------------------------------------------------- /scripts/A1111/network_full.py: -------------------------------------------------------------------------------- 1 | import scripts.A1111.network as network 2 | 3 | 4 | class ModuleTypeFull(network.ModuleType): 5 | def create_module(self, net: network.Network, weights: network.NetworkWeights): 6 | if all(x in weights.w for x in ["diff"]): 7 | return NetworkModuleFull(net, weights) 8 | 9 | return None 10 | 11 | 12 | class NetworkModuleFull(network.NetworkModule): 13 | def __init__(self, net: network.Network, weights: network.NetworkWeights): 14 | super().__init__(net, weights) 15 | 16 | self.weight = weights.w.get("diff") 17 | self.ex_bias = weights.w.get("diff_b") 18 | 19 | def calc_updown(self, orig_weight): 20 | output_shape = self.weight.shape 21 | updown = self.weight.to(orig_weight.device) 22 | if self.ex_bias is not None: 23 | ex_bias = self.ex_bias.to(orig_weight.device) 24 | else: 25 | ex_bias = None 26 | 27 | return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) 28 | -------------------------------------------------------------------------------- /scripts/A1111/network_glora.py: -------------------------------------------------------------------------------- 1 | 2 | import scripts.A1111.network as network 3 | 4 | class ModuleTypeGLora(network.ModuleType): 5 | def create_module(self, net: network.Network, weights: network.NetworkWeights): 6 | if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]): 7 | return NetworkModuleGLora(net, weights) 8 | 9 | return None 10 | 11 | # adapted from https://github.com/KohakuBlueleaf/LyCORIS 12 | class NetworkModuleGLora(network.NetworkModule): 13 | def __init__(self, net: network.Network, weights: network.NetworkWeights): 14 | super().__init__(net, weights) 15 | 16 | if hasattr(self.sd_module, 'weight'): 17 | self.shape = self.sd_module.weight.shape 18 | 19 | self.w1a = weights.w["a1.weight"] 20 | self.w1b = weights.w["b1.weight"] 21 | self.w2a = weights.w["a2.weight"] 22 | self.w2b = weights.w["b2.weight"] 23 | 24 | def calc_updown(self, orig_weight): 25 | w1a = self.w1a.to(orig_weight.device) 26 | w1b = self.w1b.to(orig_weight.device) 27 | w2a = self.w2a.to(orig_weight.device) 28 | w2b = self.w2b.to(orig_weight.device) 29 | 30 | output_shape = [w1a.size(0), w1b.size(1)] 31 | updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a)) 32 | 33 | return self.finalize_updown(updown, orig_weight, output_shape) 34 | -------------------------------------------------------------------------------- /scripts/A1111/network_hada.py: -------------------------------------------------------------------------------- 1 | import scripts.A1111.lyco_helpers as lyco_helpers 2 | import scripts.A1111.network as network 3 | 4 | 5 | class ModuleTypeHada(network.ModuleType): 6 | def create_module(self, net: network.Network, weights: network.NetworkWeights): 7 | if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]): 8 | return NetworkModuleHada(net, weights) 9 | 10 | return None 11 | 12 | 13 | class NetworkModuleHada(network.NetworkModule): 14 | def __init__(self, net: network.Network, weights: network.NetworkWeights): 15 | super().__init__(net, weights) 16 | 17 | if hasattr(self.sd_module, 'weight'): 18 | self.shape = self.sd_module.weight.shape 19 | 20 | self.w1a = weights.w["hada_w1_a"] 21 | self.w1b = weights.w["hada_w1_b"] 22 | self.dim = self.w1b.shape[0] 23 | self.w2a = weights.w["hada_w2_a"] 24 | self.w2b = weights.w["hada_w2_b"] 25 | 26 | self.t1 = weights.w.get("hada_t1") 27 | self.t2 = weights.w.get("hada_t2") 28 | 29 | def calc_updown(self, orig_weight): 30 | w1a = self.w1a.to(orig_weight.device) 31 | w1b = self.w1b.to(orig_weight.device) 32 | w2a = self.w2a.to(orig_weight.device) 33 | w2b = self.w2b.to(orig_weight.device) 34 | 35 | output_shape = [w1a.size(0), w1b.size(1)] 36 | 37 | if self.t1 is not None: 38 | output_shape = [w1a.size(1), w1b.size(1)] 39 | t1 = self.t1.to(orig_weight.device) 40 | updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) 41 | output_shape += t1.shape[2:] 42 | else: 43 | if len(w1b.shape) == 4: 44 | output_shape += w1b.shape[2:] 45 | updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) 46 | 47 | if self.t2 is not None: 48 | t2 = self.t2.to(orig_weight.device) 49 | updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) 50 | else: 51 | updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) 52 | 53 | updown = updown1 * updown2 54 | 55 | return self.finalize_updown(updown, orig_weight, output_shape) 56 | -------------------------------------------------------------------------------- /scripts/A1111/network_ia3.py: -------------------------------------------------------------------------------- 1 | import scripts.A1111.network as network 2 | 3 | 4 | class ModuleTypeIa3(network.ModuleType): 5 | def create_module(self, net: network.Network, weights: network.NetworkWeights): 6 | if all(x in weights.w for x in ["weight"]): 7 | return NetworkModuleIa3(net, weights) 8 | 9 | return None 10 | 11 | 12 | class NetworkModuleIa3(network.NetworkModule): 13 | def __init__(self, net: network.Network, weights: network.NetworkWeights): 14 | super().__init__(net, weights) 15 | 16 | self.w = weights.w["weight"] 17 | self.on_input = weights.w["on_input"].item() 18 | 19 | def calc_updown(self, orig_weight): 20 | w = self.w.to(orig_weight.device) 21 | 22 | output_shape = [w.size(0), orig_weight.size(1)] 23 | if self.on_input: 24 | output_shape.reverse() 25 | else: 26 | w = w.reshape(-1, 1) 27 | 28 | updown = orig_weight * w 29 | 30 | return self.finalize_updown(updown, orig_weight, output_shape) 31 | -------------------------------------------------------------------------------- /scripts/A1111/network_lokr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import scripts.A1111.lyco_helpers as lyco_helpers 4 | import scripts.A1111.network as network 5 | 6 | 7 | class ModuleTypeLokr(network.ModuleType): 8 | def create_module(self, net: network.Network, weights: network.NetworkWeights): 9 | has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w) 10 | has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w) 11 | if has_1 and has_2: 12 | return NetworkModuleLokr(net, weights) 13 | 14 | return None 15 | 16 | 17 | def make_kron(orig_shape, w1, w2): 18 | if len(w2.shape) == 4: 19 | w1 = w1.unsqueeze(2).unsqueeze(2) 20 | w2 = w2.contiguous() 21 | return torch.kron(w1, w2).reshape(orig_shape) 22 | 23 | 24 | class NetworkModuleLokr(network.NetworkModule): 25 | def __init__(self, net: network.Network, weights: network.NetworkWeights): 26 | super().__init__(net, weights) 27 | 28 | self.w1 = weights.w.get("lokr_w1") 29 | self.w1a = weights.w.get("lokr_w1_a") 30 | self.w1b = weights.w.get("lokr_w1_b") 31 | self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim 32 | self.w2 = weights.w.get("lokr_w2") 33 | self.w2a = weights.w.get("lokr_w2_a") 34 | self.w2b = weights.w.get("lokr_w2_b") 35 | self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim 36 | self.t2 = weights.w.get("lokr_t2") 37 | 38 | def calc_updown(self, orig_weight): 39 | if self.w1 is not None: 40 | w1 = self.w1.to(orig_weight.device) 41 | else: 42 | w1a = self.w1a.to(orig_weight.device) 43 | w1b = self.w1b.to(orig_weight.device) 44 | w1 = w1a @ w1b 45 | 46 | if self.w2 is not None: 47 | w2 = self.w2.to(orig_weight.device) 48 | elif self.t2 is None: 49 | w2a = self.w2a.to(orig_weight.device) 50 | w2b = self.w2b.to(orig_weight.device) 51 | w2 = w2a @ w2b 52 | else: 53 | t2 = self.t2.to(orig_weight.device) 54 | w2a = self.w2a.to(orig_weight.device) 55 | w2b = self.w2b.to(orig_weight.device) 56 | w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) 57 | 58 | output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] 59 | if len(orig_weight.shape) == 4: 60 | output_shape = orig_weight.shape 61 | 62 | updown = make_kron(output_shape, w1, w2) 63 | 64 | return self.finalize_updown(updown, orig_weight, output_shape) 65 | -------------------------------------------------------------------------------- /scripts/A1111/network_lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import scripts.A1111.lyco_helpers as lyco_helpers 4 | import scripts.A1111.network as network 5 | from modules import devices 6 | from modules.ui import versions_html 7 | 8 | forge = "forge" in versions_html() 9 | class QkvLinear(torch.nn.Linear): 10 | pass 11 | 12 | class ModuleTypeLora(network.ModuleType): 13 | def create_module(self, net: network.Network, weights: network.NetworkWeights): 14 | if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]): 15 | return NetworkModuleLora(net, weights) 16 | 17 | if all(x in weights.w for x in ["lora_A.weight", "lora_B.weight"]): 18 | w = weights.w.copy() 19 | weights.w.clear() 20 | weights.w.update({"lora_up.weight": w["lora_B.weight"], "lora_down.weight": w["lora_A.weight"]}) 21 | 22 | return NetworkModuleLora(net, weights) 23 | 24 | return None 25 | 26 | 27 | class NetworkModuleLora(network.NetworkModule): 28 | def __init__(self, net: network.Network, weights: network.NetworkWeights): 29 | super().__init__(net, weights) 30 | 31 | self.up_model = self.create_module(weights.w, "lora_up.weight") 32 | self.down_model = self.create_module(weights.w, "lora_down.weight") 33 | self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True) 34 | 35 | self.dim = weights.w["lora_down.weight"].shape[0] 36 | 37 | def create_module(self, weights, key, none_ok=False): 38 | weight = weights.get(key) 39 | 40 | if weight is None and none_ok: 41 | return None 42 | 43 | is_linear = "Linear" in str(type(self.sd_module)) if forge else type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, QkvLinear] 44 | is_conv = "Conv2d" in str(type(self.sd_module)) if forge else type(self.sd_module) in [torch.nn.Conv2d] 45 | 46 | if is_linear: 47 | weight = weight.reshape(weight.shape[0], -1) 48 | module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) 49 | elif is_conv and key == "lora_down.weight" or key == "dyn_up": 50 | if len(weight.shape) == 2: 51 | weight = weight.reshape(weight.shape[0], -1, 1, 1) 52 | 53 | if weight.shape[2] != 1 or weight.shape[3] != 1: 54 | module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) 55 | else: 56 | module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) 57 | elif is_conv and key == "lora_mid.weight": 58 | module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) 59 | elif is_conv and key == "lora_up.weight" or key == "dyn_down": 60 | module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) 61 | else: 62 | raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') 63 | 64 | with torch.no_grad(): 65 | if weight.shape != module.weight.shape: 66 | weight = weight.reshape(module.weight.shape) 67 | module.weight.copy_(weight) 68 | 69 | module.to(device=devices.cpu, dtype=devices.dtype) 70 | module.weight.requires_grad_(False) 71 | 72 | return module 73 | 74 | def calc_updown(self, orig_weight): 75 | up = self.up_model.weight.to(orig_weight.device) 76 | down = self.down_model.weight.to(orig_weight.device) 77 | 78 | output_shape = [up.size(0), down.size(1)] 79 | if self.mid_model is not None: 80 | # cp-decomposition 81 | mid = self.mid_model.weight.to(orig_weight.device) 82 | updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) 83 | output_shape += mid.shape[2:] 84 | else: 85 | if len(down.shape) == 4: 86 | output_shape += down.shape[2:] 87 | updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim) 88 | 89 | return self.finalize_updown(updown, orig_weight, output_shape) 90 | 91 | def forward(self, x, y): 92 | self.up_model.to(device=devices.device) 93 | self.down_model.to(device=devices.device) 94 | 95 | return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale() 96 | 97 | 98 | -------------------------------------------------------------------------------- /scripts/A1111/network_norm.py: -------------------------------------------------------------------------------- 1 | import scripts.A1111.network as network 2 | 3 | class ModuleTypeNorm(network.ModuleType): 4 | def create_module(self, net: network.Network, weights: network.NetworkWeights): 5 | if all(x in weights.w for x in ["w_norm", "b_norm"]): 6 | return NetworkModuleNorm(net, weights) 7 | 8 | return None 9 | 10 | 11 | class NetworkModuleNorm(network.NetworkModule): 12 | def __init__(self, net: network.Network, weights: network.NetworkWeights): 13 | super().__init__(net, weights) 14 | 15 | self.w_norm = weights.w.get("w_norm") 16 | self.b_norm = weights.w.get("b_norm") 17 | 18 | def calc_updown(self, orig_weight): 19 | output_shape = self.w_norm.shape 20 | updown = self.w_norm.to(orig_weight.device) 21 | 22 | if self.b_norm is not None: 23 | ex_bias = self.b_norm.to(orig_weight.device) 24 | else: 25 | ex_bias = None 26 | 27 | return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) 28 | -------------------------------------------------------------------------------- /scripts/A1111/network_oft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import scripts.A1111.network as network 3 | from einops import rearrange 4 | 5 | 6 | class ModuleTypeOFT(network.ModuleType): 7 | def create_module(self, net: network.Network, weights: network.NetworkWeights): 8 | if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]): 9 | return NetworkModuleOFT(net, weights) 10 | 11 | return None 12 | 13 | # Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py 14 | # and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py 15 | class NetworkModuleOFT(network.NetworkModule): 16 | def __init__(self, net: network.Network, weights: network.NetworkWeights): 17 | 18 | super().__init__(net, weights) 19 | 20 | self.lin_module = None 21 | self.org_module: list[torch.Module] = [self.sd_module] 22 | 23 | self.scale = 1.0 24 | self.is_R = False 25 | self.is_boft = False 26 | 27 | # kohya-ss/New LyCORIS OFT/BOFT 28 | if "oft_blocks" in weights.w.keys(): 29 | self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) 30 | self.alpha = weights.w.get("alpha", None) # alpha is constraint 31 | self.dim = self.oft_blocks.shape[0] # lora dim 32 | # Old LyCORIS OFT 33 | elif "oft_diag" in weights.w.keys(): 34 | self.is_R = True 35 | self.oft_blocks = weights.w["oft_diag"] 36 | # self.alpha is unused 37 | self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) 38 | 39 | is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] 40 | is_conv = type(self.sd_module) in [torch.nn.Conv2d] 41 | is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported 42 | 43 | if is_linear: 44 | self.out_dim = self.sd_module.out_features 45 | elif is_conv: 46 | self.out_dim = self.sd_module.out_channels 47 | elif is_other_linear: 48 | self.out_dim = self.sd_module.embed_dim 49 | 50 | # LyCORIS BOFT 51 | if self.oft_blocks.dim() == 4: 52 | self.is_boft = True 53 | self.rescale = weights.w.get('rescale', None) 54 | if self.rescale is not None and not is_other_linear: 55 | self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1)) 56 | 57 | self.num_blocks = self.dim 58 | self.block_size = self.out_dim // self.dim 59 | self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim 60 | if self.is_R: 61 | self.constraint = None 62 | self.block_size = self.dim 63 | self.num_blocks = self.out_dim // self.dim 64 | elif self.is_boft: 65 | self.boft_m = self.oft_blocks.shape[0] 66 | self.num_blocks = self.oft_blocks.shape[1] 67 | self.block_size = self.oft_blocks.shape[2] 68 | self.boft_b = self.block_size 69 | 70 | def calc_updown(self, orig_weight): 71 | oft_blocks = self.oft_blocks.to(orig_weight.device) 72 | eye = torch.eye(self.block_size, device=oft_blocks.device) 73 | 74 | if not self.is_R: 75 | block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix 76 | if self.constraint != 0: 77 | norm_Q = torch.norm(block_Q.flatten()) 78 | new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device)) 79 | block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) 80 | oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) 81 | 82 | R = oft_blocks.to(orig_weight.device) 83 | 84 | if not self.is_boft: 85 | # This errors out for MultiheadAttention, might need to be handled up-stream 86 | merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) 87 | merged_weight = torch.einsum( 88 | 'k n m, k n ... -> k m ...', 89 | R, 90 | merged_weight 91 | ) 92 | merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') 93 | else: 94 | # TODO: determine correct value for scale 95 | scale = 1.0 96 | m = self.boft_m 97 | b = self.boft_b 98 | r_b = b // 2 99 | inp = orig_weight 100 | for i in range(m): 101 | bi = R[i] # b_num, b_size, b_size 102 | if i == 0: 103 | # Apply multiplier/scale and rescale into first weight 104 | bi = bi * scale + (1 - scale) * eye 105 | inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b) 106 | inp = rearrange(inp, "(d b) ... -> d b ...", b=b) 107 | inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp) 108 | inp = rearrange(inp, "d b ... -> (d b) ...") 109 | inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b) 110 | merged_weight = inp 111 | 112 | # Rescale mechanism 113 | if self.rescale is not None: 114 | merged_weight = self.rescale.to(merged_weight) * merged_weight 115 | 116 | updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) 117 | output_shape = orig_weight.shape 118 | return self.finalize_updown(updown, orig_weight, output_shape) 119 | -------------------------------------------------------------------------------- /scripts/GenParamGetter.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import scripts.mergers.components as components 3 | from scripts.mergers.mergers import smergegen, simggen 4 | from scripts.mergers.xyplot import numanager 5 | try: 6 | from scripts.mergers.pluslora import frompromptf 7 | except ImportError as e: 8 | try: 9 | import transformers 10 | transformers_version = transformers.__version__ 11 | except ImportError: 12 | transformers_version = "not installed" 13 | 14 | try: 15 | import diffusers 16 | diffusers_version = diffusers.__version__ 17 | except ImportError: 18 | diffusers_version = "not installed" 19 | 20 | print( 21 | f"Version error: Failed to import module.\n" 22 | f"Transformers version: {transformers_version}\n" 23 | f"Diffusers version: {diffusers_version}\n" 24 | "Please ensure compatibility between these packages." 25 | ) 26 | raise e 27 | 28 | from modules import scripts, script_callbacks 29 | 30 | class GenParamGetter(scripts.Script): 31 | txt2img_gen_button = None 32 | img2img_gen_button = None 33 | 34 | events_assigned = False 35 | 36 | def title(self): 37 | return "Super Marger Generation Parameter Getter" 38 | 39 | def show(self, is_img2img): 40 | return scripts.AlwaysVisible 41 | 42 | def get_wanted_params(params,wanted): 43 | output = [] 44 | for target in wanted: 45 | if target is None: 46 | output.append(params[0]) 47 | continue 48 | for param in params: 49 | if hasattr(param,"label"): 50 | if param.label == target: 51 | output.append(param) 52 | return output 53 | 54 | def after_component(self, component: gr.components.Component, **_kwargs): 55 | """Find generate button""" 56 | if component.elem_id == "txt2img_generate": 57 | GenParamGetter.txt2img_gen_button = component 58 | elif component.elem_id == "img2img_generate": 59 | GenParamGetter.img2img_gen_button = component 60 | 61 | def get_components_by_ids(root: gr.Blocks, ids: list[int]): 62 | components: list[gr.Blocks] = [] 63 | 64 | if root._id in ids: 65 | components.append(root) 66 | ids = [_id for _id in ids if _id != root._id] 67 | 68 | if hasattr(root,"children"): 69 | for block in root.children: 70 | components.extend(GenParamGetter.get_components_by_ids(block, ids)) 71 | return components 72 | 73 | def compare_components_with_ids(components: list[gr.Blocks], ids: list[int]): 74 | 75 | try: 76 | return len(components) == len(ids) and all(component._id == _id for component, _id in zip(components, ids)) 77 | except: 78 | return False 79 | 80 | def get_params_components(demo: gr.Blocks, app): 81 | for _id, _is_txt2img in zip([GenParamGetter.txt2img_gen_button._id, GenParamGetter.img2img_gen_button._id], [True, False]): 82 | if hasattr(demo,"dependencies"): 83 | dependencies: list[dict] = [x for x in demo.dependencies if x["trigger"] == "click" and _id in x["targets"]] 84 | g4 = False 85 | else: 86 | dependencies: list[dict] = [x for x in demo.config["dependencies"] if x["targets"][0][1] == "click" and _id in x["targets"][0]] 87 | g4 = True 88 | 89 | dependency: dict = None 90 | 91 | for d in dependencies: 92 | if len(d["outputs"]) == 4: 93 | dependency = d 94 | 95 | if g4: 96 | params = [demo.blocks[x] for x in dependency['inputs']] 97 | if _is_txt2img: 98 | components.paramsnames = [x.label if hasattr(x,"label") else "None" for x in params] 99 | 100 | if _is_txt2img: 101 | components.txt2img_params = params 102 | else: 103 | components.img2img_params = params 104 | else: 105 | params = [params for params in demo.fns if GenParamGetter.compare_components_with_ids(params.inputs, dependency["inputs"])] 106 | 107 | if _is_txt2img: 108 | components.paramsnames = [x.label if hasattr(x,"label") else "None" for x in params[0].inputs] 109 | 110 | if _is_txt2img: 111 | components.txt2img_params = params[0].inputs 112 | else: 113 | components.img2img_params = params[0].inputs 114 | 115 | 116 | if not GenParamGetter.events_assigned: 117 | with demo: 118 | components.merge.click( 119 | fn=smergegen, 120 | inputs=[*components.msettings,components.esettings1,*components.genparams,*components.hiresfix,*components.lucks,components.currentmodel,components.dfalse,*components.txt2img_params], 121 | outputs=[components.submit_result,components.currentmodel] 122 | ) 123 | 124 | components.mergeandgen.click( 125 | fn=smergegen, 126 | inputs=[*components.msettings,components.esettings1,*components.genparams,*components.hiresfix,*components.lucks,components.currentmodel,components.dtrue,*components.txt2img_params], 127 | outputs=[components.submit_result,components.currentmodel,*components.imagegal] 128 | ) 129 | 130 | components.gen.click( 131 | fn=simggen, 132 | inputs=[*components.genparams,*components.hiresfix,components.currentmodel,components.id_sets,gr.Textbox(value="No id",visible=False),*components.txt2img_params], 133 | outputs=[*components.imagegal], 134 | ) 135 | 136 | components.merge2.click( 137 | fn=smergegen, 138 | inputs=[*components.msettings,components.esettings1,*components.genparams,*components.hiresfix,*components.lucks,components.currentmodel,components.dfalse,*components.txt2img_params], 139 | outputs=[components.submit_result,components.currentmodel] 140 | ) 141 | 142 | components.mergeandgen2.click( 143 | fn=smergegen, 144 | inputs=[*components.msettings,components.esettings1,*components.genparams,*components.hiresfix,*components.lucks,components.currentmodel,components.dtrue,*components.txt2img_params], 145 | outputs=[components.submit_result,components.currentmodel,*components.imagegal] 146 | ) 147 | 148 | components.gen2.click( 149 | fn=simggen, 150 | inputs=[*components.genparams,*components.hiresfix,components.currentmodel,components.id_sets,gr.Textbox(value="No id",visible=False),*components.txt2img_params], 151 | outputs=[*components.imagegal], 152 | ) 153 | 154 | 155 | components.s_reserve.click( 156 | fn=numanager, 157 | inputs=[gr.Textbox(value="reserve",visible=False),*components.xysettings,*components.msettings,*components.genparams,*components.hiresfix,*components.lucks,*components.txt2img_params], 158 | outputs=[components.numaframe] 159 | ) 160 | 161 | components.s_reserve1.click( 162 | fn=numanager, 163 | inputs=[gr.Textbox(value="reserve",visible=False),*components.xysettings,*components.msettings,*components.genparams,*components.hiresfix,*components.lucks,*components.txt2img_params], 164 | outputs=[components.numaframe] 165 | ) 166 | 167 | components.gengrid.click( 168 | fn=numanager, 169 | inputs=[gr.Textbox(value="normal",visible=False),*components.xysettings,*components.msettings,*components.genparams,*components.hiresfix,*components.lucks,*components.txt2img_params], 170 | outputs=[components.submit_result,components.currentmodel,*components.imagegal], 171 | ) 172 | 173 | components.s_startreserve.click( 174 | fn=numanager, 175 | inputs=[gr.Textbox(value=" ",visible=False),*components.xysettings,*components.msettings,*components.genparams,*components.hiresfix,*components.lucks,*components.txt2img_params], 176 | outputs=[components.submit_result,components.currentmodel,*components.imagegal], 177 | ) 178 | 179 | components.rand_merge.click( 180 | fn=numanager, 181 | inputs=[gr.Textbox(value="random",visible=False),*components.xysettings,*components.msettings,*components.genparams,*components.hiresfix,*components.lucks,*components.txt2img_params], 182 | outputs=[components.submit_result,components.currentmodel,*components.imagegal], 183 | ) 184 | 185 | components.frompromptb.click( 186 | fn=frompromptf, 187 | inputs=[*components.txt2img_params], 188 | outputs=components.sml_loranames, 189 | ) 190 | GenParamGetter.events_assigned = True 191 | 192 | if __package__ == "GenParamGetter": 193 | script_callbacks.on_app_started(GenParamGetter.get_params_components) 194 | -------------------------------------------------------------------------------- /scripts/Roboto-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/352c0047ec852fd7000835a861bf8cc7a907e04b/scripts/Roboto-Regular.ttf -------------------------------------------------------------------------------- /scripts/kohyas/extract_lora_from_models.py: -------------------------------------------------------------------------------- 1 | # extract approximating LoRA by svd from two SD models 2 | # The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py 3 | # Thanks to cloneofsimo! 4 | 5 | import argparse 6 | import json 7 | import os 8 | import time 9 | import torch 10 | from safetensors.torch import load_file, save_file 11 | from tqdm import tqdm 12 | from scripts.kohyas import sai_model_spec,model_util,sdxl_model_util,lora 13 | 14 | 15 | CLAMP_QUANTILE = 0.99 16 | MIN_DIFF = 1e-1 17 | 18 | 19 | def save_to_file(file_name, model, state_dict, dtype): 20 | if dtype is not None: 21 | for key in list(state_dict.keys()): 22 | if type(state_dict[key]) == torch.Tensor: 23 | state_dict[key] = state_dict[key].to(dtype) 24 | 25 | if os.path.splitext(file_name)[1] == ".safetensors": 26 | save_file(model, file_name) 27 | else: 28 | torch.save(model, file_name) 29 | 30 | 31 | def svd(args): 32 | def str_to_dtype(p): 33 | if p == "float": 34 | return torch.float 35 | if p == "fp16": 36 | return torch.float16 37 | if p == "bf16": 38 | return torch.bfloat16 39 | return None 40 | 41 | assert args.v2 != args.sdxl or ( 42 | not args.v2 and not args.sdxl 43 | ), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" 44 | if args.v_parameterization is None: 45 | args.v_parameterization = args.v2 46 | 47 | save_dtype = str_to_dtype(args.save_precision) 48 | 49 | # load models 50 | if not args.sdxl: 51 | print(f"loading original SD model : {args.model_org}") 52 | text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org, del_ids=True) 53 | text_encoders_o = [text_encoder_o] 54 | print(f"loading tuned SD model : {args.model_tuned}") 55 | text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned, del_ids=True) 56 | text_encoders_t = [text_encoder_t] 57 | model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization) 58 | else: 59 | print(f"loading original SDXL model : {args.model_org}") 60 | text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( 61 | sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu" 62 | ) 63 | text_encoders_o = [text_encoder_o1, text_encoder_o2] 64 | print(f"loading original SDXL model : {args.model_tuned}") 65 | text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( 66 | sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu" 67 | ) 68 | text_encoders_t = [text_encoder_t1, text_encoder_t2] 69 | model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 70 | 71 | # create LoRA network to extract weights: Use dim (rank) as alpha 72 | if args.conv_dim is None: 73 | kwargs = {} 74 | else: 75 | kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} 76 | 77 | lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs) 78 | lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs) 79 | assert len(lora_network_o.text_encoder_loras) == len( 80 | lora_network_t.text_encoder_loras 81 | ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " 82 | 83 | # get diffs 84 | diffs = {} 85 | text_encoder_different = False 86 | for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): 87 | lora_name = lora_o.lora_name 88 | module_o = lora_o.org_module 89 | module_t = lora_t.org_module 90 | diff = args.alpha * module_t.weight - args.beta * module_o.weight 91 | 92 | # Text Encoder might be same 93 | if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF: 94 | text_encoder_different = True 95 | print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}") 96 | 97 | diff = diff.float() 98 | diffs[lora_name] = diff 99 | 100 | if not text_encoder_different: 101 | print("Text encoder is same. Extract U-Net only.") 102 | lora_network_o.text_encoder_loras = [] 103 | diffs = {} 104 | 105 | for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): 106 | lora_name = lora_o.lora_name 107 | module_o = lora_o.org_module 108 | module_t = lora_t.org_module 109 | diff = args.alpha * module_t.weight - args.beta * module_o.weight 110 | diff = diff.float() 111 | 112 | if args.device: 113 | diff = diff.to(args.device) 114 | 115 | diffs[lora_name] = diff 116 | 117 | # make LoRA with svd 118 | print("calculating by svd") 119 | lora_weights = {} 120 | with torch.no_grad(): 121 | for lora_name, mat in tqdm(list(diffs.items())): 122 | # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3 123 | conv2d = len(mat.size()) == 4 124 | kernel_size = None if not conv2d else mat.size()[2:4] 125 | conv2d_3x3 = conv2d and kernel_size != (1, 1) 126 | 127 | rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim 128 | out_dim, in_dim = mat.size()[0:2] 129 | 130 | if args.device: 131 | mat = mat.to(args.device) 132 | 133 | # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) 134 | rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim 135 | 136 | if conv2d: 137 | if conv2d_3x3: 138 | mat = mat.flatten(start_dim=1) 139 | else: 140 | mat = mat.squeeze() 141 | 142 | U, S, Vh = torch.linalg.svd(mat) 143 | 144 | U = U[:, :rank] 145 | S = S[:rank] 146 | U = U @ torch.diag(S) 147 | 148 | Vh = Vh[:rank, :] 149 | 150 | dist = torch.cat([U.flatten(), Vh.flatten()]) 151 | hi_val = torch.quantile(dist, CLAMP_QUANTILE) 152 | low_val = -hi_val 153 | 154 | U = U.clamp(low_val, hi_val) 155 | Vh = Vh.clamp(low_val, hi_val) 156 | 157 | if conv2d: 158 | U = U.reshape(out_dim, rank, 1, 1) 159 | Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) 160 | 161 | U = U.to("cpu").contiguous() 162 | Vh = Vh.to("cpu").contiguous() 163 | 164 | lora_weights[lora_name] = (U, Vh) 165 | 166 | # make state dict for LoRA 167 | lora_sd = {} 168 | for lora_name, (up_weight, down_weight) in lora_weights.items(): 169 | lora_sd[lora_name + ".lora_up.weight"] = up_weight 170 | lora_sd[lora_name + ".lora_down.weight"] = down_weight 171 | lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) 172 | 173 | # load state dict to LoRA and save it 174 | lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd) 175 | lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict 176 | 177 | info = lora_network_save.load_state_dict(lora_sd) 178 | print(f"Loading extracted LoRA weights: {info}") 179 | 180 | dir_name = os.path.dirname(args.save_to) 181 | if dir_name and not os.path.exists(dir_name): 182 | os.makedirs(dir_name, exist_ok=True) 183 | 184 | # minimum metadata 185 | net_kwargs = {} 186 | if args.conv_dim is not None: 187 | net_kwargs["conv_dim"] = args.conv_dim 188 | net_kwargs["conv_alpha"] = args.conv_dim 189 | 190 | metadata = { 191 | "ss_v2": str(args.v2), 192 | "ss_base_model_version": model_version, 193 | "ss_network_module": "networks.lora", 194 | "ss_network_dim": str(args.dim), 195 | "ss_network_alpha": str(args.dim), 196 | "ss_network_args": json.dumps(net_kwargs), 197 | } 198 | 199 | if not args.no_metadata: 200 | title = os.path.splitext(os.path.basename(args.save_to))[0] 201 | sai_metadata = sai_model_spec.build_metadata( 202 | None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title 203 | ) 204 | metadata.update(sai_metadata) 205 | 206 | lora_network_save.save_weights(args.save_to, save_dtype, metadata) 207 | return f"LoRA weights are saved to: {args.save_to}" 208 | 209 | 210 | -------------------------------------------------------------------------------- /scripts/kohyas/ipex/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import contextlib 4 | import torch 5 | import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import 6 | from .hijacks import ipex_hijacks 7 | from .attention import attention_init 8 | 9 | # pylint: disable=protected-access, missing-function-docstring, line-too-long 10 | 11 | def ipex_init(): # pylint: disable=too-many-statements 12 | try: 13 | #Replace cuda with xpu: 14 | torch.cuda.current_device = torch.xpu.current_device 15 | torch.cuda.current_stream = torch.xpu.current_stream 16 | torch.cuda.device = torch.xpu.device 17 | torch.cuda.device_count = torch.xpu.device_count 18 | torch.cuda.device_of = torch.xpu.device_of 19 | torch.cuda.get_device_name = torch.xpu.get_device_name 20 | torch.cuda.get_device_properties = torch.xpu.get_device_properties 21 | torch.cuda.init = torch.xpu.init 22 | torch.cuda.is_available = torch.xpu.is_available 23 | torch.cuda.is_initialized = torch.xpu.is_initialized 24 | torch.cuda.is_current_stream_capturing = lambda: False 25 | torch.cuda.set_device = torch.xpu.set_device 26 | torch.cuda.stream = torch.xpu.stream 27 | torch.cuda.synchronize = torch.xpu.synchronize 28 | torch.cuda.Event = torch.xpu.Event 29 | torch.cuda.Stream = torch.xpu.Stream 30 | torch.cuda.FloatTensor = torch.xpu.FloatTensor 31 | torch.Tensor.cuda = torch.Tensor.xpu 32 | torch.Tensor.is_cuda = torch.Tensor.is_xpu 33 | torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock 34 | torch.cuda._initialized = torch.xpu.lazy_init._initialized 35 | torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker 36 | torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls 37 | torch.cuda._tls = torch.xpu.lazy_init._tls 38 | torch.cuda.threading = torch.xpu.lazy_init.threading 39 | torch.cuda.traceback = torch.xpu.lazy_init.traceback 40 | torch.cuda.Optional = torch.xpu.Optional 41 | torch.cuda.__cached__ = torch.xpu.__cached__ 42 | torch.cuda.__loader__ = torch.xpu.__loader__ 43 | torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage 44 | torch.cuda.Tuple = torch.xpu.Tuple 45 | torch.cuda.streams = torch.xpu.streams 46 | torch.cuda._lazy_new = torch.xpu._lazy_new 47 | torch.cuda.FloatStorage = torch.xpu.FloatStorage 48 | torch.cuda.Any = torch.xpu.Any 49 | torch.cuda.__doc__ = torch.xpu.__doc__ 50 | torch.cuda.default_generators = torch.xpu.default_generators 51 | torch.cuda.HalfTensor = torch.xpu.HalfTensor 52 | torch.cuda._get_device_index = torch.xpu._get_device_index 53 | torch.cuda.__path__ = torch.xpu.__path__ 54 | torch.cuda.Device = torch.xpu.Device 55 | torch.cuda.IntTensor = torch.xpu.IntTensor 56 | torch.cuda.ByteStorage = torch.xpu.ByteStorage 57 | torch.cuda.set_stream = torch.xpu.set_stream 58 | torch.cuda.BoolStorage = torch.xpu.BoolStorage 59 | torch.cuda.os = torch.xpu.os 60 | torch.cuda.torch = torch.xpu.torch 61 | torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage 62 | torch.cuda.Union = torch.xpu.Union 63 | torch.cuda.DoubleTensor = torch.xpu.DoubleTensor 64 | torch.cuda.ShortTensor = torch.xpu.ShortTensor 65 | torch.cuda.LongTensor = torch.xpu.LongTensor 66 | torch.cuda.IntStorage = torch.xpu.IntStorage 67 | torch.cuda.LongStorage = torch.xpu.LongStorage 68 | torch.cuda.__annotations__ = torch.xpu.__annotations__ 69 | torch.cuda.__package__ = torch.xpu.__package__ 70 | torch.cuda.__builtins__ = torch.xpu.__builtins__ 71 | torch.cuda.CharTensor = torch.xpu.CharTensor 72 | torch.cuda.List = torch.xpu.List 73 | torch.cuda._lazy_init = torch.xpu._lazy_init 74 | torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor 75 | torch.cuda.DoubleStorage = torch.xpu.DoubleStorage 76 | torch.cuda.ByteTensor = torch.xpu.ByteTensor 77 | torch.cuda.StreamContext = torch.xpu.StreamContext 78 | torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage 79 | torch.cuda.ShortStorage = torch.xpu.ShortStorage 80 | torch.cuda._lazy_call = torch.xpu._lazy_call 81 | torch.cuda.HalfStorage = torch.xpu.HalfStorage 82 | torch.cuda.random = torch.xpu.random 83 | torch.cuda._device = torch.xpu._device 84 | torch.cuda.classproperty = torch.xpu.classproperty 85 | torch.cuda.__name__ = torch.xpu.__name__ 86 | torch.cuda._device_t = torch.xpu._device_t 87 | torch.cuda.warnings = torch.xpu.warnings 88 | torch.cuda.__spec__ = torch.xpu.__spec__ 89 | torch.cuda.BoolTensor = torch.xpu.BoolTensor 90 | torch.cuda.CharStorage = torch.xpu.CharStorage 91 | torch.cuda.__file__ = torch.xpu.__file__ 92 | torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork 93 | #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing 94 | 95 | #Memory: 96 | torch.cuda.memory = torch.xpu.memory 97 | if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): 98 | torch.xpu.empty_cache = lambda: None 99 | torch.cuda.empty_cache = torch.xpu.empty_cache 100 | torch.cuda.memory_stats = torch.xpu.memory_stats 101 | torch.cuda.memory_summary = torch.xpu.memory_summary 102 | torch.cuda.memory_snapshot = torch.xpu.memory_snapshot 103 | torch.cuda.memory_allocated = torch.xpu.memory_allocated 104 | torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated 105 | torch.cuda.memory_reserved = torch.xpu.memory_reserved 106 | torch.cuda.memory_cached = torch.xpu.memory_reserved 107 | torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved 108 | torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved 109 | torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats 110 | torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats 111 | torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats 112 | torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict 113 | torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats 114 | 115 | #RNG: 116 | torch.cuda.get_rng_state = torch.xpu.get_rng_state 117 | torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all 118 | torch.cuda.set_rng_state = torch.xpu.set_rng_state 119 | torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all 120 | torch.cuda.manual_seed = torch.xpu.manual_seed 121 | torch.cuda.manual_seed_all = torch.xpu.manual_seed_all 122 | torch.cuda.seed = torch.xpu.seed 123 | torch.cuda.seed_all = torch.xpu.seed_all 124 | torch.cuda.initial_seed = torch.xpu.initial_seed 125 | 126 | #AMP: 127 | torch.cuda.amp = torch.xpu.amp 128 | if not hasattr(torch.cuda.amp, "common"): 129 | torch.cuda.amp.common = contextlib.nullcontext() 130 | torch.cuda.amp.common.amp_definitely_not_available = lambda: False 131 | try: 132 | torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler 133 | except Exception: # pylint: disable=broad-exception-caught 134 | try: 135 | from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error 136 | gradscaler_init() 137 | torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler 138 | except Exception: # pylint: disable=broad-exception-caught 139 | torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler 140 | 141 | #C 142 | torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream 143 | ipex._C._DeviceProperties.major = 2023 144 | ipex._C._DeviceProperties.minor = 2 145 | 146 | #Fix functions with ipex: 147 | torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] 148 | torch._utils._get_available_device_type = lambda: "xpu" 149 | torch.has_cuda = True 150 | torch.cuda.has_half = True 151 | torch.cuda.is_bf16_supported = lambda *args, **kwargs: True 152 | torch.cuda.is_fp16_supported = lambda *args, **kwargs: True 153 | torch.version.cuda = "11.7" 154 | torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] 155 | torch.cuda.get_device_properties.major = 11 156 | torch.cuda.get_device_properties.minor = 7 157 | torch.cuda.ipc_collect = lambda *args, **kwargs: None 158 | torch.cuda.utilization = lambda *args, **kwargs: 0 159 | if hasattr(torch.xpu, 'getDeviceIdListForCard'): 160 | torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard 161 | torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard 162 | else: 163 | torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card 164 | torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card 165 | 166 | ipex_hijacks() 167 | attention_init() 168 | try: 169 | from .diffusers import ipex_diffusers 170 | ipex_diffusers() 171 | except Exception: # pylint: disable=broad-exception-caught 172 | pass 173 | except Exception as e: 174 | return False, e 175 | return True, None 176 | -------------------------------------------------------------------------------- /scripts/kohyas/ipex/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import 3 | 4 | # pylint: disable=protected-access, missing-function-docstring, line-too-long 5 | 6 | original_torch_bmm = torch.bmm 7 | def torch_bmm(input, mat2, *, out=None): 8 | if input.dtype != mat2.dtype: 9 | mat2 = mat2.to(input.dtype) 10 | 11 | #ARC GPUs can't allocate more than 4GB to a single block, Slice it: 12 | batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2] 13 | block_multiply = input.element_size() 14 | slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply 15 | block_size = batch_size_attention * slice_block_size 16 | 17 | split_slice_size = batch_size_attention 18 | if block_size > 4: 19 | do_split = True 20 | #Find something divisible with the input_tokens 21 | while (split_slice_size * slice_block_size) > 4: 22 | split_slice_size = split_slice_size // 2 23 | if split_slice_size <= 1: 24 | split_slice_size = 1 25 | break 26 | else: 27 | do_split = False 28 | 29 | split_2_slice_size = input_tokens 30 | if split_slice_size * slice_block_size > 4: 31 | slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply 32 | do_split_2 = True 33 | #Find something divisible with the input_tokens 34 | while (split_2_slice_size * slice_block_size2) > 4: 35 | split_2_slice_size = split_2_slice_size // 2 36 | if split_2_slice_size <= 1: 37 | split_2_slice_size = 1 38 | break 39 | else: 40 | do_split_2 = False 41 | 42 | if do_split: 43 | hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) 44 | for i in range(batch_size_attention // split_slice_size): 45 | start_idx = i * split_slice_size 46 | end_idx = (i + 1) * split_slice_size 47 | if do_split_2: 48 | for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name 49 | start_idx_2 = i2 * split_2_slice_size 50 | end_idx_2 = (i2 + 1) * split_2_slice_size 51 | hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( 52 | input[start_idx:end_idx, start_idx_2:end_idx_2], 53 | mat2[start_idx:end_idx, start_idx_2:end_idx_2], 54 | out=out 55 | ) 56 | else: 57 | hidden_states[start_idx:end_idx] = original_torch_bmm( 58 | input[start_idx:end_idx], 59 | mat2[start_idx:end_idx], 60 | out=out 61 | ) 62 | else: 63 | return original_torch_bmm(input, mat2, out=out) 64 | return hidden_states 65 | 66 | original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention 67 | def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): 68 | #ARC GPUs can't allocate more than 4GB to a single block, Slice it: 69 | if len(query.shape) == 3: 70 | batch_size_attention, query_tokens, shape_four = query.shape 71 | shape_one = 1 72 | no_shape_one = True 73 | else: 74 | shape_one, batch_size_attention, query_tokens, shape_four = query.shape 75 | no_shape_one = False 76 | 77 | block_multiply = query.element_size() 78 | slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply 79 | block_size = batch_size_attention * slice_block_size 80 | 81 | split_slice_size = batch_size_attention 82 | if block_size > 4: 83 | do_split = True 84 | #Find something divisible with the shape_one 85 | while (split_slice_size * slice_block_size) > 4: 86 | split_slice_size = split_slice_size // 2 87 | if split_slice_size <= 1: 88 | split_slice_size = 1 89 | break 90 | else: 91 | do_split = False 92 | 93 | split_2_slice_size = query_tokens 94 | if split_slice_size * slice_block_size > 4: 95 | slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply 96 | do_split_2 = True 97 | #Find something divisible with the batch_size_attention 98 | while (split_2_slice_size * slice_block_size2) > 4: 99 | split_2_slice_size = split_2_slice_size // 2 100 | if split_2_slice_size <= 1: 101 | split_2_slice_size = 1 102 | break 103 | else: 104 | do_split_2 = False 105 | 106 | if do_split: 107 | hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) 108 | for i in range(batch_size_attention // split_slice_size): 109 | start_idx = i * split_slice_size 110 | end_idx = (i + 1) * split_slice_size 111 | if do_split_2: 112 | for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name 113 | start_idx_2 = i2 * split_2_slice_size 114 | end_idx_2 = (i2 + 1) * split_2_slice_size 115 | if no_shape_one: 116 | hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( 117 | query[start_idx:end_idx, start_idx_2:end_idx_2], 118 | key[start_idx:end_idx, start_idx_2:end_idx_2], 119 | value[start_idx:end_idx, start_idx_2:end_idx_2], 120 | attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, 121 | dropout_p=dropout_p, is_causal=is_causal 122 | ) 123 | else: 124 | hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( 125 | query[:, start_idx:end_idx, start_idx_2:end_idx_2], 126 | key[:, start_idx:end_idx, start_idx_2:end_idx_2], 127 | value[:, start_idx:end_idx, start_idx_2:end_idx_2], 128 | attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, 129 | dropout_p=dropout_p, is_causal=is_causal 130 | ) 131 | else: 132 | if no_shape_one: 133 | hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( 134 | query[start_idx:end_idx], 135 | key[start_idx:end_idx], 136 | value[start_idx:end_idx], 137 | attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, 138 | dropout_p=dropout_p, is_causal=is_causal 139 | ) 140 | else: 141 | hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention( 142 | query[:, start_idx:end_idx], 143 | key[:, start_idx:end_idx], 144 | value[:, start_idx:end_idx], 145 | attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask, 146 | dropout_p=dropout_p, is_causal=is_causal 147 | ) 148 | else: 149 | return original_scaled_dot_product_attention( 150 | query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal 151 | ) 152 | return hidden_states 153 | 154 | def attention_init(): 155 | #ARC GPUs can't allocate more than 4GB to a single block: 156 | torch.bmm = torch_bmm 157 | torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention 158 | -------------------------------------------------------------------------------- /scripts/kohyas/ipex/diffusers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import 3 | import diffusers #0.21.1 # pylint: disable=import-error 4 | from diffusers.models.attention_processor import Attention 5 | 6 | # pylint: disable=protected-access, missing-function-docstring, line-too-long 7 | 8 | class SlicedAttnProcessor: # pylint: disable=too-few-public-methods 9 | r""" 10 | Processor for implementing sliced attention. 11 | 12 | Args: 13 | slice_size (`int`, *optional*): 14 | The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and 15 | `attention_head_dim` must be a multiple of the `slice_size`. 16 | """ 17 | 18 | def __init__(self, slice_size): 19 | self.slice_size = slice_size 20 | 21 | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches 22 | residual = hidden_states 23 | 24 | input_ndim = hidden_states.ndim 25 | 26 | if input_ndim == 4: 27 | batch_size, channel, height, width = hidden_states.shape 28 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 29 | 30 | batch_size, sequence_length, _ = ( 31 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 32 | ) 33 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 34 | 35 | if attn.group_norm is not None: 36 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 37 | 38 | query = attn.to_q(hidden_states) 39 | dim = query.shape[-1] 40 | query = attn.head_to_batch_dim(query) 41 | 42 | if encoder_hidden_states is None: 43 | encoder_hidden_states = hidden_states 44 | elif attn.norm_cross: 45 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 46 | 47 | key = attn.to_k(encoder_hidden_states) 48 | value = attn.to_v(encoder_hidden_states) 49 | key = attn.head_to_batch_dim(key) 50 | value = attn.head_to_batch_dim(value) 51 | 52 | batch_size_attention, query_tokens, shape_three = query.shape 53 | hidden_states = torch.zeros( 54 | (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype 55 | ) 56 | 57 | #ARC GPUs can't allocate more than 4GB to a single block, Slice it: 58 | block_multiply = query.element_size() 59 | slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply 60 | block_size = query_tokens * slice_block_size 61 | split_2_slice_size = query_tokens 62 | if block_size > 4: 63 | do_split_2 = True 64 | #Find something divisible with the query_tokens 65 | while (split_2_slice_size * slice_block_size) > 4: 66 | split_2_slice_size = split_2_slice_size // 2 67 | if split_2_slice_size <= 1: 68 | split_2_slice_size = 1 69 | break 70 | else: 71 | do_split_2 = False 72 | 73 | for i in range(batch_size_attention // self.slice_size): 74 | start_idx = i * self.slice_size 75 | end_idx = (i + 1) * self.slice_size 76 | 77 | if do_split_2: 78 | for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name 79 | start_idx_2 = i2 * split_2_slice_size 80 | end_idx_2 = (i2 + 1) * split_2_slice_size 81 | 82 | query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] 83 | key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] 84 | attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None 85 | 86 | attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) 87 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) 88 | 89 | hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice 90 | else: 91 | query_slice = query[start_idx:end_idx] 92 | key_slice = key[start_idx:end_idx] 93 | attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None 94 | 95 | attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) 96 | 97 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) 98 | 99 | hidden_states[start_idx:end_idx] = attn_slice 100 | 101 | hidden_states = attn.batch_to_head_dim(hidden_states) 102 | 103 | # linear proj 104 | hidden_states = attn.to_out[0](hidden_states) 105 | # dropout 106 | hidden_states = attn.to_out[1](hidden_states) 107 | 108 | if input_ndim == 4: 109 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 110 | 111 | if attn.residual_connection: 112 | hidden_states = hidden_states + residual 113 | 114 | hidden_states = hidden_states / attn.rescale_output_factor 115 | 116 | return hidden_states 117 | 118 | def ipex_diffusers(): 119 | #ARC GPUs can't allocate more than 4GB to a single block: 120 | diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor 121 | -------------------------------------------------------------------------------- /scripts/kohyas/ipex/gradscaler.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch 3 | import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import 4 | import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import 5 | 6 | # pylint: disable=protected-access, missing-function-docstring, line-too-long 7 | 8 | OptState = ipex.cpu.autocast._grad_scaler.OptState 9 | _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator 10 | _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state 11 | 12 | def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument 13 | per_device_inv_scale = _MultiDeviceReplicator(inv_scale) 14 | per_device_found_inf = _MultiDeviceReplicator(found_inf) 15 | 16 | # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. 17 | # There could be hundreds of grads, so we'd like to iterate through them just once. 18 | # However, we don't know their devices or dtypes in advance. 19 | 20 | # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict 21 | # Google says mypy struggles with defaultdicts type annotations. 22 | per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] 23 | # sync grad to master weight 24 | if hasattr(optimizer, "sync_grad"): 25 | optimizer.sync_grad() 26 | with torch.no_grad(): 27 | for group in optimizer.param_groups: 28 | for param in group["params"]: 29 | if param.grad is None: 30 | continue 31 | if (not allow_fp16) and param.grad.dtype == torch.float16: 32 | raise ValueError("Attempting to unscale FP16 gradients.") 33 | if param.grad.is_sparse: 34 | # is_coalesced() == False means the sparse grad has values with duplicate indices. 35 | # coalesce() deduplicates indices and adds all values that have the same index. 36 | # For scaled fp16 values, there's a good chance coalescing will cause overflow, 37 | # so we should check the coalesced _values(). 38 | if param.grad.dtype is torch.float16: 39 | param.grad = param.grad.coalesce() 40 | to_unscale = param.grad._values() 41 | else: 42 | to_unscale = param.grad 43 | 44 | # -: is there a way to split by device and dtype without appending in the inner loop? 45 | to_unscale = to_unscale.to("cpu") 46 | per_device_and_dtype_grads[to_unscale.device][ 47 | to_unscale.dtype 48 | ].append(to_unscale) 49 | 50 | for _, per_dtype_grads in per_device_and_dtype_grads.items(): 51 | for grads in per_dtype_grads.values(): 52 | core._amp_foreach_non_finite_check_and_unscale_( 53 | grads, 54 | per_device_found_inf.get("cpu"), 55 | per_device_inv_scale.get("cpu"), 56 | ) 57 | 58 | return per_device_found_inf._per_device_tensors 59 | 60 | def unscale_(self, optimizer): 61 | """ 62 | Divides ("unscales") the optimizer's gradient tensors by the scale factor. 63 | :meth:`unscale_` is optional, serving cases where you need to 64 | :ref:`modify or inspect gradients` 65 | between the backward pass(es) and :meth:`step`. 66 | If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. 67 | Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: 68 | ... 69 | scaler.scale(loss).backward() 70 | scaler.unscale_(optimizer) 71 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 72 | scaler.step(optimizer) 73 | scaler.update() 74 | Args: 75 | optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. 76 | .. warning:: 77 | :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, 78 | and only after all gradients for that optimizer's assigned parameters have been accumulated. 79 | Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. 80 | .. warning:: 81 | :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. 82 | """ 83 | if not self._enabled: 84 | return 85 | 86 | self._check_scale_growth_tracker("unscale_") 87 | 88 | optimizer_state = self._per_optimizer_states[id(optimizer)] 89 | 90 | if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise 91 | raise RuntimeError( 92 | "unscale_() has already been called on this optimizer since the last update()." 93 | ) 94 | elif optimizer_state["stage"] is OptState.STEPPED: 95 | raise RuntimeError("unscale_() is being called after step().") 96 | 97 | # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. 98 | assert self._scale is not None 99 | inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) 100 | found_inf = torch.full( 101 | (1,), 0.0, dtype=torch.float32, device=self._scale.device 102 | ) 103 | 104 | optimizer_state["found_inf_per_device"] = self._unscale_grads_( 105 | optimizer, inv_scale, found_inf, False 106 | ) 107 | optimizer_state["stage"] = OptState.UNSCALED 108 | 109 | def update(self, new_scale=None): 110 | """ 111 | Updates the scale factor. 112 | If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` 113 | to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, 114 | the scale is multiplied by ``growth_factor`` to increase it. 115 | Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not 116 | used directly, it's used to fill GradScaler's internal scale tensor. So if 117 | ``new_scale`` was a tensor, later in-place changes to that tensor will not further 118 | affect the scale GradScaler uses internally.) 119 | Args: 120 | new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor. 121 | .. warning:: 122 | :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has 123 | been invoked for all optimizers used this iteration. 124 | """ 125 | if not self._enabled: 126 | return 127 | 128 | _scale, _growth_tracker = self._check_scale_growth_tracker("update") 129 | 130 | if new_scale is not None: 131 | # Accept a new user-defined scale. 132 | if isinstance(new_scale, float): 133 | self._scale.fill_(new_scale) # type: ignore[union-attr] 134 | else: 135 | reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False." 136 | assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined] 137 | assert new_scale.numel() == 1, reason 138 | assert new_scale.requires_grad is False, reason 139 | self._scale.copy_(new_scale) # type: ignore[union-attr] 140 | else: 141 | # Consume shared inf/nan data collected from optimizers to update the scale. 142 | # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. 143 | found_infs = [ 144 | found_inf.to(device="cpu", non_blocking=True) 145 | for state in self._per_optimizer_states.values() 146 | for found_inf in state["found_inf_per_device"].values() 147 | ] 148 | 149 | assert len(found_infs) > 0, "No inf checks were recorded prior to update." 150 | 151 | found_inf_combined = found_infs[0] 152 | if len(found_infs) > 1: 153 | for i in range(1, len(found_infs)): 154 | found_inf_combined += found_infs[i] 155 | 156 | to_device = _scale.device 157 | _scale = _scale.to("cpu") 158 | _growth_tracker = _growth_tracker.to("cpu") 159 | 160 | core._amp_update_scale_( 161 | _scale, 162 | _growth_tracker, 163 | found_inf_combined, 164 | self._growth_factor, 165 | self._backoff_factor, 166 | self._growth_interval, 167 | ) 168 | 169 | _scale = _scale.to(to_device) 170 | _growth_tracker = _growth_tracker.to(to_device) 171 | # To prepare for next iteration, clear the data collected from optimizers this iteration. 172 | self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) 173 | 174 | def gradscaler_init(): 175 | torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler 176 | torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_ 177 | torch.xpu.amp.GradScaler.unscale_ = unscale_ 178 | torch.xpu.amp.GradScaler.update = update 179 | return torch.xpu.amp.GradScaler 180 | -------------------------------------------------------------------------------- /scripts/kohyas/ipex/hijacks.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import importlib 3 | import torch 4 | import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import 5 | 6 | # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return 7 | 8 | class CondFunc: # pylint: disable=missing-class-docstring 9 | def __new__(cls, orig_func, sub_func, cond_func): 10 | self = super(CondFunc, cls).__new__(cls) 11 | if isinstance(orig_func, str): 12 | func_path = orig_func.split('.') 13 | for i in range(len(func_path)-1, -1, -1): 14 | try: 15 | resolved_obj = importlib.import_module('.'.join(func_path[:i])) 16 | break 17 | except ImportError: 18 | pass 19 | for attr_name in func_path[i:-1]: 20 | resolved_obj = getattr(resolved_obj, attr_name) 21 | orig_func = getattr(resolved_obj, func_path[-1]) 22 | setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) 23 | self.__init__(orig_func, sub_func, cond_func) 24 | return lambda *args, **kwargs: self(*args, **kwargs) 25 | def __init__(self, orig_func, sub_func, cond_func): 26 | self.__orig_func = orig_func 27 | self.__sub_func = sub_func 28 | self.__cond_func = cond_func 29 | def __call__(self, *args, **kwargs): 30 | if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): 31 | return self.__sub_func(self.__orig_func, *args, **kwargs) 32 | else: 33 | return self.__orig_func(*args, **kwargs) 34 | 35 | _utils = torch.utils.data._utils 36 | def _shutdown_workers(self): 37 | if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None: 38 | return 39 | if hasattr(self, "_shutdown") and not self._shutdown: 40 | self._shutdown = True 41 | try: 42 | if hasattr(self, '_pin_memory_thread'): 43 | self._pin_memory_thread_done_event.set() 44 | self._worker_result_queue.put((None, None)) 45 | self._pin_memory_thread.join() 46 | self._worker_result_queue.cancel_join_thread() 47 | self._worker_result_queue.close() 48 | self._workers_done_event.set() 49 | for worker_id in range(len(self._workers)): 50 | if self._persistent_workers or self._workers_status[worker_id]: 51 | self._mark_worker_as_unavailable(worker_id, shutdown=True) 52 | for w in self._workers: # pylint: disable=invalid-name 53 | w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL) 54 | for q in self._index_queues: # pylint: disable=invalid-name 55 | q.cancel_join_thread() 56 | q.close() 57 | finally: 58 | if self._worker_pids_set: 59 | torch.utils.data._utils.signal_handling._remove_worker_pids(id(self)) 60 | self._worker_pids_set = False 61 | for w in self._workers: # pylint: disable=invalid-name 62 | if w.is_alive(): 63 | w.terminate() 64 | 65 | class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods 66 | def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument 67 | if isinstance(device_ids, list) and len(device_ids) > 1: 68 | print("IPEX backend doesn't support DataParallel on multiple XPU devices") 69 | return module.to("xpu") 70 | 71 | def return_null_context(*args, **kwargs): # pylint: disable=unused-argument 72 | return contextlib.nullcontext() 73 | 74 | def check_device(device): 75 | return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) 76 | 77 | def return_xpu(device): 78 | return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" 79 | 80 | def ipex_no_cuda(orig_func, *args, **kwargs): 81 | torch.cuda.is_available = lambda: False 82 | orig_func(*args, **kwargs) 83 | torch.cuda.is_available = torch.xpu.is_available 84 | 85 | original_autocast = torch.autocast 86 | def ipex_autocast(*args, **kwargs): 87 | if len(args) > 0 and args[0] == "cuda": 88 | return original_autocast("xpu", *args[1:], **kwargs) 89 | else: 90 | return original_autocast(*args, **kwargs) 91 | 92 | original_torch_cat = torch.cat 93 | def torch_cat(tensor, *args, **kwargs): 94 | if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): 95 | return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) 96 | else: 97 | return original_torch_cat(tensor, *args, **kwargs) 98 | 99 | original_interpolate = torch.nn.functional.interpolate 100 | def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments 101 | if antialias or align_corners is not None: 102 | return_device = tensor.device 103 | return_dtype = tensor.dtype 104 | return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, 105 | align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype) 106 | else: 107 | return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, 108 | align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) 109 | 110 | original_linalg_solve = torch.linalg.solve 111 | def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name 112 | if A.device != torch.device("cpu") or B.device != torch.device("cpu"): 113 | return_device = A.device 114 | return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device) 115 | else: 116 | return original_linalg_solve(A, B, *args, **kwargs) 117 | 118 | def ipex_hijacks(): 119 | CondFunc('torch.Tensor.to', 120 | lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), 121 | lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) 122 | CondFunc('torch.Tensor.cuda', 123 | lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), 124 | lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) 125 | CondFunc('torch.empty', 126 | lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), 127 | lambda orig_func, *args, device=None, **kwargs: check_device(device)) 128 | CondFunc('torch.load', 129 | lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs), 130 | lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location)) 131 | CondFunc('torch.randn', 132 | lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), 133 | lambda orig_func, *args, device=None, **kwargs: check_device(device)) 134 | CondFunc('torch.ones', 135 | lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), 136 | lambda orig_func, *args, device=None, **kwargs: check_device(device)) 137 | CondFunc('torch.zeros', 138 | lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), 139 | lambda orig_func, *args, device=None, **kwargs: check_device(device)) 140 | CondFunc('torch.tensor', 141 | lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), 142 | lambda orig_func, *args, device=None, **kwargs: check_device(device)) 143 | CondFunc('torch.linspace', 144 | lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), 145 | lambda orig_func, *args, device=None, **kwargs: check_device(device)) 146 | 147 | CondFunc('torch.Generator', 148 | lambda orig_func, device=None: torch.xpu.Generator(device), 149 | lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") 150 | 151 | CondFunc('torch.batch_norm', 152 | lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, 153 | weight if weight is not None else torch.ones(input.size()[1], device=input.device), 154 | bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), 155 | lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) 156 | CondFunc('torch.instance_norm', 157 | lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, 158 | weight if weight is not None else torch.ones(input.size()[1], device=input.device), 159 | bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), 160 | lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) 161 | 162 | #Functions with dtype errors: 163 | CondFunc('torch.nn.modules.GroupNorm.forward', 164 | lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), 165 | lambda orig_func, self, input: input.dtype != self.weight.data.dtype) 166 | CondFunc('torch.nn.modules.linear.Linear.forward', 167 | lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), 168 | lambda orig_func, self, input: input.dtype != self.weight.data.dtype) 169 | CondFunc('torch.nn.modules.conv.Conv2d.forward', 170 | lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), 171 | lambda orig_func, self, input: input.dtype != self.weight.data.dtype) 172 | CondFunc('torch.nn.functional.layer_norm', 173 | lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: 174 | orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), 175 | lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: 176 | weight is not None and input.dtype != weight.data.dtype) 177 | 178 | #Diffusers Float64 (ARC GPUs doesn't support double or Float64): 179 | if not torch.xpu.has_fp64_dtype(): 180 | CondFunc('torch.from_numpy', 181 | lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), 182 | lambda orig_func, ndarray: ndarray.dtype == float) 183 | 184 | #Broken functions when torch.cuda.is_available is True: 185 | CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', 186 | lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), 187 | lambda orig_func, *args, **kwargs: True) 188 | 189 | #Functions that make compile mad with CondFunc: 190 | torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers 191 | torch.nn.DataParallel = DummyDataParallel 192 | torch.autocast = ipex_autocast 193 | torch.cat = torch_cat 194 | torch.linalg.solve = linalg_solve 195 | torch.nn.functional.interpolate = interpolate 196 | torch.backends.cuda.sdp_kernel = return_null_context 197 | -------------------------------------------------------------------------------- /scripts/kohyas/merge_lora.py: -------------------------------------------------------------------------------- 1 | import math 2 | import argparse 3 | import os 4 | import time 5 | import torch 6 | from safetensors.torch import load_file, save_file 7 | from library import sai_model_spec, train_util 8 | from scripts.kohyas import model_util, lora 9 | 10 | def load_state_dict(file_name, dtype): 11 | if os.path.splitext(file_name)[1] == ".safetensors": 12 | sd = load_file(file_name) 13 | metadata = train_util.load_metadata_from_safetensors(file_name) 14 | else: 15 | sd = torch.load(file_name, map_location="cpu") 16 | metadata = {} 17 | 18 | for key in list(sd.keys()): 19 | if type(sd[key]) == torch.Tensor: 20 | sd[key] = sd[key].to(dtype) 21 | 22 | return sd, metadata 23 | 24 | 25 | def save_to_file(file_name, model, state_dict, dtype, metadata): 26 | if dtype is not None: 27 | for key in list(state_dict.keys()): 28 | if type(state_dict[key]) == torch.Tensor: 29 | state_dict[key] = state_dict[key].to(dtype) 30 | 31 | if os.path.splitext(file_name)[1] == ".safetensors": 32 | save_file(model, file_name, metadata=metadata) 33 | else: 34 | torch.save(model, file_name) 35 | 36 | 37 | def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): 38 | text_encoder.to(merge_dtype) 39 | unet.to(merge_dtype) 40 | 41 | # create module map 42 | name_to_module = {} 43 | for i, root_module in enumerate([text_encoder, unet]): 44 | if i == 0: 45 | prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER 46 | target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE 47 | else: 48 | prefix = lora.LoRANetwork.LORA_PREFIX_UNET 49 | target_replace_modules = ( 50 | lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 51 | ) 52 | 53 | for name, module in root_module.named_modules(): 54 | if module.__class__.__name__ in target_replace_modules: 55 | for child_name, child_module in module.named_modules(): 56 | if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": 57 | lora_name = prefix + "." + name + "." + child_name 58 | lora_name = lora_name.replace(".", "_") 59 | name_to_module[lora_name] = child_module 60 | 61 | for model, ratio in zip(models, ratios): 62 | print(f"loading: {model}") 63 | lora_sd, _ = load_state_dict(model, merge_dtype) 64 | 65 | print(f"merging...") 66 | for key in lora_sd.keys(): 67 | if "lora_down" in key: 68 | up_key = key.replace("lora_down", "lora_up") 69 | alpha_key = key[: key.index("lora_down")] + "alpha" 70 | 71 | # find original module for this lora 72 | module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" 73 | if module_name not in name_to_module: 74 | print(f"no module found for LoRA weight: {key}") 75 | continue 76 | module = name_to_module[module_name] 77 | # print(f"apply {key} to {module}") 78 | 79 | down_weight = lora_sd[key] 80 | up_weight = lora_sd[up_key] 81 | 82 | dim = down_weight.size()[0] 83 | alpha = lora_sd.get(alpha_key, dim) 84 | scale = alpha / dim 85 | 86 | # W <- W + U * D 87 | weight = module.weight 88 | if len(weight.size()) == 2: 89 | # linear 90 | if len(up_weight.size()) == 4: # use linear projection mismatch 91 | up_weight = up_weight.squeeze(3).squeeze(2) 92 | down_weight = down_weight.squeeze(3).squeeze(2) 93 | weight = weight + ratio * (up_weight @ down_weight) * scale 94 | elif down_weight.size()[2:4] == (1, 1): 95 | # conv2d 1x1 96 | weight = ( 97 | weight 98 | + ratio 99 | * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) 100 | * scale 101 | ) 102 | else: 103 | # conv2d 3x3 104 | conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) 105 | # print(conved.size(), weight.size(), module.stride, module.padding) 106 | weight = weight + ratio * conved * scale 107 | 108 | module.weight = torch.nn.Parameter(weight) 109 | 110 | 111 | def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): 112 | base_alphas = {} # alpha for merged model 113 | base_dims = {} 114 | 115 | merged_sd = {} 116 | v2 = None 117 | base_model = None 118 | for model, ratio in zip(models, ratios): 119 | print(f"loading: {model}") 120 | lora_sd, lora_metadata = load_state_dict(model, merge_dtype) 121 | 122 | if lora_metadata is not None: 123 | if v2 is None: 124 | v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string 125 | if base_model is None: 126 | base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) 127 | 128 | # get alpha and dim 129 | alphas = {} # alpha for current model 130 | dims = {} # dims for current model 131 | for key in lora_sd.keys(): 132 | if "alpha" in key: 133 | lora_module_name = key[: key.rfind(".alpha")] 134 | alpha = float(lora_sd[key].detach().numpy()) 135 | alphas[lora_module_name] = alpha 136 | if lora_module_name not in base_alphas: 137 | base_alphas[lora_module_name] = alpha 138 | elif "lora_down" in key: 139 | lora_module_name = key[: key.rfind(".lora_down")] 140 | dim = lora_sd[key].size()[0] 141 | dims[lora_module_name] = dim 142 | if lora_module_name not in base_dims: 143 | base_dims[lora_module_name] = dim 144 | 145 | for lora_module_name in dims.keys(): 146 | if lora_module_name not in alphas: 147 | alpha = dims[lora_module_name] 148 | alphas[lora_module_name] = alpha 149 | if lora_module_name not in base_alphas: 150 | base_alphas[lora_module_name] = alpha 151 | 152 | print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") 153 | 154 | # merge 155 | print(f"merging...") 156 | for key in lora_sd.keys(): 157 | if "alpha" in key: 158 | continue 159 | if "lora_up" in key and concat: 160 | concat_dim = 1 161 | elif "lora_down" in key and concat: 162 | concat_dim = 0 163 | else: 164 | concat_dim = None 165 | 166 | lora_module_name = key[: key.rfind(".lora_")] 167 | 168 | base_alpha = base_alphas[lora_module_name] 169 | alpha = alphas[lora_module_name] 170 | 171 | scale = math.sqrt(alpha / base_alpha) * ratio 172 | scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 173 | 174 | if key in merged_sd: 175 | assert ( 176 | merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None 177 | ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" 178 | if concat_dim is not None: 179 | merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) 180 | else: 181 | merged_sd[key] = merged_sd[key] + lora_sd[key] * scale 182 | else: 183 | merged_sd[key] = lora_sd[key] * scale 184 | 185 | # set alpha to sd 186 | for lora_module_name, alpha in base_alphas.items(): 187 | key = lora_module_name + ".alpha" 188 | merged_sd[key] = torch.tensor(alpha) 189 | if shuffle: 190 | key_down = lora_module_name + ".lora_down.weight" 191 | key_up = lora_module_name + ".lora_up.weight" 192 | dim = merged_sd[key_down].shape[0] 193 | perm = torch.randperm(dim) 194 | merged_sd[key_down] = merged_sd[key_down][perm] 195 | merged_sd[key_up] = merged_sd[key_up][:,perm] 196 | 197 | print("merged model") 198 | print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") 199 | 200 | # check all dims are same 201 | dims_list = list(set(base_dims.values())) 202 | alphas_list = list(set(base_alphas.values())) 203 | all_same_dims = True 204 | all_same_alphas = True 205 | for dims in dims_list: 206 | if dims != dims_list[0]: 207 | all_same_dims = False 208 | break 209 | for alphas in alphas_list: 210 | if alphas != alphas_list[0]: 211 | all_same_alphas = False 212 | break 213 | 214 | # build minimum metadata 215 | dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" 216 | alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" 217 | metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None) 218 | 219 | return merged_sd, metadata, v2 == "True" 220 | 221 | 222 | def merge(args): 223 | assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" 224 | 225 | def str_to_dtype(p): 226 | if p == "float": 227 | return torch.float 228 | if p == "fp16": 229 | return torch.float16 230 | if p == "bf16": 231 | return torch.bfloat16 232 | return None 233 | 234 | merge_dtype = str_to_dtype(args.precision) 235 | save_dtype = str_to_dtype(args.save_precision) 236 | if save_dtype is None: 237 | save_dtype = merge_dtype 238 | 239 | if args.sd_model is not None: 240 | print(f"loading SD model: {args.sd_model}") 241 | 242 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) 243 | 244 | merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) 245 | 246 | if args.no_metadata: 247 | sai_metadata = None 248 | else: 249 | merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models) 250 | title = os.path.splitext(os.path.basename(args.save_to))[0] 251 | sai_metadata = sai_model_spec.build_metadata( 252 | None, 253 | args.v2, 254 | args.v2, 255 | False, 256 | False, 257 | False, 258 | time.time(), 259 | title=title, 260 | merged_from=merged_from, 261 | is_stable_diffusion_ckpt=True, 262 | ) 263 | if args.v2: 264 | # TODO read sai modelspec 265 | print( 266 | "Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" 267 | ) 268 | 269 | print(f"saving SD model to: {args.save_to}") 270 | model_util.save_stable_diffusion_checkpoint( 271 | args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae 272 | ) 273 | else: 274 | state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) 275 | 276 | print(f"calculating hashes and creating metadata...") 277 | 278 | model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) 279 | metadata["sshs_model_hash"] = model_hash 280 | metadata["sshs_legacy_hash"] = legacy_hash 281 | 282 | if not args.no_metadata: 283 | merged_from = sai_model_spec.build_merged_from(args.models) 284 | title = os.path.splitext(os.path.basename(args.save_to))[0] 285 | sai_metadata = sai_model_spec.build_metadata( 286 | state_dict, v2, v2, False, True, False, time.time(), title=title, merged_from=merged_from 287 | ) 288 | if v2: 289 | # TODO read sai modelspec 290 | print( 291 | "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" 292 | ) 293 | metadata.update(sai_metadata) 294 | 295 | print(f"saving model to: {args.save_to}") 296 | save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) 297 | 298 | 299 | def setup_parser() -> argparse.ArgumentParser: 300 | parser = argparse.ArgumentParser() 301 | parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") 302 | parser.add_argument( 303 | "--save_precision", 304 | type=str, 305 | default=None, 306 | choices=[None, "float", "fp16", "bf16"], 307 | help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", 308 | ) 309 | parser.add_argument( 310 | "--precision", 311 | type=str, 312 | default="float", 313 | choices=["float", "fp16", "bf16"], 314 | help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", 315 | ) 316 | parser.add_argument( 317 | "--sd_model", 318 | type=str, 319 | default=None, 320 | help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", 321 | ) 322 | parser.add_argument( 323 | "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" 324 | ) 325 | parser.add_argument( 326 | "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" 327 | ) 328 | parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") 329 | parser.add_argument( 330 | "--no_metadata", 331 | action="store_true", 332 | help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " 333 | + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", 334 | ) 335 | parser.add_argument( 336 | "--concat", 337 | action="store_true", 338 | help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " 339 | + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", 340 | ) 341 | parser.add_argument( 342 | "--shuffle", 343 | action="store_true", 344 | help="shuffle lora weight./ " 345 | + "LoRAの重みをシャッフルする", 346 | ) 347 | 348 | return parser 349 | 350 | 351 | if __name__ == "__main__": 352 | parser = setup_parser() 353 | 354 | args = parser.parse_args() 355 | merge(args) 356 | -------------------------------------------------------------------------------- /scripts/kohyas/sai_model_spec.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/Stability-AI/ModelSpec 2 | import datetime 3 | import hashlib 4 | from io import BytesIO 5 | import os 6 | from typing import List, Optional, Tuple, Union 7 | import safetensors 8 | 9 | r""" 10 | # Metadata Example 11 | metadata = { 12 | # === Must === 13 | "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec 14 | "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID 15 | "modelspec.implementation": "sgm", 16 | "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc 17 | # === Should === 18 | "modelspec.author": "Example Corp", # Your name or company name 19 | "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know 20 | "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created 21 | # === Can === 22 | "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc. 23 | "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model 24 | } 25 | """ 26 | 27 | BASE_METADATA = { 28 | # === Must === 29 | "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec 30 | "modelspec.architecture": None, 31 | "modelspec.implementation": None, 32 | "modelspec.title": None, 33 | "modelspec.resolution": None, 34 | # === Should === 35 | "modelspec.description": None, 36 | "modelspec.author": None, 37 | "modelspec.date": None, 38 | # === Can === 39 | "modelspec.license": None, 40 | "modelspec.tags": None, 41 | "modelspec.merged_from": None, 42 | "modelspec.prediction_type": None, 43 | "modelspec.timestep_range": None, 44 | "modelspec.encoder_layer": None, 45 | } 46 | 47 | # 別に使うやつだけ定義 48 | MODELSPEC_TITLE = "modelspec.title" 49 | 50 | ARCH_SD_V1 = "stable-diffusion-v1" 51 | ARCH_SD_V2_512 = "stable-diffusion-v2-512" 52 | ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" 53 | ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" 54 | 55 | ADAPTER_LORA = "lora" 56 | ADAPTER_TEXTUAL_INVERSION = "textual-inversion" 57 | 58 | IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" 59 | IMPL_DIFFUSERS = "diffusers" 60 | 61 | PRED_TYPE_EPSILON = "epsilon" 62 | PRED_TYPE_V = "v" 63 | 64 | 65 | def load_bytes_in_safetensors(tensors): 66 | bytes = safetensors.torch.save(tensors) 67 | b = BytesIO(bytes) 68 | 69 | b.seek(0) 70 | header = b.read(8) 71 | n = int.from_bytes(header, "little") 72 | 73 | offset = n + 8 74 | b.seek(offset) 75 | 76 | return b.read() 77 | 78 | 79 | def precalculate_safetensors_hashes(state_dict): 80 | # calculate each tensor one by one to reduce memory usage 81 | hash_sha256 = hashlib.sha256() 82 | for tensor in state_dict.values(): 83 | single_tensor_sd = {"tensor": tensor} 84 | bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd) 85 | hash_sha256.update(bytes_for_tensor) 86 | 87 | return f"0x{hash_sha256.hexdigest()}" 88 | 89 | 90 | def update_hash_sha256(metadata: dict, state_dict: dict): 91 | raise NotImplementedError 92 | 93 | 94 | def build_metadata( 95 | state_dict: Optional[dict], 96 | v2: bool, 97 | v_parameterization: bool, 98 | sdxl: bool, 99 | lora: bool, 100 | textual_inversion: bool, 101 | timestamp: float, 102 | title: Optional[str] = None, 103 | reso: Optional[Union[int, Tuple[int, int]]] = None, 104 | is_stable_diffusion_ckpt: Optional[bool] = None, 105 | author: Optional[str] = None, 106 | description: Optional[str] = None, 107 | license: Optional[str] = None, 108 | tags: Optional[str] = None, 109 | merged_from: Optional[str] = None, 110 | timesteps: Optional[Tuple[int, int]] = None, 111 | clip_skip: Optional[int] = None, 112 | ): 113 | # if state_dict is None, hash is not calculated 114 | 115 | metadata = {} 116 | metadata.update(BASE_METADATA) 117 | 118 | # TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する 119 | # if state_dict is not None: 120 | # hash = precalculate_safetensors_hashes(state_dict) 121 | # metadata["modelspec.hash_sha256"] = hash 122 | 123 | if sdxl: 124 | arch = ARCH_SD_XL_V1_BASE 125 | elif v2: 126 | if v_parameterization: 127 | arch = ARCH_SD_V2_768_V 128 | else: 129 | arch = ARCH_SD_V2_512 130 | else: 131 | arch = ARCH_SD_V1 132 | 133 | if lora: 134 | arch += f"/{ADAPTER_LORA}" 135 | elif textual_inversion: 136 | arch += f"/{ADAPTER_TEXTUAL_INVERSION}" 137 | 138 | metadata["modelspec.architecture"] = arch 139 | 140 | if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: 141 | is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion 142 | 143 | if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: 144 | # Stable Diffusion ckpt, TI, SDXL LoRA 145 | impl = IMPL_STABILITY_AI 146 | else: 147 | # v1/v2 LoRA or Diffusers 148 | impl = IMPL_DIFFUSERS 149 | metadata["modelspec.implementation"] = impl 150 | 151 | if title is None: 152 | if lora: 153 | title = "LoRA" 154 | elif textual_inversion: 155 | title = "TextualInversion" 156 | else: 157 | title = "Checkpoint" 158 | title += f"@{timestamp}" 159 | metadata[MODELSPEC_TITLE] = title 160 | 161 | if author is not None: 162 | metadata["modelspec.author"] = author 163 | else: 164 | del metadata["modelspec.author"] 165 | 166 | if description is not None: 167 | metadata["modelspec.description"] = description 168 | else: 169 | del metadata["modelspec.description"] 170 | 171 | if merged_from is not None: 172 | metadata["modelspec.merged_from"] = merged_from 173 | else: 174 | del metadata["modelspec.merged_from"] 175 | 176 | if license is not None: 177 | metadata["modelspec.license"] = license 178 | else: 179 | del metadata["modelspec.license"] 180 | 181 | if tags is not None: 182 | metadata["modelspec.tags"] = tags 183 | else: 184 | del metadata["modelspec.tags"] 185 | 186 | # remove microsecond from time 187 | int_ts = int(timestamp) 188 | 189 | # time to iso-8601 compliant date 190 | date = datetime.datetime.fromtimestamp(int_ts).isoformat() 191 | metadata["modelspec.date"] = date 192 | 193 | if reso is not None: 194 | # comma separated to tuple 195 | if isinstance(reso, str): 196 | reso = tuple(map(int, reso.split(","))) 197 | if len(reso) == 1: 198 | reso = (reso[0], reso[0]) 199 | else: 200 | # resolution is defined in dataset, so use default 201 | if sdxl: 202 | reso = 1024 203 | elif v2 and v_parameterization: 204 | reso = 768 205 | else: 206 | reso = 512 207 | if isinstance(reso, int): 208 | reso = (reso, reso) 209 | 210 | metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}" 211 | 212 | if v_parameterization: 213 | metadata["modelspec.prediction_type"] = PRED_TYPE_V 214 | else: 215 | metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON 216 | 217 | if timesteps is not None: 218 | if isinstance(timesteps, str) or isinstance(timesteps, int): 219 | timesteps = (timesteps, timesteps) 220 | if len(timesteps) == 1: 221 | timesteps = (timesteps[0], timesteps[0]) 222 | metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}" 223 | else: 224 | del metadata["modelspec.timestep_range"] 225 | 226 | if clip_skip is not None: 227 | metadata["modelspec.encoder_layer"] = f"{clip_skip}" 228 | else: 229 | del metadata["modelspec.encoder_layer"] 230 | 231 | # # assert all values are filled 232 | # assert all([v is not None for v in metadata.values()]), metadata 233 | if not all([v is not None for v in metadata.values()]): 234 | print(f"Internal error: some metadata values are None: {metadata}") 235 | 236 | return metadata 237 | 238 | 239 | # region utils 240 | 241 | 242 | def get_title(metadata: dict) -> Optional[str]: 243 | return metadata.get(MODELSPEC_TITLE, None) 244 | 245 | 246 | def load_metadata_from_safetensors(model: str) -> dict: 247 | if not model.endswith(".safetensors"): 248 | return {} 249 | 250 | with safetensors.safe_open(model, framework="pt") as f: 251 | metadata = f.metadata() 252 | if metadata is None: 253 | metadata = {} 254 | return metadata 255 | 256 | 257 | def build_merged_from(models: List[str]) -> str: 258 | def get_title(model: str): 259 | metadata = load_metadata_from_safetensors(model) 260 | title = metadata.get(MODELSPEC_TITLE, None) 261 | if title is None: 262 | title = os.path.splitext(os.path.basename(model))[0] # use filename 263 | return title 264 | 265 | titles = [get_title(model) for model in models] 266 | return ", ".join(titles) 267 | 268 | 269 | # endregion 270 | 271 | 272 | r""" 273 | if __name__ == "__main__": 274 | import argparse 275 | import torch 276 | from safetensors.torch import load_file 277 | from library import train_util 278 | 279 | parser = argparse.ArgumentParser() 280 | parser.add_argument("--ckpt", type=str, required=True) 281 | args = parser.parse_args() 282 | 283 | print(f"Loading {args.ckpt}") 284 | state_dict = load_file(args.ckpt) 285 | 286 | print(f"Calculating metadata") 287 | metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0) 288 | print(metadata) 289 | del state_dict 290 | 291 | # by reference implementation 292 | with open(args.ckpt, mode="rb") as file_data: 293 | file_hash = hashlib.sha256() 294 | head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix 295 | header = json.loads(file_data.read(head_len[0])) # header itself, json string 296 | content = ( 297 | file_data.read() 298 | ) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl. 299 | file_hash.update(content) 300 | # ===== Update the hash for modelspec ===== 301 | by_ref = f"0x{file_hash.hexdigest()}" 302 | print(by_ref) 303 | print("is same?", by_ref == metadata["modelspec.hash_sha256"]) 304 | 305 | """ 306 | -------------------------------------------------------------------------------- /scripts/kohyas/sdxl_merge_lora.py: -------------------------------------------------------------------------------- 1 | import math 2 | import argparse 3 | import os 4 | import time 5 | import torch 6 | from safetensors.torch import load_file, save_file 7 | from tqdm import tqdm 8 | from scripts.kohyas import sai_model_spec, sdxl_model_util, train_util, lora 9 | 10 | def load_state_dict(file_name, dtype): 11 | if os.path.splitext(file_name)[1] == ".safetensors": 12 | sd = load_file(file_name) 13 | metadata = train_util.load_metadata_from_safetensors(file_name) 14 | else: 15 | sd = torch.load(file_name, map_location="cpu") 16 | metadata = {} 17 | 18 | for key in list(sd.keys()): 19 | if type(sd[key]) == torch.Tensor: 20 | sd[key] = sd[key].to(dtype) 21 | 22 | return sd, metadata 23 | 24 | 25 | def save_to_file(file_name, model, state_dict, dtype, metadata): 26 | if dtype is not None: 27 | for key in list(state_dict.keys()): 28 | if type(state_dict[key]) == torch.Tensor: 29 | state_dict[key] = state_dict[key].to(dtype) 30 | 31 | if os.path.splitext(file_name)[1] == ".safetensors": 32 | save_file(model, file_name, metadata=metadata) 33 | else: 34 | torch.save(model, file_name) 35 | 36 | 37 | def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): 38 | text_encoder1.to(merge_dtype) 39 | text_encoder1.to(merge_dtype) 40 | unet.to(merge_dtype) 41 | 42 | # create module map 43 | name_to_module = {} 44 | for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): 45 | if i <= 1: 46 | if i == 0: 47 | prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 48 | else: 49 | prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 50 | target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE 51 | else: 52 | prefix = lora.LoRANetwork.LORA_PREFIX_UNET 53 | target_replace_modules = ( 54 | lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 55 | ) 56 | 57 | for name, module in root_module.named_modules(): 58 | if module.__class__.__name__ in target_replace_modules: 59 | for child_name, child_module in module.named_modules(): 60 | if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": 61 | lora_name = prefix + "." + name + "." + child_name 62 | lora_name = lora_name.replace(".", "_") 63 | name_to_module[lora_name] = child_module 64 | 65 | for model, ratio in zip(models, ratios): 66 | print(f"loading: {model}") 67 | lora_sd, _ = load_state_dict(model, merge_dtype) 68 | 69 | print(f"merging...") 70 | for key in tqdm(lora_sd.keys()): 71 | if "lora_down" in key: 72 | up_key = key.replace("lora_down", "lora_up") 73 | alpha_key = key[: key.index("lora_down")] + "alpha" 74 | 75 | # find original module for this lora 76 | module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" 77 | if module_name not in name_to_module: 78 | print(f"no module found for LoRA weight: {key}") 79 | continue 80 | module = name_to_module[module_name] 81 | # print(f"apply {key} to {module}") 82 | 83 | down_weight = lora_sd[key] 84 | up_weight = lora_sd[up_key] 85 | 86 | dim = down_weight.size()[0] 87 | alpha = lora_sd.get(alpha_key, dim) 88 | scale = alpha / dim 89 | 90 | # W <- W + U * D 91 | weight = module.weight 92 | # print(module_name, down_weight.size(), up_weight.size()) 93 | if len(weight.size()) == 2: 94 | # linear 95 | weight = weight + ratio * (up_weight @ down_weight) * scale 96 | elif down_weight.size()[2:4] == (1, 1): 97 | # conv2d 1x1 98 | weight = ( 99 | weight 100 | + ratio 101 | * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) 102 | * scale 103 | ) 104 | else: 105 | # conv2d 3x3 106 | conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) 107 | # print(conved.size(), weight.size(), module.stride, module.padding) 108 | weight = weight + ratio * conved * scale 109 | 110 | module.weight = torch.nn.Parameter(weight) 111 | 112 | 113 | def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): 114 | base_alphas = {} # alpha for merged model 115 | base_dims = {} 116 | 117 | merged_sd = {} 118 | v2 = None 119 | base_model = None 120 | for model, ratio in zip(models, ratios): 121 | print(f"loading: {model}") 122 | lora_sd, lora_metadata = load_state_dict(model, merge_dtype) 123 | 124 | if lora_metadata is not None: 125 | if v2 is None: 126 | v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず 127 | if base_model is None: 128 | base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) 129 | 130 | # get alpha and dim 131 | alphas = {} # alpha for current model 132 | dims = {} # dims for current model 133 | for key in lora_sd.keys(): 134 | if "alpha" in key: 135 | lora_module_name = key[: key.rfind(".alpha")] 136 | alpha = float(lora_sd[key].detach().numpy()) 137 | alphas[lora_module_name] = alpha 138 | if lora_module_name not in base_alphas: 139 | base_alphas[lora_module_name] = alpha 140 | elif "lora_down" in key: 141 | lora_module_name = key[: key.rfind(".lora_down")] 142 | dim = lora_sd[key].size()[0] 143 | dims[lora_module_name] = dim 144 | if lora_module_name not in base_dims: 145 | base_dims[lora_module_name] = dim 146 | 147 | for lora_module_name in dims.keys(): 148 | if lora_module_name not in alphas: 149 | alpha = dims[lora_module_name] 150 | alphas[lora_module_name] = alpha 151 | if lora_module_name not in base_alphas: 152 | base_alphas[lora_module_name] = alpha 153 | 154 | print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") 155 | 156 | # merge 157 | print(f"merging...") 158 | for key in tqdm(lora_sd.keys()): 159 | if "alpha" in key: 160 | continue 161 | 162 | if "lora_up" in key and concat: 163 | concat_dim = 1 164 | elif "lora_down" in key and concat: 165 | concat_dim = 0 166 | else: 167 | concat_dim = None 168 | 169 | lora_module_name = key[: key.rfind(".lora_")] 170 | 171 | base_alpha = base_alphas[lora_module_name] 172 | alpha = alphas[lora_module_name] 173 | 174 | scale = math.sqrt(alpha / base_alpha) * ratio 175 | scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 176 | 177 | if key in merged_sd: 178 | assert ( 179 | merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None 180 | ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" 181 | if concat_dim is not None: 182 | merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) 183 | else: 184 | merged_sd[key] = merged_sd[key] + lora_sd[key] * scale 185 | else: 186 | merged_sd[key] = lora_sd[key] * scale 187 | 188 | # set alpha to sd 189 | for lora_module_name, alpha in base_alphas.items(): 190 | key = lora_module_name + ".alpha" 191 | merged_sd[key] = torch.tensor(alpha) 192 | if shuffle: 193 | key_down = lora_module_name + ".lora_down.weight" 194 | key_up = lora_module_name + ".lora_up.weight" 195 | dim = merged_sd[key_down].shape[0] 196 | perm = torch.randperm(dim) 197 | merged_sd[key_down] = merged_sd[key_down][perm] 198 | merged_sd[key_up] = merged_sd[key_up][:,perm] 199 | 200 | print("merged model") 201 | print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") 202 | 203 | # check all dims are same 204 | dims_list = list(set(base_dims.values())) 205 | alphas_list = list(set(base_alphas.values())) 206 | all_same_dims = True 207 | all_same_alphas = True 208 | for dims in dims_list: 209 | if dims != dims_list[0]: 210 | all_same_dims = False 211 | break 212 | for alphas in alphas_list: 213 | if alphas != alphas_list[0]: 214 | all_same_alphas = False 215 | break 216 | 217 | # build minimum metadata 218 | dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" 219 | alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" 220 | metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None) 221 | 222 | return merged_sd, metadata 223 | 224 | 225 | def merge(args): 226 | assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" 227 | 228 | def str_to_dtype(p): 229 | if p == "float": 230 | return torch.float 231 | if p == "fp16": 232 | return torch.float16 233 | if p == "bf16": 234 | return torch.bfloat16 235 | return None 236 | 237 | merge_dtype = str_to_dtype(args.precision) 238 | save_dtype = str_to_dtype(args.save_precision) 239 | if save_dtype is None: 240 | save_dtype = merge_dtype 241 | 242 | if args.sd_model is not None: 243 | print(f"loading SD model: {args.sd_model}") 244 | 245 | ( 246 | text_model1, 247 | text_model2, 248 | vae, 249 | unet, 250 | logit_scale, 251 | ckpt_info, 252 | ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu") 253 | 254 | merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype) 255 | 256 | if args.no_metadata: 257 | sai_metadata = None 258 | else: 259 | merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models) 260 | title = os.path.splitext(os.path.basename(args.save_to))[0] 261 | sai_metadata = sai_model_spec.build_metadata( 262 | None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from 263 | ) 264 | 265 | print(f"saving SD model to: {args.save_to}") 266 | sdxl_model_util.save_stable_diffusion_checkpoint( 267 | args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype 268 | ) 269 | else: 270 | state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) 271 | 272 | print(f"calculating hashes and creating metadata...") 273 | 274 | model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) 275 | metadata["sshs_model_hash"] = model_hash 276 | metadata["sshs_legacy_hash"] = legacy_hash 277 | 278 | if not args.no_metadata: 279 | merged_from = sai_model_spec.build_merged_from(args.models) 280 | title = os.path.splitext(os.path.basename(args.save_to))[0] 281 | sai_metadata = sai_model_spec.build_metadata( 282 | state_dict, False, False, True, True, False, time.time(), title=title, merged_from=merged_from 283 | ) 284 | metadata.update(sai_metadata) 285 | 286 | print(f"saving model to: {args.save_to}") 287 | save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) 288 | 289 | 290 | def setup_parser() -> argparse.ArgumentParser: 291 | parser = argparse.ArgumentParser() 292 | parser.add_argument( 293 | "--save_precision", 294 | type=str, 295 | default=None, 296 | choices=[None, "float", "fp16", "bf16"], 297 | help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", 298 | ) 299 | parser.add_argument( 300 | "--precision", 301 | type=str, 302 | default="float", 303 | choices=["float", "fp16", "bf16"], 304 | help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", 305 | ) 306 | parser.add_argument( 307 | "--sd_model", 308 | type=str, 309 | default=None, 310 | help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", 311 | ) 312 | parser.add_argument( 313 | "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" 314 | ) 315 | parser.add_argument( 316 | "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" 317 | ) 318 | parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") 319 | parser.add_argument( 320 | "--no_metadata", 321 | action="store_true", 322 | help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " 323 | + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", 324 | ) 325 | parser.add_argument( 326 | "--concat", 327 | action="store_true", 328 | help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " 329 | + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", 330 | ) 331 | parser.add_argument( 332 | "--shuffle", 333 | action="store_true", 334 | help="shuffle lora weight./ " 335 | + "LoRAの重みをシャッフルする", 336 | ) 337 | 338 | return parser 339 | 340 | 341 | if __name__ == "__main__": 342 | parser = setup_parser() 343 | 344 | args = parser.parse_args() 345 | merge(args) 346 | -------------------------------------------------------------------------------- /scripts/kohyas/svd_merge_lora.py: -------------------------------------------------------------------------------- 1 | import math 2 | import argparse 3 | import os 4 | import time 5 | import torch 6 | from safetensors.torch import load_file, save_file 7 | from tqdm import tqdm 8 | from scripts.kohyas import sai_model_spec, train_util 9 | 10 | 11 | CLAMP_QUANTILE = 0.99 12 | 13 | 14 | def load_state_dict(file_name, dtype): 15 | if os.path.splitext(file_name)[1] == ".safetensors": 16 | sd = load_file(file_name) 17 | metadata = train_util.load_metadata_from_safetensors(file_name) 18 | else: 19 | sd = torch.load(file_name, map_location="cpu") 20 | metadata = {} 21 | 22 | for key in list(sd.keys()): 23 | if type(sd[key]) == torch.Tensor: 24 | sd[key] = sd[key].to(dtype) 25 | 26 | return sd, metadata 27 | 28 | 29 | def save_to_file(file_name, state_dict, dtype, metadata): 30 | if dtype is not None: 31 | for key in list(state_dict.keys()): 32 | if type(state_dict[key]) == torch.Tensor: 33 | state_dict[key] = state_dict[key].to(dtype) 34 | 35 | if os.path.splitext(file_name)[1] == ".safetensors": 36 | save_file(state_dict, file_name, metadata=metadata) 37 | else: 38 | torch.save(state_dict, file_name) 39 | 40 | 41 | def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): 42 | print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") 43 | merged_sd = {} 44 | v2 = None 45 | base_model = None 46 | for model, ratio in zip(models, ratios): 47 | print(f"loading: {model}") 48 | lora_sd, lora_metadata = load_state_dict(model, merge_dtype) 49 | 50 | if lora_metadata is not None: 51 | if v2 is None: 52 | v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string 53 | if base_model is None: 54 | base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) 55 | 56 | # merge 57 | print(f"merging...") 58 | for key in tqdm(list(lora_sd.keys())): 59 | if "lora_down" not in key: 60 | continue 61 | 62 | lora_module_name = key[: key.rfind(".lora_down")] 63 | 64 | down_weight = lora_sd[key] 65 | network_dim = down_weight.size()[0] 66 | 67 | up_weight = lora_sd[lora_module_name + ".lora_up.weight"] 68 | alpha = lora_sd.get(lora_module_name + ".alpha", network_dim) 69 | 70 | in_dim = down_weight.size()[1] 71 | out_dim = up_weight.size()[0] 72 | conv2d = len(down_weight.size()) == 4 73 | kernel_size = None if not conv2d else down_weight.size()[2:4] 74 | # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) 75 | 76 | # make original weight if not exist 77 | if lora_module_name not in merged_sd: 78 | weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype) 79 | if device: 80 | weight = weight.to(device) 81 | else: 82 | weight = merged_sd[lora_module_name] 83 | 84 | # merge to weight 85 | if device: 86 | up_weight = up_weight.to(device) 87 | down_weight = down_weight.to(device) 88 | 89 | # W <- W + U * D 90 | scale = alpha / network_dim 91 | 92 | if device: # and isinstance(scale, torch.Tensor): 93 | scale = scale.to(device) 94 | 95 | if not conv2d: # linear 96 | weight = weight + ratio * (up_weight @ down_weight) * scale 97 | elif kernel_size == (1, 1): 98 | weight = ( 99 | weight 100 | + ratio 101 | * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) 102 | * scale 103 | ) 104 | else: 105 | conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) 106 | weight = weight + ratio * conved * scale 107 | 108 | merged_sd[lora_module_name] = weight 109 | 110 | # extract from merged weights 111 | print("extract new lora...") 112 | merged_lora_sd = {} 113 | with torch.no_grad(): 114 | for lora_module_name, mat in tqdm(list(merged_sd.items())): 115 | conv2d = len(mat.size()) == 4 116 | kernel_size = None if not conv2d else mat.size()[2:4] 117 | conv2d_3x3 = conv2d and kernel_size != (1, 1) 118 | out_dim, in_dim = mat.size()[0:2] 119 | 120 | if conv2d: 121 | if conv2d_3x3: 122 | mat = mat.flatten(start_dim=1) 123 | else: 124 | mat = mat.squeeze() 125 | 126 | module_new_rank = new_conv_rank if conv2d_3x3 else new_rank 127 | module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim 128 | 129 | U, S, Vh = torch.linalg.svd(mat) 130 | 131 | U = U[:, :module_new_rank] 132 | S = S[:module_new_rank] 133 | U = U @ torch.diag(S) 134 | 135 | Vh = Vh[:module_new_rank, :] 136 | 137 | dist = torch.cat([U.flatten(), Vh.flatten()]) 138 | hi_val = torch.quantile(dist, CLAMP_QUANTILE) 139 | low_val = -hi_val 140 | 141 | U = U.clamp(low_val, hi_val) 142 | Vh = Vh.clamp(low_val, hi_val) 143 | 144 | if conv2d: 145 | U = U.reshape(out_dim, module_new_rank, 1, 1) 146 | Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1]) 147 | 148 | up_weight = U 149 | down_weight = Vh 150 | 151 | merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous() 152 | merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous() 153 | merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank) 154 | 155 | # build minimum metadata 156 | dims = f"{new_rank}" 157 | alphas = f"{new_rank}" 158 | if new_conv_rank is not None: 159 | network_args = {"conv_dim": new_conv_rank, "conv_alpha": new_conv_rank} 160 | else: 161 | network_args = None 162 | metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, network_args) 163 | 164 | return merged_lora_sd, metadata, v2 == "True", base_model 165 | 166 | 167 | def merge(args): 168 | assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" 169 | 170 | def str_to_dtype(p): 171 | if p == "float": 172 | return torch.float 173 | if p == "fp16": 174 | return torch.float16 175 | if p == "bf16": 176 | return torch.bfloat16 177 | return None 178 | 179 | merge_dtype = str_to_dtype(args.precision) 180 | save_dtype = str_to_dtype(args.save_precision) 181 | if save_dtype is None: 182 | save_dtype = merge_dtype 183 | 184 | new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank 185 | state_dict, metadata, v2, base_model = merge_lora_models( 186 | args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype 187 | ) 188 | 189 | print(f"calculating hashes and creating metadata...") 190 | 191 | model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) 192 | metadata["sshs_model_hash"] = model_hash 193 | metadata["sshs_legacy_hash"] = legacy_hash 194 | 195 | if not args.no_metadata: 196 | is_sdxl = base_model is not None and base_model.lower().startswith("sdxl") 197 | merged_from = sai_model_spec.build_merged_from(args.models) 198 | title = os.path.splitext(os.path.basename(args.save_to))[0] 199 | sai_metadata = sai_model_spec.build_metadata( 200 | state_dict, v2, v2, is_sdxl, True, False, time.time(), title=title, merged_from=merged_from 201 | ) 202 | if v2: 203 | # TODO read sai modelspec 204 | print( 205 | "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" 206 | ) 207 | metadata.update(sai_metadata) 208 | 209 | print(f"saving model to: {args.save_to}") 210 | save_to_file(args.save_to, state_dict, save_dtype, metadata) 211 | 212 | 213 | def setup_parser() -> argparse.ArgumentParser: 214 | parser = argparse.ArgumentParser() 215 | parser.add_argument( 216 | "--save_precision", 217 | type=str, 218 | default=None, 219 | choices=[None, "float", "fp16", "bf16"], 220 | help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", 221 | ) 222 | parser.add_argument( 223 | "--precision", 224 | type=str, 225 | default="float", 226 | choices=["float", "fp16", "bf16"], 227 | help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", 228 | ) 229 | parser.add_argument( 230 | "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" 231 | ) 232 | parser.add_argument( 233 | "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" 234 | ) 235 | parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") 236 | parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") 237 | parser.add_argument( 238 | "--new_conv_rank", 239 | type=int, 240 | default=None, 241 | help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", 242 | ) 243 | parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") 244 | parser.add_argument( 245 | "--no_metadata", 246 | action="store_true", 247 | help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " 248 | + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", 249 | ) 250 | 251 | return parser 252 | 253 | 254 | if __name__ == "__main__": 255 | parser = setup_parser() 256 | 257 | args = parser.parse_args() 258 | merge(args) 259 | -------------------------------------------------------------------------------- /scripts/mbwpresets_master.txt: -------------------------------------------------------------------------------- 1 | preset_name preset_weights 2 | GRAD_V 0,1,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0 3 | GRAD_A 0,0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0 4 | FLAT_25 0,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25 5 | FLAT_75 0,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75 6 | WRAP08 0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1 7 | WRAP12 0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1 8 | WRAP14 0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1 9 | WRAP16 0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1 10 | MID12_50 0,0,0,0,0,0,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0,0,0,0,0,0 11 | OUT07 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1 12 | OUT12 0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1 13 | OUT12_5 0,0,0,0,0,0,0,0,0,0,0,0,0,0.5,1,1,1,1,1,1,1,1,1,1,1,1 14 | RING08_SOFT 0,0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0 15 | RING08_5 0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0 16 | RING10_5 0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0 17 | RING10_3 0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0,0,0,0 18 | SMOOTHSTEP 0,0,0.00506365740740741,0.0196759259259259,0.04296875,0.0740740740740741,0.112123842592593,0.15625,0.205584490740741,0.259259259259259,0.31640625,0.376157407407407,0.437644675925926,0.5,0.562355324074074,0.623842592592592,0.68359375,0.740740740740741,0.794415509259259,0.84375,0.887876157407408,0.925925925925926,0.95703125,0.980324074074074,0.994936342592593,1 19 | REVERSE-SMOOTHSTEP 0,1,0.994936342592593,0.980324074074074,0.95703125,0.925925925925926,0.887876157407407,0.84375,0.794415509259259,0.740740740740741,0.68359375,0.623842592592593,0.562355324074074,0.5,0.437644675925926,0.376157407407408,0.31640625,0.259259259259259,0.205584490740741,0.15625,0.112123842592592,0.0740740740740742,0.0429687499999996,0.0196759259259258,0.00506365740740744,0 20 | SMOOTHSTEP*2 0,0,0.0101273148148148,0.0393518518518519,0.0859375,0.148148148148148,0.224247685185185,0.3125,0.411168981481482,0.518518518518519,0.6328125,0.752314814814815,0.875289351851852,1.,0.875289351851852,0.752314814814815,0.6328125,0.518518518518519,0.411168981481481,0.3125,0.224247685185184,0.148148148148148,0.0859375,0.0393518518518512,0.0101273148148153,0 21 | R_SMOOTHSTEP*2 0,1,0.989872685185185,0.960648148148148,0.9140625,0.851851851851852,0.775752314814815,0.6875,0.588831018518519,0.481481481481481,0.3671875,0.247685185185185,0.124710648148148,0.,0.124710648148148,0.247685185185185,0.3671875,0.481481481481481,0.588831018518519,0.6875,0.775752314814816,0.851851851851852,0.9140625,0.960648148148149,0.989872685185185,1 22 | SMOOTHSTEP*3 0,0,0.0151909722222222,0.0590277777777778,0.12890625,0.222222222222222,0.336371527777778,0.46875,0.616753472222222,0.777777777777778,0.94921875,0.871527777777778,0.687065972222222,0.5,0.312934027777778,0.128472222222222,0.0507812500000004,0.222222222222222,0.383246527777778,0.53125,0.663628472222223,0.777777777777778,0.87109375,0.940972222222222,0.984809027777777,1 23 | R_SMOOTHSTEP*3 0,1,0.984809027777778,0.940972222222222,0.87109375,0.777777777777778,0.663628472222222,0.53125,0.383246527777778,0.222222222222222,0.05078125,0.128472222222222,0.312934027777778,0.5,0.687065972222222,0.871527777777778,0.94921875,0.777777777777778,0.616753472222222,0.46875,0.336371527777777,0.222222222222222,0.12890625,0.0590277777777777,0.0151909722222232,0 24 | SMOOTHSTEP*4 0,0,0.0202546296296296,0.0787037037037037,0.171875,0.296296296296296,0.44849537037037,0.625,0.822337962962963,0.962962962962963,0.734375,0.49537037037037,0.249421296296296,0.,0.249421296296296,0.495370370370371,0.734375000000001,0.962962962962963,0.822337962962962,0.625,0.448495370370369,0.296296296296297,0.171875,0.0787037037037024,0.0202546296296307,0 25 | R_SMOOTHSTEP*4 0,1,0.97974537037037,0.921296296296296,0.828125,0.703703703703704,0.55150462962963,0.375,0.177662037037037,0.0370370370370372,0.265625,0.50462962962963,0.750578703703704,1.,0.750578703703704,0.504629629629629,0.265624999999999,0.0370370370370372,0.177662037037038,0.375,0.551504629629631,0.703703703703703,0.828125,0.921296296296298,0.979745370370369,1 26 | SMOOTHSTEP/2 0,0,0.0196759259259259,0.0740740740740741,0.15625,0.259259259259259,0.376157407407407,0.5,0.623842592592593,0.740740740740741,0.84375,0.925925925925926,0.980324074074074,1.,0.980324074074074,0.925925925925926,0.84375,0.740740740740741,0.623842592592593,0.5,0.376157407407407,0.259259259259259,0.15625,0.0740740740740741,0.0196759259259259,0 27 | R_SMOOTHSTEP/2 0,1,0.980324074074074,0.925925925925926,0.84375,0.740740740740741,0.623842592592593,0.5,0.376157407407407,0.259259259259259,0.15625,0.0740740740740742,0.0196759259259256,0.,0.0196759259259256,0.0740740740740742,0.15625,0.259259259259259,0.376157407407407,0.5,0.623842592592593,0.740740740740741,0.84375,0.925925925925926,0.980324074074074,1 28 | SMOOTHSTEP/3 0,0,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1.,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0.,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1 29 | R_SMOOTHSTEP/3 0,1,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0.,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1.,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0 30 | SMOOTHSTEP/4 0,0,0.0740740740740741,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740741,0.,0.0740740740740741,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740741,0 31 | R_SMOOTHSTEP/4 0,1,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740742,0.,0.0740740740740742,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740742,0.,0.0740740740740742,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1 32 | COSINE 0,1,0.995722430686905,0.982962913144534,0.961939766255643,0.933012701892219,0.896676670145617,0.853553390593274,0.80438071450436,0.75,0.691341716182545,0.62940952255126,0.565263096110026,0.5,0.434736903889974,0.37059047744874,0.308658283817455,0.25,0.195619285495639,0.146446609406726,0.103323329854382,0.0669872981077805,0.0380602337443566,0.0170370868554658,0.00427756931309475,0 33 | REVERSE_COSINE 0,0,0.00427756931309475,0.0170370868554659,0.0380602337443566,0.0669872981077808,0.103323329854383,0.146446609406726,0.19561928549564,0.25,0.308658283817455,0.37059047744874,0.434736903889974,0.5,0.565263096110026,0.62940952255126,0.691341716182545,0.75,0.804380714504361,0.853553390593274,0.896676670145618,0.933012701892219,0.961939766255643,0.982962913144534,0.995722430686905,1 34 | TRUE_CUBIC_HERMITE 0,0,0.199031876929012,0.325761959876543,0.424641927083333,0.498456790123457,0.549991560570988,0.58203125,0.597360869984568,0.598765432098765,0.589029947916667,0.570939429012346,0.547278886959876,0.520833333333333,0.49438777970679,0.470727237654321,0.45263671875,0.442901234567901,0.444305796682099,0.459635416666667,0.491675106095678,0.543209876543211,0.617024739583333,0.715904706790124,0.842634789737655,1 35 | TRUE_REVERSE_CUBIC_HERMITE 0,1,0.800968123070988,0.674238040123457,0.575358072916667,0.501543209876543,0.450008439429012,0.41796875,0.402639130015432,0.401234567901235,0.410970052083333,0.429060570987654,0.452721113040124,0.479166666666667,0.50561222029321,0.529272762345679,0.54736328125,0.557098765432099,0.555694203317901,0.540364583333333,0.508324893904322,0.456790123456789,0.382975260416667,0.284095293209876,0.157365210262345,0 36 | FAKE_CUBIC_HERMITE 0,0,0.157576195987654,0.28491512345679,0.384765625,0.459876543209877,0.512996720679012,0.546875,0.564260223765432,0.567901234567901,0.560546875,0.544945987654321,0.523847415123457,0.5,0.476152584876543,0.455054012345679,0.439453125,0.432098765432099,0.435739776234568,0.453125,0.487003279320987,0.540123456790124,0.615234375,0.71508487654321,0.842423804012347,1 37 | FAKE_REVERSE_CUBIC_HERMITE 0,1,0.842423804012346,0.71508487654321,0.615234375,0.540123456790123,0.487003279320988,0.453125,0.435739776234568,0.432098765432099,0.439453125,0.455054012345679,0.476152584876543,0.5,0.523847415123457,0.544945987654321,0.560546875,0.567901234567901,0.564260223765432,0.546875,0.512996720679013,0.459876543209876,0.384765625,0.28491512345679,0.157576195987653,0 38 | ALL_A 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 39 | ALL_B 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1 40 | -------------------------------------------------------------------------------- /scripts/mergers/bcolors.py: -------------------------------------------------------------------------------- 1 | class bcolors: 2 | HEADER = '\033[95m' 3 | OKBLUE = '\033[94m' 4 | OKCYAN = '\033[96m' 5 | OKGREEN = '\033[92m' 6 | WARNING = '\033[93m' 7 | FAIL = '\033[91m' 8 | ENDC = '\033[0m' 9 | BOLD = '\033[1m' 10 | UNDERLINE = '\033[4m' -------------------------------------------------------------------------------- /scripts/mergers/components.py: -------------------------------------------------------------------------------- 1 | merge = None 2 | mergeandgen = None 3 | gen = None 4 | s_reserve = None 5 | s_reserve1 = None 6 | gengrid = None 7 | s_startreserve = None 8 | rand_merge = None 9 | paramsnames = None 10 | frompromptf = None 11 | 12 | msettings = None 13 | esettings1 = None 14 | genparams = None 15 | hiresfix = None 16 | lucks = None 17 | currentmodel = None 18 | dfalse = None 19 | dtrue = None 20 | id_sets = None 21 | xysettings = None 22 | 23 | submit_result = None 24 | imagegal = None 25 | numaframe = None 26 | 27 | txt2img_params = [] 28 | img2img_params = [] -------------------------------------------------------------------------------- /scripts/mergers/model_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import safetensors.torch 4 | import threading 5 | from modules import shared, sd_hijack, sd_models 6 | from modules.sd_models import load_model 7 | import json 8 | 9 | try: 10 | from modules import sd_models_xl 11 | xl = True 12 | except: 13 | xl = False 14 | 15 | def prune_model(model, isxl=False): 16 | keys = list(model.keys()) 17 | base_prefix = "conditioner." if isxl else "cond_stage_model." 18 | for k in keys: 19 | if "diffusion_model." not in k and "first_stage_model." not in k and base_prefix not in k: 20 | model.pop(k, None) 21 | return model 22 | 23 | def to_half(sd): 24 | for key in sd.keys(): 25 | if 'model' in key and sd[key].dtype in {torch.float32, torch.float64, torch.bfloat16}: 26 | sd[key] = sd[key].half() 27 | return sd 28 | 29 | def savemodel(state_dict,currentmodel,fname,savesets,metadata={}): 30 | other_dict = {} 31 | if state_dict is None: 32 | if shared.sd_model and shared.sd_model.sd_checkpoint_info: 33 | metadata = shared.sd_model.sd_checkpoint_info.metadata.copy() 34 | else: 35 | return "Current model is not a valid merged model" 36 | 37 | checkpoint_info = shared.sd_model.sd_checkpoint_info 38 | # check if current merged model is a fake checkpoint_info 39 | if checkpoint_info is not None: 40 | filename = checkpoint_info.filename 41 | name = os.path.basename(filename) 42 | info = sd_models.get_closet_checkpoint_match(name) 43 | if info == checkpoint_info: 44 | # this is a valid checkpoint_info 45 | # no need to save 46 | return "Current model is not a merged model or you've already saved model" 47 | 48 | # prepare metadata 49 | save_metadata = "save metadata" in savesets 50 | if save_metadata: 51 | metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"]) 52 | else: 53 | metadata = {"format": "pt"} 54 | 55 | if shared.sd_model is not None: 56 | print("load from shared.sd_model..") 57 | 58 | # restore textencoder 59 | sd_hijack.model_hijack.undo_hijack(shared.sd_model) 60 | 61 | for name,module in shared.sd_model.named_modules(): 62 | if hasattr(module,"network_weights_backup"): 63 | module = network_restore_weights_from_backup(module) 64 | 65 | state_dict = shared.sd_model.state_dict() 66 | for key in list(state_dict.keys()): 67 | if key in POPKEYS: 68 | other_dict[key] = state_dict[key] 69 | del state_dict[key] 70 | 71 | sd_hijack.model_hijack.hijack(shared.sd_model) 72 | else: 73 | return "No current loaded model found" 74 | 75 | # name_for_extra was set with the currentmodel 76 | currentmodel = checkpoint_info.name_for_extra 77 | 78 | if "fp16" in savesets: 79 | pre = ".fp16" 80 | else:pre = "" 81 | ext = ".safetensors" if "safetensors" in savesets else ".ckpt" 82 | 83 | # is it a inpainting or instruct-pix2pix2 model? 84 | if "model.diffusion_model.input_blocks.0.0.weight" in state_dict.keys(): 85 | shape = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape 86 | if shape[1] == 9: 87 | pre += "-inpainting" 88 | if shape[1] == 8: 89 | pre += "-instruct-pix2pix" 90 | 91 | if not fname or fname == "": 92 | model_name = os.path.basename(currentmodel) 93 | model_name = model_name.replace(" ", "").replace(",", "_").replace("(", "_").replace(")", "_") 94 | fname = model_name + pre + ext 95 | if fname[0] == "_": 96 | fname = fname[1:] 97 | else: 98 | fname = fname if ext in fname else fname +pre+ext 99 | 100 | fname = os.path.join(shared.cmd_opts.ckpt_dir if shared.cmd_opts.ckpt_dir is not None else sd_models.model_path, fname) 101 | fname = fname.replace("ProgramFiles_x86_","Program Files (x86)") 102 | 103 | if len(fname) > 255: 104 | fname.replace(ext,"") 105 | fname=fname[:240]+ext 106 | 107 | # check if output file already exists 108 | if os.path.isfile(fname) and not "overwrite" in savesets: 109 | _err_msg = f"Output file ({fname}) existed and was not saved]" 110 | print(_err_msg) 111 | return _err_msg 112 | 113 | print("Saving...") 114 | isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in state_dict 115 | if isxl: 116 | # prune share memory tensors, "cond_stage_model." prefixed base tensors are share memory with "conditioner." prefixed tensors 117 | for key in list(state_dict.keys()): 118 | if "cond_stage_model." in key: 119 | del state_dict[key] 120 | 121 | if "fp16" in savesets: 122 | state_dict = to_half(state_dict) 123 | if "prune" in savesets: 124 | state_dict = prune_model(state_dict, isxl) 125 | 126 | # for safetensors contiguous error 127 | print("Check contiguous...") 128 | for key in state_dict.keys(): 129 | v = state_dict[key] 130 | v = v.contiguous() 131 | state_dict[key] = v 132 | 133 | try: 134 | if ext == ".safetensors": 135 | safetensors.torch.save_file(state_dict, fname, metadata=metadata) 136 | else: 137 | torch.save(state_dict, fname) 138 | except Exception as e: 139 | print(f"ERROR: Couldn't saved:{fname},ERROR is {e}") 140 | return f"ERROR: Couldn't saved:{fname},ERROR is {e}" 141 | print("Done!") 142 | if other_dict: 143 | for key in other_dict.keys(): 144 | state_dict[key] = other_dict[key] 145 | del other_dict 146 | load_model(checkpoint_info, already_loaded_state_dict=state_dict) 147 | return "Merged model saved in "+fname 148 | 149 | def filenamecutter(name,model_a = False): 150 | if name =="" or name ==[]: return 151 | checkpoint_info = sd_models.get_closet_checkpoint_match(name) 152 | name= os.path.splitext(checkpoint_info.filename)[0] 153 | 154 | if not model_a: 155 | name = os.path.basename(name) 156 | return name 157 | 158 | from typing import Union 159 | 160 | def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): 161 | weights_backup = getattr(self, "network_weights_backup", None) 162 | bias_backup = getattr(self, "network_bias_backup", None) 163 | 164 | if weights_backup is None and bias_backup is None: 165 | return self 166 | 167 | with torch.no_grad(): 168 | if weights_backup is not None: 169 | if isinstance(self, torch.nn.MultiheadAttention): 170 | self.in_proj_weight = torch.nn.Parameter(weights_backup[0].detach().requires_grad_(self.in_proj_weight.requires_grad)) 171 | self.out_proj.weight = torch.nn.Parameter(weights_backup[1].detach().requires_grad_(self.out_proj.weight.requires_grad)) 172 | else: 173 | self.weight = torch.nn.Parameter(weights_backup.detach().requires_grad_(self.weight.requires_grad)) 174 | 175 | if bias_backup is not None: 176 | if isinstance(self, torch.nn.MultiheadAttention): 177 | self.out_proj.bias = torch.nn.Parameter(bias_backup.detach().requires_grad_(self.out_proj.bias.requires_grad)) 178 | else: 179 | self.bias = torch.nn.Parameter(bias_backup.detach().requires_grad_(self.bias.requires_grad)) 180 | else: 181 | if isinstance(self, torch.nn.MultiheadAttention): 182 | self.out_proj.bias = None 183 | else: 184 | self.bias = None 185 | return self 186 | 187 | def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): 188 | self.network_current_names = () 189 | self.network_weights_backup = None 190 | self.network_bias_backup = None 191 | 192 | POPKEYS=[ 193 | "betas", 194 | "alphas_cumprod", 195 | "alphas_cumprod_prev", 196 | "sqrt_alphas_cumprod", 197 | "sqrt_one_minus_alphas_cumprod", 198 | "log_one_minus_alphas_cumprod", 199 | "sqrt_recip_alphas_cumprod", 200 | "sqrt_recipm1_alphas_cumprod", 201 | "posterior_variance", 202 | "posterior_log_variance_clipped", 203 | "posterior_mean_coef1", 204 | "posterior_mean_coef2", 205 | ] 206 | -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | button.reset { 2 | min-width: min(100px,100%); 3 | } 4 | 5 | button.compact_button { 6 | min-width: min(100px,100%); 7 | } 8 | --------------------------------------------------------------------------------