├── .gitignore ├── LICENSE ├── README.md ├── lora-keyword.txt ├── model-keyword-user.txt ├── model-keyword.txt └── scripts └── model_keyword.py /.gitignore: -------------------------------------------------------------------------------- 1 | custom-mappings.txt 2 | custom-mappings-backup.txt 3 | settings.txt 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2022 ChunKoo Park 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # model-keyword 2 | Automatic1111 WEBUI extension to autofill keyword(trigger words) for custom stable diffusion models and LORA models. 3 | 4 | [model-keyword-github.webm](https://user-images.githubusercontent.com/1288793/205525862-a8eaebfe-1860-41d1-bc66-335896b467dd.webm) 5 | 6 | ## Installation 7 | 8 | Copy the url of the repository ( https://github.com/mix1009/model-keyword ) into the extension tab and press "Install" 9 | 10 | Screenshot 2022-12-01 at 12 14 25 PM 11 | 12 | From "Extensions/Installed" tab press "Apply and restart UI". 13 | Screenshot 2022-12-01 at 12 18 43 PM 14 | 15 | ## Usage 16 | From txt2img, img2img tab, "Model Keyword" section is added. Model keyword extension is enabled by default. Click Model Keyword or triangle to reveal options. 17 | 18 | ![model-keyword-open](https://user-images.githubusercontent.com/1288793/212831258-0eea1dc8-9b67-4395-9368-fd69eebe4fb0.png) 19 | 20 | When generating image, the extension will try to figure out which model is used and insert matching keyword to the prompt: 21 | model-keyword-usage 22 | 23 | 24 | 25 | ### Keyword placement 26 | Screenshot 2022-12-01 at 12 26 41 PM 27 | 28 | 29 | ### Multiple keywords 30 | Screenshot 2022-12-01 at 12 27 00 PM 31 | 32 | 1) keyword1, keyword2 - use all keywords separated by comma 33 | 2) random - choose one random keyword 34 | 3) iterate - iterate through each keyword for each image generation 35 | * If sd-dynamic-prompts extension is installed, iterate will not work properly. Please disable Dynamic Prompts. 36 | * Alternatively, you can rename model-keyword to sd-model-keyword in the extensions folder. It will change the order of the extension and fix the bug. 37 | 4) keyword1 - use first keyword 38 | 5) keyword2 - use second keyword (if it exists) 39 | 40 | ### LORA model support 41 | LORA 42 | 43 | 1) Select LORA model from the Model dropdown. 44 | 2) Keywords dropdown list should contain keywords for the selected model. 45 | 3) Limitation: You can only select one model, and select one keyword or all keywords. 46 | 4) If you are using multiple LORA models, please check https://lora-help.coolai.app/ . 47 | 48 | ## Add Custom Mappings 49 | custom_mappings 50 | 51 | * "Check" -> outputs model filename, hash, matching keyword(s), and source of match in result. 52 | * "Save" -> save custom mapping with keyword. (Fill keyword) 53 | * "Delete" -> deletes custom mapping for model if it's available. 54 | 55 | * Mappings are saved in custom-mappings.txt 56 | * If previous mapping is found, save overwrites it. 57 | * do NOT edit model-keyword.txt . It can be overwritten or cause conflict while upgrading. 58 | * hash value for model has been changed in webui(2023-01-14), this extension uses old hash value. Old hash value is no longer displayed in webui. 59 | 60 | 61 | -------------------------------------------------------------------------------- /model-keyword-user.txt: -------------------------------------------------------------------------------- 1 | # This file is no long used. 2 | # User settings are now saved in custom-mappings.txt 3 | # Instead of manually editing, use the UI to add custom mappings. 4 | -------------------------------------------------------------------------------- /scripts/model_keyword.py: -------------------------------------------------------------------------------- 1 | import modules.scripts as scripts 2 | import gradio as gr 3 | import csv 4 | import os 5 | from collections import defaultdict 6 | 7 | import modules.shared as shared 8 | import difflib 9 | import random 10 | import glob 11 | import hashlib 12 | import shutil 13 | import fnmatch 14 | 15 | scripts_dir = scripts.basedir() 16 | kw_idx = 0 17 | lora_idx = 0 18 | hash_dict = None 19 | hash_dict_modified = None 20 | lora_hash_dict = None 21 | lora_hash_dict_modified = None 22 | 23 | model_hash_dict = {} 24 | 25 | def str_simularity(a, b): 26 | return difflib.SequenceMatcher(None, a, b).ratio() 27 | 28 | def get_old_model_hash(filename): 29 | if filename in model_hash_dict: 30 | return model_hash_dict[filename] 31 | try: 32 | with open(filename, "rb") as file: 33 | m = hashlib.sha256() 34 | 35 | file.seek(0x100000) 36 | m.update(file.read(0x10000)) 37 | hash = m.hexdigest()[0:8] 38 | model_hash_dict[filename] = hash 39 | return hash 40 | except FileNotFoundError: 41 | return 'NOFILE' 42 | 43 | def find_files(directory, exts): 44 | for root, dirs, files in os.walk(directory): 45 | for ext in exts: 46 | pattern = f'*.{ext}' 47 | for filename in fnmatch.filter(files, pattern): 48 | yield os.path.relpath(os.path.join(root, filename), directory) 49 | 50 | def load_hash_dict(): 51 | global hash_dict, hash_dict_modified, scripts_dir 52 | 53 | default_file = f'{scripts_dir}/model-keyword.txt' 54 | user_file = f'{scripts_dir}/custom-mappings.txt' 55 | 56 | if not os.path.exists(user_file): 57 | open(user_file, 'w').write('\n') 58 | 59 | modified = str(os.stat(default_file).st_mtime) + '_' + str(os.stat(user_file).st_mtime) 60 | 61 | if hash_dict is None or hash_dict_modified != modified: 62 | hash_dict = defaultdict(list) 63 | def parse_file(path, idx): 64 | if os.path.exists(path): 65 | with open(path, newline='', encoding='utf-8') as csvfile: 66 | csvreader = csv.reader(csvfile) 67 | for row in csvreader: 68 | try: 69 | mhash = row[0].strip(' ') 70 | kw = row[1].strip(' ') 71 | if mhash.startswith('#'): 72 | continue 73 | mhash = mhash.lower() 74 | ckptname = 'default' if len(row)<=2 else row[2].strip(' ') 75 | hash_dict[mhash].append((kw, ckptname,idx)) 76 | except: 77 | pass 78 | 79 | parse_file(default_file, 0) # 0 for default_file 80 | parse_file(user_file, 1) # 1 for user_file 81 | 82 | hash_dict_modified = modified 83 | 84 | return hash_dict 85 | 86 | def load_lora_hash_dict(): 87 | global lora_hash_dict, lora_hash_dict_modified, scripts_dir 88 | 89 | default_file = f'{scripts_dir}/lora-keyword.txt' 90 | user_file = f'{scripts_dir}/lora-keyword-user.txt' 91 | 92 | if not os.path.exists(user_file): 93 | open(user_file, 'w').write('\n') 94 | 95 | modified = str(os.stat(default_file).st_mtime) + '_' + str(os.stat(user_file).st_mtime) 96 | 97 | if lora_hash_dict is None or lora_hash_dict_modified != modified: 98 | lora_hash_dict = defaultdict(list) 99 | def parse_file(path, idx): 100 | if os.path.exists(path): 101 | with open(path, encoding='utf-8', newline='') as csvfile: 102 | csvreader = csv.reader(csvfile) 103 | for row in csvreader: 104 | try: 105 | mhash = row[0].strip(' ') 106 | kw = row[1].strip(' ') 107 | if mhash.startswith('#'): 108 | continue 109 | mhash = mhash.lower() 110 | ckptname = 'default' if len(row)<=2 else row[2].strip(' ') 111 | lora_hash_dict[mhash].append((kw, ckptname,idx)) 112 | except: 113 | pass 114 | 115 | parse_file(default_file, 0) # 0 for default_file 116 | parse_file(user_file, 1) # 1 for user_file 117 | 118 | lora_hash_dict_modified = modified 119 | 120 | return lora_hash_dict 121 | 122 | def get_keyword_for_model(model_hash, model_ckpt, return_entry=False): 123 | found = None 124 | 125 | # hash -> [ (keyword, ckptname, idx) ] 126 | hash_dict = load_hash_dict() 127 | 128 | # print(hash_dict) 129 | 130 | if model_hash in hash_dict: 131 | lst = hash_dict[model_hash] 132 | if len(lst) == 1: 133 | found = lst[0] 134 | 135 | elif len(lst) > 1: 136 | max_sim = 0.0 137 | found = lst[0] 138 | for kw_ckpt in lst: 139 | sim = str_simularity(kw_ckpt[1], model_ckpt) 140 | if sim >= max_sim: 141 | max_sim = sim 142 | found = kw_ckpt 143 | if return_entry: 144 | return found 145 | return found[0] if found else None 146 | 147 | def _get_keywords_for_lora(lora_model, return_entry=False): 148 | found = None 149 | 150 | lora_model_path = f'{shared.cmd_opts.lora_dir}/{lora_model}' 151 | 152 | # hash -> [ (keyword, ckptname, idx) ] 153 | lora_hash_dict = load_lora_hash_dict() 154 | 155 | lora_model_hash = get_old_model_hash(lora_model_path) 156 | 157 | if lora_model_hash in lora_hash_dict: 158 | lst = lora_hash_dict[lora_model_hash] 159 | if len(lst) == 1: 160 | found = lst[0] 161 | 162 | elif len(lst) > 1: 163 | max_sim = 0.0 164 | found = lst[0] 165 | for kw_ckpt in lst: 166 | sim = str_simularity(kw_ckpt[1], lora_model) 167 | if sim >= max_sim: 168 | max_sim = sim 169 | found = kw_ckpt 170 | if return_entry: 171 | return found 172 | return found[0] if found else None 173 | 174 | def get_lora_keywords(lora_model, keyword_only=False): 175 | lora_keywords = ["None"] 176 | if lora_model != 'None': 177 | kws = _get_keywords_for_lora(lora_model) 178 | if kws: 179 | words = [x.strip() for x in kws.split('|')] 180 | if keyword_only: 181 | return words 182 | if len(words) > 1: 183 | words.insert(0, ', '.join(words)) 184 | words.append('< iterate >') 185 | words.append('< random >') 186 | lora_keywords.extend(words) 187 | 188 | return lora_keywords 189 | settings = None 190 | 191 | def save_settings(m): 192 | global scripts_dir, settings 193 | 194 | if settings is None: 195 | settings = get_settings() 196 | 197 | for k in m.keys(): 198 | settings[k] = m[k] 199 | 200 | # print(settings) 201 | 202 | settings_file = f'{scripts_dir}/settings.txt' 203 | 204 | lines = [] 205 | for k in settings.keys(): 206 | lines.append(f'{k}={settings[k]}') 207 | csvtxt = '\n'.join(lines)+'\n' 208 | open(settings_file, 'w').write(csvtxt) 209 | 210 | def get_settings(): 211 | global scripts_dir, settings 212 | if settings: 213 | return settings 214 | 215 | settings = {} 216 | 217 | settings['is_enabled'] = 'True' 218 | settings['keyword_placement'] = 'keyword prompt' 219 | settings['multiple_keywords'] = 'keyword1, keyword2' 220 | settings['ti_keywords'] = 'None' 221 | settings['keyword_order'] = 'textual inversion first' 222 | settings['lora_model'] = 'None' 223 | settings['lora_multiplier'] = 0.7 224 | settings['lora_keywords'] = 'None' 225 | 226 | settings_file = f'{scripts_dir}/settings.txt' 227 | 228 | if os.path.exists(settings_file): 229 | with open(settings_file, newline='') as file: 230 | for line in file.read().split('\n'): 231 | pos = line.find('=') 232 | if pos == -1: continue 233 | k = line[:pos] 234 | v = line[pos+1:].strip() 235 | settings[k] = v 236 | 237 | return settings 238 | 239 | class Script(scripts.Script): 240 | def title(self): 241 | return "Model keyword" 242 | 243 | def show(self, is_img2img): 244 | return scripts.AlwaysVisible 245 | 246 | def ui(self, is_img2img): 247 | def get_embeddings(): 248 | return [os.path.basename(x) for x in glob.glob(f'{shared.cmd_opts.embeddings_dir}/*.pt')] 249 | def get_loras(): 250 | return sorted(list(find_files(shared.cmd_opts.lora_dir,['safetensors','ckpt','pt'])), key=str.casefold) 251 | # return [os.path.basename(x) for x in glob.glob(f'{shared.cmd_opts.lora_dir}/*.safetensors')] 252 | 253 | def update_keywords(): 254 | model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename) 255 | model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename) 256 | kws = get_keyword_for_model(model_hash, model_ckpt) 257 | mk_choices = ["keyword1, keyword2", "random", "iterate"] 258 | if kws: 259 | mk_choices.extend([x.strip() for x in kws.split('|')]) 260 | else: 261 | mk_choices.extend(["keyword1", "keyword2"]) 262 | return gr.Dropdown.update(choices=mk_choices) 263 | def update_embeddings(): 264 | ti_choices = ["None"] 265 | ti_choices.extend(get_embeddings()) 266 | return gr.Dropdown.update(choices=ti_choices) 267 | def update_loras(): 268 | lora_choices = ["None"] 269 | lora_choices.extend(get_loras()) 270 | return gr.Dropdown.update(choices=lora_choices) 271 | 272 | 273 | def update_lora_keywords(lora_model): 274 | lora_keywords = get_lora_keywords(lora_model) 275 | return gr.Dropdown.update(choices=lora_keywords) 276 | 277 | def check_keyword(): 278 | model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename) 279 | model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename) 280 | entry = get_keyword_for_model(model_hash, model_ckpt, return_entry=True) 281 | 282 | if entry: 283 | kw = entry[0] 284 | src = 'custom-mappings.txt' if entry[2]==1 else 'model-keyword.txt (default database)' 285 | return f"filename={model_ckpt}\nhash={model_hash}\nkeyword={kw}\nmatch from {src}" 286 | else: 287 | return f"filename={model_ckpt}\nhash={model_hash}\nno match" 288 | 289 | def delete_keyword(): 290 | model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename) 291 | model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename) 292 | user_file = f'{scripts_dir}/custom-mappings.txt' 293 | user_backup_file = f'{scripts_dir}/custom-mappings-backup.txt' 294 | lines = [] 295 | found = None 296 | 297 | if os.path.exists(user_file): 298 | with open(user_file, newline='') as csvfile: 299 | csvreader = csv.reader(csvfile) 300 | for row in csvreader: 301 | try: 302 | mhash = row[0] 303 | if mhash.startswith('#'): 304 | lines.append(','.join(row)) 305 | continue 306 | # kw = row[1] 307 | ckptname = None if len(row)<=2 else row[2].strip(' ') 308 | if mhash==model_hash and ckptname==model_ckpt: 309 | found = row 310 | continue 311 | lines.append(','.join(row)) 312 | except: 313 | pass 314 | 315 | if found: 316 | csvtxt = '\n'.join(lines) + '\n' 317 | try: 318 | shutil.copy(user_file, user_backup_file) 319 | except: 320 | pass 321 | open(user_file, 'w').write(csvtxt) 322 | 323 | return f'deleted entry: {found}' 324 | else: 325 | return f'no custom mapping found' 326 | 327 | 328 | def add_custom(txt): 329 | txt = txt.strip() 330 | model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename) 331 | model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename) 332 | if len(txt) == 0: 333 | return "Fill keyword(trigger word) or keywords separated by | (pipe character)" 334 | insert_line = f'{model_hash}, {txt}, {model_ckpt}' 335 | global scripts_dir 336 | 337 | user_file = f'{scripts_dir}/custom-mappings.txt' 338 | user_backup_file = f'{scripts_dir}/custom-mappings-backup.txt' 339 | lines = [] 340 | 341 | if os.path.exists(user_file): 342 | with open(user_file, newline='') as csvfile: 343 | csvreader = csv.reader(csvfile) 344 | for row in csvreader: 345 | try: 346 | mhash = row[0] 347 | if mhash.startswith('#'): 348 | lines.append(','.join(row)) 349 | continue 350 | # kw = row[1] 351 | ckptname = None if len(row)<=2 else row[2].strip(' ') 352 | if mhash==model_hash and ckptname==model_ckpt: 353 | continue 354 | lines.append(','.join(row)) 355 | except: 356 | pass 357 | lines.append(insert_line) 358 | csvtxt = '\n'.join(lines) + '\n' 359 | try: 360 | shutil.copy(user_file, user_backup_file) 361 | except: 362 | pass 363 | open(user_file, 'w').write(csvtxt) 364 | 365 | return 'added: ' + insert_line 366 | 367 | def delete_lora_keyword(lora_model): 368 | model_ckpt = lora_model 369 | lora_model_path = f'{shared.cmd_opts.lora_dir}/{lora_model}' 370 | model_hash = get_old_model_hash(lora_model_path) 371 | user_file = f'{scripts_dir}/lora-keyword-user.txt' 372 | user_backup_file = f'{scripts_dir}/lora-keyword-user-backup.txt' 373 | lines = [] 374 | found = None 375 | 376 | if os.path.exists(user_file): 377 | with open(user_file, newline='') as csvfile: 378 | csvreader = csv.reader(csvfile) 379 | for row in csvreader: 380 | try: 381 | mhash = row[0] 382 | if mhash.startswith('#'): 383 | lines.append(','.join(row)) 384 | continue 385 | # kw = row[1] 386 | ckptname = None if len(row)<=2 else row[2].strip(' ') 387 | if mhash==model_hash and ckptname==model_ckpt: 388 | found = row 389 | continue 390 | lines.append(','.join(row)) 391 | except: 392 | pass 393 | 394 | outline = '' 395 | if found: 396 | csvtxt = '\n'.join(lines) + '\n' 397 | try: 398 | shutil.copy(user_file, user_backup_file) 399 | except: 400 | pass 401 | open(user_file, 'w').write(csvtxt) 402 | 403 | outline = f'deleted entry: {found}' 404 | else: 405 | outline = f'no custom mapping found' 406 | lora_keywords = get_lora_keywords(lora_model) 407 | return [outline, gr.Dropdown.update(choices=lora_keywords)] 408 | 409 | 410 | def add_lora_keyword(txt, lora_model): 411 | txt = txt.strip() 412 | model_ckpt = lora_model 413 | lora_model_path = f'{shared.cmd_opts.lora_dir}/{lora_model}' 414 | model_hash = get_old_model_hash(lora_model_path) 415 | if len(txt) == 0: 416 | return "Fill keyword(trigger word) or keywords separated by | (pipe character)" 417 | insert_line = f'{model_hash}, {txt}, {model_ckpt}' 418 | global scripts_dir 419 | 420 | user_file = f'{scripts_dir}/lora-keyword-user.txt' 421 | user_backup_file = f'{scripts_dir}/lora-keyword-user-backup.txt' 422 | lines = [] 423 | 424 | if os.path.exists(user_file): 425 | with open(user_file, newline='') as csvfile: 426 | csvreader = csv.reader(csvfile) 427 | for row in csvreader: 428 | try: 429 | mhash = row[0] 430 | if mhash.startswith('#'): 431 | lines.append(','.join(row)) 432 | continue 433 | # kw = row[1] 434 | ckptname = None if len(row)<=2 else row[2].strip(' ') 435 | if mhash==model_hash and ckptname==model_ckpt: 436 | continue 437 | lines.append(','.join(row)) 438 | except: 439 | pass 440 | lines.append(insert_line) 441 | csvtxt = '\n'.join(lines) + '\n' 442 | try: 443 | shutil.copy(user_file, user_backup_file) 444 | except: 445 | pass 446 | open(user_file, 'w').write(csvtxt) 447 | 448 | lora_keywords = get_lora_keywords(lora_model) 449 | return ['added: ' + insert_line, gr.Dropdown.update(choices=lora_keywords)] 450 | 451 | settings = get_settings() 452 | 453 | def cb_enabled(): 454 | return True if settings['is_enabled']=='True' else False 455 | def cb_keyword_placement(): 456 | return settings['keyword_placement'] 457 | def cb_multiple_keywords(): 458 | return settings['multiple_keywords'] 459 | def cb_ti_keywords(): 460 | return settings['ti_keywords'] 461 | def cb_lora_model(): 462 | return settings['lora_model'] 463 | def cb_lora_multiplier(): 464 | return settings['lora_multiplier'] 465 | def cb_lora_keywords(): 466 | return settings['lora_keywords'] 467 | def cb_keyword_order(): 468 | return settings['keyword_order'] 469 | 470 | refresh_icon = '\U0001f504' 471 | with gr.Group(): 472 | with gr.Accordion('Model Keyword', open=False): 473 | 474 | is_enabled = gr.Checkbox(label='Model Keyword Enabled ', value=cb_enabled) 475 | setattr(is_enabled,"do_not_save_to_config",True) 476 | 477 | placement_choices = ["keyword prompt", "prompt keyword", "keyword, prompt", "prompt, keyword"] 478 | keyword_placement = gr.Dropdown(choices=placement_choices, 479 | value=cb_keyword_placement, 480 | label='Keyword placement: ') 481 | setattr(keyword_placement,"do_not_save_to_config",True) 482 | 483 | mk_choices = ["keyword1, keyword2", "random", "iterate"] 484 | mk_choices.extend(["keyword1", "keyword2"]) 485 | 486 | # css = '#mk_refresh_btn { min-width: 2.3em; height: 2.5em; flex-grow: 0; margin-top: 0.4em; margin-right: 1em; padding-left: 0.25em; padding-right: 0.25em;}' 487 | # with gr.Blocks(css=css): 488 | with gr.Row(equal_height=True): 489 | multiple_keywords = gr.Dropdown(choices=mk_choices, 490 | value=cb_multiple_keywords, 491 | label='Multiple keywords: ') 492 | setattr(multiple_keywords,"do_not_save_to_config",True) 493 | 494 | refresh_btn = gr.Button(value=refresh_icon, elem_id='mk_refresh_btn_random_seed') # XXX _random_seed workaround. 495 | refresh_btn.click(update_keywords, inputs=None, outputs=multiple_keywords) 496 | 497 | ti_choices = ["None"] 498 | ti_choices.extend(get_embeddings()) 499 | with gr.Row(equal_height=True): 500 | ti_keywords = gr.Dropdown(choices=ti_choices, 501 | value=cb_ti_keywords, 502 | label='Textual Inversion (Embedding): ') 503 | setattr(ti_keywords,"do_not_save_to_config",True) 504 | refresh_btn = gr.Button(value=refresh_icon, elem_id='ti_refresh_btn_random_seed') # XXX _random_seed workaround. 505 | refresh_btn.click(update_embeddings, inputs=None, outputs=ti_keywords) 506 | 507 | keyword_order = gr.Dropdown(choices=["textual inversion first", "model keyword first"], 508 | value=cb_keyword_order, 509 | label='Keyword order: ') 510 | setattr(keyword_order,"do_not_save_to_config",True) 511 | 512 | 513 | with gr.Accordion('LORA', open=True): 514 | lora_choices = ["None"] 515 | lora_choices.extend(get_loras()) 516 | lora_kw_choices = get_lora_keywords(settings['lora_model']) 517 | 518 | with gr.Row(equal_height=True): 519 | lora_model = gr.Dropdown(choices=lora_choices, 520 | value=cb_lora_model, 521 | label='Model: ') 522 | setattr(lora_model,"do_not_save_to_config",True) 523 | lora_refresh_btn = gr.Button(value=refresh_icon, elem_id='lora_m_refresh_btn_random_seed') # XXX _random_seed workaround. 524 | lora_refresh_btn.click(update_loras, inputs=None, outputs=lora_model) 525 | 526 | lora_multiplier = gr.Slider(minimum=0,maximum=2, step=0.01, value=cb_lora_multiplier, label="multiplier") 527 | with gr.Row(equal_height=True): 528 | lora_keywords = gr.Dropdown(choices=lora_kw_choices, 529 | value=cb_lora_keywords, 530 | label='keywords: ') 531 | setattr(lora_keywords,"do_not_save_to_config",True) 532 | 533 | lora_model.change(fn=update_lora_keywords,inputs=lora_model, outputs=lora_keywords) 534 | info = gr.HTML("

