├── .gitignore ├── README.md ├── README_each.ja.md ├── README_each.md ├── csv └── preset.tsv ├── javascript └── js_mbw_each.js ├── misc ├── bw01-1.png ├── bw02.png ├── bw03.png ├── bw04.png ├── bw05.png ├── bw06.png ├── bw07.png ├── bw08.png ├── bw09.png ├── bw10.png ├── each │ ├── each_00.png │ └── each_01.png ├── preset_grid │ ├── COSINE.PNG │ ├── FAKE_CUBIC_HERMITE.PNG │ ├── FAKE_REVERSE_CUBIC_HERMITE.PNG │ ├── FLAT_25.PNG │ ├── FLAT_75.PNG │ ├── GRAD_A.PNG │ ├── GRAD_V.PNG │ ├── MID12_50.PNG │ ├── OUT07.PNG │ ├── OUT12.PNG │ ├── OUT12_5.PNG │ ├── REVERSE_COSINE.PNG │ ├── REVERSE_SMOOTHSTEP.PNG │ ├── RING08_5.PNG │ ├── RING08_SOFT.PNG │ ├── RING10_3.PNG │ ├── RING10_5.PNG │ ├── R_SMOOTHSTEP_2.PNG │ ├── R_SMOOTHSTEP_3.PNG │ ├── R_SMOOTHSTEP_4.PNG │ ├── R_SMOOTHSTEPx2.PNG │ ├── R_SMOOTHSTEPx3.PNG │ ├── R_SMOOTHSTEPx4.PNG │ ├── SMOOTHSTEP.PNG │ ├── SMOOTHSTEP_2.PNG │ ├── SMOOTHSTEP_3.PNG │ ├── SMOOTHSTEP_4.PNG │ ├── SMOOTHSTEPx2.PNG │ ├── SMOOTHSTEPx3.PNG │ ├── SMOOTHSTEPx4.PNG │ ├── TRUE_CUBIC_HERMITE.PNG │ ├── TRUE_REVERSE_CUBIC_HERMITE.PNG │ ├── WRAP08.PNG │ ├── WRAP12.PNG │ ├── WRAP14.PNG │ └── WRAP16.PNG ├── xy_plus-0000-40-7_1.png └── xy_plus-0000-40-7_1_2.png ├── scripts ├── mbw │ ├── merge_block_weighted.py │ └── ui_mbw.py ├── mbw_each │ ├── merge_block_weighted_mod.py │ └── ui_mbw_each.py ├── mbw_util │ ├── merge_history.py │ └── preset_weights.py └── merge_block_weighted_extension.py └── style.css /.gitignore: -------------------------------------------------------------------------------- 1 | /csv/history.tsv 2 | /csv/preset_own.tsv 3 | 4 | # 5 | _* 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Merge Block Weighted - GUI 2 | 3 | - This is Extension for [AUTOMATIC1111's Stable Diffusion Web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 4 | - Implementation GUI of [Merge Block Weighted] (https://note.com/kohya_ss/n/n9a485a066d5b) idea by kohya_ss 5 | - change some part of script to adjust for AUTO1111, basic method is not changed. 6 | 7 | # Recent Update 8 | 9 | - 2023/01/12: Add some function 10 | 11 | - Save as half 12 | - Save as safetensors 13 | - Select of "Skip/Reset CKIP `position_ids`" 14 | - [[調査] Smile Test: Elysium_Anime_V3 問題を調べる #3|bbcmc|note](https://note.com/bbcmc/n/n12c05bf109cc) 15 | 16 | - 2022/12/25: Add new feature and new UI 17 | 18 | - Read "README" [English](README_each.md)/[日本語](README_each.ja.md) 19 | 20 | # 21 | 22 | # What is this 23 | 24 | ![](misc/bw01-1.png) 25 | 26 | ## Table of contents 27 | 28 | 29 | 30 | * [How to Install](#how-to-install) 31 | 32 | * [How to use](#how-to-use) 33 | 34 | * [Select `model_A` and `model_B`, and input `Output model name`](#select-model_a-and-model_b-and-input-output-model-name) 35 | * [Set merge ratio for each block of U-Net](#set-merge-ratio-for-each-block-of-u-net) 36 | * [Setting values](#setting-values) 37 | * [base_alpha](#base_alpha) 38 | * [Other settings](#other-settings) 39 | * [Save as half / safetensors](#save-as-half--safetensors) 40 | * [Skip/Reset CLIP `position_ids`key value](#skipreset-clip-position-ids-key-value) 41 | 42 | * [Other function](#other-function) 43 | 44 | * [Save Merge Log](#save-merge-log) 45 | 46 | * [Sample/Example](#sampleexample) 47 | 48 | * [result (x/y)](#result-xy) 49 | * [後述1: weight1](#%E5%BE%8C%E8%BF%B01-weight1) 50 | * [後述2: weight2](#%E5%BE%8C%E8%BF%B02-weight2) 51 | 52 | * [Preset's grids](#presets-grids) 53 | 54 | * [Examples of Sigmoid-like Functions](#examples-of-sigmoid-like-functions) 55 | 56 | * [Special Thanks](#special-thanks) 57 | 58 | 59 | 60 | ## How to Install 61 | 62 | - Go to `Extensions` tab on your web UI 63 | - `Install from URL` with this repo URL 64 | - Install 65 | - Restart Web UI 66 | 67 | ## How to use 68 | 69 | ### Select `model_A` and `model_B`, and input `Output model name` 70 | 71 | ![](misc/bw02.png) 72 | 73 | - if checkpoint is updated, push `Reload Checkpoint` button to reload Dropdown choises. 74 | 75 | ### Set merge ratio for each block of U-Net 76 | 77 | - Select Presets by Dropdown 78 | 79 | ![](misc/bw08.png) 80 | 81 | You can manage presets on tsv file (tab separated file) at `extention//csv/preset.tsv` 82 | ![](misc/bw06.png) 83 | 84 | - or Input at GUI Slider 85 | 86 | ![](misc/bw03.png) 87 | 88 | - "INxx" is input blocks. 12 blocks 89 | 90 | - "M00" is middle block. 1 block 91 | 92 | - "OUTxx" is output blocks. 12 blocks 93 | 94 | ![](misc/bw04.png) 95 | 96 | - You can write your weights in "Textbox" and "Apply block weight from text" 97 | 98 | - Weights must have 25 values and comma separated 99 | 100 | ## Setting values 101 | 102 | ![](misc/bw05.png) 103 | 104 | ### base_alpha 105 | 106 | - set "base_alpha" 107 | 108 | | base_alpha | | 109 | | ---------- | ----------------------------------------------------------------- | 110 | | 0 | merged model uses (Text Encoder、Auto Encoder) 100% from `model_A` | 111 | | 1 | marged model uses (Text Encoder、Auto Encoder) 100% from `model_B` | 112 | 113 | ### Other settings 114 | 115 | | Settings | | 116 | | ---------------------------- | -------------------------------------------------------------- | 117 | | verbose console output | Check true, if want to see some additional info on CLI | 118 | | Allow overwrite output-model | Check true, if allow overwrite model file which has same name. | 119 | 120 | - Merged output is saved in normal "Model" folder. 121 | 122 | ### Save as half / safetensors 123 | 124 | ![](misc/bw09.png) 125 | 126 | - Settings about save 127 | 128 | - "Save as half" mean float16 129 | 130 | - "Save as safetensors". If you set your output file ext as `.safetensors`, automaticaly saved as safetensors with/without this setting. 131 | 132 | ### Skip/Reset CLIP `position_ids`key value 133 | 134 | ![](misc/bw10.png) 135 | 136 | - In this function, you can select treatment of `position_ids` value in CLIP. 137 | - Values in this key controls matching of your prompt and embeddings. 138 | - I've try to found the cause of 'Some model ignore No.1 token(word)' problem, and write some report about that. ([[調査] Smile Test: Elysium_Anime_V3 問題を調べる #3|bbcmc|note](https://note.com/bbcmc/n/n12c05bf109cc)) 139 | - Arenatemp already have spectation of inside of models, and published Extension to fix this CLIP key problem. See also, 140 | - [stable-diffusion-webui-model-toolkit](https://github.com/arenatemp/stable-diffusion-webui-model-toolkit) 141 | - MBW is also affected by this problem, because some model may (potensialy) have this issue, and causes/transfer some trouble to merged result model. 142 | 143 | | Select | Effect | 144 | | ------ | --------------------------------------------------------------- | 145 | | None | do nothing about key. normal merge | 146 | | Skip | Skip `position_ids` key to eject effect. Value of Model A used. | 147 | | Reset | Replace `position_ids` values to tensor([[range(77)]]) | 148 | 149 | ## Other function 150 | 151 | ### Save Merge Log 152 | 153 | - save log about operated merge, as below, 154 | ![](misc/bw07.png) 155 | 156 | - log is saved at `extension//csv/history.tsv` 157 | 158 | ## Sample/Example 159 | 160 | - kohya_ss さんのテストを再現してみる 161 | 162 | - Compare SD15 and WD13 / Stable Diffusion 1.5 と WD 1.3 の結果を見る 163 | - ※元記事は SD14 を使用 (WD13はSD14ベース) 164 | - see also [Stable DiffusionのモデルをU-Netの深さに応じて比率を変えてマージする|Kohya S.|note](https://note.com/kohya_ss/n/n9a485a066d5b) 165 | 166 | - 準備する/マージして作るモデルは、以下の通り / Prepare models as below, 167 | 168 | | Model Name | | 169 | | --------------- | ----------------------------------------------------------------- | 170 | | sd-v1.5-pruned | Stable Diffusion v1.5 | 171 | | wd-v1.3-float32 | wd v1.3-float32 | 172 | | SD15-WD13-ws50 | 通常マージしたもの
SD15 + WD13, 0.5 # Weighted sum 0.5 | 173 | | bw-merge1-2-2 | Merge Block Weighted
SD15 and WD13. base_alpha=1
weightは後述1 | 174 | | bw-merge2-2-2 | Merge Block Weighted
SD15 and WD13. base_alpha=0
weightは後述2 | 175 | 176 | - テスト用のGeneration Info, Seedは 1~4 の4つ 177 | 178 | ``` 179 | masterpiece, best quality, beautiful anime girl, school uniform, strong rim light, intense shadows, highly detailed, cinematic lighting, taken by Canon EOS 5D Simga Art Lens 50mm f1.8 ISO 100 Shutter Speed 1000 180 | Negative prompt: lowres, bad anatomy, bad hands, error, missing fingers, cropped, worst quality, low quality, normal quality, jpeg artifacts, blurry 181 | Steps: 40, Sampler: Euler a, CFG scale: 7, Seed: 1, Face restoration: CodeFormer, Size: 512x512, Batch size: 4 182 | ``` 183 | 184 | ### result (x/y) 185 | 186 | ![](misc/xy_plus-0000-40-7_1.png) 187 | 188 | - 変化傾向は、 189 | 190 | - bw-merge1 で、顔立ちがややアニメ化 (sd15-wd13-ws50と比較して) 191 | - bw-merge2 で、ややリアル風(特に seed=3 の目が良い) 192 | 193 | - おおまかに見て、kohya_ss さんの結果と同様の方向性になった。実装は問題ないと判断する 194 | 195 | ### 後述1: weight1 196 | 197 | ``` 198 | 1, 0.9166666667, 0.8333333333, 0.75, 0.6666666667, 199 | 0.5833333333, 0.5, 0.4166666667, 0.3333333333, 0.25, 0.1666666667, 200 | 0.0833333333, 201 | 0, 202 | 0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5, 203 | 0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667,1.0 204 | ``` 205 | 206 | ### 後述2: weight2 207 | 208 | ``` 209 | 0,0.0833333333,0.1666666667,0.25,0.3333333333,0.4166666667,0.5, 210 | 0.5833333333,0.6666666667,0.75,0.8333333333,0.9166666667, 211 | 1.0, 212 | 0.9166666667, 0.8333333333, 0.75, 0.6666666667, 213 | 0.5833333333, 0.5, 0.4166666667, 0.3333333333, 0.25, 0.1666666667, 214 | 0.0833333333, 0 215 | ``` 216 | 217 | ## Preset's grids 218 | 219 |

