├── .github
└── FUNDING.yml
├── .gitignore
├── LICENSE
├── README.md
├── README_ja.md
├── calcmode_en.md
├── calcmode_ja.md
├── changelog.md
├── elemental_en.md
├── elemental_ja.md
├── install.py
├── sample.txt
├── scripts
├── A1111
│ ├── lora_patches.py
│ ├── lyco_helpers.py
│ ├── network.py
│ ├── network_full.py
│ ├── network_glora.py
│ ├── network_hada.py
│ ├── network_ia3.py
│ ├── network_lokr.py
│ ├── network_lora.py
│ ├── network_norm.py
│ ├── network_oft.py
│ └── networks.py
├── GenParamGetter.py
├── Roboto-Regular.ttf
├── kohyas
│ ├── extract_lora_from_models.py
│ ├── ipex
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── diffusers.py
│ │ ├── gradscaler.py
│ │ └── hijacks.py
│ ├── lora.py
│ ├── merge_lora.py
│ ├── model_util.py
│ ├── original_unet.py
│ ├── sai_model_spec.py
│ ├── sdxl_merge_lora.py
│ ├── sdxl_model_util.py
│ ├── sdxl_original_unet.py
│ ├── svd_merge_lora.py
│ └── train_util.py
├── mbwpresets_master.txt
├── mergers
│ ├── bcolors.py
│ ├── components.py
│ ├── mergers.py
│ ├── model_util.py
│ ├── pluslora.py
│ └── xyplot.py
└── supermerger.py
└── style.css
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github:[hako-mikan] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: # Replace with a single Ko-fi username
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
14 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | changelog.m
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
--------------------------------------------------------------------------------
/README_ja.md:
--------------------------------------------------------------------------------
1 | # SuperMerger
2 | - [AUTOMATIC1111's stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 用のモデルマージ拡張
3 | - マージしたモデルを保存せず直接生成に使用できます
4 |
5 | [
](README.md)
6 | [
](#overview)
7 | [
](https://github.com/sponsors/hako-mikan)
8 |
9 |
10 | # Overview
11 | このextentionではモデルをマージした際、保存せずに画像生成用のモデルとして読み込むことができます。
12 | これまでマージしたモデルはいったん保存して気に入らなければ削除するということが必要でしたが、このextentionを使うことでHDDやSSDの消耗を防ぐことができます。
13 | モデルの保存とロードの時間を節約できるため、比率を変更しながら連続生成することによりモデルマージの効率を大幅に向上させます。
14 |
15 | # もくじ
16 | - [Merge Models](#merge-models)
17 | - [Merge Block Weight](#merge-block-weight)
18 | - [XYZ Plot](#xyz-plot)
19 | - [Adjust](#adjust)
20 | - [Let the Dice Roll](#let-the-dice-roll)
21 | - [Elemental Merge](#elemental-merge)
22 | - [Generation Parameters](#generation-parameters)
23 | - [LoRA](#lora)
24 | - [Merge LoRAs](#merge-loras)
25 | - [Merge to Checipoint](#merge-to-checkpoint)
26 | - [Extract from Checkpoints](#extract-from-checkpoints)
27 | - [Other Tabs](#other-tabs)
28 |
29 | - [Calcomode](calcmode_ja.md)
30 | - [Elemental Merge](elemental_ja.md)
31 |
32 |
33 | # Recent Update
34 | 2023.11.10
35 | - LoRA抽出スクリプトの変更: SDXL,2.Xをサポート
36 | - SD2.XでのモデルへのLoRAマージをサポート
37 | - バグフィックス
38 | - 新しい計算方式extract merge(checkpoint/LoRA)を追加。作成した[subaqua](https://github.com/sbq0)氏に感謝します。
39 |
40 | ## 計算関数の変更
41 | いくつかのモードでは、計算に使用する機能が変更され、マージ計算の速度が向上しています。計算結果は同じですが、同じ結果が得られない場合は、オプションの「use old calc method」をチェックしてください。影響を受ける方式は以下の通りです:
42 | Weight Sum: normal, cosineA, cosineB
43 | Sum Twice:normal
44 | 提案した[wkpark](https://github.com/wkpark)氏に感謝します。
45 |
46 | **注意!**
47 | XLモデルのマージには最低64GBのCPUメモリが必要です。64Gのメモリであっても併用しているソフトによってはシステムが不安定になる恐れがあるのでシステムが落ちてもいい状態で作業して下さい。私は久しぶりにブルースクリーンに遭遇しました。
48 |
49 | ## 既知の問題
50 | 他の拡張機能(sd-webui-prompt-all-in-oneなど)を同時にインストールしている場合、起動時にブラウザを自動的に開くオプションを有効にすると、動作が不安定になることがあります。Gradioの問題である可能性が高いので、修正は難しいです。そのオプションを無効にしてお使いください。
51 |
52 | すべての更新は[ここ](changelog.md) で確認できます。
53 |
54 | # つかいかた
55 | ## Merge Models
56 | ここでマージされたモデルは、Web-UIの生成モデルとしてロードされます。左上のモデル表示は変わりませんが、マージされたモデルは実際にロードされています。別のモデルが左上のモデル選択から選択されるまで、マージされたモデルはロードされたままになります。
57 | ### Basic Usage
58 | Select models A/B/(C), the merge mode, and alpha (beta), then press Merge/Merge and Gen to start the merging process. In the case of Merge and Gen, generation is carried out using the prompt and other settings specified in txt2img.The Gen button only generates images, and the Stop button interrupts the merging process.
59 | モデルA/B/(C)、merge mode、alpha (beta)を選択し、Merge/Merge and Genを押すとマージ処理が始まります。Merge and Genの場合は、txt2imgで指定されたプロンプトやその他の設定を使用して生成が行われます。Genボタンは画像のみを生成し、Stopボタンはマージを中断します。
60 |
61 | ### Load Settings From:
62 | マージログから設定を読み込みます。マージが行われるたびにログが更新され、1から始まる連続IDが割り当てられます。"-1"は最後のマージからの設定に対応し、"-2"は最後から二番目のものに対応します。マージログはextension/sd-webui-supermerger/mergehistory.csvに保存されます。Historyタブで閲覧や検索ができます。半角スペースで区切ってand/orで検索できます。
63 | ### Clear Cache
64 | Web-UIのモデルキャッシュ機能が有効になっている場合、SuperMergerは連続マージを高速化するためにモデルキャッシュを作成します。モデルはWeb-UIのキャッシュ機能とは別にキャッシュされます。使用後にキャッシュを削除するにはこのボタンを使用してください。キャッシュはVRAMではなくRAMに作成されます。
65 |
66 | # 各種設定
67 | ## マージモード
68 | ### Weight sum
69 | 通常のマージです。alphaが使用されます。α=0の場合Model A, α=1 の時model Bになります。
70 | ### Add difference
71 | 差分マージです。
72 | ### Triple sum
73 | マージを3モデル同時に行います。alpha,betaが使用されます。モデル選択窓が3つあったので追加した機能ですが、ちゃんと動くようです。MBWでも使えます。それぞれMBWのalpha,betaを入力してください。
74 | ### sum Twice
75 | Weight sumを2回行います。alpha,betaが使用されます。MBWモードでも使えます。それぞれMBWのalpha,betaを入力してください。
76 |
77 | ### use MBW
78 | チェックするとブロックごとのマージ(階層マージ)が有効になります。各ブロックごとの比率は下部のスライダーかプリセットで設定してください。
79 |
80 |
81 | ### Merge mode
82 | #### Weight sum $(1-\alpha) A + \alpha B$
83 | 通常のマージです。alphaが使用されます。$\alpha$=0の場合Model A, $\alpha$=1 の時model Bになります。
84 | #### Add difference $A + \alpha (B-C)$
85 | 差分を加算します。MBWが有効な場合、$\alpha$としてMBWベースが使用されます。
86 | #### Triple sum $(1-\alpha - \beta) A + \alpha B + \beta C$
87 | 同時に3つのモデルをマージします。$\alpha$と$\beta$が使用されます。3つのモデル選択ウィンドウがあったためこの機能を追加しましたが、実際に効果的に動作するかはわかりません。
88 | #### sum Twice $(1-\beta)((1-\alpha)A+\alpha B)+\beta C$
89 | Weight sumを2回行います。$\alpha$と$\beta$が使用されます。
90 |
91 | ### calcmode
92 | 各計算方法の詳細については[リンク先](calcmode_ja.md)を参照してください。計算方法とマージモードの対応表は以下の通りです。
93 | | Calcmode | Description | Merge Mode |
94 | |----|----|----|
95 | |normal | 通常の計算方法 | ALL |
96 | |cosineA | モデルAを基準にマージ中の損失を最小限にする計算を行います。 | Weight sum |
97 | |cosineB | モデルBを基準にマージ中の損失を最小限にする計算を行います。 | Weight sum |
98 | |trainDifference |モデルAに対してファインチューニングするかのように差分を'トレーニング'します。 | Add difference |
99 | |smoothAdd | 中央値フィルタとガウスフィルタの利点を混合した差分の追加 | Add difference |
100 | |smoothAdd MT| マルチスレッドを使用して計算を高速化します。 | Add difference |
101 | |extract | モデルB/Cの共通点・非共通点を抽出して追加します. | Add difference |
102 | |tensor| 加算の代わりにテンソルを比率で入れ替えます | Weight sum |
103 | |tensor2 |テンソルの次元が大きい場合、2次元目を基準にして入れ替えます | Weight sum |
104 | |self | モデル自身に$\alpha$を掛け合わせます | Weight sum |
105 |
106 | ### use MBW
107 | 階層マーを有効にします。Merge Block Weightで重みを設定してください。これを有効にすると、アルファとベータが無効になります。
108 |
109 | ### Options
110 | | 設定 | 説明 |
111 | |-----------------|---------------------------------------------------|
112 | | save model | マージ後のモデルを保存します。 |
113 | | overwrite | モデルの上書きを有効にします。 |
114 | | safetensors | safetensors形式で保存します。 |
115 | | fp16 | 半精度で保存します。 |
116 | | save metadata | 保存時にメタデータにマージ情報を保存します。(safetensorsのみ) |
117 | | prune | 保存時にモデルから不要なデータを削除します。 |
118 | | Reset CLIP ids | CLIP idsをリセットします。 |
119 | | use old calc method | 古い計算関数を使用します|
120 | | debug | デバッグ情報をCMDに出力します。 |
121 |
122 | ### save merged model ID to
123 | 生成された画像またはPNG情報にマージIDを保存するかどうかを選択できます。
124 | ### Additional Value
125 | 現在、calcmodeで'extract'が選択されている場合にのみ有効です。
126 | ### Custom Name (Optional)
127 | モデル名を設定できます。設定されていない場合は、自動的に決定されます。
128 | ### Bake in VAE
129 | モデルを保存するとき、選択されたVAEがモデルに組み込まれます。
130 |
131 | ### Save current Merge
132 | 現在ロードされているモデルを保存します。PCのスペックやその他の問題により、マージに時間がかかる場合に効果的です。
133 |
134 | ## Merge Block Weight
135 | これは階層ごとに割合を設定できるマージ技法です。各階層ごは背景の描写、キャラクター、絵柄などに対応している可能性があるため、階層ごごとの割合を変更することで、様々なマージモデルを作成できます。
136 |
137 | 階層はSDのバージョンによって異なり、以下のような階層があります。
138 |
139 | BASEはテキストエンコーダを指し、プロンプトへの反応などに影響します。IN-OUTはU-Netを指し、画像生成を担当します。
140 |
141 | Stable diffusion 1.X, 2.X
142 | |1|2|3|4|5|6|7|8|9|10|11|12|13|14|15|16|17|18|19|20|21|22|23|24|25|26|
143 | |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|
144 | |BASE|IN00|IN01|IN02|IN03|IN04|IN05|IN06|IN07|IN08|IN09|IN10|IN11|MID|OUT00|OUT01|OUT02|OUT03|OUT04|OUT05|OUT06|OUT07|OUT08|OUT09|OUT10|OUT11|
145 |
146 | Stable diffusion XL
147 | |1|2|3|4|5|6|7|8|9|10|11|12|13|14|15|16|17|18|19|20|
148 | |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|
149 | |BASE|IN00|IN01|IN02|IN03|IN04|IN05|IN06|IN07|IN08|MID|OUT00|OUT01|OUT02|OUT03|OUT04|OUT05|OUT06|OUT07|OUT08|
150 |
151 | ## XYZ Plot
152 | ## XYZプロット
153 | 連続的にマージ画像を生成します。すべてのマージモードで効果的です。
154 | ### alpha, beta
155 | alphaとbetaの値を変更します。
156 | ### alpha と beta
157 | alphaとbetaを同時に変更します。alphaとbetaはスペース1つで区切り、各要素はカンマで区切ってください。1つの数字のみ入力された場合、alphaとbetaに同じ値が入力されます。
158 | 例: 0, 0.5 0.1, 0.3 0.4, 0.5
159 | ### MBW
160 | 階層ごとにマージを実行します。改行で区切った割合を入力してください。プリセットは使用できますが、必ず**改行で区切る**ようにしてください。TripleやTwiceの場合は、2行を1セットで入力してください。奇数行の入力ではエラーになります。
161 | ### seed
162 | シード値を変更します。-1を入力すると反対軸方向の固定シードになります。
163 | ### model_A, B, C
164 | モデルを変更します。モデル選択ウィンドウで選択したモデルは無視されます。
165 | ### pinpoint block
166 | MBWで特定の階層のみを変更します。反対軸にはalphaかbetaを選択してください。階層Dを入力すると、その階層のalpha (beta)のみが変更されます。他のタイプと同様にカンマで区切って入力してください。スペースやハイフンで区切ることで、同時に複数の階層を変更できます。効果を得るには、必ず先頭にNOTを入力してください。
167 | #### 入力例
168 | IN01, OUT10 OUT11, OUT03-OUT06, OUT07-OUT11, NOT M00 OUT03-OUT06
169 | この場合
170 | - 1:IN01のみ変化
171 | - 2:OUT10およびOUT11が変化
172 | - 3:OUT03からOUT06が変化
173 | - 4:OUT07からOUT11が変化
174 | - 5:M00およびOUT03からOUT06以外が変化
175 |
176 | となります。0の打ち忘れに注意してください。
177 | 
178 |
179 | 階層ID (大文字のみ有効)
180 | BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,IN09,IN10,IN11,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11
181 |
182 | XL model
183 | BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08
184 |
185 | ### calcmode
186 | 計算方式を変更します。適用できるマージモードとの対応に注意して下さい。カンマで区切ります
187 |
188 | ### prompt
189 | プロンプトを変更できます。txt2imgのプロンプトは無視されます。ネガティブプロンプトは有効です。
190 | **改行で区切る**ことに注意をして下さい。
191 |
192 | ### XYプロットの予約
193 | Reserve XY plotボタンはすぐさまプロットを実行せず、ボタンを押したときの設定のXYプロットの実行を予約します。予約したXYプロットは通常のXYプロットが終了した後か、ReservationタブのStart XY plotボタンを押すと実行が開始されます。予約はXYプロット実行時・未実行時いつでも可能です。予約一覧は自動更新されないのでリロードボタンを使用してください。エラー発生時はそのプロットを破棄して次の予約を実行します。すべての予約が終了するまで画像は表示されませんが、Finishedになったものについてはグリッドの生成は終わっているので、Image Browser等で見ることが可能です。
194 | 「|」を使用することで任意の場所で予約へ移動することも可能です。
195 | 0.1,0.2,0.3,0.4,0.5|0.6,0.7,0.8,0.9,1.0とすると
196 |
197 | 0.1,0.2,0.3,0.4,0.5
198 | 0.6,0.7,0.8,0.9,1.0
199 | というふたつの予約に分割され実行されます。これは要素が多すぎてグリッドが大きくなってしまう場合などに有効でしょう。
200 |
201 | ## Adjust
202 | これは、モデルの細部と色調を補正します。LoRAとは異なるメカニズムを採用しています。U-Netの入力と出力のポイントを調整することで、画像の細部と色調を調整できます。
203 | 
204 | ## 使い方
205 |
206 | テキストボックスに直接入力するか、スライダーで値を決めてから上ボタンを押してテキストボックスに反映できます。空白のままにしておくと無視されます。
207 | カンマで区切った7つの数字を入力してください。
208 |
209 | ```
210 | 0,0,0,0,0,0,0,0
211 | ```
212 | これがデフォルトで、これらの値をずらすことで効果が現れます。
213 |
214 | ### Each setting value
215 | ### 各設定値の意味
216 | 8つの数字は以下に対応しています。
217 | 1. 描き込み/ノイズ
218 | 2. 描き込み/ノイズ
219 | 3. 描き込み/ノイズ
220 | 4. コントラスト/描き込み
221 | 5. 明るさ
222 | 6. 色調1 (シアン-赤)
223 | 7. 色調2 (マゼンタ-緑)
224 | 8. 色調3 (黄-青)
225 |
226 | 描き込みが増えるほどノイズも必然的に増えることに注意してください。また、Hires.fixを使用する場合、出力が異なる可能性があるので、使用される設定でテストすることをおすすめします。
227 |
228 | 値は+/-5程度までは問題ないと思われますが、モデルによって異なります。正の値を入力すると描き込みが増えます。色調には3種類あり、おおむねカラーバランスに対応しているようです。
229 | ### 1,2,3 Detail/Noise
230 | 1.はU-Netの入り口に相当する部分です。ここを調節すると画像の描き込み量が調節できます。ここはOUTに比べて構図が変わりやすいです。マイナスにするとフラットに、そして少しぼけた感じに。プラスにすると描き込みが増えノイジーになります。通常の生成でノイジーでもhires.fixできれいになることがあるので注意してください。2,3はOUTに相当する部分です。
231 | 
232 |
233 | ### 4. Contrast/Detail
234 | ここを調節するとコントラストや明るさがかわり、同時に描き込み量も変わります。サンプルを見てもらった方が早いですね。
235 | 
236 |
237 | ### 5,6,7,8 Brightness, Color Tone
238 | 明るさと色調を補正できます。概ねカラーバランスに対応するようです。
239 | 
240 |
241 | ## Let the Dice roll
242 | ランダムにマージ比率を決定します。一度の複数のランダムマージが行えます。比率は各ブロック、各エレメントごとにランダムにすることが可能です。
243 | ## 使い方
244 | Let the Dice rollで使用できます。`Random Mode`を選択し`Run Rand`を押すと`Num of challenge`の回数分ランダムにウェイトが設定されて画像が生成されます。生成はXYZモードで動作するので`STOP`ボタンが有効です。`Seed for Random Ratio`は`-1`に設定して下さい。Num of challengeの回数が2回以上の場合、自動的に-1に設定されます。同じseedを使うと再現性があります。生成数が10を超える場合グリッドは自動的に2次元になります。`Settings`の`alpha`、`beta`はチェックするとランダム化されます。Elementalの場合`beta`は無効化されます。
245 |
246 | ## 各モード
247 | ### R,U,X
248 | 26ブロックすべてに対してランダムなウェイトが設定されます。`R`、`U`、`X`の違いは乱数の値の範囲です。Xは各層に対して`lower limit` ~ `upper limit`で指定します。
249 | R : 0 ~ 1
250 | U : -0.5 ~ 1.5
251 | X : lower limit ~ upper limit
252 | ### ER,EU,EX
253 | Elementすべてに対してランダムなウェイトが設定されます。`ER`、`EU`、`EX`の違いは乱数の値の範囲です。Xは各層に対して`lower limit` ~ `upper limit`で指定します。
254 |
255 | ### custom
256 | ランダム化される階層を指定します。`costom`で指定します。
257 | `R`、`U`、`X`、`ER`、`EU`、`EX`が使用できます。
258 | 例:
259 | ```
260 | U,0,0,0,0,0,0,0,0,0,0,0,0,R,R,R,R,R,R,R,R,R,R,R,R,R
261 | U,0,0,0,0,0,0,0,0,0,0,0,0,ER,0,0,0,0,0,0,X,0,0,0,0,0
262 | ```
263 | ### XYZモード
264 | typeに`random`を設定することで使用できます。ランダム化する回数を入力すると、回数分軸の要素が設定されます。
265 | ```
266 | X type : seed, -1,-1,-1
267 | Y type : random, 5
268 | ```
269 | とすると、3×5のgridができ、5回分ランダムにウェイトが設定されたモデルで生成されます。ランダムかの設定はランダムのパネルで設定して下さい。ここが`off`では正常に動作しません。
270 |
271 | ### Settings
272 | - `round` は丸める小数点以下の桁数を設定します。初期値は3で、0.123のようになります。
273 | - `save E-list` はElementalのキーと割合をcsv形式で`script/data/`に保存します。
274 |
275 | ## Elemental Merge
276 | [こちら](elemental_ja.md)を参照して下さい。
277 |
278 |
279 |
280 | ## Generation Parameters
281 |
282 | ここでは画像生成の条件も設定できます。ここで値を設定すると優先されます。
283 | ## Include/Exclude
284 | マージする(しない)階層を設定できます。 Includeの場合ここで設定した階層のみマージされます。Excludeは逆でここで設定した階層のみマージされません。「print」にチェックを入れると、コマンドプロンプト画面で階層が除外されたか確認できます。「Adjust」にチェックを入れると、Adjustで使用する要素がのみマージ/除外されます。`attn`などの文字列も指定でき、この場合`attn`を含む要素のみマージ/除外されます。文字列はカンマで区切ってください。
285 | ## unload button
286 | 現在ロードされているモデルを削除します。kohya-ss GUI使用時にGPUメモリを解放するために使用します。モデルが削除されると画像生成ができなくなります。画像生成を行いたい場合は、再度モデルを選択してください。
287 |
288 |
289 | ## LoRA
290 | LoRA関連の機能です。基本的にはkohya-ssのスクリプトと同じですが、階層マージに対応します。現時点ではV2.X系のマージには対応していません。
291 |
292 | 注意:LyCORISは構造が特殊なため単独マージのみに対応しています。単独マージの比率は1,0のみ使用可能です。他の値を用いるとsame to Strengthでも階層LoRAの結果と一致しません。
293 | LoConは整数以外でもそれなりに一致します。
294 |
295 | LoCon/LyCoris のモデルへのマージにはweb-ui1.5以上が必要です。
296 | | 1.X,2.X | LoRA | LoCon | LyCORIS |
297 | |----------|-------|-------|---------|
298 | | Merge to Model | Yes | Yes | Yes |
299 | | Merge LoRAs | Yes | Yes | No |
300 | | Apply Block Weight(single)|Yes|Yes|Yes|
301 | | Extract From Models | Yes | No | No |
302 |
303 | | XL | LoRA | LoCon | LyCORIS |
304 | |----------|-------|-------|---------|
305 | | Merge to Model | Yes | Yes | Yes |
306 | | Merge LoRAs | Yes | Yes | No |
307 | | Extract From Models | Yes | No | No |
308 |
309 |
310 | ### Merge LoRAs
311 | ひとつまたは複数のLoRA同士をマージします。kohya-ss氏の最新のスクリプトを使用しているので、dimensionの異なるLoRA同氏もマージ可能ですが、dimensionの変換の際はLoRAの再計算を行うため、生成される画像が大きく異なる可能性があることに注意してください。
312 |
313 | calculate dimentionボタンで各LoRAの次元を計算して表示・ソート機能が有効化します。計算にはわりと時間がかかって、50程度のLoRAでも数十秒かかります。新しくマージされたLoRAはリストに表示されないのでリロードボタンを押してください。次元の再計算は追加されたLoRAだけを計算します。
314 |
315 | ### Merge to Checkpoint
316 | Merge LoRAs into a model. Multiple LoRAs can be merged at the same time.
317 | Enter LoRA name1:ratio1:block1,LoRA name2:ratio2:block2,...
318 | LoRA can also be used alone. The ":block" part can be omitted. The ratio can be any value, including negative values. There is no restriction that the total must sum to 1 (of course, if it greatly exceeds 1, it will break down).
319 | To use ":block", use a block preset name from the bottom list of presets, or create your own. Ex:
320 | ```
321 | LoRAname1:ratio1
322 | LoRAname1:ratio1:ALL
323 | LoRAname1:ratio1:1,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0
324 | ```
325 |
326 | ### Extract from checkpoints
327 | ふたつのモデルの差分からLoRAを生成します。
328 | demensionを指定すると指定されたdimensionで作製されます。無指定の場合は128で作製します。
329 | alphaとbetaによって配合比率を調整することができます。$(\alpha A - \beta B)$ alpha, beta = 1が通常のLoRA作成となります。
330 |
331 | ### Extract from tow LoRAs
332 | [こちら](calcmode_ja.md#extractlora)を参照して下さい。
333 |
334 | ### Metadata
335 | #### create new
336 | 新しく最小限のMetadataを作製します。dim,alpha,basemodelのversion,filename,networktypeのみが作製されます。
337 | #### merge
338 | 各LoRAの情報が保存され、タグがマージされます。
339 | (各LoRAの情報はWeb-uiの簡易Metadata読み込み機能では見えません)
340 | #### save all
341 | 各LoRAの情報が保存されます。
342 | (各LoRAの情報はWeb-uiの簡易Metadata読み込み機能では見えません)
343 | #### use first LoRA
344 | 最初のLoRAの情報をそのままコピーします
345 |
346 | ### Get Ratios from Prompt
347 | プロンプト欄からLoRAの比率設定を読み込みます。これはLoRA Block Weightの設定も含まれ、そのままマージが行えます。
348 |
349 | ### Difference between Normal Merge and SAME TO STRENGTH
350 | same to Strengthオプションを使用しない場合は、kohya-ss氏の作製したスクリプトのマージと同じ結果になります。この場合、下図のようにWeb-ui上でLoRAを適用した場合と異なる結果になります。これはLoRAをU-netに組み込む際の数式が関係しています。kohya-ss氏のスクリプトでは比率をそのまま掛けていますが、適用時の数式では比率が2乗されてしまうため、比率を1以外の数値に設定すると、あるいはマイナスに設定するとStrength(適用時の強度)と異なる結果となります。same to Strengthオプションを使用すると、マージ時には比率の平方根を駆けることで、適用時にはStrengthと比率が同じ意味を持つように計算しています。また、マイナスが効果が出るようにも計算しています。追加学習をしない場合などはsame to Strengthオプションを使用しても問題ないと思いますが、マージしたLoRAに対して追加学習をする場合はだれも使用しない方がいいかもしれません。
351 |
352 | 下図は通常適用/same to Strengthオプション/通常マージの各場合の生成画像です。figma化とukiyoE LoRAのマージを使用しています。通常マージの場合はマイナス方向でも2乗されてプラスになっていることが分かります。
353 | 
354 |
355 | ## Other tabs
356 | ## Analysis
357 | 2つのモデルの違いを分析してください。比較したいモデルを選んでください、モデルAとモデルBを。
358 | ### Mode
359 |
360 | ASimilalityモードは、qkvから計算されたテンソルを比較します。他のモードは各要素のコサイン類似度から計算します。ASimilalityモード以外では計算された差が小さくなるようです。ASimilalityモードは出力画像の違いに近い結果を与えるため、一般的にはこのモードを使用すべきです。
361 | このAsimilality分析は、[Asimilality script](https://huggingface.co/JosephusCheung/ASimilarityCalculatior)を拡張して作成されました。
362 |
363 | ### Block Method
364 | ASimilalityモード以外のモードで各階層の比率を計算する方法です。Meanは平均を表し、minは最小値を表し、attn2は階層の計算結果としてattn2の値を出力します。
365 |
366 | ## History
367 | マージ履歴を検索することができます。検索機能は「and」と「or」の両方の検索に対応しています。
368 |
369 | ## Elements
370 | モデルに含まれるElementのリスト、階層の割り当て、およびテンソルのサイズを取得できます。
371 |
372 | ## 謝辞
373 | このスクリプトは[kohya](https://github.com/kohya-ss)氏,[bbc-mc](https://github.com/bbc-mc)氏のスクリプト一部使用しています。また、拡張の開発に貢献した全ての方々にも感謝します。
374 |
--------------------------------------------------------------------------------
/changelog.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 | ### 2023.10.15
3 | Adjust機能が改良されました。CD TunerのようにBrightness, Cyan-Red, Magenta-Gree, Yellow-Blueのような色指定に変わります。
4 | その他バグfix
5 | 自動でClip Idのリセットを行うオプションを追加
6 | Adjust feature was improved, changed to "Brightness, Cyan-Red, Magenta-Gree, Yellow-Blue" like CD Tuner
7 | other bug fixes were added.
8 | Added option to reset CLIP Ids.
9 |
10 | ### update 2023.10.03
11 | [Merge function of metadata for LoRAs is changed.](#Metadata)
12 |
13 | ### update 2023.09.02.1900(JST)
14 | モデルキャッシュに関する仕様が変わりました。
15 | モデルキャッシュを設定している場合、これまではweb-ui側のモデルキャッシュを使用していましたが、web-ui側の仕様変更により使えなくなりました。
16 | そこで、web-ui側のモデルキャッシュを無効にしてSuperMerger側でモデルをキャッシュするように変更しました。よって、モデルキャッシュを使用する設定にしている場合、SuperMerger使用後にモデルのキャッシュが残ることになります。SuperMeger使用後はClear cacheボタンでメモリの解放を行って下さい。
17 |
18 | The specifications regarding model caching have changed. If you have set up model caching, we used to utilize the model cache on the web-ui side. However, due to changes in the web-ui specifications, this is no longer possible. Therefore, I have disabled the model cache on the web-ui side and have made changes to cache the model on the SuperMerger side instead. As a result, if you have set it to use model caching, the model cache will remain after using SuperMerger. Please clear the cache using the "Clear cache" button to free up memory after using SuperMerger.
19 |
20 |
21 | bug fix/以下のバグを修正しました
22 | - XYZ plot でseedを選択すると発生するバグ
23 | - not work when selecting Seed in XYZ plot
24 | - LoRAマージができないバグ
25 | - Merging LoRA to checkpoint is not work
26 | - Hires fix を使用していないときでもDenoising Strength がPNGinfoに設定される
27 | - Denoising Strength is set to PNG info when Hires fix is not enabled
28 | - LOWVRAM/MEDVRAM使用時に正常に動作しない
29 | - bug when LOWVRAM/MEDVRAM
30 |
31 | ### update 2023.08.31
32 |
33 | - ほぼすべてのtxt2imgタブの設定を使えるようになりました
34 | - Almost all txt2img tab settings are now available in generation
35 | Thanks! [Filexor](https://github.com/Filexor)
36 |
37 | - support XL
38 | - XLモデル対応
39 |
40 | XL capabilities at the moment:XLでいまできること
41 | Merge/Block merge/マージ/階層マージ
42 | Merging LoRA into the model (supported within a few days)/モデルへのLoRAのマージ
43 |
44 | Cannot be done:できないこと
45 | Creating LoRA from model differences (TBD)/モデル差分からLoRA作成(未定)
46 |
47 | ### update 2023.07.07.2000(JST)
48 | - add new feature:[Random merge](#random-merge)
49 | - add new feature:[Adjust detail/colors](https://github.com/hako-mikan/sd-webui-supermerger/blob/main/elemental_en.md#adjust)
50 |
51 | ### update 2023.06.28.2000(JST)
52 | - add Image Generation Parameters(prompt,seed,etc.)
53 | for Vlad fork users, use this panel
54 |
55 | ### update 2023.06.27.2030
56 | - Add new calcmode "trainDifference"[detail here](https://github.com/hako-mikan/sd-webui-supermerger/blob/main/calcmode_en.md#trainDifference) (Thanks [SwiftIllusion](https://github.com/SwiftIllusion))
57 | - Add Z axis for XY plot
58 | - Add Analysis tab for caclrating the difference of models (thanks [Mohamed-Desouki](https://github.com/Mohamed-Desouki))
59 |
60 | ### update 2023.06.24.0300(JST)
61 | - VAE bake feature added
62 | - support inpainting/pix2pix
63 | Thanks [wkpark](https://github.com/wkpark)
64 |
65 | ### update 2023.05.02.1900(JST)
66 | - bug fix : Resolved conflict with wildcard in dynamic prompt
67 | - new feature : restore face and tile option added
68 |
69 | ### update 2023.04.19.2030(JST)
70 | - New feature, optimization using cosine similarity method updated [detail here](https://github.com/hako-mikan/sd-webui-supermerger/blob/main/calcmode_en.md#cosine)
71 | - New feature, tensor merge added [detail here](https://github.com/hako-mikan/sd-webui-supermerger/blob/main/calcmode_en.md#tensor)
72 | - New XY plot type : calcmode,prompt
73 |
74 | ### bug fix 2023.02.19.2330(JST)
75 | いくつかのバグが修正されました
76 | - LOWRAMオプション有効時にエラーになる問題
77 | - Linuxでエラーになる問題
78 | - XY plotが正常に終了しない問題
79 | - 未ロードのモデルを設定時にエラーになる問題
80 |
81 | ### update to version 3 2023.02.17.2020(JST)
82 | - LoRA関係の機能を追加しました
83 | - Logを保存し、設定を呼び出せるようになりました
84 | - safetensors,fp16形式での保存に対応しました
85 | - weightのプリセットに対応しました
86 | - XYプロットの予約が可能になりました
87 |
88 | ### bug fix 2023.02.19.2330(JST)
89 | Several bugs have been fixed
90 | - Error when LOWRAM option is enabled
91 | - Error on Linux
92 | - XY plot did not finish properly
93 | - Error when setting unused models
94 |
95 | ### update to version 3 2023.02.17.2020(JST)
96 | - Added LoRA related functions
97 | - Logs can now be saved and settings can be recalled.
98 | - Save in safetensors and fp16 format is now supported.
99 | - Weight presets are now supported.
100 | - Reservation of XY plots is now possible.
101 |
102 | ### bug fix 2023.01.29.0000(JST)
103 | pinpoint blocksがX方向で使用できない問題を修正しました。
104 | pinpoint blocks選択時Triple,Twiceを使用できない問題を解決しました
105 | XY plot 使用時に一部軸タイプでMBWを使用できない問題を解決しました
106 | Fixed a problem where pinpoint blocks could not be used in the X axis.
107 | Fixed a problem in which Triple,Twice could not be used when selecting pinpoint blocks.
108 | Problem solved where MBW could not be used with some axis types when using XY plot.
109 |
110 | ### bug fixed 2023.01.28.0100(JST)
111 | MBWモードでSave current modelボタンが正常に動作しない問題を解決しました
112 | ファイル名が長すぎて保存時にエラーが出る問題を解決しました
113 | Problem solved where the "Save current model" button would not work properly in MBW mode
114 | Problem solved where an error would occur when saving a file with too long a file name
115 |
116 | ### bug fixed 2023.01.26.2100(JST)
117 | XY plotにおいてタイプMBWが使用できない不具合を修正しました
118 | Fixed a bug that type of MBW could work in XY plot
119 |
120 | ### updated 2023.01.25.0000(JST)
121 | added several features
122 | - added new merge mode "Triple sum","sum Twice"
123 | - added XY plot
124 | - 新しいマージモードを追加しました "Triple sum","sum Twice"
125 | - XY plot機能を追加しました
126 |
127 | ### bug fixed 2023.01.20.2350(JST)
128 | png infoがうまく保存できない問題を解決しました。
129 | Problem solved where png info could not be saved properly.
130 |
--------------------------------------------------------------------------------
/elemental_en.md:
--------------------------------------------------------------------------------
1 | # Elemental Merge
2 | - This is a block-by-block merge that goes beyond block-by-block merge.
3 |
4 | In a block-by-block merge, the merge ratio can be changed for each of the 25 blocks, but a blocks also consists of multiple elements, and in principle it is possible to change the ratio for each element. It is possible, but the number of elements is more than 600, and it was doubtful whether it could be handled by human hands, but we tried to implement it. I do not recommend merging elements by element out of the blue. It is recommended to use it as a final adjustment when a problem that cannot be solved by block-by-block merging.
5 | The following images show the result of changing the elements in the OUT05 layer. The leftmost one is without merging, the second one is all the OUT05 layers (i.e., normal block-by-block merging), and the rest are element merging. As shown in the table below, there are several more elements in attn2, etc.
6 | 
7 | ## Usage
8 | Note that elemental merging is effective for both normal and block-by-block merging, and is computed last, so it will overwrite values specified for block-by-block merging.
9 |
10 | Set in Elemental Merge. Note that if text is set here, it will be automatically adapted. Each element is listed in the table below, but it is not necessary to enter the full name of each element.
11 | You can check to see if the effect is properly applied by activating "print change" check. If this check is enabled, the applied elements will be displayed on the command prompt screen during the merge.
12 |
13 | ### Format
14 | Bloks:Element:Ratio, Bloks:Element:Ratio,...
15 | or
16 | Bloks:Element:Ratio
17 | Bloks:Element:Ratio
18 | Bloks:Element:Ratio
19 |
20 | Multiple specifications can be specified by separating them with commas or newlines. Commas and newlines may be mixed.
21 | Bloks can be specified in uppercase from BASE,IN00-M00-OUT11. If left blank, all Bloks will be applied. Multiple Bloks can be specified by separating them with a space.
22 | Similarly, multiple elements can be specified by separating them with a space.
23 | Partial matching is used, so for example, typing "attn" will change both attn1 and attn2, and typing "attn2" will change only attn2. If you want to specify more details, enter "attn2.to_out" and so on.
24 |
25 | OUT03 OUT04 OUT05:attn2 attn1.to_out:0.5
26 |
27 | the ratio of elements containing attn2 and attn1.to_out in the OUT03, OUT04 and OUT05 layers will be 0.5.
28 | If the element column is left blank, all elements in the specified Blocks will change, and the effect will be the same as a block-by-block merge.
29 | If there are duplicate specifications, the one entered later takes precedence.
30 |
31 | OUT06:attn:0.5,OUT06:attn2.to_k:0.2
32 |
33 | is entered, attn other than attn2.to_k in the OUT06 layer will be 0.5, and only attn2.to_k will be 0.2.
34 |
35 | You can invert the effect by first entering NOT.
36 | This can be set by Blocks and Element.
37 |
38 | NOT OUT04:attn:1
39 |
40 | will set the ratio 1 to the attn of all Blocks except the OUT04 layer.
41 |
42 | OUT05:NOT attn proj:0.2
43 |
44 | will set all Blocks except attn and proj in the OUT05 layer to 0.2.
45 |
46 | ## XY plot
47 | Several XY plots for elemental merge are available. Input examples can be found in sample.txt.
48 | #### elemental
49 | Creates XY plots for multiple elemental merges. Elements should be separated from each other by blank lines.
50 | The following image is the result of executing sample1 of sample.txt.
51 |
52 | #### pinpoint element
53 | Creates an XY plot with different values for a specific element. Do the same with elements as with Pinpoint Blocks, but specify alpha for the opposite axis. Separate elements with a new line or comma.
54 | The following image shows the result of running sample 3 of sample.txt.
55 | 
56 |
57 | #### effective elenemtal checker
58 | Outputs the difference of each element's effective elenemtal checker. The gif.csv file will be created in the output folder under the ModelA and ModelB folders in the diff folder. If there are duplicate file names, rename and save the files, but it is recommended to rename the diff folder to an appropriate name because it is complicated when the number of files increases.
59 | Separate the files with a new line or comma. Use alpha for the opposite axis and enter a single value. This is useful to see the effect of an element, but it is also possible to see the effect of a hierarchy by not specifying an element, so you may use it that way more often.
60 | The following image shows the result of running sample5 of sample.txt.
61 | 
62 | 
63 |
64 | ### List of elements
65 | Basically, it seems that attn is responsible for the face and clothing information. The IN07, OUT03, OUT04, and OUT05 layers seem to have a particularly strong influence. It does not seem to make sense to change the same element in multiple Blocks at the same time, since the degree of influence often differs depending on the Blocks.
66 | No element exists where it is marked null.
67 |
68 | ||IN00|IN01|IN02|IN03|IN04|IN05|IN06|IN07|IN08|IN09|IN10|IN11|M00|M00|OUT00|OUT01|OUT02|OUT03|OUT04|OUT05|OUT06|OUT07|OUT08|OUT09|OUT10|OUT11
69 | |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|
70 | op.bias|null|null|null||null|null||null|null||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null
71 | op.weight|null|null|null||null|null||null|null||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null
72 | emb_layers.1.bias|null|||null|||null|||null|null|||||||||||||||
73 | emb_layers.1.weight|null|||null|||null|||null|null|||||||||||||||
74 | in_layers.0.bias|null|||null|||null|||null|null|||||||||||||||
75 | in_layers.0.weight|null|||null|||null|||null|null|||||||||||||||
76 | in_layers.2.bias|null|||null|||null|||null|null|||||||||||||||
77 | in_layers.2.weight|null|||null|||null|||null|null|||||||||||||||
78 | out_layers.0.bias|null|||null|||null|||null|null|||||||||||||||
79 | out_layers.0.weight|null|||null|||null|||null|null|||||||||||||||
80 | out_layers.3.bias|null|||null|||null|||null|null|||||||||||||||
81 | out_layers.3.weight|null|||null|||null|||null|null|||||||||||||||
82 | skip_connection.bias|null|null|null|null||null|null||null|null|null|null|null|null||||||||||||
83 | skip_connection.weight|null|null|null|null||null|null||null|null|null|null|null|null||||||||||||
84 | norm.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
85 | norm.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
86 | proj_in.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
87 | proj_in.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
88 | proj_out.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
89 | proj_out.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
90 | transformer_blocks.0.attn1.to_k.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
91 | transformer_blocks.0.attn1.to_out.0.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
92 | transformer_blocks.0.attn1.to_out.0.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
93 | transformer_blocks.0.attn1.to_q.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
94 | transformer_blocks.0.attn1.to_v.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
95 | transformer_blocks.0.attn2.to_k.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
96 | transformer_blocks.0.attn2.to_out.0.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
97 | transformer_blocks.0.attn2.to_out.0.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
98 | transformer_blocks.0.attn2.to_q.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
99 | transformer_blocks.0.attn2.to_v.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
100 | transformer_blocks.0.ff.net.0.proj.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
101 | transformer_blocks.0.ff.net.0.proj.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
102 | transformer_blocks.0.ff.net.2.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
103 | transformer_blocks.0.ff.net.2.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
104 | transformer_blocks.0.norm1.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
105 | transformer_blocks.0.norm1.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
106 | transformer_blocks.0.norm2.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
107 | transformer_blocks.0.norm2.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
108 | transformer_blocks.0.norm3.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
109 | transformer_blocks.0.norm3.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
110 | conv.bias|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null||null|null||null|null||null|null|null
111 | conv.weight|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null||null|null||null|null||null|null|null
112 | 0.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
113 | 0.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
114 | 2.bias|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
115 | 2.weight|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
116 | time_embed.0.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
117 | time_embed.0.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
118 | time_embed.2.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
119 | time_embed.2.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
120 |
--------------------------------------------------------------------------------
/elemental_ja.md:
--------------------------------------------------------------------------------
1 | # Elemental Merge
2 | - 階層マージを越えた階層マージです
3 |
4 | 階層マージでは25の階層ごとにマージ比率を変えることができますが、階層もまた複数の要素で構成されており、要素ごとに比率を変えることも原理的には可能です。可能ですが、要素の数は600以上にもなり人の手で扱えるのかは疑問でしたが実装してみました。いきなり要素ごとのマージは推奨されません。階層マージにおいて解決不可能な問題が生じたときに最終調節手段として使うことをおすすめします。
5 | 次の画像はOUT05層の要素を変えた結果です。左端はマージ無し。2番目はOUT05層すべて(つまりは普通の階層マージ),以降が要素マージです。下表のとおり、attn2などの中にはさらに複数の要素が含まれます。
6 | 
7 |
8 | ## 使い方
9 | 要素マージは通常マージ、階層マージ時どちらの場合でも有効で、最後に計算されるために、階層マージで指定した値は上書きされることに注意してください。
10 |
11 | Elemental Mergeで設定します。ここにテキストが設定されていると自動的に適応されるので注意して下さい。各要素は下表のとおりですが、各要素のフルネームを入力する必要はありません。
12 | ちゃんと効果が現れるかどうかはprint changeチェックを有効にすることで確認できます。このチェックを有効にするとマージ時にコマンドプロンプト画面に適用された要素が表示されます。
13 | 部分一致で指定が可能です。
14 | ### 書式
15 | 階層:要素:比率,階層:要素:比率,...
16 | または
17 | 階層:要素:比率
18 | 階層:要素:比率
19 | 階層:要素:比率
20 |
21 | カンマまたは改行で区切ることで複数の指定が可能です。カンマと改行は混在しても問題ありません。
22 | 階層は大文字でBASE,IN00-M00-OUT11まで指定でます。空欄にするとすべての階層に適用されます。スペースで区切ることで複数の階層を指定できます。
23 | 要素も同様でスペースで区切ることで複数の要素を指定できます。
24 | 部分一致で判断するので、例えば「attn」と入力するとattn1,attn2両方が変化します。「attn2」の場合はattn2のみ。さらに細かく指定したい場合は「attn2.to_out」などと入力します。
25 |
26 | OUT03 OUT04 OUT05:attn2 attn1.to_out:0.5
27 |
28 | と入力すると、OUT03,OUT04,OUT05層のattn2が含まれる要素及びattn1.to_outの比率が0.5になります。
29 | 要素の欄を空欄にすると指定階層のすべての要素が変わり、階層マージと同じ効果になります。
30 | 指定が重複する場合、後に入力された方が優先されます。
31 |
32 | OUT06:attn:0.5,OUT06:attn2.to_k:0.2
33 |
34 | と入力した場合、OUT06層のattn2.to_k以外のattnは0.5,attn2.to_kのみ0.2となります。
35 |
36 | 最初にNOTと入力することで効果範囲を反転させることができます。
37 | これは階層・要素別に設定できます。
38 |
39 | NOT OUT04:attn:1
40 |
41 | と入力するとOUT04層以外の層のattnに比率1が設定されます。
42 |
43 | OUT05:NOT attn proj:0.2
44 |
45 | とすると、OUT05層のattnとproj以外の層が0.2になります。
46 |
47 | ## XY plot
48 | elemental用のXY plotを複数用意しています。入力例はsample.txtにあります。
49 | #### elemental
50 | 複数の要素マージについてXY plotを作成します。要素同士は空行で区切ってください。
51 | トップ画像はsample.txtのsample1を実行した結果です。
52 |
53 | #### pinpoint element
54 | 特定の要素について値を変えてXY plotを作成します。pinpoint Blocksと同じことを要素で行います。反対の軸にはalphaを指定してください。要素同士は改行またはカンマで区切ります。
55 | 以下の画像はsample.txtのsample3を実行した結果です。
56 | 
57 |
58 | #### effective elenemtal checker
59 | 各要素の影響度を差分として出力します。オプションでanime gif、csvファイルを出力できます。gif.csvファイルはoutputフォルダにModelAとModelBから作られるフォルダ下に作成されるdiffフォルダに作成されます。ファイル名が重複する場合名前を変えて保存しますが、増えてくるとややこしいのでdiffフォルダを適当な名前に変えることをおすすめします。
60 | 改行またはカンマで区切ります。反対の軸はalphaを使用し、単一の値を入力してください。これは要素の効果を見るのにも有効ですが、要素を指定しないことで階層の効果を見ることも可能なので、そちらの使い方をする場合が多いかもしれません。
61 | 以下の画像はsample.txtのsample5を実行した結果です。
62 | 
63 | 
64 |
65 |
66 | ### 要素一覧
67 | 基本的にはattnが顔や服装の情報を担っているようです。特にIN07,OUT03,OUT04,OUT05層の影響度が強いようです。階層によって影響度が異なることが多いので複数の層の同じ要素を同時に変化させることは意味が無いように思えます。
68 | nullと書かれた場所には要素が存在しません。
69 |
70 | ||IN00|IN01|IN02|IN03|IN04|IN05|IN06|IN07|IN08|IN09|IN10|IN11|M00|M00|OUT00|OUT01|OUT02|OUT03|OUT04|OUT05|OUT06|OUT07|OUT08|OUT09|OUT10|OUT11
71 | |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|
72 | op.bias|null|null|null||null|null||null|null||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null
73 | op.weight|null|null|null||null|null||null|null||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null
74 | emb_layers.1.bias|null|||null|||null|||null|null|||||||||||||||
75 | emb_layers.1.weight|null|||null|||null|||null|null|||||||||||||||
76 | in_layers.0.bias|null|||null|||null|||null|null|||||||||||||||
77 | in_layers.0.weight|null|||null|||null|||null|null|||||||||||||||
78 | in_layers.2.bias|null|||null|||null|||null|null|||||||||||||||
79 | in_layers.2.weight|null|||null|||null|||null|null|||||||||||||||
80 | out_layers.0.bias|null|||null|||null|||null|null|||||||||||||||
81 | out_layers.0.weight|null|||null|||null|||null|null|||||||||||||||
82 | out_layers.3.bias|null|||null|||null|||null|null|||||||||||||||
83 | out_layers.3.weight|null|||null|||null|||null|null|||||||||||||||
84 | skip_connection.bias|null|null|null|null||null|null||null|null|null|null|null|null||||||||||||
85 | skip_connection.weight|null|null|null|null||null|null||null|null|null|null|null|null||||||||||||
86 | norm.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
87 | norm.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
88 | proj_in.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
89 | proj_in.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
90 | proj_out.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
91 | proj_out.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
92 | transformer_blocks.0.attn1.to_k.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
93 | transformer_blocks.0.attn1.to_out.0.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
94 | transformer_blocks.0.attn1.to_out.0.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
95 | transformer_blocks.0.attn1.to_q.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
96 | transformer_blocks.0.attn1.to_v.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
97 | transformer_blocks.0.attn2.to_k.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
98 | transformer_blocks.0.attn2.to_out.0.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
99 | transformer_blocks.0.attn2.to_out.0.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
100 | transformer_blocks.0.attn2.to_q.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
101 | transformer_blocks.0.attn2.to_v.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
102 | transformer_blocks.0.ff.net.0.proj.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
103 | transformer_blocks.0.ff.net.0.proj.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
104 | transformer_blocks.0.ff.net.2.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
105 | transformer_blocks.0.ff.net.2.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
106 | transformer_blocks.0.norm1.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
107 | transformer_blocks.0.norm1.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
108 | transformer_blocks.0.norm2.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
109 | transformer_blocks.0.norm2.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
110 | transformer_blocks.0.norm3.bias|null|||null|||null|||null|null|null||null|null|null|null|||||||||
111 | transformer_blocks.0.norm3.weight|null|||null|||null|||null|null|null||null|null|null|null|||||||||
112 | conv.bias|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null||null|null||null|null||null|null|null
113 | conv.weight|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null||null|null||null|null||null|null|null
114 | 0.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
115 | 0.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
116 | 2.bias|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
117 | 2.weight|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
118 | time_embed.0.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
119 | time_embed.0.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
120 | time_embed.2.weight||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
121 | time_embed.2.bias||null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|null|
122 |
--------------------------------------------------------------------------------
/install.py:
--------------------------------------------------------------------------------
1 | import launch
2 | import importlib
3 | from packaging.version import Version
4 | from packaging.requirements import Requirement
5 |
6 | def is_installed(pip_package):
7 | """
8 | Check if a package is installed and meets version requirements specified in pip-style format.
9 |
10 | Args:
11 | pip_package (str): Package name in pip-style format (e.g., "numpy>=1.22.0").
12 |
13 | Returns:
14 | bool: True if the package is installed and meets the version requirement, False otherwise.
15 | """
16 | try:
17 | # Parse the pip-style package name and version constraints
18 | requirement = Requirement(pip_package)
19 | package_name = requirement.name
20 | specifier = requirement.specifier # e.g., >=1.22.0
21 |
22 | # Check if the package is installed
23 | dist = importlib.metadata.distribution(package_name)
24 | installed_version = Version(dist.version)
25 |
26 | # Check version constraints
27 | if specifier.contains(installed_version):
28 | return True
29 | else:
30 | print(f"Installed version of {package_name} ({installed_version}) does not satisfy the requirement ({specifier}).")
31 | return False
32 | except importlib.metadata.PackageNotFoundError:
33 | print(f"Package {pip_package} is not installed.")
34 | return False
35 |
36 | requirements = [
37 | "diffusers==0.31.0",
38 | "scikit-learn",
39 | ]
40 |
41 | for module in requirements:
42 | if not is_installed(module):
43 | launch.run_pip(f"install {module}", module)
--------------------------------------------------------------------------------
/scripts/A1111/lora_patches.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import networks
4 | from modules import patches
5 |
6 |
7 | class LoraPatches:
8 | def __init__(self):
9 | self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
10 | self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
11 | self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
12 | self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
13 | self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
14 | self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
15 | self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
16 | self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
17 | self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
18 | self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
19 |
20 | def undo(self):
21 | self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
22 | self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
23 | self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
24 | self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
25 | self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
26 | self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
27 | self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
28 | self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
29 | self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
30 | self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')
31 |
32 |
--------------------------------------------------------------------------------
/scripts/A1111/lyco_helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def make_weight_cp(t, wa, wb):
5 | temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
6 | return torch.einsum('i j k l, i r -> r j k l', temp, wa)
7 |
8 |
9 | def rebuild_conventional(up, down, shape, dyn_dim=None):
10 | up = cpufloat(up.reshape(up.size(0), -1))
11 | down = cpufloat(down.reshape(down.size(0), -1))
12 | if dyn_dim is not None:
13 | up = up[:, :dyn_dim]
14 | down = down[:dyn_dim, :]
15 | return (up @ down).reshape(shape)
16 |
17 |
18 | def rebuild_cp_decomposition(up, down, mid):
19 | up = cpufloat(up.reshape(up.size(0), -1))
20 | down = cpufloat(down.reshape(down.size(0), -1))
21 | mid = cpufloat(mid)
22 | return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
23 |
24 |
25 | # copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
26 | def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
27 | '''
28 | return a tuple of two value of input dimension decomposed by the number closest to factor
29 | second value is higher or equal than first value.
30 |
31 | In LoRA with Kroneckor Product, first value is a value for weight scale.
32 | secon value is a value for weight.
33 |
34 | Because of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
35 |
36 | examples)
37 | factor
38 | -1 2 4 8 16 ...
39 | 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
40 | 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
41 | 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
42 | 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
43 | 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
44 | 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
45 | '''
46 |
47 | if factor > 0 and (dimension % factor) == 0:
48 | m = factor
49 | n = dimension // factor
50 | if m > n:
51 | n, m = m, n
52 | return m, n
53 | if factor < 0:
54 | factor = dimension
55 | m, n = 1, dimension
56 | length = m + n
57 | while m length or new_m>factor:
63 | break
64 | else:
65 | m, n = new_m, new_n
66 | if m > n:
67 | n, m = m, n
68 | return m, n
69 |
70 | def cpufloat(module):
71 | if module is None: return module #None対策
72 | return module.to(torch.float) if module.device.type == "cpu" else module
--------------------------------------------------------------------------------
/scripts/A1111/network.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import os
3 | from collections import namedtuple
4 | import enum
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | try:
11 | from modules import sd_models, cache, errors, hashes, shared
12 | except:
13 | pass
14 |
15 | class QkvLinear(torch.nn.Linear):
16 | pass
17 |
18 | NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
19 |
20 | metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
21 |
22 |
23 | class SdVersion(enum.Enum):
24 | Unknown = 1
25 | SD1 = 2
26 | SD2 = 3
27 | SDXL = 4
28 |
29 |
30 | class NetworkOnDisk:
31 | def __init__(self, name, filename):
32 | self.name = name
33 | self.filename = filename
34 | self.metadata = {}
35 | self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
36 |
37 | def read_metadata():
38 | metadata = sd_models.read_metadata_from_safetensors(filename)
39 |
40 | return metadata
41 |
42 | if self.is_safetensors:
43 | try:
44 | self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
45 | except Exception as e:
46 | errors.display(e, f"reading lora {filename}")
47 |
48 | if self.metadata:
49 | m = {}
50 | for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
51 | m[k] = v
52 |
53 | self.metadata = m
54 |
55 | self.alias = self.metadata.get('ss_output_name', self.name)
56 |
57 | self.hash = None
58 | self.shorthash = None
59 | self.set_hash(
60 | self.metadata.get('sshs_model_hash') or
61 | hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
62 | ''
63 | )
64 |
65 | self.sd_version = self.detect_version()
66 |
67 | def detect_version(self):
68 | if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
69 | return SdVersion.SDXL
70 | elif str(self.metadata.get('ss_v2', "")) == "True":
71 | return SdVersion.SD2
72 | elif len(self.metadata):
73 | return SdVersion.SD1
74 |
75 | return SdVersion.Unknown
76 |
77 | def set_hash(self, v):
78 | self.hash = v
79 | self.shorthash = self.hash[0:12]
80 |
81 | if self.shorthash:
82 | import networks
83 | networks.available_network_hash_lookup[self.shorthash] = self
84 |
85 | def read_hash(self):
86 | if not self.hash:
87 | self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
88 |
89 | def get_alias(self):
90 | import networks
91 | if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
92 | return self.name
93 | else:
94 | return self.alias
95 |
96 |
97 | class Network: # LoraModule
98 | def __init__(self, name, network_on_disk: NetworkOnDisk):
99 | self.name = name
100 | self.network_on_disk = network_on_disk
101 | self.te_multiplier = 1.0
102 | self.unet_multiplier = 1.0
103 | self.dyn_dim = None
104 | self.modules = {}
105 | self.bundle_embeddings = {}
106 | self.mtime = None
107 |
108 | self.mentioned_name = None
109 | """the text that was used to add the network to prompt - can be either name or an alias"""
110 |
111 |
112 | class ModuleType:
113 | def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
114 | return None
115 |
116 |
117 | class NetworkModule:
118 | def __init__(self, net: Network, weights: NetworkWeights):
119 | self.network = net
120 | self.network_key = weights.network_key
121 | self.sd_key = weights.sd_key
122 | self.sd_module = weights.sd_module
123 |
124 | if isinstance(self.sd_module, QkvLinear):
125 | s = self.sd_module.weight.shape
126 | self.shape = (s[0] // 3, s[1])
127 | elif hasattr(self.sd_module, 'weight'):
128 | self.shape = self.sd_module.weight.shape
129 | elif isinstance(self.sd_module, nn.MultiheadAttention):
130 | # For now, only self-attn use Pytorch's MHA
131 | # So assume all qkvo proj have same shape
132 | self.shape = self.sd_module.out_proj.weight.shape
133 | else:
134 | self.shape = None
135 |
136 | self.ops = None
137 | self.extra_kwargs = {}
138 | if isinstance(self.sd_module, nn.Conv2d):
139 | self.ops = F.conv2d
140 | self.extra_kwargs = {
141 | 'stride': self.sd_module.stride,
142 | 'padding': self.sd_module.padding
143 | }
144 | elif isinstance(self.sd_module, nn.Linear):
145 | self.ops = F.linear
146 | elif isinstance(self.sd_module, nn.LayerNorm):
147 | self.ops = F.layer_norm
148 | self.extra_kwargs = {
149 | 'normalized_shape': self.sd_module.normalized_shape,
150 | 'eps': self.sd_module.eps
151 | }
152 | elif isinstance(self.sd_module, nn.GroupNorm):
153 | self.ops = F.group_norm
154 | self.extra_kwargs = {
155 | 'num_groups': self.sd_module.num_groups,
156 | 'eps': self.sd_module.eps
157 | }
158 |
159 | self.dim = None
160 | self.bias = weights.w.get("bias")
161 | self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
162 | self.scale = weights.w["scale"].item() if "scale" in weights.w else None
163 |
164 | self.dora_scale = weights.w.get("dora_scale", None)
165 | self.dora_norm_dims = len(self.shape) - 1
166 |
167 | def multiplier(self):
168 | if 'transformer' in self.sd_key[:20]:
169 | return self.network.te_multiplier
170 | else:
171 | return self.network.unet_multiplier
172 |
173 | def calc_scale(self):
174 | if self.scale is not None:
175 | return self.scale
176 | if self.dim is not None and self.alpha is not None:
177 | return self.alpha / self.dim
178 |
179 | return 1.0
180 |
181 | def apply_weight_decompose(self, updown, orig_weight):
182 | # Match the device/dtype
183 | orig_weight = orig_weight.to(updown.dtype)
184 | dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)
185 | updown = updown.to(orig_weight.device)
186 |
187 | merged_scale1 = updown + orig_weight
188 | merged_scale1_norm = (
189 | merged_scale1.transpose(0, 1)
190 | .reshape(merged_scale1.shape[1], -1)
191 | .norm(dim=1, keepdim=True)
192 | .reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)
193 | .transpose(0, 1)
194 | )
195 |
196 | dora_merged = (
197 | merged_scale1 * (dora_scale / merged_scale1_norm)
198 | )
199 | final_updown = dora_merged - orig_weight
200 | return final_updown
201 |
202 | def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
203 | if self.bias is not None:
204 | updown = updown.reshape(self.bias.shape)
205 | updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
206 | updown = updown.reshape(output_shape)
207 |
208 | if len(output_shape) == 4:
209 | updown = updown.reshape(output_shape)
210 |
211 | if orig_weight.size().numel() == updown.size().numel():
212 | updown = updown.reshape(orig_weight.shape)
213 |
214 | if ex_bias is not None:
215 | ex_bias = ex_bias * self.multiplier()
216 |
217 | updown = updown * self.calc_scale()
218 |
219 | if self.dora_scale is not None:
220 | updown = self.apply_weight_decompose(updown, orig_weight)
221 |
222 | return updown * self.multiplier(), ex_bias
223 |
224 | def calc_updown(self, target):
225 | raise NotImplementedError()
226 |
227 | def forward(self, x, y):
228 | """A general forward implementation for all modules"""
229 | if self.ops is None:
230 | raise NotImplementedError()
231 | else:
232 | updown, ex_bias = self.calc_updown(self.sd_module.weight)
233 | return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
234 |
235 |
--------------------------------------------------------------------------------
/scripts/A1111/network_full.py:
--------------------------------------------------------------------------------
1 | import scripts.A1111.network as network
2 |
3 |
4 | class ModuleTypeFull(network.ModuleType):
5 | def create_module(self, net: network.Network, weights: network.NetworkWeights):
6 | if all(x in weights.w for x in ["diff"]):
7 | return NetworkModuleFull(net, weights)
8 |
9 | return None
10 |
11 |
12 | class NetworkModuleFull(network.NetworkModule):
13 | def __init__(self, net: network.Network, weights: network.NetworkWeights):
14 | super().__init__(net, weights)
15 |
16 | self.weight = weights.w.get("diff")
17 | self.ex_bias = weights.w.get("diff_b")
18 |
19 | def calc_updown(self, orig_weight):
20 | output_shape = self.weight.shape
21 | updown = self.weight.to(orig_weight.device)
22 | if self.ex_bias is not None:
23 | ex_bias = self.ex_bias.to(orig_weight.device)
24 | else:
25 | ex_bias = None
26 |
27 | return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
28 |
--------------------------------------------------------------------------------
/scripts/A1111/network_glora.py:
--------------------------------------------------------------------------------
1 |
2 | import scripts.A1111.network as network
3 |
4 | class ModuleTypeGLora(network.ModuleType):
5 | def create_module(self, net: network.Network, weights: network.NetworkWeights):
6 | if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
7 | return NetworkModuleGLora(net, weights)
8 |
9 | return None
10 |
11 | # adapted from https://github.com/KohakuBlueleaf/LyCORIS
12 | class NetworkModuleGLora(network.NetworkModule):
13 | def __init__(self, net: network.Network, weights: network.NetworkWeights):
14 | super().__init__(net, weights)
15 |
16 | if hasattr(self.sd_module, 'weight'):
17 | self.shape = self.sd_module.weight.shape
18 |
19 | self.w1a = weights.w["a1.weight"]
20 | self.w1b = weights.w["b1.weight"]
21 | self.w2a = weights.w["a2.weight"]
22 | self.w2b = weights.w["b2.weight"]
23 |
24 | def calc_updown(self, orig_weight):
25 | w1a = self.w1a.to(orig_weight.device)
26 | w1b = self.w1b.to(orig_weight.device)
27 | w2a = self.w2a.to(orig_weight.device)
28 | w2b = self.w2b.to(orig_weight.device)
29 |
30 | output_shape = [w1a.size(0), w1b.size(1)]
31 | updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
32 |
33 | return self.finalize_updown(updown, orig_weight, output_shape)
34 |
--------------------------------------------------------------------------------
/scripts/A1111/network_hada.py:
--------------------------------------------------------------------------------
1 | import scripts.A1111.lyco_helpers as lyco_helpers
2 | import scripts.A1111.network as network
3 |
4 |
5 | class ModuleTypeHada(network.ModuleType):
6 | def create_module(self, net: network.Network, weights: network.NetworkWeights):
7 | if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
8 | return NetworkModuleHada(net, weights)
9 |
10 | return None
11 |
12 |
13 | class NetworkModuleHada(network.NetworkModule):
14 | def __init__(self, net: network.Network, weights: network.NetworkWeights):
15 | super().__init__(net, weights)
16 |
17 | if hasattr(self.sd_module, 'weight'):
18 | self.shape = self.sd_module.weight.shape
19 |
20 | self.w1a = weights.w["hada_w1_a"]
21 | self.w1b = weights.w["hada_w1_b"]
22 | self.dim = self.w1b.shape[0]
23 | self.w2a = weights.w["hada_w2_a"]
24 | self.w2b = weights.w["hada_w2_b"]
25 |
26 | self.t1 = weights.w.get("hada_t1")
27 | self.t2 = weights.w.get("hada_t2")
28 |
29 | def calc_updown(self, orig_weight):
30 | w1a = self.w1a.to(orig_weight.device)
31 | w1b = self.w1b.to(orig_weight.device)
32 | w2a = self.w2a.to(orig_weight.device)
33 | w2b = self.w2b.to(orig_weight.device)
34 |
35 | output_shape = [w1a.size(0), w1b.size(1)]
36 |
37 | if self.t1 is not None:
38 | output_shape = [w1a.size(1), w1b.size(1)]
39 | t1 = self.t1.to(orig_weight.device)
40 | updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
41 | output_shape += t1.shape[2:]
42 | else:
43 | if len(w1b.shape) == 4:
44 | output_shape += w1b.shape[2:]
45 | updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
46 |
47 | if self.t2 is not None:
48 | t2 = self.t2.to(orig_weight.device)
49 | updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
50 | else:
51 | updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
52 |
53 | updown = updown1 * updown2
54 |
55 | return self.finalize_updown(updown, orig_weight, output_shape)
56 |
--------------------------------------------------------------------------------
/scripts/A1111/network_ia3.py:
--------------------------------------------------------------------------------
1 | import scripts.A1111.network as network
2 |
3 |
4 | class ModuleTypeIa3(network.ModuleType):
5 | def create_module(self, net: network.Network, weights: network.NetworkWeights):
6 | if all(x in weights.w for x in ["weight"]):
7 | return NetworkModuleIa3(net, weights)
8 |
9 | return None
10 |
11 |
12 | class NetworkModuleIa3(network.NetworkModule):
13 | def __init__(self, net: network.Network, weights: network.NetworkWeights):
14 | super().__init__(net, weights)
15 |
16 | self.w = weights.w["weight"]
17 | self.on_input = weights.w["on_input"].item()
18 |
19 | def calc_updown(self, orig_weight):
20 | w = self.w.to(orig_weight.device)
21 |
22 | output_shape = [w.size(0), orig_weight.size(1)]
23 | if self.on_input:
24 | output_shape.reverse()
25 | else:
26 | w = w.reshape(-1, 1)
27 |
28 | updown = orig_weight * w
29 |
30 | return self.finalize_updown(updown, orig_weight, output_shape)
31 |
--------------------------------------------------------------------------------
/scripts/A1111/network_lokr.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import scripts.A1111.lyco_helpers as lyco_helpers
4 | import scripts.A1111.network as network
5 |
6 |
7 | class ModuleTypeLokr(network.ModuleType):
8 | def create_module(self, net: network.Network, weights: network.NetworkWeights):
9 | has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w)
10 | has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w)
11 | if has_1 and has_2:
12 | return NetworkModuleLokr(net, weights)
13 |
14 | return None
15 |
16 |
17 | def make_kron(orig_shape, w1, w2):
18 | if len(w2.shape) == 4:
19 | w1 = w1.unsqueeze(2).unsqueeze(2)
20 | w2 = w2.contiguous()
21 | return torch.kron(w1, w2).reshape(orig_shape)
22 |
23 |
24 | class NetworkModuleLokr(network.NetworkModule):
25 | def __init__(self, net: network.Network, weights: network.NetworkWeights):
26 | super().__init__(net, weights)
27 |
28 | self.w1 = weights.w.get("lokr_w1")
29 | self.w1a = weights.w.get("lokr_w1_a")
30 | self.w1b = weights.w.get("lokr_w1_b")
31 | self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim
32 | self.w2 = weights.w.get("lokr_w2")
33 | self.w2a = weights.w.get("lokr_w2_a")
34 | self.w2b = weights.w.get("lokr_w2_b")
35 | self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim
36 | self.t2 = weights.w.get("lokr_t2")
37 |
38 | def calc_updown(self, orig_weight):
39 | if self.w1 is not None:
40 | w1 = self.w1.to(orig_weight.device)
41 | else:
42 | w1a = self.w1a.to(orig_weight.device)
43 | w1b = self.w1b.to(orig_weight.device)
44 | w1 = w1a @ w1b
45 |
46 | if self.w2 is not None:
47 | w2 = self.w2.to(orig_weight.device)
48 | elif self.t2 is None:
49 | w2a = self.w2a.to(orig_weight.device)
50 | w2b = self.w2b.to(orig_weight.device)
51 | w2 = w2a @ w2b
52 | else:
53 | t2 = self.t2.to(orig_weight.device)
54 | w2a = self.w2a.to(orig_weight.device)
55 | w2b = self.w2b.to(orig_weight.device)
56 | w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
57 |
58 | output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
59 | if len(orig_weight.shape) == 4:
60 | output_shape = orig_weight.shape
61 |
62 | updown = make_kron(output_shape, w1, w2)
63 |
64 | return self.finalize_updown(updown, orig_weight, output_shape)
65 |
--------------------------------------------------------------------------------
/scripts/A1111/network_lora.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import scripts.A1111.lyco_helpers as lyco_helpers
4 | import scripts.A1111.network as network
5 | from modules import devices
6 | from modules.ui import versions_html
7 |
8 | forge = "forge" in versions_html()
9 | class QkvLinear(torch.nn.Linear):
10 | pass
11 |
12 | class ModuleTypeLora(network.ModuleType):
13 | def create_module(self, net: network.Network, weights: network.NetworkWeights):
14 | if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
15 | return NetworkModuleLora(net, weights)
16 |
17 | if all(x in weights.w for x in ["lora_A.weight", "lora_B.weight"]):
18 | w = weights.w.copy()
19 | weights.w.clear()
20 | weights.w.update({"lora_up.weight": w["lora_B.weight"], "lora_down.weight": w["lora_A.weight"]})
21 |
22 | return NetworkModuleLora(net, weights)
23 |
24 | return None
25 |
26 |
27 | class NetworkModuleLora(network.NetworkModule):
28 | def __init__(self, net: network.Network, weights: network.NetworkWeights):
29 | super().__init__(net, weights)
30 |
31 | self.up_model = self.create_module(weights.w, "lora_up.weight")
32 | self.down_model = self.create_module(weights.w, "lora_down.weight")
33 | self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
34 |
35 | self.dim = weights.w["lora_down.weight"].shape[0]
36 |
37 | def create_module(self, weights, key, none_ok=False):
38 | weight = weights.get(key)
39 |
40 | if weight is None and none_ok:
41 | return None
42 |
43 | is_linear = "Linear" in str(type(self.sd_module)) if forge else type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, QkvLinear]
44 | is_conv = "Conv2d" in str(type(self.sd_module)) if forge else type(self.sd_module) in [torch.nn.Conv2d]
45 |
46 | if is_linear:
47 | weight = weight.reshape(weight.shape[0], -1)
48 | module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
49 | elif is_conv and key == "lora_down.weight" or key == "dyn_up":
50 | if len(weight.shape) == 2:
51 | weight = weight.reshape(weight.shape[0], -1, 1, 1)
52 |
53 | if weight.shape[2] != 1 or weight.shape[3] != 1:
54 | module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
55 | else:
56 | module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
57 | elif is_conv and key == "lora_mid.weight":
58 | module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
59 | elif is_conv and key == "lora_up.weight" or key == "dyn_down":
60 | module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
61 | else:
62 | raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
63 |
64 | with torch.no_grad():
65 | if weight.shape != module.weight.shape:
66 | weight = weight.reshape(module.weight.shape)
67 | module.weight.copy_(weight)
68 |
69 | module.to(device=devices.cpu, dtype=devices.dtype)
70 | module.weight.requires_grad_(False)
71 |
72 | return module
73 |
74 | def calc_updown(self, orig_weight):
75 | up = self.up_model.weight.to(orig_weight.device)
76 | down = self.down_model.weight.to(orig_weight.device)
77 |
78 | output_shape = [up.size(0), down.size(1)]
79 | if self.mid_model is not None:
80 | # cp-decomposition
81 | mid = self.mid_model.weight.to(orig_weight.device)
82 | updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
83 | output_shape += mid.shape[2:]
84 | else:
85 | if len(down.shape) == 4:
86 | output_shape += down.shape[2:]
87 | updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
88 |
89 | return self.finalize_updown(updown, orig_weight, output_shape)
90 |
91 | def forward(self, x, y):
92 | self.up_model.to(device=devices.device)
93 | self.down_model.to(device=devices.device)
94 |
95 | return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
96 |
97 |
98 |
--------------------------------------------------------------------------------
/scripts/A1111/network_norm.py:
--------------------------------------------------------------------------------
1 | import scripts.A1111.network as network
2 |
3 | class ModuleTypeNorm(network.ModuleType):
4 | def create_module(self, net: network.Network, weights: network.NetworkWeights):
5 | if all(x in weights.w for x in ["w_norm", "b_norm"]):
6 | return NetworkModuleNorm(net, weights)
7 |
8 | return None
9 |
10 |
11 | class NetworkModuleNorm(network.NetworkModule):
12 | def __init__(self, net: network.Network, weights: network.NetworkWeights):
13 | super().__init__(net, weights)
14 |
15 | self.w_norm = weights.w.get("w_norm")
16 | self.b_norm = weights.w.get("b_norm")
17 |
18 | def calc_updown(self, orig_weight):
19 | output_shape = self.w_norm.shape
20 | updown = self.w_norm.to(orig_weight.device)
21 |
22 | if self.b_norm is not None:
23 | ex_bias = self.b_norm.to(orig_weight.device)
24 | else:
25 | ex_bias = None
26 |
27 | return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
28 |
--------------------------------------------------------------------------------
/scripts/A1111/network_oft.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import scripts.A1111.network as network
3 | from einops import rearrange
4 |
5 |
6 | class ModuleTypeOFT(network.ModuleType):
7 | def create_module(self, net: network.Network, weights: network.NetworkWeights):
8 | if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
9 | return NetworkModuleOFT(net, weights)
10 |
11 | return None
12 |
13 | # Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
14 | # and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
15 | class NetworkModuleOFT(network.NetworkModule):
16 | def __init__(self, net: network.Network, weights: network.NetworkWeights):
17 |
18 | super().__init__(net, weights)
19 |
20 | self.lin_module = None
21 | self.org_module: list[torch.Module] = [self.sd_module]
22 |
23 | self.scale = 1.0
24 | self.is_R = False
25 | self.is_boft = False
26 |
27 | # kohya-ss/New LyCORIS OFT/BOFT
28 | if "oft_blocks" in weights.w.keys():
29 | self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
30 | self.alpha = weights.w.get("alpha", None) # alpha is constraint
31 | self.dim = self.oft_blocks.shape[0] # lora dim
32 | # Old LyCORIS OFT
33 | elif "oft_diag" in weights.w.keys():
34 | self.is_R = True
35 | self.oft_blocks = weights.w["oft_diag"]
36 | # self.alpha is unused
37 | self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
38 |
39 | is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
40 | is_conv = type(self.sd_module) in [torch.nn.Conv2d]
41 | is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
42 |
43 | if is_linear:
44 | self.out_dim = self.sd_module.out_features
45 | elif is_conv:
46 | self.out_dim = self.sd_module.out_channels
47 | elif is_other_linear:
48 | self.out_dim = self.sd_module.embed_dim
49 |
50 | # LyCORIS BOFT
51 | if self.oft_blocks.dim() == 4:
52 | self.is_boft = True
53 | self.rescale = weights.w.get('rescale', None)
54 | if self.rescale is not None and not is_other_linear:
55 | self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
56 |
57 | self.num_blocks = self.dim
58 | self.block_size = self.out_dim // self.dim
59 | self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim
60 | if self.is_R:
61 | self.constraint = None
62 | self.block_size = self.dim
63 | self.num_blocks = self.out_dim // self.dim
64 | elif self.is_boft:
65 | self.boft_m = self.oft_blocks.shape[0]
66 | self.num_blocks = self.oft_blocks.shape[1]
67 | self.block_size = self.oft_blocks.shape[2]
68 | self.boft_b = self.block_size
69 |
70 | def calc_updown(self, orig_weight):
71 | oft_blocks = self.oft_blocks.to(orig_weight.device)
72 | eye = torch.eye(self.block_size, device=oft_blocks.device)
73 |
74 | if not self.is_R:
75 | block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix
76 | if self.constraint != 0:
77 | norm_Q = torch.norm(block_Q.flatten())
78 | new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
79 | block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
80 | oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
81 |
82 | R = oft_blocks.to(orig_weight.device)
83 |
84 | if not self.is_boft:
85 | # This errors out for MultiheadAttention, might need to be handled up-stream
86 | merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
87 | merged_weight = torch.einsum(
88 | 'k n m, k n ... -> k m ...',
89 | R,
90 | merged_weight
91 | )
92 | merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
93 | else:
94 | # TODO: determine correct value for scale
95 | scale = 1.0
96 | m = self.boft_m
97 | b = self.boft_b
98 | r_b = b // 2
99 | inp = orig_weight
100 | for i in range(m):
101 | bi = R[i] # b_num, b_size, b_size
102 | if i == 0:
103 | # Apply multiplier/scale and rescale into first weight
104 | bi = bi * scale + (1 - scale) * eye
105 | inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b)
106 | inp = rearrange(inp, "(d b) ... -> d b ...", b=b)
107 | inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp)
108 | inp = rearrange(inp, "d b ... -> (d b) ...")
109 | inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b)
110 | merged_weight = inp
111 |
112 | # Rescale mechanism
113 | if self.rescale is not None:
114 | merged_weight = self.rescale.to(merged_weight) * merged_weight
115 |
116 | updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
117 | output_shape = orig_weight.shape
118 | return self.finalize_updown(updown, orig_weight, output_shape)
119 |
--------------------------------------------------------------------------------
/scripts/GenParamGetter.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import scripts.mergers.components as components
3 | from scripts.mergers.mergers import smergegen, simggen
4 | from scripts.mergers.xyplot import numanager
5 | try:
6 | from scripts.mergers.pluslora import frompromptf
7 | except ImportError as e:
8 | try:
9 | import transformers
10 | transformers_version = transformers.__version__
11 | except ImportError:
12 | transformers_version = "not installed"
13 |
14 | try:
15 | import diffusers
16 | diffusers_version = diffusers.__version__
17 | except ImportError:
18 | diffusers_version = "not installed"
19 |
20 | print(
21 | f"Version error: Failed to import module.\n"
22 | f"Transformers version: {transformers_version}\n"
23 | f"Diffusers version: {diffusers_version}\n"
24 | "Please ensure compatibility between these packages."
25 | )
26 | raise e
27 |
28 | from modules import scripts, script_callbacks
29 |
30 | class GenParamGetter(scripts.Script):
31 | txt2img_gen_button = None
32 | img2img_gen_button = None
33 |
34 | events_assigned = False
35 |
36 | def title(self):
37 | return "Super Marger Generation Parameter Getter"
38 |
39 | def show(self, is_img2img):
40 | return scripts.AlwaysVisible
41 |
42 | def get_wanted_params(params,wanted):
43 | output = []
44 | for target in wanted:
45 | if target is None:
46 | output.append(params[0])
47 | continue
48 | for param in params:
49 | if hasattr(param,"label"):
50 | if param.label == target:
51 | output.append(param)
52 | return output
53 |
54 | def after_component(self, component: gr.components.Component, **_kwargs):
55 | """Find generate button"""
56 | if component.elem_id == "txt2img_generate":
57 | GenParamGetter.txt2img_gen_button = component
58 | elif component.elem_id == "img2img_generate":
59 | GenParamGetter.img2img_gen_button = component
60 |
61 | def get_components_by_ids(root: gr.Blocks, ids: list[int]):
62 | components: list[gr.Blocks] = []
63 |
64 | if root._id in ids:
65 | components.append(root)
66 | ids = [_id for _id in ids if _id != root._id]
67 |
68 | if hasattr(root,"children"):
69 | for block in root.children:
70 | components.extend(GenParamGetter.get_components_by_ids(block, ids))
71 | return components
72 |
73 | def compare_components_with_ids(components: list[gr.Blocks], ids: list[int]):
74 |
75 | try:
76 | return len(components) == len(ids) and all(component._id == _id for component, _id in zip(components, ids))
77 | except:
78 | return False
79 |
80 | def get_params_components(demo: gr.Blocks, app):
81 | for _id, _is_txt2img in zip([GenParamGetter.txt2img_gen_button._id, GenParamGetter.img2img_gen_button._id], [True, False]):
82 | if hasattr(demo,"dependencies"):
83 | dependencies: list[dict] = [x for x in demo.dependencies if x["trigger"] == "click" and _id in x["targets"]]
84 | g4 = False
85 | else:
86 | dependencies: list[dict] = [x for x in demo.config["dependencies"] if x["targets"][0][1] == "click" and _id in x["targets"][0]]
87 | g4 = True
88 |
89 | dependency: dict = None
90 |
91 | for d in dependencies:
92 | if len(d["outputs"]) == 4:
93 | dependency = d
94 |
95 | if g4:
96 | params = [demo.blocks[x] for x in dependency['inputs']]
97 | if _is_txt2img:
98 | components.paramsnames = [x.label if hasattr(x,"label") else "None" for x in params]
99 |
100 | if _is_txt2img:
101 | components.txt2img_params = params
102 | else:
103 | components.img2img_params = params
104 | else:
105 | params = [params for params in demo.fns if GenParamGetter.compare_components_with_ids(params.inputs, dependency["inputs"])]
106 |
107 | if _is_txt2img:
108 | components.paramsnames = [x.label if hasattr(x,"label") else "None" for x in params[0].inputs]
109 |
110 | if _is_txt2img:
111 | components.txt2img_params = params[0].inputs
112 | else:
113 | components.img2img_params = params[0].inputs
114 |
115 |
116 | if not GenParamGetter.events_assigned:
117 | with demo:
118 | components.merge.click(
119 | fn=smergegen,
120 | inputs=[*components.msettings,components.esettings1,*components.genparams,*components.hiresfix,*components.lucks,components.currentmodel,components.dfalse,*components.txt2img_params],
121 | outputs=[components.submit_result,components.currentmodel]
122 | )
123 |
124 | components.mergeandgen.click(
125 | fn=smergegen,
126 | inputs=[*components.msettings,components.esettings1,*components.genparams,*components.hiresfix,*components.lucks,components.currentmodel,components.dtrue,*components.txt2img_params],
127 | outputs=[components.submit_result,components.currentmodel,*components.imagegal]
128 | )
129 |
130 | components.gen.click(
131 | fn=simggen,
132 | inputs=[*components.genparams,*components.hiresfix,components.currentmodel,components.id_sets,gr.Textbox(value="No id",visible=False),*components.txt2img_params],
133 | outputs=[*components.imagegal],
134 | )
135 |
136 | components.merge2.click(
137 | fn=smergegen,
138 | inputs=[*components.msettings,components.esettings1,*components.genparams,*components.hiresfix,*components.lucks,components.currentmodel,components.dfalse,*components.txt2img_params],
139 | outputs=[components.submit_result,components.currentmodel]
140 | )
141 |
142 | components.mergeandgen2.click(
143 | fn=smergegen,
144 | inputs=[*components.msettings,components.esettings1,*components.genparams,*components.hiresfix,*components.lucks,components.currentmodel,components.dtrue,*components.txt2img_params],
145 | outputs=[components.submit_result,components.currentmodel,*components.imagegal]
146 | )
147 |
148 | components.gen2.click(
149 | fn=simggen,
150 | inputs=[*components.genparams,*components.hiresfix,components.currentmodel,components.id_sets,gr.Textbox(value="No id",visible=False),*components.txt2img_params],
151 | outputs=[*components.imagegal],
152 | )
153 |
154 |
155 | components.s_reserve.click(
156 | fn=numanager,
157 | inputs=[gr.Textbox(value="reserve",visible=False),*components.xysettings,*components.msettings,*components.genparams,*components.hiresfix,*components.lucks,*components.txt2img_params],
158 | outputs=[components.numaframe]
159 | )
160 |
161 | components.s_reserve1.click(
162 | fn=numanager,
163 | inputs=[gr.Textbox(value="reserve",visible=False),*components.xysettings,*components.msettings,*components.genparams,*components.hiresfix,*components.lucks,*components.txt2img_params],
164 | outputs=[components.numaframe]
165 | )
166 |
167 | components.gengrid.click(
168 | fn=numanager,
169 | inputs=[gr.Textbox(value="normal",visible=False),*components.xysettings,*components.msettings,*components.genparams,*components.hiresfix,*components.lucks,*components.txt2img_params],
170 | outputs=[components.submit_result,components.currentmodel,*components.imagegal],
171 | )
172 |
173 | components.s_startreserve.click(
174 | fn=numanager,
175 | inputs=[gr.Textbox(value=" ",visible=False),*components.xysettings,*components.msettings,*components.genparams,*components.hiresfix,*components.lucks,*components.txt2img_params],
176 | outputs=[components.submit_result,components.currentmodel,*components.imagegal],
177 | )
178 |
179 | components.rand_merge.click(
180 | fn=numanager,
181 | inputs=[gr.Textbox(value="random",visible=False),*components.xysettings,*components.msettings,*components.genparams,*components.hiresfix,*components.lucks,*components.txt2img_params],
182 | outputs=[components.submit_result,components.currentmodel,*components.imagegal],
183 | )
184 |
185 | components.frompromptb.click(
186 | fn=frompromptf,
187 | inputs=[*components.txt2img_params],
188 | outputs=components.sml_loranames,
189 | )
190 | GenParamGetter.events_assigned = True
191 |
192 | if __package__ == "GenParamGetter":
193 | script_callbacks.on_app_started(GenParamGetter.get_params_components)
194 |
--------------------------------------------------------------------------------
/scripts/Roboto-Regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hako-mikan/sd-webui-supermerger/352c0047ec852fd7000835a861bf8cc7a907e04b/scripts/Roboto-Regular.ttf
--------------------------------------------------------------------------------
/scripts/kohyas/extract_lora_from_models.py:
--------------------------------------------------------------------------------
1 | # extract approximating LoRA by svd from two SD models
2 | # The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3 | # Thanks to cloneofsimo!
4 |
5 | import argparse
6 | import json
7 | import os
8 | import time
9 | import torch
10 | from safetensors.torch import load_file, save_file
11 | from tqdm import tqdm
12 | from scripts.kohyas import sai_model_spec,model_util,sdxl_model_util,lora
13 |
14 |
15 | CLAMP_QUANTILE = 0.99
16 | MIN_DIFF = 1e-1
17 |
18 |
19 | def save_to_file(file_name, model, state_dict, dtype):
20 | if dtype is not None:
21 | for key in list(state_dict.keys()):
22 | if type(state_dict[key]) == torch.Tensor:
23 | state_dict[key] = state_dict[key].to(dtype)
24 |
25 | if os.path.splitext(file_name)[1] == ".safetensors":
26 | save_file(model, file_name)
27 | else:
28 | torch.save(model, file_name)
29 |
30 |
31 | def svd(args):
32 | def str_to_dtype(p):
33 | if p == "float":
34 | return torch.float
35 | if p == "fp16":
36 | return torch.float16
37 | if p == "bf16":
38 | return torch.bfloat16
39 | return None
40 |
41 | assert args.v2 != args.sdxl or (
42 | not args.v2 and not args.sdxl
43 | ), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
44 | if args.v_parameterization is None:
45 | args.v_parameterization = args.v2
46 |
47 | save_dtype = str_to_dtype(args.save_precision)
48 |
49 | # load models
50 | if not args.sdxl:
51 | print(f"loading original SD model : {args.model_org}")
52 | text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org, del_ids=True)
53 | text_encoders_o = [text_encoder_o]
54 | print(f"loading tuned SD model : {args.model_tuned}")
55 | text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned, del_ids=True)
56 | text_encoders_t = [text_encoder_t]
57 | model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization)
58 | else:
59 | print(f"loading original SDXL model : {args.model_org}")
60 | text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
61 | sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu"
62 | )
63 | text_encoders_o = [text_encoder_o1, text_encoder_o2]
64 | print(f"loading original SDXL model : {args.model_tuned}")
65 | text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
66 | sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu"
67 | )
68 | text_encoders_t = [text_encoder_t1, text_encoder_t2]
69 | model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
70 |
71 | # create LoRA network to extract weights: Use dim (rank) as alpha
72 | if args.conv_dim is None:
73 | kwargs = {}
74 | else:
75 | kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
76 |
77 | lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs)
78 | lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs)
79 | assert len(lora_network_o.text_encoder_loras) == len(
80 | lora_network_t.text_encoder_loras
81 | ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
82 |
83 | # get diffs
84 | diffs = {}
85 | text_encoder_different = False
86 | for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
87 | lora_name = lora_o.lora_name
88 | module_o = lora_o.org_module
89 | module_t = lora_t.org_module
90 | diff = args.alpha * module_t.weight - args.beta * module_o.weight
91 |
92 | # Text Encoder might be same
93 | if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF:
94 | text_encoder_different = True
95 | print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}")
96 |
97 | diff = diff.float()
98 | diffs[lora_name] = diff
99 |
100 | if not text_encoder_different:
101 | print("Text encoder is same. Extract U-Net only.")
102 | lora_network_o.text_encoder_loras = []
103 | diffs = {}
104 |
105 | for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
106 | lora_name = lora_o.lora_name
107 | module_o = lora_o.org_module
108 | module_t = lora_t.org_module
109 | diff = args.alpha * module_t.weight - args.beta * module_o.weight
110 | diff = diff.float()
111 |
112 | if args.device:
113 | diff = diff.to(args.device)
114 |
115 | diffs[lora_name] = diff
116 |
117 | # make LoRA with svd
118 | print("calculating by svd")
119 | lora_weights = {}
120 | with torch.no_grad():
121 | for lora_name, mat in tqdm(list(diffs.items())):
122 | # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
123 | conv2d = len(mat.size()) == 4
124 | kernel_size = None if not conv2d else mat.size()[2:4]
125 | conv2d_3x3 = conv2d and kernel_size != (1, 1)
126 |
127 | rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
128 | out_dim, in_dim = mat.size()[0:2]
129 |
130 | if args.device:
131 | mat = mat.to(args.device)
132 |
133 | # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
134 | rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
135 |
136 | if conv2d:
137 | if conv2d_3x3:
138 | mat = mat.flatten(start_dim=1)
139 | else:
140 | mat = mat.squeeze()
141 |
142 | U, S, Vh = torch.linalg.svd(mat)
143 |
144 | U = U[:, :rank]
145 | S = S[:rank]
146 | U = U @ torch.diag(S)
147 |
148 | Vh = Vh[:rank, :]
149 |
150 | dist = torch.cat([U.flatten(), Vh.flatten()])
151 | hi_val = torch.quantile(dist, CLAMP_QUANTILE)
152 | low_val = -hi_val
153 |
154 | U = U.clamp(low_val, hi_val)
155 | Vh = Vh.clamp(low_val, hi_val)
156 |
157 | if conv2d:
158 | U = U.reshape(out_dim, rank, 1, 1)
159 | Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
160 |
161 | U = U.to("cpu").contiguous()
162 | Vh = Vh.to("cpu").contiguous()
163 |
164 | lora_weights[lora_name] = (U, Vh)
165 |
166 | # make state dict for LoRA
167 | lora_sd = {}
168 | for lora_name, (up_weight, down_weight) in lora_weights.items():
169 | lora_sd[lora_name + ".lora_up.weight"] = up_weight
170 | lora_sd[lora_name + ".lora_down.weight"] = down_weight
171 | lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0])
172 |
173 | # load state dict to LoRA and save it
174 | lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
175 | lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
176 |
177 | info = lora_network_save.load_state_dict(lora_sd)
178 | print(f"Loading extracted LoRA weights: {info}")
179 |
180 | dir_name = os.path.dirname(args.save_to)
181 | if dir_name and not os.path.exists(dir_name):
182 | os.makedirs(dir_name, exist_ok=True)
183 |
184 | # minimum metadata
185 | net_kwargs = {}
186 | if args.conv_dim is not None:
187 | net_kwargs["conv_dim"] = args.conv_dim
188 | net_kwargs["conv_alpha"] = args.conv_dim
189 |
190 | metadata = {
191 | "ss_v2": str(args.v2),
192 | "ss_base_model_version": model_version,
193 | "ss_network_module": "networks.lora",
194 | "ss_network_dim": str(args.dim),
195 | "ss_network_alpha": str(args.dim),
196 | "ss_network_args": json.dumps(net_kwargs),
197 | }
198 |
199 | if not args.no_metadata:
200 | title = os.path.splitext(os.path.basename(args.save_to))[0]
201 | sai_metadata = sai_model_spec.build_metadata(
202 | None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title
203 | )
204 | metadata.update(sai_metadata)
205 |
206 | lora_network_save.save_weights(args.save_to, save_dtype, metadata)
207 | return f"LoRA weights are saved to: {args.save_to}"
208 |
209 |
210 |
--------------------------------------------------------------------------------
/scripts/kohyas/ipex/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import contextlib
4 | import torch
5 | import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
6 | from .hijacks import ipex_hijacks
7 | from .attention import attention_init
8 |
9 | # pylint: disable=protected-access, missing-function-docstring, line-too-long
10 |
11 | def ipex_init(): # pylint: disable=too-many-statements
12 | try:
13 | #Replace cuda with xpu:
14 | torch.cuda.current_device = torch.xpu.current_device
15 | torch.cuda.current_stream = torch.xpu.current_stream
16 | torch.cuda.device = torch.xpu.device
17 | torch.cuda.device_count = torch.xpu.device_count
18 | torch.cuda.device_of = torch.xpu.device_of
19 | torch.cuda.get_device_name = torch.xpu.get_device_name
20 | torch.cuda.get_device_properties = torch.xpu.get_device_properties
21 | torch.cuda.init = torch.xpu.init
22 | torch.cuda.is_available = torch.xpu.is_available
23 | torch.cuda.is_initialized = torch.xpu.is_initialized
24 | torch.cuda.is_current_stream_capturing = lambda: False
25 | torch.cuda.set_device = torch.xpu.set_device
26 | torch.cuda.stream = torch.xpu.stream
27 | torch.cuda.synchronize = torch.xpu.synchronize
28 | torch.cuda.Event = torch.xpu.Event
29 | torch.cuda.Stream = torch.xpu.Stream
30 | torch.cuda.FloatTensor = torch.xpu.FloatTensor
31 | torch.Tensor.cuda = torch.Tensor.xpu
32 | torch.Tensor.is_cuda = torch.Tensor.is_xpu
33 | torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
34 | torch.cuda._initialized = torch.xpu.lazy_init._initialized
35 | torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
36 | torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
37 | torch.cuda._tls = torch.xpu.lazy_init._tls
38 | torch.cuda.threading = torch.xpu.lazy_init.threading
39 | torch.cuda.traceback = torch.xpu.lazy_init.traceback
40 | torch.cuda.Optional = torch.xpu.Optional
41 | torch.cuda.__cached__ = torch.xpu.__cached__
42 | torch.cuda.__loader__ = torch.xpu.__loader__
43 | torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
44 | torch.cuda.Tuple = torch.xpu.Tuple
45 | torch.cuda.streams = torch.xpu.streams
46 | torch.cuda._lazy_new = torch.xpu._lazy_new
47 | torch.cuda.FloatStorage = torch.xpu.FloatStorage
48 | torch.cuda.Any = torch.xpu.Any
49 | torch.cuda.__doc__ = torch.xpu.__doc__
50 | torch.cuda.default_generators = torch.xpu.default_generators
51 | torch.cuda.HalfTensor = torch.xpu.HalfTensor
52 | torch.cuda._get_device_index = torch.xpu._get_device_index
53 | torch.cuda.__path__ = torch.xpu.__path__
54 | torch.cuda.Device = torch.xpu.Device
55 | torch.cuda.IntTensor = torch.xpu.IntTensor
56 | torch.cuda.ByteStorage = torch.xpu.ByteStorage
57 | torch.cuda.set_stream = torch.xpu.set_stream
58 | torch.cuda.BoolStorage = torch.xpu.BoolStorage
59 | torch.cuda.os = torch.xpu.os
60 | torch.cuda.torch = torch.xpu.torch
61 | torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
62 | torch.cuda.Union = torch.xpu.Union
63 | torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
64 | torch.cuda.ShortTensor = torch.xpu.ShortTensor
65 | torch.cuda.LongTensor = torch.xpu.LongTensor
66 | torch.cuda.IntStorage = torch.xpu.IntStorage
67 | torch.cuda.LongStorage = torch.xpu.LongStorage
68 | torch.cuda.__annotations__ = torch.xpu.__annotations__
69 | torch.cuda.__package__ = torch.xpu.__package__
70 | torch.cuda.__builtins__ = torch.xpu.__builtins__
71 | torch.cuda.CharTensor = torch.xpu.CharTensor
72 | torch.cuda.List = torch.xpu.List
73 | torch.cuda._lazy_init = torch.xpu._lazy_init
74 | torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
75 | torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
76 | torch.cuda.ByteTensor = torch.xpu.ByteTensor
77 | torch.cuda.StreamContext = torch.xpu.StreamContext
78 | torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
79 | torch.cuda.ShortStorage = torch.xpu.ShortStorage
80 | torch.cuda._lazy_call = torch.xpu._lazy_call
81 | torch.cuda.HalfStorage = torch.xpu.HalfStorage
82 | torch.cuda.random = torch.xpu.random
83 | torch.cuda._device = torch.xpu._device
84 | torch.cuda.classproperty = torch.xpu.classproperty
85 | torch.cuda.__name__ = torch.xpu.__name__
86 | torch.cuda._device_t = torch.xpu._device_t
87 | torch.cuda.warnings = torch.xpu.warnings
88 | torch.cuda.__spec__ = torch.xpu.__spec__
89 | torch.cuda.BoolTensor = torch.xpu.BoolTensor
90 | torch.cuda.CharStorage = torch.xpu.CharStorage
91 | torch.cuda.__file__ = torch.xpu.__file__
92 | torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
93 | #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
94 |
95 | #Memory:
96 | torch.cuda.memory = torch.xpu.memory
97 | if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
98 | torch.xpu.empty_cache = lambda: None
99 | torch.cuda.empty_cache = torch.xpu.empty_cache
100 | torch.cuda.memory_stats = torch.xpu.memory_stats
101 | torch.cuda.memory_summary = torch.xpu.memory_summary
102 | torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
103 | torch.cuda.memory_allocated = torch.xpu.memory_allocated
104 | torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
105 | torch.cuda.memory_reserved = torch.xpu.memory_reserved
106 | torch.cuda.memory_cached = torch.xpu.memory_reserved
107 | torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
108 | torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
109 | torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
110 | torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
111 | torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
112 | torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
113 | torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
114 |
115 | #RNG:
116 | torch.cuda.get_rng_state = torch.xpu.get_rng_state
117 | torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
118 | torch.cuda.set_rng_state = torch.xpu.set_rng_state
119 | torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
120 | torch.cuda.manual_seed = torch.xpu.manual_seed
121 | torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
122 | torch.cuda.seed = torch.xpu.seed
123 | torch.cuda.seed_all = torch.xpu.seed_all
124 | torch.cuda.initial_seed = torch.xpu.initial_seed
125 |
126 | #AMP:
127 | torch.cuda.amp = torch.xpu.amp
128 | if not hasattr(torch.cuda.amp, "common"):
129 | torch.cuda.amp.common = contextlib.nullcontext()
130 | torch.cuda.amp.common.amp_definitely_not_available = lambda: False
131 | try:
132 | torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
133 | except Exception: # pylint: disable=broad-exception-caught
134 | try:
135 | from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
136 | gradscaler_init()
137 | torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
138 | except Exception: # pylint: disable=broad-exception-caught
139 | torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
140 |
141 | #C
142 | torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
143 | ipex._C._DeviceProperties.major = 2023
144 | ipex._C._DeviceProperties.minor = 2
145 |
146 | #Fix functions with ipex:
147 | torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
148 | torch._utils._get_available_device_type = lambda: "xpu"
149 | torch.has_cuda = True
150 | torch.cuda.has_half = True
151 | torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
152 | torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
153 | torch.version.cuda = "11.7"
154 | torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
155 | torch.cuda.get_device_properties.major = 11
156 | torch.cuda.get_device_properties.minor = 7
157 | torch.cuda.ipc_collect = lambda *args, **kwargs: None
158 | torch.cuda.utilization = lambda *args, **kwargs: 0
159 | if hasattr(torch.xpu, 'getDeviceIdListForCard'):
160 | torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
161 | torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard
162 | else:
163 | torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card
164 | torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card
165 |
166 | ipex_hijacks()
167 | attention_init()
168 | try:
169 | from .diffusers import ipex_diffusers
170 | ipex_diffusers()
171 | except Exception: # pylint: disable=broad-exception-caught
172 | pass
173 | except Exception as e:
174 | return False, e
175 | return True, None
176 |
--------------------------------------------------------------------------------
/scripts/kohyas/ipex/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
3 |
4 | # pylint: disable=protected-access, missing-function-docstring, line-too-long
5 |
6 | original_torch_bmm = torch.bmm
7 | def torch_bmm(input, mat2, *, out=None):
8 | if input.dtype != mat2.dtype:
9 | mat2 = mat2.to(input.dtype)
10 |
11 | #ARC GPUs can't allocate more than 4GB to a single block, Slice it:
12 | batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
13 | block_multiply = input.element_size()
14 | slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
15 | block_size = batch_size_attention * slice_block_size
16 |
17 | split_slice_size = batch_size_attention
18 | if block_size > 4:
19 | do_split = True
20 | #Find something divisible with the input_tokens
21 | while (split_slice_size * slice_block_size) > 4:
22 | split_slice_size = split_slice_size // 2
23 | if split_slice_size <= 1:
24 | split_slice_size = 1
25 | break
26 | else:
27 | do_split = False
28 |
29 | split_2_slice_size = input_tokens
30 | if split_slice_size * slice_block_size > 4:
31 | slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
32 | do_split_2 = True
33 | #Find something divisible with the input_tokens
34 | while (split_2_slice_size * slice_block_size2) > 4:
35 | split_2_slice_size = split_2_slice_size // 2
36 | if split_2_slice_size <= 1:
37 | split_2_slice_size = 1
38 | break
39 | else:
40 | do_split_2 = False
41 |
42 | if do_split:
43 | hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
44 | for i in range(batch_size_attention // split_slice_size):
45 | start_idx = i * split_slice_size
46 | end_idx = (i + 1) * split_slice_size
47 | if do_split_2:
48 | for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
49 | start_idx_2 = i2 * split_2_slice_size
50 | end_idx_2 = (i2 + 1) * split_2_slice_size
51 | hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
52 | input[start_idx:end_idx, start_idx_2:end_idx_2],
53 | mat2[start_idx:end_idx, start_idx_2:end_idx_2],
54 | out=out
55 | )
56 | else:
57 | hidden_states[start_idx:end_idx] = original_torch_bmm(
58 | input[start_idx:end_idx],
59 | mat2[start_idx:end_idx],
60 | out=out
61 | )
62 | else:
63 | return original_torch_bmm(input, mat2, out=out)
64 | return hidden_states
65 |
66 | original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
67 | def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
68 | #ARC GPUs can't allocate more than 4GB to a single block, Slice it:
69 | if len(query.shape) == 3:
70 | batch_size_attention, query_tokens, shape_four = query.shape
71 | shape_one = 1
72 | no_shape_one = True
73 | else:
74 | shape_one, batch_size_attention, query_tokens, shape_four = query.shape
75 | no_shape_one = False
76 |
77 | block_multiply = query.element_size()
78 | slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
79 | block_size = batch_size_attention * slice_block_size
80 |
81 | split_slice_size = batch_size_attention
82 | if block_size > 4:
83 | do_split = True
84 | #Find something divisible with the shape_one
85 | while (split_slice_size * slice_block_size) > 4:
86 | split_slice_size = split_slice_size // 2
87 | if split_slice_size <= 1:
88 | split_slice_size = 1
89 | break
90 | else:
91 | do_split = False
92 |
93 | split_2_slice_size = query_tokens
94 | if split_slice_size * slice_block_size > 4:
95 | slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply
96 | do_split_2 = True
97 | #Find something divisible with the batch_size_attention
98 | while (split_2_slice_size * slice_block_size2) > 4:
99 | split_2_slice_size = split_2_slice_size // 2
100 | if split_2_slice_size <= 1:
101 | split_2_slice_size = 1
102 | break
103 | else:
104 | do_split_2 = False
105 |
106 | if do_split:
107 | hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
108 | for i in range(batch_size_attention // split_slice_size):
109 | start_idx = i * split_slice_size
110 | end_idx = (i + 1) * split_slice_size
111 | if do_split_2:
112 | for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
113 | start_idx_2 = i2 * split_2_slice_size
114 | end_idx_2 = (i2 + 1) * split_2_slice_size
115 | if no_shape_one:
116 | hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
117 | query[start_idx:end_idx, start_idx_2:end_idx_2],
118 | key[start_idx:end_idx, start_idx_2:end_idx_2],
119 | value[start_idx:end_idx, start_idx_2:end_idx_2],
120 | attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
121 | dropout_p=dropout_p, is_causal=is_causal
122 | )
123 | else:
124 | hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
125 | query[:, start_idx:end_idx, start_idx_2:end_idx_2],
126 | key[:, start_idx:end_idx, start_idx_2:end_idx_2],
127 | value[:, start_idx:end_idx, start_idx_2:end_idx_2],
128 | attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
129 | dropout_p=dropout_p, is_causal=is_causal
130 | )
131 | else:
132 | if no_shape_one:
133 | hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
134 | query[start_idx:end_idx],
135 | key[start_idx:end_idx],
136 | value[start_idx:end_idx],
137 | attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
138 | dropout_p=dropout_p, is_causal=is_causal
139 | )
140 | else:
141 | hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
142 | query[:, start_idx:end_idx],
143 | key[:, start_idx:end_idx],
144 | value[:, start_idx:end_idx],
145 | attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask,
146 | dropout_p=dropout_p, is_causal=is_causal
147 | )
148 | else:
149 | return original_scaled_dot_product_attention(
150 | query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
151 | )
152 | return hidden_states
153 |
154 | def attention_init():
155 | #ARC GPUs can't allocate more than 4GB to a single block:
156 | torch.bmm = torch_bmm
157 | torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
158 |
--------------------------------------------------------------------------------
/scripts/kohyas/ipex/diffusers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
3 | import diffusers #0.21.1 # pylint: disable=import-error
4 | from diffusers.models.attention_processor import Attention
5 |
6 | # pylint: disable=protected-access, missing-function-docstring, line-too-long
7 |
8 | class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
9 | r"""
10 | Processor for implementing sliced attention.
11 |
12 | Args:
13 | slice_size (`int`, *optional*):
14 | The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
15 | `attention_head_dim` must be a multiple of the `slice_size`.
16 | """
17 |
18 | def __init__(self, slice_size):
19 | self.slice_size = slice_size
20 |
21 | def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
22 | residual = hidden_states
23 |
24 | input_ndim = hidden_states.ndim
25 |
26 | if input_ndim == 4:
27 | batch_size, channel, height, width = hidden_states.shape
28 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
29 |
30 | batch_size, sequence_length, _ = (
31 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
32 | )
33 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
34 |
35 | if attn.group_norm is not None:
36 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
37 |
38 | query = attn.to_q(hidden_states)
39 | dim = query.shape[-1]
40 | query = attn.head_to_batch_dim(query)
41 |
42 | if encoder_hidden_states is None:
43 | encoder_hidden_states = hidden_states
44 | elif attn.norm_cross:
45 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
46 |
47 | key = attn.to_k(encoder_hidden_states)
48 | value = attn.to_v(encoder_hidden_states)
49 | key = attn.head_to_batch_dim(key)
50 | value = attn.head_to_batch_dim(value)
51 |
52 | batch_size_attention, query_tokens, shape_three = query.shape
53 | hidden_states = torch.zeros(
54 | (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
55 | )
56 |
57 | #ARC GPUs can't allocate more than 4GB to a single block, Slice it:
58 | block_multiply = query.element_size()
59 | slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply
60 | block_size = query_tokens * slice_block_size
61 | split_2_slice_size = query_tokens
62 | if block_size > 4:
63 | do_split_2 = True
64 | #Find something divisible with the query_tokens
65 | while (split_2_slice_size * slice_block_size) > 4:
66 | split_2_slice_size = split_2_slice_size // 2
67 | if split_2_slice_size <= 1:
68 | split_2_slice_size = 1
69 | break
70 | else:
71 | do_split_2 = False
72 |
73 | for i in range(batch_size_attention // self.slice_size):
74 | start_idx = i * self.slice_size
75 | end_idx = (i + 1) * self.slice_size
76 |
77 | if do_split_2:
78 | for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
79 | start_idx_2 = i2 * split_2_slice_size
80 | end_idx_2 = (i2 + 1) * split_2_slice_size
81 |
82 | query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
83 | key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
84 | attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
85 |
86 | attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
87 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
88 |
89 | hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
90 | else:
91 | query_slice = query[start_idx:end_idx]
92 | key_slice = key[start_idx:end_idx]
93 | attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
94 |
95 | attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
96 |
97 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
98 |
99 | hidden_states[start_idx:end_idx] = attn_slice
100 |
101 | hidden_states = attn.batch_to_head_dim(hidden_states)
102 |
103 | # linear proj
104 | hidden_states = attn.to_out[0](hidden_states)
105 | # dropout
106 | hidden_states = attn.to_out[1](hidden_states)
107 |
108 | if input_ndim == 4:
109 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
110 |
111 | if attn.residual_connection:
112 | hidden_states = hidden_states + residual
113 |
114 | hidden_states = hidden_states / attn.rescale_output_factor
115 |
116 | return hidden_states
117 |
118 | def ipex_diffusers():
119 | #ARC GPUs can't allocate more than 4GB to a single block:
120 | diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
121 |
--------------------------------------------------------------------------------
/scripts/kohyas/ipex/gradscaler.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import torch
3 | import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
4 | import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
5 |
6 | # pylint: disable=protected-access, missing-function-docstring, line-too-long
7 |
8 | OptState = ipex.cpu.autocast._grad_scaler.OptState
9 | _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
10 | _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
11 |
12 | def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
13 | per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
14 | per_device_found_inf = _MultiDeviceReplicator(found_inf)
15 |
16 | # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
17 | # There could be hundreds of grads, so we'd like to iterate through them just once.
18 | # However, we don't know their devices or dtypes in advance.
19 |
20 | # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
21 | # Google says mypy struggles with defaultdicts type annotations.
22 | per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
23 | # sync grad to master weight
24 | if hasattr(optimizer, "sync_grad"):
25 | optimizer.sync_grad()
26 | with torch.no_grad():
27 | for group in optimizer.param_groups:
28 | for param in group["params"]:
29 | if param.grad is None:
30 | continue
31 | if (not allow_fp16) and param.grad.dtype == torch.float16:
32 | raise ValueError("Attempting to unscale FP16 gradients.")
33 | if param.grad.is_sparse:
34 | # is_coalesced() == False means the sparse grad has values with duplicate indices.
35 | # coalesce() deduplicates indices and adds all values that have the same index.
36 | # For scaled fp16 values, there's a good chance coalescing will cause overflow,
37 | # so we should check the coalesced _values().
38 | if param.grad.dtype is torch.float16:
39 | param.grad = param.grad.coalesce()
40 | to_unscale = param.grad._values()
41 | else:
42 | to_unscale = param.grad
43 |
44 | # -: is there a way to split by device and dtype without appending in the inner loop?
45 | to_unscale = to_unscale.to("cpu")
46 | per_device_and_dtype_grads[to_unscale.device][
47 | to_unscale.dtype
48 | ].append(to_unscale)
49 |
50 | for _, per_dtype_grads in per_device_and_dtype_grads.items():
51 | for grads in per_dtype_grads.values():
52 | core._amp_foreach_non_finite_check_and_unscale_(
53 | grads,
54 | per_device_found_inf.get("cpu"),
55 | per_device_inv_scale.get("cpu"),
56 | )
57 |
58 | return per_device_found_inf._per_device_tensors
59 |
60 | def unscale_(self, optimizer):
61 | """
62 | Divides ("unscales") the optimizer's gradient tensors by the scale factor.
63 | :meth:`unscale_` is optional, serving cases where you need to
64 | :ref:`modify or inspect gradients`
65 | between the backward pass(es) and :meth:`step`.
66 | If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
67 | Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
68 | ...
69 | scaler.scale(loss).backward()
70 | scaler.unscale_(optimizer)
71 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
72 | scaler.step(optimizer)
73 | scaler.update()
74 | Args:
75 | optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
76 | .. warning::
77 | :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
78 | and only after all gradients for that optimizer's assigned parameters have been accumulated.
79 | Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
80 | .. warning::
81 | :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
82 | """
83 | if not self._enabled:
84 | return
85 |
86 | self._check_scale_growth_tracker("unscale_")
87 |
88 | optimizer_state = self._per_optimizer_states[id(optimizer)]
89 |
90 | if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
91 | raise RuntimeError(
92 | "unscale_() has already been called on this optimizer since the last update()."
93 | )
94 | elif optimizer_state["stage"] is OptState.STEPPED:
95 | raise RuntimeError("unscale_() is being called after step().")
96 |
97 | # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
98 | assert self._scale is not None
99 | inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
100 | found_inf = torch.full(
101 | (1,), 0.0, dtype=torch.float32, device=self._scale.device
102 | )
103 |
104 | optimizer_state["found_inf_per_device"] = self._unscale_grads_(
105 | optimizer, inv_scale, found_inf, False
106 | )
107 | optimizer_state["stage"] = OptState.UNSCALED
108 |
109 | def update(self, new_scale=None):
110 | """
111 | Updates the scale factor.
112 | If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
113 | to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
114 | the scale is multiplied by ``growth_factor`` to increase it.
115 | Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
116 | used directly, it's used to fill GradScaler's internal scale tensor. So if
117 | ``new_scale`` was a tensor, later in-place changes to that tensor will not further
118 | affect the scale GradScaler uses internally.)
119 | Args:
120 | new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
121 | .. warning::
122 | :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
123 | been invoked for all optimizers used this iteration.
124 | """
125 | if not self._enabled:
126 | return
127 |
128 | _scale, _growth_tracker = self._check_scale_growth_tracker("update")
129 |
130 | if new_scale is not None:
131 | # Accept a new user-defined scale.
132 | if isinstance(new_scale, float):
133 | self._scale.fill_(new_scale) # type: ignore[union-attr]
134 | else:
135 | reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
136 | assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
137 | assert new_scale.numel() == 1, reason
138 | assert new_scale.requires_grad is False, reason
139 | self._scale.copy_(new_scale) # type: ignore[union-attr]
140 | else:
141 | # Consume shared inf/nan data collected from optimizers to update the scale.
142 | # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
143 | found_infs = [
144 | found_inf.to(device="cpu", non_blocking=True)
145 | for state in self._per_optimizer_states.values()
146 | for found_inf in state["found_inf_per_device"].values()
147 | ]
148 |
149 | assert len(found_infs) > 0, "No inf checks were recorded prior to update."
150 |
151 | found_inf_combined = found_infs[0]
152 | if len(found_infs) > 1:
153 | for i in range(1, len(found_infs)):
154 | found_inf_combined += found_infs[i]
155 |
156 | to_device = _scale.device
157 | _scale = _scale.to("cpu")
158 | _growth_tracker = _growth_tracker.to("cpu")
159 |
160 | core._amp_update_scale_(
161 | _scale,
162 | _growth_tracker,
163 | found_inf_combined,
164 | self._growth_factor,
165 | self._backoff_factor,
166 | self._growth_interval,
167 | )
168 |
169 | _scale = _scale.to(to_device)
170 | _growth_tracker = _growth_tracker.to(to_device)
171 | # To prepare for next iteration, clear the data collected from optimizers this iteration.
172 | self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
173 |
174 | def gradscaler_init():
175 | torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
176 | torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
177 | torch.xpu.amp.GradScaler.unscale_ = unscale_
178 | torch.xpu.amp.GradScaler.update = update
179 | return torch.xpu.amp.GradScaler
180 |
--------------------------------------------------------------------------------
/scripts/kohyas/ipex/hijacks.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import importlib
3 | import torch
4 | import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
5 |
6 | # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
7 |
8 | class CondFunc: # pylint: disable=missing-class-docstring
9 | def __new__(cls, orig_func, sub_func, cond_func):
10 | self = super(CondFunc, cls).__new__(cls)
11 | if isinstance(orig_func, str):
12 | func_path = orig_func.split('.')
13 | for i in range(len(func_path)-1, -1, -1):
14 | try:
15 | resolved_obj = importlib.import_module('.'.join(func_path[:i]))
16 | break
17 | except ImportError:
18 | pass
19 | for attr_name in func_path[i:-1]:
20 | resolved_obj = getattr(resolved_obj, attr_name)
21 | orig_func = getattr(resolved_obj, func_path[-1])
22 | setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
23 | self.__init__(orig_func, sub_func, cond_func)
24 | return lambda *args, **kwargs: self(*args, **kwargs)
25 | def __init__(self, orig_func, sub_func, cond_func):
26 | self.__orig_func = orig_func
27 | self.__sub_func = sub_func
28 | self.__cond_func = cond_func
29 | def __call__(self, *args, **kwargs):
30 | if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
31 | return self.__sub_func(self.__orig_func, *args, **kwargs)
32 | else:
33 | return self.__orig_func(*args, **kwargs)
34 |
35 | _utils = torch.utils.data._utils
36 | def _shutdown_workers(self):
37 | if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
38 | return
39 | if hasattr(self, "_shutdown") and not self._shutdown:
40 | self._shutdown = True
41 | try:
42 | if hasattr(self, '_pin_memory_thread'):
43 | self._pin_memory_thread_done_event.set()
44 | self._worker_result_queue.put((None, None))
45 | self._pin_memory_thread.join()
46 | self._worker_result_queue.cancel_join_thread()
47 | self._worker_result_queue.close()
48 | self._workers_done_event.set()
49 | for worker_id in range(len(self._workers)):
50 | if self._persistent_workers or self._workers_status[worker_id]:
51 | self._mark_worker_as_unavailable(worker_id, shutdown=True)
52 | for w in self._workers: # pylint: disable=invalid-name
53 | w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
54 | for q in self._index_queues: # pylint: disable=invalid-name
55 | q.cancel_join_thread()
56 | q.close()
57 | finally:
58 | if self._worker_pids_set:
59 | torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
60 | self._worker_pids_set = False
61 | for w in self._workers: # pylint: disable=invalid-name
62 | if w.is_alive():
63 | w.terminate()
64 |
65 | class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
66 | def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
67 | if isinstance(device_ids, list) and len(device_ids) > 1:
68 | print("IPEX backend doesn't support DataParallel on multiple XPU devices")
69 | return module.to("xpu")
70 |
71 | def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
72 | return contextlib.nullcontext()
73 |
74 | def check_device(device):
75 | return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
76 |
77 | def return_xpu(device):
78 | return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
79 |
80 | def ipex_no_cuda(orig_func, *args, **kwargs):
81 | torch.cuda.is_available = lambda: False
82 | orig_func(*args, **kwargs)
83 | torch.cuda.is_available = torch.xpu.is_available
84 |
85 | original_autocast = torch.autocast
86 | def ipex_autocast(*args, **kwargs):
87 | if len(args) > 0 and args[0] == "cuda":
88 | return original_autocast("xpu", *args[1:], **kwargs)
89 | else:
90 | return original_autocast(*args, **kwargs)
91 |
92 | original_torch_cat = torch.cat
93 | def torch_cat(tensor, *args, **kwargs):
94 | if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
95 | return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
96 | else:
97 | return original_torch_cat(tensor, *args, **kwargs)
98 |
99 | original_interpolate = torch.nn.functional.interpolate
100 | def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
101 | if antialias or align_corners is not None:
102 | return_device = tensor.device
103 | return_dtype = tensor.dtype
104 | return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
105 | align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
106 | else:
107 | return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
108 | align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
109 |
110 | original_linalg_solve = torch.linalg.solve
111 | def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
112 | if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
113 | return_device = A.device
114 | return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
115 | else:
116 | return original_linalg_solve(A, B, *args, **kwargs)
117 |
118 | def ipex_hijacks():
119 | CondFunc('torch.Tensor.to',
120 | lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
121 | lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
122 | CondFunc('torch.Tensor.cuda',
123 | lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
124 | lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
125 | CondFunc('torch.empty',
126 | lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
127 | lambda orig_func, *args, device=None, **kwargs: check_device(device))
128 | CondFunc('torch.load',
129 | lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs),
130 | lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location))
131 | CondFunc('torch.randn',
132 | lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
133 | lambda orig_func, *args, device=None, **kwargs: check_device(device))
134 | CondFunc('torch.ones',
135 | lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
136 | lambda orig_func, *args, device=None, **kwargs: check_device(device))
137 | CondFunc('torch.zeros',
138 | lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
139 | lambda orig_func, *args, device=None, **kwargs: check_device(device))
140 | CondFunc('torch.tensor',
141 | lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
142 | lambda orig_func, *args, device=None, **kwargs: check_device(device))
143 | CondFunc('torch.linspace',
144 | lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
145 | lambda orig_func, *args, device=None, **kwargs: check_device(device))
146 |
147 | CondFunc('torch.Generator',
148 | lambda orig_func, device=None: torch.xpu.Generator(device),
149 | lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
150 |
151 | CondFunc('torch.batch_norm',
152 | lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
153 | weight if weight is not None else torch.ones(input.size()[1], device=input.device),
154 | bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
155 | lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
156 | CondFunc('torch.instance_norm',
157 | lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
158 | weight if weight is not None else torch.ones(input.size()[1], device=input.device),
159 | bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
160 | lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
161 |
162 | #Functions with dtype errors:
163 | CondFunc('torch.nn.modules.GroupNorm.forward',
164 | lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
165 | lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
166 | CondFunc('torch.nn.modules.linear.Linear.forward',
167 | lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
168 | lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
169 | CondFunc('torch.nn.modules.conv.Conv2d.forward',
170 | lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
171 | lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
172 | CondFunc('torch.nn.functional.layer_norm',
173 | lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
174 | orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
175 | lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
176 | weight is not None and input.dtype != weight.data.dtype)
177 |
178 | #Diffusers Float64 (ARC GPUs doesn't support double or Float64):
179 | if not torch.xpu.has_fp64_dtype():
180 | CondFunc('torch.from_numpy',
181 | lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
182 | lambda orig_func, ndarray: ndarray.dtype == float)
183 |
184 | #Broken functions when torch.cuda.is_available is True:
185 | CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
186 | lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
187 | lambda orig_func, *args, **kwargs: True)
188 |
189 | #Functions that make compile mad with CondFunc:
190 | torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
191 | torch.nn.DataParallel = DummyDataParallel
192 | torch.autocast = ipex_autocast
193 | torch.cat = torch_cat
194 | torch.linalg.solve = linalg_solve
195 | torch.nn.functional.interpolate = interpolate
196 | torch.backends.cuda.sdp_kernel = return_null_context
197 |
--------------------------------------------------------------------------------
/scripts/kohyas/merge_lora.py:
--------------------------------------------------------------------------------
1 | import math
2 | import argparse
3 | import os
4 | import time
5 | import torch
6 | from safetensors.torch import load_file, save_file
7 | from library import sai_model_spec, train_util
8 | from scripts.kohyas import model_util, lora
9 |
10 | def load_state_dict(file_name, dtype):
11 | if os.path.splitext(file_name)[1] == ".safetensors":
12 | sd = load_file(file_name)
13 | metadata = train_util.load_metadata_from_safetensors(file_name)
14 | else:
15 | sd = torch.load(file_name, map_location="cpu")
16 | metadata = {}
17 |
18 | for key in list(sd.keys()):
19 | if type(sd[key]) == torch.Tensor:
20 | sd[key] = sd[key].to(dtype)
21 |
22 | return sd, metadata
23 |
24 |
25 | def save_to_file(file_name, model, state_dict, dtype, metadata):
26 | if dtype is not None:
27 | for key in list(state_dict.keys()):
28 | if type(state_dict[key]) == torch.Tensor:
29 | state_dict[key] = state_dict[key].to(dtype)
30 |
31 | if os.path.splitext(file_name)[1] == ".safetensors":
32 | save_file(model, file_name, metadata=metadata)
33 | else:
34 | torch.save(model, file_name)
35 |
36 |
37 | def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
38 | text_encoder.to(merge_dtype)
39 | unet.to(merge_dtype)
40 |
41 | # create module map
42 | name_to_module = {}
43 | for i, root_module in enumerate([text_encoder, unet]):
44 | if i == 0:
45 | prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
46 | target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
47 | else:
48 | prefix = lora.LoRANetwork.LORA_PREFIX_UNET
49 | target_replace_modules = (
50 | lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
51 | )
52 |
53 | for name, module in root_module.named_modules():
54 | if module.__class__.__name__ in target_replace_modules:
55 | for child_name, child_module in module.named_modules():
56 | if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
57 | lora_name = prefix + "." + name + "." + child_name
58 | lora_name = lora_name.replace(".", "_")
59 | name_to_module[lora_name] = child_module
60 |
61 | for model, ratio in zip(models, ratios):
62 | print(f"loading: {model}")
63 | lora_sd, _ = load_state_dict(model, merge_dtype)
64 |
65 | print(f"merging...")
66 | for key in lora_sd.keys():
67 | if "lora_down" in key:
68 | up_key = key.replace("lora_down", "lora_up")
69 | alpha_key = key[: key.index("lora_down")] + "alpha"
70 |
71 | # find original module for this lora
72 | module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
73 | if module_name not in name_to_module:
74 | print(f"no module found for LoRA weight: {key}")
75 | continue
76 | module = name_to_module[module_name]
77 | # print(f"apply {key} to {module}")
78 |
79 | down_weight = lora_sd[key]
80 | up_weight = lora_sd[up_key]
81 |
82 | dim = down_weight.size()[0]
83 | alpha = lora_sd.get(alpha_key, dim)
84 | scale = alpha / dim
85 |
86 | # W <- W + U * D
87 | weight = module.weight
88 | if len(weight.size()) == 2:
89 | # linear
90 | if len(up_weight.size()) == 4: # use linear projection mismatch
91 | up_weight = up_weight.squeeze(3).squeeze(2)
92 | down_weight = down_weight.squeeze(3).squeeze(2)
93 | weight = weight + ratio * (up_weight @ down_weight) * scale
94 | elif down_weight.size()[2:4] == (1, 1):
95 | # conv2d 1x1
96 | weight = (
97 | weight
98 | + ratio
99 | * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
100 | * scale
101 | )
102 | else:
103 | # conv2d 3x3
104 | conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
105 | # print(conved.size(), weight.size(), module.stride, module.padding)
106 | weight = weight + ratio * conved * scale
107 |
108 | module.weight = torch.nn.Parameter(weight)
109 |
110 |
111 | def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
112 | base_alphas = {} # alpha for merged model
113 | base_dims = {}
114 |
115 | merged_sd = {}
116 | v2 = None
117 | base_model = None
118 | for model, ratio in zip(models, ratios):
119 | print(f"loading: {model}")
120 | lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
121 |
122 | if lora_metadata is not None:
123 | if v2 is None:
124 | v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string
125 | if base_model is None:
126 | base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
127 |
128 | # get alpha and dim
129 | alphas = {} # alpha for current model
130 | dims = {} # dims for current model
131 | for key in lora_sd.keys():
132 | if "alpha" in key:
133 | lora_module_name = key[: key.rfind(".alpha")]
134 | alpha = float(lora_sd[key].detach().numpy())
135 | alphas[lora_module_name] = alpha
136 | if lora_module_name not in base_alphas:
137 | base_alphas[lora_module_name] = alpha
138 | elif "lora_down" in key:
139 | lora_module_name = key[: key.rfind(".lora_down")]
140 | dim = lora_sd[key].size()[0]
141 | dims[lora_module_name] = dim
142 | if lora_module_name not in base_dims:
143 | base_dims[lora_module_name] = dim
144 |
145 | for lora_module_name in dims.keys():
146 | if lora_module_name not in alphas:
147 | alpha = dims[lora_module_name]
148 | alphas[lora_module_name] = alpha
149 | if lora_module_name not in base_alphas:
150 | base_alphas[lora_module_name] = alpha
151 |
152 | print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
153 |
154 | # merge
155 | print(f"merging...")
156 | for key in lora_sd.keys():
157 | if "alpha" in key:
158 | continue
159 | if "lora_up" in key and concat:
160 | concat_dim = 1
161 | elif "lora_down" in key and concat:
162 | concat_dim = 0
163 | else:
164 | concat_dim = None
165 |
166 | lora_module_name = key[: key.rfind(".lora_")]
167 |
168 | base_alpha = base_alphas[lora_module_name]
169 | alpha = alphas[lora_module_name]
170 |
171 | scale = math.sqrt(alpha / base_alpha) * ratio
172 | scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
173 |
174 | if key in merged_sd:
175 | assert (
176 | merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
177 | ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
178 | if concat_dim is not None:
179 | merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
180 | else:
181 | merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
182 | else:
183 | merged_sd[key] = lora_sd[key] * scale
184 |
185 | # set alpha to sd
186 | for lora_module_name, alpha in base_alphas.items():
187 | key = lora_module_name + ".alpha"
188 | merged_sd[key] = torch.tensor(alpha)
189 | if shuffle:
190 | key_down = lora_module_name + ".lora_down.weight"
191 | key_up = lora_module_name + ".lora_up.weight"
192 | dim = merged_sd[key_down].shape[0]
193 | perm = torch.randperm(dim)
194 | merged_sd[key_down] = merged_sd[key_down][perm]
195 | merged_sd[key_up] = merged_sd[key_up][:,perm]
196 |
197 | print("merged model")
198 | print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
199 |
200 | # check all dims are same
201 | dims_list = list(set(base_dims.values()))
202 | alphas_list = list(set(base_alphas.values()))
203 | all_same_dims = True
204 | all_same_alphas = True
205 | for dims in dims_list:
206 | if dims != dims_list[0]:
207 | all_same_dims = False
208 | break
209 | for alphas in alphas_list:
210 | if alphas != alphas_list[0]:
211 | all_same_alphas = False
212 | break
213 |
214 | # build minimum metadata
215 | dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
216 | alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
217 | metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None)
218 |
219 | return merged_sd, metadata, v2 == "True"
220 |
221 |
222 | def merge(args):
223 | assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
224 |
225 | def str_to_dtype(p):
226 | if p == "float":
227 | return torch.float
228 | if p == "fp16":
229 | return torch.float16
230 | if p == "bf16":
231 | return torch.bfloat16
232 | return None
233 |
234 | merge_dtype = str_to_dtype(args.precision)
235 | save_dtype = str_to_dtype(args.save_precision)
236 | if save_dtype is None:
237 | save_dtype = merge_dtype
238 |
239 | if args.sd_model is not None:
240 | print(f"loading SD model: {args.sd_model}")
241 |
242 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
243 |
244 | merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
245 |
246 | if args.no_metadata:
247 | sai_metadata = None
248 | else:
249 | merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models)
250 | title = os.path.splitext(os.path.basename(args.save_to))[0]
251 | sai_metadata = sai_model_spec.build_metadata(
252 | None,
253 | args.v2,
254 | args.v2,
255 | False,
256 | False,
257 | False,
258 | time.time(),
259 | title=title,
260 | merged_from=merged_from,
261 | is_stable_diffusion_ckpt=True,
262 | )
263 | if args.v2:
264 | # TODO read sai modelspec
265 | print(
266 | "Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
267 | )
268 |
269 | print(f"saving SD model to: {args.save_to}")
270 | model_util.save_stable_diffusion_checkpoint(
271 | args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
272 | )
273 | else:
274 | state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
275 |
276 | print(f"calculating hashes and creating metadata...")
277 |
278 | model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
279 | metadata["sshs_model_hash"] = model_hash
280 | metadata["sshs_legacy_hash"] = legacy_hash
281 |
282 | if not args.no_metadata:
283 | merged_from = sai_model_spec.build_merged_from(args.models)
284 | title = os.path.splitext(os.path.basename(args.save_to))[0]
285 | sai_metadata = sai_model_spec.build_metadata(
286 | state_dict, v2, v2, False, True, False, time.time(), title=title, merged_from=merged_from
287 | )
288 | if v2:
289 | # TODO read sai modelspec
290 | print(
291 | "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
292 | )
293 | metadata.update(sai_metadata)
294 |
295 | print(f"saving model to: {args.save_to}")
296 | save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
297 |
298 |
299 | def setup_parser() -> argparse.ArgumentParser:
300 | parser = argparse.ArgumentParser()
301 | parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
302 | parser.add_argument(
303 | "--save_precision",
304 | type=str,
305 | default=None,
306 | choices=[None, "float", "fp16", "bf16"],
307 | help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
308 | )
309 | parser.add_argument(
310 | "--precision",
311 | type=str,
312 | default="float",
313 | choices=["float", "fp16", "bf16"],
314 | help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
315 | )
316 | parser.add_argument(
317 | "--sd_model",
318 | type=str,
319 | default=None,
320 | help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
321 | )
322 | parser.add_argument(
323 | "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
324 | )
325 | parser.add_argument(
326 | "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
327 | )
328 | parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
329 | parser.add_argument(
330 | "--no_metadata",
331 | action="store_true",
332 | help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
333 | + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
334 | )
335 | parser.add_argument(
336 | "--concat",
337 | action="store_true",
338 | help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
339 | + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
340 | )
341 | parser.add_argument(
342 | "--shuffle",
343 | action="store_true",
344 | help="shuffle lora weight./ "
345 | + "LoRAの重みをシャッフルする",
346 | )
347 |
348 | return parser
349 |
350 |
351 | if __name__ == "__main__":
352 | parser = setup_parser()
353 |
354 | args = parser.parse_args()
355 | merge(args)
356 |
--------------------------------------------------------------------------------
/scripts/kohyas/sai_model_spec.py:
--------------------------------------------------------------------------------
1 | # based on https://github.com/Stability-AI/ModelSpec
2 | import datetime
3 | import hashlib
4 | from io import BytesIO
5 | import os
6 | from typing import List, Optional, Tuple, Union
7 | import safetensors
8 |
9 | r"""
10 | # Metadata Example
11 | metadata = {
12 | # === Must ===
13 | "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
14 | "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
15 | "modelspec.implementation": "sgm",
16 | "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
17 | # === Should ===
18 | "modelspec.author": "Example Corp", # Your name or company name
19 | "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
20 | "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
21 | # === Can ===
22 | "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
23 | "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
24 | }
25 | """
26 |
27 | BASE_METADATA = {
28 | # === Must ===
29 | "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
30 | "modelspec.architecture": None,
31 | "modelspec.implementation": None,
32 | "modelspec.title": None,
33 | "modelspec.resolution": None,
34 | # === Should ===
35 | "modelspec.description": None,
36 | "modelspec.author": None,
37 | "modelspec.date": None,
38 | # === Can ===
39 | "modelspec.license": None,
40 | "modelspec.tags": None,
41 | "modelspec.merged_from": None,
42 | "modelspec.prediction_type": None,
43 | "modelspec.timestep_range": None,
44 | "modelspec.encoder_layer": None,
45 | }
46 |
47 | # 別に使うやつだけ定義
48 | MODELSPEC_TITLE = "modelspec.title"
49 |
50 | ARCH_SD_V1 = "stable-diffusion-v1"
51 | ARCH_SD_V2_512 = "stable-diffusion-v2-512"
52 | ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
53 | ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
54 |
55 | ADAPTER_LORA = "lora"
56 | ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
57 |
58 | IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
59 | IMPL_DIFFUSERS = "diffusers"
60 |
61 | PRED_TYPE_EPSILON = "epsilon"
62 | PRED_TYPE_V = "v"
63 |
64 |
65 | def load_bytes_in_safetensors(tensors):
66 | bytes = safetensors.torch.save(tensors)
67 | b = BytesIO(bytes)
68 |
69 | b.seek(0)
70 | header = b.read(8)
71 | n = int.from_bytes(header, "little")
72 |
73 | offset = n + 8
74 | b.seek(offset)
75 |
76 | return b.read()
77 |
78 |
79 | def precalculate_safetensors_hashes(state_dict):
80 | # calculate each tensor one by one to reduce memory usage
81 | hash_sha256 = hashlib.sha256()
82 | for tensor in state_dict.values():
83 | single_tensor_sd = {"tensor": tensor}
84 | bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
85 | hash_sha256.update(bytes_for_tensor)
86 |
87 | return f"0x{hash_sha256.hexdigest()}"
88 |
89 |
90 | def update_hash_sha256(metadata: dict, state_dict: dict):
91 | raise NotImplementedError
92 |
93 |
94 | def build_metadata(
95 | state_dict: Optional[dict],
96 | v2: bool,
97 | v_parameterization: bool,
98 | sdxl: bool,
99 | lora: bool,
100 | textual_inversion: bool,
101 | timestamp: float,
102 | title: Optional[str] = None,
103 | reso: Optional[Union[int, Tuple[int, int]]] = None,
104 | is_stable_diffusion_ckpt: Optional[bool] = None,
105 | author: Optional[str] = None,
106 | description: Optional[str] = None,
107 | license: Optional[str] = None,
108 | tags: Optional[str] = None,
109 | merged_from: Optional[str] = None,
110 | timesteps: Optional[Tuple[int, int]] = None,
111 | clip_skip: Optional[int] = None,
112 | ):
113 | # if state_dict is None, hash is not calculated
114 |
115 | metadata = {}
116 | metadata.update(BASE_METADATA)
117 |
118 | # TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
119 | # if state_dict is not None:
120 | # hash = precalculate_safetensors_hashes(state_dict)
121 | # metadata["modelspec.hash_sha256"] = hash
122 |
123 | if sdxl:
124 | arch = ARCH_SD_XL_V1_BASE
125 | elif v2:
126 | if v_parameterization:
127 | arch = ARCH_SD_V2_768_V
128 | else:
129 | arch = ARCH_SD_V2_512
130 | else:
131 | arch = ARCH_SD_V1
132 |
133 | if lora:
134 | arch += f"/{ADAPTER_LORA}"
135 | elif textual_inversion:
136 | arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
137 |
138 | metadata["modelspec.architecture"] = arch
139 |
140 | if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
141 | is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
142 |
143 | if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
144 | # Stable Diffusion ckpt, TI, SDXL LoRA
145 | impl = IMPL_STABILITY_AI
146 | else:
147 | # v1/v2 LoRA or Diffusers
148 | impl = IMPL_DIFFUSERS
149 | metadata["modelspec.implementation"] = impl
150 |
151 | if title is None:
152 | if lora:
153 | title = "LoRA"
154 | elif textual_inversion:
155 | title = "TextualInversion"
156 | else:
157 | title = "Checkpoint"
158 | title += f"@{timestamp}"
159 | metadata[MODELSPEC_TITLE] = title
160 |
161 | if author is not None:
162 | metadata["modelspec.author"] = author
163 | else:
164 | del metadata["modelspec.author"]
165 |
166 | if description is not None:
167 | metadata["modelspec.description"] = description
168 | else:
169 | del metadata["modelspec.description"]
170 |
171 | if merged_from is not None:
172 | metadata["modelspec.merged_from"] = merged_from
173 | else:
174 | del metadata["modelspec.merged_from"]
175 |
176 | if license is not None:
177 | metadata["modelspec.license"] = license
178 | else:
179 | del metadata["modelspec.license"]
180 |
181 | if tags is not None:
182 | metadata["modelspec.tags"] = tags
183 | else:
184 | del metadata["modelspec.tags"]
185 |
186 | # remove microsecond from time
187 | int_ts = int(timestamp)
188 |
189 | # time to iso-8601 compliant date
190 | date = datetime.datetime.fromtimestamp(int_ts).isoformat()
191 | metadata["modelspec.date"] = date
192 |
193 | if reso is not None:
194 | # comma separated to tuple
195 | if isinstance(reso, str):
196 | reso = tuple(map(int, reso.split(",")))
197 | if len(reso) == 1:
198 | reso = (reso[0], reso[0])
199 | else:
200 | # resolution is defined in dataset, so use default
201 | if sdxl:
202 | reso = 1024
203 | elif v2 and v_parameterization:
204 | reso = 768
205 | else:
206 | reso = 512
207 | if isinstance(reso, int):
208 | reso = (reso, reso)
209 |
210 | metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
211 |
212 | if v_parameterization:
213 | metadata["modelspec.prediction_type"] = PRED_TYPE_V
214 | else:
215 | metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
216 |
217 | if timesteps is not None:
218 | if isinstance(timesteps, str) or isinstance(timesteps, int):
219 | timesteps = (timesteps, timesteps)
220 | if len(timesteps) == 1:
221 | timesteps = (timesteps[0], timesteps[0])
222 | metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
223 | else:
224 | del metadata["modelspec.timestep_range"]
225 |
226 | if clip_skip is not None:
227 | metadata["modelspec.encoder_layer"] = f"{clip_skip}"
228 | else:
229 | del metadata["modelspec.encoder_layer"]
230 |
231 | # # assert all values are filled
232 | # assert all([v is not None for v in metadata.values()]), metadata
233 | if not all([v is not None for v in metadata.values()]):
234 | print(f"Internal error: some metadata values are None: {metadata}")
235 |
236 | return metadata
237 |
238 |
239 | # region utils
240 |
241 |
242 | def get_title(metadata: dict) -> Optional[str]:
243 | return metadata.get(MODELSPEC_TITLE, None)
244 |
245 |
246 | def load_metadata_from_safetensors(model: str) -> dict:
247 | if not model.endswith(".safetensors"):
248 | return {}
249 |
250 | with safetensors.safe_open(model, framework="pt") as f:
251 | metadata = f.metadata()
252 | if metadata is None:
253 | metadata = {}
254 | return metadata
255 |
256 |
257 | def build_merged_from(models: List[str]) -> str:
258 | def get_title(model: str):
259 | metadata = load_metadata_from_safetensors(model)
260 | title = metadata.get(MODELSPEC_TITLE, None)
261 | if title is None:
262 | title = os.path.splitext(os.path.basename(model))[0] # use filename
263 | return title
264 |
265 | titles = [get_title(model) for model in models]
266 | return ", ".join(titles)
267 |
268 |
269 | # endregion
270 |
271 |
272 | r"""
273 | if __name__ == "__main__":
274 | import argparse
275 | import torch
276 | from safetensors.torch import load_file
277 | from library import train_util
278 |
279 | parser = argparse.ArgumentParser()
280 | parser.add_argument("--ckpt", type=str, required=True)
281 | args = parser.parse_args()
282 |
283 | print(f"Loading {args.ckpt}")
284 | state_dict = load_file(args.ckpt)
285 |
286 | print(f"Calculating metadata")
287 | metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
288 | print(metadata)
289 | del state_dict
290 |
291 | # by reference implementation
292 | with open(args.ckpt, mode="rb") as file_data:
293 | file_hash = hashlib.sha256()
294 | head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
295 | header = json.loads(file_data.read(head_len[0])) # header itself, json string
296 | content = (
297 | file_data.read()
298 | ) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
299 | file_hash.update(content)
300 | # ===== Update the hash for modelspec =====
301 | by_ref = f"0x{file_hash.hexdigest()}"
302 | print(by_ref)
303 | print("is same?", by_ref == metadata["modelspec.hash_sha256"])
304 |
305 | """
306 |
--------------------------------------------------------------------------------
/scripts/kohyas/sdxl_merge_lora.py:
--------------------------------------------------------------------------------
1 | import math
2 | import argparse
3 | import os
4 | import time
5 | import torch
6 | from safetensors.torch import load_file, save_file
7 | from tqdm import tqdm
8 | from scripts.kohyas import sai_model_spec, sdxl_model_util, train_util, lora
9 |
10 | def load_state_dict(file_name, dtype):
11 | if os.path.splitext(file_name)[1] == ".safetensors":
12 | sd = load_file(file_name)
13 | metadata = train_util.load_metadata_from_safetensors(file_name)
14 | else:
15 | sd = torch.load(file_name, map_location="cpu")
16 | metadata = {}
17 |
18 | for key in list(sd.keys()):
19 | if type(sd[key]) == torch.Tensor:
20 | sd[key] = sd[key].to(dtype)
21 |
22 | return sd, metadata
23 |
24 |
25 | def save_to_file(file_name, model, state_dict, dtype, metadata):
26 | if dtype is not None:
27 | for key in list(state_dict.keys()):
28 | if type(state_dict[key]) == torch.Tensor:
29 | state_dict[key] = state_dict[key].to(dtype)
30 |
31 | if os.path.splitext(file_name)[1] == ".safetensors":
32 | save_file(model, file_name, metadata=metadata)
33 | else:
34 | torch.save(model, file_name)
35 |
36 |
37 | def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
38 | text_encoder1.to(merge_dtype)
39 | text_encoder1.to(merge_dtype)
40 | unet.to(merge_dtype)
41 |
42 | # create module map
43 | name_to_module = {}
44 | for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
45 | if i <= 1:
46 | if i == 0:
47 | prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
48 | else:
49 | prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
50 | target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
51 | else:
52 | prefix = lora.LoRANetwork.LORA_PREFIX_UNET
53 | target_replace_modules = (
54 | lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
55 | )
56 |
57 | for name, module in root_module.named_modules():
58 | if module.__class__.__name__ in target_replace_modules:
59 | for child_name, child_module in module.named_modules():
60 | if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
61 | lora_name = prefix + "." + name + "." + child_name
62 | lora_name = lora_name.replace(".", "_")
63 | name_to_module[lora_name] = child_module
64 |
65 | for model, ratio in zip(models, ratios):
66 | print(f"loading: {model}")
67 | lora_sd, _ = load_state_dict(model, merge_dtype)
68 |
69 | print(f"merging...")
70 | for key in tqdm(lora_sd.keys()):
71 | if "lora_down" in key:
72 | up_key = key.replace("lora_down", "lora_up")
73 | alpha_key = key[: key.index("lora_down")] + "alpha"
74 |
75 | # find original module for this lora
76 | module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
77 | if module_name not in name_to_module:
78 | print(f"no module found for LoRA weight: {key}")
79 | continue
80 | module = name_to_module[module_name]
81 | # print(f"apply {key} to {module}")
82 |
83 | down_weight = lora_sd[key]
84 | up_weight = lora_sd[up_key]
85 |
86 | dim = down_weight.size()[0]
87 | alpha = lora_sd.get(alpha_key, dim)
88 | scale = alpha / dim
89 |
90 | # W <- W + U * D
91 | weight = module.weight
92 | # print(module_name, down_weight.size(), up_weight.size())
93 | if len(weight.size()) == 2:
94 | # linear
95 | weight = weight + ratio * (up_weight @ down_weight) * scale
96 | elif down_weight.size()[2:4] == (1, 1):
97 | # conv2d 1x1
98 | weight = (
99 | weight
100 | + ratio
101 | * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
102 | * scale
103 | )
104 | else:
105 | # conv2d 3x3
106 | conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
107 | # print(conved.size(), weight.size(), module.stride, module.padding)
108 | weight = weight + ratio * conved * scale
109 |
110 | module.weight = torch.nn.Parameter(weight)
111 |
112 |
113 | def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
114 | base_alphas = {} # alpha for merged model
115 | base_dims = {}
116 |
117 | merged_sd = {}
118 | v2 = None
119 | base_model = None
120 | for model, ratio in zip(models, ratios):
121 | print(f"loading: {model}")
122 | lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
123 |
124 | if lora_metadata is not None:
125 | if v2 is None:
126 | v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず
127 | if base_model is None:
128 | base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
129 |
130 | # get alpha and dim
131 | alphas = {} # alpha for current model
132 | dims = {} # dims for current model
133 | for key in lora_sd.keys():
134 | if "alpha" in key:
135 | lora_module_name = key[: key.rfind(".alpha")]
136 | alpha = float(lora_sd[key].detach().numpy())
137 | alphas[lora_module_name] = alpha
138 | if lora_module_name not in base_alphas:
139 | base_alphas[lora_module_name] = alpha
140 | elif "lora_down" in key:
141 | lora_module_name = key[: key.rfind(".lora_down")]
142 | dim = lora_sd[key].size()[0]
143 | dims[lora_module_name] = dim
144 | if lora_module_name not in base_dims:
145 | base_dims[lora_module_name] = dim
146 |
147 | for lora_module_name in dims.keys():
148 | if lora_module_name not in alphas:
149 | alpha = dims[lora_module_name]
150 | alphas[lora_module_name] = alpha
151 | if lora_module_name not in base_alphas:
152 | base_alphas[lora_module_name] = alpha
153 |
154 | print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
155 |
156 | # merge
157 | print(f"merging...")
158 | for key in tqdm(lora_sd.keys()):
159 | if "alpha" in key:
160 | continue
161 |
162 | if "lora_up" in key and concat:
163 | concat_dim = 1
164 | elif "lora_down" in key and concat:
165 | concat_dim = 0
166 | else:
167 | concat_dim = None
168 |
169 | lora_module_name = key[: key.rfind(".lora_")]
170 |
171 | base_alpha = base_alphas[lora_module_name]
172 | alpha = alphas[lora_module_name]
173 |
174 | scale = math.sqrt(alpha / base_alpha) * ratio
175 | scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
176 |
177 | if key in merged_sd:
178 | assert (
179 | merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
180 | ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
181 | if concat_dim is not None:
182 | merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
183 | else:
184 | merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
185 | else:
186 | merged_sd[key] = lora_sd[key] * scale
187 |
188 | # set alpha to sd
189 | for lora_module_name, alpha in base_alphas.items():
190 | key = lora_module_name + ".alpha"
191 | merged_sd[key] = torch.tensor(alpha)
192 | if shuffle:
193 | key_down = lora_module_name + ".lora_down.weight"
194 | key_up = lora_module_name + ".lora_up.weight"
195 | dim = merged_sd[key_down].shape[0]
196 | perm = torch.randperm(dim)
197 | merged_sd[key_down] = merged_sd[key_down][perm]
198 | merged_sd[key_up] = merged_sd[key_up][:,perm]
199 |
200 | print("merged model")
201 | print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
202 |
203 | # check all dims are same
204 | dims_list = list(set(base_dims.values()))
205 | alphas_list = list(set(base_alphas.values()))
206 | all_same_dims = True
207 | all_same_alphas = True
208 | for dims in dims_list:
209 | if dims != dims_list[0]:
210 | all_same_dims = False
211 | break
212 | for alphas in alphas_list:
213 | if alphas != alphas_list[0]:
214 | all_same_alphas = False
215 | break
216 |
217 | # build minimum metadata
218 | dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
219 | alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
220 | metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None)
221 |
222 | return merged_sd, metadata
223 |
224 |
225 | def merge(args):
226 | assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
227 |
228 | def str_to_dtype(p):
229 | if p == "float":
230 | return torch.float
231 | if p == "fp16":
232 | return torch.float16
233 | if p == "bf16":
234 | return torch.bfloat16
235 | return None
236 |
237 | merge_dtype = str_to_dtype(args.precision)
238 | save_dtype = str_to_dtype(args.save_precision)
239 | if save_dtype is None:
240 | save_dtype = merge_dtype
241 |
242 | if args.sd_model is not None:
243 | print(f"loading SD model: {args.sd_model}")
244 |
245 | (
246 | text_model1,
247 | text_model2,
248 | vae,
249 | unet,
250 | logit_scale,
251 | ckpt_info,
252 | ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
253 |
254 | merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
255 |
256 | if args.no_metadata:
257 | sai_metadata = None
258 | else:
259 | merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models)
260 | title = os.path.splitext(os.path.basename(args.save_to))[0]
261 | sai_metadata = sai_model_spec.build_metadata(
262 | None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from
263 | )
264 |
265 | print(f"saving SD model to: {args.save_to}")
266 | sdxl_model_util.save_stable_diffusion_checkpoint(
267 | args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
268 | )
269 | else:
270 | state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
271 |
272 | print(f"calculating hashes and creating metadata...")
273 |
274 | model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
275 | metadata["sshs_model_hash"] = model_hash
276 | metadata["sshs_legacy_hash"] = legacy_hash
277 |
278 | if not args.no_metadata:
279 | merged_from = sai_model_spec.build_merged_from(args.models)
280 | title = os.path.splitext(os.path.basename(args.save_to))[0]
281 | sai_metadata = sai_model_spec.build_metadata(
282 | state_dict, False, False, True, True, False, time.time(), title=title, merged_from=merged_from
283 | )
284 | metadata.update(sai_metadata)
285 |
286 | print(f"saving model to: {args.save_to}")
287 | save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
288 |
289 |
290 | def setup_parser() -> argparse.ArgumentParser:
291 | parser = argparse.ArgumentParser()
292 | parser.add_argument(
293 | "--save_precision",
294 | type=str,
295 | default=None,
296 | choices=[None, "float", "fp16", "bf16"],
297 | help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
298 | )
299 | parser.add_argument(
300 | "--precision",
301 | type=str,
302 | default="float",
303 | choices=["float", "fp16", "bf16"],
304 | help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
305 | )
306 | parser.add_argument(
307 | "--sd_model",
308 | type=str,
309 | default=None,
310 | help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
311 | )
312 | parser.add_argument(
313 | "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
314 | )
315 | parser.add_argument(
316 | "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
317 | )
318 | parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
319 | parser.add_argument(
320 | "--no_metadata",
321 | action="store_true",
322 | help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
323 | + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
324 | )
325 | parser.add_argument(
326 | "--concat",
327 | action="store_true",
328 | help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
329 | + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)",
330 | )
331 | parser.add_argument(
332 | "--shuffle",
333 | action="store_true",
334 | help="shuffle lora weight./ "
335 | + "LoRAの重みをシャッフルする",
336 | )
337 |
338 | return parser
339 |
340 |
341 | if __name__ == "__main__":
342 | parser = setup_parser()
343 |
344 | args = parser.parse_args()
345 | merge(args)
346 |
--------------------------------------------------------------------------------
/scripts/kohyas/svd_merge_lora.py:
--------------------------------------------------------------------------------
1 | import math
2 | import argparse
3 | import os
4 | import time
5 | import torch
6 | from safetensors.torch import load_file, save_file
7 | from tqdm import tqdm
8 | from scripts.kohyas import sai_model_spec, train_util
9 |
10 |
11 | CLAMP_QUANTILE = 0.99
12 |
13 |
14 | def load_state_dict(file_name, dtype):
15 | if os.path.splitext(file_name)[1] == ".safetensors":
16 | sd = load_file(file_name)
17 | metadata = train_util.load_metadata_from_safetensors(file_name)
18 | else:
19 | sd = torch.load(file_name, map_location="cpu")
20 | metadata = {}
21 |
22 | for key in list(sd.keys()):
23 | if type(sd[key]) == torch.Tensor:
24 | sd[key] = sd[key].to(dtype)
25 |
26 | return sd, metadata
27 |
28 |
29 | def save_to_file(file_name, state_dict, dtype, metadata):
30 | if dtype is not None:
31 | for key in list(state_dict.keys()):
32 | if type(state_dict[key]) == torch.Tensor:
33 | state_dict[key] = state_dict[key].to(dtype)
34 |
35 | if os.path.splitext(file_name)[1] == ".safetensors":
36 | save_file(state_dict, file_name, metadata=metadata)
37 | else:
38 | torch.save(state_dict, file_name)
39 |
40 |
41 | def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
42 | print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
43 | merged_sd = {}
44 | v2 = None
45 | base_model = None
46 | for model, ratio in zip(models, ratios):
47 | print(f"loading: {model}")
48 | lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
49 |
50 | if lora_metadata is not None:
51 | if v2 is None:
52 | v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string
53 | if base_model is None:
54 | base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
55 |
56 | # merge
57 | print(f"merging...")
58 | for key in tqdm(list(lora_sd.keys())):
59 | if "lora_down" not in key:
60 | continue
61 |
62 | lora_module_name = key[: key.rfind(".lora_down")]
63 |
64 | down_weight = lora_sd[key]
65 | network_dim = down_weight.size()[0]
66 |
67 | up_weight = lora_sd[lora_module_name + ".lora_up.weight"]
68 | alpha = lora_sd.get(lora_module_name + ".alpha", network_dim)
69 |
70 | in_dim = down_weight.size()[1]
71 | out_dim = up_weight.size()[0]
72 | conv2d = len(down_weight.size()) == 4
73 | kernel_size = None if not conv2d else down_weight.size()[2:4]
74 | # print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
75 |
76 | # make original weight if not exist
77 | if lora_module_name not in merged_sd:
78 | weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
79 | if device:
80 | weight = weight.to(device)
81 | else:
82 | weight = merged_sd[lora_module_name]
83 |
84 | # merge to weight
85 | if device:
86 | up_weight = up_weight.to(device)
87 | down_weight = down_weight.to(device)
88 |
89 | # W <- W + U * D
90 | scale = alpha / network_dim
91 |
92 | if device: # and isinstance(scale, torch.Tensor):
93 | scale = scale.to(device)
94 |
95 | if not conv2d: # linear
96 | weight = weight + ratio * (up_weight @ down_weight) * scale
97 | elif kernel_size == (1, 1):
98 | weight = (
99 | weight
100 | + ratio
101 | * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
102 | * scale
103 | )
104 | else:
105 | conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
106 | weight = weight + ratio * conved * scale
107 |
108 | merged_sd[lora_module_name] = weight
109 |
110 | # extract from merged weights
111 | print("extract new lora...")
112 | merged_lora_sd = {}
113 | with torch.no_grad():
114 | for lora_module_name, mat in tqdm(list(merged_sd.items())):
115 | conv2d = len(mat.size()) == 4
116 | kernel_size = None if not conv2d else mat.size()[2:4]
117 | conv2d_3x3 = conv2d and kernel_size != (1, 1)
118 | out_dim, in_dim = mat.size()[0:2]
119 |
120 | if conv2d:
121 | if conv2d_3x3:
122 | mat = mat.flatten(start_dim=1)
123 | else:
124 | mat = mat.squeeze()
125 |
126 | module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
127 | module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
128 |
129 | U, S, Vh = torch.linalg.svd(mat)
130 |
131 | U = U[:, :module_new_rank]
132 | S = S[:module_new_rank]
133 | U = U @ torch.diag(S)
134 |
135 | Vh = Vh[:module_new_rank, :]
136 |
137 | dist = torch.cat([U.flatten(), Vh.flatten()])
138 | hi_val = torch.quantile(dist, CLAMP_QUANTILE)
139 | low_val = -hi_val
140 |
141 | U = U.clamp(low_val, hi_val)
142 | Vh = Vh.clamp(low_val, hi_val)
143 |
144 | if conv2d:
145 | U = U.reshape(out_dim, module_new_rank, 1, 1)
146 | Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
147 |
148 | up_weight = U
149 | down_weight = Vh
150 |
151 | merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
152 | merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
153 | merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
154 |
155 | # build minimum metadata
156 | dims = f"{new_rank}"
157 | alphas = f"{new_rank}"
158 | if new_conv_rank is not None:
159 | network_args = {"conv_dim": new_conv_rank, "conv_alpha": new_conv_rank}
160 | else:
161 | network_args = None
162 | metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, network_args)
163 |
164 | return merged_lora_sd, metadata, v2 == "True", base_model
165 |
166 |
167 | def merge(args):
168 | assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
169 |
170 | def str_to_dtype(p):
171 | if p == "float":
172 | return torch.float
173 | if p == "fp16":
174 | return torch.float16
175 | if p == "bf16":
176 | return torch.bfloat16
177 | return None
178 |
179 | merge_dtype = str_to_dtype(args.precision)
180 | save_dtype = str_to_dtype(args.save_precision)
181 | if save_dtype is None:
182 | save_dtype = merge_dtype
183 |
184 | new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
185 | state_dict, metadata, v2, base_model = merge_lora_models(
186 | args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
187 | )
188 |
189 | print(f"calculating hashes and creating metadata...")
190 |
191 | model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
192 | metadata["sshs_model_hash"] = model_hash
193 | metadata["sshs_legacy_hash"] = legacy_hash
194 |
195 | if not args.no_metadata:
196 | is_sdxl = base_model is not None and base_model.lower().startswith("sdxl")
197 | merged_from = sai_model_spec.build_merged_from(args.models)
198 | title = os.path.splitext(os.path.basename(args.save_to))[0]
199 | sai_metadata = sai_model_spec.build_metadata(
200 | state_dict, v2, v2, is_sdxl, True, False, time.time(), title=title, merged_from=merged_from
201 | )
202 | if v2:
203 | # TODO read sai modelspec
204 | print(
205 | "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
206 | )
207 | metadata.update(sai_metadata)
208 |
209 | print(f"saving model to: {args.save_to}")
210 | save_to_file(args.save_to, state_dict, save_dtype, metadata)
211 |
212 |
213 | def setup_parser() -> argparse.ArgumentParser:
214 | parser = argparse.ArgumentParser()
215 | parser.add_argument(
216 | "--save_precision",
217 | type=str,
218 | default=None,
219 | choices=[None, "float", "fp16", "bf16"],
220 | help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
221 | )
222 | parser.add_argument(
223 | "--precision",
224 | type=str,
225 | default="float",
226 | choices=["float", "fp16", "bf16"],
227 | help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
228 | )
229 | parser.add_argument(
230 | "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
231 | )
232 | parser.add_argument(
233 | "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
234 | )
235 | parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
236 | parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
237 | parser.add_argument(
238 | "--new_conv_rank",
239 | type=int,
240 | default=None,
241 | help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
242 | )
243 | parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
244 | parser.add_argument(
245 | "--no_metadata",
246 | action="store_true",
247 | help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
248 | + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
249 | )
250 |
251 | return parser
252 |
253 |
254 | if __name__ == "__main__":
255 | parser = setup_parser()
256 |
257 | args = parser.parse_args()
258 | merge(args)
259 |
--------------------------------------------------------------------------------
/scripts/mbwpresets_master.txt:
--------------------------------------------------------------------------------
1 | preset_name preset_weights
2 | GRAD_V 0,1,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0
3 | GRAD_A 0,0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5,0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0,0.9166666667,0.8333333333,0.75,0.6666666667,0.5833333333,0.5,0.4166666667,0.3333333333,0.25,0.1666666667,0.0833333333,0
4 | FLAT_25 0,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25
5 | FLAT_75 0,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75,0.75
6 | WRAP08 0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1
7 | WRAP12 0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1
8 | WRAP14 0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
9 | WRAP16 0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1
10 | MID12_50 0,0,0,0,0,0,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0,0,0,0,0,0
11 | OUT07 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1
12 | OUT12 0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1
13 | OUT12_5 0,0,0,0,0,0,0,0,0,0,0,0,0,0.5,1,1,1,1,1,1,1,1,1,1,1,1
14 | RING08_SOFT 0,0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0,0.5,1,1,1,0.5,0,0,0,0,0
15 | RING08_5 0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0
16 | RING10_5 0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0
17 | RING10_3 0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0,0,0,0
18 | SMOOTHSTEP 0,0,0.00506365740740741,0.0196759259259259,0.04296875,0.0740740740740741,0.112123842592593,0.15625,0.205584490740741,0.259259259259259,0.31640625,0.376157407407407,0.437644675925926,0.5,0.562355324074074,0.623842592592592,0.68359375,0.740740740740741,0.794415509259259,0.84375,0.887876157407408,0.925925925925926,0.95703125,0.980324074074074,0.994936342592593,1
19 | REVERSE-SMOOTHSTEP 0,1,0.994936342592593,0.980324074074074,0.95703125,0.925925925925926,0.887876157407407,0.84375,0.794415509259259,0.740740740740741,0.68359375,0.623842592592593,0.562355324074074,0.5,0.437644675925926,0.376157407407408,0.31640625,0.259259259259259,0.205584490740741,0.15625,0.112123842592592,0.0740740740740742,0.0429687499999996,0.0196759259259258,0.00506365740740744,0
20 | SMOOTHSTEP*2 0,0,0.0101273148148148,0.0393518518518519,0.0859375,0.148148148148148,0.224247685185185,0.3125,0.411168981481482,0.518518518518519,0.6328125,0.752314814814815,0.875289351851852,1.,0.875289351851852,0.752314814814815,0.6328125,0.518518518518519,0.411168981481481,0.3125,0.224247685185184,0.148148148148148,0.0859375,0.0393518518518512,0.0101273148148153,0
21 | R_SMOOTHSTEP*2 0,1,0.989872685185185,0.960648148148148,0.9140625,0.851851851851852,0.775752314814815,0.6875,0.588831018518519,0.481481481481481,0.3671875,0.247685185185185,0.124710648148148,0.,0.124710648148148,0.247685185185185,0.3671875,0.481481481481481,0.588831018518519,0.6875,0.775752314814816,0.851851851851852,0.9140625,0.960648148148149,0.989872685185185,1
22 | SMOOTHSTEP*3 0,0,0.0151909722222222,0.0590277777777778,0.12890625,0.222222222222222,0.336371527777778,0.46875,0.616753472222222,0.777777777777778,0.94921875,0.871527777777778,0.687065972222222,0.5,0.312934027777778,0.128472222222222,0.0507812500000004,0.222222222222222,0.383246527777778,0.53125,0.663628472222223,0.777777777777778,0.87109375,0.940972222222222,0.984809027777777,1
23 | R_SMOOTHSTEP*3 0,1,0.984809027777778,0.940972222222222,0.87109375,0.777777777777778,0.663628472222222,0.53125,0.383246527777778,0.222222222222222,0.05078125,0.128472222222222,0.312934027777778,0.5,0.687065972222222,0.871527777777778,0.94921875,0.777777777777778,0.616753472222222,0.46875,0.336371527777777,0.222222222222222,0.12890625,0.0590277777777777,0.0151909722222232,0
24 | SMOOTHSTEP*4 0,0,0.0202546296296296,0.0787037037037037,0.171875,0.296296296296296,0.44849537037037,0.625,0.822337962962963,0.962962962962963,0.734375,0.49537037037037,0.249421296296296,0.,0.249421296296296,0.495370370370371,0.734375000000001,0.962962962962963,0.822337962962962,0.625,0.448495370370369,0.296296296296297,0.171875,0.0787037037037024,0.0202546296296307,0
25 | R_SMOOTHSTEP*4 0,1,0.97974537037037,0.921296296296296,0.828125,0.703703703703704,0.55150462962963,0.375,0.177662037037037,0.0370370370370372,0.265625,0.50462962962963,0.750578703703704,1.,0.750578703703704,0.504629629629629,0.265624999999999,0.0370370370370372,0.177662037037038,0.375,0.551504629629631,0.703703703703703,0.828125,0.921296296296298,0.979745370370369,1
26 | SMOOTHSTEP/2 0,0,0.0196759259259259,0.0740740740740741,0.15625,0.259259259259259,0.376157407407407,0.5,0.623842592592593,0.740740740740741,0.84375,0.925925925925926,0.980324074074074,1.,0.980324074074074,0.925925925925926,0.84375,0.740740740740741,0.623842592592593,0.5,0.376157407407407,0.259259259259259,0.15625,0.0740740740740741,0.0196759259259259,0
27 | R_SMOOTHSTEP/2 0,1,0.980324074074074,0.925925925925926,0.84375,0.740740740740741,0.623842592592593,0.5,0.376157407407407,0.259259259259259,0.15625,0.0740740740740742,0.0196759259259256,0.,0.0196759259259256,0.0740740740740742,0.15625,0.259259259259259,0.376157407407407,0.5,0.623842592592593,0.740740740740741,0.84375,0.925925925925926,0.980324074074074,1
28 | SMOOTHSTEP/3 0,0,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1.,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0.,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1
29 | R_SMOOTHSTEP/3 0,1,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0.,0.04296875,0.15625,0.31640625,0.5,0.68359375,0.84375,0.95703125,1.,0.95703125,0.84375,0.68359375,0.5,0.31640625,0.15625,0.04296875,0
30 | SMOOTHSTEP/4 0,0,0.0740740740740741,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740741,0.,0.0740740740740741,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740741,0
31 | R_SMOOTHSTEP/4 0,1,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740742,0.,0.0740740740740742,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1.,0.925925925925926,0.740740740740741,0.5,0.259259259259259,0.0740740740740742,0.,0.0740740740740742,0.259259259259259,0.5,0.740740740740741,0.925925925925926,1
32 | COSINE 0,1,0.995722430686905,0.982962913144534,0.961939766255643,0.933012701892219,0.896676670145617,0.853553390593274,0.80438071450436,0.75,0.691341716182545,0.62940952255126,0.565263096110026,0.5,0.434736903889974,0.37059047744874,0.308658283817455,0.25,0.195619285495639,0.146446609406726,0.103323329854382,0.0669872981077805,0.0380602337443566,0.0170370868554658,0.00427756931309475,0
33 | REVERSE_COSINE 0,0,0.00427756931309475,0.0170370868554659,0.0380602337443566,0.0669872981077808,0.103323329854383,0.146446609406726,0.19561928549564,0.25,0.308658283817455,0.37059047744874,0.434736903889974,0.5,0.565263096110026,0.62940952255126,0.691341716182545,0.75,0.804380714504361,0.853553390593274,0.896676670145618,0.933012701892219,0.961939766255643,0.982962913144534,0.995722430686905,1
34 | TRUE_CUBIC_HERMITE 0,0,0.199031876929012,0.325761959876543,0.424641927083333,0.498456790123457,0.549991560570988,0.58203125,0.597360869984568,0.598765432098765,0.589029947916667,0.570939429012346,0.547278886959876,0.520833333333333,0.49438777970679,0.470727237654321,0.45263671875,0.442901234567901,0.444305796682099,0.459635416666667,0.491675106095678,0.543209876543211,0.617024739583333,0.715904706790124,0.842634789737655,1
35 | TRUE_REVERSE_CUBIC_HERMITE 0,1,0.800968123070988,0.674238040123457,0.575358072916667,0.501543209876543,0.450008439429012,0.41796875,0.402639130015432,0.401234567901235,0.410970052083333,0.429060570987654,0.452721113040124,0.479166666666667,0.50561222029321,0.529272762345679,0.54736328125,0.557098765432099,0.555694203317901,0.540364583333333,0.508324893904322,0.456790123456789,0.382975260416667,0.284095293209876,0.157365210262345,0
36 | FAKE_CUBIC_HERMITE 0,0,0.157576195987654,0.28491512345679,0.384765625,0.459876543209877,0.512996720679012,0.546875,0.564260223765432,0.567901234567901,0.560546875,0.544945987654321,0.523847415123457,0.5,0.476152584876543,0.455054012345679,0.439453125,0.432098765432099,0.435739776234568,0.453125,0.487003279320987,0.540123456790124,0.615234375,0.71508487654321,0.842423804012347,1
37 | FAKE_REVERSE_CUBIC_HERMITE 0,1,0.842423804012346,0.71508487654321,0.615234375,0.540123456790123,0.487003279320988,0.453125,0.435739776234568,0.432098765432099,0.439453125,0.455054012345679,0.476152584876543,0.5,0.523847415123457,0.544945987654321,0.560546875,0.567901234567901,0.564260223765432,0.546875,0.512996720679013,0.459876543209876,0.384765625,0.28491512345679,0.157576195987653,0
38 | ALL_A 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
39 | ALL_B 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1
40 |
--------------------------------------------------------------------------------
/scripts/mergers/bcolors.py:
--------------------------------------------------------------------------------
1 | class bcolors:
2 | HEADER = '\033[95m'
3 | OKBLUE = '\033[94m'
4 | OKCYAN = '\033[96m'
5 | OKGREEN = '\033[92m'
6 | WARNING = '\033[93m'
7 | FAIL = '\033[91m'
8 | ENDC = '\033[0m'
9 | BOLD = '\033[1m'
10 | UNDERLINE = '\033[4m'
--------------------------------------------------------------------------------
/scripts/mergers/components.py:
--------------------------------------------------------------------------------
1 | merge = None
2 | mergeandgen = None
3 | gen = None
4 | s_reserve = None
5 | s_reserve1 = None
6 | gengrid = None
7 | s_startreserve = None
8 | rand_merge = None
9 | paramsnames = None
10 | frompromptf = None
11 |
12 | msettings = None
13 | esettings1 = None
14 | genparams = None
15 | hiresfix = None
16 | lucks = None
17 | currentmodel = None
18 | dfalse = None
19 | dtrue = None
20 | id_sets = None
21 | xysettings = None
22 |
23 | submit_result = None
24 | imagegal = None
25 | numaframe = None
26 |
27 | txt2img_params = []
28 | img2img_params = []
--------------------------------------------------------------------------------
/scripts/mergers/model_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import safetensors.torch
4 | import threading
5 | from modules import shared, sd_hijack, sd_models
6 | from modules.sd_models import load_model
7 | import json
8 |
9 | try:
10 | from modules import sd_models_xl
11 | xl = True
12 | except:
13 | xl = False
14 |
15 | def prune_model(model, isxl=False):
16 | keys = list(model.keys())
17 | base_prefix = "conditioner." if isxl else "cond_stage_model."
18 | for k in keys:
19 | if "diffusion_model." not in k and "first_stage_model." not in k and base_prefix not in k:
20 | model.pop(k, None)
21 | return model
22 |
23 | def to_half(sd):
24 | for key in sd.keys():
25 | if 'model' in key and sd[key].dtype in {torch.float32, torch.float64, torch.bfloat16}:
26 | sd[key] = sd[key].half()
27 | return sd
28 |
29 | def savemodel(state_dict,currentmodel,fname,savesets,metadata={}):
30 | other_dict = {}
31 | if state_dict is None:
32 | if shared.sd_model and shared.sd_model.sd_checkpoint_info:
33 | metadata = shared.sd_model.sd_checkpoint_info.metadata.copy()
34 | else:
35 | return "Current model is not a valid merged model"
36 |
37 | checkpoint_info = shared.sd_model.sd_checkpoint_info
38 | # check if current merged model is a fake checkpoint_info
39 | if checkpoint_info is not None:
40 | filename = checkpoint_info.filename
41 | name = os.path.basename(filename)
42 | info = sd_models.get_closet_checkpoint_match(name)
43 | if info == checkpoint_info:
44 | # this is a valid checkpoint_info
45 | # no need to save
46 | return "Current model is not a merged model or you've already saved model"
47 |
48 | # prepare metadata
49 | save_metadata = "save metadata" in savesets
50 | if save_metadata:
51 | metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
52 | else:
53 | metadata = {"format": "pt"}
54 |
55 | if shared.sd_model is not None:
56 | print("load from shared.sd_model..")
57 |
58 | # restore textencoder
59 | sd_hijack.model_hijack.undo_hijack(shared.sd_model)
60 |
61 | for name,module in shared.sd_model.named_modules():
62 | if hasattr(module,"network_weights_backup"):
63 | module = network_restore_weights_from_backup(module)
64 |
65 | state_dict = shared.sd_model.state_dict()
66 | for key in list(state_dict.keys()):
67 | if key in POPKEYS:
68 | other_dict[key] = state_dict[key]
69 | del state_dict[key]
70 |
71 | sd_hijack.model_hijack.hijack(shared.sd_model)
72 | else:
73 | return "No current loaded model found"
74 |
75 | # name_for_extra was set with the currentmodel
76 | currentmodel = checkpoint_info.name_for_extra
77 |
78 | if "fp16" in savesets:
79 | pre = ".fp16"
80 | else:pre = ""
81 | ext = ".safetensors" if "safetensors" in savesets else ".ckpt"
82 |
83 | # is it a inpainting or instruct-pix2pix2 model?
84 | if "model.diffusion_model.input_blocks.0.0.weight" in state_dict.keys():
85 | shape = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape
86 | if shape[1] == 9:
87 | pre += "-inpainting"
88 | if shape[1] == 8:
89 | pre += "-instruct-pix2pix"
90 |
91 | if not fname or fname == "":
92 | model_name = os.path.basename(currentmodel)
93 | model_name = model_name.replace(" ", "").replace(",", "_").replace("(", "_").replace(")", "_")
94 | fname = model_name + pre + ext
95 | if fname[0] == "_":
96 | fname = fname[1:]
97 | else:
98 | fname = fname if ext in fname else fname +pre+ext
99 |
100 | fname = os.path.join(shared.cmd_opts.ckpt_dir if shared.cmd_opts.ckpt_dir is not None else sd_models.model_path, fname)
101 | fname = fname.replace("ProgramFiles_x86_","Program Files (x86)")
102 |
103 | if len(fname) > 255:
104 | fname.replace(ext,"")
105 | fname=fname[:240]+ext
106 |
107 | # check if output file already exists
108 | if os.path.isfile(fname) and not "overwrite" in savesets:
109 | _err_msg = f"Output file ({fname}) existed and was not saved]"
110 | print(_err_msg)
111 | return _err_msg
112 |
113 | print("Saving...")
114 | isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in state_dict
115 | if isxl:
116 | # prune share memory tensors, "cond_stage_model." prefixed base tensors are share memory with "conditioner." prefixed tensors
117 | for key in list(state_dict.keys()):
118 | if "cond_stage_model." in key:
119 | del state_dict[key]
120 |
121 | if "fp16" in savesets:
122 | state_dict = to_half(state_dict)
123 | if "prune" in savesets:
124 | state_dict = prune_model(state_dict, isxl)
125 |
126 | # for safetensors contiguous error
127 | print("Check contiguous...")
128 | for key in state_dict.keys():
129 | v = state_dict[key]
130 | v = v.contiguous()
131 | state_dict[key] = v
132 |
133 | try:
134 | if ext == ".safetensors":
135 | safetensors.torch.save_file(state_dict, fname, metadata=metadata)
136 | else:
137 | torch.save(state_dict, fname)
138 | except Exception as e:
139 | print(f"ERROR: Couldn't saved:{fname},ERROR is {e}")
140 | return f"ERROR: Couldn't saved:{fname},ERROR is {e}"
141 | print("Done!")
142 | if other_dict:
143 | for key in other_dict.keys():
144 | state_dict[key] = other_dict[key]
145 | del other_dict
146 | load_model(checkpoint_info, already_loaded_state_dict=state_dict)
147 | return "Merged model saved in "+fname
148 |
149 | def filenamecutter(name,model_a = False):
150 | if name =="" or name ==[]: return
151 | checkpoint_info = sd_models.get_closet_checkpoint_match(name)
152 | name= os.path.splitext(checkpoint_info.filename)[0]
153 |
154 | if not model_a:
155 | name = os.path.basename(name)
156 | return name
157 |
158 | from typing import Union
159 |
160 | def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
161 | weights_backup = getattr(self, "network_weights_backup", None)
162 | bias_backup = getattr(self, "network_bias_backup", None)
163 |
164 | if weights_backup is None and bias_backup is None:
165 | return self
166 |
167 | with torch.no_grad():
168 | if weights_backup is not None:
169 | if isinstance(self, torch.nn.MultiheadAttention):
170 | self.in_proj_weight = torch.nn.Parameter(weights_backup[0].detach().requires_grad_(self.in_proj_weight.requires_grad))
171 | self.out_proj.weight = torch.nn.Parameter(weights_backup[1].detach().requires_grad_(self.out_proj.weight.requires_grad))
172 | else:
173 | self.weight = torch.nn.Parameter(weights_backup.detach().requires_grad_(self.weight.requires_grad))
174 |
175 | if bias_backup is not None:
176 | if isinstance(self, torch.nn.MultiheadAttention):
177 | self.out_proj.bias = torch.nn.Parameter(bias_backup.detach().requires_grad_(self.out_proj.bias.requires_grad))
178 | else:
179 | self.bias = torch.nn.Parameter(bias_backup.detach().requires_grad_(self.bias.requires_grad))
180 | else:
181 | if isinstance(self, torch.nn.MultiheadAttention):
182 | self.out_proj.bias = None
183 | else:
184 | self.bias = None
185 | return self
186 |
187 | def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
188 | self.network_current_names = ()
189 | self.network_weights_backup = None
190 | self.network_bias_backup = None
191 |
192 | POPKEYS=[
193 | "betas",
194 | "alphas_cumprod",
195 | "alphas_cumprod_prev",
196 | "sqrt_alphas_cumprod",
197 | "sqrt_one_minus_alphas_cumprod",
198 | "log_one_minus_alphas_cumprod",
199 | "sqrt_recip_alphas_cumprod",
200 | "sqrt_recipm1_alphas_cumprod",
201 | "posterior_variance",
202 | "posterior_log_variance_clipped",
203 | "posterior_mean_coef1",
204 | "posterior_mean_coef2",
205 | ]
206 |
--------------------------------------------------------------------------------
/style.css:
--------------------------------------------------------------------------------
1 | button.reset {
2 | min-width: min(100px,100%);
3 | }
4 |
5 | button.compact_button {
6 | min-width: min(100px,100%);
7 | }
8 |
--------------------------------------------------------------------------------