├── CODEOWNERS ├── localizations └── Put localization files here.txt ├── textual_inversion_templates ├── none.txt ├── style.txt ├── subject.txt ├── hypernetwork.txt ├── style_filewords.txt └── subject_filewords.txt ├── embeddings └── Place Textual Inversion embeddings here.txt ├── models ├── Stable-diffusion │ └── Put Stable Diffusion checkpoints here.txt └── deepbooru │ └── Put your deepbooru release project folder here.txt ├── screenshot.png ├── txt2img_Screenshot.png ├── webui-user.bat ├── .pylintrc ├── modules ├── textual_inversion │ ├── test_embedding.png │ ├── ui.py │ ├── learn_schedule.py │ ├── dataset.py │ └── preprocess.py ├── errors.py ├── face_restoration.py ├── ngrok.py ├── artists.py ├── localization.py ├── safety.py ├── paths.py ├── hypernetworks │ └── ui.py ├── devices.py ├── ldsr_model.py ├── txt2img.py ├── api │ ├── api.py │ └── processing.py ├── memmon.py ├── bsrgan_model.py ├── esrgan_model_arch.py ├── scunet_model.py ├── generation_parameters_copypaste.py ├── masking.py ├── lowvram.py ├── upscaler.py ├── styles.py ├── bsrgan_model_arch.py ├── gfpgan_model.py ├── safe.py ├── img2img.py ├── realesrgan_model.py ├── modelloader.py ├── swinir_model.py ├── codeformer_model.py ├── esrgan_model.py └── deepbooru.py ├── javascript ├── textualInversion.js ├── imageParams.js ├── imageMaskFix.js ├── notification.js ├── dragdrop.js ├── edit-attention.js ├── aspectRatioOverlay.js ├── localization.js ├── progressbar.js ├── contextMenus.js └── ui.js ├── environment-wsl2.yaml ├── requirements.txt ├── .gitignore ├── requirements_versions.txt ├── scripts ├── custom_code.py ├── loopback.py ├── prompt_matrix.py ├── sd_upscale.py ├── prompts_from_file.py └── poor_mans_outpainting.py ├── webui-user.sh ├── .github ├── ISSUE_TEMPLATE │ ├── feature_request.yml │ └── bug_report.yml ├── workflows │ └── on_pull_request.yaml └── PULL_REQUEST_TEMPLATE │ └── pull_request_template.md ├── webui.bat ├── script.js ├── webui.sh └── webui.py /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @AUTOMATIC1111 2 | -------------------------------------------------------------------------------- /localizations/Put localization files here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /textual_inversion_templates/none.txt: -------------------------------------------------------------------------------- 1 | picture 2 | -------------------------------------------------------------------------------- /embeddings/Place Textual Inversion embeddings here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/Stable-diffusion/Put Stable Diffusion checkpoints here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/deepbooru/Put your deepbooru release project folder here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/C43H66N12O12S2/stable-diffusion-webui/HEAD/screenshot.png -------------------------------------------------------------------------------- /txt2img_Screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/C43H66N12O12S2/stable-diffusion-webui/HEAD/txt2img_Screenshot.png -------------------------------------------------------------------------------- /webui-user.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | set PYTHON= 4 | set GIT= 5 | set VENV_DIR= 6 | set COMMANDLINE_ARGS= 7 | 8 | call webui.bat 9 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html 2 | [MESSAGES CONTROL] 3 | disable=C,R,W,E,I 4 | -------------------------------------------------------------------------------- /modules/textual_inversion/test_embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/C43H66N12O12S2/stable-diffusion-webui/HEAD/modules/textual_inversion/test_embedding.png -------------------------------------------------------------------------------- /javascript/textualInversion.js: -------------------------------------------------------------------------------- 1 | 2 | 3 | function start_training_textual_inversion(){ 4 | requestProgress('ti') 5 | gradioApp().querySelector('#ti_error').innerHTML='' 6 | 7 | return args_to_array(arguments) 8 | } 9 | -------------------------------------------------------------------------------- /environment-wsl2.yaml: -------------------------------------------------------------------------------- 1 | name: automatic 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.10 7 | - pip=22.2.2 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.23.1 -------------------------------------------------------------------------------- /modules/errors.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | 4 | 5 | def run(code, task): 6 | try: 7 | code() 8 | except Exception as e: 9 | print(f"{task}: {type(e).__name__}", file=sys.stderr) 10 | print(traceback.format_exc(), file=sys.stderr) 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | basicsr 2 | diffusers 3 | fairscale==0.4.4 4 | fonts 5 | font-roboto 6 | gfpgan 7 | gradio==3.5 8 | invisible-watermark 9 | numpy 10 | omegaconf 11 | piexif 12 | Pillow 13 | pytorch_lightning 14 | realesrgan 15 | scikit-image>=0.19 16 | timm==0.4.12 17 | transformers==4.19.2 18 | torch 19 | einops 20 | jsonmerge 21 | clean-fid 22 | resize-right 23 | torchdiffeq 24 | kornia 25 | lark 26 | inflection 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.ckpt 3 | *.pth 4 | /ESRGAN/* 5 | /SwinIR/* 6 | /repositories 7 | /venv 8 | /tmp 9 | /model.ckpt 10 | /models/**/* 11 | /GFPGANv1.3.pth 12 | /gfpgan/weights/*.pth 13 | /ui-config.json 14 | /outputs 15 | /config.json 16 | /log 17 | /webui.settings.bat 18 | /embeddings 19 | /styles.csv 20 | /params.txt 21 | /styles.csv.bak 22 | /webui-user.bat 23 | /webui-user.sh 24 | /interrogate 25 | /user.css 26 | /.idea 27 | notification.mp3 28 | /SwinIR 29 | /textual_inversion 30 | .vscode -------------------------------------------------------------------------------- /requirements_versions.txt: -------------------------------------------------------------------------------- 1 | transformers==4.19.2 2 | diffusers==0.3.0 3 | basicsr==1.4.2 4 | gfpgan==1.3.8 5 | gradio==3.5 6 | numpy==1.23.3 7 | Pillow==9.2.0 8 | realesrgan==0.3.0 9 | torch 10 | omegaconf==2.2.3 11 | pytorch_lightning==1.7.6 12 | scikit-image==0.19.2 13 | fonts 14 | font-roboto 15 | timm==0.6.7 16 | fairscale==0.4.9 17 | piexif==1.1.3 18 | einops==0.4.1 19 | jsonmerge==1.8.0 20 | clean-fid==0.1.29 21 | resize-right==0.0.2 22 | torchdiffeq==0.2.3 23 | kornia==0.6.7 24 | lark==1.1.2 25 | inflection==0.5.1 26 | -------------------------------------------------------------------------------- /modules/face_restoration.py: -------------------------------------------------------------------------------- 1 | from modules import shared 2 | 3 | 4 | class FaceRestoration: 5 | def name(self): 6 | return "None" 7 | 8 | def restore(self, np_image): 9 | return np_image 10 | 11 | 12 | def restore_faces(np_image): 13 | face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None] 14 | if len(face_restorers) == 0: 15 | return np_image 16 | 17 | face_restorer = face_restorers[0] 18 | 19 | return face_restorer.restore(np_image) 20 | -------------------------------------------------------------------------------- /textual_inversion_templates/style.txt: -------------------------------------------------------------------------------- 1 | a painting, art by [name] 2 | a rendering, art by [name] 3 | a cropped painting, art by [name] 4 | the painting, art by [name] 5 | a clean painting, art by [name] 6 | a dirty painting, art by [name] 7 | a dark painting, art by [name] 8 | a picture, art by [name] 9 | a cool painting, art by [name] 10 | a close-up painting, art by [name] 11 | a bright painting, art by [name] 12 | a cropped painting, art by [name] 13 | a good painting, art by [name] 14 | a close-up painting, art by [name] 15 | a rendition, art by [name] 16 | a nice painting, art by [name] 17 | a small painting, art by [name] 18 | a weird painting, art by [name] 19 | a large painting, art by [name] 20 | -------------------------------------------------------------------------------- /modules/ngrok.py: -------------------------------------------------------------------------------- 1 | from pyngrok import ngrok, conf, exception 2 | 3 | 4 | def connect(token, port, region): 5 | if token == None: 6 | token = 'None' 7 | config = conf.PyngrokConfig( 8 | auth_token=token, region=region 9 | ) 10 | try: 11 | public_url = ngrok.connect(port, pyngrok_config=config).public_url 12 | except exception.PyngrokNgrokError: 13 | print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' 14 | f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') 15 | else: 16 | print(f'ngrok connected to localhost:{port}! URL: {public_url}\n' 17 | 'You can use this link after the launch is complete.') 18 | -------------------------------------------------------------------------------- /javascript/imageParams.js: -------------------------------------------------------------------------------- 1 | window.onload = (function(){ 2 | window.addEventListener('drop', e => { 3 | const target = e.composedPath()[0]; 4 | const idx = selected_gallery_index(); 5 | if (target.placeholder.indexOf("Prompt") == -1) return; 6 | 7 | let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; 8 | 9 | e.stopPropagation(); 10 | e.preventDefault(); 11 | const imgParent = gradioApp().getElementById(prompt_target); 12 | const files = e.dataTransfer.files; 13 | const fileInput = imgParent.querySelector('input[type="file"]'); 14 | if ( fileInput ) { 15 | fileInput.files = files; 16 | fileInput.dispatchEvent(new Event('change')); 17 | } 18 | }); 19 | }); 20 | -------------------------------------------------------------------------------- /modules/artists.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import csv 3 | from collections import namedtuple 4 | 5 | Artist = namedtuple("Artist", ['name', 'weight', 'category']) 6 | 7 | 8 | class ArtistsDatabase: 9 | def __init__(self, filename): 10 | self.cats = set() 11 | self.artists = [] 12 | 13 | if not os.path.exists(filename): 14 | return 15 | 16 | with open(filename, "r", newline='', encoding="utf8") as file: 17 | reader = csv.DictReader(file) 18 | 19 | for row in reader: 20 | artist = Artist(row["artist"], float(row["score"]), row["category"]) 21 | self.artists.append(artist) 22 | self.cats.add(artist.category) 23 | 24 | def categories(self): 25 | return sorted(self.cats) 26 | -------------------------------------------------------------------------------- /textual_inversion_templates/subject.txt: -------------------------------------------------------------------------------- 1 | a photo of a [name] 2 | a rendering of a [name] 3 | a cropped photo of the [name] 4 | the photo of a [name] 5 | a photo of a clean [name] 6 | a photo of a dirty [name] 7 | a dark photo of the [name] 8 | a photo of my [name] 9 | a photo of the cool [name] 10 | a close-up photo of a [name] 11 | a bright photo of the [name] 12 | a cropped photo of a [name] 13 | a photo of the [name] 14 | a good photo of the [name] 15 | a photo of one [name] 16 | a close-up photo of the [name] 17 | a rendition of the [name] 18 | a photo of the clean [name] 19 | a rendition of a [name] 20 | a photo of a nice [name] 21 | a good photo of a [name] 22 | a photo of the nice [name] 23 | a photo of the small [name] 24 | a photo of the weird [name] 25 | a photo of the large [name] 26 | a photo of a cool [name] 27 | a photo of a small [name] 28 | -------------------------------------------------------------------------------- /modules/localization.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import traceback 5 | 6 | localizations = {} 7 | 8 | 9 | def list_localizations(dirname): 10 | localizations.clear() 11 | 12 | for file in os.listdir(dirname): 13 | fn, ext = os.path.splitext(file) 14 | if ext.lower() != ".json": 15 | continue 16 | 17 | localizations[fn] = os.path.join(dirname, file) 18 | 19 | 20 | def localization_js(current_localization_name): 21 | fn = localizations.get(current_localization_name, None) 22 | data = {} 23 | if fn is not None: 24 | try: 25 | with open(fn, "r", encoding="utf8") as file: 26 | data = json.load(file) 27 | except Exception: 28 | print(f"Error loading localization from {fn}:", file=sys.stderr) 29 | print(traceback.format_exc(), file=sys.stderr) 30 | 31 | return f"var localization = {json.dumps(data)}\n" 32 | -------------------------------------------------------------------------------- /textual_inversion_templates/hypernetwork.txt: -------------------------------------------------------------------------------- 1 | a photo of a [filewords] 2 | a rendering of a [filewords] 3 | a cropped photo of the [filewords] 4 | the photo of a [filewords] 5 | a photo of a clean [filewords] 6 | a photo of a dirty [filewords] 7 | a dark photo of the [filewords] 8 | a photo of my [filewords] 9 | a photo of the cool [filewords] 10 | a close-up photo of a [filewords] 11 | a bright photo of the [filewords] 12 | a cropped photo of a [filewords] 13 | a photo of the [filewords] 14 | a good photo of the [filewords] 15 | a photo of one [filewords] 16 | a close-up photo of the [filewords] 17 | a rendition of the [filewords] 18 | a photo of the clean [filewords] 19 | a rendition of a [filewords] 20 | a photo of a nice [filewords] 21 | a good photo of a [filewords] 22 | a photo of the nice [filewords] 23 | a photo of the small [filewords] 24 | a photo of the weird [filewords] 25 | a photo of the large [filewords] 26 | a photo of a cool [filewords] 27 | a photo of a small [filewords] 28 | -------------------------------------------------------------------------------- /textual_inversion_templates/style_filewords.txt: -------------------------------------------------------------------------------- 1 | a painting of [filewords], art by [name] 2 | a rendering of [filewords], art by [name] 3 | a cropped painting of [filewords], art by [name] 4 | the painting of [filewords], art by [name] 5 | a clean painting of [filewords], art by [name] 6 | a dirty painting of [filewords], art by [name] 7 | a dark painting of [filewords], art by [name] 8 | a picture of [filewords], art by [name] 9 | a cool painting of [filewords], art by [name] 10 | a close-up painting of [filewords], art by [name] 11 | a bright painting of [filewords], art by [name] 12 | a cropped painting of [filewords], art by [name] 13 | a good painting of [filewords], art by [name] 14 | a close-up painting of [filewords], art by [name] 15 | a rendition of [filewords], art by [name] 16 | a nice painting of [filewords], art by [name] 17 | a small painting of [filewords], art by [name] 18 | a weird painting of [filewords], art by [name] 19 | a large painting of [filewords], art by [name] 20 | -------------------------------------------------------------------------------- /textual_inversion_templates/subject_filewords.txt: -------------------------------------------------------------------------------- 1 | a photo of a [name], [filewords] 2 | a rendering of a [name], [filewords] 3 | a cropped photo of the [name], [filewords] 4 | the photo of a [name], [filewords] 5 | a photo of a clean [name], [filewords] 6 | a photo of a dirty [name], [filewords] 7 | a dark photo of the [name], [filewords] 8 | a photo of my [name], [filewords] 9 | a photo of the cool [name], [filewords] 10 | a close-up photo of a [name], [filewords] 11 | a bright photo of the [name], [filewords] 12 | a cropped photo of a [name], [filewords] 13 | a photo of the [name], [filewords] 14 | a good photo of the [name], [filewords] 15 | a photo of one [name], [filewords] 16 | a close-up photo of the [name], [filewords] 17 | a rendition of the [name], [filewords] 18 | a photo of the clean [name], [filewords] 19 | a rendition of a [name], [filewords] 20 | a photo of a nice [name], [filewords] 21 | a good photo of a [name], [filewords] 22 | a photo of the nice [name], [filewords] 23 | a photo of the small [name], [filewords] 24 | a photo of the weird [name], [filewords] 25 | a photo of the large [name], [filewords] 26 | a photo of a cool [name], [filewords] 27 | a photo of a small [name], [filewords] 28 | -------------------------------------------------------------------------------- /scripts/custom_code.py: -------------------------------------------------------------------------------- 1 | import modules.scripts as scripts 2 | import gradio as gr 3 | 4 | from modules.processing import Processed 5 | from modules.shared import opts, cmd_opts, state 6 | 7 | class Script(scripts.Script): 8 | 9 | def title(self): 10 | return "Custom code" 11 | 12 | 13 | def show(self, is_img2img): 14 | return cmd_opts.allow_code 15 | 16 | def ui(self, is_img2img): 17 | code = gr.Textbox(label="Python code", visible=False, lines=1) 18 | 19 | return [code] 20 | 21 | 22 | def run(self, p, code): 23 | assert cmd_opts.allow_code, '--allow-code option must be enabled' 24 | 25 | display_result_data = [[], -1, ""] 26 | 27 | def display(imgs, s=display_result_data[1], i=display_result_data[2]): 28 | display_result_data[0] = imgs 29 | display_result_data[1] = s 30 | display_result_data[2] = i 31 | 32 | from types import ModuleType 33 | compiled = compile(code, '', 'exec') 34 | module = ModuleType("testmodule") 35 | module.__dict__.update(globals()) 36 | module.p = p 37 | module.display = display 38 | exec(compiled, module.__dict__) 39 | 40 | return Processed(p, *display_result_data) 41 | 42 | -------------------------------------------------------------------------------- /webui-user.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ######################################################### 3 | # Uncomment and change the variables below to your need:# 4 | ######################################################### 5 | 6 | # Install directory without trailing slash 7 | #install_dir="/home/$(whoami)" 8 | 9 | # Name of the subdirectory 10 | #clone_dir="stable-diffusion-webui" 11 | 12 | # Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention" 13 | export COMMANDLINE_ARGS="" 14 | 15 | # python3 executable 16 | #python_cmd="python3" 17 | 18 | # git executable 19 | #export GIT="git" 20 | 21 | # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) 22 | #venv_dir="venv" 23 | 24 | # script to launch to start the app 25 | #export LAUNCH_SCRIPT="launch.py" 26 | 27 | # install command for torch 28 | #export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113" 29 | 30 | # Requirements file to use for stable-diffusion-webui 31 | #export REQS_FILE="requirements_versions.txt" 32 | 33 | # Fixed git repos 34 | #export K_DIFFUSION_PACKAGE="" 35 | #export GFPGAN_PACKAGE="" 36 | 37 | # Fixed git commits 38 | #export STABLE_DIFFUSION_COMMIT_HASH="" 39 | #export TAMING_TRANSFORMERS_COMMIT_HASH="" 40 | #export CODEFORMER_COMMIT_HASH="" 41 | #export BLIP_COMMIT_HASH="" 42 | 43 | ########################################### 44 | -------------------------------------------------------------------------------- /modules/textual_inversion/ui.py: -------------------------------------------------------------------------------- 1 | import html 2 | 3 | import gradio as gr 4 | 5 | import modules.textual_inversion.textual_inversion 6 | import modules.textual_inversion.preprocess 7 | from modules import sd_hijack, shared 8 | 9 | 10 | def create_embedding(name, initialization_text, nvpt, overwrite_old): 11 | filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text) 12 | 13 | sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() 14 | 15 | return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" 16 | 17 | 18 | def preprocess(*args): 19 | modules.textual_inversion.preprocess.preprocess(*args) 20 | 21 | return "Preprocessing finished.", "" 22 | 23 | 24 | def train_embedding(*args): 25 | 26 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' 27 | 28 | try: 29 | sd_hijack.undo_optimizations() 30 | 31 | embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) 32 | 33 | res = f""" 34 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps. 35 | Embedding saved to {html.escape(filename)} 36 | """ 37 | return res, "" 38 | except Exception: 39 | raise 40 | finally: 41 | sd_hijack.apply_optimizations() 42 | 43 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Suggest an idea for this project 3 | title: "[Feature Request]: " 4 | labels: ["suggestion"] 5 | 6 | body: 7 | - type: checkboxes 8 | attributes: 9 | label: Is there an existing issue for this? 10 | description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit. 11 | options: 12 | - label: I have searched the existing issues and checked the recent builds/commits 13 | required: true 14 | - type: markdown 15 | attributes: 16 | value: | 17 | *Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible* 18 | - type: textarea 19 | id: feature 20 | attributes: 21 | label: What would your feature do ? 22 | description: Tell us about your feature in a very clear and simple way, and what problem it would solve 23 | validations: 24 | required: true 25 | - type: textarea 26 | id: workflow 27 | attributes: 28 | label: Proposed workflow 29 | description: Please provide us with step by step information on how you'd like the feature to be accessed and used 30 | value: | 31 | 1. Go to .... 32 | 2. Press .... 33 | 3. ... 34 | validations: 35 | required: true 36 | - type: textarea 37 | id: misc 38 | attributes: 39 | label: Additional information 40 | description: Add any other context or screenshots about the feature request here. 41 | -------------------------------------------------------------------------------- /javascript/imageMaskFix.js: -------------------------------------------------------------------------------- 1 | /** 2 | * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668 3 | * @see https://github.com/gradio-app/gradio/issues/1721 4 | */ 5 | window.addEventListener( 'resize', () => imageMaskResize()); 6 | function imageMaskResize() { 7 | const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas'); 8 | if ( ! canvases.length ) { 9 | canvases_fixed = false; 10 | window.removeEventListener( 'resize', imageMaskResize ); 11 | return; 12 | } 13 | 14 | const wrapper = canvases[0].closest('.touch-none'); 15 | const previewImage = wrapper.previousElementSibling; 16 | 17 | if ( ! previewImage.complete ) { 18 | previewImage.addEventListener( 'load', () => imageMaskResize()); 19 | return; 20 | } 21 | 22 | const w = previewImage.width; 23 | const h = previewImage.height; 24 | const nw = previewImage.naturalWidth; 25 | const nh = previewImage.naturalHeight; 26 | const portrait = nh > nw; 27 | const factor = portrait; 28 | 29 | const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw); 30 | const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh); 31 | 32 | wrapper.style.width = `${wW}px`; 33 | wrapper.style.height = `${wH}px`; 34 | wrapper.style.left = `0px`; 35 | wrapper.style.top = `0px`; 36 | 37 | canvases.forEach( c => { 38 | c.style.width = c.style.height = ''; 39 | c.style.maxWidth = '100%'; 40 | c.style.maxHeight = '100%'; 41 | c.style.objectFit = 'contain'; 42 | }); 43 | } 44 | 45 | onUiUpdate(() => imageMaskResize()); 46 | -------------------------------------------------------------------------------- /.github/workflows/on_pull_request.yaml: -------------------------------------------------------------------------------- 1 | # See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file 2 | name: Run Linting/Formatting on Pull Requests 3 | 4 | on: 5 | - push 6 | - pull_request 7 | # See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs 8 | # if you want to filter out branches, delete the `- pull_request` and uncomment these lines : 9 | # pull_request: 10 | # branches: 11 | # - master 12 | # branches-ignore: 13 | # - development 14 | 15 | jobs: 16 | lint: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Checkout Code 20 | uses: actions/checkout@v3 21 | - name: Set up Python 3.10 22 | uses: actions/setup-python@v3 23 | with: 24 | python-version: 3.10.6 25 | - uses: actions/cache@v2 26 | with: 27 | path: ~/.cache/pip 28 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} 29 | restore-keys: | 30 | ${{ runner.os }}-pip- 31 | - name: Install PyLint 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install pylint 35 | # This lets PyLint check to see if it can resolve imports 36 | - name: Install dependencies 37 | run : | 38 | export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit" 39 | python launch.py 40 | - name: Analysing the code with pylint 41 | run: | 42 | pylint $(git ls-files '*.py') 43 | -------------------------------------------------------------------------------- /modules/safety.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 3 | from transformers import AutoFeatureExtractor 4 | from PIL import Image 5 | 6 | import modules.shared as shared 7 | 8 | safety_model_id = "CompVis/stable-diffusion-safety-checker" 9 | safety_feature_extractor = None 10 | safety_checker = None 11 | 12 | def numpy_to_pil(images): 13 | """ 14 | Convert a numpy image or a batch of images to a PIL image. 15 | """ 16 | if images.ndim == 3: 17 | images = images[None, ...] 18 | images = (images * 255).round().astype("uint8") 19 | pil_images = [Image.fromarray(image) for image in images] 20 | 21 | return pil_images 22 | 23 | # check and replace nsfw content 24 | def check_safety(x_image): 25 | global safety_feature_extractor, safety_checker 26 | 27 | if safety_feature_extractor is None: 28 | safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) 29 | safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) 30 | 31 | safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") 32 | x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) 33 | 34 | return x_checked_image, has_nsfw_concept 35 | 36 | 37 | def censor_batch(x): 38 | x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy() 39 | x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy) 40 | x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) 41 | 42 | return x 43 | -------------------------------------------------------------------------------- /webui.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | if not defined PYTHON (set PYTHON=python) 4 | if not defined VENV_DIR (set VENV_DIR=venv) 5 | 6 | set ERROR_REPORTING=FALSE 7 | 8 | mkdir tmp 2>NUL 9 | 10 | %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt 11 | if %ERRORLEVEL% == 0 goto :start_venv 12 | echo Couldn't launch python 13 | goto :show_stdout_stderr 14 | 15 | :start_venv 16 | if [%VENV_DIR%] == [-] goto :skip_venv 17 | 18 | dir %VENV_DIR%\Scripts\Python.exe >tmp/stdout.txt 2>tmp/stderr.txt 19 | if %ERRORLEVEL% == 0 goto :activate_venv 20 | 21 | for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i" 22 | echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME% 23 | %PYTHON_FULLNAME% -m venv %VENV_DIR% >tmp/stdout.txt 2>tmp/stderr.txt 24 | if %ERRORLEVEL% == 0 goto :activate_venv 25 | echo Unable to create venv in directory %VENV_DIR% 26 | goto :show_stdout_stderr 27 | 28 | :activate_venv 29 | set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe" 30 | echo venv %PYTHON% 31 | goto :launch 32 | 33 | :skip_venv 34 | 35 | :launch 36 | %PYTHON% launch.py %* 37 | pause 38 | exit /b 39 | 40 | :show_stdout_stderr 41 | 42 | echo. 43 | echo exit code: %errorlevel% 44 | 45 | for /f %%i in ("tmp\stdout.txt") do set size=%%~zi 46 | if %size% equ 0 goto :show_stderr 47 | echo. 48 | echo stdout: 49 | type tmp\stdout.txt 50 | 51 | :show_stderr 52 | for /f %%i in ("tmp\stderr.txt") do set size=%%~zi 53 | if %size% equ 0 goto :show_stderr 54 | echo. 55 | echo stderr: 56 | type tmp\stderr.txt 57 | 58 | :endofscript 59 | 60 | echo. 61 | echo Launch unsuccessful. Exiting. 62 | pause 63 | -------------------------------------------------------------------------------- /javascript/notification.js: -------------------------------------------------------------------------------- 1 | // Monitors the gallery and sends a browser notification when the leading image is new. 2 | 3 | let lastHeadImg = null; 4 | 5 | notificationButton = null 6 | 7 | onUiUpdate(function(){ 8 | if(notificationButton == null){ 9 | notificationButton = gradioApp().getElementById('request_notifications') 10 | 11 | if(notificationButton != null){ 12 | notificationButton.addEventListener('click', function (evt) { 13 | Notification.requestPermission(); 14 | },true); 15 | } 16 | } 17 | 18 | const galleryPreviews = gradioApp().querySelectorAll('img.h-full.w-full.overflow-hidden'); 19 | 20 | if (galleryPreviews == null) return; 21 | 22 | const headImg = galleryPreviews[0]?.src; 23 | 24 | if (headImg == null || headImg == lastHeadImg) return; 25 | 26 | lastHeadImg = headImg; 27 | 28 | // play notification sound if available 29 | gradioApp().querySelector('#audio_notification audio')?.play(); 30 | 31 | if (document.hasFocus()) return; 32 | 33 | // Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated. 34 | const imgs = new Set(Array.from(galleryPreviews).map(img => img.src)); 35 | 36 | const notification = new Notification( 37 | 'Stable Diffusion', 38 | { 39 | body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`, 40 | icon: headImg, 41 | image: headImg, 42 | } 43 | ); 44 | 45 | notification.onclick = function(_){ 46 | parent.focus(); 47 | this.close(); 48 | }; 49 | }); 50 | -------------------------------------------------------------------------------- /modules/paths.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import modules.safe 5 | 6 | script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 7 | models_path = os.path.join(script_path, "models") 8 | sys.path.insert(0, script_path) 9 | 10 | # search for directory of stable diffusion in following places 11 | sd_path = None 12 | possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)] 13 | for possible_sd_path in possible_sd_paths: 14 | if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): 15 | sd_path = os.path.abspath(possible_sd_path) 16 | break 17 | 18 | assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths) 19 | 20 | path_dirs = [ 21 | (sd_path, 'ldm', 'Stable Diffusion', []), 22 | (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []), 23 | (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), 24 | (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), 25 | (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), 26 | ] 27 | 28 | paths = {} 29 | 30 | for d, must_exist, what, options in path_dirs: 31 | must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist)) 32 | if not os.path.exists(must_exist_path): 33 | print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr) 34 | else: 35 | d = os.path.abspath(d) 36 | if "atstart" in options: 37 | sys.path.insert(0, d) 38 | else: 39 | sys.path.append(d) 40 | paths[what] = d 41 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request! 2 | 3 | If you have a large change, pay special attention to this paragraph: 4 | 5 | > Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature. 6 | 7 | Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on. 8 | 9 | **Describe what this pull request is trying to achieve.** 10 | 11 | A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code. 12 | 13 | **Additional notes and description of your changes** 14 | 15 | More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of. 16 | 17 | **Environment this was tested in** 18 | 19 | List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box. 20 | - OS: [e.g. Windows, Linux] 21 | - Browser [e.g. chrome, safari] 22 | - Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB] 23 | 24 | **Screenshots or videos of your changes** 25 | 26 | If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made. 27 | 28 | This is **required** for anything that touches the user interface. -------------------------------------------------------------------------------- /modules/hypernetworks/ui.py: -------------------------------------------------------------------------------- 1 | import html 2 | import os 3 | import re 4 | 5 | import gradio as gr 6 | 7 | import modules.textual_inversion.textual_inversion 8 | import modules.textual_inversion.preprocess 9 | from modules import sd_hijack, shared, devices 10 | from modules.hypernetworks import hypernetwork 11 | 12 | 13 | def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None): 14 | fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") 15 | if not overwrite_old: 16 | assert not os.path.exists(fn), f"file {fn} already exists" 17 | 18 | if type(layer_structure) == str: 19 | layer_structure = [float(x.strip()) for x in layer_structure.split(",")] 20 | 21 | hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( 22 | name=name, 23 | enable_sizes=[int(x) for x in enable_sizes], 24 | layer_structure=layer_structure, 25 | add_layer_norm=add_layer_norm, 26 | activation_func=activation_func, 27 | ) 28 | hypernet.save(fn) 29 | 30 | shared.reload_hypernetworks() 31 | 32 | return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" 33 | 34 | 35 | def train_hypernetwork(*args): 36 | 37 | initial_hypernetwork = shared.loaded_hypernetwork 38 | 39 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' 40 | 41 | try: 42 | sd_hijack.undo_optimizations() 43 | 44 | hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args) 45 | 46 | res = f""" 47 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. 48 | Hypernetwork saved to {html.escape(filename)} 49 | """ 50 | return res, "" 51 | except Exception: 52 | raise 53 | finally: 54 | shared.loaded_hypernetwork = initial_hypernetwork 55 | shared.sd_model.cond_stage_model.to(devices.device) 56 | shared.sd_model.first_stage_model.to(devices.device) 57 | sd_hijack.apply_optimizations() 58 | 59 | -------------------------------------------------------------------------------- /modules/devices.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import torch 4 | 5 | from modules import errors 6 | 7 | # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility 8 | has_mps = getattr(torch, 'has_mps', False) 9 | 10 | cpu = torch.device("cpu") 11 | 12 | 13 | def get_optimal_device(): 14 | if torch.cuda.is_available(): 15 | return torch.device("cuda") 16 | 17 | if has_mps: 18 | return torch.device("mps") 19 | 20 | return cpu 21 | 22 | 23 | def torch_gc(): 24 | if torch.cuda.is_available(): 25 | torch.cuda.empty_cache() 26 | torch.cuda.ipc_collect() 27 | 28 | 29 | def enable_tf32(): 30 | if torch.cuda.is_available(): 31 | torch.backends.cuda.matmul.allow_tf32 = True 32 | torch.backends.cudnn.allow_tf32 = True 33 | 34 | 35 | errors.run(enable_tf32, "Enabling TF32") 36 | 37 | device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() 38 | dtype = torch.float16 39 | dtype_vae = torch.float16 40 | 41 | def randn(seed, shape): 42 | # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. 43 | if device.type == 'mps': 44 | generator = torch.Generator(device=cpu) 45 | generator.manual_seed(seed) 46 | noise = torch.randn(shape, generator=generator, device=cpu).to(device) 47 | return noise 48 | 49 | torch.manual_seed(seed) 50 | return torch.randn(shape, device=device) 51 | 52 | 53 | def randn_without_seed(shape): 54 | # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. 55 | if device.type == 'mps': 56 | generator = torch.Generator(device=cpu) 57 | noise = torch.randn(shape, generator=generator, device=cpu).to(device) 58 | return noise 59 | 60 | return torch.randn(shape, device=device) 61 | 62 | 63 | def autocast(disable=False): 64 | from modules import shared 65 | 66 | if disable: 67 | return contextlib.nullcontext() 68 | 69 | if dtype == torch.float32 or shared.cmd_opts.precision == "full": 70 | return contextlib.nullcontext() 71 | 72 | return torch.autocast("cuda") 73 | -------------------------------------------------------------------------------- /modules/ldsr_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | from basicsr.utils.download_util import load_file_from_url 6 | 7 | from modules.upscaler import Upscaler, UpscalerData 8 | from modules.ldsr_model_arch import LDSR 9 | from modules import shared 10 | 11 | 12 | class UpscalerLDSR(Upscaler): 13 | def __init__(self, user_path): 14 | self.name = "LDSR" 15 | self.user_path = user_path 16 | self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" 17 | self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" 18 | super().__init__() 19 | scaler_data = UpscalerData("LDSR", None, self) 20 | self.scalers = [scaler_data] 21 | 22 | def load_model(self, path: str): 23 | # Remove incorrect project.yaml file if too big 24 | yaml_path = os.path.join(self.model_path, "project.yaml") 25 | old_model_path = os.path.join(self.model_path, "model.pth") 26 | new_model_path = os.path.join(self.model_path, "model.ckpt") 27 | if os.path.exists(yaml_path): 28 | statinfo = os.stat(yaml_path) 29 | if statinfo.st_size >= 10485760: 30 | print("Removing invalid LDSR YAML file.") 31 | os.remove(yaml_path) 32 | if os.path.exists(old_model_path): 33 | print("Renaming model from model.pth to model.ckpt") 34 | os.rename(old_model_path, new_model_path) 35 | model = load_file_from_url(url=self.model_url, model_dir=self.model_path, 36 | file_name="model.ckpt", progress=True) 37 | yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path, 38 | file_name="project.yaml", progress=True) 39 | 40 | try: 41 | return LDSR(model, yaml) 42 | 43 | except Exception: 44 | print("Error importing LDSR:", file=sys.stderr) 45 | print(traceback.format_exc(), file=sys.stderr) 46 | return None 47 | 48 | def do_upscale(self, img, path): 49 | ldsr = self.load_model(path) 50 | if ldsr is None: 51 | print("NO LDSR!") 52 | return img 53 | ddim_steps = shared.opts.ldsr_steps 54 | return ldsr.super_resolution(img, ddim_steps, self.scale) 55 | -------------------------------------------------------------------------------- /modules/textual_inversion/learn_schedule.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | 3 | 4 | class LearnScheduleIterator: 5 | def __init__(self, learn_rate, max_steps, cur_step=0): 6 | """ 7 | specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000 8 | """ 9 | 10 | pairs = learn_rate.split(',') 11 | self.rates = [] 12 | self.it = 0 13 | self.maxit = 0 14 | for i, pair in enumerate(pairs): 15 | tmp = pair.split(':') 16 | if len(tmp) == 2: 17 | step = int(tmp[1]) 18 | if step > cur_step: 19 | self.rates.append((float(tmp[0]), min(step, max_steps))) 20 | self.maxit += 1 21 | if step > max_steps: 22 | return 23 | elif step == -1: 24 | self.rates.append((float(tmp[0]), max_steps)) 25 | self.maxit += 1 26 | return 27 | else: 28 | self.rates.append((float(tmp[0]), max_steps)) 29 | self.maxit += 1 30 | return 31 | 32 | def __iter__(self): 33 | return self 34 | 35 | def __next__(self): 36 | if self.it < self.maxit: 37 | self.it += 1 38 | return self.rates[self.it - 1] 39 | else: 40 | raise StopIteration 41 | 42 | 43 | class LearnRateScheduler: 44 | def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True): 45 | self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step) 46 | (self.learn_rate, self.end_step) = next(self.schedules) 47 | self.verbose = verbose 48 | 49 | if self.verbose: 50 | print(f'Training at rate of {self.learn_rate} until step {self.end_step}') 51 | 52 | self.finished = False 53 | 54 | def apply(self, optimizer, step_number): 55 | if step_number <= self.end_step: 56 | return 57 | 58 | try: 59 | (self.learn_rate, self.end_step) = next(self.schedules) 60 | except Exception: 61 | self.finished = True 62 | return 63 | 64 | if self.verbose: 65 | tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}') 66 | 67 | for pg in optimizer.param_groups: 68 | pg['lr'] = self.learn_rate 69 | 70 | -------------------------------------------------------------------------------- /modules/txt2img.py: -------------------------------------------------------------------------------- 1 | import modules.scripts 2 | from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images 3 | from modules.shared import opts, cmd_opts 4 | import modules.shared as shared 5 | import modules.processing as processing 6 | from modules.ui import plaintext_to_html 7 | 8 | 9 | def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): 10 | p = StableDiffusionProcessingTxt2Img( 11 | sd_model=shared.sd_model, 12 | outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, 13 | outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids, 14 | prompt=prompt, 15 | styles=[prompt_style, prompt_style2], 16 | negative_prompt=negative_prompt, 17 | seed=seed, 18 | subseed=subseed, 19 | subseed_strength=subseed_strength, 20 | seed_resize_from_h=seed_resize_from_h, 21 | seed_resize_from_w=seed_resize_from_w, 22 | seed_enable_extras=seed_enable_extras, 23 | sampler_index=sampler_index, 24 | batch_size=batch_size, 25 | n_iter=n_iter, 26 | steps=steps, 27 | cfg_scale=cfg_scale, 28 | width=width, 29 | height=height, 30 | restore_faces=restore_faces, 31 | tiling=tiling, 32 | enable_hr=enable_hr, 33 | denoising_strength=denoising_strength if enable_hr else None, 34 | firstphase_width=firstphase_width if enable_hr else None, 35 | firstphase_height=firstphase_height if enable_hr else None, 36 | ) 37 | 38 | if cmd_opts.enable_console_prompts: 39 | print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) 40 | 41 | processed = modules.scripts.scripts_txt2img.run(p, *args) 42 | 43 | if processed is None: 44 | processed = process_images(p) 45 | 46 | shared.total_tqdm.clear() 47 | 48 | generation_info_js = processed.js() 49 | if opts.samples_log_stdout: 50 | print(generation_info_js) 51 | 52 | if opts.do_not_show_images: 53 | processed.images = [] 54 | 55 | return processed.images, generation_info_js, plaintext_to_html(processed.info) 56 | 57 | -------------------------------------------------------------------------------- /modules/api/api.py: -------------------------------------------------------------------------------- 1 | from modules.api.processing import StableDiffusionProcessingAPI 2 | from modules.processing import StableDiffusionProcessingTxt2Img, process_images 3 | from modules.sd_samplers import all_samplers 4 | from modules.extras import run_pnginfo 5 | import modules.shared as shared 6 | import uvicorn 7 | from fastapi import Body, APIRouter, HTTPException 8 | from fastapi.responses import JSONResponse 9 | from pydantic import BaseModel, Field, Json 10 | import json 11 | import io 12 | import base64 13 | 14 | sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) 15 | 16 | class TextToImageResponse(BaseModel): 17 | images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") 18 | parameters: Json 19 | info: Json 20 | 21 | 22 | class Api: 23 | def __init__(self, app, queue_lock): 24 | self.router = APIRouter() 25 | self.app = app 26 | self.queue_lock = queue_lock 27 | self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) 28 | 29 | def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): 30 | sampler_index = sampler_to_index(txt2imgreq.sampler_index) 31 | 32 | if sampler_index is None: 33 | raise HTTPException(status_code=404, detail="Sampler not found") 34 | 35 | populate = txt2imgreq.copy(update={ # Override __init__ params 36 | "sd_model": shared.sd_model, 37 | "sampler_index": sampler_index[0], 38 | "do_not_save_samples": True, 39 | "do_not_save_grid": True 40 | } 41 | ) 42 | p = StableDiffusionProcessingTxt2Img(**vars(populate)) 43 | # Override object param 44 | with self.queue_lock: 45 | processed = process_images(p) 46 | 47 | b64images = [] 48 | for i in processed.images: 49 | buffer = io.BytesIO() 50 | i.save(buffer, format="png") 51 | b64images.append(base64.b64encode(buffer.getvalue())) 52 | 53 | return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) 54 | 55 | 56 | 57 | def img2imgapi(self): 58 | raise NotImplementedError 59 | 60 | def extrasapi(self): 61 | raise NotImplementedError 62 | 63 | def pnginfoapi(self): 64 | raise NotImplementedError 65 | 66 | def launch(self, server_name, port): 67 | self.app.include_router(self.router) 68 | uvicorn.run(self.app, host=server_name, port=port) 69 | -------------------------------------------------------------------------------- /script.js: -------------------------------------------------------------------------------- 1 | function gradioApp(){ 2 | return document.getElementsByTagName('gradio-app')[0].shadowRoot; 3 | } 4 | 5 | function get_uiCurrentTab() { 6 | return gradioApp().querySelector('.tabs button:not(.border-transparent)') 7 | } 8 | 9 | function get_uiCurrentTabContent() { 10 | return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])') 11 | } 12 | 13 | uiUpdateCallbacks = [] 14 | uiTabChangeCallbacks = [] 15 | let uiCurrentTab = null 16 | 17 | function onUiUpdate(callback){ 18 | uiUpdateCallbacks.push(callback) 19 | } 20 | function onUiTabChange(callback){ 21 | uiTabChangeCallbacks.push(callback) 22 | } 23 | 24 | function runCallback(x, m){ 25 | try { 26 | x(m) 27 | } catch (e) { 28 | (console.error || console.log).call(console, e.message, e); 29 | } 30 | } 31 | function executeCallbacks(queue, m) { 32 | queue.forEach(function(x){runCallback(x, m)}) 33 | } 34 | 35 | document.addEventListener("DOMContentLoaded", function() { 36 | var mutationObserver = new MutationObserver(function(m){ 37 | executeCallbacks(uiUpdateCallbacks, m); 38 | const newTab = get_uiCurrentTab(); 39 | if ( newTab && ( newTab !== uiCurrentTab ) ) { 40 | uiCurrentTab = newTab; 41 | executeCallbacks(uiTabChangeCallbacks); 42 | } 43 | }); 44 | mutationObserver.observe( gradioApp(), { childList:true, subtree:true }) 45 | }); 46 | 47 | /** 48 | * Add a ctrl+enter as a shortcut to start a generation 49 | */ 50 | document.addEventListener('keydown', function(e) { 51 | var handled = false; 52 | if (e.key !== undefined) { 53 | if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; 54 | } else if (e.keyCode !== undefined) { 55 | if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; 56 | } 57 | if (handled) { 58 | button = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); 59 | if (button) { 60 | button.click(); 61 | } 62 | e.preventDefault(); 63 | } 64 | }) 65 | 66 | /** 67 | * checks that a UI element is not in another hidden element or tab content 68 | */ 69 | function uiElementIsVisible(el) { 70 | let isVisible = !el.closest('.\\!hidden'); 71 | if ( ! isVisible ) { 72 | return false; 73 | } 74 | 75 | while( isVisible = el.closest('.tabitem')?.style.display !== 'none' ) { 76 | if ( ! isVisible ) { 77 | return false; 78 | } else if ( el.parentElement ) { 79 | el = el.parentElement 80 | } else { 81 | break; 82 | } 83 | } 84 | return isVisible; 85 | } -------------------------------------------------------------------------------- /modules/memmon.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | from collections import defaultdict 4 | 5 | import torch 6 | 7 | 8 | class MemUsageMonitor(threading.Thread): 9 | run_flag = None 10 | device = None 11 | disabled = False 12 | opts = None 13 | data = None 14 | 15 | def __init__(self, name, device, opts): 16 | threading.Thread.__init__(self) 17 | self.name = name 18 | self.device = device 19 | self.opts = opts 20 | 21 | self.daemon = True 22 | self.run_flag = threading.Event() 23 | self.data = defaultdict(int) 24 | 25 | try: 26 | torch.cuda.mem_get_info() 27 | torch.cuda.memory_stats(self.device) 28 | except Exception as e: # AMD or whatever 29 | print(f"Warning: caught exception '{e}', memory monitor disabled") 30 | self.disabled = True 31 | 32 | def run(self): 33 | if self.disabled: 34 | return 35 | 36 | while True: 37 | self.run_flag.wait() 38 | 39 | torch.cuda.reset_peak_memory_stats() 40 | self.data.clear() 41 | 42 | if self.opts.memmon_poll_rate <= 0: 43 | self.run_flag.clear() 44 | continue 45 | 46 | self.data["min_free"] = torch.cuda.mem_get_info()[0] 47 | 48 | while self.run_flag.is_set(): 49 | free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug? 50 | self.data["min_free"] = min(self.data["min_free"], free) 51 | 52 | time.sleep(1 / self.opts.memmon_poll_rate) 53 | 54 | def dump_debug(self): 55 | print(self, 'recorded data:') 56 | for k, v in self.read().items(): 57 | print(k, -(v // -(1024 ** 2))) 58 | 59 | print(self, 'raw torch memory stats:') 60 | tm = torch.cuda.memory_stats(self.device) 61 | for k, v in tm.items(): 62 | if 'bytes' not in k: 63 | continue 64 | print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2))) 65 | 66 | print(torch.cuda.memory_summary()) 67 | 68 | def monitor(self): 69 | self.run_flag.set() 70 | 71 | def read(self): 72 | if not self.disabled: 73 | free, total = torch.cuda.mem_get_info() 74 | self.data["total"] = total 75 | 76 | torch_stats = torch.cuda.memory_stats(self.device) 77 | self.data["active_peak"] = torch_stats["active_bytes.all.peak"] 78 | self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] 79 | self.data["system_peak"] = total - self.data["min_free"] 80 | 81 | return self.data 82 | 83 | def stop(self): 84 | self.run_flag.clear() 85 | return self.read() 86 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: You think somethings is broken in the UI 3 | title: "[Bug]: " 4 | labels: ["bug-report"] 5 | 6 | body: 7 | - type: checkboxes 8 | attributes: 9 | label: Is there an existing issue for this? 10 | description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit. 11 | options: 12 | - label: I have searched the existing issues and checked the recent builds/commits 13 | required: true 14 | - type: markdown 15 | attributes: 16 | value: | 17 | *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible** 18 | - type: textarea 19 | id: what-did 20 | attributes: 21 | label: What happened? 22 | description: Tell us what happened in a very clear and simple way 23 | validations: 24 | required: true 25 | - type: textarea 26 | id: steps 27 | attributes: 28 | label: Steps to reproduce the problem 29 | description: Please provide us with precise step by step information on how to reproduce the bug 30 | value: | 31 | 1. Go to .... 32 | 2. Press .... 33 | 3. ... 34 | validations: 35 | required: true 36 | - type: textarea 37 | id: what-should 38 | attributes: 39 | label: What should have happened? 40 | description: tell what you think the normal behavior should be 41 | validations: 42 | required: true 43 | - type: input 44 | id: commit 45 | attributes: 46 | label: Commit where the problem happens 47 | description: Which commit are you running ? (copy the **Commit hash** shown in the cmd/terminal when you launch the UI) 48 | - type: dropdown 49 | id: platforms 50 | attributes: 51 | label: What platforms do you use to access UI ? 52 | multiple: true 53 | options: 54 | - Windows 55 | - Linux 56 | - MacOS 57 | - iOS 58 | - Android 59 | - Other/Cloud 60 | - type: dropdown 61 | id: browsers 62 | attributes: 63 | label: What browsers do you use to access the UI ? 64 | multiple: true 65 | options: 66 | - Mozilla Firefox 67 | - Google Chrome 68 | - Brave 69 | - Apple Safari 70 | - Microsoft Edge 71 | - type: textarea 72 | id: cmdargs 73 | attributes: 74 | label: Command Line Arguments 75 | description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below 76 | render: Shell 77 | - type: textarea 78 | id: misc 79 | attributes: 80 | label: Additional information, context and logs 81 | description: Please provide us with any relevant additional info, context or log output. 82 | -------------------------------------------------------------------------------- /modules/bsrgan_model.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import sys 3 | import traceback 4 | 5 | import PIL.Image 6 | import numpy as np 7 | import torch 8 | from basicsr.utils.download_util import load_file_from_url 9 | 10 | import modules.upscaler 11 | from modules import devices, modelloader 12 | from modules.bsrgan_model_arch import RRDBNet 13 | 14 | 15 | class UpscalerBSRGAN(modules.upscaler.Upscaler): 16 | def __init__(self, dirname): 17 | self.name = "BSRGAN" 18 | self.model_name = "BSRGAN 4x" 19 | self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth" 20 | self.user_path = dirname 21 | super().__init__() 22 | model_paths = self.find_models(ext_filter=[".pt", ".pth"]) 23 | scalers = [] 24 | if len(model_paths) == 0: 25 | scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4) 26 | scalers.append(scaler_data) 27 | for file in model_paths: 28 | if "http" in file: 29 | name = self.model_name 30 | else: 31 | name = modelloader.friendly_name(file) 32 | try: 33 | scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) 34 | scalers.append(scaler_data) 35 | except Exception: 36 | print(f"Error loading BSRGAN model: {file}", file=sys.stderr) 37 | print(traceback.format_exc(), file=sys.stderr) 38 | self.scalers = scalers 39 | 40 | def do_upscale(self, img: PIL.Image, selected_file): 41 | torch.cuda.empty_cache() 42 | model = self.load_model(selected_file) 43 | if model is None: 44 | return img 45 | model.to(devices.device_bsrgan) 46 | torch.cuda.empty_cache() 47 | img = np.array(img) 48 | img = img[:, :, ::-1] 49 | img = np.moveaxis(img, 2, 0) / 255 50 | img = torch.from_numpy(img).float() 51 | img = img.unsqueeze(0).to(devices.device_bsrgan) 52 | with torch.no_grad(): 53 | output = model(img) 54 | output = output.squeeze().float().cpu().clamp_(0, 1).numpy() 55 | output = 255. * np.moveaxis(output, 0, 2) 56 | output = output.astype(np.uint8) 57 | output = output[:, :, ::-1] 58 | torch.cuda.empty_cache() 59 | return PIL.Image.fromarray(output, 'RGB') 60 | 61 | def load_model(self, path: str): 62 | if "http" in path: 63 | filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, 64 | progress=True) 65 | else: 66 | filename = path 67 | if not os.path.exists(filename) or filename is None: 68 | print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr) 69 | return None 70 | model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network 71 | model.load_state_dict(torch.load(filename), strict=True) 72 | model.eval() 73 | for k, v in model.named_parameters(): 74 | v.requires_grad = False 75 | return model 76 | 77 | -------------------------------------------------------------------------------- /javascript/dragdrop.js: -------------------------------------------------------------------------------- 1 | // allows drag-dropping files into gradio image elements, and also pasting images from clipboard 2 | 3 | function isValidImageList( files ) { 4 | return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type); 5 | } 6 | 7 | function dropReplaceImage( imgWrap, files ) { 8 | if ( ! isValidImageList( files ) ) { 9 | return; 10 | } 11 | 12 | imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click(); 13 | const callback = () => { 14 | const fileInput = imgWrap.querySelector('input[type="file"]'); 15 | if ( fileInput ) { 16 | fileInput.files = files; 17 | fileInput.dispatchEvent(new Event('change')); 18 | } 19 | }; 20 | 21 | if ( imgWrap.closest('#pnginfo_image') ) { 22 | // special treatment for PNG Info tab, wait for fetch request to finish 23 | const oldFetch = window.fetch; 24 | window.fetch = async (input, options) => { 25 | const response = await oldFetch(input, options); 26 | if ( 'api/predict/' === input ) { 27 | const content = await response.text(); 28 | window.fetch = oldFetch; 29 | window.requestAnimationFrame( () => callback() ); 30 | return new Response(content, { 31 | status: response.status, 32 | statusText: response.statusText, 33 | headers: response.headers 34 | }) 35 | } 36 | return response; 37 | }; 38 | } else { 39 | window.requestAnimationFrame( () => callback() ); 40 | } 41 | } 42 | 43 | window.document.addEventListener('dragover', e => { 44 | const target = e.composedPath()[0]; 45 | const imgWrap = target.closest('[data-testid="image"]'); 46 | if ( !imgWrap && target.placeholder.indexOf("Prompt") == -1) { 47 | return; 48 | } 49 | e.stopPropagation(); 50 | e.preventDefault(); 51 | e.dataTransfer.dropEffect = 'copy'; 52 | }); 53 | 54 | window.document.addEventListener('drop', e => { 55 | const target = e.composedPath()[0]; 56 | if (target.placeholder.indexOf("Prompt") == -1) { 57 | return; 58 | } 59 | const imgWrap = target.closest('[data-testid="image"]'); 60 | if ( !imgWrap ) { 61 | return; 62 | } 63 | e.stopPropagation(); 64 | e.preventDefault(); 65 | const files = e.dataTransfer.files; 66 | dropReplaceImage( imgWrap, files ); 67 | }); 68 | 69 | window.addEventListener('paste', e => { 70 | const files = e.clipboardData.files; 71 | if ( ! isValidImageList( files ) ) { 72 | return; 73 | } 74 | 75 | const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')] 76 | .filter(el => uiElementIsVisible(el)); 77 | if ( ! visibleImageFields.length ) { 78 | return; 79 | } 80 | 81 | const firstFreeImageField = visibleImageFields 82 | .filter(el => el.querySelector('input[type=file]'))?.[0]; 83 | 84 | dropReplaceImage( 85 | firstFreeImageField ? 86 | firstFreeImageField : 87 | visibleImageFields[visibleImageFields.length - 1] 88 | , files ); 89 | }); 90 | -------------------------------------------------------------------------------- /modules/esrgan_model_arch.py: -------------------------------------------------------------------------------- 1 | # this file is taken from https://github.com/xinntao/ESRGAN 2 | 3 | import functools 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def make_layer(block, n_layers): 10 | layers = [] 11 | for _ in range(n_layers): 12 | layers.append(block()) 13 | return nn.Sequential(*layers) 14 | 15 | 16 | class ResidualDenseBlock_5C(nn.Module): 17 | def __init__(self, nf=64, gc=32, bias=True): 18 | super(ResidualDenseBlock_5C, self).__init__() 19 | # gc: growth channel, i.e. intermediate channels 20 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) 21 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) 22 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) 23 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) 24 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) 25 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 26 | 27 | # initialization 28 | # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 29 | 30 | def forward(self, x): 31 | x1 = self.lrelu(self.conv1(x)) 32 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 33 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 34 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 35 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 36 | return x5 * 0.2 + x 37 | 38 | 39 | class RRDB(nn.Module): 40 | '''Residual in Residual Dense Block''' 41 | 42 | def __init__(self, nf, gc=32): 43 | super(RRDB, self).__init__() 44 | self.RDB1 = ResidualDenseBlock_5C(nf, gc) 45 | self.RDB2 = ResidualDenseBlock_5C(nf, gc) 46 | self.RDB3 = ResidualDenseBlock_5C(nf, gc) 47 | 48 | def forward(self, x): 49 | out = self.RDB1(x) 50 | out = self.RDB2(out) 51 | out = self.RDB3(out) 52 | return out * 0.2 + x 53 | 54 | 55 | class RRDBNet(nn.Module): 56 | def __init__(self, in_nc, out_nc, nf, nb, gc=32): 57 | super(RRDBNet, self).__init__() 58 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) 59 | 60 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 61 | self.RRDB_trunk = make_layer(RRDB_block_f, nb) 62 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 63 | #### upsampling 64 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 65 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 66 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 67 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 68 | 69 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 70 | 71 | def forward(self, x): 72 | fea = self.conv_first(x) 73 | trunk = self.trunk_conv(self.RRDB_trunk(fea)) 74 | fea = fea + trunk 75 | 76 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) 77 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) 78 | out = self.conv_last(self.lrelu(self.HRconv(fea))) 79 | 80 | return out 81 | -------------------------------------------------------------------------------- /scripts/loopback.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import trange 3 | 4 | import modules.scripts as scripts 5 | import gradio as gr 6 | 7 | from modules import processing, shared, sd_samplers, images 8 | from modules.processing import Processed 9 | from modules.sd_samplers import samplers 10 | from modules.shared import opts, cmd_opts, state 11 | 12 | class Script(scripts.Script): 13 | def title(self): 14 | return "Loopback" 15 | 16 | def show(self, is_img2img): 17 | return is_img2img 18 | 19 | def ui(self, is_img2img): 20 | loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4) 21 | denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1) 22 | 23 | return [loops, denoising_strength_change_factor] 24 | 25 | def run(self, p, loops, denoising_strength_change_factor): 26 | processing.fix_seed(p) 27 | batch_count = p.n_iter 28 | p.extra_generation_params = { 29 | "Denoising strength change factor": denoising_strength_change_factor, 30 | } 31 | 32 | p.batch_size = 1 33 | p.n_iter = 1 34 | 35 | output_images, info = None, None 36 | initial_seed = None 37 | initial_info = None 38 | 39 | grids = [] 40 | all_images = [] 41 | original_init_image = p.init_images 42 | state.job_count = loops * batch_count 43 | 44 | initial_color_corrections = [processing.setup_color_correction(p.init_images[0])] 45 | 46 | for n in range(batch_count): 47 | history = [] 48 | 49 | # Reset to original init image at the start of each batch 50 | p.init_images = original_init_image 51 | 52 | for i in range(loops): 53 | p.n_iter = 1 54 | p.batch_size = 1 55 | p.do_not_save_grid = True 56 | 57 | if opts.img2img_color_correction: 58 | p.color_corrections = initial_color_corrections 59 | 60 | state.job = f"Iteration {i + 1}/{loops}, batch {n + 1}/{batch_count}" 61 | 62 | processed = processing.process_images(p) 63 | 64 | if initial_seed is None: 65 | initial_seed = processed.seed 66 | initial_info = processed.info 67 | 68 | init_img = processed.images[0] 69 | 70 | p.init_images = [init_img] 71 | p.seed = processed.seed + 1 72 | p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1) 73 | history.append(processed.images[0]) 74 | 75 | grid = images.image_grid(history, rows=1) 76 | if opts.grid_save: 77 | images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p) 78 | 79 | grids.append(grid) 80 | all_images += history 81 | 82 | if opts.return_grid: 83 | all_images = grids + all_images 84 | 85 | processed = Processed(p, all_images, initial_seed, initial_info) 86 | 87 | return processed 88 | -------------------------------------------------------------------------------- /scripts/prompt_matrix.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import namedtuple 3 | from copy import copy 4 | import random 5 | 6 | import modules.scripts as scripts 7 | import gradio as gr 8 | 9 | from modules import images 10 | from modules.processing import process_images, Processed 11 | from modules.shared import opts, cmd_opts, state 12 | import modules.sd_samplers 13 | 14 | 15 | def draw_xy_grid(xs, ys, x_label, y_label, cell): 16 | res = [] 17 | 18 | ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys] 19 | hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs] 20 | 21 | first_pocessed = None 22 | 23 | state.job_count = len(xs) * len(ys) 24 | 25 | for iy, y in enumerate(ys): 26 | for ix, x in enumerate(xs): 27 | state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" 28 | 29 | processed = cell(x, y) 30 | if first_pocessed is None: 31 | first_pocessed = processed 32 | 33 | res.append(processed.images[0]) 34 | 35 | grid = images.image_grid(res, rows=len(ys)) 36 | grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) 37 | 38 | first_pocessed.images = [grid] 39 | 40 | return first_pocessed 41 | 42 | 43 | class Script(scripts.Script): 44 | def title(self): 45 | return "Prompt matrix" 46 | 47 | def ui(self, is_img2img): 48 | put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False) 49 | 50 | return [put_at_start] 51 | 52 | def run(self, p, put_at_start): 53 | modules.processing.fix_seed(p) 54 | 55 | original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt 56 | 57 | all_prompts = [] 58 | prompt_matrix_parts = original_prompt.split("|") 59 | combination_count = 2 ** (len(prompt_matrix_parts) - 1) 60 | for combination_num in range(combination_count): 61 | selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)] 62 | 63 | if put_at_start: 64 | selected_prompts = selected_prompts + [prompt_matrix_parts[0]] 65 | else: 66 | selected_prompts = [prompt_matrix_parts[0]] + selected_prompts 67 | 68 | all_prompts.append(", ".join(selected_prompts)) 69 | 70 | p.n_iter = math.ceil(len(all_prompts) / p.batch_size) 71 | p.do_not_save_grid = True 72 | 73 | print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.") 74 | 75 | p.prompt = all_prompts 76 | p.seed = [p.seed for _ in all_prompts] 77 | p.prompt_for_display = original_prompt 78 | processed = process_images(p) 79 | 80 | grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) 81 | grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts) 82 | processed.images.insert(0, grid) 83 | 84 | if opts.grid_save: 85 | images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", prompt=original_prompt, seed=processed.seed, grid=True, p=p) 86 | 87 | return processed 88 | -------------------------------------------------------------------------------- /javascript/edit-attention.js: -------------------------------------------------------------------------------- 1 | addEventListener('keydown', (event) => { 2 | let target = event.originalTarget || event.composedPath()[0]; 3 | if (!target.hasAttribute("placeholder")) return; 4 | if (!target.placeholder.toLowerCase().includes("prompt")) return; 5 | if (! (event.metaKey || event.ctrlKey)) return; 6 | 7 | 8 | let plus = "ArrowUp" 9 | let minus = "ArrowDown" 10 | if (event.key != plus && event.key != minus) return; 11 | 12 | let selectionStart = target.selectionStart; 13 | let selectionEnd = target.selectionEnd; 14 | // If the user hasn't selected anything, let's select their current parenthesis block 15 | if (selectionStart === selectionEnd) { 16 | // Find opening parenthesis around current cursor 17 | const before = target.value.substring(0, selectionStart); 18 | let beforeParen = before.lastIndexOf("("); 19 | if (beforeParen == -1) return; 20 | let beforeParenClose = before.lastIndexOf(")"); 21 | while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { 22 | beforeParen = before.lastIndexOf("(", beforeParen - 1); 23 | beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1); 24 | } 25 | 26 | // Find closing parenthesis around current cursor 27 | const after = target.value.substring(selectionStart); 28 | let afterParen = after.indexOf(")"); 29 | if (afterParen == -1) return; 30 | let afterParenOpen = after.indexOf("("); 31 | while (afterParenOpen !== -1 && afterParen > afterParenOpen) { 32 | afterParen = after.indexOf(")", afterParen + 1); 33 | afterParenOpen = after.indexOf("(", afterParenOpen + 1); 34 | } 35 | if (beforeParen === -1 || afterParen === -1) return; 36 | 37 | // Set the selection to the text between the parenthesis 38 | const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen); 39 | const lastColon = parenContent.lastIndexOf(":"); 40 | selectionStart = beforeParen + 1; 41 | selectionEnd = selectionStart + lastColon; 42 | target.setSelectionRange(selectionStart, selectionEnd); 43 | } 44 | 45 | event.preventDefault(); 46 | 47 | if (selectionStart == 0 || target.value[selectionStart - 1] != "(") { 48 | target.value = target.value.slice(0, selectionStart) + 49 | "(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" + 50 | target.value.slice(selectionEnd); 51 | 52 | target.focus(); 53 | target.selectionStart = selectionStart + 1; 54 | target.selectionEnd = selectionEnd + 1; 55 | 56 | } else { 57 | end = target.value.slice(selectionEnd + 1).indexOf(")") + 1; 58 | weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end)); 59 | if (isNaN(weight)) return; 60 | if (event.key == minus) weight -= 0.1; 61 | if (event.key == plus) weight += 0.1; 62 | 63 | weight = parseFloat(weight.toPrecision(12)); 64 | 65 | target.value = target.value.slice(0, selectionEnd + 1) + 66 | weight + 67 | target.value.slice(selectionEnd + 1 + end - 1); 68 | 69 | target.focus(); 70 | target.selectionStart = selectionStart; 71 | target.selectionEnd = selectionEnd; 72 | } 73 | // Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its 74 | // internal Svelte data binding remains in sync. 75 | target.dispatchEvent(new Event("input", { bubbles: true })); 76 | }); 77 | -------------------------------------------------------------------------------- /modules/scunet_model.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import sys 3 | import traceback 4 | 5 | import PIL.Image 6 | import numpy as np 7 | import torch 8 | from basicsr.utils.download_util import load_file_from_url 9 | 10 | import modules.upscaler 11 | from modules import devices, modelloader 12 | from modules.scunet_model_arch import SCUNet as net 13 | 14 | 15 | class UpscalerScuNET(modules.upscaler.Upscaler): 16 | def __init__(self, dirname): 17 | self.name = "ScuNET" 18 | self.model_name = "ScuNET GAN" 19 | self.model_name2 = "ScuNET PSNR" 20 | self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth" 21 | self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth" 22 | self.user_path = dirname 23 | super().__init__() 24 | model_paths = self.find_models(ext_filter=[".pth"]) 25 | scalers = [] 26 | add_model2 = True 27 | for file in model_paths: 28 | if "http" in file: 29 | name = self.model_name 30 | else: 31 | name = modelloader.friendly_name(file) 32 | if name == self.model_name2 or file == self.model_url2: 33 | add_model2 = False 34 | try: 35 | scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) 36 | scalers.append(scaler_data) 37 | except Exception: 38 | print(f"Error loading ScuNET model: {file}", file=sys.stderr) 39 | print(traceback.format_exc(), file=sys.stderr) 40 | if add_model2: 41 | scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) 42 | scalers.append(scaler_data2) 43 | self.scalers = scalers 44 | 45 | def do_upscale(self, img: PIL.Image, selected_file): 46 | torch.cuda.empty_cache() 47 | 48 | model = self.load_model(selected_file) 49 | if model is None: 50 | return img 51 | 52 | device = devices.device_scunet 53 | img = np.array(img) 54 | img = img[:, :, ::-1] 55 | img = np.moveaxis(img, 2, 0) / 255 56 | img = torch.from_numpy(img).float() 57 | img = img.unsqueeze(0).to(device) 58 | 59 | img = img.to(device) 60 | with torch.no_grad(): 61 | output = model(img) 62 | output = output.squeeze().float().cpu().clamp_(0, 1).numpy() 63 | output = 255. * np.moveaxis(output, 0, 2) 64 | output = output.astype(np.uint8) 65 | output = output[:, :, ::-1] 66 | torch.cuda.empty_cache() 67 | return PIL.Image.fromarray(output, 'RGB') 68 | 69 | def load_model(self, path: str): 70 | device = devices.device_scunet 71 | if "http" in path: 72 | filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, 73 | progress=True) 74 | else: 75 | filename = path 76 | if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: 77 | print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr) 78 | return None 79 | 80 | model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) 81 | model.load_state_dict(torch.load(filename), strict=True) 82 | model.eval() 83 | for k, v in model.named_parameters(): 84 | v.requires_grad = False 85 | model = model.to(device) 86 | 87 | return model 88 | 89 | -------------------------------------------------------------------------------- /modules/generation_parameters_copypaste.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import gradio as gr 4 | from modules.shared import script_path 5 | from modules import shared 6 | 7 | re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" 8 | re_param = re.compile(re_param_code) 9 | re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") 10 | re_imagesize = re.compile(r"^(\d+)x(\d+)$") 11 | type_of_gr_update = type(gr.update()) 12 | 13 | 14 | def parse_generation_parameters(x: str): 15 | """parses generation parameters string, the one you see in text field under the picture in UI: 16 | ``` 17 | girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate 18 | Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing 19 | Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b 20 | ``` 21 | 22 | returns a dict with field values 23 | """ 24 | 25 | res = {} 26 | 27 | prompt = "" 28 | negative_prompt = "" 29 | 30 | done_with_prompt = False 31 | 32 | *lines, lastline = x.strip().split("\n") 33 | if not re_params.match(lastline): 34 | lines.append(lastline) 35 | lastline = '' 36 | 37 | for i, line in enumerate(lines): 38 | line = line.strip() 39 | if line.startswith("Negative prompt:"): 40 | done_with_prompt = True 41 | line = line[16:].strip() 42 | 43 | if done_with_prompt: 44 | negative_prompt += ("" if negative_prompt == "" else "\n") + line 45 | else: 46 | prompt += ("" if prompt == "" else "\n") + line 47 | 48 | res["Prompt"] = prompt 49 | res["Negative prompt"] = negative_prompt 50 | 51 | for k, v in re_param.findall(lastline): 52 | m = re_imagesize.match(v) 53 | if m is not None: 54 | res[k+"-1"] = m.group(1) 55 | res[k+"-2"] = m.group(2) 56 | else: 57 | res[k] = v 58 | 59 | return res 60 | 61 | 62 | def connect_paste(button, paste_fields, input_comp, js=None): 63 | def paste_func(prompt): 64 | if not prompt and not shared.cmd_opts.hide_ui_dir_config: 65 | filename = os.path.join(script_path, "params.txt") 66 | if os.path.exists(filename): 67 | with open(filename, "r", encoding="utf8") as file: 68 | prompt = file.read() 69 | 70 | params = parse_generation_parameters(prompt) 71 | res = [] 72 | 73 | for output, key in paste_fields: 74 | if callable(key): 75 | v = key(params) 76 | else: 77 | v = params.get(key, None) 78 | 79 | if v is None: 80 | res.append(gr.update()) 81 | elif isinstance(v, type_of_gr_update): 82 | res.append(v) 83 | else: 84 | try: 85 | valtype = type(output.value) 86 | val = valtype(v) 87 | res.append(gr.update(value=val)) 88 | except Exception: 89 | res.append(gr.update()) 90 | 91 | return res 92 | 93 | button.click( 94 | fn=paste_func, 95 | _js=js, 96 | inputs=[input_comp], 97 | outputs=[x[0] for x in paste_fields], 98 | ) 99 | -------------------------------------------------------------------------------- /modules/api/processing.py: -------------------------------------------------------------------------------- 1 | from inflection import underscore 2 | from typing import Any, Dict, Optional 3 | from pydantic import BaseModel, Field, create_model 4 | from modules.processing import StableDiffusionProcessingTxt2Img 5 | import inspect 6 | 7 | 8 | API_NOT_ALLOWED = [ 9 | "self", 10 | "kwargs", 11 | "sd_model", 12 | "outpath_samples", 13 | "outpath_grids", 14 | "sampler_index", 15 | "do_not_save_samples", 16 | "do_not_save_grid", 17 | "extra_generation_params", 18 | "overlay_images", 19 | "do_not_reload_embeddings", 20 | "seed_enable_extras", 21 | "prompt_for_display", 22 | "sampler_noise_scheduler_override", 23 | "ddim_discretize" 24 | ] 25 | 26 | class ModelDef(BaseModel): 27 | """Assistance Class for Pydantic Dynamic Model Generation""" 28 | 29 | field: str 30 | field_alias: str 31 | field_type: Any 32 | field_value: Any 33 | 34 | 35 | class PydanticModelGenerator: 36 | """ 37 | Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about: 38 | source_data is a snapshot of the default values produced by the class 39 | params are the names of the actual keys required by __init__ 40 | """ 41 | 42 | def __init__( 43 | self, 44 | model_name: str = None, 45 | class_instance = None, 46 | additional_fields = None, 47 | ): 48 | def field_type_generator(k, v): 49 | # field_type = str if not overrides.get(k) else overrides[k]["type"] 50 | # print(k, v.annotation, v.default) 51 | field_type = v.annotation 52 | 53 | return Optional[field_type] 54 | 55 | def merge_class_params(class_): 56 | all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_))) 57 | parameters = {} 58 | for classes in all_classes: 59 | parameters = {**parameters, **inspect.signature(classes.__init__).parameters} 60 | return parameters 61 | 62 | 63 | self._model_name = model_name 64 | self._class_data = merge_class_params(class_instance) 65 | self._model_def = [ 66 | ModelDef( 67 | field=underscore(k), 68 | field_alias=k, 69 | field_type=field_type_generator(k, v), 70 | field_value=v.default 71 | ) 72 | for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED 73 | ] 74 | 75 | for fields in additional_fields: 76 | self._model_def.append(ModelDef( 77 | field=underscore(fields["key"]), 78 | field_alias=fields["key"], 79 | field_type=fields["type"], 80 | field_value=fields["default"])) 81 | 82 | def generate_model(self): 83 | """ 84 | Creates a pydantic BaseModel 85 | from the json and overrides provided at initialization 86 | """ 87 | fields = { 88 | d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def 89 | } 90 | DynamicModel = create_model(self._model_name, **fields) 91 | DynamicModel.__config__.allow_population_by_field_name = True 92 | DynamicModel.__config__.allow_mutation = True 93 | return DynamicModel 94 | 95 | StableDiffusionProcessingAPI = PydanticModelGenerator( 96 | "StableDiffusionProcessingTxt2Img", 97 | StableDiffusionProcessingTxt2Img, 98 | [{"key": "sampler_index", "type": str, "default": "Euler"}] 99 | ).generate_model() -------------------------------------------------------------------------------- /modules/masking.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFilter, ImageOps 2 | 3 | 4 | def get_crop_region(mask, pad=0): 5 | """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. 6 | For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" 7 | 8 | h, w = mask.shape 9 | 10 | crop_left = 0 11 | for i in range(w): 12 | if not (mask[:, i] == 0).all(): 13 | break 14 | crop_left += 1 15 | 16 | crop_right = 0 17 | for i in reversed(range(w)): 18 | if not (mask[:, i] == 0).all(): 19 | break 20 | crop_right += 1 21 | 22 | crop_top = 0 23 | for i in range(h): 24 | if not (mask[i] == 0).all(): 25 | break 26 | crop_top += 1 27 | 28 | crop_bottom = 0 29 | for i in reversed(range(h)): 30 | if not (mask[i] == 0).all(): 31 | break 32 | crop_bottom += 1 33 | 34 | return ( 35 | int(max(crop_left-pad, 0)), 36 | int(max(crop_top-pad, 0)), 37 | int(min(w - crop_right + pad, w)), 38 | int(min(h - crop_bottom + pad, h)) 39 | ) 40 | 41 | 42 | def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height): 43 | """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region 44 | for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.""" 45 | 46 | x1, y1, x2, y2 = crop_region 47 | 48 | ratio_crop_region = (x2 - x1) / (y2 - y1) 49 | ratio_processing = processing_width / processing_height 50 | 51 | if ratio_crop_region > ratio_processing: 52 | desired_height = (x2 - x1) * ratio_processing 53 | desired_height_diff = int(desired_height - (y2-y1)) 54 | y1 -= desired_height_diff//2 55 | y2 += desired_height_diff - desired_height_diff//2 56 | if y2 >= image_height: 57 | diff = y2 - image_height 58 | y2 -= diff 59 | y1 -= diff 60 | if y1 < 0: 61 | y2 -= y1 62 | y1 -= y1 63 | if y2 >= image_height: 64 | y2 = image_height 65 | else: 66 | desired_width = (y2 - y1) * ratio_processing 67 | desired_width_diff = int(desired_width - (x2-x1)) 68 | x1 -= desired_width_diff//2 69 | x2 += desired_width_diff - desired_width_diff//2 70 | if x2 >= image_width: 71 | diff = x2 - image_width 72 | x2 -= diff 73 | x1 -= diff 74 | if x1 < 0: 75 | x2 -= x1 76 | x1 -= x1 77 | if x2 >= image_width: 78 | x2 = image_width 79 | 80 | return x1, y1, x2, y2 81 | 82 | 83 | def fill(image, mask): 84 | """fills masked regions with colors from image using blur. Not extremely effective.""" 85 | 86 | image_mod = Image.new('RGBA', (image.width, image.height)) 87 | 88 | image_masked = Image.new('RGBa', (image.width, image.height)) 89 | image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L'))) 90 | 91 | image_masked = image_masked.convert('RGBa') 92 | 93 | for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]: 94 | blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA') 95 | for _ in range(repeats): 96 | image_mod.alpha_composite(blurred) 97 | 98 | return image_mod.convert("RGB") 99 | 100 | -------------------------------------------------------------------------------- /modules/lowvram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules.devices import get_optimal_device 3 | 4 | module_in_gpu = None 5 | cpu = torch.device("cpu") 6 | device = gpu = get_optimal_device() 7 | 8 | 9 | def send_everything_to_cpu(): 10 | global module_in_gpu 11 | 12 | if module_in_gpu is not None: 13 | module_in_gpu.to(cpu) 14 | 15 | module_in_gpu = None 16 | 17 | 18 | def setup_for_low_vram(sd_model, use_medvram): 19 | parents = {} 20 | 21 | def send_me_to_gpu(module, _): 22 | """send this module to GPU; send whatever tracked module was previous in GPU to CPU; 23 | we add this as forward_pre_hook to a lot of modules and this way all but one of them will 24 | be in CPU 25 | """ 26 | global module_in_gpu 27 | 28 | module = parents.get(module, module) 29 | 30 | if module_in_gpu == module: 31 | return 32 | 33 | if module_in_gpu is not None: 34 | module_in_gpu.to(cpu) 35 | 36 | module.to(gpu) 37 | module_in_gpu = module 38 | 39 | # see below for register_forward_pre_hook; 40 | # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is 41 | # useless here, and we just replace those methods 42 | def first_stage_model_encode_wrap(self, encoder, x): 43 | send_me_to_gpu(self, None) 44 | return encoder(x) 45 | 46 | def first_stage_model_decode_wrap(self, decoder, z): 47 | send_me_to_gpu(self, None) 48 | return decoder(z) 49 | 50 | # remove three big modules, cond, first_stage, and unet from the model and then 51 | # send the model to GPU. Then put modules back. the modules will be in CPU. 52 | stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model 53 | sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None 54 | sd_model.to(device) 55 | sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored 56 | 57 | # register hooks for those the first two models 58 | sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) 59 | sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) 60 | sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x) 61 | sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z) 62 | parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model 63 | 64 | if use_medvram: 65 | sd_model.model.register_forward_pre_hook(send_me_to_gpu) 66 | else: 67 | diff_model = sd_model.model.diffusion_model 68 | 69 | # the third remaining model is still too big for 4 GB, so we also do the same for its submodules 70 | # so that only one of them is in GPU at a time 71 | stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed 72 | diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None 73 | sd_model.model.to(device) 74 | diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored 75 | 76 | # install hooks for bits of third model 77 | diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu) 78 | for block in diff_model.input_blocks: 79 | block.register_forward_pre_hook(send_me_to_gpu) 80 | diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu) 81 | for block in diff_model.output_blocks: 82 | block.register_forward_pre_hook(send_me_to_gpu) 83 | -------------------------------------------------------------------------------- /scripts/sd_upscale.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import modules.scripts as scripts 4 | import gradio as gr 5 | from PIL import Image 6 | 7 | from modules import processing, shared, sd_samplers, images, devices 8 | from modules.processing import Processed 9 | from modules.shared import opts, cmd_opts, state 10 | 11 | 12 | class Script(scripts.Script): 13 | def title(self): 14 | return "SD upscale" 15 | 16 | def show(self, is_img2img): 17 | return is_img2img 18 | 19 | def ui(self, is_img2img): 20 | info = gr.HTML("

Will upscale the image to twice the dimensions; use width and height sliders to set tile size

") 21 | overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False) 22 | upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", visible=False) 23 | 24 | return [info, overlap, upscaler_index] 25 | 26 | def run(self, p, _, overlap, upscaler_index): 27 | processing.fix_seed(p) 28 | upscaler = shared.sd_upscalers[upscaler_index] 29 | 30 | p.extra_generation_params["SD upscale overlap"] = overlap 31 | p.extra_generation_params["SD upscale upscaler"] = upscaler.name 32 | 33 | initial_info = None 34 | seed = p.seed 35 | 36 | init_img = p.init_images[0] 37 | 38 | if(upscaler.name != "None"): 39 | img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path) 40 | else: 41 | img = init_img 42 | 43 | devices.torch_gc() 44 | 45 | grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap) 46 | 47 | batch_size = p.batch_size 48 | upscale_count = p.n_iter 49 | p.n_iter = 1 50 | p.do_not_save_grid = True 51 | p.do_not_save_samples = True 52 | 53 | work = [] 54 | 55 | for y, h, row in grid.tiles: 56 | for tiledata in row: 57 | work.append(tiledata[2]) 58 | 59 | batch_count = math.ceil(len(work) / batch_size) 60 | state.job_count = batch_count * upscale_count 61 | 62 | print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.") 63 | 64 | result_images = [] 65 | for n in range(upscale_count): 66 | start_seed = seed + n 67 | p.seed = start_seed 68 | 69 | work_results = [] 70 | for i in range(batch_count): 71 | p.batch_size = batch_size 72 | p.init_images = work[i*batch_size:(i+1)*batch_size] 73 | 74 | state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}" 75 | processed = processing.process_images(p) 76 | 77 | if initial_info is None: 78 | initial_info = processed.info 79 | 80 | p.seed = processed.seed + 1 81 | work_results += processed.images 82 | 83 | image_index = 0 84 | for y, h, row in grid.tiles: 85 | for tiledata in row: 86 | tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) 87 | image_index += 1 88 | 89 | combined_image = images.combine_grid(grid) 90 | result_images.append(combined_image) 91 | 92 | if opts.samples_save: 93 | images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p) 94 | 95 | processed = Processed(p, result_images, seed, initial_info) 96 | 97 | return processed 98 | -------------------------------------------------------------------------------- /modules/upscaler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import abstractmethod 3 | 4 | import PIL 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | 9 | import modules.shared 10 | from modules import modelloader, shared 11 | 12 | LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) 13 | from modules.paths import models_path 14 | 15 | 16 | class Upscaler: 17 | name = None 18 | model_path = None 19 | model_name = None 20 | model_url = None 21 | enable = True 22 | filter = None 23 | model = None 24 | user_path = None 25 | scalers: [] 26 | tile = True 27 | 28 | def __init__(self, create_dirs=False): 29 | self.mod_pad_h = None 30 | self.tile_size = modules.shared.opts.ESRGAN_tile 31 | self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap 32 | self.device = modules.shared.device 33 | self.img = None 34 | self.output = None 35 | self.scale = 1 36 | self.half = not modules.shared.cmd_opts.no_half 37 | self.pre_pad = 0 38 | self.mod_scale = None 39 | 40 | if self.model_path is None and self.name: 41 | self.model_path = os.path.join(models_path, self.name) 42 | if self.model_path and create_dirs: 43 | os.makedirs(self.model_path, exist_ok=True) 44 | 45 | try: 46 | import cv2 47 | self.can_tile = True 48 | except: 49 | pass 50 | 51 | @abstractmethod 52 | def do_upscale(self, img: PIL.Image, selected_model: str): 53 | return img 54 | 55 | def upscale(self, img: PIL.Image, scale: int, selected_model: str = None): 56 | self.scale = scale 57 | dest_w = img.width * scale 58 | dest_h = img.height * scale 59 | for i in range(3): 60 | if img.width >= dest_w and img.height >= dest_h: 61 | break 62 | img = self.do_upscale(img, selected_model) 63 | if img.width != dest_w or img.height != dest_h: 64 | img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) 65 | 66 | return img 67 | 68 | @abstractmethod 69 | def load_model(self, path: str): 70 | pass 71 | 72 | def find_models(self, ext_filter=None) -> list: 73 | return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path) 74 | 75 | def update_status(self, prompt): 76 | print(f"\nextras: {prompt}", file=shared.progress_print_out) 77 | 78 | 79 | class UpscalerData: 80 | name = None 81 | data_path = None 82 | scale: int = 4 83 | scaler: Upscaler = None 84 | model: None 85 | 86 | def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None): 87 | self.name = name 88 | self.data_path = path 89 | self.scaler = upscaler 90 | self.scale = scale 91 | self.model = model 92 | 93 | 94 | class UpscalerNone(Upscaler): 95 | name = "None" 96 | scalers = [] 97 | 98 | def load_model(self, path): 99 | pass 100 | 101 | def do_upscale(self, img, selected_model=None): 102 | return img 103 | 104 | def __init__(self, dirname=None): 105 | super().__init__(False) 106 | self.scalers = [UpscalerData("None", None, self)] 107 | 108 | 109 | class UpscalerLanczos(Upscaler): 110 | scalers = [] 111 | 112 | def do_upscale(self, img, selected_model=None): 113 | return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS) 114 | 115 | def load_model(self, _): 116 | pass 117 | 118 | def __init__(self, dirname=None): 119 | super().__init__(False) 120 | self.name = "Lanczos" 121 | self.scalers = [UpscalerData("Lanczos", None, self)] 122 | 123 | -------------------------------------------------------------------------------- /modules/styles.py: -------------------------------------------------------------------------------- 1 | # We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime 2 | from __future__ import annotations 3 | 4 | import csv 5 | import os 6 | import os.path 7 | import typing 8 | import collections.abc as abc 9 | import tempfile 10 | import shutil 11 | 12 | if typing.TYPE_CHECKING: 13 | # Only import this when code is being type-checked, it doesn't have any effect at runtime 14 | from .processing import StableDiffusionProcessing 15 | 16 | 17 | class PromptStyle(typing.NamedTuple): 18 | name: str 19 | prompt: str 20 | negative_prompt: str 21 | 22 | 23 | def merge_prompts(style_prompt: str, prompt: str) -> str: 24 | if "{prompt}" in style_prompt: 25 | res = style_prompt.replace("{prompt}", prompt) 26 | else: 27 | parts = filter(None, (prompt.strip(), style_prompt.strip())) 28 | res = ", ".join(parts) 29 | 30 | return res 31 | 32 | 33 | def apply_styles_to_prompt(prompt, styles): 34 | for style in styles: 35 | prompt = merge_prompts(style, prompt) 36 | 37 | return prompt 38 | 39 | 40 | class StyleDatabase: 41 | def __init__(self, path: str): 42 | self.no_style = PromptStyle("None", "", "") 43 | self.styles = {"None": self.no_style} 44 | 45 | if not os.path.exists(path): 46 | return 47 | 48 | with open(path, "r", encoding="utf-8-sig", newline='') as file: 49 | reader = csv.DictReader(file) 50 | for row in reader: 51 | # Support loading old CSV format with "name, text"-columns 52 | prompt = row["prompt"] if "prompt" in row else row["text"] 53 | negative_prompt = row.get("negative_prompt", "") 54 | self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt) 55 | 56 | def get_style_prompts(self, styles): 57 | return [self.styles.get(x, self.no_style).prompt for x in styles] 58 | 59 | def get_negative_style_prompts(self, styles): 60 | return [self.styles.get(x, self.no_style).negative_prompt for x in styles] 61 | 62 | def apply_styles_to_prompt(self, prompt, styles): 63 | return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles]) 64 | 65 | def apply_negative_styles_to_prompt(self, prompt, styles): 66 | return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]) 67 | 68 | def apply_styles(self, p: StableDiffusionProcessing) -> None: 69 | if isinstance(p.prompt, list): 70 | p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt] 71 | else: 72 | p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles) 73 | 74 | if isinstance(p.negative_prompt, list): 75 | p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt] 76 | else: 77 | p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles) 78 | 79 | def save_styles(self, path: str) -> None: 80 | # Write to temporary file first, so we don't nuke the file if something goes wrong 81 | fd, temp_path = tempfile.mkstemp(".csv") 82 | with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: 83 | # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, 84 | # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() 85 | writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) 86 | writer.writeheader() 87 | writer.writerows(style._asdict() for k, style in self.styles.items()) 88 | 89 | # Always keep a backup file around 90 | if os.path.exists(path): 91 | shutil.move(path, path + ".bak") 92 | shutil.move(temp_path, path) 93 | -------------------------------------------------------------------------------- /modules/bsrgan_model_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | 7 | 8 | def initialize_weights(net_l, scale=1): 9 | if not isinstance(net_l, list): 10 | net_l = [net_l] 11 | for net in net_l: 12 | for m in net.modules(): 13 | if isinstance(m, nn.Conv2d): 14 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 15 | m.weight.data *= scale # for residual block 16 | if m.bias is not None: 17 | m.bias.data.zero_() 18 | elif isinstance(m, nn.Linear): 19 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 20 | m.weight.data *= scale 21 | if m.bias is not None: 22 | m.bias.data.zero_() 23 | elif isinstance(m, nn.BatchNorm2d): 24 | init.constant_(m.weight, 1) 25 | init.constant_(m.bias.data, 0.0) 26 | 27 | 28 | def make_layer(block, n_layers): 29 | layers = [] 30 | for _ in range(n_layers): 31 | layers.append(block()) 32 | return nn.Sequential(*layers) 33 | 34 | 35 | class ResidualDenseBlock_5C(nn.Module): 36 | def __init__(self, nf=64, gc=32, bias=True): 37 | super(ResidualDenseBlock_5C, self).__init__() 38 | # gc: growth channel, i.e. intermediate channels 39 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) 40 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) 41 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) 42 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) 43 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) 44 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 45 | 46 | # initialization 47 | initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 48 | 49 | def forward(self, x): 50 | x1 = self.lrelu(self.conv1(x)) 51 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 52 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 53 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 54 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 55 | return x5 * 0.2 + x 56 | 57 | 58 | class RRDB(nn.Module): 59 | '''Residual in Residual Dense Block''' 60 | 61 | def __init__(self, nf, gc=32): 62 | super(RRDB, self).__init__() 63 | self.RDB1 = ResidualDenseBlock_5C(nf, gc) 64 | self.RDB2 = ResidualDenseBlock_5C(nf, gc) 65 | self.RDB3 = ResidualDenseBlock_5C(nf, gc) 66 | 67 | def forward(self, x): 68 | out = self.RDB1(x) 69 | out = self.RDB2(out) 70 | out = self.RDB3(out) 71 | return out * 0.2 + x 72 | 73 | 74 | class RRDBNet(nn.Module): 75 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4): 76 | super(RRDBNet, self).__init__() 77 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) 78 | self.sf = sf 79 | 80 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 81 | self.RRDB_trunk = make_layer(RRDB_block_f, nb) 82 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 83 | #### upsampling 84 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 85 | if self.sf==4: 86 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 87 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 88 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 89 | 90 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 91 | 92 | def forward(self, x): 93 | fea = self.conv_first(x) 94 | trunk = self.trunk_conv(self.RRDB_trunk(fea)) 95 | fea = fea + trunk 96 | 97 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) 98 | if self.sf==4: 99 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) 100 | out = self.conv_last(self.lrelu(self.HRconv(fea))) 101 | 102 | return out -------------------------------------------------------------------------------- /javascript/aspectRatioOverlay.js: -------------------------------------------------------------------------------- 1 | 2 | let currentWidth = null; 3 | let currentHeight = null; 4 | let arFrameTimeout = setTimeout(function(){},0); 5 | 6 | function dimensionChange(e,dimname){ 7 | 8 | if(dimname == 'Width'){ 9 | currentWidth = e.target.value*1.0 10 | } 11 | if(dimname == 'Height'){ 12 | currentHeight = e.target.value*1.0 13 | } 14 | 15 | var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) 16 | 17 | if(!inImg2img){ 18 | return; 19 | } 20 | 21 | var img2imgMode = gradioApp().querySelector('#mode_img2img.tabs > div > button.rounded-t-lg.border-gray-200') 22 | if(img2imgMode){ 23 | img2imgMode=img2imgMode.innerText 24 | }else{ 25 | return; 26 | } 27 | 28 | var redrawImage = gradioApp().querySelector('div[data-testid=image] img'); 29 | var inpaintImage = gradioApp().querySelector('#img2maskimg div[data-testid=image] img') 30 | 31 | var targetElement = null; 32 | 33 | if(img2imgMode=='img2img' && redrawImage){ 34 | targetElement = redrawImage; 35 | }else if(img2imgMode=='Inpaint' && inpaintImage){ 36 | targetElement = inpaintImage; 37 | } 38 | 39 | if(targetElement){ 40 | 41 | var arPreviewRect = gradioApp().querySelector('#imageARPreview'); 42 | if(!arPreviewRect){ 43 | arPreviewRect = document.createElement('div') 44 | arPreviewRect.id = "imageARPreview"; 45 | gradioApp().getRootNode().appendChild(arPreviewRect) 46 | } 47 | 48 | 49 | 50 | var viewportOffset = targetElement.getBoundingClientRect(); 51 | 52 | viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight ) 53 | 54 | scaledx = targetElement.naturalWidth*viewportscale 55 | scaledy = targetElement.naturalHeight*viewportscale 56 | 57 | cleintRectTop = (viewportOffset.top+window.scrollY) 58 | cleintRectLeft = (viewportOffset.left+window.scrollX) 59 | cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2) 60 | cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2) 61 | 62 | viewRectTop = cleintRectCentreY-(scaledy/2) 63 | viewRectLeft = cleintRectCentreX-(scaledx/2) 64 | arRectWidth = scaledx 65 | arRectHeight = scaledy 66 | 67 | arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight ) 68 | arscaledx = currentWidth*arscale 69 | arscaledy = currentHeight*arscale 70 | 71 | arRectTop = cleintRectCentreY-(arscaledy/2) 72 | arRectLeft = cleintRectCentreX-(arscaledx/2) 73 | arRectWidth = arscaledx 74 | arRectHeight = arscaledy 75 | 76 | arPreviewRect.style.top = arRectTop+'px'; 77 | arPreviewRect.style.left = arRectLeft+'px'; 78 | arPreviewRect.style.width = arRectWidth+'px'; 79 | arPreviewRect.style.height = arRectHeight+'px'; 80 | 81 | clearTimeout(arFrameTimeout); 82 | arFrameTimeout = setTimeout(function(){ 83 | arPreviewRect.style.display = 'none'; 84 | },2000); 85 | 86 | arPreviewRect.style.display = 'block'; 87 | 88 | } 89 | 90 | } 91 | 92 | 93 | onUiUpdate(function(){ 94 | var arPreviewRect = gradioApp().querySelector('#imageARPreview'); 95 | if(arPreviewRect){ 96 | arPreviewRect.style.display = 'none'; 97 | } 98 | var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) 99 | if(inImg2img){ 100 | let inputs = gradioApp().querySelectorAll('input'); 101 | inputs.forEach(function(e){ 102 | let parentLabel = e.parentElement.querySelector('label') 103 | if(parentLabel && parentLabel.innerText){ 104 | if(!e.classList.contains('scrollwatch')){ 105 | if(parentLabel.innerText == 'Width' || parentLabel.innerText == 'Height'){ 106 | e.addEventListener('input', function(e){dimensionChange(e,parentLabel.innerText)} ) 107 | e.classList.add('scrollwatch') 108 | } 109 | if(parentLabel.innerText == 'Width'){ 110 | currentWidth = e.value*1.0 111 | } 112 | if(parentLabel.innerText == 'Height'){ 113 | currentHeight = e.value*1.0 114 | } 115 | } 116 | } 117 | }) 118 | } 119 | }); 120 | -------------------------------------------------------------------------------- /modules/gfpgan_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | import facexlib 6 | import gfpgan 7 | 8 | import modules.face_restoration 9 | from modules import shared, devices, modelloader 10 | from modules.paths import models_path 11 | 12 | model_dir = "GFPGAN" 13 | user_path = None 14 | model_path = os.path.join(models_path, model_dir) 15 | model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" 16 | have_gfpgan = False 17 | loaded_gfpgan_model = None 18 | 19 | 20 | def gfpgann(): 21 | global loaded_gfpgan_model 22 | global model_path 23 | if loaded_gfpgan_model is not None: 24 | loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan) 25 | return loaded_gfpgan_model 26 | 27 | if gfpgan_constructor is None: 28 | return None 29 | 30 | models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") 31 | if len(models) == 1 and "http" in models[0]: 32 | model_file = models[0] 33 | elif len(models) != 0: 34 | latest_file = max(models, key=os.path.getctime) 35 | model_file = latest_file 36 | else: 37 | print("Unable to load gfpgan model!") 38 | return None 39 | model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) 40 | loaded_gfpgan_model = model 41 | 42 | return model 43 | 44 | 45 | def send_model_to(model, device): 46 | model.gfpgan.to(device) 47 | model.face_helper.face_det.to(device) 48 | model.face_helper.face_parse.to(device) 49 | 50 | 51 | def gfpgan_fix_faces(np_image): 52 | model = gfpgann() 53 | if model is None: 54 | return np_image 55 | 56 | send_model_to(model, devices.device_gfpgan) 57 | 58 | np_image_bgr = np_image[:, :, ::-1] 59 | cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) 60 | np_image = gfpgan_output_bgr[:, :, ::-1] 61 | 62 | model.face_helper.clean_all() 63 | 64 | if shared.opts.face_restoration_unload: 65 | send_model_to(model, devices.cpu) 66 | 67 | return np_image 68 | 69 | 70 | gfpgan_constructor = None 71 | 72 | 73 | def setup_model(dirname): 74 | global model_path 75 | if not os.path.exists(model_path): 76 | os.makedirs(model_path) 77 | 78 | try: 79 | from gfpgan import GFPGANer 80 | from facexlib import detection, parsing 81 | global user_path 82 | global have_gfpgan 83 | global gfpgan_constructor 84 | 85 | load_file_from_url_orig = gfpgan.utils.load_file_from_url 86 | facex_load_file_from_url_orig = facexlib.detection.load_file_from_url 87 | facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url 88 | 89 | def my_load_file_from_url(**kwargs): 90 | return load_file_from_url_orig(**dict(kwargs, model_dir=model_path)) 91 | 92 | def facex_load_file_from_url(**kwargs): 93 | return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None)) 94 | 95 | def facex_load_file_from_url2(**kwargs): 96 | return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None)) 97 | 98 | gfpgan.utils.load_file_from_url = my_load_file_from_url 99 | facexlib.detection.load_file_from_url = facex_load_file_from_url 100 | facexlib.parsing.load_file_from_url = facex_load_file_from_url2 101 | user_path = dirname 102 | have_gfpgan = True 103 | gfpgan_constructor = GFPGANer 104 | 105 | class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): 106 | def name(self): 107 | return "GFPGAN" 108 | 109 | def restore(self, np_image): 110 | return gfpgan_fix_faces(np_image) 111 | 112 | shared.face_restorers.append(FaceRestorerGFPGAN()) 113 | except Exception: 114 | print("Error setting up GFPGAN:", file=sys.stderr) 115 | print(traceback.format_exc(), file=sys.stderr) 116 | -------------------------------------------------------------------------------- /javascript/localization.js: -------------------------------------------------------------------------------- 1 | 2 | // localization = {} -- the dict with translations is created by the backend 3 | 4 | ignore_ids_for_localization={ 5 | setting_sd_hypernetwork: 'OPTION', 6 | setting_sd_model_checkpoint: 'OPTION', 7 | setting_realesrgan_enabled_models: 'OPTION', 8 | modelmerger_primary_model_name: 'OPTION', 9 | modelmerger_secondary_model_name: 'OPTION', 10 | modelmerger_tertiary_model_name: 'OPTION', 11 | train_embedding: 'OPTION', 12 | train_hypernetwork: 'OPTION', 13 | txt2img_style_index: 'OPTION', 14 | txt2img_style2_index: 'OPTION', 15 | img2img_style_index: 'OPTION', 16 | img2img_style2_index: 'OPTION', 17 | setting_random_artist_categories: 'SPAN', 18 | setting_face_restoration_model: 'SPAN', 19 | setting_realesrgan_enabled_models: 'SPAN', 20 | extras_upscaler_1: 'SPAN', 21 | extras_upscaler_2: 'SPAN', 22 | } 23 | 24 | re_num = /^[\.\d]+$/ 25 | re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u 26 | 27 | original_lines = {} 28 | translated_lines = {} 29 | 30 | function textNodesUnder(el){ 31 | var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false); 32 | while(n=walk.nextNode()) a.push(n); 33 | return a; 34 | } 35 | 36 | function canBeTranslated(node, text){ 37 | if(! text) return false; 38 | if(! node.parentElement) return false; 39 | 40 | parentType = node.parentElement.nodeName 41 | if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false; 42 | 43 | if (parentType=='OPTION' || parentType=='SPAN'){ 44 | pnode = node 45 | for(var level=0; level<4; level++){ 46 | pnode = pnode.parentElement 47 | if(! pnode) break; 48 | 49 | if(ignore_ids_for_localization[pnode.id] == parentType) return false; 50 | } 51 | } 52 | 53 | if(re_num.test(text)) return false; 54 | if(re_emoji.test(text)) return false; 55 | return true 56 | } 57 | 58 | function getTranslation(text){ 59 | if(! text) return undefined 60 | 61 | if(translated_lines[text] === undefined){ 62 | original_lines[text] = 1 63 | } 64 | 65 | tl = localization[text] 66 | if(tl !== undefined){ 67 | translated_lines[tl] = 1 68 | } 69 | 70 | return tl 71 | } 72 | 73 | function processTextNode(node){ 74 | text = node.textContent.trim() 75 | 76 | if(! canBeTranslated(node, text)) return 77 | 78 | tl = getTranslation(text) 79 | if(tl !== undefined){ 80 | node.textContent = tl 81 | } 82 | } 83 | 84 | function processNode(node){ 85 | if(node.nodeType == 3){ 86 | processTextNode(node) 87 | return 88 | } 89 | 90 | if(node.title){ 91 | tl = getTranslation(node.title) 92 | if(tl !== undefined){ 93 | node.title = tl 94 | } 95 | } 96 | 97 | if(node.placeholder){ 98 | tl = getTranslation(node.placeholder) 99 | if(tl !== undefined){ 100 | node.placeholder = tl 101 | } 102 | } 103 | 104 | textNodesUnder(node).forEach(function(node){ 105 | processTextNode(node) 106 | }) 107 | } 108 | 109 | function dumpTranslations(){ 110 | dumped = {} 111 | 112 | Object.keys(original_lines).forEach(function(text){ 113 | if(dumped[text] !== undefined) return 114 | 115 | dumped[text] = localization[text] || text 116 | }) 117 | 118 | return dumped 119 | } 120 | 121 | onUiUpdate(function(m){ 122 | m.forEach(function(mutation){ 123 | mutation.addedNodes.forEach(function(node){ 124 | processNode(node) 125 | }) 126 | }); 127 | }) 128 | 129 | 130 | document.addEventListener("DOMContentLoaded", function() { 131 | processNode(gradioApp()) 132 | }) 133 | 134 | function download_localization() { 135 | text = JSON.stringify(dumpTranslations(), null, 4) 136 | 137 | var element = document.createElement('a'); 138 | element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text)); 139 | element.setAttribute('download', "localization.json"); 140 | element.style.display = 'none'; 141 | document.body.appendChild(element); 142 | 143 | element.click(); 144 | 145 | document.body.removeChild(element); 146 | } 147 | -------------------------------------------------------------------------------- /webui.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ################################################# 3 | # Please do not make any changes to this file, # 4 | # change the variables in webui-user.sh instead # 5 | ################################################# 6 | # Read variables from webui-user.sh 7 | # shellcheck source=/dev/null 8 | if [[ -f webui-user.sh ]] 9 | then 10 | source ./webui-user.sh 11 | fi 12 | 13 | # Set defaults 14 | # Install directory without trailing slash 15 | if [[ -z "${install_dir}" ]] 16 | then 17 | install_dir="/home/$(whoami)" 18 | fi 19 | 20 | # Name of the subdirectory (defaults to stable-diffusion-webui) 21 | if [[ -z "${clone_dir}" ]] 22 | then 23 | clone_dir="stable-diffusion-webui" 24 | fi 25 | 26 | # python3 executable 27 | if [[ -z "${python_cmd}" ]] 28 | then 29 | python_cmd="python3" 30 | fi 31 | 32 | # git executable 33 | if [[ -z "${GIT}" ]] 34 | then 35 | export GIT="git" 36 | fi 37 | 38 | # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) 39 | if [[ -z "${venv_dir}" ]] 40 | then 41 | venv_dir="venv" 42 | fi 43 | 44 | if [[ -z "${LAUNCH_SCRIPT}" ]] 45 | then 46 | LAUNCH_SCRIPT="launch.py" 47 | fi 48 | 49 | # Disable sentry logging 50 | export ERROR_REPORTING=FALSE 51 | 52 | # Do not reinstall existing pip packages on Debian/Ubuntu 53 | export PIP_IGNORE_INSTALLED=0 54 | 55 | # Pretty print 56 | delimiter="################################################################" 57 | 58 | printf "\n%s\n" "${delimiter}" 59 | printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n" 60 | printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m" 61 | printf "\n%s\n" "${delimiter}" 62 | 63 | # Do not run as root 64 | if [[ $(id -u) -eq 0 ]] 65 | then 66 | printf "\n%s\n" "${delimiter}" 67 | printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m" 68 | printf "\n%s\n" "${delimiter}" 69 | exit 1 70 | else 71 | printf "\n%s\n" "${delimiter}" 72 | printf "Running on \e[1m\e[32m%s\e[0m user" "$(whoami)" 73 | printf "\n%s\n" "${delimiter}" 74 | fi 75 | 76 | if [[ -d .git ]] 77 | then 78 | printf "\n%s\n" "${delimiter}" 79 | printf "Repo already cloned, using it as install directory" 80 | printf "\n%s\n" "${delimiter}" 81 | install_dir="${PWD}/../" 82 | clone_dir="${PWD##*/}" 83 | fi 84 | 85 | # Check prerequisites 86 | for preq in "${GIT}" "${python_cmd}" 87 | do 88 | if ! hash "${preq}" &>/dev/null 89 | then 90 | printf "\n%s\n" "${delimiter}" 91 | printf "\e[1m\e[31mERROR: %s is not installed, aborting...\e[0m" "${preq}" 92 | printf "\n%s\n" "${delimiter}" 93 | exit 1 94 | fi 95 | done 96 | 97 | if ! "${python_cmd}" -c "import venv" &>/dev/null 98 | then 99 | printf "\n%s\n" "${delimiter}" 100 | printf "\e[1m\e[31mERROR: python3-venv is not installed, aborting...\e[0m" 101 | printf "\n%s\n" "${delimiter}" 102 | exit 1 103 | fi 104 | 105 | printf "\n%s\n" "${delimiter}" 106 | printf "Clone or update stable-diffusion-webui" 107 | printf "\n%s\n" "${delimiter}" 108 | cd "${install_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/, aborting...\e[0m" "${install_dir}"; exit 1; } 109 | if [[ -d "${clone_dir}" ]] 110 | then 111 | cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } 112 | "${GIT}" pull 113 | else 114 | "${GIT}" clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git "${clone_dir}" 115 | cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } 116 | fi 117 | 118 | printf "\n%s\n" "${delimiter}" 119 | printf "Create and activate python venv" 120 | printf "\n%s\n" "${delimiter}" 121 | cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } 122 | if [[ ! -d "${venv_dir}" ]] 123 | then 124 | "${python_cmd}" -m venv "${venv_dir}" 125 | first_launch=1 126 | fi 127 | # shellcheck source=/dev/null 128 | if [[ -f "${venv_dir}"/bin/activate ]] 129 | then 130 | source "${venv_dir}"/bin/activate 131 | else 132 | printf "\n%s\n" "${delimiter}" 133 | printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m" 134 | printf "\n%s\n" "${delimiter}" 135 | exit 1 136 | fi 137 | 138 | printf "\n%s\n" "${delimiter}" 139 | printf "Launching launch.py..." 140 | printf "\n%s\n" "${delimiter}" 141 | "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" 142 | -------------------------------------------------------------------------------- /modules/safe.py: -------------------------------------------------------------------------------- 1 | # this code is adapted from the script contributed by anon from /h/ 2 | 3 | import io 4 | import pickle 5 | import collections 6 | import sys 7 | import traceback 8 | 9 | import torch 10 | import numpy 11 | import _codecs 12 | import zipfile 13 | import re 14 | 15 | 16 | # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage 17 | TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage 18 | 19 | 20 | def encode(*args): 21 | out = _codecs.encode(*args) 22 | return out 23 | 24 | 25 | class RestrictedUnpickler(pickle.Unpickler): 26 | def persistent_load(self, saved_id): 27 | assert saved_id[0] == 'storage' 28 | return TypedStorage() 29 | 30 | def find_class(self, module, name): 31 | if module == 'collections' and name == 'OrderedDict': 32 | return getattr(collections, name) 33 | if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: 34 | return getattr(torch._utils, name) 35 | if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']: 36 | return getattr(torch, name) 37 | if module == 'torch.nn.modules.container' and name in ['ParameterDict']: 38 | return getattr(torch.nn.modules.container, name) 39 | if module == 'numpy.core.multiarray' and name == 'scalar': 40 | return numpy.core.multiarray.scalar 41 | if module == 'numpy' and name == 'dtype': 42 | return numpy.dtype 43 | if module == '_codecs' and name == 'encode': 44 | return encode 45 | if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': 46 | import pytorch_lightning.callbacks 47 | return pytorch_lightning.callbacks.model_checkpoint 48 | if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': 49 | import pytorch_lightning.callbacks.model_checkpoint 50 | return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint 51 | if module == "__builtin__" and name == 'set': 52 | return set 53 | 54 | # Forbid everything else. 55 | raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden") 56 | 57 | 58 | allowed_zip_names = ["archive/data.pkl", "archive/version"] 59 | allowed_zip_names_re = re.compile(r"^archive/data/\d+$") 60 | 61 | 62 | def check_zip_filenames(filename, names): 63 | for name in names: 64 | if name in allowed_zip_names: 65 | continue 66 | if allowed_zip_names_re.match(name): 67 | continue 68 | 69 | raise Exception(f"bad file inside {filename}: {name}") 70 | 71 | 72 | def check_pt(filename): 73 | try: 74 | 75 | # new pytorch format is a zip file 76 | with zipfile.ZipFile(filename) as z: 77 | check_zip_filenames(filename, z.namelist()) 78 | 79 | with z.open('archive/data.pkl') as file: 80 | unpickler = RestrictedUnpickler(file) 81 | unpickler.load() 82 | 83 | except zipfile.BadZipfile: 84 | 85 | # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle 86 | with open(filename, "rb") as file: 87 | unpickler = RestrictedUnpickler(file) 88 | for i in range(5): 89 | unpickler.load() 90 | 91 | 92 | def load(filename, *args, **kwargs): 93 | from modules import shared 94 | 95 | try: 96 | if not shared.cmd_opts.disable_safe_unpickle: 97 | check_pt(filename) 98 | 99 | except pickle.UnpicklingError: 100 | print(f"Error verifying pickled file from {filename}:", file=sys.stderr) 101 | print(traceback.format_exc(), file=sys.stderr) 102 | print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) 103 | print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) 104 | return None 105 | 106 | except Exception: 107 | print(f"Error verifying pickled file from {filename}:", file=sys.stderr) 108 | print(traceback.format_exc(), file=sys.stderr) 109 | print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) 110 | print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) 111 | return None 112 | 113 | return unsafe_torch_load(filename, *args, **kwargs) 114 | 115 | 116 | unsafe_torch_load = torch.load 117 | torch.load = load 118 | -------------------------------------------------------------------------------- /modules/textual_inversion/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | import torch 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | 9 | import random 10 | import tqdm 11 | from modules import devices, shared 12 | import re 13 | 14 | re_numbers_at_start = re.compile(r"^[-\d]+\s*") 15 | 16 | 17 | class DatasetEntry: 18 | def __init__(self, filename=None, latent=None, filename_text=None): 19 | self.filename = filename 20 | self.latent = latent 21 | self.filename_text = filename_text 22 | self.cond = None 23 | self.cond_text = None 24 | 25 | 26 | class PersonalizedBase(Dataset): 27 | def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): 28 | re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None 29 | 30 | self.placeholder_token = placeholder_token 31 | 32 | self.batch_size = batch_size 33 | self.width = width 34 | self.height = height 35 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 36 | 37 | self.dataset = [] 38 | 39 | with open(template_file, "r") as file: 40 | lines = [x.strip() for x in file.readlines()] 41 | 42 | self.lines = lines 43 | 44 | assert data_root, 'dataset directory not specified' 45 | 46 | cond_model = shared.sd_model.cond_stage_model 47 | 48 | self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] 49 | print("Preparing dataset...") 50 | for path in tqdm.tqdm(self.image_paths): 51 | try: 52 | image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) 53 | except Exception: 54 | continue 55 | 56 | text_filename = os.path.splitext(path)[0] + ".txt" 57 | filename = os.path.basename(path) 58 | 59 | if os.path.exists(text_filename): 60 | with open(text_filename, "r", encoding="utf8") as file: 61 | filename_text = file.read() 62 | else: 63 | filename_text = os.path.splitext(filename)[0] 64 | filename_text = re.sub(re_numbers_at_start, '', filename_text) 65 | if re_word: 66 | tokens = re_word.findall(filename_text) 67 | filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens) 68 | 69 | npimage = np.array(image).astype(np.uint8) 70 | npimage = (npimage / 127.5 - 1.0).astype(np.float32) 71 | 72 | torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) 73 | torchdata = torch.moveaxis(torchdata, 2, 0) 74 | 75 | init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() 76 | init_latent = init_latent.to(devices.cpu) 77 | 78 | entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) 79 | 80 | if include_cond: 81 | entry.cond_text = self.create_text(filename_text) 82 | entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) 83 | 84 | self.dataset.append(entry) 85 | 86 | assert len(self.dataset) > 0, "No images have been found in the dataset." 87 | self.length = len(self.dataset) * repeats // batch_size 88 | 89 | self.initial_indexes = np.arange(len(self.dataset)) 90 | self.indexes = None 91 | self.shuffle() 92 | 93 | def shuffle(self): 94 | self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()] 95 | 96 | def create_text(self, filename_text): 97 | text = random.choice(self.lines) 98 | text = text.replace("[name]", self.placeholder_token) 99 | text = text.replace("[filewords]", filename_text) 100 | return text 101 | 102 | def __len__(self): 103 | return self.length 104 | 105 | def __getitem__(self, i): 106 | res = [] 107 | 108 | for j in range(self.batch_size): 109 | position = i * self.batch_size + j 110 | if position % len(self.indexes) == 0: 111 | self.shuffle() 112 | 113 | index = self.indexes[position % len(self.indexes)] 114 | entry = self.dataset[index] 115 | 116 | if entry.cond is None: 117 | entry.cond_text = self.create_text(entry.filename_text) 118 | 119 | res.append(entry) 120 | 121 | return res 122 | -------------------------------------------------------------------------------- /modules/textual_inversion/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image, ImageOps 3 | import platform 4 | import sys 5 | import tqdm 6 | import time 7 | 8 | from modules import shared, images 9 | from modules.shared import opts, cmd_opts 10 | if cmd_opts.deepdanbooru: 11 | import modules.deepbooru as deepbooru 12 | 13 | 14 | def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False): 15 | try: 16 | if process_caption: 17 | shared.interrogator.load() 18 | 19 | if process_caption_deepbooru: 20 | db_opts = deepbooru.create_deepbooru_opts() 21 | db_opts[deepbooru.OPT_INCLUDE_RANKS] = False 22 | deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) 23 | 24 | preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru) 25 | 26 | finally: 27 | 28 | if process_caption: 29 | shared.interrogator.send_blip_to_ram() 30 | 31 | if process_caption_deepbooru: 32 | deepbooru.release_process() 33 | 34 | 35 | 36 | def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False): 37 | width = process_width 38 | height = process_height 39 | src = os.path.abspath(process_src) 40 | dst = os.path.abspath(process_dst) 41 | 42 | assert src != dst, 'same directory specified as source and destination' 43 | 44 | os.makedirs(dst, exist_ok=True) 45 | 46 | files = os.listdir(src) 47 | 48 | shared.state.textinfo = "Preprocessing..." 49 | shared.state.job_count = len(files) 50 | 51 | def save_pic_with_caption(image, index, existing_caption=None): 52 | caption = "" 53 | 54 | if process_caption: 55 | caption += shared.interrogator.generate_caption(image) 56 | 57 | if process_caption_deepbooru: 58 | if len(caption) > 0: 59 | caption += ", " 60 | caption += deepbooru.get_tags_from_process(image) 61 | 62 | filename_part = filename 63 | filename_part = os.path.splitext(filename_part)[0] 64 | filename_part = os.path.basename(filename_part) 65 | 66 | basename = f"{index:05}-{subindex[0]}-{filename_part}" 67 | image.save(os.path.join(dst, f"{basename}.png")) 68 | 69 | if preprocess_txt_action == 'prepend' and existing_caption: 70 | caption = existing_caption + ' ' + caption 71 | elif preprocess_txt_action == 'append' and existing_caption: 72 | caption = caption + ' ' + existing_caption 73 | elif preprocess_txt_action == 'copy' and existing_caption: 74 | caption = existing_caption 75 | 76 | caption = caption.strip() 77 | 78 | if len(caption) > 0: 79 | with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file: 80 | file.write(caption) 81 | 82 | subindex[0] += 1 83 | 84 | def save_pic(image, index, existing_caption=None): 85 | save_pic_with_caption(image, index, existing_caption=existing_caption) 86 | 87 | if process_flip: 88 | save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption) 89 | 90 | for index, imagefile in enumerate(tqdm.tqdm(files)): 91 | subindex = [0] 92 | filename = os.path.join(src, imagefile) 93 | try: 94 | img = Image.open(filename).convert("RGB") 95 | except Exception: 96 | continue 97 | 98 | existing_caption = None 99 | 100 | try: 101 | existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read() 102 | except Exception as e: 103 | print(e) 104 | 105 | if shared.state.interrupted: 106 | break 107 | 108 | ratio = img.height / img.width 109 | is_tall = ratio > 1.35 110 | is_wide = ratio < 1 / 1.35 111 | 112 | if process_split and is_tall: 113 | img = img.resize((width, height * img.height // img.width)) 114 | 115 | top = img.crop((0, 0, width, height)) 116 | save_pic(top, index, existing_caption=existing_caption) 117 | 118 | bot = img.crop((0, img.height - height, width, img.height)) 119 | save_pic(bot, index, existing_caption=existing_caption) 120 | elif process_split and is_wide: 121 | img = img.resize((width * img.width // img.height, height)) 122 | 123 | left = img.crop((0, 0, width, height)) 124 | save_pic(left, index, existing_caption=existing_caption) 125 | 126 | right = img.crop((img.width - width, 0, img.width, height)) 127 | save_pic(right, index, existing_caption=existing_caption) 128 | else: 129 | img = images.resize_image(1, img, width, height) 130 | save_pic(img, index, existing_caption=existing_caption) 131 | 132 | shared.state.nextjob() 133 | -------------------------------------------------------------------------------- /scripts/prompts_from_file.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import os 4 | import sys 5 | import traceback 6 | import shlex 7 | 8 | import modules.scripts as scripts 9 | import gradio as gr 10 | 11 | from modules.processing import Processed, process_images 12 | from PIL import Image 13 | from modules.shared import opts, cmd_opts, state 14 | 15 | 16 | def process_string_tag(tag): 17 | return tag 18 | 19 | 20 | def process_int_tag(tag): 21 | return int(tag) 22 | 23 | 24 | def process_float_tag(tag): 25 | return float(tag) 26 | 27 | 28 | def process_boolean_tag(tag): 29 | return True if (tag == "true") else False 30 | 31 | 32 | prompt_tags = { 33 | "sd_model": None, 34 | "outpath_samples": process_string_tag, 35 | "outpath_grids": process_string_tag, 36 | "prompt_for_display": process_string_tag, 37 | "prompt": process_string_tag, 38 | "negative_prompt": process_string_tag, 39 | "styles": process_string_tag, 40 | "seed": process_int_tag, 41 | "subseed_strength": process_float_tag, 42 | "subseed": process_int_tag, 43 | "seed_resize_from_h": process_int_tag, 44 | "seed_resize_from_w": process_int_tag, 45 | "sampler_index": process_int_tag, 46 | "batch_size": process_int_tag, 47 | "n_iter": process_int_tag, 48 | "steps": process_int_tag, 49 | "cfg_scale": process_float_tag, 50 | "width": process_int_tag, 51 | "height": process_int_tag, 52 | "restore_faces": process_boolean_tag, 53 | "tiling": process_boolean_tag, 54 | "do_not_save_samples": process_boolean_tag, 55 | "do_not_save_grid": process_boolean_tag 56 | } 57 | 58 | 59 | def cmdargs(line): 60 | args = shlex.split(line) 61 | pos = 0 62 | res = {} 63 | 64 | while pos < len(args): 65 | arg = args[pos] 66 | 67 | assert arg.startswith("--"), f'must start with "--": {arg}' 68 | tag = arg[2:] 69 | 70 | func = prompt_tags.get(tag, None) 71 | assert func, f'unknown commandline option: {arg}' 72 | 73 | assert pos+1 < len(args), f'missing argument for command line option {arg}' 74 | 75 | val = args[pos+1] 76 | 77 | res[tag] = func(val) 78 | 79 | pos += 2 80 | 81 | return res 82 | 83 | 84 | class Script(scripts.Script): 85 | def title(self): 86 | return "Prompts from file or textbox" 87 | 88 | def ui(self, is_img2img): 89 | # This checkbox would look nicer as two tabs, but there are two problems: 90 | # 1) There is a bug in Gradio 3.3 that prevents visibility from working on Tabs 91 | # 2) Even with Gradio 3.3.1, returning a control (like Tabs) that can't be used as input 92 | # causes a AttributeError: 'Tabs' object has no attribute 'preprocess' assert, 93 | # due to the way Script assumes all controls returned can be used as inputs. 94 | # Therefore, there's no good way to use grouping components right now, 95 | # so we will use a checkbox! :) 96 | checkbox_txt = gr.Checkbox(label="Show Textbox", value=False) 97 | file = gr.File(label="File with inputs", type='bytes') 98 | prompt_txt = gr.TextArea(label="Prompts") 99 | checkbox_txt.change(fn=lambda x: [gr.File.update(visible = not x), gr.TextArea.update(visible = x)], inputs=[checkbox_txt], outputs=[file, prompt_txt]) 100 | return [checkbox_txt, file, prompt_txt] 101 | 102 | def on_show(self, checkbox_txt, file, prompt_txt): 103 | return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ] 104 | 105 | def run(self, p, checkbox_txt, data: bytes, prompt_txt: str): 106 | if checkbox_txt: 107 | lines = [x.strip() for x in prompt_txt.splitlines()] 108 | else: 109 | lines = [x.strip() for x in data.decode('utf8', errors='ignore').split("\n")] 110 | lines = [x for x in lines if len(x) > 0] 111 | 112 | p.do_not_save_grid = True 113 | 114 | job_count = 0 115 | jobs = [] 116 | 117 | for line in lines: 118 | if "--" in line: 119 | try: 120 | args = cmdargs(line) 121 | except Exception: 122 | print(f"Error parsing line [line] as commandline:", file=sys.stderr) 123 | print(traceback.format_exc(), file=sys.stderr) 124 | args = {"prompt": line} 125 | else: 126 | args = {"prompt": line} 127 | 128 | n_iter = args.get("n_iter", 1) 129 | if n_iter != 1: 130 | job_count += n_iter 131 | else: 132 | job_count += 1 133 | 134 | jobs.append(args) 135 | 136 | print(f"Will process {len(lines)} lines in {job_count} jobs.") 137 | state.job_count = job_count 138 | 139 | images = [] 140 | for n, args in enumerate(jobs): 141 | state.job = f"{state.job_no + 1} out of {state.job_count}" 142 | 143 | copy_p = copy.copy(p) 144 | for k, v in args.items(): 145 | setattr(copy_p, k, v) 146 | 147 | proc = process_images(copy_p) 148 | images += proc.images 149 | 150 | return Processed(p, images, p.seed, "") 151 | -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading 3 | import time 4 | import importlib 5 | import signal 6 | import threading 7 | from fastapi import FastAPI 8 | from fastapi.middleware.gzip import GZipMiddleware 9 | 10 | from modules.paths import script_path 11 | 12 | from modules import devices, sd_samplers 13 | import modules.codeformer_model as codeformer 14 | import modules.extras 15 | import modules.face_restoration 16 | import modules.gfpgan_model as gfpgan 17 | import modules.img2img 18 | 19 | import modules.lowvram 20 | import modules.paths 21 | import modules.scripts 22 | import modules.sd_hijack 23 | import modules.sd_models 24 | import modules.shared as shared 25 | import modules.txt2img 26 | 27 | import modules.ui 28 | from modules import devices 29 | from modules import modelloader 30 | from modules.paths import script_path 31 | from modules.shared import cmd_opts 32 | import modules.hypernetworks.hypernetwork 33 | 34 | queue_lock = threading.Lock() 35 | 36 | 37 | def wrap_queued_call(func): 38 | def f(*args, **kwargs): 39 | with queue_lock: 40 | res = func(*args, **kwargs) 41 | 42 | return res 43 | 44 | return f 45 | 46 | 47 | def wrap_gradio_gpu_call(func, extra_outputs=None): 48 | def f(*args, **kwargs): 49 | devices.torch_gc() 50 | 51 | shared.state.sampling_step = 0 52 | shared.state.job_count = -1 53 | shared.state.job_no = 0 54 | shared.state.job_timestamp = shared.state.get_job_timestamp() 55 | shared.state.current_latent = None 56 | shared.state.current_image = None 57 | shared.state.current_image_sampling_step = 0 58 | shared.state.skipped = False 59 | shared.state.interrupted = False 60 | shared.state.textinfo = None 61 | 62 | with queue_lock: 63 | res = func(*args, **kwargs) 64 | 65 | shared.state.job = "" 66 | shared.state.job_count = 0 67 | 68 | devices.torch_gc() 69 | 70 | return res 71 | 72 | return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) 73 | 74 | def initialize(): 75 | modelloader.cleanup_models() 76 | modules.sd_models.setup_model() 77 | codeformer.setup_model(cmd_opts.codeformer_models_path) 78 | gfpgan.setup_model(cmd_opts.gfpgan_models_path) 79 | shared.face_restorers.append(modules.face_restoration.FaceRestoration()) 80 | modelloader.load_upscalers() 81 | 82 | modules.scripts.load_scripts(os.path.join(script_path, "scripts")) 83 | 84 | shared.sd_model = modules.sd_models.load_model() 85 | shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) 86 | shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) 87 | shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) 88 | 89 | # make the program just exit at ctrl+c without waiting for anything 90 | def sigint_handler(sig, frame): 91 | print(f'Interrupted with signal {sig} in {frame}') 92 | os._exit(0) 93 | 94 | signal.signal(signal.SIGINT, sigint_handler) 95 | 96 | 97 | def create_api(app): 98 | from modules.api.api import Api 99 | api = Api(app, queue_lock) 100 | return api 101 | 102 | def wait_on_server(demo=None): 103 | while 1: 104 | time.sleep(0.5) 105 | if demo and getattr(demo, 'do_restart', False): 106 | time.sleep(0.5) 107 | demo.close() 108 | time.sleep(0.5) 109 | break 110 | 111 | def api_only(): 112 | initialize() 113 | 114 | app = FastAPI() 115 | app.add_middleware(GZipMiddleware, minimum_size=1000) 116 | api = create_api(app) 117 | 118 | api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) 119 | 120 | 121 | def webui(): 122 | launch_api = cmd_opts.api 123 | initialize() 124 | 125 | while 1: 126 | demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) 127 | 128 | app, local_url, share_url = demo.launch( 129 | share=cmd_opts.share, 130 | server_name="0.0.0.0" if cmd_opts.listen else None, 131 | server_port=cmd_opts.port, 132 | debug=cmd_opts.gradio_debug, 133 | auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, 134 | inbrowser=cmd_opts.autolaunch, 135 | prevent_thread_lock=True 136 | ) 137 | 138 | app.add_middleware(GZipMiddleware, minimum_size=1000) 139 | 140 | if (launch_api): 141 | create_api(app) 142 | 143 | wait_on_server(demo) 144 | 145 | sd_samplers.set_samplers() 146 | 147 | print('Reloading Custom Scripts') 148 | modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) 149 | print('Reloading modules: modules.ui') 150 | importlib.reload(modules.ui) 151 | print('Refreshing Model List') 152 | modules.sd_models.list_models() 153 | print('Restarting Gradio') 154 | 155 | 156 | 157 | task = [] 158 | if __name__ == "__main__": 159 | if cmd_opts.nowebui: 160 | api_only() 161 | else: 162 | webui() 163 | -------------------------------------------------------------------------------- /javascript/progressbar.js: -------------------------------------------------------------------------------- 1 | // code related to showing and updating progressbar shown as the image is being made 2 | global_progressbars = {} 3 | galleries = {} 4 | galleryObservers = {} 5 | 6 | function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ 7 | var progressbar = gradioApp().getElementById(id_progressbar) 8 | var skip = id_skip ? gradioApp().getElementById(id_skip) : null 9 | var interrupt = gradioApp().getElementById(id_interrupt) 10 | 11 | if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ 12 | if(progressbar.innerText){ 13 | let newtitle = 'Stable Diffusion - ' + progressbar.innerText 14 | if(document.title != newtitle){ 15 | document.title = newtitle; 16 | } 17 | }else{ 18 | let newtitle = 'Stable Diffusion' 19 | if(document.title != newtitle){ 20 | document.title = newtitle; 21 | } 22 | } 23 | } 24 | 25 | if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){ 26 | global_progressbars[id_progressbar] = progressbar 27 | 28 | var mutationObserver = new MutationObserver(function(m){ 29 | preview = gradioApp().getElementById(id_preview) 30 | gallery = gradioApp().getElementById(id_gallery) 31 | 32 | if(preview != null && gallery != null){ 33 | preview.style.width = gallery.clientWidth + "px" 34 | preview.style.height = gallery.clientHeight + "px" 35 | 36 | //only watch gallery if there is a generation process going on 37 | check_gallery(id_gallery); 38 | 39 | var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; 40 | if(!progressDiv){ 41 | if (skip) { 42 | skip.style.display = "none" 43 | } 44 | interrupt.style.display = "none" 45 | 46 | //disconnect observer once generation finished, so user can close selected image if they want 47 | if (galleryObservers[id_gallery]) { 48 | galleryObservers[id_gallery].disconnect(); 49 | galleries[id_gallery] = null; 50 | } 51 | } 52 | 53 | 54 | } 55 | 56 | window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500) 57 | }); 58 | mutationObserver.observe( progressbar, { childList:true, subtree:true }) 59 | } 60 | } 61 | 62 | function check_gallery(id_gallery){ 63 | let gallery = gradioApp().getElementById(id_gallery) 64 | // if gallery has no change, no need to setting up observer again. 65 | if (gallery && galleries[id_gallery] !== gallery){ 66 | galleries[id_gallery] = gallery; 67 | if(galleryObservers[id_gallery]){ 68 | galleryObservers[id_gallery].disconnect(); 69 | } 70 | let prevSelectedIndex = selected_gallery_index(); 71 | galleryObservers[id_gallery] = new MutationObserver(function (){ 72 | let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') 73 | let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') 74 | if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) { 75 | // automatically re-open previously selected index (if exists) 76 | activeElement = gradioApp().activeElement; 77 | 78 | galleryButtons[prevSelectedIndex].click(); 79 | showGalleryImage(); 80 | 81 | if(activeElement){ 82 | // i fought this for about an hour; i don't know why the focus is lost or why this helps recover it 83 | // if somenoe has a better solution please by all means 84 | setTimeout(function() { activeElement.focus() }, 1); 85 | } 86 | } 87 | }) 88 | galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false }) 89 | } 90 | } 91 | 92 | onUiUpdate(function(){ 93 | check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') 94 | check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') 95 | check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery') 96 | }) 97 | 98 | function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){ 99 | btn = gradioApp().getElementById(id_part+"_check_progress"); 100 | if(btn==null) return; 101 | 102 | btn.click(); 103 | var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; 104 | var skip = id_skip ? gradioApp().getElementById(id_skip) : null 105 | var interrupt = gradioApp().getElementById(id_interrupt) 106 | if(progressDiv && interrupt){ 107 | if (skip) { 108 | skip.style.display = "block" 109 | } 110 | interrupt.style.display = "block" 111 | } 112 | } 113 | 114 | function requestProgress(id_part){ 115 | btn = gradioApp().getElementById(id_part+"_check_progress_initial"); 116 | if(btn==null) return; 117 | 118 | btn.click(); 119 | } 120 | -------------------------------------------------------------------------------- /modules/img2img.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | import traceback 5 | 6 | import numpy as np 7 | from PIL import Image, ImageOps, ImageChops 8 | 9 | from modules import devices 10 | from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images 11 | from modules.shared import opts, state 12 | import modules.shared as shared 13 | import modules.processing as processing 14 | from modules.ui import plaintext_to_html 15 | import modules.images as images 16 | import modules.scripts 17 | 18 | 19 | def process_batch(p, input_dir, output_dir, args): 20 | processing.fix_seed(p) 21 | 22 | images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)] 23 | 24 | print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") 25 | 26 | save_normally = output_dir == '' 27 | 28 | p.do_not_save_grid = True 29 | p.do_not_save_samples = not save_normally 30 | 31 | state.job_count = len(images) * p.n_iter 32 | 33 | for i, image in enumerate(images): 34 | state.job = f"{i+1} out of {len(images)}" 35 | if state.skipped: 36 | state.skipped = False 37 | 38 | if state.interrupted: 39 | break 40 | 41 | img = Image.open(image) 42 | p.init_images = [img] * p.batch_size 43 | 44 | proc = modules.scripts.scripts_img2img.run(p, *args) 45 | if proc is None: 46 | proc = process_images(p) 47 | 48 | for n, processed_image in enumerate(proc.images): 49 | filename = os.path.basename(image) 50 | 51 | if n > 0: 52 | left, right = os.path.splitext(filename) 53 | filename = f"{left}-{n}{right}" 54 | 55 | if not save_normally: 56 | processed_image.save(os.path.join(output_dir, filename)) 57 | 58 | 59 | def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): 60 | is_inpaint = mode == 1 61 | is_batch = mode == 2 62 | 63 | if is_inpaint: 64 | if mask_mode == 0: 65 | image = init_img_with_mask['image'] 66 | mask = init_img_with_mask['mask'] 67 | alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') 68 | mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') 69 | image = image.convert('RGB') 70 | else: 71 | image = init_img_inpaint 72 | mask = init_mask_inpaint 73 | else: 74 | image = init_img 75 | mask = None 76 | 77 | assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' 78 | 79 | p = StableDiffusionProcessingImg2Img( 80 | sd_model=shared.sd_model, 81 | outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples, 82 | outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids, 83 | prompt=prompt, 84 | negative_prompt=negative_prompt, 85 | styles=[prompt_style, prompt_style2], 86 | seed=seed, 87 | subseed=subseed, 88 | subseed_strength=subseed_strength, 89 | seed_resize_from_h=seed_resize_from_h, 90 | seed_resize_from_w=seed_resize_from_w, 91 | seed_enable_extras=seed_enable_extras, 92 | sampler_index=sampler_index, 93 | batch_size=batch_size, 94 | n_iter=n_iter, 95 | steps=steps, 96 | cfg_scale=cfg_scale, 97 | width=width, 98 | height=height, 99 | restore_faces=restore_faces, 100 | tiling=tiling, 101 | init_images=[image], 102 | mask=mask, 103 | mask_blur=mask_blur, 104 | inpainting_fill=inpainting_fill, 105 | resize_mode=resize_mode, 106 | denoising_strength=denoising_strength, 107 | inpaint_full_res=inpaint_full_res, 108 | inpaint_full_res_padding=inpaint_full_res_padding, 109 | inpainting_mask_invert=inpainting_mask_invert, 110 | ) 111 | 112 | if shared.cmd_opts.enable_console_prompts: 113 | print(f"\nimg2img: {prompt}", file=shared.progress_print_out) 114 | 115 | p.extra_generation_params["Mask blur"] = mask_blur 116 | 117 | if is_batch: 118 | assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" 119 | 120 | process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args) 121 | 122 | processed = Processed(p, [], p.seed, "") 123 | else: 124 | processed = modules.scripts.scripts_img2img.run(p, *args) 125 | if processed is None: 126 | processed = process_images(p) 127 | 128 | shared.total_tqdm.clear() 129 | 130 | generation_info_js = processed.js() 131 | if opts.samples_log_stdout: 132 | print(generation_info_js) 133 | 134 | if opts.do_not_show_images: 135 | processed.images = [] 136 | 137 | return processed.images, generation_info_js, plaintext_to_html(processed.info) 138 | -------------------------------------------------------------------------------- /modules/realesrgan_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | import numpy as np 6 | from PIL import Image 7 | from basicsr.utils.download_util import load_file_from_url 8 | from realesrgan import RealESRGANer 9 | 10 | from modules.upscaler import Upscaler, UpscalerData 11 | from modules.shared import cmd_opts, opts 12 | 13 | 14 | class UpscalerRealESRGAN(Upscaler): 15 | def __init__(self, path): 16 | self.name = "RealESRGAN" 17 | self.user_path = path 18 | super().__init__() 19 | try: 20 | from basicsr.archs.rrdbnet_arch import RRDBNet 21 | from realesrgan import RealESRGANer 22 | from realesrgan.archs.srvgg_arch import SRVGGNetCompact 23 | self.enable = True 24 | self.scalers = [] 25 | scalers = self.load_models(path) 26 | for scaler in scalers: 27 | if scaler.name in opts.realesrgan_enabled_models: 28 | self.scalers.append(scaler) 29 | 30 | except Exception: 31 | print("Error importing Real-ESRGAN:", file=sys.stderr) 32 | print(traceback.format_exc(), file=sys.stderr) 33 | self.enable = False 34 | self.scalers = [] 35 | 36 | def do_upscale(self, img, path): 37 | if not self.enable: 38 | return img 39 | 40 | info = self.load_model(path) 41 | if not os.path.exists(info.data_path): 42 | print("Unable to load RealESRGAN model: %s" % info.name) 43 | return img 44 | 45 | upsampler = RealESRGANer( 46 | scale=info.scale, 47 | model_path=info.data_path, 48 | model=info.model(), 49 | half=not cmd_opts.no_half, 50 | tile=opts.ESRGAN_tile, 51 | tile_pad=opts.ESRGAN_tile_overlap, 52 | ) 53 | 54 | upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0] 55 | 56 | image = Image.fromarray(upsampled) 57 | return image 58 | 59 | def load_model(self, path): 60 | try: 61 | info = None 62 | for scaler in self.scalers: 63 | if scaler.data_path == path: 64 | info = scaler 65 | 66 | if info is None: 67 | print(f"Unable to find model info: {path}") 68 | return None 69 | 70 | model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True) 71 | info.data_path = model_file 72 | return info 73 | except Exception as e: 74 | print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr) 75 | print(traceback.format_exc(), file=sys.stderr) 76 | return None 77 | 78 | def load_models(self, _): 79 | return get_realesrgan_models(self) 80 | 81 | 82 | def get_realesrgan_models(scaler): 83 | try: 84 | from basicsr.archs.rrdbnet_arch import RRDBNet 85 | from realesrgan.archs.srvgg_arch import SRVGGNetCompact 86 | models = [ 87 | UpscalerData( 88 | name="R-ESRGAN General 4xV3", 89 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", 90 | scale=4, 91 | upscaler=scaler, 92 | model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') 93 | ), 94 | UpscalerData( 95 | name="R-ESRGAN General WDN 4xV3", 96 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", 97 | scale=4, 98 | upscaler=scaler, 99 | model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') 100 | ), 101 | UpscalerData( 102 | name="R-ESRGAN AnimeVideo", 103 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", 104 | scale=4, 105 | upscaler=scaler, 106 | model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') 107 | ), 108 | UpscalerData( 109 | name="R-ESRGAN 4x+", 110 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", 111 | scale=4, 112 | upscaler=scaler, 113 | model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) 114 | ), 115 | UpscalerData( 116 | name="R-ESRGAN 4x+ Anime6B", 117 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", 118 | scale=4, 119 | upscaler=scaler, 120 | model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) 121 | ), 122 | UpscalerData( 123 | name="R-ESRGAN 2x+", 124 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", 125 | scale=2, 126 | upscaler=scaler, 127 | model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) 128 | ), 129 | ] 130 | return models 131 | except Exception as e: 132 | print("Error making Real-ESRGAN models list:", file=sys.stderr) 133 | print(traceback.format_exc(), file=sys.stderr) 134 | -------------------------------------------------------------------------------- /scripts/poor_mans_outpainting.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import modules.scripts as scripts 4 | import gradio as gr 5 | from PIL import Image, ImageDraw 6 | 7 | from modules import images, processing, devices 8 | from modules.processing import Processed, process_images 9 | from modules.shared import opts, cmd_opts, state 10 | 11 | 12 | 13 | class Script(scripts.Script): 14 | def title(self): 15 | return "Poor man's outpainting" 16 | 17 | def show(self, is_img2img): 18 | return is_img2img 19 | 20 | def ui(self, is_img2img): 21 | if not is_img2img: 22 | return None 23 | 24 | pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128) 25 | mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, visible=False) 26 | inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False) 27 | direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down']) 28 | 29 | return [pixels, mask_blur, inpainting_fill, direction] 30 | 31 | def run(self, p, pixels, mask_blur, inpainting_fill, direction): 32 | initial_seed = None 33 | initial_info = None 34 | 35 | p.mask_blur = mask_blur * 2 36 | p.inpainting_fill = inpainting_fill 37 | p.inpaint_full_res = False 38 | 39 | left = pixels if "left" in direction else 0 40 | right = pixels if "right" in direction else 0 41 | up = pixels if "up" in direction else 0 42 | down = pixels if "down" in direction else 0 43 | 44 | init_img = p.init_images[0] 45 | target_w = math.ceil((init_img.width + left + right) / 64) * 64 46 | target_h = math.ceil((init_img.height + up + down) / 64) * 64 47 | 48 | if left > 0: 49 | left = left * (target_w - init_img.width) // (left + right) 50 | if right > 0: 51 | right = target_w - init_img.width - left 52 | 53 | if up > 0: 54 | up = up * (target_h - init_img.height) // (up + down) 55 | 56 | if down > 0: 57 | down = target_h - init_img.height - up 58 | 59 | img = Image.new("RGB", (target_w, target_h)) 60 | img.paste(init_img, (left, up)) 61 | 62 | mask = Image.new("L", (img.width, img.height), "white") 63 | draw = ImageDraw.Draw(mask) 64 | draw.rectangle(( 65 | left + (mask_blur * 2 if left > 0 else 0), 66 | up + (mask_blur * 2 if up > 0 else 0), 67 | mask.width - right - (mask_blur * 2 if right > 0 else 0), 68 | mask.height - down - (mask_blur * 2 if down > 0 else 0) 69 | ), fill="black") 70 | 71 | latent_mask = Image.new("L", (img.width, img.height), "white") 72 | latent_draw = ImageDraw.Draw(latent_mask) 73 | latent_draw.rectangle(( 74 | left + (mask_blur//2 if left > 0 else 0), 75 | up + (mask_blur//2 if up > 0 else 0), 76 | mask.width - right - (mask_blur//2 if right > 0 else 0), 77 | mask.height - down - (mask_blur//2 if down > 0 else 0) 78 | ), fill="black") 79 | 80 | devices.torch_gc() 81 | 82 | grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels) 83 | grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels) 84 | grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels) 85 | 86 | p.n_iter = 1 87 | p.batch_size = 1 88 | p.do_not_save_grid = True 89 | p.do_not_save_samples = True 90 | 91 | work = [] 92 | work_mask = [] 93 | work_latent_mask = [] 94 | work_results = [] 95 | 96 | for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles): 97 | for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask): 98 | x, w = tiledata[0:2] 99 | 100 | if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: 101 | continue 102 | 103 | work.append(tiledata[2]) 104 | work_mask.append(tiledata_mask[2]) 105 | work_latent_mask.append(tiledata_latent_mask[2]) 106 | 107 | batch_count = len(work) 108 | print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.") 109 | 110 | state.job_count = batch_count 111 | 112 | for i in range(batch_count): 113 | p.init_images = [work[i]] 114 | p.image_mask = work_mask[i] 115 | p.latent_mask = work_latent_mask[i] 116 | 117 | state.job = f"Batch {i + 1} out of {batch_count}" 118 | processed = process_images(p) 119 | 120 | if initial_seed is None: 121 | initial_seed = processed.seed 122 | initial_info = processed.info 123 | 124 | p.seed = processed.seed + 1 125 | work_results += processed.images 126 | 127 | 128 | image_index = 0 129 | for y, h, row in grid.tiles: 130 | for tiledata in row: 131 | x, w = tiledata[0:2] 132 | 133 | if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: 134 | continue 135 | 136 | tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) 137 | image_index += 1 138 | 139 | combined_image = images.combine_grid(grid) 140 | 141 | if opts.samples_save: 142 | images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.grid_format, info=initial_info, p=p) 143 | 144 | processed = Processed(p, [combined_image], initial_seed, initial_info) 145 | 146 | return processed 147 | 148 | -------------------------------------------------------------------------------- /modules/modelloader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import shutil 4 | import importlib 5 | from urllib.parse import urlparse 6 | 7 | from basicsr.utils.download_util import load_file_from_url 8 | from modules import shared 9 | from modules.upscaler import Upscaler 10 | from modules.paths import script_path, models_path 11 | 12 | 13 | def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list: 14 | """ 15 | A one-and done loader to try finding the desired models in specified directories. 16 | 17 | @param download_name: Specify to download from model_url immediately. 18 | @param model_url: If no other models are found, this will be downloaded on upscale. 19 | @param model_path: The location to store/find models in. 20 | @param command_path: A command-line argument to search for models in first. 21 | @param ext_filter: An optional list of filename extensions to filter by 22 | @return: A list of paths containing the desired model(s) 23 | """ 24 | output = [] 25 | 26 | if ext_filter is None: 27 | ext_filter = [] 28 | 29 | try: 30 | places = [] 31 | 32 | if command_path is not None and command_path != model_path: 33 | pretrained_path = os.path.join(command_path, 'experiments/pretrained_models') 34 | if os.path.exists(pretrained_path): 35 | print(f"Appending path: {pretrained_path}") 36 | places.append(pretrained_path) 37 | elif os.path.exists(command_path): 38 | places.append(command_path) 39 | 40 | places.append(model_path) 41 | 42 | for place in places: 43 | if os.path.exists(place): 44 | for file in glob.iglob(place + '**/**', recursive=True): 45 | full_path = file 46 | if os.path.isdir(full_path): 47 | continue 48 | if len(ext_filter) != 0: 49 | model_name, extension = os.path.splitext(file) 50 | if extension not in ext_filter: 51 | continue 52 | if file not in output: 53 | output.append(full_path) 54 | 55 | if model_url is not None and len(output) == 0: 56 | if download_name is not None: 57 | dl = load_file_from_url(model_url, model_path, True, download_name) 58 | output.append(dl) 59 | else: 60 | output.append(model_url) 61 | 62 | except Exception: 63 | pass 64 | 65 | return output 66 | 67 | 68 | def friendly_name(file: str): 69 | if "http" in file: 70 | file = urlparse(file).path 71 | 72 | file = os.path.basename(file) 73 | model_name, extension = os.path.splitext(file) 74 | return model_name 75 | 76 | 77 | def cleanup_models(): 78 | # This code could probably be more efficient if we used a tuple list or something to store the src/destinations 79 | # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler 80 | # somehow auto-register and just do these things... 81 | root_path = script_path 82 | src_path = models_path 83 | dest_path = os.path.join(models_path, "Stable-diffusion") 84 | move_files(src_path, dest_path, ".ckpt") 85 | src_path = os.path.join(root_path, "ESRGAN") 86 | dest_path = os.path.join(models_path, "ESRGAN") 87 | move_files(src_path, dest_path) 88 | src_path = os.path.join(root_path, "gfpgan") 89 | dest_path = os.path.join(models_path, "GFPGAN") 90 | move_files(src_path, dest_path) 91 | src_path = os.path.join(root_path, "SwinIR") 92 | dest_path = os.path.join(models_path, "SwinIR") 93 | move_files(src_path, dest_path) 94 | src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/") 95 | dest_path = os.path.join(models_path, "LDSR") 96 | move_files(src_path, dest_path) 97 | 98 | 99 | def move_files(src_path: str, dest_path: str, ext_filter: str = None): 100 | try: 101 | if not os.path.exists(dest_path): 102 | os.makedirs(dest_path) 103 | if os.path.exists(src_path): 104 | for file in os.listdir(src_path): 105 | fullpath = os.path.join(src_path, file) 106 | if os.path.isfile(fullpath): 107 | if ext_filter is not None: 108 | if ext_filter not in file: 109 | continue 110 | print(f"Moving {file} from {src_path} to {dest_path}.") 111 | try: 112 | shutil.move(fullpath, dest_path) 113 | except: 114 | pass 115 | if len(os.listdir(src_path)) == 0: 116 | print(f"Removing empty folder: {src_path}") 117 | shutil.rmtree(src_path, True) 118 | except: 119 | pass 120 | 121 | 122 | def load_upscalers(): 123 | sd = shared.script_path 124 | # We can only do this 'magic' method to dynamically load upscalers if they are referenced, 125 | # so we'll try to import any _model.py files before looking in __subclasses__ 126 | modules_dir = os.path.join(sd, "modules") 127 | for file in os.listdir(modules_dir): 128 | if "_model.py" in file: 129 | model_name = file.replace("_model.py", "") 130 | full_model = f"modules.{model_name}_model" 131 | try: 132 | importlib.import_module(full_model) 133 | except: 134 | pass 135 | datas = [] 136 | c_o = vars(shared.cmd_opts) 137 | for cls in Upscaler.__subclasses__(): 138 | name = cls.__name__ 139 | module_name = cls.__module__ 140 | module = importlib.import_module(module_name) 141 | class_ = getattr(module, name) 142 | cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" 143 | opt_string = None 144 | try: 145 | if cmd_name in c_o: 146 | opt_string = c_o[cmd_name] 147 | except: 148 | pass 149 | scaler = class_(opt_string) 150 | for child in scaler.scalers: 151 | datas.append(child) 152 | 153 | shared.sd_upscalers = datas 154 | -------------------------------------------------------------------------------- /javascript/contextMenus.js: -------------------------------------------------------------------------------- 1 | 2 | contextMenuInit = function(){ 3 | let eventListenerApplied=false; 4 | let menuSpecs = new Map(); 5 | 6 | const uid = function(){ 7 | return Date.now().toString(36) + Math.random().toString(36).substr(2); 8 | } 9 | 10 | function showContextMenu(event,element,menuEntries){ 11 | let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft; 12 | let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop; 13 | 14 | let oldMenu = gradioApp().querySelector('#context-menu') 15 | if(oldMenu){ 16 | oldMenu.remove() 17 | } 18 | 19 | let tabButton = uiCurrentTab 20 | let baseStyle = window.getComputedStyle(tabButton) 21 | 22 | const contextMenu = document.createElement('nav') 23 | contextMenu.id = "context-menu" 24 | contextMenu.style.background = baseStyle.background 25 | contextMenu.style.color = baseStyle.color 26 | contextMenu.style.fontFamily = baseStyle.fontFamily 27 | contextMenu.style.top = posy+'px' 28 | contextMenu.style.left = posx+'px' 29 | 30 | 31 | 32 | const contextMenuList = document.createElement('ul') 33 | contextMenuList.className = 'context-menu-items'; 34 | contextMenu.append(contextMenuList); 35 | 36 | menuEntries.forEach(function(entry){ 37 | let contextMenuEntry = document.createElement('a') 38 | contextMenuEntry.innerHTML = entry['name'] 39 | contextMenuEntry.addEventListener("click", function(e) { 40 | entry['func'](); 41 | }) 42 | contextMenuList.append(contextMenuEntry); 43 | 44 | }) 45 | 46 | gradioApp().getRootNode().appendChild(contextMenu) 47 | 48 | let menuWidth = contextMenu.offsetWidth + 4; 49 | let menuHeight = contextMenu.offsetHeight + 4; 50 | 51 | let windowWidth = window.innerWidth; 52 | let windowHeight = window.innerHeight; 53 | 54 | if ( (windowWidth - posx) < menuWidth ) { 55 | contextMenu.style.left = windowWidth - menuWidth + "px"; 56 | } 57 | 58 | if ( (windowHeight - posy) < menuHeight ) { 59 | contextMenu.style.top = windowHeight - menuHeight + "px"; 60 | } 61 | 62 | } 63 | 64 | function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){ 65 | 66 | currentItems = menuSpecs.get(targetEmementSelector) 67 | 68 | if(!currentItems){ 69 | currentItems = [] 70 | menuSpecs.set(targetEmementSelector,currentItems); 71 | } 72 | let newItem = {'id':targetEmementSelector+'_'+uid(), 73 | 'name':entryName, 74 | 'func':entryFunction, 75 | 'isNew':true} 76 | 77 | currentItems.push(newItem) 78 | return newItem['id'] 79 | } 80 | 81 | function removeContextMenuOption(uid){ 82 | menuSpecs.forEach(function(v,k) { 83 | let index = -1 84 | v.forEach(function(e,ei){if(e['id']==uid){index=ei}}) 85 | if(index>=0){ 86 | v.splice(index, 1); 87 | } 88 | }) 89 | } 90 | 91 | function addContextMenuEventListener(){ 92 | if(eventListenerApplied){ 93 | return; 94 | } 95 | gradioApp().addEventListener("click", function(e) { 96 | let source = e.composedPath()[0] 97 | if(source.id && source.id.indexOf('check_progress')>-1){ 98 | return 99 | } 100 | 101 | let oldMenu = gradioApp().querySelector('#context-menu') 102 | if(oldMenu){ 103 | oldMenu.remove() 104 | } 105 | }); 106 | gradioApp().addEventListener("contextmenu", function(e) { 107 | let oldMenu = gradioApp().querySelector('#context-menu') 108 | if(oldMenu){ 109 | oldMenu.remove() 110 | } 111 | menuSpecs.forEach(function(v,k) { 112 | if(e.composedPath()[0].matches(k)){ 113 | showContextMenu(e,e.composedPath()[0],v) 114 | e.preventDefault() 115 | return 116 | } 117 | }) 118 | }); 119 | eventListenerApplied=true 120 | 121 | } 122 | 123 | return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener] 124 | } 125 | 126 | initResponse = contextMenuInit(); 127 | appendContextMenuOption = initResponse[0]; 128 | removeContextMenuOption = initResponse[1]; 129 | addContextMenuEventListener = initResponse[2]; 130 | 131 | (function(){ 132 | //Start example Context Menu Items 133 | let generateOnRepeat = function(genbuttonid,interruptbuttonid){ 134 | let genbutton = gradioApp().querySelector(genbuttonid); 135 | let interruptbutton = gradioApp().querySelector(interruptbuttonid); 136 | if(!interruptbutton.offsetParent){ 137 | genbutton.click(); 138 | } 139 | clearInterval(window.generateOnRepeatInterval) 140 | window.generateOnRepeatInterval = setInterval(function(){ 141 | if(!interruptbutton.offsetParent){ 142 | genbutton.click(); 143 | } 144 | }, 145 | 500) 146 | } 147 | 148 | appendContextMenuOption('#txt2img_generate','Generate forever',function(){ 149 | generateOnRepeat('#txt2img_generate','#txt2img_interrupt'); 150 | }) 151 | appendContextMenuOption('#img2img_generate','Generate forever',function(){ 152 | generateOnRepeat('#img2img_generate','#img2img_interrupt'); 153 | }) 154 | 155 | let cancelGenerateForever = function(){ 156 | clearInterval(window.generateOnRepeatInterval) 157 | } 158 | 159 | appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever) 160 | appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever) 161 | appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever) 162 | appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever) 163 | 164 | appendContextMenuOption('#roll','Roll three', 165 | function(){ 166 | let rollbutton = get_uiCurrentTabContent().querySelector('#roll'); 167 | setTimeout(function(){rollbutton.click()},100) 168 | setTimeout(function(){rollbutton.click()},200) 169 | setTimeout(function(){rollbutton.click()},300) 170 | } 171 | ) 172 | })(); 173 | //End example Context Menu Items 174 | 175 | onUiUpdate(function(){ 176 | addContextMenuEventListener() 177 | }); 178 | -------------------------------------------------------------------------------- /modules/swinir_model.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from basicsr.utils.download_util import load_file_from_url 8 | from tqdm import tqdm 9 | 10 | from modules import modelloader 11 | from modules.shared import cmd_opts, opts, device 12 | from modules.swinir_model_arch import SwinIR as net 13 | from modules.swinir_model_arch_v2 import Swin2SR as net2 14 | from modules.upscaler import Upscaler, UpscalerData 15 | 16 | precision_scope = ( 17 | torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext 18 | ) 19 | 20 | 21 | class UpscalerSwinIR(Upscaler): 22 | def __init__(self, dirname): 23 | self.name = "SwinIR" 24 | self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \ 25 | "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \ 26 | "-L_x4_GAN.pth " 27 | self.model_name = "SwinIR 4x" 28 | self.user_path = dirname 29 | super().__init__() 30 | scalers = [] 31 | model_files = self.find_models(ext_filter=[".pt", ".pth"]) 32 | for model in model_files: 33 | if "http" in model: 34 | name = self.model_name 35 | else: 36 | name = modelloader.friendly_name(model) 37 | model_data = UpscalerData(name, model, self) 38 | scalers.append(model_data) 39 | self.scalers = scalers 40 | 41 | def do_upscale(self, img, model_file): 42 | model = self.load_model(model_file) 43 | if model is None: 44 | return img 45 | model = model.to(device) 46 | img = upscale(img, model) 47 | try: 48 | torch.cuda.empty_cache() 49 | except: 50 | pass 51 | return img 52 | 53 | def load_model(self, path, scale=4): 54 | if "http" in path: 55 | dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth") 56 | filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True) 57 | else: 58 | filename = path 59 | if filename is None or not os.path.exists(filename): 60 | return None 61 | if filename.endswith(".v2.pth"): 62 | model = net2( 63 | upscale=scale, 64 | in_chans=3, 65 | img_size=64, 66 | window_size=8, 67 | img_range=1.0, 68 | depths=[6, 6, 6, 6, 6, 6], 69 | embed_dim=180, 70 | num_heads=[6, 6, 6, 6, 6, 6], 71 | mlp_ratio=2, 72 | upsampler="nearest+conv", 73 | resi_connection="1conv", 74 | ) 75 | params = None 76 | else: 77 | model = net( 78 | upscale=scale, 79 | in_chans=3, 80 | img_size=64, 81 | window_size=8, 82 | img_range=1.0, 83 | depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], 84 | embed_dim=240, 85 | num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], 86 | mlp_ratio=2, 87 | upsampler="nearest+conv", 88 | resi_connection="3conv", 89 | ) 90 | params = "params_ema" 91 | 92 | pretrained_model = torch.load(filename) 93 | if params is not None: 94 | model.load_state_dict(pretrained_model[params], strict=True) 95 | else: 96 | model.load_state_dict(pretrained_model, strict=True) 97 | if not cmd_opts.no_half: 98 | model = model.half() 99 | return model 100 | 101 | 102 | def upscale( 103 | img, 104 | model, 105 | tile=opts.SWIN_tile, 106 | tile_overlap=opts.SWIN_tile_overlap, 107 | window_size=8, 108 | scale=4, 109 | ): 110 | img = np.array(img) 111 | img = img[:, :, ::-1] 112 | img = np.moveaxis(img, 2, 0) / 255 113 | img = torch.from_numpy(img).float() 114 | img = img.unsqueeze(0).to(device) 115 | with torch.no_grad(), precision_scope("cuda"): 116 | _, _, h_old, w_old = img.size() 117 | h_pad = (h_old // window_size + 1) * window_size - h_old 118 | w_pad = (w_old // window_size + 1) * window_size - w_old 119 | img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] 120 | img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] 121 | output = inference(img, model, tile, tile_overlap, window_size, scale) 122 | output = output[..., : h_old * scale, : w_old * scale] 123 | output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() 124 | if output.ndim == 3: 125 | output = np.transpose( 126 | output[[2, 1, 0], :, :], (1, 2, 0) 127 | ) # CHW-RGB to HCW-BGR 128 | output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 129 | return Image.fromarray(output, "RGB") 130 | 131 | 132 | def inference(img, model, tile, tile_overlap, window_size, scale): 133 | # test the image tile by tile 134 | b, c, h, w = img.size() 135 | tile = min(tile, h, w) 136 | assert tile % window_size == 0, "tile size should be a multiple of window_size" 137 | sf = scale 138 | 139 | stride = tile - tile_overlap 140 | h_idx_list = list(range(0, h - tile, stride)) + [h - tile] 141 | w_idx_list = list(range(0, w - tile, stride)) + [w - tile] 142 | E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img) 143 | W = torch.zeros_like(E, dtype=torch.half, device=device) 144 | 145 | with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: 146 | for h_idx in h_idx_list: 147 | for w_idx in w_idx_list: 148 | in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] 149 | out_patch = model(in_patch) 150 | out_patch_mask = torch.ones_like(out_patch) 151 | 152 | E[ 153 | ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf 154 | ].add_(out_patch) 155 | W[ 156 | ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf 157 | ].add_(out_patch_mask) 158 | pbar.update(1) 159 | output = E.div_(W) 160 | 161 | return output 162 | -------------------------------------------------------------------------------- /modules/codeformer_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | import cv2 6 | import torch 7 | 8 | import modules.face_restoration 9 | import modules.shared 10 | from modules import shared, devices, modelloader 11 | from modules.paths import script_path, models_path 12 | 13 | # codeformer people made a choice to include modified basicsr library to their project which makes 14 | # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. 15 | # I am making a choice to include some files from codeformer to work around this issue. 16 | model_dir = "Codeformer" 17 | model_path = os.path.join(models_path, model_dir) 18 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' 19 | 20 | have_codeformer = False 21 | codeformer = None 22 | 23 | 24 | def setup_model(dirname): 25 | global model_path 26 | if not os.path.exists(model_path): 27 | os.makedirs(model_path) 28 | 29 | path = modules.paths.paths.get("CodeFormer", None) 30 | if path is None: 31 | return 32 | 33 | try: 34 | from torchvision.transforms.functional import normalize 35 | from modules.codeformer.codeformer_arch import CodeFormer 36 | from basicsr.utils.download_util import load_file_from_url 37 | from basicsr.utils import imwrite, img2tensor, tensor2img 38 | from facelib.utils.face_restoration_helper import FaceRestoreHelper 39 | from modules.shared import cmd_opts 40 | 41 | net_class = CodeFormer 42 | 43 | class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): 44 | def name(self): 45 | return "CodeFormer" 46 | 47 | def __init__(self, dirname): 48 | self.net = None 49 | self.face_helper = None 50 | self.cmd_dir = dirname 51 | 52 | def create_models(self): 53 | 54 | if self.net is not None and self.face_helper is not None: 55 | self.net.to(devices.device_codeformer) 56 | return self.net, self.face_helper 57 | model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth') 58 | if len(model_paths) != 0: 59 | ckpt_path = model_paths[0] 60 | else: 61 | print("Unable to load codeformer model.") 62 | return None, None 63 | net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer) 64 | checkpoint = torch.load(ckpt_path)['params_ema'] 65 | net.load_state_dict(checkpoint) 66 | net.eval() 67 | 68 | face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer) 69 | 70 | self.net = net 71 | self.face_helper = face_helper 72 | 73 | return net, face_helper 74 | 75 | def send_model_to(self, device): 76 | self.net.to(device) 77 | self.face_helper.face_det.to(device) 78 | self.face_helper.face_parse.to(device) 79 | 80 | def restore(self, np_image, w=None): 81 | np_image = np_image[:, :, ::-1] 82 | 83 | original_resolution = np_image.shape[0:2] 84 | 85 | self.create_models() 86 | if self.net is None or self.face_helper is None: 87 | return np_image 88 | 89 | self.send_model_to(devices.device_codeformer) 90 | 91 | self.face_helper.clean_all() 92 | self.face_helper.read_image(np_image) 93 | self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) 94 | self.face_helper.align_warp_face() 95 | 96 | for idx, cropped_face in enumerate(self.face_helper.cropped_faces): 97 | cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) 98 | normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 99 | cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) 100 | 101 | try: 102 | with torch.no_grad(): 103 | output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] 104 | restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) 105 | del output 106 | torch.cuda.empty_cache() 107 | except Exception as error: 108 | print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr) 109 | restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) 110 | 111 | restored_face = restored_face.astype('uint8') 112 | self.face_helper.add_restored_face(restored_face) 113 | 114 | self.face_helper.get_inverse_affine(None) 115 | 116 | restored_img = self.face_helper.paste_faces_to_input_image() 117 | restored_img = restored_img[:, :, ::-1] 118 | 119 | if original_resolution != restored_img.shape[0:2]: 120 | restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) 121 | 122 | self.face_helper.clean_all() 123 | 124 | if shared.opts.face_restoration_unload: 125 | self.send_model_to(devices.cpu) 126 | 127 | return restored_img 128 | 129 | global have_codeformer 130 | have_codeformer = True 131 | 132 | global codeformer 133 | codeformer = FaceRestorerCodeFormer(dirname) 134 | shared.face_restorers.append(codeformer) 135 | 136 | except Exception: 137 | print("Error setting up CodeFormer:", file=sys.stderr) 138 | print(traceback.format_exc(), file=sys.stderr) 139 | 140 | # sys.path = stored_sys_path 141 | -------------------------------------------------------------------------------- /modules/esrgan_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | from basicsr.utils.download_util import load_file_from_url 7 | 8 | import modules.esrgan_model_arch as arch 9 | from modules import shared, modelloader, images, devices 10 | from modules.upscaler import Upscaler, UpscalerData 11 | from modules.shared import opts 12 | 13 | 14 | def fix_model_layers(crt_model, pretrained_net): 15 | # this code is adapted from https://github.com/xinntao/ESRGAN 16 | if 'conv_first.weight' in pretrained_net: 17 | return pretrained_net 18 | 19 | if 'model.0.weight' not in pretrained_net: 20 | is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"] 21 | if is_realesrgan: 22 | raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.") 23 | else: 24 | raise Exception("The file is not a ESRGAN model.") 25 | 26 | crt_net = crt_model.state_dict() 27 | load_net_clean = {} 28 | for k, v in pretrained_net.items(): 29 | if k.startswith('module.'): 30 | load_net_clean[k[7:]] = v 31 | else: 32 | load_net_clean[k] = v 33 | pretrained_net = load_net_clean 34 | 35 | tbd = [] 36 | for k, v in crt_net.items(): 37 | tbd.append(k) 38 | 39 | # directly copy 40 | for k, v in crt_net.items(): 41 | if k in pretrained_net and pretrained_net[k].size() == v.size(): 42 | crt_net[k] = pretrained_net[k] 43 | tbd.remove(k) 44 | 45 | crt_net['conv_first.weight'] = pretrained_net['model.0.weight'] 46 | crt_net['conv_first.bias'] = pretrained_net['model.0.bias'] 47 | 48 | for k in tbd.copy(): 49 | if 'RDB' in k: 50 | ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') 51 | if '.weight' in k: 52 | ori_k = ori_k.replace('.weight', '.0.weight') 53 | elif '.bias' in k: 54 | ori_k = ori_k.replace('.bias', '.0.bias') 55 | crt_net[k] = pretrained_net[ori_k] 56 | tbd.remove(k) 57 | 58 | crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight'] 59 | crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias'] 60 | crt_net['upconv1.weight'] = pretrained_net['model.3.weight'] 61 | crt_net['upconv1.bias'] = pretrained_net['model.3.bias'] 62 | crt_net['upconv2.weight'] = pretrained_net['model.6.weight'] 63 | crt_net['upconv2.bias'] = pretrained_net['model.6.bias'] 64 | crt_net['HRconv.weight'] = pretrained_net['model.8.weight'] 65 | crt_net['HRconv.bias'] = pretrained_net['model.8.bias'] 66 | crt_net['conv_last.weight'] = pretrained_net['model.10.weight'] 67 | crt_net['conv_last.bias'] = pretrained_net['model.10.bias'] 68 | 69 | return crt_net 70 | 71 | class UpscalerESRGAN(Upscaler): 72 | def __init__(self, dirname): 73 | self.name = "ESRGAN" 74 | self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth" 75 | self.model_name = "ESRGAN_4x" 76 | self.scalers = [] 77 | self.user_path = dirname 78 | super().__init__() 79 | model_paths = self.find_models(ext_filter=[".pt", ".pth"]) 80 | scalers = [] 81 | if len(model_paths) == 0: 82 | scaler_data = UpscalerData(self.model_name, self.model_url, self, 4) 83 | scalers.append(scaler_data) 84 | for file in model_paths: 85 | if "http" in file: 86 | name = self.model_name 87 | else: 88 | name = modelloader.friendly_name(file) 89 | 90 | scaler_data = UpscalerData(name, file, self, 4) 91 | self.scalers.append(scaler_data) 92 | 93 | def do_upscale(self, img, selected_model): 94 | model = self.load_model(selected_model) 95 | if model is None: 96 | return img 97 | model.to(devices.device_esrgan) 98 | img = esrgan_upscale(model, img) 99 | return img 100 | 101 | def load_model(self, path: str): 102 | if "http" in path: 103 | filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, 104 | file_name="%s.pth" % self.model_name, 105 | progress=True) 106 | else: 107 | filename = path 108 | if not os.path.exists(filename) or filename is None: 109 | print("Unable to load %s from %s" % (self.model_path, filename)) 110 | return None 111 | 112 | pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) 113 | crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) 114 | 115 | pretrained_net = fix_model_layers(crt_model, pretrained_net) 116 | crt_model.load_state_dict(pretrained_net) 117 | crt_model.eval() 118 | 119 | return crt_model 120 | 121 | 122 | def upscale_without_tiling(model, img): 123 | img = np.array(img) 124 | img = img[:, :, ::-1] 125 | img = np.moveaxis(img, 2, 0) / 255 126 | img = torch.from_numpy(img).float() 127 | img = img.unsqueeze(0).to(devices.device_esrgan) 128 | with torch.no_grad(): 129 | output = model(img) 130 | output = output.squeeze().float().cpu().clamp_(0, 1).numpy() 131 | output = 255. * np.moveaxis(output, 0, 2) 132 | output = output.astype(np.uint8) 133 | output = output[:, :, ::-1] 134 | return Image.fromarray(output, 'RGB') 135 | 136 | 137 | def esrgan_upscale(model, img): 138 | if opts.ESRGAN_tile == 0: 139 | return upscale_without_tiling(model, img) 140 | 141 | grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) 142 | newtiles = [] 143 | scale_factor = 1 144 | 145 | for y, h, row in grid.tiles: 146 | newrow = [] 147 | for tiledata in row: 148 | x, w, tile = tiledata 149 | 150 | output = upscale_without_tiling(model, tile) 151 | scale_factor = output.width // tile.width 152 | 153 | newrow.append([x * scale_factor, w * scale_factor, output]) 154 | newtiles.append([y * scale_factor, h * scale_factor, newrow]) 155 | 156 | newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) 157 | output = images.combine_grid(newgrid) 158 | return output 159 | -------------------------------------------------------------------------------- /modules/deepbooru.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from concurrent.futures import ProcessPoolExecutor 3 | import multiprocessing 4 | import time 5 | import re 6 | 7 | re_special = re.compile(r'([\\()])') 8 | 9 | def get_deepbooru_tags(pil_image): 10 | """ 11 | This method is for running only one image at a time for simple use. Used to the img2img interrogate. 12 | """ 13 | from modules import shared # prevents circular reference 14 | 15 | try: 16 | create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts()) 17 | return get_tags_from_process(pil_image) 18 | finally: 19 | release_process() 20 | 21 | 22 | OPT_INCLUDE_RANKS = "include_ranks" 23 | def create_deepbooru_opts(): 24 | from modules import shared 25 | 26 | return { 27 | "use_spaces": shared.opts.deepbooru_use_spaces, 28 | "use_escape": shared.opts.deepbooru_escape, 29 | "alpha_sort": shared.opts.deepbooru_sort_alpha, 30 | OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks, 31 | } 32 | 33 | 34 | def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts): 35 | model, tags = get_deepbooru_tags_model() 36 | while True: # while process is running, keep monitoring queue for new image 37 | pil_image = queue.get() 38 | if pil_image == "QUIT": 39 | break 40 | else: 41 | deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts) 42 | 43 | 44 | def create_deepbooru_process(threshold, deepbooru_opts): 45 | """ 46 | Creates deepbooru process. A queue is created to send images into the process. This enables multiple images 47 | to be processed in a row without reloading the model or creating a new process. To return the data, a shared 48 | dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned 49 | to the dictionary and the method adding the image to the queue should wait for this value to be updated with 50 | the tags. 51 | """ 52 | from modules import shared # prevents circular reference 53 | shared.deepbooru_process_manager = multiprocessing.Manager() 54 | shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() 55 | shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() 56 | shared.deepbooru_process_return["value"] = -1 57 | shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts)) 58 | shared.deepbooru_process.start() 59 | 60 | 61 | def get_tags_from_process(image): 62 | from modules import shared 63 | 64 | shared.deepbooru_process_return["value"] = -1 65 | shared.deepbooru_process_queue.put(image) 66 | while shared.deepbooru_process_return["value"] == -1: 67 | time.sleep(0.2) 68 | caption = shared.deepbooru_process_return["value"] 69 | shared.deepbooru_process_return["value"] = -1 70 | 71 | return caption 72 | 73 | 74 | def release_process(): 75 | """ 76 | Stops the deepbooru process to return used memory 77 | """ 78 | from modules import shared # prevents circular reference 79 | shared.deepbooru_process_queue.put("QUIT") 80 | shared.deepbooru_process.join() 81 | shared.deepbooru_process_queue = None 82 | shared.deepbooru_process = None 83 | shared.deepbooru_process_return = None 84 | shared.deepbooru_process_manager = None 85 | 86 | def get_deepbooru_tags_model(): 87 | import deepdanbooru as dd 88 | import tensorflow as tf 89 | import numpy as np 90 | this_folder = os.path.dirname(__file__) 91 | model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru')) 92 | if not os.path.exists(os.path.join(model_path, 'project.json')): 93 | # there is no point importing these every time 94 | import zipfile 95 | from basicsr.utils.download_util import load_file_from_url 96 | load_file_from_url( 97 | r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", 98 | model_path) 99 | with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref: 100 | zip_ref.extractall(model_path) 101 | os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip")) 102 | 103 | tags = dd.project.load_tags_from_project(model_path) 104 | model = dd.project.load_model_from_project( 105 | model_path, compile_model=False 106 | ) 107 | return model, tags 108 | 109 | 110 | def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts): 111 | import deepdanbooru as dd 112 | import tensorflow as tf 113 | import numpy as np 114 | 115 | alpha_sort = deepbooru_opts['alpha_sort'] 116 | use_spaces = deepbooru_opts['use_spaces'] 117 | use_escape = deepbooru_opts['use_escape'] 118 | include_ranks = deepbooru_opts['include_ranks'] 119 | 120 | width = model.input_shape[2] 121 | height = model.input_shape[1] 122 | image = np.array(pil_image) 123 | image = tf.image.resize( 124 | image, 125 | size=(height, width), 126 | method=tf.image.ResizeMethod.AREA, 127 | preserve_aspect_ratio=True, 128 | ) 129 | image = image.numpy() # EagerTensor to np.array 130 | image = dd.image.transform_and_pad_image(image, width, height) 131 | image = image / 255.0 132 | image_shape = image.shape 133 | image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2])) 134 | 135 | y = model.predict(image)[0] 136 | 137 | result_dict = {} 138 | 139 | for i, tag in enumerate(tags): 140 | result_dict[tag] = y[i] 141 | 142 | unsorted_tags_in_theshold = [] 143 | result_tags_print = [] 144 | for tag in tags: 145 | if result_dict[tag] >= threshold: 146 | if tag.startswith("rating:"): 147 | continue 148 | unsorted_tags_in_theshold.append((result_dict[tag], tag)) 149 | result_tags_print.append(f'{result_dict[tag]} {tag}') 150 | 151 | # sort tags 152 | result_tags_out = [] 153 | sort_ndx = 0 154 | if alpha_sort: 155 | sort_ndx = 1 156 | 157 | # sort by reverse by likelihood and normal for alpha, and format tag text as requested 158 | unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) 159 | for weight, tag in unsorted_tags_in_theshold: 160 | tag_outformat = tag 161 | if use_spaces: 162 | tag_outformat = tag_outformat.replace('_', ' ') 163 | if use_escape: 164 | tag_outformat = re.sub(re_special, r'\\\1', tag_outformat) 165 | if include_ranks: 166 | tag_outformat = f"({tag_outformat}:{weight:.3f})" 167 | 168 | result_tags_out.append(tag_outformat) 169 | 170 | print('\n'.join(sorted(result_tags_print, reverse=True))) 171 | 172 | return ', '.join(result_tags_out) 173 | -------------------------------------------------------------------------------- /javascript/ui.js: -------------------------------------------------------------------------------- 1 | // various functions for interation with ui.py not large enough to warrant putting them in separate files 2 | 3 | function set_theme(theme){ 4 | gradioURL = window.location.href 5 | if (!gradioURL.includes('?__theme=')) { 6 | window.location.replace(gradioURL + '?__theme=' + theme); 7 | } 8 | } 9 | 10 | function selected_gallery_index(){ 11 | var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item') 12 | var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2') 13 | 14 | var result = -1 15 | buttons.forEach(function(v, i){ if(v==button) { result = i } }) 16 | 17 | return result 18 | } 19 | 20 | function extract_image_from_gallery(gallery){ 21 | if(gallery.length == 1){ 22 | return gallery[0] 23 | } 24 | 25 | index = selected_gallery_index() 26 | 27 | if (index < 0 || index >= gallery.length){ 28 | return [null] 29 | } 30 | 31 | return gallery[index]; 32 | } 33 | 34 | function args_to_array(args){ 35 | res = [] 36 | for(var i=0;i label > textarea"); 196 | txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button")); 197 | } 198 | if (!img2img_textarea) { 199 | img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); 200 | img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); 201 | } 202 | }) 203 | 204 | let txt2img_textarea, img2img_textarea = undefined; 205 | let wait_time = 800 206 | let token_timeout; 207 | 208 | function update_txt2img_tokens(...args) { 209 | update_token_counter("txt2img_token_button") 210 | if (args.length == 2) 211 | return args[0] 212 | return args; 213 | } 214 | 215 | function update_img2img_tokens(...args) { 216 | update_token_counter("img2img_token_button") 217 | if (args.length == 2) 218 | return args[0] 219 | return args; 220 | } 221 | 222 | function update_token_counter(button_id) { 223 | if (token_timeout) 224 | clearTimeout(token_timeout); 225 | token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); 226 | } 227 | 228 | function restart_reload(){ 229 | document.body.innerHTML='

Reloading...

'; 230 | setTimeout(function(){location.reload()},2000) 231 | } 232 | --------------------------------------------------------------------------------