220 | 221 |

222 | 223 | #### Examples of Sigmoid-like Functions 224 | 225 | ``` 226 | a∈{0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18;19;20;21;22;23;24} 227 | S=100/24` - Steps 228 | `𝝅` - number Pi 229 | `Q=2` - Ratio 230 | ``` 231 | 232 | | name | equation | 233 | | -------------------- | ----------------------------------------------------------------------------------------------- | 234 | | `Cosine` | `x=a(S)100` & `y=(1-COS((x-1)*𝝅))/2` | 235 | | `Smoothstep` | `x=a(S)100` & `y=3x^2-2x^3` | 236 | | `Smoothstep*Q` | `x=a(S)100` & ( when `x∈<0;0.5>` , `y=Q(3x^2-2x^3)` ∨ when `x∈(0.5;1>` , `y=2-Q(3x^2-2x^3` ) | 237 | | `Smoothstep\Q` | ( when `a<=12` , `x=a(S/Q)100` ∨ when `12= NUM_TOTAL_BLOCKS: 124 | print(f"error. illegal block index: {key}") 125 | return False, "" 126 | if weight_index >= 0: 127 | current_alpha = weights[weight_index] 128 | dprint(f"weighted '{key}': {current_alpha}", verbose) 129 | else: 130 | count_target_of_basealpha = count_target_of_basealpha + 1 131 | dprint(f"base_alpha applied: [{key}]", verbose) 132 | 133 | theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key] 134 | 135 | if save_as_half: 136 | theta_0[key] = theta_0[key].half() 137 | 138 | else: 139 | dprint(f" key - {key}", verbose) 140 | 141 | dprint(f"-- start Stage 2/2 --", verbose) 142 | for key in tqdm(theta_1.keys(), desc="Stage 2/2"): 143 | if "model" in key and key not in theta_0: 144 | 145 | if KEY_POSITION_IDS in key: 146 | if skip_position_ids == 1: 147 | print(f" modelB: skip 'position_ids' : {theta_0[KEY_POSITION_IDS].dtype}") 148 | dprint(f"{theta_0[KEY_POSITION_IDS]}", verbose) 149 | continue 150 | elif skip_position_ids == 2: 151 | theta_0[key] = torch.tensor([list(range(77))], dtype=torch.int64) 152 | print(f" modelB: reset 'position_ids': {theta_0[KEY_POSITION_IDS].dtype}") 153 | dprint(f"{theta_0[KEY_POSITION_IDS]}", verbose) 154 | continue 155 | else: 156 | print(f" modelB: 'position_ids' key found. do nothing : {skip_position_ids}") 157 | 158 | dprint(f" key : {key}", verbose) 159 | theta_0.update({key:theta_1[key]}) 160 | 161 | if save_as_half: 162 | theta_0[key] = theta_0[key].half() 163 | 164 | else: 165 | dprint(f" key - {key}", verbose) 166 | 167 | print("Saving...") 168 | 169 | _, extension = os.path.splitext(output_file) 170 | if extension.lower() == ".safetensors" or save_as_safetensors: 171 | if save_as_safetensors and extension.lower() != ".safetensors": 172 | output_file = output_file + ".safetensors" 173 | import safetensors.torch 174 | safetensors.torch.save_file(theta_0, output_file, metadata={"format": "pt"}) 175 | else: 176 | torch.save({"state_dict": theta_0}, output_file) 177 | 178 | print("Done!") 179 | 180 | return True, f"{output_file}
base_alpha applied [{count_target_of_basealpha}] times." 181 | -------------------------------------------------------------------------------- /scripts/mbw/ui_mbw.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import os 3 | import re 4 | 5 | from modules import sd_models, shared 6 | from tqdm import tqdm 7 | try: 8 | from modules import hashes 9 | from modules.sd_models import CheckpointInfo 10 | except: 11 | pass 12 | 13 | from scripts.mbw.merge_block_weighted import merge 14 | from scripts.mbw_util.preset_weights import PresetWeights 15 | from scripts.mbw_util.merge_history import MergeHistory 16 | 17 | presetWeights = PresetWeights() 18 | mergeHistory = MergeHistory() 19 | 20 | 21 | def on_ui_tabs(): 22 | with gr.Column(): 23 | with gr.Row(): 24 | with gr.Column(variant="panel"): 25 | html_output_block_weight_info = gr.HTML() 26 | with gr.Row(): 27 | btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary") 28 | btn_clear_weight = gr.Button(value="Clear values") 29 | btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint") 30 | with gr.Column(): 31 | dd_preset_weight = gr.Dropdown(label="Preset Weights", choices=presetWeights.get_preset_name_list()) 32 | txt_block_weight = gr.Text(label="Weight values", placeholder="Put weight sets. float number x 25") 33 | btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary") 34 | with gr.Row(): 35 | sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.01, value=0) 36 | chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False) 37 | chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False) 38 | with gr.Row(): 39 | with gr.Column(scale=3): 40 | with gr.Row(): 41 | chk_save_as_half = gr.Checkbox(label="Save as half", value=False) 42 | chk_save_as_safetensors = gr.Checkbox(label="Save as safetensors", value=False) 43 | with gr.Column(scale=4): 44 | radio_position_ids = gr.Radio(label="Skip/Reset CLIP position_ids", choices=["None", "Skip", "Force Reset"], value="None", type="index") 45 | with gr.Row(): 46 | model_A = gr.Dropdown(label="Model A", choices=sd_models.checkpoint_tiles()) 47 | model_B = gr.Dropdown(label="Model B", choices=sd_models.checkpoint_tiles()) 48 | txt_model_O = gr.Text(label="Output Model Name") 49 | with gr.Row(): 50 | with gr.Column(): 51 | sl_IN_00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.01, value=0.5) 52 | sl_IN_01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.01, value=0.5) 53 | sl_IN_02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.01, value=0.5) 54 | sl_IN_03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.01, value=0.5) 55 | sl_IN_04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.01, value=0.5) 56 | sl_IN_05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.01, value=0.5) 57 | sl_IN_06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.01, value=0.5) 58 | sl_IN_07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.01, value=0.5) 59 | sl_IN_08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.01, value=0.5) 60 | sl_IN_09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.01, value=0.5) 61 | sl_IN_10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.01, value=0.5) 62 | sl_IN_11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.01, value=0.5) 63 | with gr.Column(): 64 | gr.Slider(visible=False) 65 | gr.Slider(visible=False) 66 | gr.Slider(visible=False) 67 | gr.Slider(visible=False) 68 | gr.Slider(visible=False) 69 | gr.Slider(visible=False) 70 | gr.Slider(visible=False) 71 | gr.Slider(visible=False) 72 | gr.Slider(visible=False) 73 | gr.Slider(visible=False) 74 | gr.Slider(visible=False) 75 | sl_M_00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="mbw_sl_M00") 76 | with gr.Column(): 77 | sl_OUT_11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.01, value=0.5) 78 | sl_OUT_10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.01, value=0.5) 79 | sl_OUT_09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.01, value=0.5) 80 | sl_OUT_08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.01, value=0.5) 81 | sl_OUT_07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.01, value=0.5) 82 | sl_OUT_06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.01, value=0.5) 83 | sl_OUT_05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.01, value=0.5) 84 | sl_OUT_04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.01, value=0.5) 85 | sl_OUT_03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.01, value=0.5) 86 | sl_OUT_02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.01, value=0.5) 87 | sl_OUT_01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.01, value=0.5) 88 | sl_OUT_00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.01, value=0.5) 89 | 90 | sl_IN = [ 91 | sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, 92 | sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11] 93 | sl_MID = [sl_M_00] 94 | sl_OUT = [ 95 | sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, 96 | sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11] 97 | 98 | # Events 99 | def onclick_btn_do_merge_block_weighted( 100 | model_A, model_B, 101 | sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, 102 | sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, 103 | sl_M_00, 104 | sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, 105 | sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, 106 | txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite, 107 | chk_save_as_safetensors, chk_save_as_half, 108 | radio_position_ids 109 | ): 110 | 111 | # debug output 112 | print( "#### Merge Block Weighted ####") 113 | 114 | _weights = ",".join( 115 | [str(x) for x in [ 116 | sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, 117 | sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, 118 | sl_M_00, 119 | sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, 120 | sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11 121 | ]]) 122 | # 123 | if not model_A or not model_B: 124 | return gr.update(value=f"ERROR: model not found. [{model_A}][{model_B}]") 125 | 126 | # 127 | # Prepare params before run merge 128 | # 129 | 130 | # generate output file name from param 131 | model_A_info = sd_models.get_closet_checkpoint_match(model_A) 132 | if model_A_info: 133 | _model_A_name = model_A_info.model_name 134 | else: 135 | _model_A_name = "" 136 | model_B_info = sd_models.get_closet_checkpoint_match(model_B) 137 | if model_B_info: 138 | _model_B_info = model_B_info.model_name 139 | else: 140 | _model_B_info = "" 141 | 142 | def validate_output_filename(output_filename, save_as_safetensors=False, save_as_half=False): 143 | output_filename = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', output_filename) 144 | filename_body, filename_ext = os.path.splitext(output_filename) 145 | _ret = output_filename 146 | _footer = "-half" if save_as_half else "" 147 | if filename_ext in [".safetensors", ".ckpt"]: 148 | _ret = f"{filename_body}{_footer}{filename_ext}" 149 | elif save_as_safetensors: 150 | _ret = f"{output_filename}{_footer}.safetensors" 151 | else: 152 | _ret = f"{output_filename}{_footer}.ckpt" 153 | return _ret 154 | 155 | model_O = f"bw-merge-{_model_A_name}-{_model_B_info}-{sl_base_alpha}.ckpt" if txt_model_O == "" else txt_model_O 156 | model_O = validate_output_filename(model_O, save_as_safetensors=chk_save_as_safetensors, save_as_half=chk_save_as_half) 157 | 158 | _output = os.path.join(shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O) 159 | 160 | if not chk_allow_overwrite: 161 | if os.path.exists(_output): 162 | _err_msg = f"ERROR: output_file already exists. overwrite not allowed. abort." 163 | print(_err_msg) 164 | return gr.update(value=f"{_err_msg} [{_output}]") 165 | print(f" model_0 : {model_A}") 166 | print(f" model_1 : {model_B}") 167 | print(f" base_alpha : {sl_base_alpha}") 168 | print(f" output_file: {_output}") 169 | print(f" weights : {_weights}") 170 | print(f" skip ids : {radio_position_ids} : 0:None, 1:Skip, 2:Reset") 171 | 172 | result, ret_message = merge(weights=_weights, model_0=model_A, model_1=model_B, allow_overwrite=chk_allow_overwrite, 173 | base_alpha=sl_base_alpha, output_file=_output, verbose=chk_verbose_mbw, 174 | save_as_safetensors=chk_save_as_safetensors, 175 | save_as_half=chk_save_as_half, 176 | skip_position_ids=radio_position_ids 177 | ) 178 | 179 | if result: 180 | ret_html = "merged.
" \ 181 | + f"{model_A}
" \ 182 | + f"{model_B}
" \ 183 | + f"{model_O}
" \ 184 | + f"base_alpha={sl_base_alpha}
" \ 185 | + f"Weight_values={_weights}
" 186 | print("merged.") 187 | else: 188 | ret_html = ret_message 189 | print("merge failed.") 190 | 191 | # save log to history.tsv 192 | sd_models.list_models() 193 | model_A_info = sd_models.get_closet_checkpoint_match(model_A) 194 | model_B_info = sd_models.get_closet_checkpoint_match(model_B) 195 | model_O_info = sd_models.get_closet_checkpoint_match(os.path.basename(_output)) 196 | if hasattr(model_O_info, "sha256") and model_O_info.sha256 is None: 197 | model_O_info:CheckpointInfo = model_O_info 198 | model_O_info.sha256 = hashes.sha256(model_O_info.filename, "checkpoint/" + model_O_info.title) 199 | _names = presetWeights.find_names_by_weight(_weights) 200 | if _names and len(_names) > 0: 201 | weight_name = _names[0] 202 | else: 203 | weight_name = "" 204 | 205 | def model_name(model_info): 206 | return model_info.name if hasattr(model_info, "name") else model_info.title 207 | def model_sha256(model_info): 208 | return model_info.sha256 if hasattr(model_info, "sha256") else "" 209 | mergeHistory.add_history( 210 | model_name(model_A_info), 211 | model_A_info.hash, 212 | model_sha256(model_A_info), 213 | model_name(model_B_info), 214 | model_B_info.hash, 215 | model_sha256(model_B_info), 216 | model_name(model_O_info), 217 | model_O_info.hash, 218 | model_sha256(model_O_info), 219 | sl_base_alpha, 220 | _weights, 221 | "", 222 | weight_name 223 | ) 224 | 225 | return gr.update(value=f"{ret_html}") 226 | btn_do_merge_block_weighted.click( 227 | fn=onclick_btn_do_merge_block_weighted, 228 | inputs=[model_A, model_B] 229 | + sl_IN + sl_MID + sl_OUT 230 | + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite] 231 | + [chk_save_as_safetensors, chk_save_as_half, radio_position_ids], 232 | outputs=[html_output_block_weight_info] 233 | ) 234 | 235 | btn_clear_weight.click( 236 | fn=lambda: [gr.update(value=0.5) for _ in range(25)], 237 | inputs=[], 238 | outputs=[ 239 | sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, 240 | sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, 241 | sl_M_00, 242 | sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, 243 | sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, 244 | ] 245 | ) 246 | 247 | def on_change_dd_preset_weight(dd_preset_weight): 248 | _weights = presetWeights.find_weight_by_name(dd_preset_weight) 249 | _ret = on_btn_apply_block_weight_from_txt(_weights) 250 | return [gr.update(value=_weights)] + _ret 251 | dd_preset_weight.change( 252 | fn=on_change_dd_preset_weight, 253 | inputs=[dd_preset_weight], 254 | outputs=[txt_block_weight, 255 | sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, 256 | sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, 257 | sl_M_00, 258 | sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, 259 | sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, 260 | ] 261 | ) 262 | 263 | def on_btn_reload_checkpoint_mbw(): 264 | sd_models.list_models() 265 | return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())] 266 | btn_reload_checkpoint_mbw.click( 267 | fn=on_btn_reload_checkpoint_mbw, 268 | inputs=[], 269 | outputs=[model_A, model_B] 270 | ) 271 | 272 | def on_btn_apply_block_weight_from_txt(txt_block_weight): 273 | if not txt_block_weight or txt_block_weight == "": 274 | return [gr.update() for _ in range(25)] 275 | _list = [x.strip() for x in txt_block_weight.split(",")] 276 | if(len(_list) != 25): 277 | return [gr.update() for _ in range(25)] 278 | return [gr.update(value=x) for x in _list] 279 | btn_apply_block_weithg_from_txt.click( 280 | fn=on_btn_apply_block_weight_from_txt, 281 | inputs=[txt_block_weight], 282 | outputs=[ 283 | sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, 284 | sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11, 285 | sl_M_00, 286 | sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, 287 | sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11, 288 | ] 289 | ) 290 | 291 | -------------------------------------------------------------------------------- /scripts/mbw_each/merge_block_weighted_mod.py: -------------------------------------------------------------------------------- 1 | # from https://note.com/kohya_ss/n/n9a485a066d5b 2 | # kohya_ss 3 | # original code: https://github.com/eyriewow/merge-models 4 | 5 | # use them as base of this code 6 | # 2022/12/15 7 | # bbc-mc 8 | 9 | import os 10 | import argparse 11 | import re 12 | import torch 13 | from tqdm import tqdm 14 | 15 | from modules import sd_models, shared 16 | 17 | 18 | NUM_INPUT_BLOCKS = 12 19 | NUM_MID_BLOCK = 1 20 | NUM_OUTPUT_BLOCKS = 12 21 | NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS 22 | 23 | KEY_POSITION_IDS = "cond_stage_model.transformer.text_model.embeddings.position_ids" 24 | 25 | def dprint(str, flg): 26 | if flg: 27 | print(str) 28 | 29 | 30 | def merge(weight_A:list, weight_B:list, model_0, model_1, device="cpu", base_alpha=0.5, 31 | output_file="", allow_overwrite=False, verbose=False, 32 | save_as_safetensors=False, 33 | save_as_half=False, 34 | skip_position_ids=0, 35 | ): 36 | 37 | def _check_arg_weight(weight): 38 | if weight is None: 39 | return None 40 | else: 41 | _weight = [float(w) for w in weight.split(",")] 42 | if len(_weight) != NUM_TOTAL_BLOCKS: 43 | return None 44 | else: 45 | return _weight 46 | 47 | weight_A = _check_arg_weight(weight_A) 48 | if weight_A is None: 49 | _err_msg = f"Weight A invalid. program abort. weights value must be {NUM_TOTAL_BLOCKS}." 50 | print(_err_msg) 51 | return False, _err_msg 52 | weight_B = _check_arg_weight(weight_B) 53 | if weight_B is None: 54 | _err_msg = f"Weight B invalid. program abort. weights value must be {NUM_TOTAL_BLOCKS}." 55 | print(_err_msg) 56 | return False, _err_msg 57 | 58 | device = device if device in ["cpu", "cuda"] else "cpu" 59 | 60 | alpha = base_alpha 61 | 62 | _footer = "-half" if save_as_half else "" 63 | _footer = f"{_footer}.safetensors" if save_as_safetensors else f"{_footer}.ckpt" 64 | if not output_file or output_file == "": 65 | output_file = f'bw-{model_0}-{model_1}-{str(alpha)[2:] + "0"}{_footer}' 66 | 67 | # check if output file already exists 68 | if os.path.isfile(output_file) and not allow_overwrite: 69 | _err_msg = f"Exiting... [{output_file}]" 70 | print(_err_msg) 71 | return False, _err_msg 72 | 73 | def load_model(_model, _device): 74 | model_info = sd_models.get_closet_checkpoint_match(_model) 75 | if model_info: 76 | model_file = model_info.filename 77 | else: 78 | return None 79 | cache_enabled = shared.opts.sd_checkpoint_cache > 0 80 | if cache_enabled and model_info in sd_models.checkpoints_loaded: 81 | print(" load from cache") 82 | return sd_models.checkpoints_loaded[model_info].copy() 83 | else: 84 | print(" loading ...") 85 | return sd_models.read_state_dict(model_file, map_location=_device) 86 | 87 | print("loading", model_0) 88 | theta_0 = load_model(model_0, device) 89 | 90 | print("loading", model_1) 91 | theta_1 = load_model(model_1, device) 92 | 93 | re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12 94 | re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1 95 | re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12 96 | 97 | dprint(f"-- start Stage 1/2 --", verbose) 98 | count_target_of_basealpha = 0 99 | for key in (tqdm(theta_0.keys(), desc="Stage 1/2") if not verbose else theta_0.keys()): 100 | if "model" in key and key in theta_1: 101 | 102 | if KEY_POSITION_IDS in key: 103 | if skip_position_ids == 1: 104 | print(f" modelA: skip 'position_ids': dtype:{theta_0[KEY_POSITION_IDS].dtype}") 105 | dprint(f"{theta_0[KEY_POSITION_IDS]}", verbose) 106 | continue 107 | elif skip_position_ids == 2: 108 | theta_0[key] = torch.tensor([list(range(77))], dtype=torch.int64) 109 | print(f" modelA: reset 'position_ids': dtype:{theta_0[KEY_POSITION_IDS].dtype}") 110 | dprint(f"{theta_0[KEY_POSITION_IDS]}", verbose) 111 | continue 112 | else: 113 | print(f" modelA: key found. do nothing: dtype:{theta_0[KEY_POSITION_IDS].dtype}") 114 | 115 | dprint(f" key : {key}", verbose) 116 | 117 | current_alpha_A = 1 - alpha 118 | current_alpha_B = alpha 119 | current_alpha_I = 0 120 | 121 | # check weighted and U-Net or not 122 | if weight_A is not None and 'model.diffusion_model.' in key: 123 | # check block index 124 | weight_index = -1 125 | 126 | if 'time_embed' in key: 127 | weight_index = 0 # before input blocks 128 | elif '.out.' in key: 129 | weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks 130 | else: 131 | m = re_inp.search(key) 132 | if m: 133 | inp_idx = int(m.groups()[0]) 134 | weight_index = inp_idx 135 | else: 136 | m = re_mid.search(key) 137 | if m: 138 | weight_index = NUM_INPUT_BLOCKS 139 | else: 140 | m = re_out.search(key) 141 | if m: 142 | out_idx = int(m.groups()[0]) 143 | weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx 144 | 145 | if weight_index >= NUM_TOTAL_BLOCKS: 146 | print(f"error. illegal block index: {key}") 147 | if weight_index >= 0: 148 | current_alpha_A = weight_A[weight_index] 149 | current_alpha_B = weight_B[weight_index] 150 | current_alpha_I = 1 - current_alpha_A - current_alpha_B 151 | if verbose: 152 | print(f"weighted '{key}': A{current_alpha_A} B{current_alpha_B} I{current_alpha_I}") 153 | 154 | # create I tensor 155 | tensor_I_0 = torch.zeros_like(theta_0[key], dtype=theta_0[key].dtype) 156 | _var1 = current_alpha_I * tensor_I_0 157 | _var2 = current_alpha_A * theta_0[key] 158 | _var3 = current_alpha_B * theta_1[key] 159 | theta_0[key] = _var1 + _var2 + _var3 160 | 161 | # theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key] 162 | if save_as_half: 163 | theta_0[key] = theta_0[key].half() 164 | else: 165 | dprint(f" key - {key}", verbose) 166 | 167 | dprint(f"-- start Stage 2/2 --", verbose) 168 | for key in tqdm(theta_1.keys(), desc="Stage 2/2"): 169 | if "model" in key and key not in theta_0: 170 | 171 | if KEY_POSITION_IDS in key: 172 | if skip_position_ids == 1: 173 | print(f" modelB: skip 'position_ids' : {theta_0[KEY_POSITION_IDS].dtype}") 174 | dprint(f"{theta_0[KEY_POSITION_IDS]}", verbose) 175 | continue 176 | elif skip_position_ids == 2: 177 | theta_0[key] = torch.tensor([list(range(77))], dtype=torch.int64) 178 | print(f" modelB: reset 'position_ids': {theta_0[KEY_POSITION_IDS].dtype}") 179 | dprint(f"{theta_0[KEY_POSITION_IDS]}", verbose) 180 | continue 181 | else: 182 | print(f" modelB: key found. do nothing : {skip_position_ids}") 183 | 184 | dprint(f" key : {key}", verbose) 185 | theta_0.update({key:theta_1[key]}) 186 | 187 | if save_as_half: 188 | theta_0[key] = theta_0[key].half() 189 | 190 | else: 191 | dprint(f" key - {key}", verbose) 192 | 193 | print("Saving...") 194 | 195 | _, extension = os.path.splitext(output_file) 196 | if extension.lower() == ".safetensors" or save_as_safetensors: 197 | if save_as_safetensors and extension.lower() != ".safetensors": 198 | output_file = output_file + ".safetensors" 199 | import safetensors.torch 200 | safetensors.torch.save_file(theta_0, output_file, metadata={"format": "pt"}) 201 | else: 202 | torch.save({"state_dict": theta_0}, output_file) 203 | 204 | print("Done!") 205 | 206 | return True, f"{output_file}
base_alpha applied [{count_target_of_basealpha}] times." 207 | -------------------------------------------------------------------------------- /scripts/mbw_each/ui_mbw_each.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import os 3 | import re 4 | 5 | from modules import sd_models, shared 6 | from tqdm import tqdm 7 | try: 8 | from modules import hashes 9 | from modules.sd_models import CheckpointInfo 10 | except: 11 | pass 12 | 13 | from scripts.mbw_each.merge_block_weighted_mod import merge 14 | from scripts.mbw_util.preset_weights import PresetWeights 15 | from scripts.mbw_util.merge_history import MergeHistory 16 | 17 | presetWeights = PresetWeights() 18 | mergeHistory = MergeHistory() 19 | 20 | 21 | def on_ui_tabs(): 22 | with gr.Column(): 23 | with gr.Row(): 24 | with gr.Column(variant="panel"): 25 | with gr.Row(): 26 | txt_multi_process_cmd = gr.TextArea(label="Multi Proc Cmd", placeholder="Keep empty if dont use.") 27 | html_output_block_weight_info = gr.HTML() 28 | with gr.Row(): 29 | btn_do_merge_block_weighted = gr.Button(value="Run Merge", variant="primary") 30 | btn_clear_weighted = gr.Button(value="Clear values") 31 | btn_reload_checkpoint_mbw = gr.Button(value="Reload checkpoint") 32 | with gr.Column(): 33 | dd_preset_weight = gr.Dropdown(label="Preset_Weights", choices=presetWeights.get_preset_name_list()) 34 | txt_block_weight = gr.Text(label="Weight_values", placeholder="Put weight sets. float number x 25") 35 | btn_apply_block_weithg_from_txt = gr.Button(value="Apply block weight from text", variant="primary") 36 | with gr.Row(): 37 | sl_base_alpha = gr.Slider(label="base_alpha", minimum=0, maximum=1, step=0.01, value=0) 38 | chk_verbose_mbw = gr.Checkbox(label="verbose console output", value=False) 39 | chk_allow_overwrite = gr.Checkbox(label="Allow overwrite output-model", value=False) 40 | with gr.Row(): 41 | with gr.Column(scale=3): 42 | with gr.Row(): 43 | chk_save_as_half = gr.Checkbox(label="Save as half", value=False) 44 | chk_save_as_safetensors = gr.Checkbox(label="Save as safetensors", value=False) 45 | with gr.Column(scale=4): 46 | radio_position_ids = gr.Radio(label="Skip/Reset CLIP position_ids", choices=["None", "Skip", "Force Reset"], value="None", type="index") 47 | with gr.Row(): 48 | dd_model_A = gr.Dropdown(label="Model_A", choices=sd_models.checkpoint_tiles()) 49 | dd_model_B = gr.Dropdown(label="Model_B", choices=sd_models.checkpoint_tiles()) 50 | txt_model_O = gr.Text(label="(O)Output Model Name") 51 | with gr.Row(): 52 | with gr.Column(): 53 | sl_IN_A_00 = gr.Slider(label="IN_A_00", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_00") 54 | sl_IN_A_01 = gr.Slider(label="IN_A_01", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_01") 55 | sl_IN_A_02 = gr.Slider(label="IN_A_02", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_02") 56 | sl_IN_A_03 = gr.Slider(label="IN_A_03", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_03") 57 | sl_IN_A_04 = gr.Slider(label="IN_A_04", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_04") 58 | sl_IN_A_05 = gr.Slider(label="IN_A_05", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_05") 59 | sl_IN_A_06 = gr.Slider(label="IN_A_06", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_06") 60 | sl_IN_A_07 = gr.Slider(label="IN_A_07", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_07") 61 | sl_IN_A_08 = gr.Slider(label="IN_A_08", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_08") 62 | sl_IN_A_09 = gr.Slider(label="IN_A_09", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_09") 63 | sl_IN_A_10 = gr.Slider(label="IN_A_10", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_10") 64 | sl_IN_A_11 = gr.Slider(label="IN_A_11", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_A_11") 65 | with gr.Column(): 66 | sl_IN_B_00 = gr.Slider(label="IN_B_00", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_00") 67 | sl_IN_B_01 = gr.Slider(label="IN_B_01", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_01") 68 | sl_IN_B_02 = gr.Slider(label="IN_B_02", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_02") 69 | sl_IN_B_03 = gr.Slider(label="IN_B_03", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_03") 70 | sl_IN_B_04 = gr.Slider(label="IN_B_04", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_04") 71 | sl_IN_B_05 = gr.Slider(label="IN_B_05", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_05") 72 | sl_IN_B_06 = gr.Slider(label="IN_B_06", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_06") 73 | sl_IN_B_07 = gr.Slider(label="IN_B_07", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_07") 74 | sl_IN_B_08 = gr.Slider(label="IN_B_08", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_08") 75 | sl_IN_B_09 = gr.Slider(label="IN_B_09", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_09") 76 | sl_IN_B_10 = gr.Slider(label="IN_B_10", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_10") 77 | sl_IN_B_11 = gr.Slider(label="IN_B_11", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_IN_B_11") 78 | with gr.Column(): 79 | gr.Slider(visible=False) 80 | gr.Slider(visible=False) 81 | gr.Slider(visible=False) 82 | gr.Slider(visible=False) 83 | gr.Slider(visible=False) 84 | gr.Slider(visible=False) 85 | gr.Slider(visible=False) 86 | gr.Slider(visible=False) 87 | gr.Slider(visible=False) 88 | gr.Slider(visible=False) 89 | gr.Slider(visible=False) 90 | sl_M_A_00 = gr.Slider(label="M_A_00", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_M_A_00") 91 | with gr.Column(): 92 | gr.Slider(visible=False) 93 | gr.Slider(visible=False) 94 | gr.Slider(visible=False) 95 | gr.Slider(visible=False) 96 | gr.Slider(visible=False) 97 | gr.Slider(visible=False) 98 | gr.Slider(visible=False) 99 | gr.Slider(visible=False) 100 | gr.Slider(visible=False) 101 | gr.Slider(visible=False) 102 | gr.Slider(visible=False) 103 | sl_M_B_00 = gr.Slider(label="M_B_00", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_M_B_00") 104 | with gr.Column(): 105 | sl_OUT_A_11 = gr.Slider(label="OUT_A_11", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_11") 106 | sl_OUT_A_10 = gr.Slider(label="OUT_A_10", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_10") 107 | sl_OUT_A_09 = gr.Slider(label="OUT_A_09", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_09") 108 | sl_OUT_A_08 = gr.Slider(label="OUT_A_08", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_08") 109 | sl_OUT_A_07 = gr.Slider(label="OUT_A_07", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_07") 110 | sl_OUT_A_06 = gr.Slider(label="OUT_A_06", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_06") 111 | sl_OUT_A_05 = gr.Slider(label="OUT_A_05", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_05") 112 | sl_OUT_A_04 = gr.Slider(label="OUT_A_04", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_04") 113 | sl_OUT_A_03 = gr.Slider(label="OUT_A_03", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_03") 114 | sl_OUT_A_02 = gr.Slider(label="OUT_A_02", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_02") 115 | sl_OUT_A_01 = gr.Slider(label="OUT_A_01", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_01") 116 | sl_OUT_A_00 = gr.Slider(label="OUT_A_00", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_A_00") 117 | with gr.Column(): 118 | sl_OUT_B_11 = gr.Slider(label="OUT_B_11", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_11") 119 | sl_OUT_B_10 = gr.Slider(label="OUT_B_10", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_10") 120 | sl_OUT_B_09 = gr.Slider(label="OUT_B_09", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_09") 121 | sl_OUT_B_08 = gr.Slider(label="OUT_B_08", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_08") 122 | sl_OUT_B_07 = gr.Slider(label="OUT_B_07", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_07") 123 | sl_OUT_B_06 = gr.Slider(label="OUT_B_06", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_06") 124 | sl_OUT_B_05 = gr.Slider(label="OUT_B_05", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_05") 125 | sl_OUT_B_04 = gr.Slider(label="OUT_B_04", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_04") 126 | sl_OUT_B_03 = gr.Slider(label="OUT_B_03", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_03") 127 | sl_OUT_B_02 = gr.Slider(label="OUT_B_02", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_02") 128 | sl_OUT_B_01 = gr.Slider(label="OUT_B_01", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_01") 129 | sl_OUT_B_00 = gr.Slider(label="OUT_B_00", minimum=0, maximum=1, step=0.01, value=0.5, elem_id="sl_OUT_B_00") 130 | 131 | # Footer 132 | gr.HTML( 133 | """ 134 |

