├── .gitignore ├── LICENSE.txt ├── README.md ├── explanation.html ├── install.py ├── javascript └── promptgen.js ├── requirements.txt ├── screenshot.png ├── scripts └── promptgen.py └── style.css /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | /models 3 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 AUTOMATIC1111 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prompt generator 2 | An extension for [webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) that lets you generate prompts. 3 | 4 | ![](screenshot.png) 5 | -------------------------------------------------------------------------------- /explanation.html: -------------------------------------------------------------------------------- 1 |
2 | 3 | Information 4 | 5 |
6 | 50 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import launch 2 | import os 3 | 4 | current_dir = os.path.dirname(os.path.realpath(__file__)) 5 | req_file = os.path.join(current_dir, "requirements.txt") 6 | 7 | with open(req_file) as file: 8 | for lib in file: 9 | lib = lib.strip() 10 | if not launch.is_installed(lib): 11 | launch.run_pip( 12 | f"install {lib}", 13 | f"danbooru-tag-gen requirement: {lib}") 14 | -------------------------------------------------------------------------------- /javascript/promptgen.js: -------------------------------------------------------------------------------- 1 | 2 | function promptgen_send_to(where, text){ 3 | textarea = gradioApp().querySelector('#promptgen_selected_text textarea') 4 | textarea.value = text 5 | updateInput(textarea) 6 | 7 | gradioApp().querySelector('#promptgen_send_to_'+where).click() 8 | 9 | where == 'txt2img' ? switch_to_txt2img() : switch_to_img2img() 10 | } 11 | 12 | function promptgen_send_to_txt2img(text){ promptgen_send_to('txt2img', text) } 13 | function promptgen_send_to_img2img(text){ promptgen_send_to('img2img', text) } 14 | 15 | function submit_promptgen(){ 16 | var id = randomId() 17 | requestProgress(id, gradioApp().getElementById('promptgen_results_column'), null, function(){}) 18 | 19 | var res = create_submit_args(arguments) 20 | res[0] = id 21 | return res 22 | } 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.30.1 2 | auto_gptq==0.2.2 3 | -------------------------------------------------------------------------------- /screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qwopqwop200/stable-diffusion-webui-promptgen-danbooru/7aace7ffd8071190c8b1a4f31e3e37060d9615e1/screenshot.png -------------------------------------------------------------------------------- /scripts/promptgen.py: -------------------------------------------------------------------------------- 1 | import html 2 | import os 3 | import time 4 | 5 | import torch 6 | import transformers 7 | from transformers import AutoTokenizer 8 | from auto_gptq import AutoGPTQForCausalLM 9 | 10 | from modules import shared, generation_parameters_copypaste 11 | 12 | from modules import scripts, script_callbacks, devices, ui 13 | import gradio as gr 14 | 15 | from modules.ui_components import FormRow 16 | 17 | 18 | class Model: 19 | name = None 20 | model = None 21 | tokenizer = None 22 | 23 | 24 | available_models = [] 25 | current = Model() 26 | 27 | base_dir = scripts.basedir() 28 | models_dir = os.path.join(base_dir, "models") 29 | 30 | 31 | def device(): 32 | return devices.cpu if shared.opts.promptgen_device == 'cpu' else devices.device 33 | 34 | 35 | def list_available_models(): 36 | available_models.clear() 37 | 38 | os.makedirs(models_dir, exist_ok=True) 39 | 40 | for dirname in os.listdir(models_dir): 41 | if os.path.isdir(os.path.join(models_dir, dirname)): 42 | available_models.append(dirname) 43 | 44 | for name in [x.strip() for x in shared.opts.promptgen_names.split(",")]: 45 | if not name: 46 | continue 47 | 48 | available_models.append(name) 49 | 50 | 51 | def get_model_path(name): 52 | dirname = os.path.join(models_dir, name) 53 | if not os.path.isdir(dirname): 54 | return name 55 | 56 | return dirname 57 | 58 | 59 | def generate_batch(input_ids, min_length, max_length, num_beams, temperature, repetition_penalty, length_penalty, sampling_mode, top_k, top_p): 60 | top_p = float(top_p) if sampling_mode == 'Top P' else None 61 | top_k = int(top_k) if sampling_mode == 'Top K' else None 62 | 63 | outputs = current.model.generate( 64 | input_ids, 65 | do_sample=True, 66 | temperature=max(float(temperature), 1e-6), 67 | repetition_penalty=repetition_penalty, 68 | length_penalty=length_penalty, 69 | top_p=top_p, 70 | top_k=top_k, 71 | num_beams=int(num_beams), 72 | min_length=min_length, 73 | max_length=max_length, 74 | pad_token_id=current.tokenizer.pad_token_id or current.tokenizer.eos_token_id 75 | ) 76 | texts = current.tokenizer.batch_decode(outputs, skip_special_tokens=True) 77 | return texts 78 | 79 | 80 | def model_selection_changed(model_name): 81 | if model_name == "None": 82 | current.tokenizer = None 83 | current.model = None 84 | current.name = None 85 | 86 | devices.torch_gc() 87 | 88 | 89 | def generate(id_task, model_name, batch_count, batch_size, text, *args): 90 | shared.state.textinfo = "Loading model..." 91 | shared.state.job_count = batch_count 92 | model_name = 'qwopqwop/danbooru-llama-gptq' 93 | 94 | if current.name != model_name: 95 | current.tokenizer = None 96 | current.model = None 97 | current.name = None 98 | 99 | if model_name != 'None': 100 | model = AutoGPTQForCausalLM.from_quantized("qwopqwop/danbooru-llama-gptq").model 101 | current.model = model 102 | 103 | DEFAULT_PAD_TOKEN = "[PAD]" 104 | 105 | tokenizer = AutoTokenizer.from_pretrained("pinkmanlove/llama-7b-hf", use_fast=False) 106 | 107 | def smart_tokenizer_and_embedding_resize( 108 | special_tokens_dict, 109 | tokenizer, 110 | model, 111 | ): 112 | """Resize tokenizer and embedding. 113 | 114 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 115 | """ 116 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 117 | model.resize_token_embeddings(len(tokenizer)) 118 | 119 | if num_new_tokens > 0: 120 | input_embeddings = model.get_input_embeddings().weight.data 121 | output_embeddings = model.get_output_embeddings().weight.data 122 | 123 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 124 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 125 | 126 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 127 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 128 | 129 | if tokenizer._pad_token is None: 130 | smart_tokenizer_and_embedding_resize( 131 | special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), 132 | tokenizer=tokenizer, 133 | model=model) 134 | 135 | tokenizer.add_special_tokens({"eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id), 136 | "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id), 137 | "unk_token": tokenizer.convert_ids_to_tokens(model.config.pad_token_id if model.config.pad_token_id != -1 else tokenizer.pad_token_id),}) 138 | 139 | current.tokenizer = tokenizer 140 | current.name = model_name 141 | 142 | assert current.model, 'No model available' 143 | assert current.tokenizer, 'No tokenizer available' 144 | 145 | current.model.to(device()) 146 | 147 | shared.state.textinfo = "" 148 | 149 | input_ids = current.tokenizer(text, return_tensors="pt").input_ids 150 | if input_ids.shape[1] == 0: 151 | input_ids = torch.asarray([[current.tokenizer.bos_token_id]], dtype=torch.long) 152 | input_ids = input_ids.to(device()) 153 | input_ids = input_ids.repeat((batch_size, 1)) 154 | 155 | markup = '' 156 | 157 | index = 0 158 | for i in range(batch_count): 159 | texts = generate_batch(input_ids, *args) 160 | shared.state.nextjob() 161 | for generated_text in texts: 162 | index += 1 163 | markup += f""" 164 | 165 | 170 | 174 | 175 | """ 176 | 177 | markup += '
166 |
167 |

