├── test ├── __init__.py ├── basic_features │ ├── __init__.py │ ├── extras_test.py │ ├── utils_test.py │ ├── img2img_test.py │ └── txt2img_test.py ├── test_files │ ├── empty.pt │ ├── mask_basic.png │ └── img2img_basic.png └── server_poll.py ├── models ├── VAE │ └── Put VAE here.txt ├── Stable-diffusion │ └── Put Stable Diffusion checkpoints here.txt ├── deepbooru │ └── Put your deepbooru release project folder here.txt └── VAE-approx │ └── model.pt ├── extensions └── put extensions here.txt ├── localizations └── Put localization files here.txt ├── textual_inversion_templates ├── none.txt ├── style.txt ├── subject.txt ├── hypernetwork.txt ├── style_filewords.txt └── subject_filewords.txt ├── embeddings └── Place Textual Inversion embeddings here.txt ├── screenshot.png ├── html ├── card-no-preview.png ├── extra-networks-no-cards.html ├── extra-networks-card.html ├── footer.html └── image-update.svg ├── webui-user.bat ├── .pylintrc ├── modules ├── textual_inversion │ ├── test_embedding.png │ ├── logging.py │ ├── ui.py │ └── learn_schedule.py ├── import_hook.py ├── sd_hijack_ip2p.py ├── face_restoration.py ├── shared_items.py ├── timer.py ├── script_loading.py ├── ngrok.py ├── localization.py ├── extra_networks_hypernet.py ├── ui_extra_networks_hypernets.py ├── errors.py ├── ui_extra_networks_textual_inversion.py ├── sd_hijack_utils.py ├── sd_hijack_checkpoint.py ├── sd_hijack_open_clip.py ├── ui_extra_networks_checkpoints.py ├── sd_samplers.py ├── sd_hijack_xlmr.py ├── scripts_auto_postprocessing.py ├── hypernetworks │ └── ui.py ├── ui_components.py ├── sd_vae_approx.py ├── hashes.py ├── sd_samplers_common.py ├── paths.py ├── ui_tempdir.py ├── memmon.py ├── ui_postprocessing.py ├── txt2img.py ├── deepbooru.py ├── styles.py ├── mac_specific.py ├── masking.py ├── sd_hijack_clip_old.py ├── extensions.py ├── call_queue.py ├── progress.py ├── sd_hijack_unet.py ├── gfpgan_model.py ├── postprocessing.py ├── sd_models_config.py ├── lowvram.py └── upscaler.py ├── environment-wsl2.yaml ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── feature_request.yml │ └── bug_report.yml ├── workflows │ ├── run_tests.yaml │ └── on_pull_request.yaml └── pull_request_template.md ├── extensions-builtin ├── Lora │ ├── preload.py │ ├── extra_networks_lora.py │ ├── ui_extra_networks_lora.py │ └── scripts │ │ └── lora_script.py ├── LDSR │ ├── preload.py │ └── scripts │ │ └── ldsr_model.py ├── ScuNET │ ├── preload.py │ └── scripts │ │ └── scunet_model.py ├── SwinIR │ └── preload.py └── prompt-bracket-checker │ └── javascript │ └── prompt-bracket-checker.js ├── javascript ├── textualInversion.js ├── imageParams.js ├── hires_fix.js ├── generationParams.js ├── extensions.js ├── imageMaskFix.js ├── notification.js ├── dragdrop.js ├── edit-attention.js ├── extraNetworks.js └── aspectRatioOverlay.js ├── requirements.txt ├── .gitignore ├── requirements_versions.txt ├── CODEOWNERS ├── webui-macos-env.sh ├── scripts ├── postprocessing_gfpgan.py ├── custom_code.py ├── postprocessing_codeformer.py ├── loopback.py └── sd_upscale.py ├── webui-user.sh ├── configs ├── v1-inference.yaml ├── alt-diffusion-inference.yaml ├── v1-inpainting-inference.yaml └── instruct-pix2pix.yaml ├── webui.bat └── script.js /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/VAE/Put VAE here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/basic_features/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /extensions/put extensions here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /localizations/Put localization files here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /textual_inversion_templates/none.txt: -------------------------------------------------------------------------------- 1 | picture 2 | -------------------------------------------------------------------------------- /embeddings/Place Textual Inversion embeddings here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/Stable-diffusion/Put Stable Diffusion checkpoints here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/deepbooru/Put your deepbooru release project folder here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mthli/stable-diffusion-webui/master/screenshot.png -------------------------------------------------------------------------------- /html/card-no-preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mthli/stable-diffusion-webui/master/html/card-no-preview.png -------------------------------------------------------------------------------- /models/VAE-approx/model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mthli/stable-diffusion-webui/master/models/VAE-approx/model.pt -------------------------------------------------------------------------------- /test/test_files/empty.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mthli/stable-diffusion-webui/master/test/test_files/empty.pt -------------------------------------------------------------------------------- /test/test_files/mask_basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mthli/stable-diffusion-webui/master/test/test_files/mask_basic.png -------------------------------------------------------------------------------- /test/test_files/img2img_basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mthli/stable-diffusion-webui/master/test/test_files/img2img_basic.png -------------------------------------------------------------------------------- /webui-user.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | set PYTHON= 4 | set GIT= 5 | set VENV_DIR= 6 | set COMMANDLINE_ARGS= 7 | 8 | call webui.bat 9 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html 2 | [MESSAGES CONTROL] 3 | disable=C,R,W,E,I 4 | -------------------------------------------------------------------------------- /modules/textual_inversion/test_embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mthli/stable-diffusion-webui/master/modules/textual_inversion/test_embedding.png -------------------------------------------------------------------------------- /html/extra-networks-no-cards.html: -------------------------------------------------------------------------------- 1 |
2 |

Nothing here. Add some content to the following directories:

