├── .gitignore ├── README.md └── scripts ├── convert.py └── ui.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sd-webui-model-converter 2 | 3 | Model convert extension , Used for [AUTOMATIC1111's stable diffusion webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 4 |  5 | 6 | ## Features 7 | 8 | - convert to precisions: fp32, fp16, bf16 9 | - pruning model: no-ema, ema-only 10 | - checkpoint ext convert: ckpt, safetensors 11 | - convert/copy/delete any parts of model: unet, text encoder(clip), vae 12 | - Fix CLIP 13 | - Force CLIP position_id to int64 before convert 14 | 15 | ### Fix CLIP 16 | 17 | Sometimes, the CLIP position_id becomes incorrect due to model merging. 18 | For example, Anything-v3. 19 | 20 | This option will reset CLIP position to `torch.Tensor([list(range(77))]).to(torch.int64)` 21 | 22 | 23 | ### Force CLIP position_id to int64 before convert 24 | 25 | If you use this extension to convert a model to fp16, and the model has an incorrect CLIP, the precision of the CLIP position_id may decrease during the compression process, which might coincidentally correct the offset. 26 | 27 |  28 | 29 | If you do not wish to correct this CLIP offset coincidentally (because fixing it would alter the model, 30 | even though the correction is accurate, not everyone prefers the most correct, right? :P), 31 | you can use this option. It will force the CLIP position_id to be int64 and retain the incorrect CLIP -------------------------------------------------------------------------------- /scripts/convert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import torch 4 | import safetensors.torch 5 | from torch import Tensor 6 | from modules import shared 7 | from modules import sd_models, sd_vae 8 | 9 | # position_ids in clip is int64. model_ema.num_updates is int32 10 | dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16} 11 | dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16} 12 | dtypes_to_float8_e4m3fn = {torch.float32, torch.float64, torch.bfloat16, torch.float16} 13 | dtypes_to_float8_e5m2 = {torch.float32, torch.float64, torch.bfloat16, torch.float16} 14 | 15 | 16 | class MockModelInfo: 17 | def __init__(self, model_path: str) -> None: 18 | self.filepath = model_path 19 | self.filename: str = os.path.basename(model_path) 20 | self.model_name: str = self.filename.split(".")[0] 21 | 22 | 23 | def conv_fp16(t: Tensor): 24 | return t.half() if t.dtype in dtypes_to_fp16 else t 25 | 26 | 27 | def conv_bf16(t: Tensor): 28 | return t.bfloat16() if t.dtype in dtypes_to_bf16 else t 29 | 30 | 31 | def conv_float8_e4m3fn(t: Tensor): 32 | return t.to(torch.float8_e4m3fn) if t.dtype in dtypes_to_float8_e4m3fn else t 33 | 34 | def conv_float8_e5m2(t: Tensor): 35 | return t.to(torch.float8_e5m2) if t.dtype in dtypes_to_float8_e5m2 else t 36 | 37 | def conv_full(t): 38 | return t 39 | 40 | 41 | _g_precision_func = { 42 | "full": conv_full, 43 | "fp32": conv_full, 44 | "fp16": conv_fp16, 45 | "bf16": conv_bf16, 46 | "float8_e4m3fn": conv_float8_e4m3fn, 47 | "float8_e5m2": conv_float8_e5m2, 48 | } 49 | 50 | 51 | def check_weight_type(k: str) -> str: 52 | if k.startswith("model.diffusion_model"): 53 | return "unet" 54 | elif k.startswith("first_stage_model"): 55 | return "vae" 56 | elif k.startswith("cond_stage_model") or k.startswith("conditioner.embedders"): 57 | return "clip" 58 | return "other" 59 | 60 | 61 | def load_model(path): 62 | if path.endswith(".safetensors"): 63 | m = safetensors.torch.load_file(path, device="cpu") 64 | else: 65 | m = torch.load(path, map_location="cpu") 66 | state_dict = m["state_dict"] if "state_dict" in m else m 67 | return state_dict 68 | 69 | 70 | def fix_model(model, fix_clip=False, force_position_id=False): 71 | # code from model-toolkit 72 | nai_keys = { 73 | 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', 74 | 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', 75 | 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.' 76 | } 77 | position_id_key = "cond_stage_model.transformer.text_model.embeddings.position_ids" 78 | for k in list(model.keys()): 79 | for r in nai_keys: 80 | if type(k) == str and k.startswith(r): 81 | new_key = k.replace(r, nai_keys[r]) 82 | model[new_key] = model[k] 83 | del model[k] 84 | print(f"[Converter] Fixed novelai error key {k}") 85 | break 86 | 87 | if force_position_id and position_id_key in model: 88 | model[position_id_key] = model[position_id_key].to(torch.int64) 89 | 90 | if fix_clip: 91 | if position_id_key in model: 92 | correct = torch.Tensor([list(range(77))]).to(torch.int64) 93 | now = model[position_id_key].to(torch.int64) 94 | 95 | broken = correct.ne(now) 96 | broken = [i for i in range(77) if broken[0][i]] 97 | if len(broken) != 0: 98 | model[position_id_key] = correct 99 | print(f"[Converter] Fixed broken clip\n{broken}") 100 | else: 101 | print("[Converter] Clip in this model is fine, skip fixing...") 102 | else: 103 | print("[Converter] Missing position id in model, try fixing...") 104 | model[position_id_key] = torch.Tensor([list(range(77))]).to(torch.int64) 105 | 106 | return model 107 | 108 | 109 | def is_sdxl_model(model): 110 | for k in list(model.keys()): 111 | if k.startswith("conditioner.embedders"): 112 | return True 113 | return False 114 | 115 | 116 | def convert_warp( 117 | path_mode, model_name, model_path, directory, 118 | *args 119 | ): 120 | match path_mode: 121 | case 0: # single process 122 | if model_info := sd_models.checkpoints_list.get(model_name, None): 123 | return do_convert(MockModelInfo(model_info.filename), *args) 124 | return "Error: model not found" 125 | 126 | case 1: # input file path 127 | if os.path.exists(model_path): 128 | return do_convert(MockModelInfo(model_path), *args) 129 | return f'Error: model path "{model_path}" not exists' 130 | 131 | case 2: # batch from directory 132 | if not os.path.isdir(directory) or not os.path.exists(directory): 133 | return f'Error: path "{directory}" not exists or not dir' 134 | 135 | if not (files := [f for f in os.listdir(directory) if f.endswith(".ckpt") or f.endswith(".safetensors")]): 136 | return "Error: cant found model in directory" 137 | 138 | # remove custom filename in batch processing 139 | _args = list(args) 140 | _args[3] = "" 141 | 142 | for m in files: 143 | do_convert(MockModelInfo(os.path.join(directory, m)), *_args) 144 | 145 | return "Batch processing done" 146 | 147 | case _: 148 | return f"Error: unknown mode {path_mode}" 149 | 150 | 151 | def do_convert(model_info: MockModelInfo, 152 | checkpoint_formats, 153 | precision, conv_type, custom_name, 154 | bake_in_vae, 155 | unet_conv, text_encoder_conv, vae_conv, others_conv, 156 | fix_clip, force_position_id, delete_known_junk_data): 157 | if len(checkpoint_formats) == 0: 158 | return "Error: at least choose one model save format" 159 | 160 | extra_opt = { 161 | "unet": unet_conv, 162 | "clip": text_encoder_conv, 163 | "vae": vae_conv, 164 | "other": others_conv 165 | } 166 | shared.state.begin() 167 | shared.state.job = 'model-convert' 168 | shared.state.textinfo = f"Loading {model_info.filename}..." 169 | print(f"[Converter] Loading {model_info.filename}...") 170 | 171 | ok = {} 172 | state_dict = load_model(model_info.filepath) 173 | is_sdxl = is_sdxl_model(state_dict) 174 | 175 | if not is_sdxl: 176 | fix_model(state_dict, fix_clip=fix_clip, force_position_id=force_position_id) 177 | 178 | 179 | if precision == "fp8": 180 | assert torch.__version__ >= "2.1.0", "PyTorch 2.1.0 or newer is required for fp8 conversion" 181 | 182 | conv_func = _g_precision_func[precision] 183 | 184 | def _hf(wk: str, t: Tensor): 185 | if not isinstance(t, Tensor): 186 | return 187 | weight_type = check_weight_type(wk) 188 | conv_t = extra_opt[weight_type] 189 | if conv_t == "convert": 190 | ok[wk] = conv_func(t) 191 | elif conv_t == "copy": 192 | ok[wk] = t 193 | elif conv_t == "delete": 194 | return 195 | 196 | print("[Converter] Converting model...") 197 | 198 | if conv_type == "ema-only": 199 | for k in tqdm.tqdm(state_dict): 200 | ema_k = "___" 201 | try: 202 | ema_k = "model_ema." + k[6:].replace(".", "") 203 | except: 204 | pass 205 | if ema_k in state_dict: 206 | _hf(k, state_dict[ema_k]) 207 | # print("ema: " + ema_k + " > " + k) 208 | elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]: 209 | _hf(k, state_dict[k]) 210 | # print(k) 211 | # else: 212 | # print("skipped: " + k) 213 | elif conv_type == "no-ema": 214 | for k, v in tqdm.tqdm(state_dict.items()): 215 | if "model_ema." not in k: 216 | _hf(k, v) 217 | else: 218 | for k, v in tqdm.tqdm(state_dict.items()): 219 | _hf(k, v) 220 | 221 | if delete_known_junk_data: 222 | known_junk_data_prefix = [ 223 | "embedding_manager.embedder.", 224 | "lora_te_text_model", 225 | "control_model." 226 | ] 227 | need_delete = [] 228 | for key in ok.keys(): 229 | for jk in known_junk_data_prefix: 230 | if key.startswith(jk): 231 | need_delete.append(key) 232 | 233 | for k in need_delete: 234 | del ok[k] 235 | 236 | bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) 237 | if bake_in_vae_filename is not None: 238 | print(f"[Converter] Baking in VAE from {bake_in_vae_filename}") 239 | vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') 240 | 241 | for k, v in vae_dict.items(): 242 | _hf(k, vae_dict[k]) 243 | 244 | del vae_dict 245 | 246 | output = "" 247 | ckpt_dir = os.path.dirname(model_info.filepath) 248 | save_name = f"{model_info.model_name}-{precision}" 249 | if conv_type != "disabled": 250 | save_name += f"-{conv_type}" 251 | 252 | if fix_clip: 253 | save_name += f"-clip-fix" 254 | 255 | if custom_name != "": 256 | save_name = custom_name 257 | 258 | for fmt in checkpoint_formats: 259 | ext = ".safetensors" if fmt == "safetensors" else ".ckpt" 260 | _save_name = save_name + ext 261 | 262 | save_path = os.path.join(ckpt_dir, _save_name) 263 | print(f"[Converter] Saving to {save_path}...") 264 | 265 | if fmt == "safetensors": 266 | safetensors.torch.save_file(ok, save_path) 267 | else: 268 | torch.save({"state_dict": ok}, save_path) 269 | output += f"Checkpoint saved to {save_path}\n" 270 | 271 | shared.state.end() 272 | return output[:-1] 273 | -------------------------------------------------------------------------------- /scripts/ui.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from modules import script_callbacks 3 | from modules import sd_models, sd_vae 4 | from modules.ui import create_refresh_button 5 | from scripts import convert 6 | 7 | 8 | def gr_show(visible=True): 9 | return {"visible": visible, "__type__": "update"} 10 | 11 | 12 | def add_tab(): 13 | with gr.Blocks(analytics_enabled=False) as ui: 14 | with gr.Row(equal_height=True): 15 | with gr.Column(variant='panel'): 16 | gr.HTML(value="
Converted checkpoints will be saved in your checkpoint directory.
") 17 | with gr.Tabs(): 18 | with gr.TabItem(label='Single process') as single_process: 19 | with gr.Row(): 20 | model_name = gr.Dropdown(sd_models.checkpoint_tiles(), 21 | elem_id="model_converter_model_name", 22 | label="Model") 23 | create_refresh_button(model_name, sd_models.list_models, 24 | lambda: {"choices": sd_models.checkpoint_tiles()}, 25 | "refresh_checkpoint_Z") 26 | custom_name = gr.Textbox(label="Custom Name (Optional)") 27 | 28 | with gr.TabItem(label='Input file path') as input_file_path: 29 | with gr.Row(): 30 | model_path = gr.Textbox(label="model path") 31 | 32 | with gr.TabItem(label='Batch from directory') as batch_from_directory: 33 | with gr.Row(): 34 | input_directory = gr.Textbox(label="Input Directory") 35 | 36 | with gr.Row(): 37 | precision = gr.Radio(choices=["fp32", "fp16", "bf16", "float8_e4m3fn","float8_e5m2"], value="fp16", label="Precision") 38 | m_type = gr.Radio(choices=["disabled", "no-ema", "ema-only"], value="disabled", label="Pruning Methods") 39 | 40 | with gr.Row(): 41 | checkpoint_formats = gr.CheckboxGroup(choices=["ckpt", "safetensors"], value=["safetensors"], label="Checkpoint Format") 42 | show_extra_options = gr.Checkbox(label="Show extra options", value=False) 43 | 44 | with gr.Row(): 45 | bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE") 46 | create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "model_converter_refresh_bake_in_vae") 47 | 48 | with gr.Row(): 49 | force_position_id = gr.Checkbox(label="Force CLIP position_id to int64 before convert", value=True) 50 | fix_clip = gr.Checkbox(label="Fix clip", value=False) 51 | delete_known_junk_data = gr.Checkbox(label="Delete known junk data", value=False) 52 | 53 | with gr.Row(visible=False) as extra_options: 54 | specific_part_conv = ["copy", "convert", "delete"] 55 | unet_conv = gr.Dropdown(specific_part_conv, value="convert", label="unet") 56 | text_encoder_conv = gr.Dropdown(specific_part_conv, value="convert", label="text encoder") 57 | vae_conv = gr.Dropdown(specific_part_conv, value="convert", label="vae") 58 | others_conv = gr.Dropdown(specific_part_conv, value="convert", label="others") 59 | 60 | model_converter_convert = gr.Button(elem_id="model_converter_convert", variant='primary') 61 | 62 | with gr.Column(variant='panel'): 63 | submit_result = gr.Textbox(elem_id="model_converter_result", show_label=False) 64 | 65 | path_mode = gr.Number(value=0, visible=False) 66 | for i, tab in enumerate([single_process, input_file_path, batch_from_directory]): 67 | tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[path_mode]) 68 | 69 | show_extra_options.change( 70 | fn=lambda x: gr_show(x), 71 | inputs=[show_extra_options], 72 | outputs=[extra_options], 73 | ) 74 | 75 | model_converter_convert.click( 76 | fn=convert.convert_warp, 77 | inputs=[ 78 | path_mode, 79 | model_name, 80 | model_path, 81 | input_directory, 82 | checkpoint_formats, 83 | precision, m_type, custom_name, 84 | bake_in_vae, 85 | unet_conv, 86 | text_encoder_conv, 87 | vae_conv, 88 | others_conv, 89 | fix_clip, 90 | force_position_id, 91 | delete_known_junk_data 92 | ], 93 | outputs=[submit_result] 94 | ) 95 | 96 | return [(ui, "Model Converter", "model_converter")] 97 | 98 | 99 | script_callbacks.on_ui_tabs(add_tab) 100 | --------------------------------------------------------------------------------