{html.escape(generated_text)}

168 |
169 |
171 | to txt2img 172 | to img2img 173 |
' 178 | 179 | return markup, '' 180 | 181 | 182 | def find_prompts(fields): 183 | field_prompt = [x for x in fields if x[1] == "Prompt"][0] 184 | field_negative_prompt = [x for x in fields if x[1] == "Negative prompt"][0] 185 | return [field_prompt[0], field_negative_prompt[0]] 186 | 187 | 188 | def send_prompts(text): 189 | params = generation_parameters_copypaste.parse_generation_parameters(text) 190 | negative_prompt = params.get("Negative prompt", "") 191 | return params.get("Prompt", ""), negative_prompt or gr.update() 192 | 193 | 194 | def add_tab(): 195 | list_available_models() 196 | 197 | with gr.Blocks(analytics_enabled=False) as tab: 198 | with gr.Row(): 199 | with gr.Column(scale=80): 200 | prompt = gr.Textbox(label="Prompt", elem_id="promptgen_prompt", show_label=False, lines=2, placeholder="Beginning of the prompt (press Ctrl+Enter or Alt+Enter to generate)").style(container=False) 201 | with gr.Column(scale=10): 202 | submit = gr.Button('Generate', elem_id="promptgen_generate", variant='primary') 203 | 204 | with gr.Row(elem_id="promptgen_main"): 205 | with gr.Column(variant="compact"): 206 | selected_text = gr.TextArea(elem_id='promptgen_selected_text', visible=False) 207 | send_to_txt2img = gr.Button(elem_id='promptgen_send_to_txt2img', visible=False) 208 | send_to_img2img = gr.Button(elem_id='promptgen_send_to_img2img', visible=False) 209 | 210 | with FormRow(): 211 | model_selection = gr.Dropdown(label="Model", elem_id="promptgen_model", value=available_models[0], choices=["None"] + available_models) 212 | 213 | with FormRow(): 214 | sampling_mode = gr.Radio(label="Sampling mode", elem_id="promptgen_sampling_mode", value="Top K", choices=["Top K", "Top P"]) 215 | top_k = gr.Slider(label="Top K", elem_id="promptgen_top_k", value=12, minimum=1, maximum=50, step=1) 216 | top_p = gr.Slider(label="Top P", elem_id="promptgen_top_p", value=0.15, minimum=0, maximum=1, step=0.001) 217 | 218 | with gr.Row(): 219 | num_beams = gr.Slider(label="Number of beams", elem_id="promptgen_num_beams", value=1, minimum=1, maximum=8, step=1) 220 | temperature = gr.Slider(label="Temperature", elem_id="promptgen_temperature", value=1, minimum=0, maximum=4, step=0.01) 221 | repetition_penalty = gr.Slider(label="Repetition penalty", elem_id="promptgen_repetition_penalty", value=1, minimum=1, maximum=4, step=0.01) 222 | 223 | with FormRow(): 224 | length_penalty = gr.Slider(label="Length preference", elem_id="promptgen_length_preference", value=1, minimum=-10, maximum=10, step=0.1) 225 | min_length = gr.Slider(label="Min length", elem_id="promptgen_min_length", value=20, minimum=1, maximum=400, step=1) 226 | max_length = gr.Slider(label="Max length", elem_id="promptgen_max_length", value=150, minimum=1, maximum=400, step=1) 227 | 228 | with FormRow(): 229 | batch_count = gr.Slider(label="Batch count", elem_id="promptgen_batch_count", value=1, minimum=1, maximum=100, step=1) 230 | batch_size = gr.Slider(label="Batch size", elem_id="promptgen_batch_size", value=10, minimum=1, maximum=100, step=1) 231 | 232 | with open(os.path.join(base_dir, "explanation.html"), encoding="utf8") as file: 233 | footer = file.read() 234 | gr.HTML(footer) 235 | 236 | with gr.Column(): 237 | with gr.Group(elem_id="promptgen_results_column"): 238 | res = gr.HTML() 239 | res_info = gr.HTML() 240 | 241 | submit.click( 242 | fn=ui.wrap_gradio_gpu_call(generate, extra_outputs=['']), 243 | _js="submit_promptgen", 244 | inputs=[model_selection, model_selection, batch_count, batch_size, prompt, min_length, max_length, num_beams, temperature, repetition_penalty, length_penalty, sampling_mode, top_k, top_p, ], 245 | outputs=[res, res_info] 246 | ) 247 | 248 | model_selection.change( 249 | fn=model_selection_changed, 250 | inputs=[model_selection], 251 | outputs=[], 252 | ) 253 | 254 | send_to_txt2img.click( 255 | fn=send_prompts, 256 | inputs=[selected_text], 257 | outputs=find_prompts(ui.txt2img_paste_fields) 258 | ) 259 | 260 | send_to_img2img.click( 261 | fn=send_prompts, 262 | inputs=[selected_text], 263 | outputs=find_prompts(ui.img2img_paste_fields) 264 | ) 265 | 266 | return [(tab, "Promptgen", "promptgen")] 267 | 268 | 269 | def on_ui_settings(): 270 | section = ("promptgen", "Promptgen") 271 | 272 | shared.opts.add_option("promptgen_names", shared.OptionInfo("qwopqwop/danbooru-llama-gptq", section=section)) 273 | shared.opts.add_option("promptgen_device", shared.OptionInfo("gpu", "Device to use for text generation", gr.Radio, {"choices": ["gpu"]}, section=section)) 274 | 275 | def on_unload(): 276 | current.model = None 277 | current.tokenizer = None 278 | 279 | 280 | script_callbacks.on_ui_tabs(add_tab) 281 | script_callbacks.on_ui_settings(on_ui_settings) 282 | script_callbacks.on_script_unloaded(on_unload) 283 | -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | 2 | #promptgen_generate{ 3 | height: 100% 4 | } 5 | 6 | #promptgen_main{ 7 | margin-top: 1em; 8 | } 9 | 10 | #tab_promptgen table tr{ 11 | height: 1px; 12 | } 13 | 14 | #tab_promptgen table tr td{ 15 | height: 100%; 16 | padding: 0.3em; 17 | } 18 | 19 | #tab_promptgen .prompt{ 20 | border: 1px solid rgba(128, 128, 128, 0.2); 21 | height: 100%; 22 | } 23 | 24 | #tab_promptgen .prompt p{ 25 | white-space: pre-line; 26 | } 27 | 28 | #tab_promptgen .sendto{ 29 | width: 8em; 30 | } 31 | 32 | #tab_promptgen .sendto a{ 33 | cursor: pointer; 34 | display: block; 35 | margin: 0.2em; 36 | padding: 0.4em; 37 | } 38 | 39 | #tab_promptgen .gr-form{ 40 | border: none; 41 | padding-bottom: 0.5em; 42 | } 43 | 44 | #promptgen_explanation table{ 45 | border-collapse: collapse; 46 | } 47 | 48 | #promptgen_explanation table td, #promptgen_explanation table th{ 49 | border: 1px solid rgba(128,128,128,0.1); 50 | vertical-align: top; 51 | 52 | } 53 | 54 | 55 | --------------------------------------------------------------------------------