3 | 4 | 7 |
8 | 9 | -------------------------------------------------------------------------------- /environment-wsl2.yaml: -------------------------------------------------------------------------------- 1 | name: automatic 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.10 7 | - pip=22.2.2 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.23.1 -------------------------------------------------------------------------------- /modules/import_hook.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | # this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it 4 | if "--xformers" not in "".join(sys.argv): 5 | sys.modules["xformers"] = None 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: WebUI Community Support 4 | url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions 5 | about: Please ask and answer questions here. 6 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/preload.py: -------------------------------------------------------------------------------- 1 | import os 2 | from modules import paths 3 | 4 | 5 | def preload(parser): 6 | parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora')) 7 | -------------------------------------------------------------------------------- /extensions-builtin/LDSR/preload.py: -------------------------------------------------------------------------------- 1 | import os 2 | from modules import paths 3 | 4 | 5 | def preload(parser): 6 | parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR')) 7 | -------------------------------------------------------------------------------- /extensions-builtin/ScuNET/preload.py: -------------------------------------------------------------------------------- 1 | import os 2 | from modules import paths 3 | 4 | 5 | def preload(parser): 6 | parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET')) 7 | -------------------------------------------------------------------------------- /extensions-builtin/SwinIR/preload.py: -------------------------------------------------------------------------------- 1 | import os 2 | from modules import paths 3 | 4 | 5 | def preload(parser): 6 | parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR')) 7 | -------------------------------------------------------------------------------- /html/extra-networks-card.html: -------------------------------------------------------------------------------- 1 |
2 |
3 |
4 | 7 | 8 |
9 | {name} 10 |
11 |
12 | 13 | -------------------------------------------------------------------------------- /html/footer.html: -------------------------------------------------------------------------------- 1 |
2 | API 3 |  •  4 | Github 5 |  •  6 | Gradio 7 |  •  8 | Reload UI 9 |
10 |
11 |
12 | {versions} 13 |
14 | -------------------------------------------------------------------------------- /modules/sd_hijack_ip2p.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os.path 3 | import sys 4 | import gc 5 | import time 6 | 7 | def should_hijack_ip2p(checkpoint_info): 8 | from modules import sd_models_config 9 | 10 | ckpt_basename = os.path.basename(checkpoint_info.filename).lower() 11 | cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower() 12 | 13 | return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename 14 | -------------------------------------------------------------------------------- /javascript/textualInversion.js: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | function start_training_textual_inversion(){ 5 | gradioApp().querySelector('#ti_error').innerHTML='' 6 | 7 | var id = randomId() 8 | requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){ 9 | gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo 10 | }) 11 | 12 | var res = args_to_array(arguments) 13 | 14 | res[0] = id 15 | 16 | return res 17 | } 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | blendmodes 2 | accelerate 3 | basicsr 4 | fonts 5 | font-roboto 6 | gfpgan 7 | gradio==3.16.2 8 | invisible-watermark 9 | numpy 10 | omegaconf 11 | opencv-contrib-python 12 | requests 13 | piexif 14 | Pillow 15 | pytorch_lightning==1.7.7 16 | realesrgan 17 | scikit-image>=0.19 18 | timm==0.4.12 19 | transformers==4.25.1 20 | torch 21 | einops 22 | jsonmerge 23 | clean-fid 24 | resize-right 25 | torchdiffeq 26 | kornia 27 | lark 28 | inflection 29 | GitPython 30 | torchsde 31 | safetensors 32 | psutil 33 | -------------------------------------------------------------------------------- /modules/face_restoration.py: -------------------------------------------------------------------------------- 1 | from modules import shared 2 | 3 | 4 | class FaceRestoration: 5 | def name(self): 6 | return "None" 7 | 8 | def restore(self, np_image): 9 | return np_image 10 | 11 | 12 | def restore_faces(np_image): 13 | face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None] 14 | if len(face_restorers) == 0: 15 | return np_image 16 | 17 | face_restorer = face_restorers[0] 18 | 19 | return face_restorer.restore(np_image) 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.ckpt 3 | *.safetensors 4 | *.pth 5 | /ESRGAN/* 6 | /SwinIR/* 7 | /repositories 8 | /venv 9 | /tmp 10 | /model.ckpt 11 | /models/**/* 12 | /GFPGANv1.3.pth 13 | /gfpgan/weights/*.pth 14 | /ui-config.json 15 | /outputs 16 | /config.json 17 | /log 18 | /webui.settings.bat 19 | /embeddings 20 | /styles.csv 21 | /params.txt 22 | /styles.csv.bak 23 | /webui-user.bat 24 | /webui-user.sh 25 | /interrogate 26 | /user.css 27 | /.idea 28 | notification.mp3 29 | /SwinIR 30 | /textual_inversion 31 | .vscode 32 | /extensions 33 | /test/stdout.txt 34 | /test/stderr.txt 35 | /cache.json 36 | -------------------------------------------------------------------------------- /modules/shared_items.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def realesrgan_models_names(): 4 | import modules.realesrgan_model 5 | return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] 6 | 7 | 8 | def postprocessing_scripts(): 9 | import modules.scripts 10 | 11 | return modules.scripts.scripts_postproc.scripts 12 | 13 | 14 | def sd_vae_items(): 15 | import modules.sd_vae 16 | 17 | return ["Automatic", "None"] + list(modules.sd_vae.vae_dict) 18 | 19 | 20 | def refresh_vae_list(): 21 | import modules.sd_vae 22 | 23 | modules.sd_vae.refresh_vae_list() 24 | -------------------------------------------------------------------------------- /requirements_versions.txt: -------------------------------------------------------------------------------- 1 | blendmodes==2022 2 | transformers==4.25.1 3 | accelerate==0.12.0 4 | basicsr==1.4.2 5 | gfpgan==1.3.8 6 | gradio==3.16.2 7 | numpy==1.23.3 8 | Pillow==9.4.0 9 | realesrgan==0.3.0 10 | torch 11 | omegaconf==2.2.3 12 | pytorch_lightning==1.7.6 13 | scikit-image==0.19.2 14 | fonts 15 | font-roboto 16 | timm==0.6.7 17 | piexif==1.1.3 18 | einops==0.4.1 19 | jsonmerge==1.8.0 20 | clean-fid==0.1.29 21 | resize-right==0.0.2 22 | torchdiffeq==0.2.3 23 | kornia==0.6.7 24 | lark==1.1.2 25 | inflection==0.5.1 26 | GitPython==3.1.27 27 | torchsde==0.2.5 28 | safetensors==0.2.7 29 | httpcore<=0.15 30 | fastapi==0.90.1 31 | -------------------------------------------------------------------------------- /textual_inversion_templates/style.txt: -------------------------------------------------------------------------------- 1 | a painting, art by [name] 2 | a rendering, art by [name] 3 | a cropped painting, art by [name] 4 | the painting, art by [name] 5 | a clean painting, art by [name] 6 | a dirty painting, art by [name] 7 | a dark painting, art by [name] 8 | a picture, art by [name] 9 | a cool painting, art by [name] 10 | a close-up painting, art by [name] 11 | a bright painting, art by [name] 12 | a cropped painting, art by [name] 13 | a good painting, art by [name] 14 | a close-up painting, art by [name] 15 | a rendition, art by [name] 16 | a nice painting, art by [name] 17 | a small painting, art by [name] 18 | a weird painting, art by [name] 19 | a large painting, art by [name] 20 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @AUTOMATIC1111 2 | 3 | # if you were managing a localization and were removed from this file, this is because 4 | # the intended way to do localizations now is via extensions. See: 5 | # https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions 6 | # Make a repo with your localization and since you are still listed as a collaborator 7 | # you can add it to the wiki page yourself. This change is because some people complained 8 | # the git commit log is cluttered with things unrelated to almost everyone and 9 | # because I believe this is the best overall for the project to handle localizations almost 10 | # entirely without my oversight. 11 | 12 | 13 | -------------------------------------------------------------------------------- /javascript/imageParams.js: -------------------------------------------------------------------------------- 1 | window.onload = (function(){ 2 | window.addEventListener('drop', e => { 3 | const target = e.composedPath()[0]; 4 | const idx = selected_gallery_index(); 5 | if (target.placeholder.indexOf("Prompt") == -1) return; 6 | 7 | let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; 8 | 9 | e.stopPropagation(); 10 | e.preventDefault(); 11 | const imgParent = gradioApp().getElementById(prompt_target); 12 | const files = e.dataTransfer.files; 13 | const fileInput = imgParent.querySelector('input[type="file"]'); 14 | if ( fileInput ) { 15 | fileInput.files = files; 16 | fileInput.dispatchEvent(new Event('change')); 17 | } 18 | }); 19 | }); 20 | -------------------------------------------------------------------------------- /textual_inversion_templates/subject.txt: -------------------------------------------------------------------------------- 1 | a photo of a [name] 2 | a rendering of a [name] 3 | a cropped photo of the [name] 4 | the photo of a [name] 5 | a photo of a clean [name] 6 | a photo of a dirty [name] 7 | a dark photo of the [name] 8 | a photo of my [name] 9 | a photo of the cool [name] 10 | a close-up photo of a [name] 11 | a bright photo of the [name] 12 | a cropped photo of a [name] 13 | a photo of the [name] 14 | a good photo of the [name] 15 | a photo of one [name] 16 | a close-up photo of the [name] 17 | a rendition of the [name] 18 | a photo of the clean [name] 19 | a rendition of a [name] 20 | a photo of a nice [name] 21 | a good photo of a [name] 22 | a photo of the nice [name] 23 | a photo of the small [name] 24 | a photo of the weird [name] 25 | a photo of the large [name] 26 | a photo of a cool [name] 27 | a photo of a small [name] 28 | -------------------------------------------------------------------------------- /test/server_poll.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import requests 3 | import time 4 | 5 | 6 | def run_tests(proc, test_dir): 7 | timeout_threshold = 240 8 | start_time = time.time() 9 | while time.time()-start_time < timeout_threshold: 10 | try: 11 | requests.head("http://localhost:7860/") 12 | break 13 | except requests.exceptions.ConnectionError: 14 | if proc.poll() is not None: 15 | break 16 | if proc.poll() is None: 17 | if test_dir is None: 18 | test_dir = "test" 19 | suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test") 20 | result = unittest.TextTestRunner(verbosity=2).run(suite) 21 | return len(result.failures) + len(result.errors) 22 | else: 23 | print("Launch unsuccessful") 24 | return 1 25 | -------------------------------------------------------------------------------- /webui-macos-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #################################################################### 3 | # macOS defaults # 4 | # Please modify webui-user.sh to change these instead of this file # 5 | #################################################################### 6 | 7 | if [[ -x "$(command -v python3.10)" ]] 8 | then 9 | python_cmd="python3.10" 10 | fi 11 | 12 | export install_dir="$HOME" 13 | export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" 14 | export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1" 15 | export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" 16 | export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" 17 | export PYTORCH_ENABLE_MPS_FALLBACK=1 18 | 19 | #################################################################### 20 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yaml: -------------------------------------------------------------------------------- 1 | name: Run basic features tests on CPU with empty SD model 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Checkout Code 12 | uses: actions/checkout@v3 13 | - name: Set up Python 3.10 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: 3.10.6 17 | cache: pip 18 | cache-dependency-path: | 19 | **/requirements*txt 20 | - name: Run tests 21 | run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test 22 | - name: Upload main app stdout-stderr 23 | uses: actions/upload-artifact@v3 24 | if: always() 25 | with: 26 | name: stdout-stderr 27 | path: | 28 | test/stdout.txt 29 | test/stderr.txt 30 | -------------------------------------------------------------------------------- /textual_inversion_templates/hypernetwork.txt: -------------------------------------------------------------------------------- 1 | a photo of a [filewords] 2 | a rendering of a [filewords] 3 | a cropped photo of the [filewords] 4 | the photo of a [filewords] 5 | a photo of a clean [filewords] 6 | a photo of a dirty [filewords] 7 | a dark photo of the [filewords] 8 | a photo of my [filewords] 9 | a photo of the cool [filewords] 10 | a close-up photo of a [filewords] 11 | a bright photo of the [filewords] 12 | a cropped photo of a [filewords] 13 | a photo of the [filewords] 14 | a good photo of the [filewords] 15 | a photo of one [filewords] 16 | a close-up photo of the [filewords] 17 | a rendition of the [filewords] 18 | a photo of the clean [filewords] 19 | a rendition of a [filewords] 20 | a photo of a nice [filewords] 21 | a good photo of a [filewords] 22 | a photo of the nice [filewords] 23 | a photo of the small [filewords] 24 | a photo of the weird [filewords] 25 | a photo of the large [filewords] 26 | a photo of a cool [filewords] 27 | a photo of a small [filewords] 28 | -------------------------------------------------------------------------------- /textual_inversion_templates/style_filewords.txt: -------------------------------------------------------------------------------- 1 | a painting of [filewords], art by [name] 2 | a rendering of [filewords], art by [name] 3 | a cropped painting of [filewords], art by [name] 4 | the painting of [filewords], art by [name] 5 | a clean painting of [filewords], art by [name] 6 | a dirty painting of [filewords], art by [name] 7 | a dark painting of [filewords], art by [name] 8 | a picture of [filewords], art by [name] 9 | a cool painting of [filewords], art by [name] 10 | a close-up painting of [filewords], art by [name] 11 | a bright painting of [filewords], art by [name] 12 | a cropped painting of [filewords], art by [name] 13 | a good painting of [filewords], art by [name] 14 | a close-up painting of [filewords], art by [name] 15 | a rendition of [filewords], art by [name] 16 | a nice painting of [filewords], art by [name] 17 | a small painting of [filewords], art by [name] 18 | a weird painting of [filewords], art by [name] 19 | a large painting of [filewords], art by [name] 20 | -------------------------------------------------------------------------------- /html/image-update.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /modules/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class Timer: 5 | def __init__(self): 6 | self.start = time.time() 7 | self.records = {} 8 | self.total = 0 9 | 10 | def elapsed(self): 11 | end = time.time() 12 | res = end - self.start 13 | self.start = end 14 | return res 15 | 16 | def record(self, category, extra_time=0): 17 | e = self.elapsed() 18 | if category not in self.records: 19 | self.records[category] = 0 20 | 21 | self.records[category] += e + extra_time 22 | self.total += e + extra_time 23 | 24 | def summary(self): 25 | res = f"{self.total:.1f}s" 26 | 27 | additions = [x for x in self.records.items() if x[1] >= 0.1] 28 | if not additions: 29 | return res 30 | 31 | res += " (" 32 | res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions]) 33 | res += ")" 34 | 35 | return res 36 | -------------------------------------------------------------------------------- /javascript/hires_fix.js: -------------------------------------------------------------------------------- 1 | 2 | function setInactive(elem, inactive){ 3 | if(inactive){ 4 | elem.classList.add('inactive') 5 | } else{ 6 | elem.classList.remove('inactive') 7 | } 8 | } 9 | 10 | function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){ 11 | hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale') 12 | hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x') 13 | hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y') 14 | 15 | gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "" 16 | 17 | setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0) 18 | setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0) 19 | setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0) 20 | 21 | return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y] 22 | } 23 | -------------------------------------------------------------------------------- /modules/script_loading.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | import importlib.util 5 | from types import ModuleType 6 | 7 | 8 | def load_module(path): 9 | module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path) 10 | module = importlib.util.module_from_spec(module_spec) 11 | module_spec.loader.exec_module(module) 12 | 13 | return module 14 | 15 | 16 | def preload_extensions(extensions_dir, parser): 17 | if not os.path.isdir(extensions_dir): 18 | return 19 | 20 | for dirname in sorted(os.listdir(extensions_dir)): 21 | preload_script = os.path.join(extensions_dir, dirname, "preload.py") 22 | if not os.path.isfile(preload_script): 23 | continue 24 | 25 | try: 26 | module = load_module(preload_script) 27 | if hasattr(module, 'preload'): 28 | module.preload(parser) 29 | 30 | except Exception: 31 | print(f"Error running preload() for {preload_script}", file=sys.stderr) 32 | print(traceback.format_exc(), file=sys.stderr) 33 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/extra_networks_lora.py: -------------------------------------------------------------------------------- 1 | from modules import extra_networks, shared 2 | import lora 3 | 4 | class ExtraNetworkLora(extra_networks.ExtraNetwork): 5 | def __init__(self): 6 | super().__init__('lora') 7 | 8 | def activate(self, p, params_list): 9 | additional = shared.opts.sd_lora 10 | 11 | if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: 12 | p.all_prompts = [x + f"" for x in p.all_prompts] 13 | params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) 14 | 15 | names = [] 16 | multipliers = [] 17 | for params in params_list: 18 | assert len(params.items) > 0 19 | 20 | names.append(params.items[0]) 21 | multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) 22 | 23 | lora.load_loras(names, multipliers) 24 | 25 | def deactivate(self, p): 26 | pass 27 | -------------------------------------------------------------------------------- /modules/ngrok.py: -------------------------------------------------------------------------------- 1 | from pyngrok import ngrok, conf, exception 2 | 3 | def connect(token, port, region): 4 | account = None 5 | if token is None: 6 | token = 'None' 7 | else: 8 | if ':' in token: 9 | # token = authtoken:username:password 10 | account = token.split(':')[1] + ':' + token.split(':')[-1] 11 | token = token.split(':')[0] 12 | 13 | config = conf.PyngrokConfig( 14 | auth_token=token, region=region 15 | ) 16 | try: 17 | if account is None: 18 | public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url 19 | else: 20 | public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url 21 | except exception.PyngrokNgrokError: 22 | print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' 23 | f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') 24 | else: 25 | print(f'ngrok connected to localhost:{port}! URL: {public_url}\n' 26 | 'You can use this link after the launch is complete.') 27 | -------------------------------------------------------------------------------- /modules/localization.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import traceback 5 | 6 | 7 | localizations = {} 8 | 9 | 10 | def list_localizations(dirname): 11 | localizations.clear() 12 | 13 | for file in os.listdir(dirname): 14 | fn, ext = os.path.splitext(file) 15 | if ext.lower() != ".json": 16 | continue 17 | 18 | localizations[fn] = os.path.join(dirname, file) 19 | 20 | from modules import scripts 21 | for file in scripts.list_scripts("localizations", ".json"): 22 | fn, ext = os.path.splitext(file.filename) 23 | localizations[fn] = file.path 24 | 25 | 26 | def localization_js(current_localization_name): 27 | fn = localizations.get(current_localization_name, None) 28 | data = {} 29 | if fn is not None: 30 | try: 31 | with open(fn, "r", encoding="utf8") as file: 32 | data = json.load(file) 33 | except Exception: 34 | print(f"Error loading localization from {fn}:", file=sys.stderr) 35 | print(traceback.format_exc(), file=sys.stderr) 36 | 37 | return f"var localization = {json.dumps(data)}\n" 38 | -------------------------------------------------------------------------------- /textual_inversion_templates/subject_filewords.txt: -------------------------------------------------------------------------------- 1 | a photo of a [name], [filewords] 2 | a rendering of a [name], [filewords] 3 | a cropped photo of the [name], [filewords] 4 | the photo of a [name], [filewords] 5 | a photo of a clean [name], [filewords] 6 | a photo of a dirty [name], [filewords] 7 | a dark photo of the [name], [filewords] 8 | a photo of my [name], [filewords] 9 | a photo of the cool [name], [filewords] 10 | a close-up photo of a [name], [filewords] 11 | a bright photo of the [name], [filewords] 12 | a cropped photo of a [name], [filewords] 13 | a photo of the [name], [filewords] 14 | a good photo of the [name], [filewords] 15 | a photo of one [name], [filewords] 16 | a close-up photo of the [name], [filewords] 17 | a rendition of the [name], [filewords] 18 | a photo of the clean [name], [filewords] 19 | a rendition of a [name], [filewords] 20 | a photo of a nice [name], [filewords] 21 | a good photo of a [name], [filewords] 22 | a photo of the nice [name], [filewords] 23 | a photo of the small [name], [filewords] 24 | a photo of the weird [name], [filewords] 25 | a photo of the large [name], [filewords] 26 | a photo of a cool [name], [filewords] 27 | a photo of a small [name], [filewords] 28 | -------------------------------------------------------------------------------- /scripts/postprocessing_gfpgan.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | from modules import scripts_postprocessing, gfpgan_model 5 | import gradio as gr 6 | 7 | from modules.ui_components import FormRow 8 | 9 | 10 | class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): 11 | name = "GFPGAN" 12 | order = 2000 13 | 14 | def ui(self): 15 | with FormRow(): 16 | gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") 17 | 18 | return { 19 | "gfpgan_visibility": gfpgan_visibility, 20 | } 21 | 22 | def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): 23 | if gfpgan_visibility == 0: 24 | return 25 | 26 | restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) 27 | res = Image.fromarray(restored_img) 28 | 29 | if gfpgan_visibility < 1.0: 30 | res = Image.blend(pp.image, res, gfpgan_visibility) 31 | 32 | pp.image = res 33 | pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3) 34 | -------------------------------------------------------------------------------- /modules/extra_networks_hypernet.py: -------------------------------------------------------------------------------- 1 | from modules import extra_networks, shared, extra_networks 2 | from modules.hypernetworks import hypernetwork 3 | 4 | 5 | class ExtraNetworkHypernet(extra_networks.ExtraNetwork): 6 | def __init__(self): 7 | super().__init__('hypernet') 8 | 9 | def activate(self, p, params_list): 10 | additional = shared.opts.sd_hypernetwork 11 | 12 | if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: 13 | p.all_prompts = [x + f"" for x in p.all_prompts] 14 | params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) 15 | 16 | names = [] 17 | multipliers = [] 18 | for params in params_list: 19 | assert len(params.items) > 0 20 | 21 | names.append(params.items[0]) 22 | multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) 23 | 24 | hypernetwork.load_hypernetworks(names, multipliers) 25 | 26 | def deactivate(self, p): 27 | pass 28 | -------------------------------------------------------------------------------- /scripts/custom_code.py: -------------------------------------------------------------------------------- 1 | import modules.scripts as scripts 2 | import gradio as gr 3 | 4 | from modules.processing import Processed 5 | from modules.shared import opts, cmd_opts, state 6 | 7 | class Script(scripts.Script): 8 | 9 | def title(self): 10 | return "Custom code" 11 | 12 | def show(self, is_img2img): 13 | return cmd_opts.allow_code 14 | 15 | def ui(self, is_img2img): 16 | code = gr.Textbox(label="Python code", lines=1, elem_id=self.elem_id("code")) 17 | 18 | return [code] 19 | 20 | 21 | def run(self, p, code): 22 | assert cmd_opts.allow_code, '--allow-code option must be enabled' 23 | 24 | display_result_data = [[], -1, ""] 25 | 26 | def display(imgs, s=display_result_data[1], i=display_result_data[2]): 27 | display_result_data[0] = imgs 28 | display_result_data[1] = s 29 | display_result_data[2] = i 30 | 31 | from types import ModuleType 32 | compiled = compile(code, '', 'exec') 33 | module = ModuleType("testmodule") 34 | module.__dict__.update(globals()) 35 | module.p = p 36 | module.display = display 37 | exec(compiled, module.__dict__) 38 | 39 | return Processed(p, *display_result_data) 40 | 41 | -------------------------------------------------------------------------------- /modules/ui_extra_networks_hypernets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from modules import shared, ui_extra_networks 5 | 6 | 7 | class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): 8 | def __init__(self): 9 | super().__init__('Hypernetworks') 10 | 11 | def refresh(self): 12 | shared.reload_hypernetworks() 13 | 14 | def list_items(self): 15 | for name, path in shared.hypernetworks.items(): 16 | path, ext = os.path.splitext(path) 17 | previews = [path + ".png", path + ".preview.png"] 18 | 19 | preview = None 20 | for file in previews: 21 | if os.path.isfile(file): 22 | preview = self.link_preview(file) 23 | break 24 | 25 | yield { 26 | "name": name, 27 | "filename": path, 28 | "preview": preview, 29 | "search_term": self.search_terms_from_path(path), 30 | "prompt": json.dumps(f""), 31 | "local_preview": path + ".png", 32 | } 33 | 34 | def allowed_directories_for_previews(self): 35 | return [shared.cmd_opts.hypernetwork_dir] 36 | 37 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/ui_extra_networks_lora.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import lora 4 | 5 | from modules import shared, ui_extra_networks 6 | 7 | 8 | class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): 9 | def __init__(self): 10 | super().__init__('Lora') 11 | 12 | def refresh(self): 13 | lora.list_available_loras() 14 | 15 | def list_items(self): 16 | for name, lora_on_disk in lora.available_loras.items(): 17 | path, ext = os.path.splitext(lora_on_disk.filename) 18 | previews = [path + ".png", path + ".preview.png"] 19 | 20 | preview = None 21 | for file in previews: 22 | if os.path.isfile(file): 23 | preview = self.link_preview(file) 24 | break 25 | 26 | yield { 27 | "name": name, 28 | "filename": path, 29 | "preview": preview, 30 | "search_term": self.search_terms_from_path(lora_on_disk.filename), 31 | "prompt": json.dumps(f""), 32 | "local_preview": path + ".png", 33 | } 34 | 35 | def allowed_directories_for_previews(self): 36 | return [shared.cmd_opts.lora_dir] 37 | 38 | -------------------------------------------------------------------------------- /modules/errors.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | 4 | 5 | def print_error_explanation(message): 6 | lines = message.strip().split("\n") 7 | max_len = max([len(x) for x in lines]) 8 | 9 | print('=' * max_len, file=sys.stderr) 10 | for line in lines: 11 | print(line, file=sys.stderr) 12 | print('=' * max_len, file=sys.stderr) 13 | 14 | 15 | def display(e: Exception, task): 16 | print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr) 17 | print(traceback.format_exc(), file=sys.stderr) 18 | 19 | message = str(e) 20 | if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message: 21 | print_error_explanation(""" 22 | The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file. 23 | See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this. 24 | """) 25 | 26 | 27 | already_displayed = {} 28 | 29 | 30 | def display_once(e: Exception, task): 31 | if task in already_displayed: 32 | return 33 | 34 | display(e, task) 35 | 36 | already_displayed[task] = 1 37 | 38 | 39 | def run(code, task): 40 | try: 41 | code() 42 | except Exception as e: 43 | display(task, e) 44 | -------------------------------------------------------------------------------- /modules/ui_extra_networks_textual_inversion.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from modules import ui_extra_networks, sd_hijack 5 | 6 | 7 | class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): 8 | def __init__(self): 9 | super().__init__('Textual Inversion') 10 | self.allow_negative_prompt = True 11 | 12 | def refresh(self): 13 | sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) 14 | 15 | def list_items(self): 16 | for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values(): 17 | path, ext = os.path.splitext(embedding.filename) 18 | preview_file = path + ".preview.png" 19 | 20 | preview = None 21 | if os.path.isfile(preview_file): 22 | preview = self.link_preview(preview_file) 23 | 24 | yield { 25 | "name": embedding.name, 26 | "filename": embedding.filename, 27 | "preview": preview, 28 | "search_term": self.search_terms_from_path(embedding.filename), 29 | "prompt": json.dumps(embedding.name), 30 | "local_preview": path + ".preview.png", 31 | } 32 | 33 | def allowed_directories_for_previews(self): 34 | return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) 35 | -------------------------------------------------------------------------------- /modules/sd_hijack_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | class CondFunc: 4 | def __new__(cls, orig_func, sub_func, cond_func): 5 | self = super(CondFunc, cls).__new__(cls) 6 | if isinstance(orig_func, str): 7 | func_path = orig_func.split('.') 8 | for i in range(len(func_path)-1, -1, -1): 9 | try: 10 | resolved_obj = importlib.import_module('.'.join(func_path[:i])) 11 | break 12 | except ImportError: 13 | pass 14 | for attr_name in func_path[i:-1]: 15 | resolved_obj = getattr(resolved_obj, attr_name) 16 | orig_func = getattr(resolved_obj, func_path[-1]) 17 | setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) 18 | self.__init__(orig_func, sub_func, cond_func) 19 | return lambda *args, **kwargs: self(*args, **kwargs) 20 | def __init__(self, orig_func, sub_func, cond_func): 21 | self.__orig_func = orig_func 22 | self.__sub_func = sub_func 23 | self.__cond_func = cond_func 24 | def __call__(self, *args, **kwargs): 25 | if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): 26 | return self.__sub_func(self.__orig_func, *args, **kwargs) 27 | else: 28 | return self.__orig_func(*args, **kwargs) 29 | -------------------------------------------------------------------------------- /modules/sd_hijack_checkpoint.py: -------------------------------------------------------------------------------- 1 | from torch.utils.checkpoint import checkpoint 2 | 3 | import ldm.modules.attention 4 | import ldm.modules.diffusionmodules.openaimodel 5 | 6 | 7 | def BasicTransformerBlock_forward(self, x, context=None): 8 | return checkpoint(self._forward, x, context) 9 | 10 | 11 | def AttentionBlock_forward(self, x): 12 | return checkpoint(self._forward, x) 13 | 14 | 15 | def ResBlock_forward(self, x, emb): 16 | return checkpoint(self._forward, x, emb) 17 | 18 | 19 | stored = [] 20 | 21 | 22 | def add(): 23 | if len(stored) != 0: 24 | return 25 | 26 | stored.extend([ 27 | ldm.modules.attention.BasicTransformerBlock.forward, 28 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward, 29 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward 30 | ]) 31 | 32 | ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward 33 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward 34 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward 35 | 36 | 37 | def remove(): 38 | if len(stored) == 0: 39 | return 40 | 41 | ldm.modules.attention.BasicTransformerBlock.forward = stored[0] 42 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1] 43 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2] 44 | 45 | stored.clear() 46 | 47 | -------------------------------------------------------------------------------- /javascript/generationParams.js: -------------------------------------------------------------------------------- 1 | // attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes 2 | 3 | let txt2img_gallery, img2img_gallery, modal = undefined; 4 | onUiUpdate(function(){ 5 | if (!txt2img_gallery) { 6 | txt2img_gallery = attachGalleryListeners("txt2img") 7 | } 8 | if (!img2img_gallery) { 9 | img2img_gallery = attachGalleryListeners("img2img") 10 | } 11 | if (!modal) { 12 | modal = gradioApp().getElementById('lightboxModal') 13 | modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] }); 14 | } 15 | }); 16 | 17 | let modalObserver = new MutationObserver(function(mutations) { 18 | mutations.forEach(function(mutationRecord) { 19 | let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText 20 | if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img') 21 | gradioApp().getElementById(selectedTab+"_generation_info_button").click() 22 | }); 23 | }); 24 | 25 | function attachGalleryListeners(tab_name) { 26 | gallery = gradioApp().querySelector('#'+tab_name+'_gallery') 27 | gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click()); 28 | gallery?.addEventListener('keydown', (e) => { 29 | if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow 30 | gradioApp().getElementById(tab_name+"_generation_info_button").click() 31 | }); 32 | return gallery; 33 | } 34 | -------------------------------------------------------------------------------- /.github/workflows/on_pull_request.yaml: -------------------------------------------------------------------------------- 1 | # See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file 2 | name: Run Linting/Formatting on Pull Requests 3 | 4 | on: 5 | - push 6 | - pull_request 7 | # See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs 8 | # if you want to filter out branches, delete the `- pull_request` and uncomment these lines : 9 | # pull_request: 10 | # branches: 11 | # - master 12 | # branches-ignore: 13 | # - development 14 | 15 | jobs: 16 | lint: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Checkout Code 20 | uses: actions/checkout@v3 21 | - name: Set up Python 3.10 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: 3.10.6 25 | cache: pip 26 | cache-dependency-path: | 27 | **/requirements*txt 28 | - name: Install PyLint 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install pylint 32 | # This lets PyLint check to see if it can resolve imports 33 | - name: Install dependencies 34 | run: | 35 | export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit" 36 | python launch.py 37 | - name: Analysing the code with pylint 38 | run: | 39 | pylint $(git ls-files '*.py') 40 | -------------------------------------------------------------------------------- /modules/sd_hijack_open_clip.py: -------------------------------------------------------------------------------- 1 | import open_clip.tokenizer 2 | import torch 3 | 4 | from modules import sd_hijack_clip, devices 5 | from modules.shared import opts 6 | 7 | tokenizer = open_clip.tokenizer._tokenizer 8 | 9 | 10 | class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): 11 | def __init__(self, wrapped, hijack): 12 | super().__init__(wrapped, hijack) 13 | 14 | self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] 15 | self.id_start = tokenizer.encoder[""] 16 | self.id_end = tokenizer.encoder[""] 17 | self.id_pad = 0 18 | 19 | def tokenize(self, texts): 20 | assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' 21 | 22 | tokenized = [tokenizer.encode(text) for text in texts] 23 | 24 | return tokenized 25 | 26 | def encode_with_transformers(self, tokens): 27 | # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers 28 | z = self.wrapped.encode_with_transformer(tokens) 29 | 30 | return z 31 | 32 | def encode_embedding_init_text(self, init_text, nvpt): 33 | ids = tokenizer.encode(init_text) 34 | ids = torch.asarray([ids], device=devices.device, dtype=torch.int) 35 | embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) 36 | 37 | return embedded 38 | -------------------------------------------------------------------------------- /webui-user.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ######################################################### 3 | # Uncomment and change the variables below to your need:# 4 | ######################################################### 5 | 6 | # Install directory without trailing slash 7 | #install_dir="/home/$(whoami)" 8 | 9 | # Name of the subdirectory 10 | #clone_dir="stable-diffusion-webui" 11 | 12 | # Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention" 13 | #export COMMANDLINE_ARGS="" 14 | 15 | # python3 executable 16 | #python_cmd="python3" 17 | 18 | # git executable 19 | #export GIT="git" 20 | 21 | # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) 22 | #venv_dir="venv" 23 | 24 | # script to launch to start the app 25 | #export LAUNCH_SCRIPT="launch.py" 26 | 27 | # install command for torch 28 | #export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113" 29 | 30 | # Requirements file to use for stable-diffusion-webui 31 | #export REQS_FILE="requirements_versions.txt" 32 | 33 | # Fixed git repos 34 | #export K_DIFFUSION_PACKAGE="" 35 | #export GFPGAN_PACKAGE="" 36 | 37 | # Fixed git commits 38 | #export STABLE_DIFFUSION_COMMIT_HASH="" 39 | #export TAMING_TRANSFORMERS_COMMIT_HASH="" 40 | #export CODEFORMER_COMMIT_HASH="" 41 | #export BLIP_COMMIT_HASH="" 42 | 43 | # Uncomment to enable accelerated launch 44 | #export ACCELERATE="True" 45 | 46 | ########################################### 47 | -------------------------------------------------------------------------------- /modules/ui_extra_networks_checkpoints.py: -------------------------------------------------------------------------------- 1 | import html 2 | import json 3 | import os 4 | import urllib.parse 5 | 6 | from modules import shared, ui_extra_networks, sd_models 7 | 8 | 9 | class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): 10 | def __init__(self): 11 | super().__init__('Checkpoints') 12 | 13 | def refresh(self): 14 | shared.refresh_checkpoints() 15 | 16 | def list_items(self): 17 | checkpoint: sd_models.CheckpointInfo 18 | for name, checkpoint in sd_models.checkpoints_list.items(): 19 | path, ext = os.path.splitext(checkpoint.filename) 20 | previews = [path + ".png", path + ".preview.png"] 21 | 22 | preview = None 23 | for file in previews: 24 | if os.path.isfile(file): 25 | preview = self.link_preview(file) 26 | break 27 | 28 | yield { 29 | "name": checkpoint.name_for_extra, 30 | "filename": path, 31 | "preview": preview, 32 | "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), 33 | "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', 34 | "local_preview": path + ".png", 35 | } 36 | 37 | def allowed_directories_for_previews(self): 38 | return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] 39 | 40 | -------------------------------------------------------------------------------- /modules/textual_inversion/logging.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import os 4 | 5 | saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"} 6 | saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"} 7 | saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"} 8 | saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet 9 | saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"} 10 | 11 | 12 | def save_settings_to_file(log_directory, all_params): 13 | now = datetime.datetime.now() 14 | params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")} 15 | 16 | keys = saved_params_all 17 | if all_params.get('preview_from_txt2img'): 18 | keys = keys | saved_params_previews 19 | 20 | params.update({k: v for k, v in all_params.items() if k in keys}) 21 | 22 | filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json' 23 | with open(os.path.join(log_directory, filename), "w") as file: 24 | json.dump(params, file, indent=4) 25 | -------------------------------------------------------------------------------- /modules/sd_samplers.py: -------------------------------------------------------------------------------- 1 | from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared 2 | 3 | # imports for functions that previously were here and are used by other modules 4 | from modules.sd_samplers_common import samples_to_image_grid, sample_to_image 5 | 6 | all_samplers = [ 7 | *sd_samplers_kdiffusion.samplers_data_k_diffusion, 8 | *sd_samplers_compvis.samplers_data_compvis, 9 | ] 10 | all_samplers_map = {x.name: x for x in all_samplers} 11 | 12 | samplers = [] 13 | samplers_for_img2img = [] 14 | samplers_map = {} 15 | 16 | 17 | def create_sampler(name, model): 18 | if name is not None: 19 | config = all_samplers_map.get(name, None) 20 | else: 21 | config = all_samplers[0] 22 | 23 | assert config is not None, f'bad sampler name: {name}' 24 | 25 | sampler = config.constructor(model) 26 | sampler.config = config 27 | 28 | return sampler 29 | 30 | 31 | def set_samplers(): 32 | global samplers, samplers_for_img2img 33 | 34 | hidden = set(shared.opts.hide_samplers) 35 | hidden_img2img = set(shared.opts.hide_samplers + ['PLMS']) 36 | 37 | samplers = [x for x in all_samplers if x.name not in hidden] 38 | samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] 39 | 40 | samplers_map.clear() 41 | for sampler in all_samplers: 42 | samplers_map[sampler.name.lower()] = sampler.name 43 | for alias in sampler.aliases: 44 | samplers_map[alias.lower()] = sampler.name 45 | 46 | 47 | set_samplers() 48 | -------------------------------------------------------------------------------- /javascript/extensions.js: -------------------------------------------------------------------------------- 1 | 2 | function extensions_apply(_, _){ 3 | var disable = [] 4 | var update = [] 5 | 6 | gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ 7 | if(x.name.startsWith("enable_") && ! x.checked) 8 | disable.push(x.name.substr(7)) 9 | 10 | if(x.name.startsWith("update_") && x.checked) 11 | update.push(x.name.substr(7)) 12 | }) 13 | 14 | restart_reload() 15 | 16 | return [JSON.stringify(disable), JSON.stringify(update)] 17 | } 18 | 19 | function extensions_check(){ 20 | var disable = [] 21 | 22 | gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ 23 | if(x.name.startsWith("enable_") && ! x.checked) 24 | disable.push(x.name.substr(7)) 25 | }) 26 | 27 | gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ 28 | x.innerHTML = "Loading..." 29 | }) 30 | 31 | 32 | var id = randomId() 33 | requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){ 34 | 35 | }) 36 | 37 | return [id, JSON.stringify(disable)] 38 | } 39 | 40 | function install_extension_from_index(button, url){ 41 | button.disabled = "disabled" 42 | button.value = "Installing..." 43 | 44 | textarea = gradioApp().querySelector('#extension_to_install textarea') 45 | textarea.value = url 46 | updateInput(textarea) 47 | 48 | gradioApp().querySelector('#install_extension_button').click() 49 | } 50 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Suggest an idea for this project 3 | title: "[Feature Request]: " 4 | labels: ["enhancement"] 5 | 6 | body: 7 | - type: checkboxes 8 | attributes: 9 | label: Is there an existing issue for this? 10 | description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit. 11 | options: 12 | - label: I have searched the existing issues and checked the recent builds/commits 13 | required: true 14 | - type: markdown 15 | attributes: 16 | value: | 17 | *Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible* 18 | - type: textarea 19 | id: feature 20 | attributes: 21 | label: What would your feature do ? 22 | description: Tell us about your feature in a very clear and simple way, and what problem it would solve 23 | validations: 24 | required: true 25 | - type: textarea 26 | id: workflow 27 | attributes: 28 | label: Proposed workflow 29 | description: Please provide us with step by step information on how you'd like the feature to be accessed and used 30 | value: | 31 | 1. Go to .... 32 | 2. Press .... 33 | 3. ... 34 | validations: 35 | required: true 36 | - type: textarea 37 | id: misc 38 | attributes: 39 | label: Additional information 40 | description: Add any other context or screenshots about the feature request here. 41 | -------------------------------------------------------------------------------- /javascript/imageMaskFix.js: -------------------------------------------------------------------------------- 1 | /** 2 | * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668 3 | * @see https://github.com/gradio-app/gradio/issues/1721 4 | */ 5 | window.addEventListener( 'resize', () => imageMaskResize()); 6 | function imageMaskResize() { 7 | const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas'); 8 | if ( ! canvases.length ) { 9 | canvases_fixed = false; 10 | window.removeEventListener( 'resize', imageMaskResize ); 11 | return; 12 | } 13 | 14 | const wrapper = canvases[0].closest('.touch-none'); 15 | const previewImage = wrapper.previousElementSibling; 16 | 17 | if ( ! previewImage.complete ) { 18 | previewImage.addEventListener( 'load', () => imageMaskResize()); 19 | return; 20 | } 21 | 22 | const w = previewImage.width; 23 | const h = previewImage.height; 24 | const nw = previewImage.naturalWidth; 25 | const nh = previewImage.naturalHeight; 26 | const portrait = nh > nw; 27 | const factor = portrait; 28 | 29 | const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw); 30 | const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh); 31 | 32 | wrapper.style.width = `${wW}px`; 33 | wrapper.style.height = `${wH}px`; 34 | wrapper.style.left = `0px`; 35 | wrapper.style.top = `0px`; 36 | 37 | canvases.forEach( c => { 38 | c.style.width = c.style.height = ''; 39 | c.style.maxWidth = '100%'; 40 | c.style.maxHeight = '100%'; 41 | c.style.objectFit = 'contain'; 42 | }); 43 | } 44 | 45 | onUiUpdate(() => imageMaskResize()); 46 | -------------------------------------------------------------------------------- /scripts/postprocessing_codeformer.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | from modules import scripts_postprocessing, codeformer_model 5 | import gradio as gr 6 | 7 | from modules.ui_components import FormRow 8 | 9 | 10 | class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): 11 | name = "CodeFormer" 12 | order = 3000 13 | 14 | def ui(self): 15 | with FormRow(): 16 | codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility") 17 | codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") 18 | 19 | return { 20 | "codeformer_visibility": codeformer_visibility, 21 | "codeformer_weight": codeformer_weight, 22 | } 23 | 24 | def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): 25 | if codeformer_visibility == 0: 26 | return 27 | 28 | restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) 29 | res = Image.fromarray(restored_img) 30 | 31 | if codeformer_visibility < 1.0: 32 | res = Image.blend(pp.image, res, codeformer_visibility) 33 | 34 | pp.image = res 35 | pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3) 36 | pp.info["CodeFormer weight"] = round(codeformer_weight, 3) 37 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/scripts/lora_script.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gradio as gr 3 | 4 | import lora 5 | import extra_networks_lora 6 | import ui_extra_networks_lora 7 | from modules import script_callbacks, ui_extra_networks, extra_networks, shared 8 | 9 | 10 | def unload(): 11 | torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora 12 | torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora 13 | 14 | 15 | def before_ui(): 16 | ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) 17 | extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora()) 18 | 19 | 20 | if not hasattr(torch.nn, 'Linear_forward_before_lora'): 21 | torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward 22 | 23 | if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): 24 | torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward 25 | 26 | torch.nn.Linear.forward = lora.lora_Linear_forward 27 | torch.nn.Conv2d.forward = lora.lora_Conv2d_forward 28 | 29 | script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) 30 | script_callbacks.on_script_unloaded(unload) 31 | script_callbacks.on_before_ui(before_ui) 32 | 33 | 34 | shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { 35 | "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), 36 | "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"), 37 | 38 | })) 39 | -------------------------------------------------------------------------------- /modules/sd_hijack_xlmr.py: -------------------------------------------------------------------------------- 1 | import open_clip.tokenizer 2 | import torch 3 | 4 | from modules import sd_hijack_clip, devices 5 | from modules.shared import opts 6 | 7 | 8 | class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords): 9 | def __init__(self, wrapped, hijack): 10 | super().__init__(wrapped, hijack) 11 | 12 | self.id_start = wrapped.config.bos_token_id 13 | self.id_end = wrapped.config.eos_token_id 14 | self.id_pad = wrapped.config.pad_token_id 15 | 16 | self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have bits for comma 17 | 18 | def encode_with_transformers(self, tokens): 19 | # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a 20 | # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer 21 | # layer to work with - you have to use the last 22 | 23 | attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64) 24 | features = self.wrapped(input_ids=tokens, attention_mask=attention_mask) 25 | z = features['projection_state'] 26 | 27 | return z 28 | 29 | def encode_embedding_init_text(self, init_text, nvpt): 30 | embedding_layer = self.wrapped.roberta.embeddings 31 | ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] 32 | embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) 33 | 34 | return embedded 35 | -------------------------------------------------------------------------------- /modules/scripts_auto_postprocessing.py: -------------------------------------------------------------------------------- 1 | from modules import scripts, scripts_postprocessing, shared 2 | 3 | 4 | class ScriptPostprocessingForMainUI(scripts.Script): 5 | def __init__(self, script_postproc): 6 | self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc 7 | self.postprocessing_controls = None 8 | 9 | def title(self): 10 | return self.script.name 11 | 12 | def show(self, is_img2img): 13 | return scripts.AlwaysVisible 14 | 15 | def ui(self, is_img2img): 16 | self.postprocessing_controls = self.script.ui() 17 | return self.postprocessing_controls.values() 18 | 19 | def postprocess_image(self, p, script_pp, *args): 20 | args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)} 21 | 22 | pp = scripts_postprocessing.PostprocessedImage(script_pp.image) 23 | pp.info = {} 24 | self.script.process(pp, **args_dict) 25 | p.extra_generation_params.update(pp.info) 26 | script_pp.image = pp.image 27 | 28 | 29 | def create_auto_preprocessing_script_data(): 30 | from modules import scripts 31 | 32 | res = [] 33 | 34 | for name in shared.opts.postprocessing_enable_in_main_ui: 35 | script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None) 36 | if script is None: 37 | continue 38 | 39 | constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class()) 40 | res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module)) 41 | 42 | return res 43 | -------------------------------------------------------------------------------- /modules/textual_inversion/ui.py: -------------------------------------------------------------------------------- 1 | import html 2 | 3 | import gradio as gr 4 | 5 | import modules.textual_inversion.textual_inversion 6 | import modules.textual_inversion.preprocess 7 | from modules import sd_hijack, shared 8 | 9 | 10 | def create_embedding(name, initialization_text, nvpt, overwrite_old): 11 | filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text) 12 | 13 | sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() 14 | 15 | return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" 16 | 17 | 18 | def preprocess(*args): 19 | modules.textual_inversion.preprocess.preprocess(*args) 20 | 21 | return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", "" 22 | 23 | 24 | def train_embedding(*args): 25 | 26 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' 27 | 28 | apply_optimizations = shared.opts.training_xattention_optimizations 29 | try: 30 | if not apply_optimizations: 31 | sd_hijack.undo_optimizations() 32 | 33 | embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) 34 | 35 | res = f""" 36 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps. 37 | Embedding saved to {html.escape(filename)} 38 | """ 39 | return res, "" 40 | except Exception: 41 | raise 42 | finally: 43 | if not apply_optimizations: 44 | sd_hijack.apply_optimizations() 45 | 46 | -------------------------------------------------------------------------------- /javascript/notification.js: -------------------------------------------------------------------------------- 1 | // Monitors the gallery and sends a browser notification when the leading image is new. 2 | 3 | let lastHeadImg = null; 4 | 5 | notificationButton = null 6 | 7 | onUiUpdate(function(){ 8 | if(notificationButton == null){ 9 | notificationButton = gradioApp().getElementById('request_notifications') 10 | 11 | if(notificationButton != null){ 12 | notificationButton.addEventListener('click', function (evt) { 13 | Notification.requestPermission(); 14 | },true); 15 | } 16 | } 17 | 18 | const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden'); 19 | 20 | if (galleryPreviews == null) return; 21 | 22 | const headImg = galleryPreviews[0]?.src; 23 | 24 | if (headImg == null || headImg == lastHeadImg) return; 25 | 26 | lastHeadImg = headImg; 27 | 28 | // play notification sound if available 29 | gradioApp().querySelector('#audio_notification audio')?.play(); 30 | 31 | if (document.hasFocus()) return; 32 | 33 | // Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated. 34 | const imgs = new Set(Array.from(galleryPreviews).map(img => img.src)); 35 | 36 | const notification = new Notification( 37 | 'Stable Diffusion', 38 | { 39 | body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`, 40 | icon: headImg, 41 | image: headImg, 42 | } 43 | ); 44 | 45 | notification.onclick = function(_){ 46 | parent.focus(); 47 | this.close(); 48 | }; 49 | }); 50 | -------------------------------------------------------------------------------- /modules/hypernetworks/ui.py: -------------------------------------------------------------------------------- 1 | import html 2 | import os 3 | import re 4 | 5 | import gradio as gr 6 | import modules.hypernetworks.hypernetwork 7 | from modules import devices, sd_hijack, shared 8 | 9 | not_available = ["hardswish", "multiheadattention"] 10 | keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) 11 | 12 | 13 | def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): 14 | filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure) 15 | 16 | return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", "" 17 | 18 | 19 | def train_hypernetwork(*args): 20 | shared.loaded_hypernetworks = [] 21 | 22 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' 23 | 24 | try: 25 | sd_hijack.undo_optimizations() 26 | 27 | hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args) 28 | 29 | res = f""" 30 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. 31 | Hypernetwork saved to {html.escape(filename)} 32 | """ 33 | return res, "" 34 | except Exception: 35 | raise 36 | finally: 37 | shared.sd_model.cond_stage_model.to(devices.device) 38 | shared.sd_model.first_stage_model.to(devices.device) 39 | sd_hijack.apply_optimizations() 40 | 41 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request! 2 | 3 | If you have a large change, pay special attention to this paragraph: 4 | 5 | > Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature. 6 | 7 | Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on. 8 | 9 | **Describe what this pull request is trying to achieve.** 10 | 11 | A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code. 12 | 13 | **Additional notes and description of your changes** 14 | 15 | More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of. 16 | 17 | **Environment this was tested in** 18 | 19 | List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box. 20 | - OS: [e.g. Windows, Linux] 21 | - Browser: [e.g. chrome, safari] 22 | - Graphics card: [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB] 23 | 24 | **Screenshots or videos of your changes** 25 | 26 | If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made. 27 | 28 | This is **required** for anything that touches the user interface. -------------------------------------------------------------------------------- /modules/ui_components.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | 4 | class ToolButton(gr.Button, gr.components.FormComponent): 5 | """Small button with single emoji as text, fits inside gradio forms""" 6 | 7 | def __init__(self, **kwargs): 8 | super().__init__(variant="tool", **kwargs) 9 | 10 | def get_block_name(self): 11 | return "button" 12 | 13 | 14 | class ToolButtonTop(gr.Button, gr.components.FormComponent): 15 | """Small button with single emoji as text, with extra margin at top, fits inside gradio forms""" 16 | 17 | def __init__(self, **kwargs): 18 | super().__init__(variant="tool-top", **kwargs) 19 | 20 | def get_block_name(self): 21 | return "button" 22 | 23 | 24 | class FormRow(gr.Row, gr.components.FormComponent): 25 | """Same as gr.Row but fits inside gradio forms""" 26 | 27 | def get_block_name(self): 28 | return "row" 29 | 30 | 31 | class FormGroup(gr.Group, gr.components.FormComponent): 32 | """Same as gr.Row but fits inside gradio forms""" 33 | 34 | def get_block_name(self): 35 | return "group" 36 | 37 | 38 | class FormHTML(gr.HTML, gr.components.FormComponent): 39 | """Same as gr.HTML but fits inside gradio forms""" 40 | 41 | def get_block_name(self): 42 | return "html" 43 | 44 | 45 | class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): 46 | """Same as gr.ColorPicker but fits inside gradio forms""" 47 | 48 | def get_block_name(self): 49 | return "colorpicker" 50 | 51 | 52 | class DropdownMulti(gr.Dropdown): 53 | """Same as gr.Dropdown but always multiselect""" 54 | def __init__(self, **kwargs): 55 | super().__init__(multiselect=True, **kwargs) 56 | 57 | def get_block_name(self): 58 | return "dropdown" 59 | -------------------------------------------------------------------------------- /modules/sd_vae_approx.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from modules import devices, paths 6 | 7 | sd_vae_approx_model = None 8 | 9 | 10 | class VAEApprox(nn.Module): 11 | def __init__(self): 12 | super(VAEApprox, self).__init__() 13 | self.conv1 = nn.Conv2d(4, 8, (7, 7)) 14 | self.conv2 = nn.Conv2d(8, 16, (5, 5)) 15 | self.conv3 = nn.Conv2d(16, 32, (3, 3)) 16 | self.conv4 = nn.Conv2d(32, 64, (3, 3)) 17 | self.conv5 = nn.Conv2d(64, 32, (3, 3)) 18 | self.conv6 = nn.Conv2d(32, 16, (3, 3)) 19 | self.conv7 = nn.Conv2d(16, 8, (3, 3)) 20 | self.conv8 = nn.Conv2d(8, 3, (3, 3)) 21 | 22 | def forward(self, x): 23 | extra = 11 24 | x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2)) 25 | x = nn.functional.pad(x, (extra, extra, extra, extra)) 26 | 27 | for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]: 28 | x = layer(x) 29 | x = nn.functional.leaky_relu(x, 0.1) 30 | 31 | return x 32 | 33 | 34 | def model(): 35 | global sd_vae_approx_model 36 | 37 | if sd_vae_approx_model is None: 38 | sd_vae_approx_model = VAEApprox() 39 | sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None)) 40 | sd_vae_approx_model.eval() 41 | sd_vae_approx_model.to(devices.device, devices.dtype) 42 | 43 | return sd_vae_approx_model 44 | 45 | 46 | def cheap_approximation(sample): 47 | # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 48 | 49 | coefs = torch.tensor([ 50 | [0.298, 0.207, 0.208], 51 | [0.187, 0.286, 0.173], 52 | [-0.158, 0.189, 0.264], 53 | [-0.184, -0.271, -0.473], 54 | ]).to(sample.device) 55 | 56 | x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs) 57 | 58 | return x_sample 59 | -------------------------------------------------------------------------------- /configs/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /test/basic_features/extras_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import requests 3 | from gradio.processing_utils import encode_pil_to_base64 4 | from PIL import Image 5 | 6 | class TestExtrasWorking(unittest.TestCase): 7 | def setUp(self): 8 | self.url_extras_single = "http://localhost:7860/sdapi/v1/extra-single-image" 9 | self.extras_single = { 10 | "resize_mode": 0, 11 | "show_extras_results": True, 12 | "gfpgan_visibility": 0, 13 | "codeformer_visibility": 0, 14 | "codeformer_weight": 0, 15 | "upscaling_resize": 2, 16 | "upscaling_resize_w": 128, 17 | "upscaling_resize_h": 128, 18 | "upscaling_crop": True, 19 | "upscaler_1": "None", 20 | "upscaler_2": "None", 21 | "extras_upscaler_2_visibility": 0, 22 | "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")) 23 | } 24 | 25 | def test_simple_upscaling_performed(self): 26 | self.extras_single["upscaler_1"] = "Lanczos" 27 | self.assertEqual(requests.post(self.url_extras_single, json=self.extras_single).status_code, 200) 28 | 29 | 30 | class TestPngInfoWorking(unittest.TestCase): 31 | def setUp(self): 32 | self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image" 33 | self.png_info = { 34 | "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")) 35 | } 36 | 37 | def test_png_info_performed(self): 38 | self.assertEqual(requests.post(self.url_png_info, json=self.png_info).status_code, 200) 39 | 40 | 41 | class TestInterrogateWorking(unittest.TestCase): 42 | def setUp(self): 43 | self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image" 44 | self.interrogate = { 45 | "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")), 46 | "model": "clip" 47 | } 48 | 49 | def test_interrogate_performed(self): 50 | self.assertEqual(requests.post(self.url_interrogate, json=self.interrogate).status_code, 200) 51 | 52 | 53 | if __name__ == "__main__": 54 | unittest.main() 55 | -------------------------------------------------------------------------------- /configs/alt-diffusion-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: modules.xlmr.BertSeriesModelWithTransformation 71 | params: 72 | name: "XLMR-Large" -------------------------------------------------------------------------------- /configs/v1-inpainting-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 7.5e-05 3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid # important 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | finetune_keys: null 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /modules/hashes.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import os.path 4 | 5 | import filelock 6 | 7 | from modules import shared 8 | from modules.paths import data_path 9 | 10 | 11 | cache_filename = os.path.join(data_path, "cache.json") 12 | cache_data = None 13 | 14 | 15 | def dump_cache(): 16 | with filelock.FileLock(cache_filename+".lock"): 17 | with open(cache_filename, "w", encoding="utf8") as file: 18 | json.dump(cache_data, file, indent=4) 19 | 20 | 21 | def cache(subsection): 22 | global cache_data 23 | 24 | if cache_data is None: 25 | with filelock.FileLock(cache_filename+".lock"): 26 | if not os.path.isfile(cache_filename): 27 | cache_data = {} 28 | else: 29 | with open(cache_filename, "r", encoding="utf8") as file: 30 | cache_data = json.load(file) 31 | 32 | s = cache_data.get(subsection, {}) 33 | cache_data[subsection] = s 34 | 35 | return s 36 | 37 | 38 | def calculate_sha256(filename): 39 | hash_sha256 = hashlib.sha256() 40 | blksize = 1024 * 1024 41 | 42 | with open(filename, "rb") as f: 43 | for chunk in iter(lambda: f.read(blksize), b""): 44 | hash_sha256.update(chunk) 45 | 46 | return hash_sha256.hexdigest() 47 | 48 | 49 | def sha256_from_cache(filename, title): 50 | hashes = cache("hashes") 51 | ondisk_mtime = os.path.getmtime(filename) 52 | 53 | if title not in hashes: 54 | return None 55 | 56 | cached_sha256 = hashes[title].get("sha256", None) 57 | cached_mtime = hashes[title].get("mtime", 0) 58 | 59 | if ondisk_mtime > cached_mtime or cached_sha256 is None: 60 | return None 61 | 62 | return cached_sha256 63 | 64 | 65 | def sha256(filename, title): 66 | hashes = cache("hashes") 67 | 68 | sha256_value = sha256_from_cache(filename, title) 69 | if sha256_value is not None: 70 | return sha256_value 71 | 72 | if shared.cmd_opts.no_hashing: 73 | return None 74 | 75 | print(f"Calculating sha256 for {filename}: ", end='') 76 | sha256_value = calculate_sha256(filename) 77 | print(f"{sha256_value}") 78 | 79 | hashes[title] = { 80 | "mtime": os.path.getmtime(filename), 81 | "sha256": sha256_value, 82 | } 83 | 84 | dump_cache() 85 | 86 | return sha256_value 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /webui.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | if not defined PYTHON (set PYTHON=python) 4 | if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv") 5 | 6 | 7 | set ERROR_REPORTING=FALSE 8 | 9 | mkdir tmp 2>NUL 10 | 11 | %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt 12 | if %ERRORLEVEL% == 0 goto :check_pip 13 | echo Couldn't launch python 14 | goto :show_stdout_stderr 15 | 16 | :check_pip 17 | %PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt 18 | if %ERRORLEVEL% == 0 goto :start_venv 19 | if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr 20 | %PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt 21 | if %ERRORLEVEL% == 0 goto :start_venv 22 | echo Couldn't install pip 23 | goto :show_stdout_stderr 24 | 25 | :start_venv 26 | if ["%VENV_DIR%"] == ["-"] goto :skip_venv 27 | if ["%SKIP_VENV%"] == ["1"] goto :skip_venv 28 | 29 | dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt 30 | if %ERRORLEVEL% == 0 goto :activate_venv 31 | 32 | for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i" 33 | echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME% 34 | %PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt 35 | if %ERRORLEVEL% == 0 goto :activate_venv 36 | echo Unable to create venv in directory "%VENV_DIR%" 37 | goto :show_stdout_stderr 38 | 39 | :activate_venv 40 | set PYTHON="%VENV_DIR%\Scripts\Python.exe" 41 | echo venv %PYTHON% 42 | 43 | :skip_venv 44 | if [%ACCELERATE%] == ["True"] goto :accelerate 45 | goto :launch 46 | 47 | :accelerate 48 | echo Checking for accelerate 49 | set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe" 50 | if EXIST %ACCELERATE% goto :accelerate_launch 51 | 52 | :launch 53 | %PYTHON% launch.py %* 54 | pause 55 | exit /b 56 | 57 | :accelerate_launch 58 | echo Accelerating 59 | %ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py 60 | pause 61 | exit /b 62 | 63 | :show_stdout_stderr 64 | 65 | echo. 66 | echo exit code: %errorlevel% 67 | 68 | for /f %%i in ("tmp\stdout.txt") do set size=%%~zi 69 | if %size% equ 0 goto :show_stderr 70 | echo. 71 | echo stdout: 72 | type tmp\stdout.txt 73 | 74 | :show_stderr 75 | for /f %%i in ("tmp\stderr.txt") do set size=%%~zi 76 | if %size% equ 0 goto :show_stderr 77 | echo. 78 | echo stderr: 79 | type tmp\stderr.txt 80 | 81 | :endofscript 82 | 83 | echo. 84 | echo Launch unsuccessful. Exiting. 85 | pause 86 | -------------------------------------------------------------------------------- /modules/sd_samplers_common.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | from modules import devices, processing, images, sd_vae_approx 6 | 7 | from modules.shared import opts, state 8 | import modules.shared as shared 9 | 10 | SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) 11 | 12 | 13 | def setup_img2img_steps(p, steps=None): 14 | if opts.img2img_fix_steps or steps is not None: 15 | requested_steps = (steps or p.steps) 16 | steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 17 | t_enc = requested_steps - 1 18 | else: 19 | steps = p.steps 20 | t_enc = int(min(p.denoising_strength, 0.999) * steps) 21 | 22 | return steps, t_enc 23 | 24 | 25 | approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2} 26 | 27 | 28 | def single_sample_to_image(sample, approximation=None): 29 | if approximation is None: 30 | approximation = approximation_indexes.get(opts.show_progress_type, 0) 31 | 32 | if approximation == 2: 33 | x_sample = sd_vae_approx.cheap_approximation(sample) 34 | elif approximation == 1: 35 | x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() 36 | else: 37 | x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] 38 | 39 | x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) 40 | x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) 41 | x_sample = x_sample.astype(np.uint8) 42 | return Image.fromarray(x_sample) 43 | 44 | 45 | def sample_to_image(samples, index=0, approximation=None): 46 | return single_sample_to_image(samples[index], approximation) 47 | 48 | 49 | def samples_to_image_grid(samples, approximation=None): 50 | return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples]) 51 | 52 | 53 | def store_latent(decoded): 54 | state.current_latent = decoded 55 | 56 | if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: 57 | if not shared.parallel_processing_allowed: 58 | shared.state.assign_current_image(sample_to_image(decoded)) 59 | 60 | 61 | class InterruptedException(BaseException): 62 | pass 63 | -------------------------------------------------------------------------------- /modules/paths.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import modules.safe 5 | 6 | script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 7 | 8 | # Parse the --data-dir flag first so we can use it as a base for our other argument default values 9 | parser = argparse.ArgumentParser(add_help=False) 10 | parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) 11 | cmd_opts_pre = parser.parse_known_args()[0] 12 | data_path = cmd_opts_pre.data_dir 13 | models_path = os.path.join(data_path, "models") 14 | 15 | # data_path = cmd_opts_pre.data 16 | sys.path.insert(0, script_path) 17 | 18 | # search for directory of stable diffusion in following places 19 | sd_path = None 20 | possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)] 21 | for possible_sd_path in possible_sd_paths: 22 | if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): 23 | sd_path = os.path.abspath(possible_sd_path) 24 | break 25 | 26 | assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths) 27 | 28 | path_dirs = [ 29 | (sd_path, 'ldm', 'Stable Diffusion', []), 30 | (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []), 31 | (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), 32 | (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), 33 | (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), 34 | ] 35 | 36 | paths = {} 37 | 38 | for d, must_exist, what, options in path_dirs: 39 | must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist)) 40 | if not os.path.exists(must_exist_path): 41 | print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr) 42 | else: 43 | d = os.path.abspath(d) 44 | if "atstart" in options: 45 | sys.path.insert(0, d) 46 | else: 47 | sys.path.append(d) 48 | paths[what] = d 49 | 50 | 51 | class Prioritize: 52 | def __init__(self, name): 53 | self.name = name 54 | self.path = None 55 | 56 | def __enter__(self): 57 | self.path = sys.path.copy() 58 | sys.path = [paths[self.name]] + sys.path 59 | 60 | def __exit__(self, exc_type, exc_val, exc_tb): 61 | sys.path = self.path 62 | self.path = None 63 | -------------------------------------------------------------------------------- /test/basic_features/utils_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import requests 3 | 4 | class UtilsTests(unittest.TestCase): 5 | def setUp(self): 6 | self.url_options = "http://localhost:7860/sdapi/v1/options" 7 | self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags" 8 | self.url_samplers = "http://localhost:7860/sdapi/v1/samplers" 9 | self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers" 10 | self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models" 11 | self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks" 12 | self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers" 13 | self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models" 14 | self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles" 15 | self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings" 16 | 17 | def test_options_get(self): 18 | self.assertEqual(requests.get(self.url_options).status_code, 200) 19 | 20 | def test_options_write(self): 21 | response = requests.get(self.url_options) 22 | self.assertEqual(response.status_code, 200) 23 | 24 | pre_value = response.json()["send_seed"] 25 | 26 | self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200) 27 | 28 | response = requests.get(self.url_options) 29 | self.assertEqual(response.status_code, 200) 30 | self.assertEqual(response.json()["send_seed"], not pre_value) 31 | 32 | requests.post(self.url_options, json={"send_seed": pre_value}) 33 | 34 | def test_cmd_flags(self): 35 | self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200) 36 | 37 | def test_samplers(self): 38 | self.assertEqual(requests.get(self.url_samplers).status_code, 200) 39 | 40 | def test_upscalers(self): 41 | self.assertEqual(requests.get(self.url_upscalers).status_code, 200) 42 | 43 | def test_sd_models(self): 44 | self.assertEqual(requests.get(self.url_sd_models).status_code, 200) 45 | 46 | def test_hypernetworks(self): 47 | self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200) 48 | 49 | def test_face_restorers(self): 50 | self.assertEqual(requests.get(self.url_face_restorers).status_code, 200) 51 | 52 | def test_realesrgan_models(self): 53 | self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200) 54 | 55 | def test_prompt_styles(self): 56 | self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200) 57 | 58 | def test_embeddings(self): 59 | self.assertEqual(requests.get(self.url_embeddings).status_code, 200) 60 | 61 | if __name__ == "__main__": 62 | unittest.main() 63 | -------------------------------------------------------------------------------- /test/basic_features/img2img_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import requests 3 | from gradio.processing_utils import encode_pil_to_base64 4 | from PIL import Image 5 | 6 | 7 | class TestImg2ImgWorking(unittest.TestCase): 8 | def setUp(self): 9 | self.url_img2img = "http://localhost:7860/sdapi/v1/img2img" 10 | self.simple_img2img = { 11 | "init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))], 12 | "resize_mode": 0, 13 | "denoising_strength": 0.75, 14 | "mask": None, 15 | "mask_blur": 4, 16 | "inpainting_fill": 0, 17 | "inpaint_full_res": False, 18 | "inpaint_full_res_padding": 0, 19 | "inpainting_mask_invert": False, 20 | "prompt": "example prompt", 21 | "styles": [], 22 | "seed": -1, 23 | "subseed": -1, 24 | "subseed_strength": 0, 25 | "seed_resize_from_h": -1, 26 | "seed_resize_from_w": -1, 27 | "batch_size": 1, 28 | "n_iter": 1, 29 | "steps": 3, 30 | "cfg_scale": 7, 31 | "width": 64, 32 | "height": 64, 33 | "restore_faces": False, 34 | "tiling": False, 35 | "negative_prompt": "", 36 | "eta": 0, 37 | "s_churn": 0, 38 | "s_tmax": 0, 39 | "s_tmin": 0, 40 | "s_noise": 1, 41 | "override_settings": {}, 42 | "sampler_index": "Euler a", 43 | "include_init_images": False 44 | } 45 | 46 | def test_img2img_simple_performed(self): 47 | self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) 48 | 49 | def test_inpainting_masked_performed(self): 50 | self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) 51 | self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) 52 | 53 | def test_inpainting_with_inverted_masked_performed(self): 54 | self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) 55 | self.simple_img2img["inpainting_mask_invert"] = True 56 | self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) 57 | 58 | def test_img2img_sd_upscale_performed(self): 59 | self.simple_img2img["script_name"] = "sd upscale" 60 | self.simple_img2img["script_args"] = ["", 8, "Lanczos", 2.0] 61 | 62 | self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) 63 | 64 | 65 | if __name__ == "__main__": 66 | unittest.main() 67 | -------------------------------------------------------------------------------- /modules/ui_tempdir.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from collections import namedtuple 4 | from pathlib import Path 5 | 6 | import gradio as gr 7 | 8 | from PIL import PngImagePlugin 9 | 10 | from modules import shared 11 | 12 | 13 | Savedfile = namedtuple("Savedfile", ["name"]) 14 | 15 | 16 | def register_tmp_file(gradio, filename): 17 | if hasattr(gradio, 'temp_file_sets'): # gradio 3.15 18 | gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)} 19 | 20 | if hasattr(gradio, 'temp_dirs'): # gradio 3.9 21 | gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))} 22 | 23 | 24 | def check_tmp_file(gradio, filename): 25 | if hasattr(gradio, 'temp_file_sets'): 26 | return any([filename in fileset for fileset in gradio.temp_file_sets]) 27 | 28 | if hasattr(gradio, 'temp_dirs'): 29 | return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs) 30 | 31 | return False 32 | 33 | 34 | def save_pil_to_file(pil_image, dir=None): 35 | already_saved_as = getattr(pil_image, 'already_saved_as', None) 36 | if already_saved_as and os.path.isfile(already_saved_as): 37 | register_tmp_file(shared.demo, already_saved_as) 38 | 39 | file_obj = Savedfile(already_saved_as) 40 | return file_obj 41 | 42 | if shared.opts.temp_dir != "": 43 | dir = shared.opts.temp_dir 44 | 45 | use_metadata = False 46 | metadata = PngImagePlugin.PngInfo() 47 | for key, value in pil_image.info.items(): 48 | if isinstance(key, str) and isinstance(value, str): 49 | metadata.add_text(key, value) 50 | use_metadata = True 51 | 52 | file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) 53 | pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) 54 | return file_obj 55 | 56 | 57 | # override save to file function so that it also writes PNG info 58 | gr.processing_utils.save_pil_to_file = save_pil_to_file 59 | 60 | 61 | def on_tmpdir_changed(): 62 | if shared.opts.temp_dir == "" or shared.demo is None: 63 | return 64 | 65 | os.makedirs(shared.opts.temp_dir, exist_ok=True) 66 | 67 | register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x")) 68 | 69 | 70 | def cleanup_tmpdr(): 71 | temp_dir = shared.opts.temp_dir 72 | if temp_dir == "" or not os.path.isdir(temp_dir): 73 | return 74 | 75 | for root, dirs, files in os.walk(temp_dir, topdown=False): 76 | for name in files: 77 | _, extension = os.path.splitext(name) 78 | if extension != ".png": 79 | continue 80 | 81 | filename = os.path.join(root, name) 82 | os.remove(filename) 83 | -------------------------------------------------------------------------------- /modules/memmon.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | from collections import defaultdict 4 | 5 | import torch 6 | 7 | 8 | class MemUsageMonitor(threading.Thread): 9 | run_flag = None 10 | device = None 11 | disabled = False 12 | opts = None 13 | data = None 14 | 15 | def __init__(self, name, device, opts): 16 | threading.Thread.__init__(self) 17 | self.name = name 18 | self.device = device 19 | self.opts = opts 20 | 21 | self.daemon = True 22 | self.run_flag = threading.Event() 23 | self.data = defaultdict(int) 24 | 25 | try: 26 | torch.cuda.mem_get_info() 27 | torch.cuda.memory_stats(self.device) 28 | except Exception as e: # AMD or whatever 29 | print(f"Warning: caught exception '{e}', memory monitor disabled") 30 | self.disabled = True 31 | 32 | def run(self): 33 | if self.disabled: 34 | return 35 | 36 | while True: 37 | self.run_flag.wait() 38 | 39 | torch.cuda.reset_peak_memory_stats() 40 | self.data.clear() 41 | 42 | if self.opts.memmon_poll_rate <= 0: 43 | self.run_flag.clear() 44 | continue 45 | 46 | self.data["min_free"] = torch.cuda.mem_get_info()[0] 47 | 48 | while self.run_flag.is_set(): 49 | free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug? 50 | self.data["min_free"] = min(self.data["min_free"], free) 51 | 52 | time.sleep(1 / self.opts.memmon_poll_rate) 53 | 54 | def dump_debug(self): 55 | print(self, 'recorded data:') 56 | for k, v in self.read().items(): 57 | print(k, -(v // -(1024 ** 2))) 58 | 59 | print(self, 'raw torch memory stats:') 60 | tm = torch.cuda.memory_stats(self.device) 61 | for k, v in tm.items(): 62 | if 'bytes' not in k: 63 | continue 64 | print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2))) 65 | 66 | print(torch.cuda.memory_summary()) 67 | 68 | def monitor(self): 69 | self.run_flag.set() 70 | 71 | def read(self): 72 | if not self.disabled: 73 | free, total = torch.cuda.mem_get_info() 74 | self.data["free"] = free 75 | self.data["total"] = total 76 | 77 | torch_stats = torch.cuda.memory_stats(self.device) 78 | self.data["active"] = torch_stats["active.all.current"] 79 | self.data["active_peak"] = torch_stats["active_bytes.all.peak"] 80 | self.data["reserved"] = torch_stats["reserved_bytes.all.current"] 81 | self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] 82 | self.data["system_peak"] = total - self.data["min_free"] 83 | 84 | return self.data 85 | 86 | def stop(self): 87 | self.run_flag.clear() 88 | return self.read() 89 | -------------------------------------------------------------------------------- /modules/ui_postprocessing.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue 3 | import modules.generation_parameters_copypaste as parameters_copypaste 4 | 5 | 6 | def create_ui(): 7 | tab_index = gr.State(value=0) 8 | 9 | with gr.Row().style(equal_height=False, variant='compact'): 10 | with gr.Column(variant='compact'): 11 | with gr.Tabs(elem_id="mode_extras"): 12 | with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single: 13 | extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") 14 | 15 | with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch: 16 | image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") 17 | 18 | with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir: 19 | extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") 20 | extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") 21 | show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") 22 | 23 | submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') 24 | 25 | script_inputs = scripts.scripts_postproc.setup_ui() 26 | 27 | with gr.Column(): 28 | result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) 29 | 30 | tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) 31 | tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index]) 32 | tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) 33 | 34 | submit.click( 35 | fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), 36 | inputs=[ 37 | tab_index, 38 | extras_image, 39 | image_batch, 40 | extras_batch_input_dir, 41 | extras_batch_output_dir, 42 | show_extras_results, 43 | *script_inputs 44 | ], 45 | outputs=[ 46 | result_images, 47 | html_info_x, 48 | html_info, 49 | ] 50 | ) 51 | 52 | parameters_copypaste.add_paste_fields("extras", extras_image, None) 53 | 54 | extras_image.change( 55 | fn=scripts.scripts_postproc.image_changed, 56 | inputs=[], outputs=[] 57 | ) 58 | -------------------------------------------------------------------------------- /configs/instruct-pix2pix.yaml: -------------------------------------------------------------------------------- 1 | # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). 2 | # See more details in LICENSE. 3 | 4 | model: 5 | base_learning_rate: 1.0e-04 6 | target: modules.models.diffusion.ddpm_edit.LatentDiffusion 7 | params: 8 | linear_start: 0.00085 9 | linear_end: 0.0120 10 | num_timesteps_cond: 1 11 | log_every_t: 200 12 | timesteps: 1000 13 | first_stage_key: edited 14 | cond_stage_key: edit 15 | # image_size: 64 16 | # image_size: 32 17 | image_size: 16 18 | channels: 4 19 | cond_stage_trainable: false # Note: different from the one we trained before 20 | conditioning_key: hybrid 21 | monitor: val/loss_simple_ema 22 | scale_factor: 0.18215 23 | use_ema: false 24 | 25 | scheduler_config: # 10000 warmup steps 26 | target: ldm.lr_scheduler.LambdaLinearScheduler 27 | params: 28 | warm_up_steps: [ 0 ] 29 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 30 | f_start: [ 1.e-6 ] 31 | f_max: [ 1. ] 32 | f_min: [ 1. ] 33 | 34 | unet_config: 35 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 36 | params: 37 | image_size: 32 # unused 38 | in_channels: 8 39 | out_channels: 4 40 | model_channels: 320 41 | attention_resolutions: [ 4, 2, 1 ] 42 | num_res_blocks: 2 43 | channel_mult: [ 1, 2, 4, 4 ] 44 | num_heads: 8 45 | use_spatial_transformer: True 46 | transformer_depth: 1 47 | context_dim: 768 48 | use_checkpoint: True 49 | legacy: False 50 | 51 | first_stage_config: 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | embed_dim: 4 55 | monitor: val/rec_loss 56 | ddconfig: 57 | double_z: true 58 | z_channels: 4 59 | resolution: 256 60 | in_channels: 3 61 | out_ch: 3 62 | ch: 128 63 | ch_mult: 64 | - 1 65 | - 2 66 | - 4 67 | - 4 68 | num_res_blocks: 2 69 | attn_resolutions: [] 70 | dropout: 0.0 71 | lossconfig: 72 | target: torch.nn.Identity 73 | 74 | cond_stage_config: 75 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 76 | 77 | data: 78 | target: main.DataModuleFromConfig 79 | params: 80 | batch_size: 128 81 | num_workers: 1 82 | wrap: false 83 | validation: 84 | target: edit_dataset.EditDataset 85 | params: 86 | path: data/clip-filtered-dataset 87 | cache_dir: data/ 88 | cache_name: data_10k 89 | split: val 90 | min_text_sim: 0.2 91 | min_image_sim: 0.75 92 | min_direction_sim: 0.2 93 | max_samples_per_prompt: 1 94 | min_resize_res: 512 95 | max_resize_res: 512 96 | crop_res: 512 97 | output_as_edit: False 98 | real_input: True 99 | -------------------------------------------------------------------------------- /extensions-builtin/LDSR/scripts/ldsr_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | from basicsr.utils.download_util import load_file_from_url 6 | 7 | from modules.upscaler import Upscaler, UpscalerData 8 | from ldsr_model_arch import LDSR 9 | from modules import shared, script_callbacks 10 | import sd_hijack_autoencoder, sd_hijack_ddpm_v1 11 | 12 | 13 | class UpscalerLDSR(Upscaler): 14 | def __init__(self, user_path): 15 | self.name = "LDSR" 16 | self.user_path = user_path 17 | self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" 18 | self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" 19 | super().__init__() 20 | scaler_data = UpscalerData("LDSR", None, self) 21 | self.scalers = [scaler_data] 22 | 23 | def load_model(self, path: str): 24 | # Remove incorrect project.yaml file if too big 25 | yaml_path = os.path.join(self.model_path, "project.yaml") 26 | old_model_path = os.path.join(self.model_path, "model.pth") 27 | new_model_path = os.path.join(self.model_path, "model.ckpt") 28 | safetensors_model_path = os.path.join(self.model_path, "model.safetensors") 29 | if os.path.exists(yaml_path): 30 | statinfo = os.stat(yaml_path) 31 | if statinfo.st_size >= 10485760: 32 | print("Removing invalid LDSR YAML file.") 33 | os.remove(yaml_path) 34 | if os.path.exists(old_model_path): 35 | print("Renaming model from model.pth to model.ckpt") 36 | os.rename(old_model_path, new_model_path) 37 | if os.path.exists(safetensors_model_path): 38 | model = safetensors_model_path 39 | else: 40 | model = load_file_from_url(url=self.model_url, model_dir=self.model_path, 41 | file_name="model.ckpt", progress=True) 42 | yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path, 43 | file_name="project.yaml", progress=True) 44 | 45 | try: 46 | return LDSR(model, yaml) 47 | 48 | except Exception: 49 | print("Error importing LDSR:", file=sys.stderr) 50 | print(traceback.format_exc(), file=sys.stderr) 51 | return None 52 | 53 | def do_upscale(self, img, path): 54 | ldsr = self.load_model(path) 55 | if ldsr is None: 56 | print("NO LDSR!") 57 | return img 58 | ddim_steps = shared.opts.ldsr_steps 59 | return ldsr.super_resolution(img, ddim_steps, self.scale) 60 | 61 | 62 | def on_ui_settings(): 63 | import gradio as gr 64 | 65 | shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling"))) 66 | shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling"))) 67 | 68 | 69 | script_callbacks.on_ui_settings(on_ui_settings) 70 | -------------------------------------------------------------------------------- /modules/textual_inversion/learn_schedule.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | 3 | 4 | class LearnScheduleIterator: 5 | def __init__(self, learn_rate, max_steps, cur_step=0): 6 | """ 7 | specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000 8 | """ 9 | 10 | pairs = learn_rate.split(',') 11 | self.rates = [] 12 | self.it = 0 13 | self.maxit = 0 14 | try: 15 | for i, pair in enumerate(pairs): 16 | if not pair.strip(): 17 | continue 18 | tmp = pair.split(':') 19 | if len(tmp) == 2: 20 | step = int(tmp[1]) 21 | if step > cur_step: 22 | self.rates.append((float(tmp[0]), min(step, max_steps))) 23 | self.maxit += 1 24 | if step > max_steps: 25 | return 26 | elif step == -1: 27 | self.rates.append((float(tmp[0]), max_steps)) 28 | self.maxit += 1 29 | return 30 | else: 31 | self.rates.append((float(tmp[0]), max_steps)) 32 | self.maxit += 1 33 | return 34 | assert self.rates 35 | except (ValueError, AssertionError): 36 | raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') 37 | 38 | 39 | def __iter__(self): 40 | return self 41 | 42 | def __next__(self): 43 | if self.it < self.maxit: 44 | self.it += 1 45 | return self.rates[self.it - 1] 46 | else: 47 | raise StopIteration 48 | 49 | 50 | class LearnRateScheduler: 51 | def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True): 52 | self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step) 53 | (self.learn_rate, self.end_step) = next(self.schedules) 54 | self.verbose = verbose 55 | 56 | if self.verbose: 57 | print(f'Training at rate of {self.learn_rate} until step {self.end_step}') 58 | 59 | self.finished = False 60 | 61 | def step(self, step_number): 62 | if step_number < self.end_step: 63 | return False 64 | 65 | try: 66 | (self.learn_rate, self.end_step) = next(self.schedules) 67 | except StopIteration: 68 | self.finished = True 69 | return False 70 | return True 71 | 72 | def apply(self, optimizer, step_number): 73 | if not self.step(step_number): 74 | return 75 | 76 | if self.verbose: 77 | tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}') 78 | 79 | for pg in optimizer.param_groups: 80 | pg['lr'] = self.learn_rate 81 | 82 | -------------------------------------------------------------------------------- /modules/txt2img.py: -------------------------------------------------------------------------------- 1 | import modules.scripts 2 | from modules import sd_samplers 3 | from modules.generation_parameters_copypaste import create_override_settings_dict 4 | from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \ 5 | StableDiffusionProcessingImg2Img, process_images 6 | from modules.shared import opts, cmd_opts 7 | import modules.shared as shared 8 | import modules.processing as processing 9 | from modules.ui import plaintext_to_html 10 | 11 | 12 | def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args): 13 | override_settings = create_override_settings_dict(override_settings_texts) 14 | 15 | p = StableDiffusionProcessingTxt2Img( 16 | sd_model=shared.sd_model, 17 | outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, 18 | outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids, 19 | prompt=prompt, 20 | styles=prompt_styles, 21 | negative_prompt=negative_prompt, 22 | seed=seed, 23 | subseed=subseed, 24 | subseed_strength=subseed_strength, 25 | seed_resize_from_h=seed_resize_from_h, 26 | seed_resize_from_w=seed_resize_from_w, 27 | seed_enable_extras=seed_enable_extras, 28 | sampler_name=sd_samplers.samplers[sampler_index].name, 29 | batch_size=batch_size, 30 | n_iter=n_iter, 31 | steps=steps, 32 | cfg_scale=cfg_scale, 33 | width=width, 34 | height=height, 35 | restore_faces=restore_faces, 36 | tiling=tiling, 37 | enable_hr=enable_hr, 38 | denoising_strength=denoising_strength if enable_hr else None, 39 | hr_scale=hr_scale, 40 | hr_upscaler=hr_upscaler, 41 | hr_second_pass_steps=hr_second_pass_steps, 42 | hr_resize_x=hr_resize_x, 43 | hr_resize_y=hr_resize_y, 44 | override_settings=override_settings, 45 | ) 46 | 47 | p.scripts = modules.scripts.scripts_txt2img 48 | p.script_args = args 49 | 50 | if cmd_opts.enable_console_prompts: 51 | print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) 52 | 53 | processed = modules.scripts.scripts_txt2img.run(p, *args) 54 | 55 | if processed is None: 56 | processed = process_images(p) 57 | 58 | p.close() 59 | 60 | shared.total_tqdm.clear() 61 | 62 | generation_info_js = processed.js() 63 | if opts.samples_log_stdout: 64 | print(generation_info_js) 65 | 66 | if opts.do_not_show_images: 67 | processed.images = [] 68 | 69 | return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments) 70 | -------------------------------------------------------------------------------- /script.js: -------------------------------------------------------------------------------- 1 | function gradioApp() { 2 | const elems = document.getElementsByTagName('gradio-app') 3 | const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot 4 | return !!gradioShadowRoot ? gradioShadowRoot : document; 5 | } 6 | 7 | function get_uiCurrentTab() { 8 | return gradioApp().querySelector('#tabs button:not(.border-transparent)') 9 | } 10 | 11 | function get_uiCurrentTabContent() { 12 | return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])') 13 | } 14 | 15 | uiUpdateCallbacks = [] 16 | uiLoadedCallbacks = [] 17 | uiTabChangeCallbacks = [] 18 | optionsChangedCallbacks = [] 19 | let uiCurrentTab = null 20 | 21 | function onUiUpdate(callback){ 22 | uiUpdateCallbacks.push(callback) 23 | } 24 | function onUiLoaded(callback){ 25 | uiLoadedCallbacks.push(callback) 26 | } 27 | function onUiTabChange(callback){ 28 | uiTabChangeCallbacks.push(callback) 29 | } 30 | function onOptionsChanged(callback){ 31 | optionsChangedCallbacks.push(callback) 32 | } 33 | 34 | function runCallback(x, m){ 35 | try { 36 | x(m) 37 | } catch (e) { 38 | (console.error || console.log).call(console, e.message, e); 39 | } 40 | } 41 | function executeCallbacks(queue, m) { 42 | queue.forEach(function(x){runCallback(x, m)}) 43 | } 44 | 45 | var executedOnLoaded = false; 46 | 47 | document.addEventListener("DOMContentLoaded", function() { 48 | var mutationObserver = new MutationObserver(function(m){ 49 | if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){ 50 | executedOnLoaded = true; 51 | executeCallbacks(uiLoadedCallbacks); 52 | } 53 | 54 | executeCallbacks(uiUpdateCallbacks, m); 55 | const newTab = get_uiCurrentTab(); 56 | if ( newTab && ( newTab !== uiCurrentTab ) ) { 57 | uiCurrentTab = newTab; 58 | executeCallbacks(uiTabChangeCallbacks); 59 | } 60 | }); 61 | mutationObserver.observe( gradioApp(), { childList:true, subtree:true }) 62 | }); 63 | 64 | /** 65 | * Add a ctrl+enter as a shortcut to start a generation 66 | */ 67 | document.addEventListener('keydown', function(e) { 68 | var handled = false; 69 | if (e.key !== undefined) { 70 | if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; 71 | } else if (e.keyCode !== undefined) { 72 | if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; 73 | } 74 | if (handled) { 75 | button = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); 76 | if (button) { 77 | button.click(); 78 | } 79 | e.preventDefault(); 80 | } 81 | }) 82 | 83 | /** 84 | * checks that a UI element is not in another hidden element or tab content 85 | */ 86 | function uiElementIsVisible(el) { 87 | let isVisible = !el.closest('.\\!hidden'); 88 | if ( ! isVisible ) { 89 | return false; 90 | } 91 | 92 | while( isVisible = el.closest('.tabitem')?.style.display !== 'none' ) { 93 | if ( ! isVisible ) { 94 | return false; 95 | } else if ( el.parentElement ) { 96 | el = el.parentElement 97 | } else { 98 | break; 99 | } 100 | } 101 | return isVisible; 102 | } 103 | -------------------------------------------------------------------------------- /modules/deepbooru.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import torch 5 | from PIL import Image 6 | import numpy as np 7 | 8 | from modules import modelloader, paths, deepbooru_model, devices, images, shared 9 | 10 | re_special = re.compile(r'([\\()])') 11 | 12 | 13 | class DeepDanbooru: 14 | def __init__(self): 15 | self.model = None 16 | 17 | def load(self): 18 | if self.model is not None: 19 | return 20 | 21 | files = modelloader.load_models( 22 | model_path=os.path.join(paths.models_path, "torch_deepdanbooru"), 23 | model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt', 24 | ext_filter=[".pt"], 25 | download_name='model-resnet_custom_v3.pt', 26 | ) 27 | 28 | self.model = deepbooru_model.DeepDanbooruModel() 29 | self.model.load_state_dict(torch.load(files[0], map_location="cpu")) 30 | 31 | self.model.eval() 32 | self.model.to(devices.cpu, devices.dtype) 33 | 34 | def start(self): 35 | self.load() 36 | self.model.to(devices.device) 37 | 38 | def stop(self): 39 | if not shared.opts.interrogate_keep_models_in_memory: 40 | self.model.to(devices.cpu) 41 | devices.torch_gc() 42 | 43 | def tag(self, pil_image): 44 | self.start() 45 | res = self.tag_multi(pil_image) 46 | self.stop() 47 | 48 | return res 49 | 50 | def tag_multi(self, pil_image, force_disable_ranks=False): 51 | threshold = shared.opts.interrogate_deepbooru_score_threshold 52 | use_spaces = shared.opts.deepbooru_use_spaces 53 | use_escape = shared.opts.deepbooru_escape 54 | alpha_sort = shared.opts.deepbooru_sort_alpha 55 | include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks 56 | 57 | pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512) 58 | a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255 59 | 60 | with torch.no_grad(), devices.autocast(): 61 | x = torch.from_numpy(a).to(devices.device) 62 | y = self.model(x)[0].detach().cpu().numpy() 63 | 64 | probability_dict = {} 65 | 66 | for tag, probability in zip(self.model.tags, y): 67 | if probability < threshold: 68 | continue 69 | 70 | if tag.startswith("rating:"): 71 | continue 72 | 73 | probability_dict[tag] = probability 74 | 75 | if alpha_sort: 76 | tags = sorted(probability_dict) 77 | else: 78 | tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])] 79 | 80 | res = [] 81 | 82 | filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")]) 83 | 84 | for tag in [x for x in tags if x not in filtertags]: 85 | probability = probability_dict[tag] 86 | tag_outformat = tag 87 | if use_spaces: 88 | tag_outformat = tag_outformat.replace('_', ' ') 89 | if use_escape: 90 | tag_outformat = re.sub(re_special, r'\\\1', tag_outformat) 91 | if include_ranks: 92 | tag_outformat = f"({tag_outformat}:{probability:.3f})" 93 | 94 | res.append(tag_outformat) 95 | 96 | return ", ".join(res) 97 | 98 | 99 | model = DeepDanbooru() 100 | -------------------------------------------------------------------------------- /modules/styles.py: -------------------------------------------------------------------------------- 1 | # We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime 2 | from __future__ import annotations 3 | 4 | import csv 5 | import os 6 | import os.path 7 | import typing 8 | import collections.abc as abc 9 | import tempfile 10 | import shutil 11 | 12 | if typing.TYPE_CHECKING: 13 | # Only import this when code is being type-checked, it doesn't have any effect at runtime 14 | from .processing import StableDiffusionProcessing 15 | 16 | 17 | class PromptStyle(typing.NamedTuple): 18 | name: str 19 | prompt: str 20 | negative_prompt: str 21 | 22 | 23 | def merge_prompts(style_prompt: str, prompt: str) -> str: 24 | if "{prompt}" in style_prompt: 25 | res = style_prompt.replace("{prompt}", prompt) 26 | else: 27 | parts = filter(None, (prompt.strip(), style_prompt.strip())) 28 | res = ", ".join(parts) 29 | 30 | return res 31 | 32 | 33 | def apply_styles_to_prompt(prompt, styles): 34 | for style in styles: 35 | prompt = merge_prompts(style, prompt) 36 | 37 | return prompt 38 | 39 | 40 | class StyleDatabase: 41 | def __init__(self, path: str): 42 | self.no_style = PromptStyle("None", "", "") 43 | self.styles = {} 44 | self.path = path 45 | 46 | self.reload() 47 | 48 | def reload(self): 49 | self.styles.clear() 50 | 51 | if not os.path.exists(self.path): 52 | return 53 | 54 | with open(self.path, "r", encoding="utf-8-sig", newline='') as file: 55 | reader = csv.DictReader(file) 56 | for row in reader: 57 | # Support loading old CSV format with "name, text"-columns 58 | prompt = row["prompt"] if "prompt" in row else row["text"] 59 | negative_prompt = row.get("negative_prompt", "") 60 | self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt) 61 | 62 | def get_style_prompts(self, styles): 63 | return [self.styles.get(x, self.no_style).prompt for x in styles] 64 | 65 | def get_negative_style_prompts(self, styles): 66 | return [self.styles.get(x, self.no_style).negative_prompt for x in styles] 67 | 68 | def apply_styles_to_prompt(self, prompt, styles): 69 | return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles]) 70 | 71 | def apply_negative_styles_to_prompt(self, prompt, styles): 72 | return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]) 73 | 74 | def save_styles(self, path: str) -> None: 75 | # Write to temporary file first, so we don't nuke the file if something goes wrong 76 | fd, temp_path = tempfile.mkstemp(".csv") 77 | with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: 78 | # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, 79 | # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() 80 | writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) 81 | writer.writeheader() 82 | writer.writerows(style._asdict() for k, style in self.styles.items()) 83 | 84 | # Always keep a backup file around 85 | if os.path.exists(path): 86 | shutil.move(path, path + ".bak") 87 | shutil.move(temp_path, path) 88 | -------------------------------------------------------------------------------- /javascript/dragdrop.js: -------------------------------------------------------------------------------- 1 | // allows drag-dropping files into gradio image elements, and also pasting images from clipboard 2 | 3 | function isValidImageList( files ) { 4 | return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type); 5 | } 6 | 7 | function dropReplaceImage( imgWrap, files ) { 8 | if ( ! isValidImageList( files ) ) { 9 | return; 10 | } 11 | 12 | const tmpFile = files[0]; 13 | 14 | imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click(); 15 | const callback = () => { 16 | const fileInput = imgWrap.querySelector('input[type="file"]'); 17 | if ( fileInput ) { 18 | if ( files.length === 0 ) { 19 | files = new DataTransfer(); 20 | files.items.add(tmpFile); 21 | fileInput.files = files.files; 22 | } else { 23 | fileInput.files = files; 24 | } 25 | fileInput.dispatchEvent(new Event('change')); 26 | } 27 | }; 28 | 29 | if ( imgWrap.closest('#pnginfo_image') ) { 30 | // special treatment for PNG Info tab, wait for fetch request to finish 31 | const oldFetch = window.fetch; 32 | window.fetch = async (input, options) => { 33 | const response = await oldFetch(input, options); 34 | if ( 'api/predict/' === input ) { 35 | const content = await response.text(); 36 | window.fetch = oldFetch; 37 | window.requestAnimationFrame( () => callback() ); 38 | return new Response(content, { 39 | status: response.status, 40 | statusText: response.statusText, 41 | headers: response.headers 42 | }) 43 | } 44 | return response; 45 | }; 46 | } else { 47 | window.requestAnimationFrame( () => callback() ); 48 | } 49 | } 50 | 51 | window.document.addEventListener('dragover', e => { 52 | const target = e.composedPath()[0]; 53 | const imgWrap = target.closest('[data-testid="image"]'); 54 | if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) { 55 | return; 56 | } 57 | e.stopPropagation(); 58 | e.preventDefault(); 59 | e.dataTransfer.dropEffect = 'copy'; 60 | }); 61 | 62 | window.document.addEventListener('drop', e => { 63 | const target = e.composedPath()[0]; 64 | if (target.placeholder.indexOf("Prompt") == -1) { 65 | return; 66 | } 67 | const imgWrap = target.closest('[data-testid="image"]'); 68 | if ( !imgWrap ) { 69 | return; 70 | } 71 | e.stopPropagation(); 72 | e.preventDefault(); 73 | const files = e.dataTransfer.files; 74 | dropReplaceImage( imgWrap, files ); 75 | }); 76 | 77 | window.addEventListener('paste', e => { 78 | const files = e.clipboardData.files; 79 | if ( ! isValidImageList( files ) ) { 80 | return; 81 | } 82 | 83 | const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')] 84 | .filter(el => uiElementIsVisible(el)); 85 | if ( ! visibleImageFields.length ) { 86 | return; 87 | } 88 | 89 | const firstFreeImageField = visibleImageFields 90 | .filter(el => el.querySelector('input[type=file]'))?.[0]; 91 | 92 | dropReplaceImage( 93 | firstFreeImageField ? 94 | firstFreeImageField : 95 | visibleImageFields[visibleImageFields.length - 1] 96 | , files ); 97 | }); 98 | -------------------------------------------------------------------------------- /extensions-builtin/ScuNET/scripts/scunet_model.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import sys 3 | import traceback 4 | 5 | import PIL.Image 6 | import numpy as np 7 | import torch 8 | from basicsr.utils.download_util import load_file_from_url 9 | 10 | import modules.upscaler 11 | from modules import devices, modelloader 12 | from scunet_model_arch import SCUNet as net 13 | 14 | 15 | class UpscalerScuNET(modules.upscaler.Upscaler): 16 | def __init__(self, dirname): 17 | self.name = "ScuNET" 18 | self.model_name = "ScuNET GAN" 19 | self.model_name2 = "ScuNET PSNR" 20 | self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth" 21 | self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth" 22 | self.user_path = dirname 23 | super().__init__() 24 | model_paths = self.find_models(ext_filter=[".pth"]) 25 | scalers = [] 26 | add_model2 = True 27 | for file in model_paths: 28 | if "http" in file: 29 | name = self.model_name 30 | else: 31 | name = modelloader.friendly_name(file) 32 | if name == self.model_name2 or file == self.model_url2: 33 | add_model2 = False 34 | try: 35 | scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) 36 | scalers.append(scaler_data) 37 | except Exception: 38 | print(f"Error loading ScuNET model: {file}", file=sys.stderr) 39 | print(traceback.format_exc(), file=sys.stderr) 40 | if add_model2: 41 | scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) 42 | scalers.append(scaler_data2) 43 | self.scalers = scalers 44 | 45 | def do_upscale(self, img: PIL.Image, selected_file): 46 | torch.cuda.empty_cache() 47 | 48 | model = self.load_model(selected_file) 49 | if model is None: 50 | return img 51 | 52 | device = devices.get_device_for('scunet') 53 | img = np.array(img) 54 | img = img[:, :, ::-1] 55 | img = np.moveaxis(img, 2, 0) / 255 56 | img = torch.from_numpy(img).float() 57 | img = img.unsqueeze(0).to(device) 58 | 59 | with torch.no_grad(): 60 | output = model(img) 61 | output = output.squeeze().float().cpu().clamp_(0, 1).numpy() 62 | output = 255. * np.moveaxis(output, 0, 2) 63 | output = output.astype(np.uint8) 64 | output = output[:, :, ::-1] 65 | torch.cuda.empty_cache() 66 | return PIL.Image.fromarray(output, 'RGB') 67 | 68 | def load_model(self, path: str): 69 | device = devices.get_device_for('scunet') 70 | if "http" in path: 71 | filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, 72 | progress=True) 73 | else: 74 | filename = path 75 | if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: 76 | print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr) 77 | return None 78 | 79 | model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) 80 | model.load_state_dict(torch.load(filename), strict=True) 81 | model.eval() 82 | for k, v in model.named_parameters(): 83 | v.requires_grad = False 84 | model = model.to(device) 85 | 86 | return model 87 | 88 | -------------------------------------------------------------------------------- /test/basic_features/txt2img_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import requests 3 | 4 | 5 | class TestTxt2ImgWorking(unittest.TestCase): 6 | def setUp(self): 7 | self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img" 8 | self.simple_txt2img = { 9 | "enable_hr": False, 10 | "denoising_strength": 0, 11 | "firstphase_width": 0, 12 | "firstphase_height": 0, 13 | "prompt": "example prompt", 14 | "styles": [], 15 | "seed": -1, 16 | "subseed": -1, 17 | "subseed_strength": 0, 18 | "seed_resize_from_h": -1, 19 | "seed_resize_from_w": -1, 20 | "batch_size": 1, 21 | "n_iter": 1, 22 | "steps": 3, 23 | "cfg_scale": 7, 24 | "width": 64, 25 | "height": 64, 26 | "restore_faces": False, 27 | "tiling": False, 28 | "negative_prompt": "", 29 | "eta": 0, 30 | "s_churn": 0, 31 | "s_tmax": 0, 32 | "s_tmin": 0, 33 | "s_noise": 1, 34 | "sampler_index": "Euler a" 35 | } 36 | 37 | def test_txt2img_simple_performed(self): 38 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 39 | 40 | def test_txt2img_with_negative_prompt_performed(self): 41 | self.simple_txt2img["negative_prompt"] = "example negative prompt" 42 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 43 | 44 | def test_txt2img_with_complex_prompt_performed(self): 45 | self.simple_txt2img["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]" 46 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 47 | 48 | def test_txt2img_not_square_image_performed(self): 49 | self.simple_txt2img["height"] = 128 50 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 51 | 52 | def test_txt2img_with_hrfix_performed(self): 53 | self.simple_txt2img["enable_hr"] = True 54 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 55 | 56 | def test_txt2img_with_tiling_performed(self): 57 | self.simple_txt2img["tiling"] = True 58 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 59 | 60 | def test_txt2img_with_restore_faces_performed(self): 61 | self.simple_txt2img["restore_faces"] = True 62 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 63 | 64 | def test_txt2img_with_vanilla_sampler_performed(self): 65 | self.simple_txt2img["sampler_index"] = "PLMS" 66 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 67 | self.simple_txt2img["sampler_index"] = "DDIM" 68 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 69 | 70 | def test_txt2img_multiple_batches_performed(self): 71 | self.simple_txt2img["n_iter"] = 2 72 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 73 | 74 | def test_txt2img_batch_performed(self): 75 | self.simple_txt2img["batch_size"] = 2 76 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) 77 | 78 | 79 | if __name__ == "__main__": 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /modules/mac_specific.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules import paths 3 | from modules.sd_hijack_utils import CondFunc 4 | from packaging import version 5 | 6 | 7 | # has_mps is only available in nightly pytorch (for now) and macOS 12.3+. 8 | # check `getattr` and try it for compatibility 9 | def check_for_mps() -> bool: 10 | if not getattr(torch, 'has_mps', False): 11 | return False 12 | try: 13 | torch.zeros(1).to(torch.device("mps")) 14 | return True 15 | except Exception: 16 | return False 17 | has_mps = check_for_mps() 18 | 19 | 20 | # MPS workaround for https://github.com/pytorch/pytorch/issues/89784 21 | def cumsum_fix(input, cumsum_func, *args, **kwargs): 22 | if input.device.type == 'mps': 23 | output_dtype = kwargs.get('dtype', input.dtype) 24 | if output_dtype == torch.int64: 25 | return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) 26 | elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): 27 | return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) 28 | return cumsum_func(input, *args, **kwargs) 29 | 30 | 31 | if has_mps: 32 | # MPS fix for randn in torchsde 33 | CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps') 34 | 35 | if version.parse(torch.__version__) < version.parse("1.13"): 36 | # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working 37 | 38 | # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 39 | CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), 40 | lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) 41 | # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 42 | CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), 43 | lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') 44 | # MPS workaround for https://github.com/pytorch/pytorch/issues/90532 45 | CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) 46 | elif version.parse(torch.__version__) > version.parse("1.13.1"): 47 | cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) 48 | cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) 49 | cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) 50 | CondFunc('torch.cumsum', cumsum_fix_func, None) 51 | CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) 52 | CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) 53 | 54 | -------------------------------------------------------------------------------- /modules/masking.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFilter, ImageOps 2 | 3 | 4 | def get_crop_region(mask, pad=0): 5 | """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. 6 | For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" 7 | 8 | h, w = mask.shape 9 | 10 | crop_left = 0 11 | for i in range(w): 12 | if not (mask[:, i] == 0).all(): 13 | break 14 | crop_left += 1 15 | 16 | crop_right = 0 17 | for i in reversed(range(w)): 18 | if not (mask[:, i] == 0).all(): 19 | break 20 | crop_right += 1 21 | 22 | crop_top = 0 23 | for i in range(h): 24 | if not (mask[i] == 0).all(): 25 | break 26 | crop_top += 1 27 | 28 | crop_bottom = 0 29 | for i in reversed(range(h)): 30 | if not (mask[i] == 0).all(): 31 | break 32 | crop_bottom += 1 33 | 34 | return ( 35 | int(max(crop_left-pad, 0)), 36 | int(max(crop_top-pad, 0)), 37 | int(min(w - crop_right + pad, w)), 38 | int(min(h - crop_bottom + pad, h)) 39 | ) 40 | 41 | 42 | def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height): 43 | """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region 44 | for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.""" 45 | 46 | x1, y1, x2, y2 = crop_region 47 | 48 | ratio_crop_region = (x2 - x1) / (y2 - y1) 49 | ratio_processing = processing_width / processing_height 50 | 51 | if ratio_crop_region > ratio_processing: 52 | desired_height = (x2 - x1) / ratio_processing 53 | desired_height_diff = int(desired_height - (y2-y1)) 54 | y1 -= desired_height_diff//2 55 | y2 += desired_height_diff - desired_height_diff//2 56 | if y2 >= image_height: 57 | diff = y2 - image_height 58 | y2 -= diff 59 | y1 -= diff 60 | if y1 < 0: 61 | y2 -= y1 62 | y1 -= y1 63 | if y2 >= image_height: 64 | y2 = image_height 65 | else: 66 | desired_width = (y2 - y1) * ratio_processing 67 | desired_width_diff = int(desired_width - (x2-x1)) 68 | x1 -= desired_width_diff//2 69 | x2 += desired_width_diff - desired_width_diff//2 70 | if x2 >= image_width: 71 | diff = x2 - image_width 72 | x2 -= diff 73 | x1 -= diff 74 | if x1 < 0: 75 | x2 -= x1 76 | x1 -= x1 77 | if x2 >= image_width: 78 | x2 = image_width 79 | 80 | return x1, y1, x2, y2 81 | 82 | 83 | def fill(image, mask): 84 | """fills masked regions with colors from image using blur. Not extremely effective.""" 85 | 86 | image_mod = Image.new('RGBA', (image.width, image.height)) 87 | 88 | image_masked = Image.new('RGBa', (image.width, image.height)) 89 | image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L'))) 90 | 91 | image_masked = image_masked.convert('RGBa') 92 | 93 | for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]: 94 | blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA') 95 | for _ in range(repeats): 96 | image_mod.alpha_composite(blurred) 97 | 98 | return image_mod.convert("RGB") 99 | 100 | -------------------------------------------------------------------------------- /modules/sd_hijack_clip_old.py: -------------------------------------------------------------------------------- 1 | from modules import sd_hijack_clip 2 | from modules import shared 3 | 4 | 5 | def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts): 6 | id_start = self.id_start 7 | id_end = self.id_end 8 | maxlen = self.wrapped.max_length # you get to stay at 77 9 | used_custom_terms = [] 10 | remade_batch_tokens = [] 11 | hijack_comments = [] 12 | hijack_fixes = [] 13 | token_count = 0 14 | 15 | cache = {} 16 | batch_tokens = self.tokenize(texts) 17 | batch_multipliers = [] 18 | for tokens in batch_tokens: 19 | tuple_tokens = tuple(tokens) 20 | 21 | if tuple_tokens in cache: 22 | remade_tokens, fixes, multipliers = cache[tuple_tokens] 23 | else: 24 | fixes = [] 25 | remade_tokens = [] 26 | multipliers = [] 27 | mult = 1.0 28 | 29 | i = 0 30 | while i < len(tokens): 31 | token = tokens[i] 32 | 33 | embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) 34 | 35 | mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None 36 | if mult_change is not None: 37 | mult *= mult_change 38 | i += 1 39 | elif embedding is None: 40 | remade_tokens.append(token) 41 | multipliers.append(mult) 42 | i += 1 43 | else: 44 | emb_len = int(embedding.vec.shape[0]) 45 | fixes.append((len(remade_tokens), embedding)) 46 | remade_tokens += [0] * emb_len 47 | multipliers += [mult] * emb_len 48 | used_custom_terms.append((embedding.name, embedding.checksum())) 49 | i += embedding_length_in_tokens 50 | 51 | if len(remade_tokens) > maxlen - 2: 52 | vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} 53 | ovf = remade_tokens[maxlen - 2:] 54 | overflowing_words = [vocab.get(int(x), "") for x in ovf] 55 | overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) 56 | hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") 57 | 58 | token_count = len(remade_tokens) 59 | remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) 60 | remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] 61 | cache[tuple_tokens] = (remade_tokens, fixes, multipliers) 62 | 63 | multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) 64 | multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] 65 | 66 | remade_batch_tokens.append(remade_tokens) 67 | hijack_fixes.append(fixes) 68 | batch_multipliers.append(multipliers) 69 | return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count 70 | 71 | 72 | def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts): 73 | batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts) 74 | 75 | self.hijack.comments += hijack_comments 76 | 77 | if len(used_custom_terms) > 0: 78 | self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) 79 | 80 | self.hijack.fixes = hijack_fixes 81 | return self.process_tokens(remade_batch_tokens, batch_multipliers) 82 | -------------------------------------------------------------------------------- /javascript/edit-attention.js: -------------------------------------------------------------------------------- 1 | function keyupEditAttention(event){ 2 | let target = event.originalTarget || event.composedPath()[0]; 3 | if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return; 4 | if (! (event.metaKey || event.ctrlKey)) return; 5 | 6 | let isPlus = event.key == "ArrowUp" 7 | let isMinus = event.key == "ArrowDown" 8 | if (!isPlus && !isMinus) return; 9 | 10 | let selectionStart = target.selectionStart; 11 | let selectionEnd = target.selectionEnd; 12 | let text = target.value; 13 | 14 | function selectCurrentParenthesisBlock(OPEN, CLOSE){ 15 | if (selectionStart !== selectionEnd) return false; 16 | 17 | // Find opening parenthesis around current cursor 18 | const before = text.substring(0, selectionStart); 19 | let beforeParen = before.lastIndexOf(OPEN); 20 | if (beforeParen == -1) return false; 21 | let beforeParenClose = before.lastIndexOf(CLOSE); 22 | while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { 23 | beforeParen = before.lastIndexOf(OPEN, beforeParen - 1); 24 | beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1); 25 | } 26 | 27 | // Find closing parenthesis around current cursor 28 | const after = text.substring(selectionStart); 29 | let afterParen = after.indexOf(CLOSE); 30 | if (afterParen == -1) return false; 31 | let afterParenOpen = after.indexOf(OPEN); 32 | while (afterParenOpen !== -1 && afterParen > afterParenOpen) { 33 | afterParen = after.indexOf(CLOSE, afterParen + 1); 34 | afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1); 35 | } 36 | if (beforeParen === -1 || afterParen === -1) return false; 37 | 38 | // Set the selection to the text between the parenthesis 39 | const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); 40 | const lastColon = parenContent.lastIndexOf(":"); 41 | selectionStart = beforeParen + 1; 42 | selectionEnd = selectionStart + lastColon; 43 | target.setSelectionRange(selectionStart, selectionEnd); 44 | return true; 45 | } 46 | 47 | // If the user hasn't selected anything, let's select their current parenthesis block 48 | if(! selectCurrentParenthesisBlock('<', '>')){ 49 | selectCurrentParenthesisBlock('(', ')') 50 | } 51 | 52 | event.preventDefault(); 53 | 54 | closeCharacter = ')' 55 | delta = opts.keyedit_precision_attention 56 | 57 | if (selectionStart > 0 && text[selectionStart - 1] == '<'){ 58 | closeCharacter = '>' 59 | delta = opts.keyedit_precision_extra 60 | } else if (selectionStart == 0 || text[selectionStart - 1] != "(") { 61 | 62 | // do not include spaces at the end 63 | while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){ 64 | selectionEnd -= 1; 65 | } 66 | if(selectionStart == selectionEnd){ 67 | return 68 | } 69 | 70 | text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd); 71 | 72 | selectionStart += 1; 73 | selectionEnd += 1; 74 | } 75 | 76 | end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; 77 | weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); 78 | if (isNaN(weight)) return; 79 | 80 | weight += isPlus ? delta : -delta; 81 | weight = parseFloat(weight.toPrecision(12)); 82 | if(String(weight).length == 1) weight += ".0" 83 | 84 | text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); 85 | 86 | target.focus(); 87 | target.value = text; 88 | target.selectionStart = selectionStart; 89 | target.selectionEnd = selectionEnd; 90 | 91 | updateInput(target) 92 | } 93 | 94 | addEventListener('keydown', (event) => { 95 | keyupEditAttention(event); 96 | }); -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: You think somethings is broken in the UI 3 | title: "[Bug]: " 4 | labels: ["bug-report"] 5 | 6 | body: 7 | - type: checkboxes 8 | attributes: 9 | label: Is there an existing issue for this? 10 | description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit. 11 | options: 12 | - label: I have searched the existing issues and checked the recent builds/commits 13 | required: true 14 | - type: markdown 15 | attributes: 16 | value: | 17 | *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible** 18 | - type: textarea 19 | id: what-did 20 | attributes: 21 | label: What happened? 22 | description: Tell us what happened in a very clear and simple way 23 | validations: 24 | required: true 25 | - type: textarea 26 | id: steps 27 | attributes: 28 | label: Steps to reproduce the problem 29 | description: Please provide us with precise step by step information on how to reproduce the bug 30 | value: | 31 | 1. Go to .... 32 | 2. Press .... 33 | 3. ... 34 | validations: 35 | required: true 36 | - type: textarea 37 | id: what-should 38 | attributes: 39 | label: What should have happened? 40 | description: Tell what you think the normal behavior should be 41 | validations: 42 | required: true 43 | - type: input 44 | id: commit 45 | attributes: 46 | label: Commit where the problem happens 47 | description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.) 48 | validations: 49 | required: true 50 | - type: dropdown 51 | id: platforms 52 | attributes: 53 | label: What platforms do you use to access the UI ? 54 | multiple: true 55 | options: 56 | - Windows 57 | - Linux 58 | - MacOS 59 | - iOS 60 | - Android 61 | - Other/Cloud 62 | - type: dropdown 63 | id: browsers 64 | attributes: 65 | label: What browsers do you use to access the UI ? 66 | multiple: true 67 | options: 68 | - Mozilla Firefox 69 | - Google Chrome 70 | - Brave 71 | - Apple Safari 72 | - Microsoft Edge 73 | - type: textarea 74 | id: cmdargs 75 | attributes: 76 | label: Command Line Arguments 77 | description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise. 78 | render: Shell 79 | validations: 80 | required: true 81 | - type: textarea 82 | id: extensions 83 | attributes: 84 | label: List of extensions 85 | description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise. 86 | validations: 87 | required: true 88 | - type: textarea 89 | id: logs 90 | attributes: 91 | label: Console logs 92 | description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service. 93 | render: Shell 94 | validations: 95 | required: true 96 | - type: textarea 97 | id: misc 98 | attributes: 99 | label: Additional information 100 | description: Please provide us with any relevant additional info or context. 101 | -------------------------------------------------------------------------------- /modules/extensions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | import time 6 | import git 7 | 8 | from modules import paths, shared 9 | 10 | extensions = [] 11 | extensions_dir = os.path.join(paths.data_path, "extensions") 12 | extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin") 13 | 14 | if not os.path.exists(extensions_dir): 15 | os.makedirs(extensions_dir) 16 | 17 | def active(): 18 | return [x for x in extensions if x.enabled] 19 | 20 | 21 | class Extension: 22 | def __init__(self, name, path, enabled=True, is_builtin=False): 23 | self.name = name 24 | self.path = path 25 | self.enabled = enabled 26 | self.status = '' 27 | self.can_update = False 28 | self.is_builtin = is_builtin 29 | self.version = '' 30 | 31 | repo = None 32 | try: 33 | if os.path.exists(os.path.join(path, ".git")): 34 | repo = git.Repo(path) 35 | except Exception: 36 | print(f"Error reading github repository info from {path}:", file=sys.stderr) 37 | print(traceback.format_exc(), file=sys.stderr) 38 | 39 | if repo is None or repo.bare: 40 | self.remote = None 41 | else: 42 | try: 43 | self.remote = next(repo.remote().urls, None) 44 | self.status = 'unknown' 45 | head = repo.head.commit 46 | ts = time.asctime(time.gmtime(repo.head.commit.committed_date)) 47 | self.version = f'{head.hexsha[:8]} ({ts})' 48 | 49 | except Exception: 50 | self.remote = None 51 | 52 | def list_files(self, subdir, extension): 53 | from modules import scripts 54 | 55 | dirpath = os.path.join(self.path, subdir) 56 | if not os.path.isdir(dirpath): 57 | return [] 58 | 59 | res = [] 60 | for filename in sorted(os.listdir(dirpath)): 61 | res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename))) 62 | 63 | res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] 64 | 65 | return res 66 | 67 | def check_updates(self): 68 | repo = git.Repo(self.path) 69 | for fetch in repo.remote().fetch("--dry-run"): 70 | if fetch.flags != fetch.HEAD_UPTODATE: 71 | self.can_update = True 72 | self.status = "behind" 73 | return 74 | 75 | self.can_update = False 76 | self.status = "latest" 77 | 78 | def fetch_and_reset_hard(self): 79 | repo = git.Repo(self.path) 80 | # Fix: `error: Your local changes to the following files would be overwritten by merge`, 81 | # because WSL2 Docker set 755 file permissions instead of 644, this results to the error. 82 | repo.git.fetch('--all') 83 | repo.git.reset('--hard', 'origin') 84 | 85 | 86 | def list_extensions(): 87 | extensions.clear() 88 | 89 | if not os.path.isdir(extensions_dir): 90 | return 91 | 92 | paths = [] 93 | for dirname in [extensions_dir, extensions_builtin_dir]: 94 | if not os.path.isdir(dirname): 95 | return 96 | 97 | for extension_dirname in sorted(os.listdir(dirname)): 98 | path = os.path.join(dirname, extension_dirname) 99 | if not os.path.isdir(path): 100 | continue 101 | 102 | paths.append((extension_dirname, path, dirname == extensions_builtin_dir)) 103 | 104 | for dirname, path, is_builtin in paths: 105 | extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin) 106 | extensions.append(extension) 107 | 108 | -------------------------------------------------------------------------------- /javascript/extraNetworks.js: -------------------------------------------------------------------------------- 1 | 2 | function setupExtraNetworksForTab(tabname){ 3 | gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks') 4 | 5 | var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div') 6 | var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea') 7 | var refresh = gradioApp().getElementById(tabname+'_extra_refresh') 8 | var close = gradioApp().getElementById(tabname+'_extra_close') 9 | 10 | search.classList.add('search') 11 | tabs.appendChild(search) 12 | tabs.appendChild(refresh) 13 | tabs.appendChild(close) 14 | 15 | search.addEventListener("input", function(evt){ 16 | searchTerm = search.value.toLowerCase() 17 | 18 | gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){ 19 | text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase() 20 | elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : "" 21 | }) 22 | }); 23 | } 24 | 25 | var activePromptTextarea = {}; 26 | 27 | function setupExtraNetworks(){ 28 | setupExtraNetworksForTab('txt2img') 29 | setupExtraNetworksForTab('img2img') 30 | 31 | function registerPrompt(tabname, id){ 32 | var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); 33 | 34 | if (! activePromptTextarea[tabname]){ 35 | activePromptTextarea[tabname] = textarea 36 | } 37 | 38 | textarea.addEventListener("focus", function(){ 39 | activePromptTextarea[tabname] = textarea; 40 | }); 41 | } 42 | 43 | registerPrompt('txt2img', 'txt2img_prompt') 44 | registerPrompt('txt2img', 'txt2img_neg_prompt') 45 | registerPrompt('img2img', 'img2img_prompt') 46 | registerPrompt('img2img', 'img2img_neg_prompt') 47 | } 48 | 49 | onUiLoaded(setupExtraNetworks) 50 | 51 | var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/; 52 | var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g; 53 | 54 | function tryToRemoveExtraNetworkFromPrompt(textarea, text){ 55 | var m = text.match(re_extranet) 56 | if(! m) return false 57 | 58 | var partToSearch = m[1] 59 | var replaced = false 60 | var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){ 61 | m = found.match(re_extranet); 62 | if(m[1] == partToSearch){ 63 | replaced = true; 64 | return "" 65 | } 66 | return found; 67 | }) 68 | 69 | if(replaced){ 70 | textarea.value = newTextareaText 71 | return true; 72 | } 73 | 74 | return false 75 | } 76 | 77 | function cardClicked(tabname, textToAdd, allowNegativePrompt){ 78 | var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea") 79 | 80 | if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){ 81 | textarea.value = textarea.value + " " + textToAdd 82 | } 83 | 84 | updateInput(textarea) 85 | } 86 | 87 | function saveCardPreview(event, tabname, filename){ 88 | var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea') 89 | var button = gradioApp().getElementById(tabname + '_save_preview') 90 | 91 | textarea.value = filename 92 | updateInput(textarea) 93 | 94 | button.click() 95 | 96 | event.stopPropagation() 97 | event.preventDefault() 98 | } 99 | 100 | function extraNetworksSearchButton(tabs_id, event){ 101 | searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea') 102 | button = event.target 103 | text = button.classList.contains("search-all") ? "" : button.textContent.trim() 104 | 105 | searchTextarea.value = text 106 | updateInput(searchTextarea) 107 | } -------------------------------------------------------------------------------- /javascript/aspectRatioOverlay.js: -------------------------------------------------------------------------------- 1 | 2 | let currentWidth = null; 3 | let currentHeight = null; 4 | let arFrameTimeout = setTimeout(function(){},0); 5 | 6 | function dimensionChange(e, is_width, is_height){ 7 | 8 | if(is_width){ 9 | currentWidth = e.target.value*1.0 10 | } 11 | if(is_height){ 12 | currentHeight = e.target.value*1.0 13 | } 14 | 15 | var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) 16 | 17 | if(!inImg2img){ 18 | return; 19 | } 20 | 21 | var targetElement = null; 22 | 23 | var tabIndex = get_tab_index('mode_img2img') 24 | if(tabIndex == 0){ // img2img 25 | targetElement = gradioApp().querySelector('div[data-testid=image] img'); 26 | } else if(tabIndex == 1){ //Sketch 27 | targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img'); 28 | } else if(tabIndex == 2){ // Inpaint 29 | targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); 30 | } else if(tabIndex == 3){ // Inpaint sketch 31 | targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img'); 32 | } 33 | 34 | 35 | if(targetElement){ 36 | 37 | var arPreviewRect = gradioApp().querySelector('#imageARPreview'); 38 | if(!arPreviewRect){ 39 | arPreviewRect = document.createElement('div') 40 | arPreviewRect.id = "imageARPreview"; 41 | gradioApp().getRootNode().appendChild(arPreviewRect) 42 | } 43 | 44 | 45 | 46 | var viewportOffset = targetElement.getBoundingClientRect(); 47 | 48 | viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight ) 49 | 50 | scaledx = targetElement.naturalWidth*viewportscale 51 | scaledy = targetElement.naturalHeight*viewportscale 52 | 53 | cleintRectTop = (viewportOffset.top+window.scrollY) 54 | cleintRectLeft = (viewportOffset.left+window.scrollX) 55 | cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2) 56 | cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2) 57 | 58 | viewRectTop = cleintRectCentreY-(scaledy/2) 59 | viewRectLeft = cleintRectCentreX-(scaledx/2) 60 | arRectWidth = scaledx 61 | arRectHeight = scaledy 62 | 63 | arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight ) 64 | arscaledx = currentWidth*arscale 65 | arscaledy = currentHeight*arscale 66 | 67 | arRectTop = cleintRectCentreY-(arscaledy/2) 68 | arRectLeft = cleintRectCentreX-(arscaledx/2) 69 | arRectWidth = arscaledx 70 | arRectHeight = arscaledy 71 | 72 | arPreviewRect.style.top = arRectTop+'px'; 73 | arPreviewRect.style.left = arRectLeft+'px'; 74 | arPreviewRect.style.width = arRectWidth+'px'; 75 | arPreviewRect.style.height = arRectHeight+'px'; 76 | 77 | clearTimeout(arFrameTimeout); 78 | arFrameTimeout = setTimeout(function(){ 79 | arPreviewRect.style.display = 'none'; 80 | },2000); 81 | 82 | arPreviewRect.style.display = 'block'; 83 | 84 | } 85 | 86 | } 87 | 88 | 89 | onUiUpdate(function(){ 90 | var arPreviewRect = gradioApp().querySelector('#imageARPreview'); 91 | if(arPreviewRect){ 92 | arPreviewRect.style.display = 'none'; 93 | } 94 | var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) 95 | if(inImg2img){ 96 | let inputs = gradioApp().querySelectorAll('input'); 97 | inputs.forEach(function(e){ 98 | var is_width = e.parentElement.id == "img2img_width" 99 | var is_height = e.parentElement.id == "img2img_height" 100 | 101 | if((is_width || is_height) && !e.classList.contains('scrollwatch')){ 102 | e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} ) 103 | e.classList.add('scrollwatch') 104 | } 105 | if(is_width){ 106 | currentWidth = e.value*1.0 107 | } 108 | if(is_height){ 109 | currentHeight = e.value*1.0 110 | } 111 | }) 112 | } 113 | }); 114 | -------------------------------------------------------------------------------- /modules/call_queue.py: -------------------------------------------------------------------------------- 1 | import html 2 | import sys 3 | import threading 4 | import traceback 5 | import time 6 | 7 | from modules import shared, progress 8 | 9 | queue_lock = threading.Lock() 10 | 11 | 12 | def wrap_queued_call(func): 13 | def f(*args, **kwargs): 14 | with queue_lock: 15 | res = func(*args, **kwargs) 16 | 17 | return res 18 | 19 | return f 20 | 21 | 22 | def wrap_gradio_gpu_call(func, extra_outputs=None): 23 | def f(*args, **kwargs): 24 | 25 | # if the first argument is a string that says "task(...)", it is treated as a job id 26 | if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": 27 | id_task = args[0] 28 | progress.add_task_to_queue(id_task) 29 | else: 30 | id_task = None 31 | 32 | with queue_lock: 33 | shared.state.begin() 34 | progress.start_task(id_task) 35 | 36 | try: 37 | res = func(*args, **kwargs) 38 | finally: 39 | progress.finish_task(id_task) 40 | 41 | shared.state.end() 42 | 43 | return res 44 | 45 | return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) 46 | 47 | 48 | def wrap_gradio_call(func, extra_outputs=None, add_stats=False): 49 | def f(*args, extra_outputs_array=extra_outputs, **kwargs): 50 | run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats 51 | if run_memmon: 52 | shared.mem_mon.monitor() 53 | t = time.perf_counter() 54 | 55 | try: 56 | res = list(func(*args, **kwargs)) 57 | except Exception as e: 58 | # When printing out our debug argument list, do not print out more than a MB of text 59 | max_debug_str_len = 131072 # (1024*1024)/8 60 | 61 | print("Error completing request", file=sys.stderr) 62 | argStr = f"Arguments: {str(args)} {str(kwargs)}" 63 | print(argStr[:max_debug_str_len], file=sys.stderr) 64 | if len(argStr) > max_debug_str_len: 65 | print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) 66 | 67 | print(traceback.format_exc(), file=sys.stderr) 68 | 69 | shared.state.job = "" 70 | shared.state.job_count = 0 71 | 72 | if extra_outputs_array is None: 73 | extra_outputs_array = [None, ''] 74 | 75 | res = extra_outputs_array + [f"
{html.escape(type(e).__name__+': '+str(e))}
"] 76 | 77 | shared.state.skipped = False 78 | shared.state.interrupted = False 79 | shared.state.job_count = 0 80 | 81 | if not add_stats: 82 | return tuple(res) 83 | 84 | elapsed = time.perf_counter() - t 85 | elapsed_m = int(elapsed // 60) 86 | elapsed_s = elapsed % 60 87 | elapsed_text = f"{elapsed_s:.2f}s" 88 | if elapsed_m > 0: 89 | elapsed_text = f"{elapsed_m}m "+elapsed_text 90 | 91 | if run_memmon: 92 | mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} 93 | active_peak = mem_stats['active_peak'] 94 | reserved_peak = mem_stats['reserved_peak'] 95 | sys_peak = mem_stats['system_peak'] 96 | sys_total = mem_stats['total'] 97 | sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) 98 | 99 | vram_html = f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" 100 | else: 101 | vram_html = '' 102 | 103 | # last item is always HTML 104 | res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" 105 | 106 | return tuple(res) 107 | 108 | return f 109 | 110 | -------------------------------------------------------------------------------- /scripts/loopback.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import trange 3 | 4 | import modules.scripts as scripts 5 | import gradio as gr 6 | 7 | from modules import processing, shared, sd_samplers, images 8 | from modules.processing import Processed 9 | from modules.sd_samplers import samplers 10 | from modules.shared import opts, cmd_opts, state 11 | from modules import deepbooru 12 | 13 | 14 | class Script(scripts.Script): 15 | def title(self): 16 | return "Loopback" 17 | 18 | def show(self, is_img2img): 19 | return is_img2img 20 | 21 | def ui(self, is_img2img): 22 | loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops")) 23 | denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=self.elem_id("denoising_strength_change_factor")) 24 | append_interrogation = gr.Dropdown(label="Append interrogated prompt at each iteration", choices=["None", "CLIP", "DeepBooru"], value="None") 25 | 26 | return [loops, denoising_strength_change_factor, append_interrogation] 27 | 28 | def run(self, p, loops, denoising_strength_change_factor, append_interrogation): 29 | processing.fix_seed(p) 30 | batch_count = p.n_iter 31 | p.extra_generation_params = { 32 | "Denoising strength change factor": denoising_strength_change_factor, 33 | } 34 | 35 | p.batch_size = 1 36 | p.n_iter = 1 37 | 38 | output_images, info = None, None 39 | initial_seed = None 40 | initial_info = None 41 | 42 | grids = [] 43 | all_images = [] 44 | original_init_image = p.init_images 45 | original_prompt = p.prompt 46 | state.job_count = loops * batch_count 47 | 48 | initial_color_corrections = [processing.setup_color_correction(p.init_images[0])] 49 | 50 | for n in range(batch_count): 51 | history = [] 52 | 53 | # Reset to original init image at the start of each batch 54 | p.init_images = original_init_image 55 | 56 | for i in range(loops): 57 | p.n_iter = 1 58 | p.batch_size = 1 59 | p.do_not_save_grid = True 60 | 61 | if opts.img2img_color_correction: 62 | p.color_corrections = initial_color_corrections 63 | 64 | if append_interrogation != "None": 65 | p.prompt = original_prompt + ", " if original_prompt != "" else "" 66 | if append_interrogation == "CLIP": 67 | p.prompt += shared.interrogator.interrogate(p.init_images[0]) 68 | elif append_interrogation == "DeepBooru": 69 | p.prompt += deepbooru.model.tag(p.init_images[0]) 70 | 71 | state.job = f"Iteration {i + 1}/{loops}, batch {n + 1}/{batch_count}" 72 | 73 | processed = processing.process_images(p) 74 | 75 | if initial_seed is None: 76 | initial_seed = processed.seed 77 | initial_info = processed.info 78 | 79 | init_img = processed.images[0] 80 | 81 | p.init_images = [init_img] 82 | p.seed = processed.seed + 1 83 | p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1) 84 | history.append(processed.images[0]) 85 | 86 | grid = images.image_grid(history, rows=1) 87 | if opts.grid_save: 88 | images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p) 89 | 90 | grids.append(grid) 91 | all_images += history 92 | 93 | if opts.return_grid: 94 | all_images = grids + all_images 95 | 96 | processed = Processed(p, all_images, initial_seed, initial_info) 97 | 98 | return processed 99 | -------------------------------------------------------------------------------- /modules/progress.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import time 4 | 5 | import gradio as gr 6 | from pydantic import BaseModel, Field 7 | 8 | from modules.shared import opts 9 | 10 | import modules.shared as shared 11 | 12 | 13 | current_task = None 14 | pending_tasks = {} 15 | finished_tasks = [] 16 | 17 | 18 | def start_task(id_task): 19 | global current_task 20 | 21 | current_task = id_task 22 | pending_tasks.pop(id_task, None) 23 | 24 | 25 | def finish_task(id_task): 26 | global current_task 27 | 28 | if current_task == id_task: 29 | current_task = None 30 | 31 | finished_tasks.append(id_task) 32 | if len(finished_tasks) > 16: 33 | finished_tasks.pop(0) 34 | 35 | 36 | def add_task_to_queue(id_job): 37 | pending_tasks[id_job] = time.time() 38 | 39 | 40 | class ProgressRequest(BaseModel): 41 | id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") 42 | id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image") 43 | 44 | 45 | class ProgressResponse(BaseModel): 46 | active: bool = Field(title="Whether the task is being worked on right now") 47 | queued: bool = Field(title="Whether the task is in queue") 48 | completed: bool = Field(title="Whether the task has already finished") 49 | progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1") 50 | eta: float = Field(default=None, title="ETA in secs") 51 | live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri") 52 | id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image") 53 | textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.") 54 | 55 | 56 | def setup_progress_api(app): 57 | return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) 58 | 59 | 60 | def progressapi(req: ProgressRequest): 61 | active = req.id_task == current_task 62 | queued = req.id_task in pending_tasks 63 | completed = req.id_task in finished_tasks 64 | 65 | if not active: 66 | return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...") 67 | 68 | progress = 0 69 | 70 | job_count, job_no = shared.state.job_count, shared.state.job_no 71 | sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step 72 | 73 | if job_count > 0: 74 | progress += job_no / job_count 75 | if sampling_steps > 0 and job_count > 0: 76 | progress += 1 / job_count * sampling_step / sampling_steps 77 | 78 | progress = min(progress, 1) 79 | 80 | elapsed_since_start = time.time() - shared.state.time_start 81 | predicted_duration = elapsed_since_start / progress if progress > 0 else None 82 | eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None 83 | 84 | id_live_preview = req.id_live_preview 85 | shared.state.set_current_image() 86 | if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview: 87 | image = shared.state.current_image 88 | if image is not None: 89 | buffered = io.BytesIO() 90 | image.save(buffered, format="png") 91 | live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii") 92 | id_live_preview = shared.state.id_live_preview 93 | else: 94 | live_preview = None 95 | else: 96 | live_preview = None 97 | 98 | return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) 99 | 100 | -------------------------------------------------------------------------------- /modules/sd_hijack_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from packaging import version 3 | 4 | from modules import devices 5 | from modules.sd_hijack_utils import CondFunc 6 | 7 | 8 | class TorchHijackForUnet: 9 | """ 10 | This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; 11 | this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 12 | """ 13 | 14 | def __getattr__(self, item): 15 | if item == 'cat': 16 | return self.cat 17 | 18 | if hasattr(torch, item): 19 | return getattr(torch, item) 20 | 21 | raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) 22 | 23 | def cat(self, tensors, *args, **kwargs): 24 | if len(tensors) == 2: 25 | a, b = tensors 26 | if a.shape[-2:] != b.shape[-2:]: 27 | a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") 28 | 29 | tensors = (a, b) 30 | 31 | return torch.cat(tensors, *args, **kwargs) 32 | 33 | 34 | th = TorchHijackForUnet() 35 | 36 | 37 | # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling 38 | def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): 39 | 40 | if isinstance(cond, dict): 41 | for y in cond.keys(): 42 | cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] 43 | 44 | with devices.autocast(): 45 | return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() 46 | 47 | 48 | class GELUHijack(torch.nn.GELU, torch.nn.Module): 49 | def __init__(self, *args, **kwargs): 50 | torch.nn.GELU.__init__(self, *args, **kwargs) 51 | def forward(self, x): 52 | if devices.unet_needs_upcast: 53 | return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) 54 | else: 55 | return torch.nn.GELU.forward(self, x) 56 | 57 | 58 | ddpm_edit_hijack = None 59 | def hijack_ddpm_edit(): 60 | global ddpm_edit_hijack 61 | if not ddpm_edit_hijack: 62 | CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) 63 | CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) 64 | ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) 65 | 66 | 67 | unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast 68 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) 69 | CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) 70 | if version.parse(torch.__version__) <= version.parse("1.13.1"): 71 | CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) 72 | CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) 73 | CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) 74 | 75 | first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 76 | first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) 77 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) 78 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) 79 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) 80 | -------------------------------------------------------------------------------- /modules/gfpgan_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | import facexlib 6 | import gfpgan 7 | 8 | import modules.face_restoration 9 | from modules import paths, shared, devices, modelloader 10 | 11 | model_dir = "GFPGAN" 12 | user_path = None 13 | model_path = os.path.join(paths.models_path, model_dir) 14 | model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" 15 | have_gfpgan = False 16 | loaded_gfpgan_model = None 17 | 18 | 19 | def gfpgann(): 20 | global loaded_gfpgan_model 21 | global model_path 22 | if loaded_gfpgan_model is not None: 23 | loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan) 24 | return loaded_gfpgan_model 25 | 26 | if gfpgan_constructor is None: 27 | return None 28 | 29 | models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") 30 | if len(models) == 1 and "http" in models[0]: 31 | model_file = models[0] 32 | elif len(models) != 0: 33 | latest_file = max(models, key=os.path.getctime) 34 | model_file = latest_file 35 | else: 36 | print("Unable to load gfpgan model!") 37 | return None 38 | if hasattr(facexlib.detection.retinaface, 'device'): 39 | facexlib.detection.retinaface.device = devices.device_gfpgan 40 | model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) 41 | loaded_gfpgan_model = model 42 | 43 | return model 44 | 45 | 46 | def send_model_to(model, device): 47 | model.gfpgan.to(device) 48 | model.face_helper.face_det.to(device) 49 | model.face_helper.face_parse.to(device) 50 | 51 | 52 | def gfpgan_fix_faces(np_image): 53 | model = gfpgann() 54 | if model is None: 55 | return np_image 56 | 57 | send_model_to(model, devices.device_gfpgan) 58 | 59 | np_image_bgr = np_image[:, :, ::-1] 60 | cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) 61 | np_image = gfpgan_output_bgr[:, :, ::-1] 62 | 63 | model.face_helper.clean_all() 64 | 65 | if shared.opts.face_restoration_unload: 66 | send_model_to(model, devices.cpu) 67 | 68 | return np_image 69 | 70 | 71 | gfpgan_constructor = None 72 | 73 | 74 | def setup_model(dirname): 75 | global model_path 76 | if not os.path.exists(model_path): 77 | os.makedirs(model_path) 78 | 79 | try: 80 | from gfpgan import GFPGANer 81 | from facexlib import detection, parsing 82 | global user_path 83 | global have_gfpgan 84 | global gfpgan_constructor 85 | 86 | load_file_from_url_orig = gfpgan.utils.load_file_from_url 87 | facex_load_file_from_url_orig = facexlib.detection.load_file_from_url 88 | facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url 89 | 90 | def my_load_file_from_url(**kwargs): 91 | return load_file_from_url_orig(**dict(kwargs, model_dir=model_path)) 92 | 93 | def facex_load_file_from_url(**kwargs): 94 | return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None)) 95 | 96 | def facex_load_file_from_url2(**kwargs): 97 | return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None)) 98 | 99 | gfpgan.utils.load_file_from_url = my_load_file_from_url 100 | facexlib.detection.load_file_from_url = facex_load_file_from_url 101 | facexlib.parsing.load_file_from_url = facex_load_file_from_url2 102 | user_path = dirname 103 | have_gfpgan = True 104 | gfpgan_constructor = GFPGANer 105 | 106 | class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): 107 | def name(self): 108 | return "GFPGAN" 109 | 110 | def restore(self, np_image): 111 | return gfpgan_fix_faces(np_image) 112 | 113 | shared.face_restorers.append(FaceRestorerGFPGAN()) 114 | except Exception: 115 | print("Error setting up GFPGAN:", file=sys.stderr) 116 | print(traceback.format_exc(), file=sys.stderr) 117 | -------------------------------------------------------------------------------- /scripts/sd_upscale.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import modules.scripts as scripts 4 | import gradio as gr 5 | from PIL import Image 6 | 7 | from modules import processing, shared, sd_samplers, images, devices 8 | from modules.processing import Processed 9 | from modules.shared import opts, cmd_opts, state 10 | 11 | 12 | class Script(scripts.Script): 13 | def title(self): 14 | return "SD upscale" 15 | 16 | def show(self, is_img2img): 17 | return is_img2img 18 | 19 | def ui(self, is_img2img): 20 | info = gr.HTML("