135 | Merge Block Weighted extension by bbc_mc
136 | MBW Each is experimental functions and NO PROOF of effectiveness.
137 | You can try it by own, to dig more deeper into Abyss ...
138 |

139 | """ 140 | ) 141 | 142 | sl_A_IN = [ 143 | sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, 144 | sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11] 145 | sl_A_MID = [sl_M_A_00] 146 | sl_A_OUT = [ 147 | sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, 148 | sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11] 149 | 150 | sl_B_IN = [ 151 | sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, 152 | sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11] 153 | sl_B_MID = [sl_M_B_00] 154 | sl_B_OUT = [ 155 | sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, 156 | sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11] 157 | 158 | 159 | # Events 160 | def onclick_btn_do_merge_block_weighted( 161 | dd_model_A, dd_model_B, txt_multi_process_cmd, 162 | sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, 163 | sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11, 164 | sl_M_A_00, 165 | sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, 166 | sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11, 167 | sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, 168 | sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11, 169 | sl_M_B_00, 170 | sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, 171 | sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11, 172 | txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite, 173 | chk_save_as_safetensors, chk_save_as_half, 174 | radio_position_ids 175 | ): 176 | base_alpha = sl_base_alpha 177 | _weight_A = ",".join( 178 | [str(x) for x in [ 179 | sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, 180 | sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11, 181 | sl_M_A_00, 182 | sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, 183 | sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11, 184 | ]]) 185 | _weight_B = ",".join( 186 | [str(x) for x in [ 187 | sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, 188 | sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11, 189 | sl_M_B_00, 190 | sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, 191 | sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11, 192 | ]]) 193 | 194 | # debug output 195 | print( "#### Merge Block Weighted : Each ####") 196 | 197 | if (not dd_model_A or not dd_model_B) and txt_multi_process_cmd == "": 198 | _err_msg = f"ERROR: model not found. [{dd_model_A}][{dd_model_B}]" 199 | print(_err_msg) 200 | return gr.update(value=_err_msg) 201 | 202 | ret_html = "" 203 | if txt_multi_process_cmd != "": 204 | # need multi-merge 205 | _lines = txt_multi_process_cmd.split('\n') 206 | print(f"check multi-merge. {len(_lines)} lines found.") 207 | for line_index, _line in enumerate(_lines): 208 | if _line == "": 209 | continue 210 | print(f"\n== merge line {line_index+1}/{len(_lines)} ==") 211 | _items = [x.strip() for x in _line.split(",") if x != ""] 212 | if len(_items) > 0: 213 | ret_html += _run_merge( 214 | weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B, 215 | allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, 216 | verbose=chk_verbose_mbw, 217 | params=_items, 218 | save_as_safetensors=chk_save_as_safetensors, 219 | save_as_half=chk_save_as_half, 220 | skip_position_ids=radio_position_ids 221 | ) 222 | else: 223 | _ret = f" multi-merge text found, but invalid params. skipped :[{_line}]" 224 | ret_html += _ret 225 | print(_ret) 226 | else: 227 | # normal merge 228 | ret_html += _run_merge( 229 | weight_A=_weight_A, weight_B=_weight_B, model_0=dd_model_A, model_1=dd_model_B, 230 | allow_overwrite=chk_allow_overwrite, base_alpha=base_alpha, model_Output=txt_model_O, 231 | verbose=chk_verbose_mbw, 232 | save_as_safetensors=chk_save_as_safetensors, 233 | save_as_half=chk_save_as_half, 234 | skip_position_ids=radio_position_ids 235 | ) 236 | 237 | sd_models.list_models() 238 | print( "#### All merge process done. ####") 239 | 240 | return gr.update(value=f"{ret_html}") 241 | btn_do_merge_block_weighted.click( 242 | fn=onclick_btn_do_merge_block_weighted, 243 | inputs=[dd_model_A, dd_model_B, txt_multi_process_cmd] 244 | + sl_A_IN + sl_A_MID + sl_A_OUT + sl_B_IN + sl_B_MID + sl_B_OUT 245 | + [txt_model_O, sl_base_alpha, chk_verbose_mbw, chk_allow_overwrite] 246 | + [chk_save_as_safetensors, chk_save_as_half, radio_position_ids], 247 | outputs=[html_output_block_weight_info] 248 | ) 249 | 250 | def _run_merge(weight_A, weight_B, model_0, model_1, allow_overwrite=False, base_alpha=0, 251 | model_Output="", verbose=False, params=[], 252 | save_as_safetensors=False, 253 | save_as_half=False, 254 | skip_position_ids=0, 255 | ): 256 | 257 | def validate_output_filename(output_filename, save_as_safetensors=False, save_as_half=False): 258 | output_filename = re.sub(r'[\\|:|?|"|<|>|\|\*]', '-', output_filename) 259 | filename_body, filename_ext = os.path.splitext(output_filename) 260 | _ret = output_filename 261 | _footer = "-half" if save_as_half else "" 262 | if filename_ext in [".safetensors", ".ckpt"]: 263 | _ret = f"{filename_body}{_footer}{filename_ext}" 264 | elif save_as_safetensors: 265 | _ret = f"{output_filename}{_footer}.safetensors" 266 | else: 267 | _ret = f"{output_filename}{_footer}.ckpt" 268 | return _ret 269 | 270 | model_O = "" 271 | 272 | if params and len(params) > 0: 273 | for _item in params: 274 | # expect "O=merge/test02, IN_B_00 = 0.12345" as params=["O=merge/test02", "IN_B_00 = 0.12345"] 275 | if len(_item.split("=")) == 2: 276 | _item_l = _item.split("=")[0].strip() 277 | _item_r = _item.split("=")[1].strip() 278 | if _item_r != "": 279 | if _item_l.lower() == "model_a" or _item_l.lower() == "model_b": 280 | _model_info = sd_models.get_closet_checkpoint_match(_item_r) 281 | if _model_info: 282 | _model_name = _model_info.title.split(" ")[0] 283 | if _model_name and _model_name.strip() != "": 284 | if _item_l.lower() == "model_a": 285 | print(f" * Model changed: {model_0} -> {_model_info.title}") 286 | model_0 = _model_info.title 287 | elif _item_l.lower() == "model_b": 288 | print(f" * Model changed: {model_1} -> {_model_info.title}") 289 | model_1 = _model_info.title 290 | 291 | elif _item_l.lower() == "preset_weights": 292 | _weights = presetWeights.find_weight_by_name(_item_r) 293 | if _weights != "" and len(_weights.split(',')) == 25: 294 | print(f" * Weights changed by preset-name: {_item_r}") 295 | weight_B = _weights 296 | weight_A = ",".join([str(1-float(x)) for x in _weights.split(',')]) 297 | else: 298 | print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]") 299 | 300 | elif _item_l.lower() == "weight_values": 301 | _weights = _item_r.strip() 302 | if _weights != "" and len(_weights.split(' ')) == 25: # this is work-around to use space as separator. Double-meaning issue on commna which already used as value separator and weights separator. 303 | print(f" * Weights changed: {_item_r}") 304 | weight_B = _weights 305 | weight_A = ",".join([str(1-float(x)) for x in _weights.split(' ')]) 306 | else: 307 | print(f" * Weights change :canceled: [{_item_r}][{_weights}][{len(_weights.split(','))}]") 308 | 309 | elif _item_l.lower() == "base_alpha": 310 | if float(_item_r) >= 0: 311 | print(f" * base_alpha changed: {base_alpha} -> {_item_r}") 312 | base_alpha = float(_item_r) 313 | 314 | elif _item_l.upper() == "O": 315 | if _item_r.strip() != "": 316 | _ret = validate_output_filename(_item_r.strip(), save_as_safetensors=save_as_safetensors, save_as_half=save_as_half) 317 | print(f" * Output filename changed:[{model_O}] -> [{_ret}]") 318 | model_O = _ret 319 | 320 | elif len(_item_l.split("_")) == 3: 321 | _IMO = _item_l.split("_")[0] 322 | _AB = _item_l.split("_")[1] 323 | _NUM = _item_l.split("_")[2] 324 | 325 | _index = int(_NUM) 326 | _index = _index + 0 if _IMO == "IN" else _index 327 | _index = _index + 12 if _IMO == "M" else _index 328 | _index = _index + 13 if _IMO == "OUT" else _index 329 | 330 | def _apply_val(key, weight, index, new_value): 331 | _weight = [x.strip() for x in weight.split(",")] 332 | _new_weight = _weight[:] 333 | _new_weight[index] = new_value 334 | _new_weight = ",".join(_new_weight) 335 | print(f" * weight_{key} changed:[{weight}]") 336 | print(f" -> [{_new_weight}]") 337 | return _new_weight 338 | 339 | if _AB == "A": 340 | weight_A = _apply_val(_AB, weight_A, _index, _item_r) 341 | elif _AB == "B": 342 | weight_B = _apply_val(_AB, weight_B, _index, _item_r) 343 | else: 344 | print(f" * Waring: uncaught param found. ignored. [{_item_l}][{_item_r}]") 345 | 346 | # 347 | # Prepare params before run merge 348 | # 349 | 350 | # generate output file name from param 351 | model_A_info = sd_models.get_closet_checkpoint_match(model_0) 352 | _model_A_name = "" if not model_A_info else model_A_info.filename 353 | 354 | model_B_info = sd_models.get_closet_checkpoint_match(model_1) 355 | _model_B_name = "" if not model_B_info else model_B_info.filename 356 | 357 | if model_O == "": 358 | _a = os.path.splitext(os.path.basename(_model_A_name))[0] 359 | _b = os.path.splitext(os.path.basename(_model_B_name))[0] 360 | model_O = f"bw-merge-{_a}-{_b}-{base_alpha}" if model_Output == "" else model_Output 361 | model_O = validate_output_filename(model_O, save_as_safetensors=save_as_safetensors, save_as_half=save_as_half) 362 | output_file = os.path.join(shared.cmd_opts.ckpt_dir or sd_models.model_path, model_O) 363 | # 364 | # Check params 365 | # 366 | if not os.path.exists(os.path.dirname(output_file)): 367 | _err_msg = f"WARNING: target path not found: {os.path.dirname(output_file)}. skipped." 368 | print(_err_msg) 369 | return _err_msg + "
" 370 | if not allow_overwrite: 371 | if os.path.exists(output_file): 372 | _err_msg = f"WARNING: output_file already exists. overwrite not allowed. skipped." 373 | print(_err_msg) 374 | return _err_msg + "
" 375 | 376 | # debug output 377 | print(f" model_0 : {model_0}") 378 | print(f" model_1 : {model_1}") 379 | print(f" model_Out : {model_O}") 380 | print(f" base_alpha : {base_alpha}") 381 | print(f" output_file: {output_file}") 382 | print(f" weight_A : {weight_A}") 383 | print(f" weight_B : {weight_B}") 384 | print(f" half : {save_as_half}") 385 | print(f" skip ids : {skip_position_ids} : 0:None, 1:Skip, 2:Reset") 386 | 387 | result, ret_message = merge( 388 | weight_A=weight_A, weight_B=weight_B, model_0=model_0, model_1=model_1, 389 | allow_overwrite=allow_overwrite, base_alpha=base_alpha, output_file=output_file, 390 | verbose=verbose, 391 | save_as_safetensors=save_as_safetensors, 392 | save_as_half=save_as_half, 393 | skip_position_ids=skip_position_ids, 394 | ) 395 | if result: 396 | ret_html = f"merged. {model_0} + {model_1} = {model_O}
" 397 | print("merged.") 398 | else: 399 | ret_html = ret_message 400 | print("merge failed.") 401 | 402 | 403 | # save log to history.tsv 404 | sd_models.list_models() 405 | model_A_info = sd_models.get_closet_checkpoint_match(model_0) 406 | model_B_info = sd_models.get_closet_checkpoint_match(model_1) 407 | model_O_info = sd_models.get_closet_checkpoint_match(os.path.basename(output_file)) 408 | if hasattr(model_O_info, "sha256") and model_O_info.sha256 is None: 409 | model_O_info:CheckpointInfo = model_O_info 410 | model_O_info.sha256 = hashes.sha256(model_O_info.filename, "checkpoint/" + model_O_info.title) 411 | _names = presetWeights.find_names_by_weight(weight_B) 412 | if _names and len(_names) > 0: 413 | weight_name = _names[0] 414 | else: 415 | weight_name = "" 416 | 417 | def model_name(model_info): 418 | return model_info.name if hasattr(model_info, "name") else model_info.title 419 | def model_sha256(model_info): 420 | return model_info.sha256 if hasattr(model_info, "sha256") else "" 421 | mergeHistory.add_history( 422 | model_name(model_A_info), 423 | model_A_info.hash, 424 | model_sha256(model_A_info), 425 | model_name(model_B_info), 426 | model_B_info.hash, 427 | model_sha256(model_B_info), 428 | model_name(model_O_info), 429 | model_O_info.hash, 430 | model_sha256(model_O_info), 431 | base_alpha, 432 | weight_A, 433 | weight_B, 434 | weight_name 435 | ) 436 | 437 | return ret_html 438 | 439 | btn_clear_weighted.click( 440 | fn=lambda: [gr.update(value=0.5) for _ in range(25*2)], 441 | inputs=[], 442 | outputs=[ 443 | sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, 444 | sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11, 445 | sl_M_A_00, 446 | sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, 447 | sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11, 448 | sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, 449 | sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11, 450 | sl_M_B_00, 451 | sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, 452 | sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11, 453 | ] 454 | ) 455 | 456 | def on_change_dd_preset_weight(dd_preset_weight): 457 | _weights = presetWeights.find_weight_by_name(dd_preset_weight) 458 | _ret = on_btn_apply_block_weight_from_txt(_weights) 459 | return [gr.update(value=_weights)] + _ret 460 | dd_preset_weight.change( 461 | fn=on_change_dd_preset_weight, 462 | inputs=[dd_preset_weight], 463 | outputs=[ 464 | txt_block_weight, 465 | sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, 466 | sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11, 467 | sl_M_A_00, 468 | sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, 469 | sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11, 470 | sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, 471 | sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11, 472 | sl_M_B_00, 473 | sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, 474 | sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11, 475 | ] 476 | ) 477 | 478 | def on_btn_reload_checkpoint_mbw(): 479 | sd_models.list_models() 480 | return [gr.update(choices=sd_models.checkpoint_tiles()), gr.update(choices=sd_models.checkpoint_tiles())] 481 | btn_reload_checkpoint_mbw.click( 482 | fn=on_btn_reload_checkpoint_mbw, 483 | inputs=[], 484 | outputs=[dd_model_A, dd_model_B] 485 | ) 486 | 487 | def on_btn_apply_block_weight_from_txt(txt_block_weight): 488 | if not txt_block_weight or txt_block_weight == "": 489 | return [gr.update() for _ in range(25*2)] 490 | _list = [x.strip() for x in txt_block_weight.split(",")] 491 | if(len(_list) != 25): 492 | return [gr.update() for _ in range(25*2)] 493 | return [gr.update(value=str(1-float(x))) for x in _list] + [gr.update(value=x) for x in _list] 494 | btn_apply_block_weithg_from_txt.click( 495 | fn=on_btn_apply_block_weight_from_txt, 496 | inputs=[txt_block_weight], 497 | outputs=[ 498 | sl_IN_A_00, sl_IN_A_01, sl_IN_A_02, sl_IN_A_03, sl_IN_A_04, sl_IN_A_05, 499 | sl_IN_A_06, sl_IN_A_07, sl_IN_A_08, sl_IN_A_09, sl_IN_A_10, sl_IN_A_11, 500 | sl_M_A_00, 501 | sl_OUT_A_00, sl_OUT_A_01, sl_OUT_A_02, sl_OUT_A_03, sl_OUT_A_04, sl_OUT_A_05, 502 | sl_OUT_A_06, sl_OUT_A_07, sl_OUT_A_08, sl_OUT_A_09, sl_OUT_A_10, sl_OUT_A_11, 503 | sl_IN_B_00, sl_IN_B_01, sl_IN_B_02, sl_IN_B_03, sl_IN_B_04, sl_IN_B_05, 504 | sl_IN_B_06, sl_IN_B_07, sl_IN_B_08, sl_IN_B_09, sl_IN_B_10, sl_IN_B_11, 505 | sl_M_B_00, 506 | sl_OUT_B_00, sl_OUT_B_01, sl_OUT_B_02, sl_OUT_B_03, sl_OUT_B_04, sl_OUT_B_05, 507 | sl_OUT_B_06, sl_OUT_B_07, sl_OUT_B_08, sl_OUT_B_09, sl_OUT_B_10, sl_OUT_B_11, 508 | ] 509 | ) 510 | -------------------------------------------------------------------------------- /scripts/mbw_util/merge_history.py: -------------------------------------------------------------------------------- 1 | # 2 | # 3 | # 4 | import os 5 | import datetime 6 | from csv import DictWriter, DictReader 7 | import shutil 8 | 9 | from modules import scripts 10 | 11 | 12 | CSV_FILE_ROOT = "csv/" 13 | CSV_FILE_PATH = "csv/history.tsv" 14 | HEADERS = [ 15 | "model_A", "model_A_hash", "model_A_sha256", 16 | "model_B", "model_B_hash", "model_B_sha256", 17 | "model_O", "model_O_hash", "model_O_sha256", 18 | "base_alpha", "weight_name", "weight_values", "weight_values2", "datetime"] 19 | path_root = scripts.basedir() 20 | 21 | 22 | class MergeHistory(): 23 | def __init__(self): 24 | self.fileroot = os.path.join(path_root, CSV_FILE_ROOT) 25 | self.filepath = os.path.join(path_root, CSV_FILE_PATH) 26 | if not os.path.exists(self.fileroot): 27 | os.mkdir(self.fileroot) 28 | if os.path.exists(self.filepath): 29 | self.update_header() 30 | 31 | def add_history(self, 32 | model_A_name, model_A_hash, model_A_sha256, 33 | model_B_name, model_B_hash, model_B_sha256, 34 | model_O_name, model_O_hash, model_O_sha256, 35 | sl_base_alpha, 36 | weight_value_A, 37 | weight_value_B, 38 | weight_name=""): 39 | _history_dict = {} 40 | _history_dict.update({ 41 | "model_A": model_A_name, 42 | "model_A_hash": model_A_hash, 43 | "model_A_sha256": model_A_sha256, 44 | "model_B": model_B_name, 45 | "model_B_hash": model_B_hash, 46 | "model_B_sha256": model_B_sha256, 47 | "model_O": model_O_name, 48 | "model_O_hash": model_O_hash, 49 | "model_O_sha256": model_O_sha256, 50 | "base_alpha": sl_base_alpha, 51 | "weight_name": weight_name, 52 | "weight_values": weight_value_A, 53 | "weight_values2": weight_value_B, 54 | "datetime": f"{datetime.datetime.now()}" 55 | }) 56 | 57 | if not os.path.exists(self.filepath): 58 | with open(self.filepath, "w", newline="", encoding="utf-8") as f: 59 | dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t') 60 | dw.writeheader() 61 | # save to file 62 | with open(self.filepath, "a", newline="", encoding='utf-8') as f: 63 | dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t') 64 | dw.writerow(_history_dict) 65 | 66 | def update_header(self): 67 | hist_data = [] 68 | if os.path.exists(self.filepath): 69 | # check header in case HEADERS updated 70 | with open(self.filepath, "r", newline="", encoding="utf-8") as f: 71 | dr = DictReader(f, delimiter='\t') 72 | new_header = [ x for x in HEADERS if x not in dr.fieldnames ] 73 | if len(new_header) > 0: 74 | # need update. 75 | hist_data = [ x for x in dr] 76 | # apply change 77 | if len(hist_data) > 0: 78 | # backup before change 79 | shutil.copy(self.filepath, self.filepath + ".bak") 80 | with open(self.filepath, "w", newline="", encoding="utf-8") as f: 81 | dw = DictWriter(f, fieldnames=HEADERS, delimiter='\t') 82 | dw.writeheader() 83 | dw.writerows(hist_data) 84 | -------------------------------------------------------------------------------- /scripts/mbw_util/preset_weights.py: -------------------------------------------------------------------------------- 1 | # 2 | # 3 | # 4 | import os 5 | from csv import DictReader 6 | 7 | from modules import scripts 8 | 9 | 10 | CSV_FILE_PATH = "csv/preset.tsv" 11 | MYPRESET_PATH = "csv/preset_own.tsv" 12 | HEADER = ["preset_name", "preset_weights"] 13 | path_root = scripts.basedir() 14 | 15 | 16 | class PresetWeights(): 17 | def __init__(self): 18 | self.presets = {} 19 | 20 | if os.path.exists(os.path.join(path_root, MYPRESET_PATH)): 21 | with open(os.path.join(path_root, MYPRESET_PATH), "r") as f: 22 | reader = DictReader(f, delimiter="\t") 23 | lines_dict = [row for row in reader] 24 | for line_dict in lines_dict: 25 | _w = ",".join([f"{x.strip()}" for x in line_dict["preset_weights"].split(",")]) 26 | self.presets.update({line_dict["preset_name"]: _w}) 27 | 28 | with open(os.path.join(path_root, CSV_FILE_PATH), "r") as f: 29 | reader = DictReader(f, delimiter="\t") 30 | lines_dict = [row for row in reader] 31 | for line_dict in lines_dict: 32 | _w = ",".join([f"{x.strip()}" for x in line_dict["preset_weights"].split(",")]) 33 | self.presets.update({line_dict["preset_name"]: _w}) 34 | 35 | def get_preset_name_list(self): 36 | return [k for k in self.presets.keys()] 37 | 38 | def find_weight_by_name(self, preset_name=""): 39 | if preset_name and preset_name != "" and preset_name in self.presets.keys(): 40 | return self.presets.get(preset_name, ",".join(["0.5" for _ in range(25)])) 41 | else: 42 | return "" 43 | 44 | def find_names_by_weight(self, weights=""): 45 | if weights and weights != "": 46 | if weights in self.presets.values(): 47 | return [k for k, v in self.presets.items() if v == weights] 48 | else: 49 | _val = ",".join([f"{x.strip()}" for x in weights.split(",")]) 50 | if _val in self.presets.values(): 51 | return [k for k, v in self.presets.items() if v == _val] 52 | else: 53 | return [] 54 | else: 55 | return [] 56 | -------------------------------------------------------------------------------- /scripts/merge_block_weighted_extension.py: -------------------------------------------------------------------------------- 1 | # Merge block weighted Board 2 | # 3 | # extension of AUTOMATIC1111 web ui 4 | # 5 | # 2022/12/14 bbc_mc 6 | # 7 | 8 | import os 9 | import gradio as gr 10 | 11 | from modules import script_callbacks 12 | 13 | 14 | from scripts.mbw import ui_mbw 15 | from scripts.mbw_each import ui_mbw_each 16 | 17 | 18 | # 19 | # UI callback 20 | # 21 | def on_ui_tabs(): 22 | 23 | with gr.Blocks() as main_block: 24 | with gr.Tab("MBW", elem_id="tab_mbw"): 25 | ui_mbw.on_ui_tabs() 26 | 27 | with gr.Tab("MBW Each", elem_id="tab_mbw_each"): 28 | ui_mbw_each.on_ui_tabs() 29 | 30 | # return required as (gradio_component, title, elem_id) 31 | return (main_block, "Merge Block Weighted", "merge_block_weighted"), 32 | 33 | # on_UI 34 | script_callbacks.on_ui_tabs(on_ui_tabs) 35 | -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | #mbw_sl_M00, #mbw_sl_a_M00, #mbw_sl_b_M00 { 2 | bottom:0; 3 | position:absolute; 4 | } 5 | 6 | /* 7 | #sl_IN_A_00, #sl_IN_A_01, #sl_IN_A_02, #sl_IN_A_03, #sl_IN_A_04, #sl_IN_A_05, #sl_IN_A_06, #sl_IN_A_07, #sl_IN_A_08, #sl_IN_A_09, #sl_IN_A_10, #sl_IN_A_11 { 8 | width: 220; 9 | } 10 | 11 | #sl_IN_B_00, #sl_IN_B_01, #sl_IN_B_02, #sl_IN_B_03, #sl_IN_B_04, #sl_IN_B_05, #sl_IN_B_06, #sl_IN_B_07, #sl_IN_B_08, #sl_IN_B_09, #sl_IN_B_10, #sl_IN_B_11 { 12 | width: 220; 13 | } 14 | */ 15 | 16 | #sl_M_A_00, #sl_M_B_00 { 17 | bottom:0; 18 | position:absolute; 19 | } 20 | 21 | /* 22 | #sl_OUT_A_00, #sl_OUT_A_01, #sl_OUT_A_02, #sl_OUT_A_03, #sl_OUT_A_04, #sl_OUT_A_05, #sl_OUT_A_06, #sl_OUT_A_07, #sl_OUT_A_08, #sl_OUT_A_09, #sl_OUT_A_10, #sl_OUT_A_11 { 23 | width: 220; 24 | } 25 | 26 | #sl_OUT_B_00, #sl_OUT_B_01, #sl_OUT_B_02, #sl_OUT_B_03, #sl_OUT_B_04, #sl_OUT_B_05, #sl_OUT_B_06, #sl_OUT_B_07, #sl_OUT_B_08, #sl_OUT_B_09, #sl_OUT_B_10, #sl_OUT_B_11 { 27 | width: 220; 28 | } 29 | */ --------------------------------------------------------------------------------