Add custom keyword(trigger word) mapping for selected LORA model.

") 535 | lora_text_input = gr.Textbox(placeholder="Keyword or keywords separated by |", label="Keyword(trigger word)") 536 | with gr.Row(): 537 | add_mappings = gr.Button(value='Save') 538 | delete_mappings = gr.Button(value='Delete') 539 | lora_text_output = gr.Textbox(interactive=False, label='result') 540 | add_mappings.click(add_lora_keyword, inputs=[lora_text_input, lora_model], outputs=[lora_text_output, lora_keywords]) 541 | delete_mappings.click(delete_lora_keyword, inputs=lora_model, outputs=[lora_text_output, lora_keywords]) 542 | 543 | with gr.Accordion('Add Custom Mappings', open=False): 544 | info = gr.HTML("

Add custom keyword(trigger word) mapping for current model. Custom mappings are saved to extensions/model-keyword/custom-mappings.txt

") 545 | text_input = gr.Textbox(placeholder="Keyword or keywords separated by |", label="Keyword(trigger word)") 546 | with gr.Row(): 547 | check_mappings = gr.Button(value='Check') 548 | add_mappings = gr.Button(value='Save') 549 | delete_mappings = gr.Button(value='Delete') 550 | 551 | text_output = gr.Textbox(interactive=False, label='result') 552 | 553 | add_mappings.click(add_custom, inputs=text_input, outputs=text_output) 554 | check_mappings.click(check_keyword, inputs=None, outputs=text_output) 555 | delete_mappings.click(delete_keyword, inputs=None, outputs=text_output) 556 | 557 | 558 | return [is_enabled, keyword_placement, multiple_keywords, ti_keywords, keyword_order, lora_model, lora_multiplier, lora_keywords] 559 | 560 | def process(self, p, is_enabled, keyword_placement, multiple_keywords, ti_keywords, keyword_order, lora_model, lora_multiplier, lora_keywords): 561 | if lora_model != 'None': 562 | if lora_keywords not in get_lora_keywords(lora_model): 563 | lora_keywords = 'None' 564 | 565 | save_settings({ 566 | 'is_enabled': f'{is_enabled}', 567 | 'keyword_placement': keyword_placement, 568 | 'multiple_keywords': multiple_keywords, 569 | 'ti_keywords': ti_keywords, 570 | 'keyword_order': keyword_order, 571 | 'lora_model': lora_model, 572 | 'lora_multiplier': lora_multiplier, 573 | 'lora_keywords': lora_keywords, 574 | }) 575 | 576 | if not is_enabled: 577 | global hash_dict 578 | hash_dict = None 579 | return 580 | 581 | model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename) 582 | model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename) 583 | # print(f'model_hash = {model_hash}') 584 | 585 | def new_prompt(prompt, kw, no_iter=False): 586 | global kw_idx, lora_idx 587 | if kw: 588 | kws = kw.split('|') 589 | if len(kws) > 1: 590 | kws = [x.strip(' ') for x in kws] 591 | if multiple_keywords=="keyword1, keyword2": 592 | kw = ', '.join(kws) 593 | elif multiple_keywords=="random": 594 | kw = random.choice(kws) 595 | elif multiple_keywords=="iterate": 596 | kw = kws[kw_idx%len(kws)] 597 | if not no_iter: 598 | kw_idx += 1 599 | elif multiple_keywords=="keyword1": 600 | kw = kws[0] 601 | elif multiple_keywords=="keyword2": 602 | kw = kws[1] 603 | elif multiple_keywords in kws: 604 | kw = multiple_keywords 605 | else: 606 | kw = kws[0] 607 | 608 | arr = [kw] 609 | 610 | ti = None 611 | if ti_keywords != 'None': 612 | ti = ti_keywords[:ti_keywords.rfind('.')] 613 | 614 | lora = None 615 | if lora_keywords != 'None' and lora_model != 'None': 616 | lora = lora_keywords 617 | try: 618 | if lora == '< iterate >': 619 | loras = get_lora_keywords(lora_model, keyword_only=True) 620 | lora = loras[lora_idx%len(loras)] 621 | if not no_iter: 622 | lora_idx += 1 623 | elif lora == '< random >': 624 | loras = get_lora_keywords(lora_model, keyword_only=True) 625 | lora = random.choice(loras) 626 | except: 627 | pass 628 | 629 | if keyword_order == 'model keyword first': 630 | arr = [kw, lora, ti] 631 | else: 632 | arr = [ti, lora, kw] 633 | 634 | while None in arr: 635 | arr.remove(None) 636 | 637 | if keyword_placement.startswith('keyword'): 638 | arr.append(prompt) 639 | else: 640 | arr.insert(0, prompt) 641 | 642 | if lora_model != 'None': 643 | lora_name = lora_model[:lora_model.rfind('.')] 644 | lora_name = lora_name.replace('\\', '/') 645 | lora_name = lora_name.split('/')[-1] 646 | arr.insert(0, f'') 647 | 648 | if ',' in keyword_placement: 649 | return ', '.join(arr) 650 | else: 651 | return ' '.join(arr) 652 | 653 | 654 | kw = get_keyword_for_model(model_hash, model_ckpt) 655 | 656 | if kw is not None or ti_keywords != 'None' or lora_model != 'None': 657 | p.prompt = new_prompt(p.prompt, kw, no_iter=True) 658 | p.all_prompts = [new_prompt(prompt, kw) for prompt in p.all_prompts] 659 | 660 | 661 | from fastapi import FastAPI, Response, Query, Body 662 | from fastapi.responses import JSONResponse 663 | 664 | 665 | def model_keyword_api(_: gr.Blocks, app: FastAPI): 666 | @app.get("/model_keyword/get_keywords") 667 | async def get_keywords(): 668 | model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename) 669 | model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename) 670 | r = get_keyword_for_model(model_hash, model_ckpt, return_entry=True) 671 | if r is None: 672 | return {"keywords": [], "model": model_ckpt, "hash": model_hash, "match_source": "no match"} 673 | kws = [x.strip() for x in r[0].split('|')] 674 | match_source = "model-keyword.txt" if r[2]==0 else "custom-mappings.txt" 675 | return {"keywords": kws, "model": model_ckpt, "hash": model_hash, "match_source": match_source} 676 | 677 | # @app.get("/model_keyword/get_raw_keywords") 678 | # async def get_raw_keywords(): 679 | # model_ckpt = os.path.basename(shared.sd_model.sd_checkpoint_info.filename) 680 | # model_hash = get_old_model_hash(shared.sd_model.sd_checkpoint_info.filename) 681 | # kw = get_keyword_for_model(model_hash, model_ckpt) 682 | # return {"keywords": kw, "model": model_ckpt, "hash": model_hash} 683 | 684 | try: 685 | import modules.script_callbacks as script_callbacks 686 | 687 | script_callbacks.on_app_started(model_keyword_api) 688 | except: 689 | pass 690 | --------------------------------------------------------------------------------