Will upscale the image by the selected scale factor; use width and height sliders to set tile size

") 21 | overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap")) 22 | scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor")) 23 | upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=self.elem_id("upscaler_index")) 24 | 25 | return [info, overlap, upscaler_index, scale_factor] 26 | 27 | def run(self, p, _, overlap, upscaler_index, scale_factor): 28 | if isinstance(upscaler_index, str): 29 | upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower()) 30 | processing.fix_seed(p) 31 | upscaler = shared.sd_upscalers[upscaler_index] 32 | 33 | p.extra_generation_params["SD upscale overlap"] = overlap 34 | p.extra_generation_params["SD upscale upscaler"] = upscaler.name 35 | 36 | initial_info = None 37 | seed = p.seed 38 | 39 | init_img = p.init_images[0] 40 | init_img = images.flatten(init_img, opts.img2img_background_color) 41 | 42 | if upscaler.name != "None": 43 | img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path) 44 | else: 45 | img = init_img 46 | 47 | devices.torch_gc() 48 | 49 | grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap) 50 | 51 | batch_size = p.batch_size 52 | upscale_count = p.n_iter 53 | p.n_iter = 1 54 | p.do_not_save_grid = True 55 | p.do_not_save_samples = True 56 | 57 | work = [] 58 | 59 | for y, h, row in grid.tiles: 60 | for tiledata in row: 61 | work.append(tiledata[2]) 62 | 63 | batch_count = math.ceil(len(work) / batch_size) 64 | state.job_count = batch_count * upscale_count 65 | 66 | print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.") 67 | 68 | result_images = [] 69 | for n in range(upscale_count): 70 | start_seed = seed + n 71 | p.seed = start_seed 72 | 73 | work_results = [] 74 | for i in range(batch_count): 75 | p.batch_size = batch_size 76 | p.init_images = work[i * batch_size:(i + 1) * batch_size] 77 | 78 | state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}" 79 | processed = processing.process_images(p) 80 | 81 | if initial_info is None: 82 | initial_info = processed.info 83 | 84 | p.seed = processed.seed + 1 85 | work_results += processed.images 86 | 87 | image_index = 0 88 | for y, h, row in grid.tiles: 89 | for tiledata in row: 90 | tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) 91 | image_index += 1 92 | 93 | combined_image = images.combine_grid(grid) 94 | result_images.append(combined_image) 95 | 96 | if opts.samples_save: 97 | images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p) 98 | 99 | processed = Processed(p, result_images, seed, initial_info) 100 | 101 | return processed 102 | -------------------------------------------------------------------------------- /modules/postprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | 5 | from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste 6 | from modules.shared import opts 7 | 8 | 9 | def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): 10 | devices.torch_gc() 11 | 12 | shared.state.begin() 13 | shared.state.job = 'extras' 14 | 15 | image_data = [] 16 | image_names = [] 17 | outputs = [] 18 | 19 | if extras_mode == 1: 20 | for img in image_folder: 21 | image = Image.open(img) 22 | image_data.append(image) 23 | image_names.append(os.path.splitext(img.orig_name)[0]) 24 | elif extras_mode == 2: 25 | assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' 26 | assert input_dir, 'input directory not selected' 27 | 28 | image_list = shared.listfiles(input_dir) 29 | for filename in image_list: 30 | try: 31 | image = Image.open(filename) 32 | except Exception: 33 | continue 34 | image_data.append(image) 35 | image_names.append(filename) 36 | else: 37 | assert image, 'image not selected' 38 | 39 | image_data.append(image) 40 | image_names.append(None) 41 | 42 | if extras_mode == 2 and output_dir != '': 43 | outpath = output_dir 44 | else: 45 | outpath = opts.outdir_samples or opts.outdir_extras_samples 46 | 47 | infotext = '' 48 | 49 | for image, name in zip(image_data, image_names): 50 | shared.state.textinfo = name 51 | 52 | existing_pnginfo = image.info or {} 53 | 54 | pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB")) 55 | 56 | scripts.scripts_postproc.run(pp, args) 57 | 58 | if opts.use_original_name_batch and name is not None: 59 | basename = os.path.splitext(os.path.basename(name))[0] 60 | else: 61 | basename = '' 62 | 63 | infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) 64 | 65 | if opts.enable_pnginfo: 66 | pp.image.info = existing_pnginfo 67 | pp.image.info["postprocessing"] = infotext 68 | 69 | if save_output: 70 | images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) 71 | 72 | if extras_mode != 2 or show_extras_results: 73 | outputs.append(pp.image) 74 | 75 | devices.torch_gc() 76 | 77 | return outputs, ui_common.plaintext_to_html(infotext), '' 78 | 79 | 80 | def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): 81 | """old handler for API""" 82 | 83 | args = scripts.scripts_postproc.create_args_for_run({ 84 | "Upscale": { 85 | "upscale_mode": resize_mode, 86 | "upscale_by": upscaling_resize, 87 | "upscale_to_width": upscaling_resize_w, 88 | "upscale_to_height": upscaling_resize_h, 89 | "upscale_crop": upscaling_crop, 90 | "upscaler_1_name": extras_upscaler_1, 91 | "upscaler_2_name": extras_upscaler_2, 92 | "upscaler_2_visibility": extras_upscaler_2_visibility, 93 | }, 94 | "GFPGAN": { 95 | "gfpgan_visibility": gfpgan_visibility, 96 | }, 97 | "CodeFormer": { 98 | "codeformer_visibility": codeformer_visibility, 99 | "codeformer_weight": codeformer_weight, 100 | }, 101 | }) 102 | 103 | return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) 104 | -------------------------------------------------------------------------------- /modules/sd_models_config.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | 4 | import torch 5 | 6 | from modules import shared, paths, sd_disable_initialization 7 | 8 | sd_configs_path = shared.sd_configs_path 9 | sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") 10 | 11 | 12 | config_default = shared.sd_default_config 13 | config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") 14 | config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") 15 | config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") 16 | config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") 17 | config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") 18 | config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") 19 | config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") 20 | 21 | 22 | def is_using_v_parameterization_for_sd2(state_dict): 23 | """ 24 | Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome. 25 | """ 26 | 27 | import ldm.modules.diffusionmodules.openaimodel 28 | from modules import devices 29 | 30 | device = devices.cpu 31 | 32 | with sd_disable_initialization.DisableInitialization(): 33 | unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( 34 | use_checkpoint=True, 35 | use_fp16=False, 36 | image_size=32, 37 | in_channels=4, 38 | out_channels=4, 39 | model_channels=320, 40 | attention_resolutions=[4, 2, 1], 41 | num_res_blocks=2, 42 | channel_mult=[1, 2, 4, 4], 43 | num_head_channels=64, 44 | use_spatial_transformer=True, 45 | use_linear_in_transformer=True, 46 | transformer_depth=1, 47 | context_dim=1024, 48 | legacy=False 49 | ) 50 | unet.eval() 51 | 52 | with torch.no_grad(): 53 | unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} 54 | unet.load_state_dict(unet_sd, strict=True) 55 | unet.to(device=device, dtype=torch.float) 56 | 57 | test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 58 | x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 59 | 60 | out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item() 61 | 62 | return out < -1 63 | 64 | 65 | def guess_model_config_from_state_dict(sd, filename): 66 | sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) 67 | diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) 68 | 69 | if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: 70 | return config_depth_model 71 | 72 | if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: 73 | if diffusion_model_input.shape[1] == 9: 74 | return config_sd2_inpainting 75 | elif is_using_v_parameterization_for_sd2(sd): 76 | return config_sd2v 77 | else: 78 | return config_sd2 79 | 80 | if diffusion_model_input is not None: 81 | if diffusion_model_input.shape[1] == 9: 82 | return config_inpainting 83 | if diffusion_model_input.shape[1] == 8: 84 | return config_instruct_pix2pix 85 | 86 | if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: 87 | return config_alt_diffusion 88 | 89 | return config_default 90 | 91 | 92 | def find_checkpoint_config(state_dict, info): 93 | if info is None: 94 | return guess_model_config_from_state_dict(state_dict, "") 95 | 96 | config = find_checkpoint_config_near_filename(info) 97 | if config is not None: 98 | return config 99 | 100 | return guess_model_config_from_state_dict(state_dict, info.filename) 101 | 102 | 103 | def find_checkpoint_config_near_filename(info): 104 | if info is None: 105 | return None 106 | 107 | config = os.path.splitext(info.filename)[0] + ".yaml" 108 | if os.path.exists(config): 109 | return config 110 | 111 | return None 112 | 113 | -------------------------------------------------------------------------------- /extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js: -------------------------------------------------------------------------------- 1 | // Stable Diffusion WebUI - Bracket checker 2 | // Version 1.0 3 | // By Hingashi no Florin/Bwin4L 4 | // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs. 5 | // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong. 6 | 7 | function checkBrackets(evt, textArea, counterElt) { 8 | errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n'; 9 | errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n'; 10 | errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n'; 11 | 12 | openBracketRegExp = /\(/g; 13 | closeBracketRegExp = /\)/g; 14 | 15 | openSquareBracketRegExp = /\[/g; 16 | closeSquareBracketRegExp = /\]/g; 17 | 18 | openCurlyBracketRegExp = /\{/g; 19 | closeCurlyBracketRegExp = /\}/g; 20 | 21 | totalOpenBracketMatches = 0; 22 | totalCloseBracketMatches = 0; 23 | totalOpenSquareBracketMatches = 0; 24 | totalCloseSquareBracketMatches = 0; 25 | totalOpenCurlyBracketMatches = 0; 26 | totalCloseCurlyBracketMatches = 0; 27 | 28 | openBracketMatches = textArea.value.match(openBracketRegExp); 29 | if(openBracketMatches) { 30 | totalOpenBracketMatches = openBracketMatches.length; 31 | } 32 | 33 | closeBracketMatches = textArea.value.match(closeBracketRegExp); 34 | if(closeBracketMatches) { 35 | totalCloseBracketMatches = closeBracketMatches.length; 36 | } 37 | 38 | openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp); 39 | if(openSquareBracketMatches) { 40 | totalOpenSquareBracketMatches = openSquareBracketMatches.length; 41 | } 42 | 43 | closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp); 44 | if(closeSquareBracketMatches) { 45 | totalCloseSquareBracketMatches = closeSquareBracketMatches.length; 46 | } 47 | 48 | openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp); 49 | if(openCurlyBracketMatches) { 50 | totalOpenCurlyBracketMatches = openCurlyBracketMatches.length; 51 | } 52 | 53 | closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp); 54 | if(closeCurlyBracketMatches) { 55 | totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length; 56 | } 57 | 58 | if(totalOpenBracketMatches != totalCloseBracketMatches) { 59 | if(!counterElt.title.includes(errorStringParen)) { 60 | counterElt.title += errorStringParen; 61 | } 62 | } else { 63 | counterElt.title = counterElt.title.replace(errorStringParen, ''); 64 | } 65 | 66 | if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) { 67 | if(!counterElt.title.includes(errorStringSquare)) { 68 | counterElt.title += errorStringSquare; 69 | } 70 | } else { 71 | counterElt.title = counterElt.title.replace(errorStringSquare, ''); 72 | } 73 | 74 | if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) { 75 | if(!counterElt.title.includes(errorStringCurly)) { 76 | counterElt.title += errorStringCurly; 77 | } 78 | } else { 79 | counterElt.title = counterElt.title.replace(errorStringCurly, ''); 80 | } 81 | 82 | if(counterElt.title != '') { 83 | counterElt.classList.add('error'); 84 | } else { 85 | counterElt.classList.remove('error'); 86 | } 87 | } 88 | 89 | function setupBracketChecking(id_prompt, id_counter){ 90 | var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); 91 | var counter = gradioApp().getElementById(id_counter) 92 | textarea.addEventListener("input", function(evt){ 93 | checkBrackets(evt, textarea, counter) 94 | }); 95 | } 96 | 97 | var shadowRootLoaded = setInterval(function() { 98 | var shadowRoot = document.querySelector('gradio-app').shadowRoot; 99 | if(! shadowRoot) return false; 100 | 101 | var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea'); 102 | if(shadowTextArea.length < 1) return false; 103 | 104 | clearInterval(shadowRootLoaded); 105 | 106 | setupBracketChecking('txt2img_prompt', 'txt2img_token_counter') 107 | setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter') 108 | setupBracketChecking('img2img_prompt', 'imgimg_token_counter') 109 | setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter') 110 | }, 1000); 111 | -------------------------------------------------------------------------------- /modules/lowvram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules import devices 3 | 4 | module_in_gpu = None 5 | cpu = torch.device("cpu") 6 | 7 | 8 | def send_everything_to_cpu(): 9 | global module_in_gpu 10 | 11 | if module_in_gpu is not None: 12 | module_in_gpu.to(cpu) 13 | 14 | module_in_gpu = None 15 | 16 | 17 | def setup_for_low_vram(sd_model, use_medvram): 18 | parents = {} 19 | 20 | def send_me_to_gpu(module, _): 21 | """send this module to GPU; send whatever tracked module was previous in GPU to CPU; 22 | we add this as forward_pre_hook to a lot of modules and this way all but one of them will 23 | be in CPU 24 | """ 25 | global module_in_gpu 26 | 27 | module = parents.get(module, module) 28 | 29 | if module_in_gpu == module: 30 | return 31 | 32 | if module_in_gpu is not None: 33 | module_in_gpu.to(cpu) 34 | 35 | module.to(devices.device) 36 | module_in_gpu = module 37 | 38 | # see below for register_forward_pre_hook; 39 | # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is 40 | # useless here, and we just replace those methods 41 | 42 | first_stage_model = sd_model.first_stage_model 43 | first_stage_model_encode = sd_model.first_stage_model.encode 44 | first_stage_model_decode = sd_model.first_stage_model.decode 45 | 46 | def first_stage_model_encode_wrap(x): 47 | send_me_to_gpu(first_stage_model, None) 48 | return first_stage_model_encode(x) 49 | 50 | def first_stage_model_decode_wrap(z): 51 | send_me_to_gpu(first_stage_model, None) 52 | return first_stage_model_decode(z) 53 | 54 | # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field 55 | if hasattr(sd_model.cond_stage_model, 'model'): 56 | sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model 57 | 58 | # remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then 59 | # send the model to GPU. Then put modules back. the modules will be in CPU. 60 | stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model 61 | sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None 62 | sd_model.to(devices.device) 63 | sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored 64 | 65 | # register hooks for those the first three models 66 | sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) 67 | sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) 68 | sd_model.first_stage_model.encode = first_stage_model_encode_wrap 69 | sd_model.first_stage_model.decode = first_stage_model_decode_wrap 70 | if sd_model.depth_model: 71 | sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu) 72 | parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model 73 | 74 | if hasattr(sd_model.cond_stage_model, 'model'): 75 | sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer 76 | del sd_model.cond_stage_model.transformer 77 | 78 | if use_medvram: 79 | sd_model.model.register_forward_pre_hook(send_me_to_gpu) 80 | else: 81 | diff_model = sd_model.model.diffusion_model 82 | 83 | # the third remaining model is still too big for 4 GB, so we also do the same for its submodules 84 | # so that only one of them is in GPU at a time 85 | stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed 86 | diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None 87 | sd_model.model.to(devices.device) 88 | diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored 89 | 90 | # install hooks for bits of third model 91 | diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu) 92 | for block in diff_model.input_blocks: 93 | block.register_forward_pre_hook(send_me_to_gpu) 94 | diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu) 95 | for block in diff_model.output_blocks: 96 | block.register_forward_pre_hook(send_me_to_gpu) 97 | -------------------------------------------------------------------------------- /modules/upscaler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import abstractmethod 3 | 4 | import PIL 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | 9 | import modules.shared 10 | from modules import modelloader, shared 11 | 12 | LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) 13 | NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) 14 | 15 | 16 | class Upscaler: 17 | name = None 18 | model_path = None 19 | model_name = None 20 | model_url = None 21 | enable = True 22 | filter = None 23 | model = None 24 | user_path = None 25 | scalers: [] 26 | tile = True 27 | 28 | def __init__(self, create_dirs=False): 29 | self.mod_pad_h = None 30 | self.tile_size = modules.shared.opts.ESRGAN_tile 31 | self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap 32 | self.device = modules.shared.device 33 | self.img = None 34 | self.output = None 35 | self.scale = 1 36 | self.half = not modules.shared.cmd_opts.no_half 37 | self.pre_pad = 0 38 | self.mod_scale = None 39 | 40 | if self.model_path is None and self.name: 41 | self.model_path = os.path.join(shared.models_path, self.name) 42 | if self.model_path and create_dirs: 43 | os.makedirs(self.model_path, exist_ok=True) 44 | 45 | try: 46 | import cv2 47 | self.can_tile = True 48 | except: 49 | pass 50 | 51 | @abstractmethod 52 | def do_upscale(self, img: PIL.Image, selected_model: str): 53 | return img 54 | 55 | def upscale(self, img: PIL.Image, scale, selected_model: str = None): 56 | self.scale = scale 57 | dest_w = int(img.width * scale) 58 | dest_h = int(img.height * scale) 59 | 60 | for i in range(3): 61 | shape = (img.width, img.height) 62 | 63 | img = self.do_upscale(img, selected_model) 64 | 65 | if shape == (img.width, img.height): 66 | break 67 | 68 | if img.width >= dest_w and img.height >= dest_h: 69 | break 70 | 71 | if img.width != dest_w or img.height != dest_h: 72 | img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) 73 | 74 | return img 75 | 76 | @abstractmethod 77 | def load_model(self, path: str): 78 | pass 79 | 80 | def find_models(self, ext_filter=None) -> list: 81 | return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path) 82 | 83 | def update_status(self, prompt): 84 | print(f"\nextras: {prompt}", file=shared.progress_print_out) 85 | 86 | 87 | class UpscalerData: 88 | name = None 89 | data_path = None 90 | scale: int = 4 91 | scaler: Upscaler = None 92 | model: None 93 | 94 | def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None): 95 | self.name = name 96 | self.data_path = path 97 | self.local_data_path = path 98 | self.scaler = upscaler 99 | self.scale = scale 100 | self.model = model 101 | 102 | 103 | class UpscalerNone(Upscaler): 104 | name = "None" 105 | scalers = [] 106 | 107 | def load_model(self, path): 108 | pass 109 | 110 | def do_upscale(self, img, selected_model=None): 111 | return img 112 | 113 | def __init__(self, dirname=None): 114 | super().__init__(False) 115 | self.scalers = [UpscalerData("None", None, self)] 116 | 117 | 118 | class UpscalerLanczos(Upscaler): 119 | scalers = [] 120 | 121 | def do_upscale(self, img, selected_model=None): 122 | return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS) 123 | 124 | def load_model(self, _): 125 | pass 126 | 127 | def __init__(self, dirname=None): 128 | super().__init__(False) 129 | self.name = "Lanczos" 130 | self.scalers = [UpscalerData("Lanczos", None, self)] 131 | 132 | 133 | class UpscalerNearest(Upscaler): 134 | scalers = [] 135 | 136 | def do_upscale(self, img, selected_model=None): 137 | return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST) 138 | 139 | def load_model(self, _): 140 | pass 141 | 142 | def __init__(self, dirname=None): 143 | super().__init__(False) 144 | self.name = "Nearest" 145 | self.scalers = [UpscalerData("Nearest", None, self)] 146 | --------------------------------------------------------------------------------