├── test ├── __init__.py ├── test_files │ ├── mask_basic.png │ └── img2img_basic.png ├── server_poll.py ├── extras_test.py ├── img2img_test.py ├── utils_test.py └── txt2img_test.py ├── models ├── VAE │ └── Put VAE here.txt ├── Stable-diffusion │ └── Put Stable Diffusion checkpoints here.txt └── deepbooru │ └── Put your deepbooru release project folder here.txt ├── extensions └── put extensions here.txt ├── localizations └── Put localization files here.txt ├── textual_inversion_templates ├── none.txt ├── style.txt ├── subject.txt ├── hypernetwork.txt ├── style_filewords.txt └── subject_filewords.txt ├── embeddings └── Place Textual Inversion embeddings here.txt ├── screenshot.png ├── txt2img_Screenshot.png ├── webui-user.bat ├── .pylintrc ├── modules ├── textual_inversion │ ├── test_embedding.png │ ├── ui.py │ ├── learn_schedule.py │ └── dataset.py ├── errors.py ├── face_restoration.py ├── ngrok.py ├── artists.py ├── localization.py ├── safety.py ├── paths.py ├── ldsr_model.py ├── hypernetworks │ └── ui.py ├── txt2img.py ├── extensions.py ├── memmon.py ├── devices.py ├── scunet_model.py ├── masking.py ├── lowvram.py ├── styles.py ├── gfpgan_model.py ├── upscaler.py ├── realesrgan_model.py ├── safe.py ├── img2img.py ├── swinir_model.py ├── codeformer_model.py ├── modelloader.py └── deepbooru.py ├── javascript ├── textualInversion.js ├── imageParams.js ├── extensions.js ├── imageMaskFix.js ├── notification.js ├── dragdrop.js ├── edit-attention.js ├── aspectRatioOverlay.js ├── localization.js ├── contextMenus.js ├── progressbar.js └── ui.js ├── environment-wsl2.yaml ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── feature_request.yml │ └── bug_report.yml ├── workflows │ └── on_pull_request.yaml └── PULL_REQUEST_TEMPLATE │ └── pull_request_template.md ├── requirements.txt ├── .gitignore ├── requirements_versions.txt ├── CODEOWNERS ├── scripts ├── custom_code.py ├── loopback.py ├── prompt_matrix.py ├── sd_upscale.py ├── prompts_from_file.py └── poor_mans_outpainting.py ├── webui-user.sh ├── webui.bat ├── README.md ├── repositories └── stable-diffusion-taiyi │ └── configs │ └── stable-diffusion │ ├── v1-inference.yaml │ └── v1-inference-en.yaml ├── script.js ├── webui.sh └── webui.py /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/VAE/Put VAE here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /extensions/put extensions here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /localizations/Put localization files here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /textual_inversion_templates/none.txt: -------------------------------------------------------------------------------- 1 | picture 2 | -------------------------------------------------------------------------------- /embeddings/Place Textual Inversion embeddings here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/Stable-diffusion/Put Stable Diffusion checkpoints here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/deepbooru/Put your deepbooru release project folder here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-CCNL/stable-diffusion-webui/HEAD/screenshot.png -------------------------------------------------------------------------------- /txt2img_Screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-CCNL/stable-diffusion-webui/HEAD/txt2img_Screenshot.png -------------------------------------------------------------------------------- /test/test_files/mask_basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-CCNL/stable-diffusion-webui/HEAD/test/test_files/mask_basic.png -------------------------------------------------------------------------------- /webui-user.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | set PYTHON= 4 | set GIT= 5 | set VENV_DIR= 6 | set COMMANDLINE_ARGS= 7 | 8 | call webui.bat 9 | -------------------------------------------------------------------------------- /test/test_files/img2img_basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-CCNL/stable-diffusion-webui/HEAD/test/test_files/img2img_basic.png -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html 2 | [MESSAGES CONTROL] 3 | disable=C,R,W,E,I 4 | -------------------------------------------------------------------------------- /modules/textual_inversion/test_embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-CCNL/stable-diffusion-webui/HEAD/modules/textual_inversion/test_embedding.png -------------------------------------------------------------------------------- /javascript/textualInversion.js: -------------------------------------------------------------------------------- 1 | 2 | 3 | function start_training_textual_inversion(){ 4 | requestProgress('ti') 5 | gradioApp().querySelector('#ti_error').innerHTML='' 6 | 7 | return args_to_array(arguments) 8 | } 9 | -------------------------------------------------------------------------------- /environment-wsl2.yaml: -------------------------------------------------------------------------------- 1 | name: automatic 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.10 7 | - pip=22.2.2 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.23.1 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: WebUI Community Support 4 | url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions 5 | about: Please ask and answer questions here. 6 | -------------------------------------------------------------------------------- /modules/errors.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | 4 | 5 | def run(code, task): 6 | try: 7 | code() 8 | except Exception as e: 9 | print(f"{task}: {type(e).__name__}", file=sys.stderr) 10 | print(traceback.format_exc(), file=sys.stderr) 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | basicsr 2 | diffusers 3 | fairscale==0.4.4 4 | fonts 5 | font-roboto 6 | gfpgan 7 | gradio==3.8 8 | invisible-watermark 9 | numpy 10 | omegaconf 11 | opencv-python 12 | requests 13 | piexif 14 | Pillow 15 | pytorch_lightning==1.7.7 16 | realesrgan 17 | scikit-image>=0.19 18 | timm==0.4.12 19 | transformers==4.19.2 20 | torch 21 | einops 22 | jsonmerge 23 | clean-fid 24 | resize-right 25 | torchdiffeq 26 | kornia 27 | lark 28 | inflection 29 | GitPython 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.ckpt 3 | *.pth 4 | /ESRGAN/* 5 | /SwinIR/* 6 | /repositories 7 | /venv 8 | /tmp 9 | /model.ckpt 10 | /models/**/* 11 | /GFPGANv1.3.pth 12 | /gfpgan/weights/*.pth 13 | /ui-config.json 14 | /outputs 15 | /config.json 16 | /log 17 | /webui.settings.bat 18 | /embeddings 19 | /styles.csv 20 | /params.txt 21 | /styles.csv.bak 22 | /webui-user.bat 23 | /webui-user.sh 24 | /interrogate 25 | /user.css 26 | /.idea 27 | notification.mp3 28 | /SwinIR 29 | /textual_inversion 30 | .vscode 31 | /extensions 32 | /test/stdout.txt 33 | /test/stderr.txt 34 | -------------------------------------------------------------------------------- /requirements_versions.txt: -------------------------------------------------------------------------------- 1 | transformers==4.19.2 2 | diffusers==0.3.0 3 | basicsr==1.4.2 4 | gfpgan==1.3.8 5 | gradio==3.8 6 | numpy==1.23.3 7 | Pillow==9.2.0 8 | realesrgan==0.3.0 9 | torch 10 | omegaconf==2.2.3 11 | pytorch_lightning==1.7.6 12 | scikit-image==0.19.2 13 | fonts 14 | font-roboto 15 | timm==0.6.7 16 | fairscale==0.4.9 17 | piexif==1.1.3 18 | einops==0.4.1 19 | jsonmerge==1.8.0 20 | clean-fid==0.1.29 21 | resize-right==0.0.2 22 | torchdiffeq==0.2.3 23 | kornia==0.6.7 24 | lark==1.1.2 25 | inflection==0.5.1 26 | GitPython==3.1.27 27 | -------------------------------------------------------------------------------- /modules/face_restoration.py: -------------------------------------------------------------------------------- 1 | from modules import shared 2 | 3 | 4 | class FaceRestoration: 5 | def name(self): 6 | return "None" 7 | 8 | def restore(self, np_image): 9 | return np_image 10 | 11 | 12 | def restore_faces(np_image): 13 | face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None] 14 | if len(face_restorers) == 0: 15 | return np_image 16 | 17 | face_restorer = face_restorers[0] 18 | 19 | return face_restorer.restore(np_image) 20 | -------------------------------------------------------------------------------- /test/server_poll.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import requests 3 | import time 4 | 5 | 6 | def run_tests(): 7 | timeout_threshold = 240 8 | start_time = time.time() 9 | while time.time()-start_time < timeout_threshold: 10 | try: 11 | requests.head("http://localhost:7860/") 12 | break 13 | except requests.exceptions.ConnectionError: 14 | pass 15 | if time.time()-start_time < timeout_threshold: 16 | suite = unittest.TestLoader().discover('', pattern='*_test.py') 17 | result = unittest.TextTestRunner(verbosity=2).run(suite) 18 | else: 19 | print("Launch unsuccessful") 20 | -------------------------------------------------------------------------------- /textual_inversion_templates/style.txt: -------------------------------------------------------------------------------- 1 | a painting, art by [name] 2 | a rendering, art by [name] 3 | a cropped painting, art by [name] 4 | the painting, art by [name] 5 | a clean painting, art by [name] 6 | a dirty painting, art by [name] 7 | a dark painting, art by [name] 8 | a picture, art by [name] 9 | a cool painting, art by [name] 10 | a close-up painting, art by [name] 11 | a bright painting, art by [name] 12 | a cropped painting, art by [name] 13 | a good painting, art by [name] 14 | a close-up painting, art by [name] 15 | a rendition, art by [name] 16 | a nice painting, art by [name] 17 | a small painting, art by [name] 18 | a weird painting, art by [name] 19 | a large painting, art by [name] 20 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @AUTOMATIC1111 2 | 3 | # if you were managing a localization and were removed from this file, this is because 4 | # the intended way to do localizations now is via extensions. See: 5 | # https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions 6 | # Make a repo with your localization and since you are still listed as a collaborator 7 | # you can add it to the wiki page yourself. This change is because some people complained 8 | # the git commit log is cluttered with things unrelated to almost everyone and 9 | # because I believe this is the best overall for the project to handle localizations almost 10 | # entirely without my oversight. 11 | 12 | 13 | -------------------------------------------------------------------------------- /modules/ngrok.py: -------------------------------------------------------------------------------- 1 | from pyngrok import ngrok, conf, exception 2 | 3 | 4 | def connect(token, port, region): 5 | if token == None: 6 | token = 'None' 7 | config = conf.PyngrokConfig( 8 | auth_token=token, region=region 9 | ) 10 | try: 11 | public_url = ngrok.connect(port, pyngrok_config=config).public_url 12 | except exception.PyngrokNgrokError: 13 | print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' 14 | f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') 15 | else: 16 | print(f'ngrok connected to localhost:{port}! URL: {public_url}\n' 17 | 'You can use this link after the launch is complete.') 18 | -------------------------------------------------------------------------------- /javascript/imageParams.js: -------------------------------------------------------------------------------- 1 | window.onload = (function(){ 2 | window.addEventListener('drop', e => { 3 | const target = e.composedPath()[0]; 4 | const idx = selected_gallery_index(); 5 | if (target.placeholder.indexOf("Prompt") == -1) return; 6 | 7 | let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; 8 | 9 | e.stopPropagation(); 10 | e.preventDefault(); 11 | const imgParent = gradioApp().getElementById(prompt_target); 12 | const files = e.dataTransfer.files; 13 | const fileInput = imgParent.querySelector('input[type="file"]'); 14 | if ( fileInput ) { 15 | fileInput.files = files; 16 | fileInput.dispatchEvent(new Event('change')); 17 | } 18 | }); 19 | }); 20 | -------------------------------------------------------------------------------- /modules/artists.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import csv 3 | from collections import namedtuple 4 | 5 | Artist = namedtuple("Artist", ['name', 'weight', 'category']) 6 | 7 | 8 | class ArtistsDatabase: 9 | def __init__(self, filename): 10 | self.cats = set() 11 | self.artists = [] 12 | 13 | if not os.path.exists(filename): 14 | return 15 | 16 | with open(filename, "r", newline='', encoding="utf8") as file: 17 | reader = csv.DictReader(file) 18 | 19 | for row in reader: 20 | artist = Artist(row["artist"], float(row["score"]), row["category"]) 21 | self.artists.append(artist) 22 | self.cats.add(artist.category) 23 | 24 | def categories(self): 25 | return sorted(self.cats) 26 | -------------------------------------------------------------------------------- /textual_inversion_templates/subject.txt: -------------------------------------------------------------------------------- 1 | a photo of a [name] 2 | a rendering of a [name] 3 | a cropped photo of the [name] 4 | the photo of a [name] 5 | a photo of a clean [name] 6 | a photo of a dirty [name] 7 | a dark photo of the [name] 8 | a photo of my [name] 9 | a photo of the cool [name] 10 | a close-up photo of a [name] 11 | a bright photo of the [name] 12 | a cropped photo of a [name] 13 | a photo of the [name] 14 | a good photo of the [name] 15 | a photo of one [name] 16 | a close-up photo of the [name] 17 | a rendition of the [name] 18 | a photo of the clean [name] 19 | a rendition of a [name] 20 | a photo of a nice [name] 21 | a good photo of a [name] 22 | a photo of the nice [name] 23 | a photo of the small [name] 24 | a photo of the weird [name] 25 | a photo of the large [name] 26 | a photo of a cool [name] 27 | a photo of a small [name] 28 | -------------------------------------------------------------------------------- /test/extras_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | 4 | class TestExtrasWorking(unittest.TestCase): 5 | def setUp(self): 6 | self.url_img2img = "http://localhost:7860/sdapi/v1/extra-single-image" 7 | self.simple_extras = { 8 | "resize_mode": 0, 9 | "show_extras_results": True, 10 | "gfpgan_visibility": 0, 11 | "codeformer_visibility": 0, 12 | "codeformer_weight": 0, 13 | "upscaling_resize": 2, 14 | "upscaling_resize_w": 512, 15 | "upscaling_resize_h": 512, 16 | "upscaling_crop": True, 17 | "upscaler_1": "None", 18 | "upscaler_2": "None", 19 | "extras_upscaler_2_visibility": 0, 20 | "image": "" 21 | } 22 | 23 | 24 | class TestExtrasCorrectness(unittest.TestCase): 25 | pass 26 | 27 | 28 | if __name__ == "__main__": 29 | unittest.main() 30 | -------------------------------------------------------------------------------- /textual_inversion_templates/hypernetwork.txt: -------------------------------------------------------------------------------- 1 | a photo of a [filewords] 2 | a rendering of a [filewords] 3 | a cropped photo of the [filewords] 4 | the photo of a [filewords] 5 | a photo of a clean [filewords] 6 | a photo of a dirty [filewords] 7 | a dark photo of the [filewords] 8 | a photo of my [filewords] 9 | a photo of the cool [filewords] 10 | a close-up photo of a [filewords] 11 | a bright photo of the [filewords] 12 | a cropped photo of a [filewords] 13 | a photo of the [filewords] 14 | a good photo of the [filewords] 15 | a photo of one [filewords] 16 | a close-up photo of the [filewords] 17 | a rendition of the [filewords] 18 | a photo of the clean [filewords] 19 | a rendition of a [filewords] 20 | a photo of a nice [filewords] 21 | a good photo of a [filewords] 22 | a photo of the nice [filewords] 23 | a photo of the small [filewords] 24 | a photo of the weird [filewords] 25 | a photo of the large [filewords] 26 | a photo of a cool [filewords] 27 | a photo of a small [filewords] 28 | -------------------------------------------------------------------------------- /textual_inversion_templates/style_filewords.txt: -------------------------------------------------------------------------------- 1 | a painting of [filewords], art by [name] 2 | a rendering of [filewords], art by [name] 3 | a cropped painting of [filewords], art by [name] 4 | the painting of [filewords], art by [name] 5 | a clean painting of [filewords], art by [name] 6 | a dirty painting of [filewords], art by [name] 7 | a dark painting of [filewords], art by [name] 8 | a picture of [filewords], art by [name] 9 | a cool painting of [filewords], art by [name] 10 | a close-up painting of [filewords], art by [name] 11 | a bright painting of [filewords], art by [name] 12 | a cropped painting of [filewords], art by [name] 13 | a good painting of [filewords], art by [name] 14 | a close-up painting of [filewords], art by [name] 15 | a rendition of [filewords], art by [name] 16 | a nice painting of [filewords], art by [name] 17 | a small painting of [filewords], art by [name] 18 | a weird painting of [filewords], art by [name] 19 | a large painting of [filewords], art by [name] 20 | -------------------------------------------------------------------------------- /javascript/extensions.js: -------------------------------------------------------------------------------- 1 | 2 | function extensions_apply(_, _){ 3 | disable = [] 4 | update = [] 5 | gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ 6 | if(x.name.startsWith("enable_") && ! x.checked) 7 | disable.push(x.name.substr(7)) 8 | 9 | if(x.name.startsWith("update_") && x.checked) 10 | update.push(x.name.substr(7)) 11 | }) 12 | 13 | restart_reload() 14 | 15 | return [JSON.stringify(disable), JSON.stringify(update)] 16 | } 17 | 18 | function extensions_check(){ 19 | gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ 20 | x.innerHTML = "Loading..." 21 | }) 22 | 23 | return [] 24 | } 25 | 26 | function install_extension_from_index(button, url){ 27 | button.disabled = "disabled" 28 | button.value = "Installing..." 29 | 30 | textarea = gradioApp().querySelector('#extension_to_install textarea') 31 | textarea.value = url 32 | textarea.dispatchEvent(new Event("input", { bubbles: true })) 33 | 34 | gradioApp().querySelector('#install_extension_button').click() 35 | } 36 | -------------------------------------------------------------------------------- /modules/localization.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import traceback 5 | 6 | 7 | localizations = {} 8 | 9 | 10 | def list_localizations(dirname): 11 | localizations.clear() 12 | 13 | for file in os.listdir(dirname): 14 | fn, ext = os.path.splitext(file) 15 | if ext.lower() != ".json": 16 | continue 17 | 18 | localizations[fn] = os.path.join(dirname, file) 19 | 20 | from modules import scripts 21 | for file in scripts.list_scripts("localizations", ".json"): 22 | fn, ext = os.path.splitext(file.filename) 23 | localizations[fn] = file.path 24 | 25 | 26 | def localization_js(current_localization_name): 27 | fn = localizations.get(current_localization_name, None) 28 | data = {} 29 | if fn is not None: 30 | try: 31 | with open(fn, "r", encoding="utf8") as file: 32 | data = json.load(file) 33 | except Exception: 34 | print(f"Error loading localization from {fn}:", file=sys.stderr) 35 | print(traceback.format_exc(), file=sys.stderr) 36 | 37 | return f"var localization = {json.dumps(data)}\n" 38 | -------------------------------------------------------------------------------- /textual_inversion_templates/subject_filewords.txt: -------------------------------------------------------------------------------- 1 | a photo of a [name], [filewords] 2 | a rendering of a [name], [filewords] 3 | a cropped photo of the [name], [filewords] 4 | the photo of a [name], [filewords] 5 | a photo of a clean [name], [filewords] 6 | a photo of a dirty [name], [filewords] 7 | a dark photo of the [name], [filewords] 8 | a photo of my [name], [filewords] 9 | a photo of the cool [name], [filewords] 10 | a close-up photo of a [name], [filewords] 11 | a bright photo of the [name], [filewords] 12 | a cropped photo of a [name], [filewords] 13 | a photo of the [name], [filewords] 14 | a good photo of the [name], [filewords] 15 | a photo of one [name], [filewords] 16 | a close-up photo of the [name], [filewords] 17 | a rendition of the [name], [filewords] 18 | a photo of the clean [name], [filewords] 19 | a rendition of a [name], [filewords] 20 | a photo of a nice [name], [filewords] 21 | a good photo of a [name], [filewords] 22 | a photo of the nice [name], [filewords] 23 | a photo of the small [name], [filewords] 24 | a photo of the weird [name], [filewords] 25 | a photo of the large [name], [filewords] 26 | a photo of a cool [name], [filewords] 27 | a photo of a small [name], [filewords] 28 | -------------------------------------------------------------------------------- /scripts/custom_code.py: -------------------------------------------------------------------------------- 1 | import modules.scripts as scripts 2 | import gradio as gr 3 | 4 | from modules.processing import Processed 5 | from modules.shared import opts, cmd_opts, state 6 | 7 | class Script(scripts.Script): 8 | 9 | def title(self): 10 | return "Custom code" 11 | 12 | 13 | def show(self, is_img2img): 14 | return cmd_opts.allow_code 15 | 16 | def ui(self, is_img2img): 17 | code = gr.Textbox(label="Python code", lines=1) 18 | 19 | return [code] 20 | 21 | 22 | def run(self, p, code): 23 | assert cmd_opts.allow_code, '--allow-code option must be enabled' 24 | 25 | display_result_data = [[], -1, ""] 26 | 27 | def display(imgs, s=display_result_data[1], i=display_result_data[2]): 28 | display_result_data[0] = imgs 29 | display_result_data[1] = s 30 | display_result_data[2] = i 31 | 32 | from types import ModuleType 33 | compiled = compile(code, '', 'exec') 34 | module = ModuleType("testmodule") 35 | module.__dict__.update(globals()) 36 | module.p = p 37 | module.display = display 38 | exec(compiled, module.__dict__) 39 | 40 | return Processed(p, *display_result_data) 41 | 42 | -------------------------------------------------------------------------------- /webui-user.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ######################################################### 3 | # Uncomment and change the variables below to your need:# 4 | ######################################################### 5 | 6 | # Install directory without trailing slash 7 | #install_dir="/home/$(whoami)" 8 | 9 | # Name of the subdirectory 10 | #clone_dir="stable-diffusion-webui" 11 | 12 | # Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention" 13 | export COMMANDLINE_ARGS="" 14 | 15 | # python3 executable 16 | #python_cmd="python3" 17 | 18 | # git executable 19 | #export GIT="git" 20 | 21 | # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) 22 | #venv_dir="venv" 23 | 24 | # script to launch to start the app 25 | #export LAUNCH_SCRIPT="launch.py" 26 | 27 | # install command for torch 28 | #export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113" 29 | 30 | # Requirements file to use for stable-diffusion-webui 31 | #export REQS_FILE="requirements_versions.txt" 32 | 33 | # Fixed git repos 34 | #export K_DIFFUSION_PACKAGE="" 35 | #export GFPGAN_PACKAGE="" 36 | 37 | # Fixed git commits 38 | #export STABLE_DIFFUSION_COMMIT_HASH="" 39 | #export TAMING_TRANSFORMERS_COMMIT_HASH="" 40 | #export CODEFORMER_COMMIT_HASH="" 41 | #export BLIP_COMMIT_HASH="" 42 | 43 | ########################################### 44 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Suggest an idea for this project 3 | title: "[Feature Request]: " 4 | labels: ["suggestion"] 5 | 6 | body: 7 | - type: checkboxes 8 | attributes: 9 | label: Is there an existing issue for this? 10 | description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit. 11 | options: 12 | - label: I have searched the existing issues and checked the recent builds/commits 13 | required: true 14 | - type: markdown 15 | attributes: 16 | value: | 17 | *Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible* 18 | - type: textarea 19 | id: feature 20 | attributes: 21 | label: What would your feature do ? 22 | description: Tell us about your feature in a very clear and simple way, and what problem it would solve 23 | validations: 24 | required: true 25 | - type: textarea 26 | id: workflow 27 | attributes: 28 | label: Proposed workflow 29 | description: Please provide us with step by step information on how you'd like the feature to be accessed and used 30 | value: | 31 | 1. Go to .... 32 | 2. Press .... 33 | 3. ... 34 | validations: 35 | required: true 36 | - type: textarea 37 | id: misc 38 | attributes: 39 | label: Additional information 40 | description: Add any other context or screenshots about the feature request here. 41 | -------------------------------------------------------------------------------- /javascript/imageMaskFix.js: -------------------------------------------------------------------------------- 1 | /** 2 | * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668 3 | * @see https://github.com/gradio-app/gradio/issues/1721 4 | */ 5 | window.addEventListener( 'resize', () => imageMaskResize()); 6 | function imageMaskResize() { 7 | const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas'); 8 | if ( ! canvases.length ) { 9 | canvases_fixed = false; 10 | window.removeEventListener( 'resize', imageMaskResize ); 11 | return; 12 | } 13 | 14 | const wrapper = canvases[0].closest('.touch-none'); 15 | const previewImage = wrapper.previousElementSibling; 16 | 17 | if ( ! previewImage.complete ) { 18 | previewImage.addEventListener( 'load', () => imageMaskResize()); 19 | return; 20 | } 21 | 22 | const w = previewImage.width; 23 | const h = previewImage.height; 24 | const nw = previewImage.naturalWidth; 25 | const nh = previewImage.naturalHeight; 26 | const portrait = nh > nw; 27 | const factor = portrait; 28 | 29 | const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw); 30 | const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh); 31 | 32 | wrapper.style.width = `${wW}px`; 33 | wrapper.style.height = `${wH}px`; 34 | wrapper.style.left = `0px`; 35 | wrapper.style.top = `0px`; 36 | 37 | canvases.forEach( c => { 38 | c.style.width = c.style.height = ''; 39 | c.style.maxWidth = '100%'; 40 | c.style.maxHeight = '100%'; 41 | c.style.objectFit = 'contain'; 42 | }); 43 | } 44 | 45 | onUiUpdate(() => imageMaskResize()); 46 | -------------------------------------------------------------------------------- /.github/workflows/on_pull_request.yaml: -------------------------------------------------------------------------------- 1 | # See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file 2 | name: Run Linting/Formatting on Pull Requests 3 | 4 | on: 5 | - push 6 | - pull_request 7 | # See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs 8 | # if you want to filter out branches, delete the `- pull_request` and uncomment these lines : 9 | # pull_request: 10 | # branches: 11 | # - master 12 | # branches-ignore: 13 | # - development 14 | 15 | jobs: 16 | lint: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Checkout Code 20 | uses: actions/checkout@v3 21 | - name: Set up Python 3.10 22 | uses: actions/setup-python@v3 23 | with: 24 | python-version: 3.10.6 25 | - uses: actions/cache@v2 26 | with: 27 | path: ~/.cache/pip 28 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} 29 | restore-keys: | 30 | ${{ runner.os }}-pip- 31 | - name: Install PyLint 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install pylint 35 | # This lets PyLint check to see if it can resolve imports 36 | - name: Install dependencies 37 | run : | 38 | export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit" 39 | python launch.py 40 | - name: Analysing the code with pylint 41 | run: | 42 | pylint $(git ls-files '*.py') 43 | -------------------------------------------------------------------------------- /modules/textual_inversion/ui.py: -------------------------------------------------------------------------------- 1 | import html 2 | 3 | import gradio as gr 4 | 5 | import modules.textual_inversion.textual_inversion 6 | import modules.textual_inversion.preprocess 7 | from modules import sd_hijack, shared 8 | 9 | 10 | def create_embedding(name, initialization_text, nvpt, overwrite_old): 11 | filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text) 12 | 13 | sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() 14 | 15 | return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" 16 | 17 | 18 | def preprocess(*args): 19 | modules.textual_inversion.preprocess.preprocess(*args) 20 | 21 | return "Preprocessing finished.", "" 22 | 23 | 24 | def train_embedding(*args): 25 | 26 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' 27 | 28 | apply_optimizations = shared.opts.training_xattention_optimizations 29 | try: 30 | if not apply_optimizations: 31 | sd_hijack.undo_optimizations() 32 | 33 | embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) 34 | 35 | res = f""" 36 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps. 37 | Embedding saved to {html.escape(filename)} 38 | """ 39 | return res, "" 40 | except Exception: 41 | raise 42 | finally: 43 | if not apply_optimizations: 44 | sd_hijack.apply_optimizations() 45 | 46 | -------------------------------------------------------------------------------- /modules/safety.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 3 | from transformers import AutoFeatureExtractor 4 | from PIL import Image 5 | 6 | import modules.shared as shared 7 | 8 | safety_model_id = "CompVis/stable-diffusion-safety-checker" 9 | safety_feature_extractor = None 10 | safety_checker = None 11 | 12 | def numpy_to_pil(images): 13 | """ 14 | Convert a numpy image or a batch of images to a PIL image. 15 | """ 16 | if images.ndim == 3: 17 | images = images[None, ...] 18 | images = (images * 255).round().astype("uint8") 19 | pil_images = [Image.fromarray(image) for image in images] 20 | 21 | return pil_images 22 | 23 | # check and replace nsfw content 24 | def check_safety(x_image): 25 | global safety_feature_extractor, safety_checker 26 | 27 | if safety_feature_extractor is None: 28 | safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) 29 | safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) 30 | 31 | safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") 32 | x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) 33 | 34 | return x_checked_image, has_nsfw_concept 35 | 36 | 37 | def censor_batch(x): 38 | x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy() 39 | x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy) 40 | x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) 41 | 42 | return x 43 | -------------------------------------------------------------------------------- /webui.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | if not defined PYTHON (set PYTHON=python) 4 | if not defined VENV_DIR (set VENV_DIR=venv) 5 | 6 | set ERROR_REPORTING=FALSE 7 | 8 | mkdir tmp 2>NUL 9 | 10 | %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt 11 | if %ERRORLEVEL% == 0 goto :start_venv 12 | echo Couldn't launch python 13 | goto :show_stdout_stderr 14 | 15 | :start_venv 16 | if [%VENV_DIR%] == [-] goto :skip_venv 17 | 18 | dir %VENV_DIR%\Scripts\Python.exe >tmp/stdout.txt 2>tmp/stderr.txt 19 | if %ERRORLEVEL% == 0 goto :activate_venv 20 | 21 | for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i" 22 | echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME% 23 | %PYTHON_FULLNAME% -m venv %VENV_DIR% >tmp/stdout.txt 2>tmp/stderr.txt 24 | if %ERRORLEVEL% == 0 goto :activate_venv 25 | echo Unable to create venv in directory %VENV_DIR% 26 | goto :show_stdout_stderr 27 | 28 | :activate_venv 29 | set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe" 30 | echo venv %PYTHON% 31 | goto :launch 32 | 33 | :skip_venv 34 | 35 | :launch 36 | %PYTHON% launch.py %* 37 | pause 38 | exit /b 39 | 40 | :show_stdout_stderr 41 | 42 | echo. 43 | echo exit code: %errorlevel% 44 | 45 | for /f %%i in ("tmp\stdout.txt") do set size=%%~zi 46 | if %size% equ 0 goto :show_stderr 47 | echo. 48 | echo stdout: 49 | type tmp\stdout.txt 50 | 51 | :show_stderr 52 | for /f %%i in ("tmp\stderr.txt") do set size=%%~zi 53 | if %size% equ 0 goto :show_stderr 54 | echo. 55 | echo stderr: 56 | type tmp\stderr.txt 57 | 58 | :endofscript 59 | 60 | echo. 61 | echo Launch unsuccessful. Exiting. 62 | pause 63 | -------------------------------------------------------------------------------- /javascript/notification.js: -------------------------------------------------------------------------------- 1 | // Monitors the gallery and sends a browser notification when the leading image is new. 2 | 3 | let lastHeadImg = null; 4 | 5 | notificationButton = null 6 | 7 | onUiUpdate(function(){ 8 | if(notificationButton == null){ 9 | notificationButton = gradioApp().getElementById('request_notifications') 10 | 11 | if(notificationButton != null){ 12 | notificationButton.addEventListener('click', function (evt) { 13 | Notification.requestPermission(); 14 | },true); 15 | } 16 | } 17 | 18 | const galleryPreviews = gradioApp().querySelectorAll('img.h-full.w-full.overflow-hidden'); 19 | 20 | if (galleryPreviews == null) return; 21 | 22 | const headImg = galleryPreviews[0]?.src; 23 | 24 | if (headImg == null || headImg == lastHeadImg) return; 25 | 26 | lastHeadImg = headImg; 27 | 28 | // play notification sound if available 29 | gradioApp().querySelector('#audio_notification audio')?.play(); 30 | 31 | if (document.hasFocus()) return; 32 | 33 | // Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated. 34 | const imgs = new Set(Array.from(galleryPreviews).map(img => img.src)); 35 | 36 | const notification = new Notification( 37 | 'Stable Diffusion', 38 | { 39 | body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`, 40 | icon: headImg, 41 | image: headImg, 42 | } 43 | ); 44 | 45 | notification.onclick = function(_){ 46 | parent.focus(); 47 | this.close(); 48 | }; 49 | }); 50 | -------------------------------------------------------------------------------- /modules/paths.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import modules.safe 5 | 6 | script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 7 | models_path = os.path.join(script_path, "models") 8 | sys.path.insert(0, script_path) 9 | 10 | # search for directory of stable diffusion in following places 11 | sd_path = None 12 | possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)] 13 | for possible_sd_path in possible_sd_paths: 14 | if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): 15 | sd_path = os.path.abspath(possible_sd_path) 16 | break 17 | 18 | assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths) 19 | 20 | path_dirs = [ 21 | (sd_path, 'ldm', 'Stable Diffusion', []), 22 | (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []), 23 | (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), 24 | (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), 25 | (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), 26 | ] 27 | 28 | paths = {} 29 | 30 | for d, must_exist, what, options in path_dirs: 31 | must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist)) 32 | if not os.path.exists(must_exist_path): 33 | print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr) 34 | else: 35 | d = os.path.abspath(d) 36 | if "atstart" in options: 37 | sys.path.insert(0, d) 38 | else: 39 | sys.path.append(d) 40 | paths[what] = d 41 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request! 2 | 3 | If you have a large change, pay special attention to this paragraph: 4 | 5 | > Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature. 6 | 7 | Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on. 8 | 9 | **Describe what this pull request is trying to achieve.** 10 | 11 | A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code. 12 | 13 | **Additional notes and description of your changes** 14 | 15 | More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of. 16 | 17 | **Environment this was tested in** 18 | 19 | List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box. 20 | - OS: [e.g. Windows, Linux] 21 | - Browser [e.g. chrome, safari] 22 | - Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB] 23 | 24 | **Screenshots or videos of your changes** 25 | 26 | If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made. 27 | 28 | This is **required** for anything that touches the user interface. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Taiyi stable-diffusion-webui 2 | Stable Diffusion web UI for Taiyi 3 | 4 | Make sure the requirement at least, very helpful. 5 | 6 | - transformers>=4.24.0 7 | - diffusers>=0.7.2 8 | 9 | ## step 1 10 | 11 | Since Taiyi's text_encoder has been modified (BertModel vs CLIPTextModel), and webui currently only supports stable diffusion in English, it is necessary to use the webui project modified by Fengshenbang's own fork. 12 | 13 | ``` 14 | git clone https://github.com/IDEA-CCNL/stable-diffusion-webui.git 15 | cd stable-diffusion-webui 16 | ``` 17 | 18 | ## step 2 19 | 20 | Run webui's own commands to check and install the environment, webui will pull down the required repositories in the stable-diffusion-webui/repositories directory, this process will take some time. 21 | 22 | ``` 23 | bash webui.sh 24 | ``` 25 | 26 | This script will then automatically download the required files, back up the original v1_inference.yaml file and replace it with the version needed to start our [taiyi model](https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1 ). 27 | 28 | **Notice that**, if you choose to redownload the Taiyi model, the total size of all the files needed is over 10G, the step: 29 | "Cloning taiyi_model into repositories/Taiyi-Stable-Diffusion-1B-Chinese-v0.1..." will take lots of time, please be patient. 30 | 31 | If you have already downloaded our whole Taiyi model in [https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1]() once, follow the path checker and choose "(2)move your downloaded Taiyi model path?", and move your downloaded model folder to ./repositories/Taiyi-Stable-Diffusion-1B-Chinese-v0.1. 32 | 33 | After all the progress is done, the web-ui service will be started on port 12345. 34 | 35 | 36 |
37 | 38 | You can run the following command to start the web-ui service. 39 | 40 | ``` 41 | bash webui.sh --listen --port 12345 42 | bash webui.sh --ckpt repositories/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/Taiyi-Stable-Diffusion-1B-Chinese-v0.1.ckpt --listen --port 12345 43 | ``` 44 | -------------------------------------------------------------------------------- /test/img2img_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import requests 3 | from gradio.processing_utils import encode_pil_to_base64 4 | from PIL import Image 5 | 6 | 7 | class TestImg2ImgWorking(unittest.TestCase): 8 | def setUp(self): 9 | self.url_img2img = "http://localhost:7860/sdapi/v1/img2img" 10 | self.simple_img2img = { 11 | "init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))], 12 | "resize_mode": 0, 13 | "denoising_strength": 0.75, 14 | "mask": None, 15 | "mask_blur": 4, 16 | "inpainting_fill": 0, 17 | "inpaint_full_res": False, 18 | "inpaint_full_res_padding": 0, 19 | "inpainting_mask_invert": 0, 20 | "prompt": "example prompt", 21 | "styles": [], 22 | "seed": -1, 23 | "subseed": -1, 24 | "subseed_strength": 0, 25 | "seed_resize_from_h": -1, 26 | "seed_resize_from_w": -1, 27 | "batch_size": 1, 28 | "n_iter": 1, 29 | "steps": 3, 30 | "cfg_scale": 7, 31 | "width": 64, 32 | "height": 64, 33 | "restore_faces": False, 34 | "tiling": False, 35 | "negative_prompt": "", 36 | "eta": 0, 37 | "s_churn": 0, 38 | "s_tmax": 0, 39 | "s_tmin": 0, 40 | "s_noise": 1, 41 | "override_settings": {}, 42 | "sampler_index": "Euler a", 43 | "include_init_images": False 44 | } 45 | 46 | def test_img2img_simple_performed(self): 47 | self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) 48 | 49 | def test_inpainting_masked_performed(self): 50 | self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) 51 | self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) 52 | 53 | 54 | class TestImg2ImgCorrectness(unittest.TestCase): 55 | pass 56 | 57 | 58 | if __name__ == "__main__": 59 | unittest.main() 60 | -------------------------------------------------------------------------------- /modules/ldsr_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | from basicsr.utils.download_util import load_file_from_url 6 | 7 | from modules.upscaler import Upscaler, UpscalerData 8 | from modules.ldsr_model_arch import LDSR 9 | from modules import shared 10 | 11 | 12 | class UpscalerLDSR(Upscaler): 13 | def __init__(self, user_path): 14 | self.name = "LDSR" 15 | self.user_path = user_path 16 | self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" 17 | self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" 18 | super().__init__() 19 | scaler_data = UpscalerData("LDSR", None, self) 20 | self.scalers = [scaler_data] 21 | 22 | def load_model(self, path: str): 23 | # Remove incorrect project.yaml file if too big 24 | yaml_path = os.path.join(self.model_path, "project.yaml") 25 | old_model_path = os.path.join(self.model_path, "model.pth") 26 | new_model_path = os.path.join(self.model_path, "model.ckpt") 27 | if os.path.exists(yaml_path): 28 | statinfo = os.stat(yaml_path) 29 | if statinfo.st_size >= 10485760: 30 | print("Removing invalid LDSR YAML file.") 31 | os.remove(yaml_path) 32 | if os.path.exists(old_model_path): 33 | print("Renaming model from model.pth to model.ckpt") 34 | os.rename(old_model_path, new_model_path) 35 | model = load_file_from_url(url=self.model_url, model_dir=self.model_path, 36 | file_name="model.ckpt", progress=True) 37 | yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path, 38 | file_name="project.yaml", progress=True) 39 | 40 | try: 41 | return LDSR(model, yaml) 42 | 43 | except Exception: 44 | print("Error importing LDSR:", file=sys.stderr) 45 | print(traceback.format_exc(), file=sys.stderr) 46 | return None 47 | 48 | def do_upscale(self, img, path): 49 | ldsr = self.load_model(path) 50 | if ldsr is None: 51 | print("NO LDSR!") 52 | return img 53 | ddim_steps = shared.opts.ldsr_steps 54 | return ldsr.super_resolution(img, ddim_steps, self.scale) 55 | -------------------------------------------------------------------------------- /repositories/stable-diffusion-taiyi/configs/stable-diffusion/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | # cond_stage_config: 70 | # target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | cond_stage_config: 72 | target: ldm.modules.encoders.modules.TaiyiCLIPEmbedder 73 | params: 74 | # you can git clone the model and change the version to your local model path 75 | version: your_path/Taiyi-Stable-Diffusion-1B-Chinese-v0.1 76 | max_length: 512 77 | 78 | -------------------------------------------------------------------------------- /repositories/stable-diffusion-taiyi/configs/stable-diffusion/v1-inference-en.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | # cond_stage_config: 70 | # target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | cond_stage_config: 72 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 73 | params: 74 | # you can git clone the model and change the version to your local model path 75 | version: your_path/Taiyi-Stable-Diffusion-1B-Chinese-EN-v0.1 76 | max_length: 77 77 | 78 | -------------------------------------------------------------------------------- /modules/hypernetworks/ui.py: -------------------------------------------------------------------------------- 1 | import html 2 | import os 3 | import re 4 | 5 | import gradio as gr 6 | import modules.textual_inversion.preprocess 7 | import modules.textual_inversion.textual_inversion 8 | from modules import devices, sd_hijack, shared 9 | from modules.hypernetworks import hypernetwork 10 | 11 | not_available = ["hardswish", "multiheadattention"] 12 | keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) 13 | 14 | def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): 15 | # Remove illegal characters from name. 16 | name = "".join( x for x in name if (x.isalnum() or x in "._- ")) 17 | 18 | fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") 19 | if not overwrite_old: 20 | assert not os.path.exists(fn), f"file {fn} already exists" 21 | 22 | if type(layer_structure) == str: 23 | layer_structure = [float(x.strip()) for x in layer_structure.split(",")] 24 | 25 | hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( 26 | name=name, 27 | enable_sizes=[int(x) for x in enable_sizes], 28 | layer_structure=layer_structure, 29 | activation_func=activation_func, 30 | weight_init=weight_init, 31 | add_layer_norm=add_layer_norm, 32 | use_dropout=use_dropout, 33 | ) 34 | hypernet.save(fn) 35 | 36 | shared.reload_hypernetworks() 37 | 38 | return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" 39 | 40 | 41 | def train_hypernetwork(*args): 42 | 43 | initial_hypernetwork = shared.loaded_hypernetwork 44 | 45 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' 46 | 47 | try: 48 | sd_hijack.undo_optimizations() 49 | 50 | hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args) 51 | 52 | res = f""" 53 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. 54 | Hypernetwork saved to {html.escape(filename)} 55 | """ 56 | return res, "" 57 | except Exception: 58 | raise 59 | finally: 60 | shared.loaded_hypernetwork = initial_hypernetwork 61 | shared.sd_model.cond_stage_model.to(devices.device) 62 | shared.sd_model.first_stage_model.to(devices.device) 63 | sd_hijack.apply_optimizations() 64 | 65 | -------------------------------------------------------------------------------- /script.js: -------------------------------------------------------------------------------- 1 | function gradioApp(){ 2 | return document.getElementsByTagName('gradio-app')[0].shadowRoot; 3 | } 4 | 5 | function get_uiCurrentTab() { 6 | return gradioApp().querySelector('.tabs button:not(.border-transparent)') 7 | } 8 | 9 | function get_uiCurrentTabContent() { 10 | return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])') 11 | } 12 | 13 | uiUpdateCallbacks = [] 14 | uiTabChangeCallbacks = [] 15 | let uiCurrentTab = null 16 | 17 | function onUiUpdate(callback){ 18 | uiUpdateCallbacks.push(callback) 19 | } 20 | function onUiTabChange(callback){ 21 | uiTabChangeCallbacks.push(callback) 22 | } 23 | 24 | function runCallback(x, m){ 25 | try { 26 | x(m) 27 | } catch (e) { 28 | (console.error || console.log).call(console, e.message, e); 29 | } 30 | } 31 | function executeCallbacks(queue, m) { 32 | queue.forEach(function(x){runCallback(x, m)}) 33 | } 34 | 35 | document.addEventListener("DOMContentLoaded", function() { 36 | var mutationObserver = new MutationObserver(function(m){ 37 | executeCallbacks(uiUpdateCallbacks, m); 38 | const newTab = get_uiCurrentTab(); 39 | if ( newTab && ( newTab !== uiCurrentTab ) ) { 40 | uiCurrentTab = newTab; 41 | executeCallbacks(uiTabChangeCallbacks); 42 | } 43 | }); 44 | mutationObserver.observe( gradioApp(), { childList:true, subtree:true }) 45 | }); 46 | 47 | /** 48 | * Add a ctrl+enter as a shortcut to start a generation 49 | */ 50 | document.addEventListener('keydown', function(e) { 51 | var handled = false; 52 | if (e.key !== undefined) { 53 | if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; 54 | } else if (e.keyCode !== undefined) { 55 | if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; 56 | } 57 | if (handled) { 58 | button = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); 59 | if (button) { 60 | button.click(); 61 | } 62 | e.preventDefault(); 63 | } 64 | }) 65 | 66 | /** 67 | * checks that a UI element is not in another hidden element or tab content 68 | */ 69 | function uiElementIsVisible(el) { 70 | let isVisible = !el.closest('.\\!hidden'); 71 | if ( ! isVisible ) { 72 | return false; 73 | } 74 | 75 | while( isVisible = el.closest('.tabitem')?.style.display !== 'none' ) { 76 | if ( ! isVisible ) { 77 | return false; 78 | } else if ( el.parentElement ) { 79 | el = el.parentElement 80 | } else { 81 | break; 82 | } 83 | } 84 | return isVisible; 85 | } -------------------------------------------------------------------------------- /modules/txt2img.py: -------------------------------------------------------------------------------- 1 | import modules.scripts 2 | from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \ 3 | StableDiffusionProcessingImg2Img, process_images 4 | from modules.shared import opts, cmd_opts 5 | import modules.shared as shared 6 | import modules.processing as processing 7 | from modules.ui import plaintext_to_html 8 | 9 | 10 | def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): 11 | p = StableDiffusionProcessingTxt2Img( 12 | sd_model=shared.sd_model, 13 | outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, 14 | outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids, 15 | prompt=prompt, 16 | styles=[prompt_style, prompt_style2], 17 | negative_prompt=negative_prompt, 18 | seed=seed, 19 | subseed=subseed, 20 | subseed_strength=subseed_strength, 21 | seed_resize_from_h=seed_resize_from_h, 22 | seed_resize_from_w=seed_resize_from_w, 23 | seed_enable_extras=seed_enable_extras, 24 | sampler_index=sampler_index, 25 | batch_size=batch_size, 26 | n_iter=n_iter, 27 | steps=steps, 28 | cfg_scale=cfg_scale, 29 | width=width, 30 | height=height, 31 | restore_faces=restore_faces, 32 | tiling=tiling, 33 | enable_hr=enable_hr, 34 | denoising_strength=denoising_strength if enable_hr else None, 35 | firstphase_width=firstphase_width if enable_hr else None, 36 | firstphase_height=firstphase_height if enable_hr else None, 37 | ) 38 | 39 | p.scripts = modules.scripts.scripts_txt2img 40 | p.script_args = args 41 | 42 | if cmd_opts.enable_console_prompts: 43 | print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) 44 | 45 | processed = modules.scripts.scripts_txt2img.run(p, *args) 46 | 47 | if processed is None: 48 | processed = process_images(p) 49 | 50 | p.close() 51 | 52 | shared.total_tqdm.clear() 53 | 54 | generation_info_js = processed.js() 55 | if opts.samples_log_stdout: 56 | print(generation_info_js) 57 | 58 | if opts.do_not_show_images: 59 | processed.images = [] 60 | 61 | return processed.images, generation_info_js, plaintext_to_html(processed.info) 62 | -------------------------------------------------------------------------------- /modules/extensions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | import git 6 | 7 | from modules import paths, shared 8 | 9 | 10 | extensions = [] 11 | extensions_dir = os.path.join(paths.script_path, "extensions") 12 | 13 | 14 | def active(): 15 | return [x for x in extensions if x.enabled] 16 | 17 | 18 | class Extension: 19 | def __init__(self, name, path, enabled=True): 20 | self.name = name 21 | self.path = path 22 | self.enabled = enabled 23 | self.status = '' 24 | self.can_update = False 25 | 26 | repo = None 27 | try: 28 | if os.path.exists(os.path.join(path, ".git")): 29 | repo = git.Repo(path) 30 | except Exception: 31 | print(f"Error reading github repository info from {path}:", file=sys.stderr) 32 | print(traceback.format_exc(), file=sys.stderr) 33 | 34 | if repo is None or repo.bare: 35 | self.remote = None 36 | else: 37 | try: 38 | self.remote = next(repo.remote().urls, None) 39 | self.status = 'unknown' 40 | except Exception: 41 | self.remote = None 42 | 43 | def list_files(self, subdir, extension): 44 | from modules import scripts 45 | 46 | dirpath = os.path.join(self.path, subdir) 47 | if not os.path.isdir(dirpath): 48 | return [] 49 | 50 | res = [] 51 | for filename in sorted(os.listdir(dirpath)): 52 | res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename))) 53 | 54 | res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] 55 | 56 | return res 57 | 58 | def check_updates(self): 59 | repo = git.Repo(self.path) 60 | for fetch in repo.remote().fetch("--dry-run"): 61 | if fetch.flags != fetch.HEAD_UPTODATE: 62 | self.can_update = True 63 | self.status = "behind" 64 | return 65 | 66 | self.can_update = False 67 | self.status = "latest" 68 | 69 | def pull(self): 70 | repo = git.Repo(self.path) 71 | repo.remotes.origin.pull() 72 | 73 | 74 | def list_extensions(): 75 | extensions.clear() 76 | 77 | if not os.path.isdir(extensions_dir): 78 | return 79 | 80 | for dirname in sorted(os.listdir(extensions_dir)): 81 | path = os.path.join(extensions_dir, dirname) 82 | if not os.path.isdir(path): 83 | continue 84 | 85 | extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions) 86 | extensions.append(extension) 87 | -------------------------------------------------------------------------------- /modules/memmon.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | from collections import defaultdict 4 | 5 | import torch 6 | 7 | 8 | class MemUsageMonitor(threading.Thread): 9 | run_flag = None 10 | device = None 11 | disabled = False 12 | opts = None 13 | data = None 14 | 15 | def __init__(self, name, device, opts): 16 | threading.Thread.__init__(self) 17 | self.name = name 18 | self.device = device 19 | self.opts = opts 20 | 21 | self.daemon = True 22 | self.run_flag = threading.Event() 23 | self.data = defaultdict(int) 24 | 25 | try: 26 | torch.cuda.mem_get_info() 27 | torch.cuda.memory_stats(self.device) 28 | except Exception as e: # AMD or whatever 29 | print(f"Warning: caught exception '{e}', memory monitor disabled") 30 | self.disabled = True 31 | 32 | def run(self): 33 | if self.disabled: 34 | return 35 | 36 | while True: 37 | self.run_flag.wait() 38 | 39 | torch.cuda.reset_peak_memory_stats() 40 | self.data.clear() 41 | 42 | if self.opts.memmon_poll_rate <= 0: 43 | self.run_flag.clear() 44 | continue 45 | 46 | self.data["min_free"] = torch.cuda.mem_get_info()[0] 47 | 48 | while self.run_flag.is_set(): 49 | free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug? 50 | self.data["min_free"] = min(self.data["min_free"], free) 51 | 52 | time.sleep(1 / self.opts.memmon_poll_rate) 53 | 54 | def dump_debug(self): 55 | print(self, 'recorded data:') 56 | for k, v in self.read().items(): 57 | print(k, -(v // -(1024 ** 2))) 58 | 59 | print(self, 'raw torch memory stats:') 60 | tm = torch.cuda.memory_stats(self.device) 61 | for k, v in tm.items(): 62 | if 'bytes' not in k: 63 | continue 64 | print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2))) 65 | 66 | print(torch.cuda.memory_summary()) 67 | 68 | def monitor(self): 69 | self.run_flag.set() 70 | 71 | def read(self): 72 | if not self.disabled: 73 | free, total = torch.cuda.mem_get_info() 74 | self.data["total"] = total 75 | 76 | torch_stats = torch.cuda.memory_stats(self.device) 77 | self.data["active_peak"] = torch_stats["active_bytes.all.peak"] 78 | self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] 79 | self.data["system_peak"] = total - self.data["min_free"] 80 | 81 | return self.data 82 | 83 | def stop(self): 84 | self.run_flag.clear() 85 | return self.read() 86 | -------------------------------------------------------------------------------- /test/utils_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import requests 3 | 4 | class UtilsTests(unittest.TestCase): 5 | def setUp(self): 6 | self.url_options = "http://localhost:7860/sdapi/v1/options" 7 | self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags" 8 | self.url_samplers = "http://localhost:7860/sdapi/v1/samplers" 9 | self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers" 10 | self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models" 11 | self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks" 12 | self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers" 13 | self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models" 14 | self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles" 15 | self.url_artist_categories = "http://localhost:7860/sdapi/v1/artist-categories" 16 | self.url_artists = "http://localhost:7860/sdapi/v1/artists" 17 | 18 | def test_options_get(self): 19 | self.assertEqual(requests.get(self.url_options).status_code, 200) 20 | 21 | def test_options_write(self): 22 | response = requests.get(self.url_options) 23 | self.assertEqual(response.status_code, 200) 24 | 25 | pre_value = response.json()["send_seed"] 26 | 27 | self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200) 28 | 29 | response = requests.get(self.url_options) 30 | self.assertEqual(response.status_code, 200) 31 | self.assertEqual(response.json()["send_seed"], not pre_value) 32 | 33 | requests.post(self.url_options, json={"send_seed": pre_value}) 34 | 35 | def test_cmd_flags(self): 36 | self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200) 37 | 38 | def test_samplers(self): 39 | self.assertEqual(requests.get(self.url_samplers).status_code, 200) 40 | 41 | def test_upscalers(self): 42 | self.assertEqual(requests.get(self.url_upscalers).status_code, 200) 43 | 44 | def test_sd_models(self): 45 | self.assertEqual(requests.get(self.url_sd_models).status_code, 200) 46 | 47 | def test_hypernetworks(self): 48 | self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200) 49 | 50 | def test_face_restorers(self): 51 | self.assertEqual(requests.get(self.url_face_restorers).status_code, 200) 52 | 53 | def test_realesrgan_models(self): 54 | self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200) 55 | 56 | def test_prompt_styles(self): 57 | self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200) 58 | 59 | def test_artist_categories(self): 60 | self.assertEqual(requests.get(self.url_artist_categories).status_code, 200) 61 | 62 | def test_artists(self): 63 | self.assertEqual(requests.get(self.url_artists).status_code, 200) -------------------------------------------------------------------------------- /modules/devices.py: -------------------------------------------------------------------------------- 1 | import sys, os, shlex 2 | import contextlib 3 | import torch 4 | from modules import errors 5 | 6 | # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility 7 | has_mps = getattr(torch, 'has_mps', False) 8 | 9 | cpu = torch.device("cpu") 10 | 11 | def extract_device_id(args, name): 12 | for x in range(len(args)): 13 | if name in args[x]: return args[x+1] 14 | return None 15 | 16 | def get_optimal_device(): 17 | if torch.cuda.is_available(): 18 | from modules import shared 19 | 20 | device_id = shared.cmd_opts.device_id 21 | 22 | if device_id is not None: 23 | cuda_device = f"cuda:{device_id}" 24 | return torch.device(cuda_device) 25 | else: 26 | return torch.device("cuda") 27 | 28 | if has_mps: 29 | return torch.device("mps") 30 | 31 | return cpu 32 | 33 | 34 | def torch_gc(): 35 | if torch.cuda.is_available(): 36 | torch.cuda.empty_cache() 37 | torch.cuda.ipc_collect() 38 | 39 | 40 | def enable_tf32(): 41 | if torch.cuda.is_available(): 42 | torch.backends.cuda.matmul.allow_tf32 = True 43 | torch.backends.cudnn.allow_tf32 = True 44 | 45 | 46 | errors.run(enable_tf32, "Enabling TF32") 47 | 48 | device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None 49 | dtype = torch.float16 50 | dtype_vae = torch.float16 51 | 52 | def randn(seed, shape): 53 | # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. 54 | if device.type == 'mps': 55 | generator = torch.Generator(device=cpu) 56 | generator.manual_seed(seed) 57 | noise = torch.randn(shape, generator=generator, device=cpu).to(device) 58 | return noise 59 | 60 | torch.manual_seed(seed) 61 | return torch.randn(shape, device=device) 62 | 63 | 64 | def randn_without_seed(shape): 65 | # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. 66 | if device.type == 'mps': 67 | generator = torch.Generator(device=cpu) 68 | noise = torch.randn(shape, generator=generator, device=cpu).to(device) 69 | return noise 70 | 71 | return torch.randn(shape, device=device) 72 | 73 | 74 | def autocast(disable=False): 75 | from modules import shared 76 | 77 | if disable: 78 | return contextlib.nullcontext() 79 | 80 | if dtype == torch.float32 or shared.cmd_opts.precision == "full": 81 | return contextlib.nullcontext() 82 | 83 | return torch.autocast("cuda") 84 | 85 | # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 86 | def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor 87 | def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device) 88 | -------------------------------------------------------------------------------- /modules/textual_inversion/learn_schedule.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | 3 | 4 | class LearnScheduleIterator: 5 | def __init__(self, learn_rate, max_steps, cur_step=0): 6 | """ 7 | specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000 8 | """ 9 | 10 | pairs = learn_rate.split(',') 11 | self.rates = [] 12 | self.it = 0 13 | self.maxit = 0 14 | try: 15 | for i, pair in enumerate(pairs): 16 | if not pair.strip(): 17 | continue 18 | tmp = pair.split(':') 19 | if len(tmp) == 2: 20 | step = int(tmp[1]) 21 | if step > cur_step: 22 | self.rates.append((float(tmp[0]), min(step, max_steps))) 23 | self.maxit += 1 24 | if step > max_steps: 25 | return 26 | elif step == -1: 27 | self.rates.append((float(tmp[0]), max_steps)) 28 | self.maxit += 1 29 | return 30 | else: 31 | self.rates.append((float(tmp[0]), max_steps)) 32 | self.maxit += 1 33 | return 34 | assert self.rates 35 | except (ValueError, AssertionError): 36 | raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') 37 | 38 | 39 | def __iter__(self): 40 | return self 41 | 42 | def __next__(self): 43 | if self.it < self.maxit: 44 | self.it += 1 45 | return self.rates[self.it - 1] 46 | else: 47 | raise StopIteration 48 | 49 | 50 | class LearnRateScheduler: 51 | def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True): 52 | self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step) 53 | (self.learn_rate, self.end_step) = next(self.schedules) 54 | self.verbose = verbose 55 | 56 | if self.verbose: 57 | print(f'Training at rate of {self.learn_rate} until step {self.end_step}') 58 | 59 | self.finished = False 60 | 61 | def apply(self, optimizer, step_number): 62 | if step_number < self.end_step: 63 | return 64 | 65 | try: 66 | (self.learn_rate, self.end_step) = next(self.schedules) 67 | except Exception: 68 | self.finished = True 69 | return 70 | 71 | if self.verbose: 72 | tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}') 73 | 74 | for pg in optimizer.param_groups: 75 | pg['lr'] = self.learn_rate 76 | 77 | -------------------------------------------------------------------------------- /test/txt2img_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import requests 3 | 4 | 5 | class TestTxt2ImgWorking(unittest.TestCase): 6 | def setUp(self): 7 | self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img" 8 | self.simple_txt2img = { 9 | "enable_hr": False, 10 | "denoising_strength": 0, 11 | "firstphase_width": 0, 12 | "firstphase_height": 0, 13 | "prompt": "example prompt", 14 | "styles": [], 15 | "seed": -1, 16 | "subseed": -1, 17 | "subseed_strength": 0, 18 | "seed_resize_from_h": -1, 19 | "seed_resize_from_w": -1, 20 | "batch_size": 1, 21 | "n_iter": 1, 22 | "steps": 3, 23 | "cfg_scale": 7, 24 | "width": 64, 25 | "height": 64, 26 | "restore_faces": False, 27 | "tiling": False, 28 | "negative_prompt": "", 29 | "eta": 0, 30 | "s_churn": 0, 31 | "s_tmax": 0, 32 | "s_tmin": 0, 33 | "s_noise": 1, 34 | "sampler_index": "Euler a" 35 | } 36 | 37 | def test_txt2img_simple_performed(self): 38 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 39 | 40 | def test_txt2img_with_negative_prompt_performed(self): 41 | self.simple_txt2img["negative_prompt"] = "example negative prompt" 42 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 43 | 44 | def test_txt2img_not_square_image_performed(self): 45 | self.simple_txt2img["height"] = 128 46 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 47 | 48 | def test_txt2img_with_hrfix_performed(self): 49 | self.simple_txt2img["enable_hr"] = True 50 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 51 | 52 | def test_txt2img_with_restore_faces_performed(self): 53 | self.simple_txt2img["restore_faces"] = True 54 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 55 | 56 | def test_txt2img_with_tiling_faces_performed(self): 57 | self.simple_txt2img["tiling"] = True 58 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 59 | 60 | def test_txt2img_with_vanilla_sampler_performed(self): 61 | self.simple_txt2img["sampler_index"] = "PLMS" 62 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 63 | 64 | def test_txt2img_multiple_batches_performed(self): 65 | self.simple_txt2img["n_iter"] = 2 66 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 67 | 68 | 69 | class TestTxt2ImgCorrectness(unittest.TestCase): 70 | pass 71 | 72 | 73 | if __name__ == "__main__": 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: You think somethings is broken in the UI 3 | title: "[Bug]: " 4 | labels: ["bug-report"] 5 | 6 | body: 7 | - type: checkboxes 8 | attributes: 9 | label: Is there an existing issue for this? 10 | description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit. 11 | options: 12 | - label: I have searched the existing issues and checked the recent builds/commits 13 | required: true 14 | - type: markdown 15 | attributes: 16 | value: | 17 | *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible** 18 | - type: textarea 19 | id: what-did 20 | attributes: 21 | label: What happened? 22 | description: Tell us what happened in a very clear and simple way 23 | validations: 24 | required: true 25 | - type: textarea 26 | id: steps 27 | attributes: 28 | label: Steps to reproduce the problem 29 | description: Please provide us with precise step by step information on how to reproduce the bug 30 | value: | 31 | 1. Go to .... 32 | 2. Press .... 33 | 3. ... 34 | validations: 35 | required: true 36 | - type: textarea 37 | id: what-should 38 | attributes: 39 | label: What should have happened? 40 | description: tell what you think the normal behavior should be 41 | validations: 42 | required: true 43 | - type: input 44 | id: commit 45 | attributes: 46 | label: Commit where the problem happens 47 | description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI) 48 | validations: 49 | required: true 50 | - type: dropdown 51 | id: platforms 52 | attributes: 53 | label: What platforms do you use to access UI ? 54 | multiple: true 55 | options: 56 | - Windows 57 | - Linux 58 | - MacOS 59 | - iOS 60 | - Android 61 | - Other/Cloud 62 | - type: dropdown 63 | id: browsers 64 | attributes: 65 | label: What browsers do you use to access the UI ? 66 | multiple: true 67 | options: 68 | - Mozilla Firefox 69 | - Google Chrome 70 | - Brave 71 | - Apple Safari 72 | - Microsoft Edge 73 | - type: textarea 74 | id: cmdargs 75 | attributes: 76 | label: Command Line Arguments 77 | description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below 78 | render: Shell 79 | - type: textarea 80 | id: misc 81 | attributes: 82 | label: Additional information, context and logs 83 | description: Please provide us with any relevant additional info, context or log output. 84 | -------------------------------------------------------------------------------- /javascript/dragdrop.js: -------------------------------------------------------------------------------- 1 | // allows drag-dropping files into gradio image elements, and also pasting images from clipboard 2 | 3 | function isValidImageList( files ) { 4 | return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type); 5 | } 6 | 7 | function dropReplaceImage( imgWrap, files ) { 8 | if ( ! isValidImageList( files ) ) { 9 | return; 10 | } 11 | 12 | imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click(); 13 | const callback = () => { 14 | const fileInput = imgWrap.querySelector('input[type="file"]'); 15 | if ( fileInput ) { 16 | fileInput.files = files; 17 | fileInput.dispatchEvent(new Event('change')); 18 | } 19 | }; 20 | 21 | if ( imgWrap.closest('#pnginfo_image') ) { 22 | // special treatment for PNG Info tab, wait for fetch request to finish 23 | const oldFetch = window.fetch; 24 | window.fetch = async (input, options) => { 25 | const response = await oldFetch(input, options); 26 | if ( 'api/predict/' === input ) { 27 | const content = await response.text(); 28 | window.fetch = oldFetch; 29 | window.requestAnimationFrame( () => callback() ); 30 | return new Response(content, { 31 | status: response.status, 32 | statusText: response.statusText, 33 | headers: response.headers 34 | }) 35 | } 36 | return response; 37 | }; 38 | } else { 39 | window.requestAnimationFrame( () => callback() ); 40 | } 41 | } 42 | 43 | window.document.addEventListener('dragover', e => { 44 | const target = e.composedPath()[0]; 45 | const imgWrap = target.closest('[data-testid="image"]'); 46 | if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) { 47 | return; 48 | } 49 | e.stopPropagation(); 50 | e.preventDefault(); 51 | e.dataTransfer.dropEffect = 'copy'; 52 | }); 53 | 54 | window.document.addEventListener('drop', e => { 55 | const target = e.composedPath()[0]; 56 | if (target.placeholder.indexOf("Prompt") == -1) { 57 | return; 58 | } 59 | const imgWrap = target.closest('[data-testid="image"]'); 60 | if ( !imgWrap ) { 61 | return; 62 | } 63 | e.stopPropagation(); 64 | e.preventDefault(); 65 | const files = e.dataTransfer.files; 66 | dropReplaceImage( imgWrap, files ); 67 | }); 68 | 69 | window.addEventListener('paste', e => { 70 | const files = e.clipboardData.files; 71 | if ( ! isValidImageList( files ) ) { 72 | return; 73 | } 74 | 75 | const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')] 76 | .filter(el => uiElementIsVisible(el)); 77 | if ( ! visibleImageFields.length ) { 78 | return; 79 | } 80 | 81 | const firstFreeImageField = visibleImageFields 82 | .filter(el => el.querySelector('input[type=file]'))?.[0]; 83 | 84 | dropReplaceImage( 85 | firstFreeImageField ? 86 | firstFreeImageField : 87 | visibleImageFields[visibleImageFields.length - 1] 88 | , files ); 89 | }); 90 | -------------------------------------------------------------------------------- /javascript/edit-attention.js: -------------------------------------------------------------------------------- 1 | addEventListener('keydown', (event) => { 2 | let target = event.originalTarget || event.composedPath()[0]; 3 | if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return; 4 | if (! (event.metaKey || event.ctrlKey)) return; 5 | 6 | 7 | let plus = "ArrowUp" 8 | let minus = "ArrowDown" 9 | if (event.key != plus && event.key != minus) return; 10 | 11 | let selectionStart = target.selectionStart; 12 | let selectionEnd = target.selectionEnd; 13 | // If the user hasn't selected anything, let's select their current parenthesis block 14 | if (selectionStart === selectionEnd) { 15 | // Find opening parenthesis around current cursor 16 | const before = target.value.substring(0, selectionStart); 17 | let beforeParen = before.lastIndexOf("("); 18 | if (beforeParen == -1) return; 19 | let beforeParenClose = before.lastIndexOf(")"); 20 | while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { 21 | beforeParen = before.lastIndexOf("(", beforeParen - 1); 22 | beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1); 23 | } 24 | 25 | // Find closing parenthesis around current cursor 26 | const after = target.value.substring(selectionStart); 27 | let afterParen = after.indexOf(")"); 28 | if (afterParen == -1) return; 29 | let afterParenOpen = after.indexOf("("); 30 | while (afterParenOpen !== -1 && afterParen > afterParenOpen) { 31 | afterParen = after.indexOf(")", afterParen + 1); 32 | afterParenOpen = after.indexOf("(", afterParenOpen + 1); 33 | } 34 | if (beforeParen === -1 || afterParen === -1) return; 35 | 36 | // Set the selection to the text between the parenthesis 37 | const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen); 38 | const lastColon = parenContent.lastIndexOf(":"); 39 | selectionStart = beforeParen + 1; 40 | selectionEnd = selectionStart + lastColon; 41 | target.setSelectionRange(selectionStart, selectionEnd); 42 | } 43 | 44 | event.preventDefault(); 45 | 46 | if (selectionStart == 0 || target.value[selectionStart - 1] != "(") { 47 | target.value = target.value.slice(0, selectionStart) + 48 | "(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" + 49 | target.value.slice(selectionEnd); 50 | 51 | target.focus(); 52 | target.selectionStart = selectionStart + 1; 53 | target.selectionEnd = selectionEnd + 1; 54 | 55 | } else { 56 | end = target.value.slice(selectionEnd + 1).indexOf(")") + 1; 57 | weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end)); 58 | if (isNaN(weight)) return; 59 | if (event.key == minus) weight -= 0.1; 60 | if (event.key == plus) weight += 0.1; 61 | 62 | weight = parseFloat(weight.toPrecision(12)); 63 | 64 | target.value = target.value.slice(0, selectionEnd + 1) + 65 | weight + 66 | target.value.slice(selectionEnd + 1 + end - 1); 67 | 68 | target.focus(); 69 | target.selectionStart = selectionStart; 70 | target.selectionEnd = selectionEnd; 71 | } 72 | // Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its 73 | // internal Svelte data binding remains in sync. 74 | target.dispatchEvent(new Event("input", { bubbles: true })); 75 | }); 76 | -------------------------------------------------------------------------------- /scripts/loopback.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import trange 3 | 4 | import modules.scripts as scripts 5 | import gradio as gr 6 | 7 | from modules import processing, shared, sd_samplers, images 8 | from modules.processing import Processed 9 | from modules.sd_samplers import samplers 10 | from modules.shared import opts, cmd_opts, state 11 | 12 | class Script(scripts.Script): 13 | def title(self): 14 | return "Loopback" 15 | 16 | def show(self, is_img2img): 17 | return is_img2img 18 | 19 | def ui(self, is_img2img): 20 | loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4) 21 | denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1) 22 | 23 | return [loops, denoising_strength_change_factor] 24 | 25 | def run(self, p, loops, denoising_strength_change_factor): 26 | processing.fix_seed(p) 27 | batch_count = p.n_iter 28 | p.extra_generation_params = { 29 | "Denoising strength change factor": denoising_strength_change_factor, 30 | } 31 | 32 | p.batch_size = 1 33 | p.n_iter = 1 34 | 35 | output_images, info = None, None 36 | initial_seed = None 37 | initial_info = None 38 | 39 | grids = [] 40 | all_images = [] 41 | original_init_image = p.init_images 42 | state.job_count = loops * batch_count 43 | 44 | initial_color_corrections = [processing.setup_color_correction(p.init_images[0])] 45 | 46 | for n in range(batch_count): 47 | history = [] 48 | 49 | # Reset to original init image at the start of each batch 50 | p.init_images = original_init_image 51 | 52 | for i in range(loops): 53 | p.n_iter = 1 54 | p.batch_size = 1 55 | p.do_not_save_grid = True 56 | 57 | if opts.img2img_color_correction: 58 | p.color_corrections = initial_color_corrections 59 | 60 | state.job = f"Iteration {i + 1}/{loops}, batch {n + 1}/{batch_count}" 61 | 62 | processed = processing.process_images(p) 63 | 64 | if initial_seed is None: 65 | initial_seed = processed.seed 66 | initial_info = processed.info 67 | 68 | init_img = processed.images[0] 69 | 70 | p.init_images = [init_img] 71 | p.seed = processed.seed + 1 72 | p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1) 73 | history.append(processed.images[0]) 74 | 75 | grid = images.image_grid(history, rows=1) 76 | if opts.grid_save: 77 | images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p) 78 | 79 | grids.append(grid) 80 | all_images += history 81 | 82 | if opts.return_grid: 83 | all_images = grids + all_images 84 | 85 | processed = Processed(p, all_images, initial_seed, initial_info) 86 | 87 | return processed 88 | -------------------------------------------------------------------------------- /scripts/prompt_matrix.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import namedtuple 3 | from copy import copy 4 | import random 5 | 6 | import modules.scripts as scripts 7 | import gradio as gr 8 | 9 | from modules import images 10 | from modules.processing import process_images, Processed 11 | from modules.shared import opts, cmd_opts, state 12 | import modules.sd_samplers 13 | 14 | 15 | def draw_xy_grid(xs, ys, x_label, y_label, cell): 16 | res = [] 17 | 18 | ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys] 19 | hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs] 20 | 21 | first_pocessed = None 22 | 23 | state.job_count = len(xs) * len(ys) 24 | 25 | for iy, y in enumerate(ys): 26 | for ix, x in enumerate(xs): 27 | state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" 28 | 29 | processed = cell(x, y) 30 | if first_pocessed is None: 31 | first_pocessed = processed 32 | 33 | res.append(processed.images[0]) 34 | 35 | grid = images.image_grid(res, rows=len(ys)) 36 | grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) 37 | 38 | first_pocessed.images = [grid] 39 | 40 | return first_pocessed 41 | 42 | 43 | class Script(scripts.Script): 44 | def title(self): 45 | return "Prompt matrix" 46 | 47 | def ui(self, is_img2img): 48 | put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False) 49 | 50 | return [put_at_start] 51 | 52 | def run(self, p, put_at_start): 53 | modules.processing.fix_seed(p) 54 | 55 | original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt 56 | 57 | all_prompts = [] 58 | prompt_matrix_parts = original_prompt.split("|") 59 | combination_count = 2 ** (len(prompt_matrix_parts) - 1) 60 | for combination_num in range(combination_count): 61 | selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)] 62 | 63 | if put_at_start: 64 | selected_prompts = selected_prompts + [prompt_matrix_parts[0]] 65 | else: 66 | selected_prompts = [prompt_matrix_parts[0]] + selected_prompts 67 | 68 | all_prompts.append(", ".join(selected_prompts)) 69 | 70 | p.n_iter = math.ceil(len(all_prompts) / p.batch_size) 71 | p.do_not_save_grid = True 72 | 73 | print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.") 74 | 75 | p.prompt = all_prompts 76 | p.seed = [p.seed for _ in all_prompts] 77 | p.prompt_for_display = original_prompt 78 | processed = process_images(p) 79 | 80 | grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) 81 | grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts) 82 | processed.images.insert(0, grid) 83 | 84 | if opts.grid_save: 85 | images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", prompt=original_prompt, seed=processed.seed, grid=True, p=p) 86 | 87 | return processed 88 | -------------------------------------------------------------------------------- /modules/scunet_model.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import sys 3 | import traceback 4 | 5 | import PIL.Image 6 | import numpy as np 7 | import torch 8 | from basicsr.utils.download_util import load_file_from_url 9 | 10 | import modules.upscaler 11 | from modules import devices, modelloader 12 | from modules.scunet_model_arch import SCUNet as net 13 | 14 | 15 | class UpscalerScuNET(modules.upscaler.Upscaler): 16 | def __init__(self, dirname): 17 | self.name = "ScuNET" 18 | self.model_name = "ScuNET GAN" 19 | self.model_name2 = "ScuNET PSNR" 20 | self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth" 21 | self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth" 22 | self.user_path = dirname 23 | super().__init__() 24 | model_paths = self.find_models(ext_filter=[".pth"]) 25 | scalers = [] 26 | add_model2 = True 27 | for file in model_paths: 28 | if "http" in file: 29 | name = self.model_name 30 | else: 31 | name = modelloader.friendly_name(file) 32 | if name == self.model_name2 or file == self.model_url2: 33 | add_model2 = False 34 | try: 35 | scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) 36 | scalers.append(scaler_data) 37 | except Exception: 38 | print(f"Error loading ScuNET model: {file}", file=sys.stderr) 39 | print(traceback.format_exc(), file=sys.stderr) 40 | if add_model2: 41 | scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) 42 | scalers.append(scaler_data2) 43 | self.scalers = scalers 44 | 45 | def do_upscale(self, img: PIL.Image, selected_file): 46 | torch.cuda.empty_cache() 47 | 48 | model = self.load_model(selected_file) 49 | if model is None: 50 | return img 51 | 52 | device = devices.device_scunet 53 | img = np.array(img) 54 | img = img[:, :, ::-1] 55 | img = np.moveaxis(img, 2, 0) / 255 56 | img = torch.from_numpy(img).float() 57 | img = devices.mps_contiguous_to(img.unsqueeze(0), device) 58 | 59 | with torch.no_grad(): 60 | output = model(img) 61 | output = output.squeeze().float().cpu().clamp_(0, 1).numpy() 62 | output = 255. * np.moveaxis(output, 0, 2) 63 | output = output.astype(np.uint8) 64 | output = output[:, :, ::-1] 65 | torch.cuda.empty_cache() 66 | return PIL.Image.fromarray(output, 'RGB') 67 | 68 | def load_model(self, path: str): 69 | device = devices.device_scunet 70 | if "http" in path: 71 | filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, 72 | progress=True) 73 | else: 74 | filename = path 75 | if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: 76 | print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr) 77 | return None 78 | 79 | model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) 80 | model.load_state_dict(torch.load(filename), strict=True) 81 | model.eval() 82 | for k, v in model.named_parameters(): 83 | v.requires_grad = False 84 | model = model.to(device) 85 | 86 | return model 87 | 88 | -------------------------------------------------------------------------------- /modules/masking.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFilter, ImageOps 2 | 3 | 4 | def get_crop_region(mask, pad=0): 5 | """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. 6 | For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" 7 | 8 | h, w = mask.shape 9 | 10 | crop_left = 0 11 | for i in range(w): 12 | if not (mask[:, i] == 0).all(): 13 | break 14 | crop_left += 1 15 | 16 | crop_right = 0 17 | for i in reversed(range(w)): 18 | if not (mask[:, i] == 0).all(): 19 | break 20 | crop_right += 1 21 | 22 | crop_top = 0 23 | for i in range(h): 24 | if not (mask[i] == 0).all(): 25 | break 26 | crop_top += 1 27 | 28 | crop_bottom = 0 29 | for i in reversed(range(h)): 30 | if not (mask[i] == 0).all(): 31 | break 32 | crop_bottom += 1 33 | 34 | return ( 35 | int(max(crop_left-pad, 0)), 36 | int(max(crop_top-pad, 0)), 37 | int(min(w - crop_right + pad, w)), 38 | int(min(h - crop_bottom + pad, h)) 39 | ) 40 | 41 | 42 | def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height): 43 | """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region 44 | for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.""" 45 | 46 | x1, y1, x2, y2 = crop_region 47 | 48 | ratio_crop_region = (x2 - x1) / (y2 - y1) 49 | ratio_processing = processing_width / processing_height 50 | 51 | if ratio_crop_region > ratio_processing: 52 | desired_height = (x2 - x1) / ratio_processing 53 | desired_height_diff = int(desired_height - (y2-y1)) 54 | y1 -= desired_height_diff//2 55 | y2 += desired_height_diff - desired_height_diff//2 56 | if y2 >= image_height: 57 | diff = y2 - image_height 58 | y2 -= diff 59 | y1 -= diff 60 | if y1 < 0: 61 | y2 -= y1 62 | y1 -= y1 63 | if y2 >= image_height: 64 | y2 = image_height 65 | else: 66 | desired_width = (y2 - y1) * ratio_processing 67 | desired_width_diff = int(desired_width - (x2-x1)) 68 | x1 -= desired_width_diff//2 69 | x2 += desired_width_diff - desired_width_diff//2 70 | if x2 >= image_width: 71 | diff = x2 - image_width 72 | x2 -= diff 73 | x1 -= diff 74 | if x1 < 0: 75 | x2 -= x1 76 | x1 -= x1 77 | if x2 >= image_width: 78 | x2 = image_width 79 | 80 | return x1, y1, x2, y2 81 | 82 | 83 | def fill(image, mask): 84 | """fills masked regions with colors from image using blur. Not extremely effective.""" 85 | 86 | image_mod = Image.new('RGBA', (image.width, image.height)) 87 | 88 | image_masked = Image.new('RGBa', (image.width, image.height)) 89 | image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L'))) 90 | 91 | image_masked = image_masked.convert('RGBa') 92 | 93 | for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]: 94 | blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA') 95 | for _ in range(repeats): 96 | image_mod.alpha_composite(blurred) 97 | 98 | return image_mod.convert("RGB") 99 | 100 | -------------------------------------------------------------------------------- /javascript/aspectRatioOverlay.js: -------------------------------------------------------------------------------- 1 | 2 | let currentWidth = null; 3 | let currentHeight = null; 4 | let arFrameTimeout = setTimeout(function(){},0); 5 | 6 | function dimensionChange(e, is_width, is_height){ 7 | 8 | if(is_width){ 9 | currentWidth = e.target.value*1.0 10 | } 11 | if(is_height){ 12 | currentHeight = e.target.value*1.0 13 | } 14 | 15 | var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) 16 | 17 | if(!inImg2img){ 18 | return; 19 | } 20 | 21 | var targetElement = null; 22 | 23 | var tabIndex = get_tab_index('mode_img2img') 24 | if(tabIndex == 0){ 25 | targetElement = gradioApp().querySelector('div[data-testid=image] img'); 26 | } else if(tabIndex == 1){ 27 | targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); 28 | } 29 | 30 | if(targetElement){ 31 | 32 | var arPreviewRect = gradioApp().querySelector('#imageARPreview'); 33 | if(!arPreviewRect){ 34 | arPreviewRect = document.createElement('div') 35 | arPreviewRect.id = "imageARPreview"; 36 | gradioApp().getRootNode().appendChild(arPreviewRect) 37 | } 38 | 39 | 40 | 41 | var viewportOffset = targetElement.getBoundingClientRect(); 42 | 43 | viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight ) 44 | 45 | scaledx = targetElement.naturalWidth*viewportscale 46 | scaledy = targetElement.naturalHeight*viewportscale 47 | 48 | cleintRectTop = (viewportOffset.top+window.scrollY) 49 | cleintRectLeft = (viewportOffset.left+window.scrollX) 50 | cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2) 51 | cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2) 52 | 53 | viewRectTop = cleintRectCentreY-(scaledy/2) 54 | viewRectLeft = cleintRectCentreX-(scaledx/2) 55 | arRectWidth = scaledx 56 | arRectHeight = scaledy 57 | 58 | arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight ) 59 | arscaledx = currentWidth*arscale 60 | arscaledy = currentHeight*arscale 61 | 62 | arRectTop = cleintRectCentreY-(arscaledy/2) 63 | arRectLeft = cleintRectCentreX-(arscaledx/2) 64 | arRectWidth = arscaledx 65 | arRectHeight = arscaledy 66 | 67 | arPreviewRect.style.top = arRectTop+'px'; 68 | arPreviewRect.style.left = arRectLeft+'px'; 69 | arPreviewRect.style.width = arRectWidth+'px'; 70 | arPreviewRect.style.height = arRectHeight+'px'; 71 | 72 | clearTimeout(arFrameTimeout); 73 | arFrameTimeout = setTimeout(function(){ 74 | arPreviewRect.style.display = 'none'; 75 | },2000); 76 | 77 | arPreviewRect.style.display = 'block'; 78 | 79 | } 80 | 81 | } 82 | 83 | 84 | onUiUpdate(function(){ 85 | var arPreviewRect = gradioApp().querySelector('#imageARPreview'); 86 | if(arPreviewRect){ 87 | arPreviewRect.style.display = 'none'; 88 | } 89 | var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) 90 | if(inImg2img){ 91 | let inputs = gradioApp().querySelectorAll('input'); 92 | inputs.forEach(function(e){ 93 | var is_width = e.parentElement.id == "img2img_width" 94 | var is_height = e.parentElement.id == "img2img_height" 95 | 96 | if((is_width || is_height) && !e.classList.contains('scrollwatch')){ 97 | e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} ) 98 | e.classList.add('scrollwatch') 99 | } 100 | if(is_width){ 101 | currentWidth = e.value*1.0 102 | } 103 | if(is_height){ 104 | currentHeight = e.value*1.0 105 | } 106 | }) 107 | } 108 | }); 109 | -------------------------------------------------------------------------------- /scripts/sd_upscale.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import modules.scripts as scripts 4 | import gradio as gr 5 | from PIL import Image 6 | 7 | from modules import processing, shared, sd_samplers, images, devices 8 | from modules.processing import Processed 9 | from modules.shared import opts, cmd_opts, state 10 | 11 | 12 | class Script(scripts.Script): 13 | def title(self): 14 | return "SD upscale" 15 | 16 | def show(self, is_img2img): 17 | return is_img2img 18 | 19 | def ui(self, is_img2img): 20 | info = gr.HTML("

Will upscale the image to twice the dimensions; use width and height sliders to set tile size

") 21 | overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64) 22 | upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") 23 | 24 | return [info, overlap, upscaler_index] 25 | 26 | def run(self, p, _, overlap, upscaler_index): 27 | processing.fix_seed(p) 28 | upscaler = shared.sd_upscalers[upscaler_index] 29 | 30 | p.extra_generation_params["SD upscale overlap"] = overlap 31 | p.extra_generation_params["SD upscale upscaler"] = upscaler.name 32 | 33 | initial_info = None 34 | seed = p.seed 35 | 36 | init_img = p.init_images[0] 37 | 38 | if(upscaler.name != "None"): 39 | img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path) 40 | else: 41 | img = init_img 42 | 43 | devices.torch_gc() 44 | 45 | grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap) 46 | 47 | batch_size = p.batch_size 48 | upscale_count = p.n_iter 49 | p.n_iter = 1 50 | p.do_not_save_grid = True 51 | p.do_not_save_samples = True 52 | 53 | work = [] 54 | 55 | for y, h, row in grid.tiles: 56 | for tiledata in row: 57 | work.append(tiledata[2]) 58 | 59 | batch_count = math.ceil(len(work) / batch_size) 60 | state.job_count = batch_count * upscale_count 61 | 62 | print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.") 63 | 64 | result_images = [] 65 | for n in range(upscale_count): 66 | start_seed = seed + n 67 | p.seed = start_seed 68 | 69 | work_results = [] 70 | for i in range(batch_count): 71 | p.batch_size = batch_size 72 | p.init_images = work[i*batch_size:(i+1)*batch_size] 73 | 74 | state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}" 75 | processed = processing.process_images(p) 76 | 77 | if initial_info is None: 78 | initial_info = processed.info 79 | 80 | p.seed = processed.seed + 1 81 | work_results += processed.images 82 | 83 | image_index = 0 84 | for y, h, row in grid.tiles: 85 | for tiledata in row: 86 | tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) 87 | image_index += 1 88 | 89 | combined_image = images.combine_grid(grid) 90 | result_images.append(combined_image) 91 | 92 | if opts.samples_save: 93 | images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p) 94 | 95 | processed = Processed(p, result_images, seed, initial_info) 96 | 97 | return processed 98 | -------------------------------------------------------------------------------- /modules/lowvram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules import devices 3 | 4 | module_in_gpu = None 5 | cpu = torch.device("cpu") 6 | 7 | 8 | def send_everything_to_cpu(): 9 | global module_in_gpu 10 | 11 | if module_in_gpu is not None: 12 | module_in_gpu.to(cpu) 13 | 14 | module_in_gpu = None 15 | 16 | 17 | def setup_for_low_vram(sd_model, use_medvram): 18 | parents = {} 19 | 20 | def send_me_to_gpu(module, _): 21 | """send this module to GPU; send whatever tracked module was previous in GPU to CPU; 22 | we add this as forward_pre_hook to a lot of modules and this way all but one of them will 23 | be in CPU 24 | """ 25 | global module_in_gpu 26 | 27 | module = parents.get(module, module) 28 | 29 | if module_in_gpu == module: 30 | return 31 | 32 | if module_in_gpu is not None: 33 | module_in_gpu.to(cpu) 34 | 35 | module.to(devices.device) 36 | module_in_gpu = module 37 | 38 | # see below for register_forward_pre_hook; 39 | # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is 40 | # useless here, and we just replace those methods 41 | 42 | first_stage_model = sd_model.first_stage_model 43 | first_stage_model_encode = sd_model.first_stage_model.encode 44 | first_stage_model_decode = sd_model.first_stage_model.decode 45 | 46 | def first_stage_model_encode_wrap(x): 47 | send_me_to_gpu(first_stage_model, None) 48 | return first_stage_model_encode(x) 49 | 50 | def first_stage_model_decode_wrap(z): 51 | send_me_to_gpu(first_stage_model, None) 52 | return first_stage_model_decode(z) 53 | 54 | # remove three big modules, cond, first_stage, and unet from the model and then 55 | # send the model to GPU. Then put modules back. the modules will be in CPU. 56 | stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model 57 | sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None 58 | sd_model.to(devices.device) 59 | sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored 60 | 61 | # register hooks for those the first two models 62 | sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) 63 | sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) 64 | sd_model.first_stage_model.encode = first_stage_model_encode_wrap 65 | sd_model.first_stage_model.decode = first_stage_model_decode_wrap 66 | parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model 67 | 68 | if use_medvram: 69 | sd_model.model.register_forward_pre_hook(send_me_to_gpu) 70 | else: 71 | diff_model = sd_model.model.diffusion_model 72 | 73 | # the third remaining model is still too big for 4 GB, so we also do the same for its submodules 74 | # so that only one of them is in GPU at a time 75 | stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed 76 | diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None 77 | sd_model.model.to(devices.device) 78 | diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored 79 | 80 | # install hooks for bits of third model 81 | diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu) 82 | for block in diff_model.input_blocks: 83 | block.register_forward_pre_hook(send_me_to_gpu) 84 | diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu) 85 | for block in diff_model.output_blocks: 86 | block.register_forward_pre_hook(send_me_to_gpu) 87 | -------------------------------------------------------------------------------- /modules/styles.py: -------------------------------------------------------------------------------- 1 | # We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime 2 | from __future__ import annotations 3 | 4 | import csv 5 | import os 6 | import os.path 7 | import typing 8 | import collections.abc as abc 9 | import tempfile 10 | import shutil 11 | 12 | if typing.TYPE_CHECKING: 13 | # Only import this when code is being type-checked, it doesn't have any effect at runtime 14 | from .processing import StableDiffusionProcessing 15 | 16 | 17 | class PromptStyle(typing.NamedTuple): 18 | name: str 19 | prompt: str 20 | negative_prompt: str 21 | 22 | 23 | def merge_prompts(style_prompt: str, prompt: str) -> str: 24 | if "{prompt}" in style_prompt: 25 | res = style_prompt.replace("{prompt}", prompt) 26 | else: 27 | parts = filter(None, (prompt.strip(), style_prompt.strip())) 28 | res = ", ".join(parts) 29 | 30 | return res 31 | 32 | 33 | def apply_styles_to_prompt(prompt, styles): 34 | for style in styles: 35 | prompt = merge_prompts(style, prompt) 36 | 37 | return prompt 38 | 39 | 40 | class StyleDatabase: 41 | def __init__(self, path: str): 42 | self.no_style = PromptStyle("None", "", "") 43 | self.styles = {"None": self.no_style} 44 | 45 | if not os.path.exists(path): 46 | return 47 | 48 | with open(path, "r", encoding="utf-8-sig", newline='') as file: 49 | reader = csv.DictReader(file) 50 | for row in reader: 51 | # Support loading old CSV format with "name, text"-columns 52 | prompt = row["prompt"] if "prompt" in row else row["text"] 53 | negative_prompt = row.get("negative_prompt", "") 54 | self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt) 55 | 56 | def get_style_prompts(self, styles): 57 | return [self.styles.get(x, self.no_style).prompt for x in styles] 58 | 59 | def get_negative_style_prompts(self, styles): 60 | return [self.styles.get(x, self.no_style).negative_prompt for x in styles] 61 | 62 | def apply_styles_to_prompt(self, prompt, styles): 63 | return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles]) 64 | 65 | def apply_negative_styles_to_prompt(self, prompt, styles): 66 | return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]) 67 | 68 | def apply_styles(self, p: StableDiffusionProcessing) -> None: 69 | if isinstance(p.prompt, list): 70 | p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt] 71 | else: 72 | p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles) 73 | 74 | if isinstance(p.negative_prompt, list): 75 | p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt] 76 | else: 77 | p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles) 78 | 79 | def save_styles(self, path: str) -> None: 80 | # Write to temporary file first, so we don't nuke the file if something goes wrong 81 | fd, temp_path = tempfile.mkstemp(".csv") 82 | with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: 83 | # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, 84 | # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() 85 | writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) 86 | writer.writeheader() 87 | writer.writerows(style._asdict() for k, style in self.styles.items()) 88 | 89 | # Always keep a backup file around 90 | if os.path.exists(path): 91 | shutil.move(path, path + ".bak") 92 | shutil.move(temp_path, path) 93 | -------------------------------------------------------------------------------- /modules/gfpgan_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | import facexlib 6 | import gfpgan 7 | 8 | import modules.face_restoration 9 | from modules import shared, devices, modelloader 10 | from modules.paths import models_path 11 | 12 | model_dir = "GFPGAN" 13 | user_path = None 14 | model_path = os.path.join(models_path, model_dir) 15 | model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" 16 | have_gfpgan = False 17 | loaded_gfpgan_model = None 18 | 19 | 20 | def gfpgann(): 21 | global loaded_gfpgan_model 22 | global model_path 23 | if loaded_gfpgan_model is not None: 24 | loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan) 25 | return loaded_gfpgan_model 26 | 27 | if gfpgan_constructor is None: 28 | return None 29 | 30 | models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") 31 | if len(models) == 1 and "http" in models[0]: 32 | model_file = models[0] 33 | elif len(models) != 0: 34 | latest_file = max(models, key=os.path.getctime) 35 | model_file = latest_file 36 | else: 37 | print("Unable to load gfpgan model!") 38 | return None 39 | model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) 40 | loaded_gfpgan_model = model 41 | 42 | return model 43 | 44 | 45 | def send_model_to(model, device): 46 | model.gfpgan.to(device) 47 | model.face_helper.face_det.to(device) 48 | model.face_helper.face_parse.to(device) 49 | 50 | 51 | def gfpgan_fix_faces(np_image): 52 | model = gfpgann() 53 | if model is None: 54 | return np_image 55 | 56 | send_model_to(model, devices.device_gfpgan) 57 | 58 | np_image_bgr = np_image[:, :, ::-1] 59 | cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) 60 | np_image = gfpgan_output_bgr[:, :, ::-1] 61 | 62 | model.face_helper.clean_all() 63 | 64 | if shared.opts.face_restoration_unload: 65 | send_model_to(model, devices.cpu) 66 | 67 | return np_image 68 | 69 | 70 | gfpgan_constructor = None 71 | 72 | 73 | def setup_model(dirname): 74 | global model_path 75 | if not os.path.exists(model_path): 76 | os.makedirs(model_path) 77 | 78 | try: 79 | from gfpgan import GFPGANer 80 | from facexlib import detection, parsing 81 | global user_path 82 | global have_gfpgan 83 | global gfpgan_constructor 84 | 85 | load_file_from_url_orig = gfpgan.utils.load_file_from_url 86 | facex_load_file_from_url_orig = facexlib.detection.load_file_from_url 87 | facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url 88 | 89 | def my_load_file_from_url(**kwargs): 90 | return load_file_from_url_orig(**dict(kwargs, model_dir=model_path)) 91 | 92 | def facex_load_file_from_url(**kwargs): 93 | return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None)) 94 | 95 | def facex_load_file_from_url2(**kwargs): 96 | return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None)) 97 | 98 | gfpgan.utils.load_file_from_url = my_load_file_from_url 99 | facexlib.detection.load_file_from_url = facex_load_file_from_url 100 | facexlib.parsing.load_file_from_url = facex_load_file_from_url2 101 | user_path = dirname 102 | have_gfpgan = True 103 | gfpgan_constructor = GFPGANer 104 | 105 | class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): 106 | def name(self): 107 | return "GFPGAN" 108 | 109 | def restore(self, np_image): 110 | return gfpgan_fix_faces(np_image) 111 | 112 | shared.face_restorers.append(FaceRestorerGFPGAN()) 113 | except Exception: 114 | print("Error setting up GFPGAN:", file=sys.stderr) 115 | print(traceback.format_exc(), file=sys.stderr) 116 | -------------------------------------------------------------------------------- /webui.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ################################################# 3 | # Please do not make any changes to this file, # 4 | # change the variables in webui-user.sh instead # 5 | ################################################# 6 | # Read variables from webui-user.sh 7 | # shellcheck source=/dev/null 8 | if [[ -f webui-user.sh ]] 9 | then 10 | source ./webui-user.sh 11 | fi 12 | 13 | # Set defaults 14 | # Install directory without trailing slash 15 | if [[ -z "${install_dir}" ]] 16 | then 17 | install_dir="/home/$(whoami)" 18 | fi 19 | 20 | # Name of the subdirectory (defaults to stable-diffusion-webui) 21 | if [[ -z "${clone_dir}" ]] 22 | then 23 | clone_dir="stable-diffusion-webui" 24 | fi 25 | 26 | # python3 executable 27 | if [[ -z "${python_cmd}" ]] 28 | then 29 | python_cmd="python3" 30 | fi 31 | 32 | # git executable 33 | if [[ -z "${GIT}" ]] 34 | then 35 | export GIT="git" 36 | fi 37 | 38 | # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) 39 | if [[ -z "${venv_dir}" ]] 40 | then 41 | venv_dir="venv" 42 | fi 43 | 44 | if [[ -z "${LAUNCH_SCRIPT}" ]] 45 | then 46 | LAUNCH_SCRIPT="launch.py" 47 | fi 48 | 49 | # Disable sentry logging 50 | export ERROR_REPORTING=FALSE 51 | 52 | # Do not reinstall existing pip packages on Debian/Ubuntu 53 | export PIP_IGNORE_INSTALLED=0 54 | 55 | # Pretty print 56 | delimiter="################################################################" 57 | 58 | printf "\n%s\n" "${delimiter}" 59 | printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n" 60 | printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m" 61 | printf "\n%s\n" "${delimiter}" 62 | 63 | # Do not run as root 64 | if [[ $(id -u) -eq 0 ]] 65 | then 66 | printf "\n%s\n" "${delimiter}" 67 | printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m" 68 | printf "\n%s\n" "${delimiter}" 69 | exit 1 70 | else 71 | printf "\n%s\n" "${delimiter}" 72 | printf "Running on \e[1m\e[32m%s\e[0m user" "$(whoami)" 73 | printf "\n%s\n" "${delimiter}" 74 | fi 75 | 76 | if [[ -d .git ]] 77 | then 78 | printf "\n%s\n" "${delimiter}" 79 | printf "Repo already cloned, using it as install directory" 80 | printf "\n%s\n" "${delimiter}" 81 | install_dir="${PWD}/../" 82 | clone_dir="${PWD##*/}" 83 | fi 84 | 85 | # Check prerequisites 86 | for preq in "${GIT}" "${python_cmd}" 87 | do 88 | if ! hash "${preq}" &>/dev/null 89 | then 90 | printf "\n%s\n" "${delimiter}" 91 | printf "\e[1m\e[31mERROR: %s is not installed, aborting...\e[0m" "${preq}" 92 | printf "\n%s\n" "${delimiter}" 93 | exit 1 94 | fi 95 | done 96 | 97 | if ! "${python_cmd}" -c "import venv" &>/dev/null 98 | then 99 | printf "\n%s\n" "${delimiter}" 100 | printf "\e[1m\e[31mERROR: python3-venv is not installed, aborting...\e[0m" 101 | printf "\n%s\n" "${delimiter}" 102 | exit 1 103 | fi 104 | 105 | cd "${install_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/, aborting...\e[0m" "${install_dir}"; exit 1; } 106 | if [[ -d "${clone_dir}" ]] 107 | then 108 | cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } 109 | else 110 | printf "\n%s\n" "${delimiter}" 111 | printf "Clone stable-diffusion-webui" 112 | printf "\n%s\n" "${delimiter}" 113 | "${GIT}" clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git "${clone_dir}" 114 | cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } 115 | fi 116 | 117 | printf "\n%s\n" "${delimiter}" 118 | printf "Create and activate python venv" 119 | printf "\n%s\n" "${delimiter}" 120 | cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } 121 | if [[ ! -d "${venv_dir}" ]] 122 | then 123 | "${python_cmd}" -m venv "${venv_dir}" 124 | first_launch=1 125 | fi 126 | # shellcheck source=/dev/null 127 | if [[ -f "${venv_dir}"/bin/activate ]] 128 | then 129 | source "${venv_dir}"/bin/activate 130 | else 131 | printf "\n%s\n" "${delimiter}" 132 | printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m" 133 | printf "\n%s\n" "${delimiter}" 134 | exit 1 135 | fi 136 | 137 | printf "\n%s\n" "${delimiter}" 138 | printf "Launching launch.py..." 139 | printf "\n%s\n" "${delimiter}" 140 | "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" 141 | -------------------------------------------------------------------------------- /modules/upscaler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import abstractmethod 3 | 4 | import PIL 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | 9 | import modules.shared 10 | from modules import modelloader, shared 11 | 12 | LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) 13 | NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) 14 | from modules.paths import models_path 15 | 16 | 17 | class Upscaler: 18 | name = None 19 | model_path = None 20 | model_name = None 21 | model_url = None 22 | enable = True 23 | filter = None 24 | model = None 25 | user_path = None 26 | scalers: [] 27 | tile = True 28 | 29 | def __init__(self, create_dirs=False): 30 | self.mod_pad_h = None 31 | self.tile_size = modules.shared.opts.ESRGAN_tile 32 | self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap 33 | self.device = modules.shared.device 34 | self.img = None 35 | self.output = None 36 | self.scale = 1 37 | self.half = not modules.shared.cmd_opts.no_half 38 | self.pre_pad = 0 39 | self.mod_scale = None 40 | 41 | if self.model_path is None and self.name: 42 | self.model_path = os.path.join(models_path, self.name) 43 | if self.model_path and create_dirs: 44 | os.makedirs(self.model_path, exist_ok=True) 45 | 46 | try: 47 | import cv2 48 | self.can_tile = True 49 | except: 50 | pass 51 | 52 | @abstractmethod 53 | def do_upscale(self, img: PIL.Image, selected_model: str): 54 | return img 55 | 56 | def upscale(self, img: PIL.Image, scale: int, selected_model: str = None): 57 | self.scale = scale 58 | dest_w = img.width * scale 59 | dest_h = img.height * scale 60 | 61 | for i in range(3): 62 | shape = (img.width, img.height) 63 | 64 | img = self.do_upscale(img, selected_model) 65 | 66 | if shape == (img.width, img.height): 67 | break 68 | 69 | if img.width >= dest_w and img.height >= dest_h: 70 | break 71 | 72 | if img.width != dest_w or img.height != dest_h: 73 | img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) 74 | 75 | return img 76 | 77 | @abstractmethod 78 | def load_model(self, path: str): 79 | pass 80 | 81 | def find_models(self, ext_filter=None) -> list: 82 | return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path) 83 | 84 | def update_status(self, prompt): 85 | print(f"\nextras: {prompt}", file=shared.progress_print_out) 86 | 87 | 88 | class UpscalerData: 89 | name = None 90 | data_path = None 91 | scale: int = 4 92 | scaler: Upscaler = None 93 | model: None 94 | 95 | def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None): 96 | self.name = name 97 | self.data_path = path 98 | self.scaler = upscaler 99 | self.scale = scale 100 | self.model = model 101 | 102 | 103 | class UpscalerNone(Upscaler): 104 | name = "None" 105 | scalers = [] 106 | 107 | def load_model(self, path): 108 | pass 109 | 110 | def do_upscale(self, img, selected_model=None): 111 | return img 112 | 113 | def __init__(self, dirname=None): 114 | super().__init__(False) 115 | self.scalers = [UpscalerData("None", None, self)] 116 | 117 | 118 | class UpscalerLanczos(Upscaler): 119 | scalers = [] 120 | 121 | def do_upscale(self, img, selected_model=None): 122 | return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS) 123 | 124 | def load_model(self, _): 125 | pass 126 | 127 | def __init__(self, dirname=None): 128 | super().__init__(False) 129 | self.name = "Lanczos" 130 | self.scalers = [UpscalerData("Lanczos", None, self)] 131 | 132 | 133 | class UpscalerNearest(Upscaler): 134 | scalers = [] 135 | 136 | def do_upscale(self, img, selected_model=None): 137 | return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST) 138 | 139 | def load_model(self, _): 140 | pass 141 | 142 | def __init__(self, dirname=None): 143 | super().__init__(False) 144 | self.name = "Nearest" 145 | self.scalers = [UpscalerData("Nearest", None, self)] -------------------------------------------------------------------------------- /modules/textual_inversion/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | import torch 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | 9 | import random 10 | import tqdm 11 | from modules import devices, shared 12 | import re 13 | 14 | re_numbers_at_start = re.compile(r"^[-\d]+\s*") 15 | 16 | 17 | class DatasetEntry: 18 | def __init__(self, filename=None, latent=None, filename_text=None): 19 | self.filename = filename 20 | self.latent = latent 21 | self.filename_text = filename_text 22 | self.cond = None 23 | self.cond_text = None 24 | 25 | 26 | class PersonalizedBase(Dataset): 27 | def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): 28 | re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None 29 | 30 | self.placeholder_token = placeholder_token 31 | 32 | self.batch_size = batch_size 33 | self.width = width 34 | self.height = height 35 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 36 | 37 | self.dataset = [] 38 | 39 | with open(template_file, "r") as file: 40 | lines = [x.strip() for x in file.readlines()] 41 | 42 | self.lines = lines 43 | 44 | assert data_root, 'dataset directory not specified' 45 | assert os.path.isdir(data_root), "Dataset directory doesn't exist" 46 | assert os.listdir(data_root), "Dataset directory is empty" 47 | 48 | cond_model = shared.sd_model.cond_stage_model 49 | 50 | self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] 51 | print("Preparing dataset...") 52 | for path in tqdm.tqdm(self.image_paths): 53 | try: 54 | image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) 55 | except Exception: 56 | continue 57 | 58 | text_filename = os.path.splitext(path)[0] + ".txt" 59 | filename = os.path.basename(path) 60 | 61 | if os.path.exists(text_filename): 62 | with open(text_filename, "r", encoding="utf8") as file: 63 | filename_text = file.read() 64 | else: 65 | filename_text = os.path.splitext(filename)[0] 66 | filename_text = re.sub(re_numbers_at_start, '', filename_text) 67 | if re_word: 68 | tokens = re_word.findall(filename_text) 69 | filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens) 70 | 71 | npimage = np.array(image).astype(np.uint8) 72 | npimage = (npimage / 127.5 - 1.0).astype(np.float32) 73 | 74 | torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) 75 | torchdata = torch.moveaxis(torchdata, 2, 0) 76 | 77 | init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() 78 | init_latent = init_latent.to(devices.cpu) 79 | 80 | entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) 81 | 82 | if include_cond: 83 | entry.cond_text = self.create_text(filename_text) 84 | entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) 85 | 86 | self.dataset.append(entry) 87 | 88 | assert len(self.dataset) > 0, "No images have been found in the dataset." 89 | self.length = len(self.dataset) * repeats // batch_size 90 | 91 | self.dataset_length = len(self.dataset) 92 | self.indexes = None 93 | self.shuffle() 94 | 95 | def shuffle(self): 96 | self.indexes = np.random.permutation(self.dataset_length) 97 | 98 | def create_text(self, filename_text): 99 | text = random.choice(self.lines) 100 | text = text.replace("[name]", self.placeholder_token) 101 | text = text.replace("[filewords]", filename_text) 102 | return text 103 | 104 | def __len__(self): 105 | return self.length 106 | 107 | def __getitem__(self, i): 108 | res = [] 109 | 110 | for j in range(self.batch_size): 111 | position = i * self.batch_size + j 112 | if position % len(self.indexes) == 0: 113 | self.shuffle() 114 | 115 | index = self.indexes[position % len(self.indexes)] 116 | entry = self.dataset[index] 117 | 118 | if entry.cond is None: 119 | entry.cond_text = self.create_text(entry.filename_text) 120 | 121 | res.append(entry) 122 | 123 | return res 124 | -------------------------------------------------------------------------------- /javascript/localization.js: -------------------------------------------------------------------------------- 1 | 2 | // localization = {} -- the dict with translations is created by the backend 3 | 4 | ignore_ids_for_localization={ 5 | setting_sd_hypernetwork: 'OPTION', 6 | setting_sd_model_checkpoint: 'OPTION', 7 | setting_realesrgan_enabled_models: 'OPTION', 8 | modelmerger_primary_model_name: 'OPTION', 9 | modelmerger_secondary_model_name: 'OPTION', 10 | modelmerger_tertiary_model_name: 'OPTION', 11 | train_embedding: 'OPTION', 12 | train_hypernetwork: 'OPTION', 13 | txt2img_style_index: 'OPTION', 14 | txt2img_style2_index: 'OPTION', 15 | img2img_style_index: 'OPTION', 16 | img2img_style2_index: 'OPTION', 17 | setting_random_artist_categories: 'SPAN', 18 | setting_face_restoration_model: 'SPAN', 19 | setting_realesrgan_enabled_models: 'SPAN', 20 | extras_upscaler_1: 'SPAN', 21 | extras_upscaler_2: 'SPAN', 22 | } 23 | 24 | re_num = /^[\.\d]+$/ 25 | re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u 26 | 27 | original_lines = {} 28 | translated_lines = {} 29 | 30 | function textNodesUnder(el){ 31 | var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false); 32 | while(n=walk.nextNode()) a.push(n); 33 | return a; 34 | } 35 | 36 | function canBeTranslated(node, text){ 37 | if(! text) return false; 38 | if(! node.parentElement) return false; 39 | 40 | parentType = node.parentElement.nodeName 41 | if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false; 42 | 43 | if (parentType=='OPTION' || parentType=='SPAN'){ 44 | pnode = node 45 | for(var level=0; level<4; level++){ 46 | pnode = pnode.parentElement 47 | if(! pnode) break; 48 | 49 | if(ignore_ids_for_localization[pnode.id] == parentType) return false; 50 | } 51 | } 52 | 53 | if(re_num.test(text)) return false; 54 | if(re_emoji.test(text)) return false; 55 | return true 56 | } 57 | 58 | function getTranslation(text){ 59 | if(! text) return undefined 60 | 61 | if(translated_lines[text] === undefined){ 62 | original_lines[text] = 1 63 | } 64 | 65 | tl = localization[text] 66 | if(tl !== undefined){ 67 | translated_lines[tl] = 1 68 | } 69 | 70 | return tl 71 | } 72 | 73 | function processTextNode(node){ 74 | text = node.textContent.trim() 75 | 76 | if(! canBeTranslated(node, text)) return 77 | 78 | tl = getTranslation(text) 79 | if(tl !== undefined){ 80 | node.textContent = tl 81 | } 82 | } 83 | 84 | function processNode(node){ 85 | if(node.nodeType == 3){ 86 | processTextNode(node) 87 | return 88 | } 89 | 90 | if(node.title){ 91 | tl = getTranslation(node.title) 92 | if(tl !== undefined){ 93 | node.title = tl 94 | } 95 | } 96 | 97 | if(node.placeholder){ 98 | tl = getTranslation(node.placeholder) 99 | if(tl !== undefined){ 100 | node.placeholder = tl 101 | } 102 | } 103 | 104 | textNodesUnder(node).forEach(function(node){ 105 | processTextNode(node) 106 | }) 107 | } 108 | 109 | function dumpTranslations(){ 110 | dumped = {} 111 | if (localization.rtl) { 112 | dumped.rtl = true 113 | } 114 | 115 | Object.keys(original_lines).forEach(function(text){ 116 | if(dumped[text] !== undefined) return 117 | 118 | dumped[text] = localization[text] || text 119 | }) 120 | 121 | return dumped 122 | } 123 | 124 | onUiUpdate(function(m){ 125 | m.forEach(function(mutation){ 126 | mutation.addedNodes.forEach(function(node){ 127 | processNode(node) 128 | }) 129 | }); 130 | }) 131 | 132 | 133 | document.addEventListener("DOMContentLoaded", function() { 134 | processNode(gradioApp()) 135 | 136 | if (localization.rtl) { // if the language is from right to left, 137 | (new MutationObserver((mutations, observer) => { // wait for the style to load 138 | mutations.forEach(mutation => { 139 | mutation.addedNodes.forEach(node => { 140 | if (node.tagName === 'STYLE') { 141 | observer.disconnect(); 142 | 143 | for (const x of node.sheet.rules) { // find all rtl media rules 144 | if (Array.from(x.media || []).includes('rtl')) { 145 | x.media.appendMedium('all'); // enable them 146 | } 147 | } 148 | } 149 | }) 150 | }); 151 | })).observe(gradioApp(), { childList: true }); 152 | } 153 | }) 154 | 155 | function download_localization() { 156 | text = JSON.stringify(dumpTranslations(), null, 4) 157 | 158 | var element = document.createElement('a'); 159 | element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text)); 160 | element.setAttribute('download', "localization.json"); 161 | element.style.display = 'none'; 162 | document.body.appendChild(element); 163 | 164 | element.click(); 165 | 166 | document.body.removeChild(element); 167 | } 168 | -------------------------------------------------------------------------------- /scripts/prompts_from_file.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import os 4 | import random 5 | import sys 6 | import traceback 7 | import shlex 8 | 9 | import modules.scripts as scripts 10 | import gradio as gr 11 | 12 | from modules.processing import Processed, process_images 13 | from PIL import Image 14 | from modules.shared import opts, cmd_opts, state 15 | 16 | 17 | def process_string_tag(tag): 18 | return tag 19 | 20 | 21 | def process_int_tag(tag): 22 | return int(tag) 23 | 24 | 25 | def process_float_tag(tag): 26 | return float(tag) 27 | 28 | 29 | def process_boolean_tag(tag): 30 | return True if (tag == "true") else False 31 | 32 | 33 | prompt_tags = { 34 | "sd_model": None, 35 | "outpath_samples": process_string_tag, 36 | "outpath_grids": process_string_tag, 37 | "prompt_for_display": process_string_tag, 38 | "prompt": process_string_tag, 39 | "negative_prompt": process_string_tag, 40 | "styles": process_string_tag, 41 | "seed": process_int_tag, 42 | "subseed_strength": process_float_tag, 43 | "subseed": process_int_tag, 44 | "seed_resize_from_h": process_int_tag, 45 | "seed_resize_from_w": process_int_tag, 46 | "sampler_index": process_int_tag, 47 | "batch_size": process_int_tag, 48 | "n_iter": process_int_tag, 49 | "steps": process_int_tag, 50 | "cfg_scale": process_float_tag, 51 | "width": process_int_tag, 52 | "height": process_int_tag, 53 | "restore_faces": process_boolean_tag, 54 | "tiling": process_boolean_tag, 55 | "do_not_save_samples": process_boolean_tag, 56 | "do_not_save_grid": process_boolean_tag 57 | } 58 | 59 | 60 | def cmdargs(line): 61 | args = shlex.split(line) 62 | pos = 0 63 | res = {} 64 | 65 | while pos < len(args): 66 | arg = args[pos] 67 | 68 | assert arg.startswith("--"), f'must start with "--": {arg}' 69 | tag = arg[2:] 70 | 71 | func = prompt_tags.get(tag, None) 72 | assert func, f'unknown commandline option: {arg}' 73 | 74 | assert pos+1 < len(args), f'missing argument for command line option {arg}' 75 | 76 | val = args[pos+1] 77 | 78 | res[tag] = func(val) 79 | 80 | pos += 2 81 | 82 | return res 83 | 84 | 85 | def load_prompt_file(file): 86 | if file is None: 87 | lines = [] 88 | else: 89 | lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")] 90 | 91 | return None, "\n".join(lines), gr.update(lines=7) 92 | 93 | 94 | class Script(scripts.Script): 95 | def title(self): 96 | return "Prompts from file or textbox" 97 | 98 | def ui(self, is_img2img): 99 | checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False) 100 | checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False) 101 | 102 | prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1) 103 | file = gr.File(label="Upload prompt inputs", type='bytes') 104 | 105 | file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt]) 106 | 107 | # We start at one line. When the text changes, we jump to seven lines, or two lines if no \n. 108 | # We don't shrink back to 1, because that causes the control to ignore [enter], and it may 109 | # be unclear to the user that shift-enter is needed. 110 | prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt]) 111 | return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] 112 | 113 | def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): 114 | lines = [x.strip() for x in prompt_txt.splitlines()] 115 | lines = [x for x in lines if len(x) > 0] 116 | 117 | p.do_not_save_grid = True 118 | 119 | job_count = 0 120 | jobs = [] 121 | 122 | for line in lines: 123 | if "--" in line: 124 | try: 125 | args = cmdargs(line) 126 | except Exception: 127 | print(f"Error parsing line [line] as commandline:", file=sys.stderr) 128 | print(traceback.format_exc(), file=sys.stderr) 129 | args = {"prompt": line} 130 | else: 131 | args = {"prompt": line} 132 | 133 | n_iter = args.get("n_iter", 1) 134 | if n_iter != 1: 135 | job_count += n_iter 136 | else: 137 | job_count += 1 138 | 139 | jobs.append(args) 140 | 141 | print(f"Will process {len(lines)} lines in {job_count} jobs.") 142 | if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1: 143 | p.seed = int(random.randrange(4294967294)) 144 | 145 | state.job_count = job_count 146 | 147 | images = [] 148 | for n, args in enumerate(jobs): 149 | state.job = f"{state.job_no + 1} out of {state.job_count}" 150 | 151 | copy_p = copy.copy(p) 152 | for k, v in args.items(): 153 | setattr(copy_p, k, v) 154 | 155 | proc = process_images(copy_p) 156 | images += proc.images 157 | 158 | if checkbox_iterate: 159 | p.seed = p.seed + (p.batch_size * p.n_iter) 160 | 161 | return Processed(p, images, p.seed, "") 162 | -------------------------------------------------------------------------------- /modules/realesrgan_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | import numpy as np 6 | from PIL import Image 7 | from basicsr.utils.download_util import load_file_from_url 8 | from realesrgan import RealESRGANer 9 | 10 | from modules.upscaler import Upscaler, UpscalerData 11 | from modules.shared import cmd_opts, opts 12 | 13 | 14 | class UpscalerRealESRGAN(Upscaler): 15 | def __init__(self, path): 16 | self.name = "RealESRGAN" 17 | self.user_path = path 18 | super().__init__() 19 | try: 20 | from basicsr.archs.rrdbnet_arch import RRDBNet 21 | from realesrgan import RealESRGANer 22 | from realesrgan.archs.srvgg_arch import SRVGGNetCompact 23 | self.enable = True 24 | self.scalers = [] 25 | scalers = self.load_models(path) 26 | for scaler in scalers: 27 | if scaler.name in opts.realesrgan_enabled_models: 28 | self.scalers.append(scaler) 29 | 30 | except Exception: 31 | print("Error importing Real-ESRGAN:", file=sys.stderr) 32 | print(traceback.format_exc(), file=sys.stderr) 33 | self.enable = False 34 | self.scalers = [] 35 | 36 | def do_upscale(self, img, path): 37 | if not self.enable: 38 | return img 39 | 40 | info = self.load_model(path) 41 | if not os.path.exists(info.data_path): 42 | print("Unable to load RealESRGAN model: %s" % info.name) 43 | return img 44 | 45 | upsampler = RealESRGANer( 46 | scale=info.scale, 47 | model_path=info.data_path, 48 | model=info.model(), 49 | half=not cmd_opts.no_half, 50 | tile=opts.ESRGAN_tile, 51 | tile_pad=opts.ESRGAN_tile_overlap, 52 | ) 53 | 54 | upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0] 55 | 56 | image = Image.fromarray(upsampled) 57 | return image 58 | 59 | def load_model(self, path): 60 | try: 61 | info = None 62 | for scaler in self.scalers: 63 | if scaler.data_path == path: 64 | info = scaler 65 | 66 | if info is None: 67 | print(f"Unable to find model info: {path}") 68 | return None 69 | 70 | model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True) 71 | info.data_path = model_file 72 | return info 73 | except Exception as e: 74 | print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr) 75 | print(traceback.format_exc(), file=sys.stderr) 76 | return None 77 | 78 | def load_models(self, _): 79 | return get_realesrgan_models(self) 80 | 81 | 82 | def get_realesrgan_models(scaler): 83 | try: 84 | from basicsr.archs.rrdbnet_arch import RRDBNet 85 | from realesrgan.archs.srvgg_arch import SRVGGNetCompact 86 | models = [ 87 | UpscalerData( 88 | name="R-ESRGAN General 4xV3", 89 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", 90 | scale=4, 91 | upscaler=scaler, 92 | model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') 93 | ), 94 | UpscalerData( 95 | name="R-ESRGAN General WDN 4xV3", 96 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", 97 | scale=4, 98 | upscaler=scaler, 99 | model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') 100 | ), 101 | UpscalerData( 102 | name="R-ESRGAN AnimeVideo", 103 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", 104 | scale=4, 105 | upscaler=scaler, 106 | model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') 107 | ), 108 | UpscalerData( 109 | name="R-ESRGAN 4x+", 110 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", 111 | scale=4, 112 | upscaler=scaler, 113 | model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) 114 | ), 115 | UpscalerData( 116 | name="R-ESRGAN 4x+ Anime6B", 117 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", 118 | scale=4, 119 | upscaler=scaler, 120 | model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) 121 | ), 122 | UpscalerData( 123 | name="R-ESRGAN 2x+", 124 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", 125 | scale=2, 126 | upscaler=scaler, 127 | model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) 128 | ), 129 | ] 130 | return models 131 | except Exception as e: 132 | print("Error making Real-ESRGAN models list:", file=sys.stderr) 133 | print(traceback.format_exc(), file=sys.stderr) 134 | -------------------------------------------------------------------------------- /modules/safe.py: -------------------------------------------------------------------------------- 1 | # this code is adapted from the script contributed by anon from /h/ 2 | 3 | import io 4 | import pickle 5 | import collections 6 | import sys 7 | import traceback 8 | 9 | import torch 10 | import numpy 11 | import _codecs 12 | import zipfile 13 | import re 14 | 15 | 16 | # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage 17 | TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage 18 | 19 | 20 | def encode(*args): 21 | out = _codecs.encode(*args) 22 | return out 23 | 24 | 25 | class RestrictedUnpickler(pickle.Unpickler): 26 | extra_handler = None 27 | 28 | def persistent_load(self, saved_id): 29 | assert saved_id[0] == 'storage' 30 | return TypedStorage() 31 | 32 | def find_class(self, module, name): 33 | if self.extra_handler is not None: 34 | res = self.extra_handler(module, name) 35 | if res is not None: 36 | return res 37 | 38 | if module == 'collections' and name == 'OrderedDict': 39 | return getattr(collections, name) 40 | if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: 41 | return getattr(torch._utils, name) 42 | if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']: 43 | return getattr(torch, name) 44 | if module == 'torch.nn.modules.container' and name in ['ParameterDict']: 45 | return getattr(torch.nn.modules.container, name) 46 | if module == 'numpy.core.multiarray' and name == 'scalar': 47 | return numpy.core.multiarray.scalar 48 | if module == 'numpy' and name == 'dtype': 49 | return numpy.dtype 50 | if module == '_codecs' and name == 'encode': 51 | return encode 52 | if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': 53 | import pytorch_lightning.callbacks 54 | return pytorch_lightning.callbacks.model_checkpoint 55 | if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': 56 | import pytorch_lightning.callbacks.model_checkpoint 57 | return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint 58 | if module == "__builtin__" and name == 'set': 59 | return set 60 | 61 | # Forbid everything else. 62 | raise Exception(f"global '{module}/{name}' is forbidden") 63 | 64 | 65 | allowed_zip_names = ["archive/data.pkl", "archive/version"] 66 | allowed_zip_names_re = re.compile(r"^archive/data/\d+$") 67 | 68 | 69 | def check_zip_filenames(filename, names): 70 | for name in names: 71 | if name in allowed_zip_names: 72 | continue 73 | if allowed_zip_names_re.match(name): 74 | continue 75 | 76 | raise Exception(f"bad file inside {filename}: {name}") 77 | 78 | 79 | def check_pt(filename, extra_handler): 80 | try: 81 | 82 | # new pytorch format is a zip file 83 | with zipfile.ZipFile(filename) as z: 84 | check_zip_filenames(filename, z.namelist()) 85 | 86 | with z.open('archive/data.pkl') as file: 87 | unpickler = RestrictedUnpickler(file) 88 | unpickler.extra_handler = extra_handler 89 | unpickler.load() 90 | 91 | except zipfile.BadZipfile: 92 | 93 | # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle 94 | with open(filename, "rb") as file: 95 | unpickler = RestrictedUnpickler(file) 96 | unpickler.extra_handler = extra_handler 97 | for i in range(5): 98 | unpickler.load() 99 | 100 | 101 | def load(filename, *args, **kwargs): 102 | return load_with_extra(filename, *args, **kwargs) 103 | 104 | 105 | def load_with_extra(filename, extra_handler=None, *args, **kwargs): 106 | """ 107 | this functon is intended to be used by extensions that want to load models with 108 | some extra classes in them that the usual unpickler would find suspicious. 109 | 110 | Use the extra_handler argument to specify a function that takes module and field name as text, 111 | and returns that field's value: 112 | 113 | ```python 114 | def extra(module, name): 115 | if module == 'collections' and name == 'OrderedDict': 116 | return collections.OrderedDict 117 | 118 | return None 119 | 120 | safe.load_with_extra('model.pt', extra_handler=extra) 121 | ``` 122 | 123 | The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is 124 | definitely unsafe. 125 | """ 126 | 127 | from modules import shared 128 | 129 | try: 130 | if not shared.cmd_opts.disable_safe_unpickle: 131 | check_pt(filename, extra_handler) 132 | 133 | except pickle.UnpicklingError: 134 | print(f"Error verifying pickled file from {filename}:", file=sys.stderr) 135 | print(traceback.format_exc(), file=sys.stderr) 136 | print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) 137 | print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) 138 | return None 139 | 140 | except Exception: 141 | print(f"Error verifying pickled file from {filename}:", file=sys.stderr) 142 | print(traceback.format_exc(), file=sys.stderr) 143 | print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) 144 | print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) 145 | return None 146 | 147 | return unsafe_torch_load(filename, *args, **kwargs) 148 | 149 | 150 | unsafe_torch_load = torch.load 151 | torch.load = load 152 | -------------------------------------------------------------------------------- /modules/img2img.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | import traceback 5 | 6 | import numpy as np 7 | from PIL import Image, ImageOps, ImageChops 8 | 9 | from modules import devices 10 | from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images 11 | from modules.shared import opts, state 12 | import modules.shared as shared 13 | import modules.processing as processing 14 | from modules.ui import plaintext_to_html 15 | import modules.images as images 16 | import modules.scripts 17 | 18 | 19 | def process_batch(p, input_dir, output_dir, args): 20 | processing.fix_seed(p) 21 | 22 | images = shared.listfiles(input_dir) 23 | 24 | print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") 25 | 26 | save_normally = output_dir == '' 27 | 28 | p.do_not_save_grid = True 29 | p.do_not_save_samples = not save_normally 30 | 31 | state.job_count = len(images) * p.n_iter 32 | 33 | for i, image in enumerate(images): 34 | state.job = f"{i+1} out of {len(images)}" 35 | if state.skipped: 36 | state.skipped = False 37 | 38 | if state.interrupted: 39 | break 40 | 41 | img = Image.open(image) 42 | # Use the EXIF orientation of photos taken by smartphones. 43 | img = ImageOps.exif_transpose(img) 44 | p.init_images = [img] * p.batch_size 45 | 46 | proc = modules.scripts.scripts_img2img.run(p, *args) 47 | if proc is None: 48 | proc = process_images(p) 49 | 50 | for n, processed_image in enumerate(proc.images): 51 | filename = os.path.basename(image) 52 | 53 | if n > 0: 54 | left, right = os.path.splitext(filename) 55 | filename = f"{left}-{n}{right}" 56 | 57 | if not save_normally: 58 | os.makedirs(output_dir, exist_ok=True) 59 | processed_image.save(os.path.join(output_dir, filename)) 60 | 61 | 62 | def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): 63 | is_inpaint = mode == 1 64 | is_batch = mode == 2 65 | 66 | if is_inpaint: 67 | # Drawn mask 68 | if mask_mode == 0: 69 | image = init_img_with_mask['image'] 70 | mask = init_img_with_mask['mask'] 71 | alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') 72 | mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') 73 | image = image.convert('RGB') 74 | # Uploaded mask 75 | else: 76 | image = init_img_inpaint 77 | mask = init_mask_inpaint 78 | # No mask 79 | else: 80 | image = init_img 81 | mask = None 82 | 83 | # Use the EXIF orientation of photos taken by smartphones. 84 | if image is not None: 85 | image = ImageOps.exif_transpose(image) 86 | 87 | assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' 88 | 89 | p = StableDiffusionProcessingImg2Img( 90 | sd_model=shared.sd_model, 91 | outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples, 92 | outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids, 93 | prompt=prompt, 94 | negative_prompt=negative_prompt, 95 | styles=[prompt_style, prompt_style2], 96 | seed=seed, 97 | subseed=subseed, 98 | subseed_strength=subseed_strength, 99 | seed_resize_from_h=seed_resize_from_h, 100 | seed_resize_from_w=seed_resize_from_w, 101 | seed_enable_extras=seed_enable_extras, 102 | sampler_index=sampler_index, 103 | batch_size=batch_size, 104 | n_iter=n_iter, 105 | steps=steps, 106 | cfg_scale=cfg_scale, 107 | width=width, 108 | height=height, 109 | restore_faces=restore_faces, 110 | tiling=tiling, 111 | init_images=[image], 112 | mask=mask, 113 | mask_blur=mask_blur, 114 | inpainting_fill=inpainting_fill, 115 | resize_mode=resize_mode, 116 | denoising_strength=denoising_strength, 117 | inpaint_full_res=inpaint_full_res, 118 | inpaint_full_res_padding=inpaint_full_res_padding, 119 | inpainting_mask_invert=inpainting_mask_invert, 120 | ) 121 | 122 | p.scripts = modules.scripts.scripts_txt2img 123 | p.script_args = args 124 | 125 | if shared.cmd_opts.enable_console_prompts: 126 | print(f"\nimg2img: {prompt}", file=shared.progress_print_out) 127 | 128 | p.extra_generation_params["Mask blur"] = mask_blur 129 | 130 | if is_batch: 131 | assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" 132 | 133 | process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args) 134 | 135 | processed = Processed(p, [], p.seed, "") 136 | else: 137 | processed = modules.scripts.scripts_img2img.run(p, *args) 138 | if processed is None: 139 | processed = process_images(p) 140 | 141 | p.close() 142 | 143 | shared.total_tqdm.clear() 144 | 145 | generation_info_js = processed.js() 146 | if opts.samples_log_stdout: 147 | print(generation_info_js) 148 | 149 | if opts.do_not_show_images: 150 | processed.images = [] 151 | 152 | return processed.images, generation_info_js, plaintext_to_html(processed.info) 153 | -------------------------------------------------------------------------------- /scripts/poor_mans_outpainting.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import modules.scripts as scripts 4 | import gradio as gr 5 | from PIL import Image, ImageDraw 6 | 7 | from modules import images, processing, devices 8 | from modules.processing import Processed, process_images 9 | from modules.shared import opts, cmd_opts, state 10 | 11 | 12 | 13 | class Script(scripts.Script): 14 | def title(self): 15 | return "Poor man's outpainting" 16 | 17 | def show(self, is_img2img): 18 | return is_img2img 19 | 20 | def ui(self, is_img2img): 21 | if not is_img2img: 22 | return None 23 | 24 | pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128) 25 | mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) 26 | inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index") 27 | direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down']) 28 | 29 | return [pixels, mask_blur, inpainting_fill, direction] 30 | 31 | def run(self, p, pixels, mask_blur, inpainting_fill, direction): 32 | initial_seed = None 33 | initial_info = None 34 | 35 | p.mask_blur = mask_blur * 2 36 | p.inpainting_fill = inpainting_fill 37 | p.inpaint_full_res = False 38 | 39 | left = pixels if "left" in direction else 0 40 | right = pixels if "right" in direction else 0 41 | up = pixels if "up" in direction else 0 42 | down = pixels if "down" in direction else 0 43 | 44 | init_img = p.init_images[0] 45 | target_w = math.ceil((init_img.width + left + right) / 64) * 64 46 | target_h = math.ceil((init_img.height + up + down) / 64) * 64 47 | 48 | if left > 0: 49 | left = left * (target_w - init_img.width) // (left + right) 50 | if right > 0: 51 | right = target_w - init_img.width - left 52 | 53 | if up > 0: 54 | up = up * (target_h - init_img.height) // (up + down) 55 | 56 | if down > 0: 57 | down = target_h - init_img.height - up 58 | 59 | img = Image.new("RGB", (target_w, target_h)) 60 | img.paste(init_img, (left, up)) 61 | 62 | mask = Image.new("L", (img.width, img.height), "white") 63 | draw = ImageDraw.Draw(mask) 64 | draw.rectangle(( 65 | left + (mask_blur * 2 if left > 0 else 0), 66 | up + (mask_blur * 2 if up > 0 else 0), 67 | mask.width - right - (mask_blur * 2 if right > 0 else 0), 68 | mask.height - down - (mask_blur * 2 if down > 0 else 0) 69 | ), fill="black") 70 | 71 | latent_mask = Image.new("L", (img.width, img.height), "white") 72 | latent_draw = ImageDraw.Draw(latent_mask) 73 | latent_draw.rectangle(( 74 | left + (mask_blur//2 if left > 0 else 0), 75 | up + (mask_blur//2 if up > 0 else 0), 76 | mask.width - right - (mask_blur//2 if right > 0 else 0), 77 | mask.height - down - (mask_blur//2 if down > 0 else 0) 78 | ), fill="black") 79 | 80 | devices.torch_gc() 81 | 82 | grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels) 83 | grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels) 84 | grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels) 85 | 86 | p.n_iter = 1 87 | p.batch_size = 1 88 | p.do_not_save_grid = True 89 | p.do_not_save_samples = True 90 | 91 | work = [] 92 | work_mask = [] 93 | work_latent_mask = [] 94 | work_results = [] 95 | 96 | for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles): 97 | for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask): 98 | x, w = tiledata[0:2] 99 | 100 | if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: 101 | continue 102 | 103 | work.append(tiledata[2]) 104 | work_mask.append(tiledata_mask[2]) 105 | work_latent_mask.append(tiledata_latent_mask[2]) 106 | 107 | batch_count = len(work) 108 | print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.") 109 | 110 | state.job_count = batch_count 111 | 112 | for i in range(batch_count): 113 | p.init_images = [work[i]] 114 | p.image_mask = work_mask[i] 115 | p.latent_mask = work_latent_mask[i] 116 | 117 | state.job = f"Batch {i + 1} out of {batch_count}" 118 | processed = process_images(p) 119 | 120 | if initial_seed is None: 121 | initial_seed = processed.seed 122 | initial_info = processed.info 123 | 124 | p.seed = processed.seed + 1 125 | work_results += processed.images 126 | 127 | 128 | image_index = 0 129 | for y, h, row in grid.tiles: 130 | for tiledata in row: 131 | x, w = tiledata[0:2] 132 | 133 | if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: 134 | continue 135 | 136 | tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) 137 | image_index += 1 138 | 139 | combined_image = images.combine_grid(grid) 140 | 141 | if opts.samples_save: 142 | images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.grid_format, info=initial_info, p=p) 143 | 144 | processed = Processed(p, [combined_image], initial_seed, initial_info) 145 | 146 | return processed 147 | 148 | -------------------------------------------------------------------------------- /javascript/contextMenus.js: -------------------------------------------------------------------------------- 1 | 2 | contextMenuInit = function(){ 3 | let eventListenerApplied=false; 4 | let menuSpecs = new Map(); 5 | 6 | const uid = function(){ 7 | return Date.now().toString(36) + Math.random().toString(36).substr(2); 8 | } 9 | 10 | function showContextMenu(event,element,menuEntries){ 11 | let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft; 12 | let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop; 13 | 14 | let oldMenu = gradioApp().querySelector('#context-menu') 15 | if(oldMenu){ 16 | oldMenu.remove() 17 | } 18 | 19 | let tabButton = uiCurrentTab 20 | let baseStyle = window.getComputedStyle(tabButton) 21 | 22 | const contextMenu = document.createElement('nav') 23 | contextMenu.id = "context-menu" 24 | contextMenu.style.background = baseStyle.background 25 | contextMenu.style.color = baseStyle.color 26 | contextMenu.style.fontFamily = baseStyle.fontFamily 27 | contextMenu.style.top = posy+'px' 28 | contextMenu.style.left = posx+'px' 29 | 30 | 31 | 32 | const contextMenuList = document.createElement('ul') 33 | contextMenuList.className = 'context-menu-items'; 34 | contextMenu.append(contextMenuList); 35 | 36 | menuEntries.forEach(function(entry){ 37 | let contextMenuEntry = document.createElement('a') 38 | contextMenuEntry.innerHTML = entry['name'] 39 | contextMenuEntry.addEventListener("click", function(e) { 40 | entry['func'](); 41 | }) 42 | contextMenuList.append(contextMenuEntry); 43 | 44 | }) 45 | 46 | gradioApp().getRootNode().appendChild(contextMenu) 47 | 48 | let menuWidth = contextMenu.offsetWidth + 4; 49 | let menuHeight = contextMenu.offsetHeight + 4; 50 | 51 | let windowWidth = window.innerWidth; 52 | let windowHeight = window.innerHeight; 53 | 54 | if ( (windowWidth - posx) < menuWidth ) { 55 | contextMenu.style.left = windowWidth - menuWidth + "px"; 56 | } 57 | 58 | if ( (windowHeight - posy) < menuHeight ) { 59 | contextMenu.style.top = windowHeight - menuHeight + "px"; 60 | } 61 | 62 | } 63 | 64 | function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){ 65 | 66 | currentItems = menuSpecs.get(targetEmementSelector) 67 | 68 | if(!currentItems){ 69 | currentItems = [] 70 | menuSpecs.set(targetEmementSelector,currentItems); 71 | } 72 | let newItem = {'id':targetEmementSelector+'_'+uid(), 73 | 'name':entryName, 74 | 'func':entryFunction, 75 | 'isNew':true} 76 | 77 | currentItems.push(newItem) 78 | return newItem['id'] 79 | } 80 | 81 | function removeContextMenuOption(uid){ 82 | menuSpecs.forEach(function(v,k) { 83 | let index = -1 84 | v.forEach(function(e,ei){if(e['id']==uid){index=ei}}) 85 | if(index>=0){ 86 | v.splice(index, 1); 87 | } 88 | }) 89 | } 90 | 91 | function addContextMenuEventListener(){ 92 | if(eventListenerApplied){ 93 | return; 94 | } 95 | gradioApp().addEventListener("click", function(e) { 96 | let source = e.composedPath()[0] 97 | if(source.id && source.id.indexOf('check_progress')>-1){ 98 | return 99 | } 100 | 101 | let oldMenu = gradioApp().querySelector('#context-menu') 102 | if(oldMenu){ 103 | oldMenu.remove() 104 | } 105 | }); 106 | gradioApp().addEventListener("contextmenu", function(e) { 107 | let oldMenu = gradioApp().querySelector('#context-menu') 108 | if(oldMenu){ 109 | oldMenu.remove() 110 | } 111 | menuSpecs.forEach(function(v,k) { 112 | if(e.composedPath()[0].matches(k)){ 113 | showContextMenu(e,e.composedPath()[0],v) 114 | e.preventDefault() 115 | return 116 | } 117 | }) 118 | }); 119 | eventListenerApplied=true 120 | 121 | } 122 | 123 | return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener] 124 | } 125 | 126 | initResponse = contextMenuInit(); 127 | appendContextMenuOption = initResponse[0]; 128 | removeContextMenuOption = initResponse[1]; 129 | addContextMenuEventListener = initResponse[2]; 130 | 131 | (function(){ 132 | //Start example Context Menu Items 133 | let generateOnRepeat = function(genbuttonid,interruptbuttonid){ 134 | let genbutton = gradioApp().querySelector(genbuttonid); 135 | let interruptbutton = gradioApp().querySelector(interruptbuttonid); 136 | if(!interruptbutton.offsetParent){ 137 | genbutton.click(); 138 | } 139 | clearInterval(window.generateOnRepeatInterval) 140 | window.generateOnRepeatInterval = setInterval(function(){ 141 | if(!interruptbutton.offsetParent){ 142 | genbutton.click(); 143 | } 144 | }, 145 | 500) 146 | } 147 | 148 | appendContextMenuOption('#txt2img_generate','Generate forever',function(){ 149 | generateOnRepeat('#txt2img_generate','#txt2img_interrupt'); 150 | }) 151 | appendContextMenuOption('#img2img_generate','Generate forever',function(){ 152 | generateOnRepeat('#img2img_generate','#img2img_interrupt'); 153 | }) 154 | 155 | let cancelGenerateForever = function(){ 156 | clearInterval(window.generateOnRepeatInterval) 157 | } 158 | 159 | appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever) 160 | appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever) 161 | appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever) 162 | appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever) 163 | 164 | appendContextMenuOption('#roll','Roll three', 165 | function(){ 166 | let rollbutton = get_uiCurrentTabContent().querySelector('#roll'); 167 | setTimeout(function(){rollbutton.click()},100) 168 | setTimeout(function(){rollbutton.click()},200) 169 | setTimeout(function(){rollbutton.click()},300) 170 | } 171 | ) 172 | })(); 173 | //End example Context Menu Items 174 | 175 | onUiUpdate(function(){ 176 | addContextMenuEventListener() 177 | }); 178 | -------------------------------------------------------------------------------- /modules/swinir_model.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from basicsr.utils.download_util import load_file_from_url 8 | from tqdm import tqdm 9 | 10 | from modules import modelloader, devices 11 | from modules.shared import cmd_opts, opts 12 | from modules.swinir_model_arch import SwinIR as net 13 | from modules.swinir_model_arch_v2 import Swin2SR as net2 14 | from modules.upscaler import Upscaler, UpscalerData 15 | 16 | precision_scope = ( 17 | torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext 18 | ) 19 | 20 | 21 | class UpscalerSwinIR(Upscaler): 22 | def __init__(self, dirname): 23 | self.name = "SwinIR" 24 | self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \ 25 | "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \ 26 | "-L_x4_GAN.pth " 27 | self.model_name = "SwinIR 4x" 28 | self.user_path = dirname 29 | super().__init__() 30 | scalers = [] 31 | model_files = self.find_models(ext_filter=[".pt", ".pth"]) 32 | for model in model_files: 33 | if "http" in model: 34 | name = self.model_name 35 | else: 36 | name = modelloader.friendly_name(model) 37 | model_data = UpscalerData(name, model, self) 38 | scalers.append(model_data) 39 | self.scalers = scalers 40 | 41 | def do_upscale(self, img, model_file): 42 | model = self.load_model(model_file) 43 | if model is None: 44 | return img 45 | model = model.to(devices.device_swinir) 46 | img = upscale(img, model) 47 | try: 48 | torch.cuda.empty_cache() 49 | except: 50 | pass 51 | return img 52 | 53 | def load_model(self, path, scale=4): 54 | if "http" in path: 55 | dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth") 56 | filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True) 57 | else: 58 | filename = path 59 | if filename is None or not os.path.exists(filename): 60 | return None 61 | if filename.endswith(".v2.pth"): 62 | model = net2( 63 | upscale=scale, 64 | in_chans=3, 65 | img_size=64, 66 | window_size=8, 67 | img_range=1.0, 68 | depths=[6, 6, 6, 6, 6, 6], 69 | embed_dim=180, 70 | num_heads=[6, 6, 6, 6, 6, 6], 71 | mlp_ratio=2, 72 | upsampler="nearest+conv", 73 | resi_connection="1conv", 74 | ) 75 | params = None 76 | else: 77 | model = net( 78 | upscale=scale, 79 | in_chans=3, 80 | img_size=64, 81 | window_size=8, 82 | img_range=1.0, 83 | depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], 84 | embed_dim=240, 85 | num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], 86 | mlp_ratio=2, 87 | upsampler="nearest+conv", 88 | resi_connection="3conv", 89 | ) 90 | params = "params_ema" 91 | 92 | pretrained_model = torch.load(filename) 93 | if params is not None: 94 | model.load_state_dict(pretrained_model[params], strict=True) 95 | else: 96 | model.load_state_dict(pretrained_model, strict=True) 97 | if not cmd_opts.no_half: 98 | model = model.half() 99 | return model 100 | 101 | 102 | def upscale( 103 | img, 104 | model, 105 | tile=opts.SWIN_tile, 106 | tile_overlap=opts.SWIN_tile_overlap, 107 | window_size=8, 108 | scale=4, 109 | ): 110 | img = np.array(img) 111 | img = img[:, :, ::-1] 112 | img = np.moveaxis(img, 2, 0) / 255 113 | img = torch.from_numpy(img).float() 114 | img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir) 115 | with torch.no_grad(), precision_scope("cuda"): 116 | _, _, h_old, w_old = img.size() 117 | h_pad = (h_old // window_size + 1) * window_size - h_old 118 | w_pad = (w_old // window_size + 1) * window_size - w_old 119 | img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] 120 | img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] 121 | output = inference(img, model, tile, tile_overlap, window_size, scale) 122 | output = output[..., : h_old * scale, : w_old * scale] 123 | output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() 124 | if output.ndim == 3: 125 | output = np.transpose( 126 | output[[2, 1, 0], :, :], (1, 2, 0) 127 | ) # CHW-RGB to HCW-BGR 128 | output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 129 | return Image.fromarray(output, "RGB") 130 | 131 | 132 | def inference(img, model, tile, tile_overlap, window_size, scale): 133 | # test the image tile by tile 134 | b, c, h, w = img.size() 135 | tile = min(tile, h, w) 136 | assert tile % window_size == 0, "tile size should be a multiple of window_size" 137 | sf = scale 138 | 139 | stride = tile - tile_overlap 140 | h_idx_list = list(range(0, h - tile, stride)) + [h - tile] 141 | w_idx_list = list(range(0, w - tile, stride)) + [w - tile] 142 | E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img) 143 | W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir) 144 | 145 | with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: 146 | for h_idx in h_idx_list: 147 | for w_idx in w_idx_list: 148 | in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] 149 | out_patch = model(in_patch) 150 | out_patch_mask = torch.ones_like(out_patch) 151 | 152 | E[ 153 | ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf 154 | ].add_(out_patch) 155 | W[ 156 | ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf 157 | ].add_(out_patch_mask) 158 | pbar.update(1) 159 | output = E.div_(W) 160 | 161 | return output 162 | -------------------------------------------------------------------------------- /modules/codeformer_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | import cv2 6 | import torch 7 | 8 | import modules.face_restoration 9 | import modules.shared 10 | from modules import shared, devices, modelloader 11 | from modules.paths import script_path, models_path 12 | 13 | # codeformer people made a choice to include modified basicsr library to their project which makes 14 | # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. 15 | # I am making a choice to include some files from codeformer to work around this issue. 16 | model_dir = "Codeformer" 17 | model_path = os.path.join(models_path, model_dir) 18 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' 19 | 20 | have_codeformer = False 21 | codeformer = None 22 | 23 | 24 | def setup_model(dirname): 25 | global model_path 26 | if not os.path.exists(model_path): 27 | os.makedirs(model_path) 28 | 29 | path = modules.paths.paths.get("CodeFormer", None) 30 | if path is None: 31 | return 32 | 33 | try: 34 | from torchvision.transforms.functional import normalize 35 | from modules.codeformer.codeformer_arch import CodeFormer 36 | from basicsr.utils.download_util import load_file_from_url 37 | from basicsr.utils import imwrite, img2tensor, tensor2img 38 | from facelib.utils.face_restoration_helper import FaceRestoreHelper 39 | from modules.shared import cmd_opts 40 | 41 | net_class = CodeFormer 42 | 43 | class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): 44 | def name(self): 45 | return "CodeFormer" 46 | 47 | def __init__(self, dirname): 48 | self.net = None 49 | self.face_helper = None 50 | self.cmd_dir = dirname 51 | 52 | def create_models(self): 53 | 54 | if self.net is not None and self.face_helper is not None: 55 | self.net.to(devices.device_codeformer) 56 | return self.net, self.face_helper 57 | model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth') 58 | if len(model_paths) != 0: 59 | ckpt_path = model_paths[0] 60 | else: 61 | print("Unable to load codeformer model.") 62 | return None, None 63 | net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer) 64 | checkpoint = torch.load(ckpt_path)['params_ema'] 65 | net.load_state_dict(checkpoint) 66 | net.eval() 67 | 68 | face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer) 69 | 70 | self.net = net 71 | self.face_helper = face_helper 72 | 73 | return net, face_helper 74 | 75 | def send_model_to(self, device): 76 | self.net.to(device) 77 | self.face_helper.face_det.to(device) 78 | self.face_helper.face_parse.to(device) 79 | 80 | def restore(self, np_image, w=None): 81 | np_image = np_image[:, :, ::-1] 82 | 83 | original_resolution = np_image.shape[0:2] 84 | 85 | self.create_models() 86 | if self.net is None or self.face_helper is None: 87 | return np_image 88 | 89 | self.send_model_to(devices.device_codeformer) 90 | 91 | self.face_helper.clean_all() 92 | self.face_helper.read_image(np_image) 93 | self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) 94 | self.face_helper.align_warp_face() 95 | 96 | for idx, cropped_face in enumerate(self.face_helper.cropped_faces): 97 | cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) 98 | normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 99 | cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) 100 | 101 | try: 102 | with torch.no_grad(): 103 | output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] 104 | restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) 105 | del output 106 | torch.cuda.empty_cache() 107 | except Exception as error: 108 | print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr) 109 | restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) 110 | 111 | restored_face = restored_face.astype('uint8') 112 | self.face_helper.add_restored_face(restored_face) 113 | 114 | self.face_helper.get_inverse_affine(None) 115 | 116 | restored_img = self.face_helper.paste_faces_to_input_image() 117 | restored_img = restored_img[:, :, ::-1] 118 | 119 | if original_resolution != restored_img.shape[0:2]: 120 | restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) 121 | 122 | self.face_helper.clean_all() 123 | 124 | if shared.opts.face_restoration_unload: 125 | self.send_model_to(devices.cpu) 126 | 127 | return restored_img 128 | 129 | global have_codeformer 130 | have_codeformer = True 131 | 132 | global codeformer 133 | codeformer = FaceRestorerCodeFormer(dirname) 134 | shared.face_restorers.append(codeformer) 135 | 136 | except Exception: 137 | print("Error setting up CodeFormer:", file=sys.stderr) 138 | print(traceback.format_exc(), file=sys.stderr) 139 | 140 | # sys.path = stored_sys_path 141 | -------------------------------------------------------------------------------- /modules/modelloader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import shutil 4 | import importlib 5 | from urllib.parse import urlparse 6 | 7 | from basicsr.utils.download_util import load_file_from_url 8 | from modules import shared 9 | from modules.upscaler import Upscaler 10 | from modules.paths import script_path, models_path 11 | 12 | 13 | def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list: 14 | """ 15 | A one-and done loader to try finding the desired models in specified directories. 16 | 17 | @param download_name: Specify to download from model_url immediately. 18 | @param model_url: If no other models are found, this will be downloaded on upscale. 19 | @param model_path: The location to store/find models in. 20 | @param command_path: A command-line argument to search for models in first. 21 | @param ext_filter: An optional list of filename extensions to filter by 22 | @return: A list of paths containing the desired model(s) 23 | """ 24 | output = [] 25 | 26 | if ext_filter is None: 27 | ext_filter = [] 28 | 29 | try: 30 | places = [] 31 | 32 | if command_path is not None and command_path != model_path: 33 | pretrained_path = os.path.join(command_path, 'experiments/pretrained_models') 34 | if os.path.exists(pretrained_path): 35 | print(f"Appending path: {pretrained_path}") 36 | places.append(pretrained_path) 37 | elif os.path.exists(command_path): 38 | places.append(command_path) 39 | 40 | places.append(model_path) 41 | 42 | for place in places: 43 | if os.path.exists(place): 44 | for file in glob.iglob(place + '**/**', recursive=True): 45 | full_path = file 46 | if os.path.isdir(full_path): 47 | continue 48 | if len(ext_filter) != 0: 49 | model_name, extension = os.path.splitext(file) 50 | if extension not in ext_filter: 51 | continue 52 | if file not in output: 53 | output.append(full_path) 54 | 55 | if model_url is not None and len(output) == 0: 56 | if download_name is not None: 57 | dl = load_file_from_url(model_url, model_path, True, download_name) 58 | output.append(dl) 59 | else: 60 | output.append(model_url) 61 | 62 | except Exception: 63 | pass 64 | 65 | return output 66 | 67 | 68 | def friendly_name(file: str): 69 | if "http" in file: 70 | file = urlparse(file).path 71 | 72 | file = os.path.basename(file) 73 | model_name, extension = os.path.splitext(file) 74 | return model_name 75 | 76 | 77 | def cleanup_models(): 78 | # This code could probably be more efficient if we used a tuple list or something to store the src/destinations 79 | # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler 80 | # somehow auto-register and just do these things... 81 | root_path = script_path 82 | src_path = models_path 83 | dest_path = os.path.join(models_path, "Stable-diffusion") 84 | move_files(src_path, dest_path, ".ckpt") 85 | src_path = os.path.join(root_path, "ESRGAN") 86 | dest_path = os.path.join(models_path, "ESRGAN") 87 | move_files(src_path, dest_path) 88 | src_path = os.path.join(models_path, "BSRGAN") 89 | dest_path = os.path.join(models_path, "ESRGAN") 90 | move_files(src_path, dest_path, ".pth") 91 | src_path = os.path.join(root_path, "gfpgan") 92 | dest_path = os.path.join(models_path, "GFPGAN") 93 | move_files(src_path, dest_path) 94 | src_path = os.path.join(root_path, "SwinIR") 95 | dest_path = os.path.join(models_path, "SwinIR") 96 | move_files(src_path, dest_path) 97 | src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/") 98 | dest_path = os.path.join(models_path, "LDSR") 99 | move_files(src_path, dest_path) 100 | 101 | 102 | def move_files(src_path: str, dest_path: str, ext_filter: str = None): 103 | try: 104 | if not os.path.exists(dest_path): 105 | os.makedirs(dest_path) 106 | if os.path.exists(src_path): 107 | for file in os.listdir(src_path): 108 | fullpath = os.path.join(src_path, file) 109 | if os.path.isfile(fullpath): 110 | if ext_filter is not None: 111 | if ext_filter not in file: 112 | continue 113 | print(f"Moving {file} from {src_path} to {dest_path}.") 114 | try: 115 | shutil.move(fullpath, dest_path) 116 | except: 117 | pass 118 | if len(os.listdir(src_path)) == 0: 119 | print(f"Removing empty folder: {src_path}") 120 | shutil.rmtree(src_path, True) 121 | except: 122 | pass 123 | 124 | 125 | def load_upscalers(): 126 | sd = shared.script_path 127 | # We can only do this 'magic' method to dynamically load upscalers if they are referenced, 128 | # so we'll try to import any _model.py files before looking in __subclasses__ 129 | modules_dir = os.path.join(sd, "modules") 130 | for file in os.listdir(modules_dir): 131 | if "_model.py" in file: 132 | model_name = file.replace("_model.py", "") 133 | full_model = f"modules.{model_name}_model" 134 | try: 135 | importlib.import_module(full_model) 136 | except: 137 | pass 138 | datas = [] 139 | c_o = vars(shared.cmd_opts) 140 | for cls in Upscaler.__subclasses__(): 141 | name = cls.__name__ 142 | module_name = cls.__module__ 143 | module = importlib.import_module(module_name) 144 | class_ = getattr(module, name) 145 | cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" 146 | opt_string = None 147 | try: 148 | if cmd_name in c_o: 149 | opt_string = c_o[cmd_name] 150 | except: 151 | pass 152 | scaler = class_(opt_string) 153 | for child in scaler.scalers: 154 | datas.append(child) 155 | 156 | shared.sd_upscalers = datas 157 | -------------------------------------------------------------------------------- /javascript/progressbar.js: -------------------------------------------------------------------------------- 1 | // code related to showing and updating progressbar shown as the image is being made 2 | global_progressbars = {} 3 | galleries = {} 4 | galleryObservers = {} 5 | 6 | // this tracks laumnches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running 7 | timeoutIds = {} 8 | 9 | function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ 10 | // gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id 11 | // every time you use gr.HTML(elem_id='xxx'), so we handle this here 12 | var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar) 13 | var progressbarParent 14 | if(progressbar){ 15 | progressbarParent = gradioApp().querySelector("#"+id_progressbar) 16 | } else{ 17 | progressbar = gradioApp().getElementById(id_progressbar) 18 | progressbarParent = null 19 | } 20 | 21 | var skip = id_skip ? gradioApp().getElementById(id_skip) : null 22 | var interrupt = gradioApp().getElementById(id_interrupt) 23 | 24 | if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ 25 | if(progressbar.innerText){ 26 | let newtitle = 'Stable Diffusion - ' + progressbar.innerText 27 | if(document.title != newtitle){ 28 | document.title = newtitle; 29 | } 30 | }else{ 31 | let newtitle = 'Stable Diffusion' 32 | if(document.title != newtitle){ 33 | document.title = newtitle; 34 | } 35 | } 36 | } 37 | 38 | if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){ 39 | global_progressbars[id_progressbar] = progressbar 40 | 41 | var mutationObserver = new MutationObserver(function(m){ 42 | if(timeoutIds[id_part]) return; 43 | 44 | preview = gradioApp().getElementById(id_preview) 45 | gallery = gradioApp().getElementById(id_gallery) 46 | 47 | if(preview != null && gallery != null){ 48 | preview.style.width = gallery.clientWidth + "px" 49 | preview.style.height = gallery.clientHeight + "px" 50 | if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px" 51 | 52 | //only watch gallery if there is a generation process going on 53 | check_gallery(id_gallery); 54 | 55 | var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; 56 | if(progressDiv){ 57 | timeoutIds[id_part] = window.setTimeout(function() { 58 | timeoutIds[id_part] = null 59 | requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) 60 | }, 500) 61 | } else{ 62 | if (skip) { 63 | skip.style.display = "none" 64 | } 65 | interrupt.style.display = "none" 66 | 67 | //disconnect observer once generation finished, so user can close selected image if they want 68 | if (galleryObservers[id_gallery]) { 69 | galleryObservers[id_gallery].disconnect(); 70 | galleries[id_gallery] = null; 71 | } 72 | } 73 | } 74 | 75 | }); 76 | mutationObserver.observe( progressbar, { childList:true, subtree:true }) 77 | } 78 | } 79 | 80 | function check_gallery(id_gallery){ 81 | let gallery = gradioApp().getElementById(id_gallery) 82 | // if gallery has no change, no need to setting up observer again. 83 | if (gallery && galleries[id_gallery] !== gallery){ 84 | galleries[id_gallery] = gallery; 85 | if(galleryObservers[id_gallery]){ 86 | galleryObservers[id_gallery].disconnect(); 87 | } 88 | let prevSelectedIndex = selected_gallery_index(); 89 | galleryObservers[id_gallery] = new MutationObserver(function (){ 90 | let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') 91 | let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') 92 | if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) { 93 | // automatically re-open previously selected index (if exists) 94 | activeElement = gradioApp().activeElement; 95 | 96 | galleryButtons[prevSelectedIndex].click(); 97 | showGalleryImage(); 98 | 99 | if(activeElement){ 100 | // i fought this for about an hour; i don't know why the focus is lost or why this helps recover it 101 | // if somenoe has a better solution please by all means 102 | setTimeout(function() { activeElement.focus() }, 1); 103 | } 104 | } 105 | }) 106 | galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false }) 107 | } 108 | } 109 | 110 | onUiUpdate(function(){ 111 | check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') 112 | check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') 113 | check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery') 114 | }) 115 | 116 | function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){ 117 | btn = gradioApp().getElementById(id_part+"_check_progress"); 118 | if(btn==null) return; 119 | 120 | btn.click(); 121 | var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; 122 | var skip = id_skip ? gradioApp().getElementById(id_skip) : null 123 | var interrupt = gradioApp().getElementById(id_interrupt) 124 | if(progressDiv && interrupt){ 125 | if (skip) { 126 | skip.style.display = "block" 127 | } 128 | interrupt.style.display = "block" 129 | } 130 | } 131 | 132 | function requestProgress(id_part){ 133 | btn = gradioApp().getElementById(id_part+"_check_progress_initial"); 134 | if(btn==null) return; 135 | 136 | btn.click(); 137 | } 138 | -------------------------------------------------------------------------------- /javascript/ui.js: -------------------------------------------------------------------------------- 1 | // various functions for interation with ui.py not large enough to warrant putting them in separate files 2 | 3 | function set_theme(theme){ 4 | gradioURL = window.location.href 5 | if (!gradioURL.includes('?__theme=')) { 6 | window.location.replace(gradioURL + '?__theme=' + theme); 7 | } 8 | } 9 | 10 | function selected_gallery_index(){ 11 | var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item') 12 | var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2') 13 | 14 | var result = -1 15 | buttons.forEach(function(v, i){ if(v==button) { result = i } }) 16 | 17 | return result 18 | } 19 | 20 | function extract_image_from_gallery(gallery){ 21 | if(gallery.length == 1){ 22 | return gallery[0] 23 | } 24 | 25 | index = selected_gallery_index() 26 | 27 | if (index < 0 || index >= gallery.length){ 28 | return [null] 29 | } 30 | 31 | return gallery[index]; 32 | } 33 | 34 | function args_to_array(args){ 35 | res = [] 36 | for(var i=0;i label > textarea"); 176 | txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button")); 177 | } 178 | if (!img2img_textarea) { 179 | img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); 180 | img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); 181 | } 182 | }) 183 | 184 | let txt2img_textarea, img2img_textarea = undefined; 185 | let wait_time = 800 186 | let token_timeout; 187 | 188 | function update_txt2img_tokens(...args) { 189 | update_token_counter("txt2img_token_button") 190 | if (args.length == 2) 191 | return args[0] 192 | return args; 193 | } 194 | 195 | function update_img2img_tokens(...args) { 196 | update_token_counter("img2img_token_button") 197 | if (args.length == 2) 198 | return args[0] 199 | return args; 200 | } 201 | 202 | function update_token_counter(button_id) { 203 | if (token_timeout) 204 | clearTimeout(token_timeout); 205 | token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); 206 | } 207 | 208 | function restart_reload(){ 209 | document.body.innerHTML='

Reloading...

'; 210 | setTimeout(function(){location.reload()},2000) 211 | 212 | return [] 213 | } 214 | -------------------------------------------------------------------------------- /modules/deepbooru.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from concurrent.futures import ProcessPoolExecutor 3 | import multiprocessing 4 | import time 5 | import re 6 | 7 | re_special = re.compile(r'([\\()])') 8 | 9 | def get_deepbooru_tags(pil_image): 10 | """ 11 | This method is for running only one image at a time for simple use. Used to the img2img interrogate. 12 | """ 13 | from modules import shared # prevents circular reference 14 | 15 | try: 16 | create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts()) 17 | return get_tags_from_process(pil_image) 18 | finally: 19 | release_process() 20 | 21 | 22 | OPT_INCLUDE_RANKS = "include_ranks" 23 | def create_deepbooru_opts(): 24 | from modules import shared 25 | 26 | return { 27 | "use_spaces": shared.opts.deepbooru_use_spaces, 28 | "use_escape": shared.opts.deepbooru_escape, 29 | "alpha_sort": shared.opts.deepbooru_sort_alpha, 30 | OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks, 31 | } 32 | 33 | 34 | def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts): 35 | model, tags = get_deepbooru_tags_model() 36 | while True: # while process is running, keep monitoring queue for new image 37 | pil_image = queue.get() 38 | if pil_image == "QUIT": 39 | break 40 | else: 41 | deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts) 42 | 43 | 44 | def create_deepbooru_process(threshold, deepbooru_opts): 45 | """ 46 | Creates deepbooru process. A queue is created to send images into the process. This enables multiple images 47 | to be processed in a row without reloading the model or creating a new process. To return the data, a shared 48 | dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned 49 | to the dictionary and the method adding the image to the queue should wait for this value to be updated with 50 | the tags. 51 | """ 52 | from modules import shared # prevents circular reference 53 | context = multiprocessing.get_context("spawn") 54 | shared.deepbooru_process_manager = context.Manager() 55 | shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() 56 | shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() 57 | shared.deepbooru_process_return["value"] = -1 58 | shared.deepbooru_process = context.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts)) 59 | shared.deepbooru_process.start() 60 | 61 | 62 | def get_tags_from_process(image): 63 | from modules import shared 64 | 65 | shared.deepbooru_process_return["value"] = -1 66 | shared.deepbooru_process_queue.put(image) 67 | while shared.deepbooru_process_return["value"] == -1: 68 | time.sleep(0.2) 69 | caption = shared.deepbooru_process_return["value"] 70 | shared.deepbooru_process_return["value"] = -1 71 | 72 | return caption 73 | 74 | 75 | def release_process(): 76 | """ 77 | Stops the deepbooru process to return used memory 78 | """ 79 | from modules import shared # prevents circular reference 80 | shared.deepbooru_process_queue.put("QUIT") 81 | shared.deepbooru_process.join() 82 | shared.deepbooru_process_queue = None 83 | shared.deepbooru_process = None 84 | shared.deepbooru_process_return = None 85 | shared.deepbooru_process_manager = None 86 | 87 | def get_deepbooru_tags_model(): 88 | import deepdanbooru as dd 89 | import tensorflow as tf 90 | import numpy as np 91 | this_folder = os.path.dirname(__file__) 92 | model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru')) 93 | if not os.path.exists(os.path.join(model_path, 'project.json')): 94 | # there is no point importing these every time 95 | import zipfile 96 | from basicsr.utils.download_util import load_file_from_url 97 | load_file_from_url( 98 | r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", 99 | model_path) 100 | with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref: 101 | zip_ref.extractall(model_path) 102 | os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip")) 103 | 104 | tags = dd.project.load_tags_from_project(model_path) 105 | model = dd.project.load_model_from_project( 106 | model_path, compile_model=False 107 | ) 108 | return model, tags 109 | 110 | 111 | def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts): 112 | import deepdanbooru as dd 113 | import tensorflow as tf 114 | import numpy as np 115 | 116 | alpha_sort = deepbooru_opts['alpha_sort'] 117 | use_spaces = deepbooru_opts['use_spaces'] 118 | use_escape = deepbooru_opts['use_escape'] 119 | include_ranks = deepbooru_opts['include_ranks'] 120 | 121 | width = model.input_shape[2] 122 | height = model.input_shape[1] 123 | image = np.array(pil_image) 124 | image = tf.image.resize( 125 | image, 126 | size=(height, width), 127 | method=tf.image.ResizeMethod.AREA, 128 | preserve_aspect_ratio=True, 129 | ) 130 | image = image.numpy() # EagerTensor to np.array 131 | image = dd.image.transform_and_pad_image(image, width, height) 132 | image = image / 255.0 133 | image_shape = image.shape 134 | image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2])) 135 | 136 | y = model.predict(image)[0] 137 | 138 | result_dict = {} 139 | 140 | for i, tag in enumerate(tags): 141 | result_dict[tag] = y[i] 142 | 143 | unsorted_tags_in_theshold = [] 144 | result_tags_print = [] 145 | for tag in tags: 146 | if result_dict[tag] >= threshold: 147 | if tag.startswith("rating:"): 148 | continue 149 | unsorted_tags_in_theshold.append((result_dict[tag], tag)) 150 | result_tags_print.append(f'{result_dict[tag]} {tag}') 151 | 152 | # sort tags 153 | result_tags_out = [] 154 | sort_ndx = 0 155 | if alpha_sort: 156 | sort_ndx = 1 157 | 158 | # sort by reverse by likelihood and normal for alpha, and format tag text as requested 159 | unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) 160 | for weight, tag in unsorted_tags_in_theshold: 161 | tag_outformat = tag 162 | if use_spaces: 163 | tag_outformat = tag_outformat.replace('_', ' ') 164 | if use_escape: 165 | tag_outformat = re.sub(re_special, r'\\\1', tag_outformat) 166 | if include_ranks: 167 | tag_outformat = f"({tag_outformat}:{weight:.3f})" 168 | 169 | result_tags_out.append(tag_outformat) 170 | 171 | print('\n'.join(sorted(result_tags_print, reverse=True))) 172 | 173 | return ', '.join(result_tags_out) 174 | -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading 3 | import time 4 | import importlib 5 | import signal 6 | import threading 7 | from fastapi import FastAPI 8 | from fastapi.middleware.cors import CORSMiddleware 9 | from fastapi.middleware.gzip import GZipMiddleware 10 | 11 | from modules.paths import script_path 12 | 13 | from modules import devices, sd_samplers, upscaler, extensions, localization 14 | import modules.codeformer_model as codeformer 15 | import modules.extras 16 | import modules.face_restoration 17 | import modules.gfpgan_model as gfpgan 18 | import modules.img2img 19 | 20 | import modules.lowvram 21 | import modules.paths 22 | import modules.scripts 23 | import modules.sd_hijack 24 | import modules.sd_models 25 | import modules.sd_vae 26 | import modules.shared as shared 27 | import modules.txt2img 28 | import modules.script_callbacks 29 | 30 | import modules.ui 31 | from modules import modelloader 32 | from modules.shared import cmd_opts 33 | import modules.hypernetworks.hypernetwork 34 | 35 | queue_lock = threading.Lock() 36 | server_name = "0.0.0.0" if cmd_opts.listen else cmd_opts.server_name 37 | 38 | def wrap_queued_call(func): 39 | def f(*args, **kwargs): 40 | with queue_lock: 41 | res = func(*args, **kwargs) 42 | 43 | return res 44 | 45 | return f 46 | 47 | 48 | def wrap_gradio_gpu_call(func, extra_outputs=None): 49 | def f(*args, **kwargs): 50 | 51 | shared.state.begin() 52 | 53 | with queue_lock: 54 | res = func(*args, **kwargs) 55 | 56 | shared.state.end() 57 | 58 | return res 59 | 60 | return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) 61 | 62 | 63 | def initialize(): 64 | extensions.list_extensions() 65 | localization.list_localizations(cmd_opts.localizations_dir) 66 | 67 | if cmd_opts.ui_debug_mode: 68 | shared.sd_upscalers = upscaler.UpscalerLanczos().scalers 69 | modules.scripts.load_scripts() 70 | return 71 | 72 | modelloader.cleanup_models() 73 | modules.sd_models.setup_model() 74 | codeformer.setup_model(cmd_opts.codeformer_models_path) 75 | gfpgan.setup_model(cmd_opts.gfpgan_models_path) 76 | shared.face_restorers.append(modules.face_restoration.FaceRestoration()) 77 | modelloader.load_upscalers() 78 | 79 | modules.scripts.load_scripts() 80 | 81 | modules.sd_vae.refresh_vae_list() 82 | modules.sd_models.load_model() 83 | shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) 84 | shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) 85 | shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) 86 | shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) 87 | 88 | if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: 89 | 90 | try: 91 | if not os.path.exists(cmd_opts.tls_keyfile): 92 | print("Invalid path to TLS keyfile given") 93 | if not os.path.exists(cmd_opts.tls_certfile): 94 | print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") 95 | except TypeError: 96 | cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None 97 | print("TLS setup invalid, running webui without TLS") 98 | else: 99 | print("Running with TLS") 100 | 101 | # make the program just exit at ctrl+c without waiting for anything 102 | def sigint_handler(sig, frame): 103 | print(f'Interrupted with signal {sig} in {frame}') 104 | os._exit(0) 105 | 106 | signal.signal(signal.SIGINT, sigint_handler) 107 | 108 | 109 | def setup_cors(app): 110 | if cmd_opts.cors_allow_origins: 111 | app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*']) 112 | 113 | 114 | def create_api(app): 115 | from modules.api.api import Api 116 | api = Api(app, queue_lock) 117 | return api 118 | 119 | 120 | def wait_on_server(demo=None): 121 | while 1: 122 | time.sleep(0.5) 123 | if shared.state.need_restart: 124 | shared.state.need_restart = False 125 | time.sleep(0.5) 126 | demo.close() 127 | time.sleep(0.5) 128 | break 129 | 130 | 131 | def api_only(): 132 | initialize() 133 | 134 | app = FastAPI() 135 | setup_cors(app) 136 | app.add_middleware(GZipMiddleware, minimum_size=1000) 137 | api = create_api(app) 138 | 139 | modules.script_callbacks.app_started_callback(None, app) 140 | 141 | api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) 142 | 143 | 144 | def webui(): 145 | launch_api = cmd_opts.api 146 | initialize() 147 | 148 | while 1: 149 | demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) 150 | 151 | app, local_url, share_url = demo.launch( 152 | share=cmd_opts.share, 153 | server_name=server_name, 154 | server_port=cmd_opts.port, 155 | ssl_keyfile=cmd_opts.tls_keyfile, 156 | ssl_certfile=cmd_opts.tls_certfile, 157 | debug=cmd_opts.gradio_debug, 158 | auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, 159 | inbrowser=cmd_opts.autolaunch, 160 | prevent_thread_lock=True 161 | ) 162 | # after initial launch, disable --autolaunch for subsequent restarts 163 | cmd_opts.autolaunch = False 164 | 165 | # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for 166 | # an attacker to trick the user into opening a malicious HTML page, which makes a request to the 167 | # running web ui and do whatever the attcker wants, including installing an extension and 168 | # runnnig its code. We disable this here. Suggested by RyotaK. 169 | app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] 170 | 171 | setup_cors(app) 172 | 173 | app.add_middleware(GZipMiddleware, minimum_size=1000) 174 | 175 | if launch_api: 176 | create_api(app) 177 | 178 | modules.script_callbacks.app_started_callback(demo, app) 179 | 180 | wait_on_server(demo) 181 | 182 | sd_samplers.set_samplers() 183 | 184 | print('Reloading extensions') 185 | extensions.list_extensions() 186 | 187 | localization.list_localizations(cmd_opts.localizations_dir) 188 | 189 | print('Reloading custom scripts') 190 | modules.scripts.reload_scripts() 191 | print('Reloading modules: modules.ui') 192 | importlib.reload(modules.ui) 193 | print('Refreshing Model List') 194 | modules.sd_models.list_models() 195 | print('Restarting Gradio') 196 | 197 | 198 | if __name__ == "__main__": 199 | if cmd_opts.nowebui: 200 | api_only() 201 | else: 202 | webui() 203 | --------------------------------------------------------------------------------