├── test ├── __init__.py ├── test_files │ ├── empty.pt │ ├── mask_basic.png │ └── img2img_basic.png ├── conftest.py ├── test_utils.py ├── test_extras.py ├── test_img2img.py └── test_txt2img.py ├── models ├── VAE │ └── Put VAE here.txt ├── Stable-diffusion │ └── Put Stable Diffusion checkpoints here.txt ├── deepbooru │ └── Put your deepbooru release project folder here.txt ├── VAE-approx │ └── model.pt └── karlo │ └── ViT-L-14_stats.th ├── extensions └── put extensions here.txt ├── localizations └── Put localization files here.txt ├── textual_inversion_templates ├── none.txt ├── style.txt ├── subject.txt ├── hypernetwork.txt ├── style_filewords.txt └── subject_filewords.txt ├── embeddings └── Place Textual Inversion embeddings here.txt ├── .eslintignore ├── .git-blame-ignore-revs ├── requirements-test.txt ├── modules ├── models │ └── diffusion │ │ └── uni_pc │ │ ├── __init__.py │ │ └── sampler.py ├── Roboto-Regular.ttf ├── textual_inversion │ ├── test_embedding.png │ ├── logging.py │ ├── ui.py │ └── learn_schedule.py ├── import_hook.py ├── sd_hijack_ip2p.py ├── face_restoration.py ├── shared_items.py ├── script_loading.py ├── timer.py ├── ui_extra_networks_hypernets.py ├── localization.py ├── ngrok.py ├── extra_networks_hypernet.py ├── ui_extra_networks_textual_inversion.py ├── ui_extra_networks_checkpoints.py ├── errors.py ├── sd_hijack_utils.py ├── paths_internal.py ├── sd_hijack_checkpoint.py ├── sd_hijack_open_clip.py ├── sd_hijack_xlmr.py ├── hypernetworks │ └── ui.py ├── scripts_auto_postprocessing.py ├── sd_samplers.py ├── paths.py ├── ui_components.py ├── sd_vae_approx.py ├── ui_postprocessing.py ├── ui_tempdir.py ├── styles.py ├── sd_vae_taesd.py ├── memmon.py ├── txt2img.py ├── hashes.py ├── deepbooru.py ├── sd_samplers_common.py ├── masking.py ├── sd_hijack_clip_old.py ├── call_queue.py ├── sd_hijack_unet.py ├── gfpgan_model.py ├── postprocessing.py ├── upscaler.py ├── sd_hijack_inpainting.py ├── mac_specific.py └── lowvram.py ├── screenshot.png ├── html ├── card-no-preview.png ├── extra-networks-no-cards.html ├── footer.html ├── extra-networks-card.html └── image-update.svg ├── webui-user.bat ├── .pylintrc ├── environment-wsl2.yaml ├── package.json ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── feature_request.yml │ └── bug_report.yml ├── pull_request_template.md └── workflows │ ├── on_pull_request.yaml │ └── run_tests.yaml ├── extensions-builtin ├── Lora │ ├── preload.py │ ├── ui_extra_networks_lora.py │ └── extra_networks_lora.py ├── LDSR │ ├── preload.py │ └── scripts │ │ └── ldsr_model.py ├── ScuNET │ └── preload.py ├── SwinIR │ └── preload.py └── prompt-bracket-checker │ └── javascript │ └── prompt-bracket-checker.js ├── requirements.txt ├── javascript ├── textualInversion.js ├── imageParams.js ├── hires_fix.js ├── imageMaskFix.js ├── generationParams.js ├── notification.js ├── imageviewerGamepad.js ├── ui_settings_hints.js ├── extensions.js ├── dragdrop.js └── aspectRatioOverlay.js ├── requirements_versions.txt ├── .gitignore ├── CODEOWNERS ├── webui-macos-env.sh ├── pyproject.toml ├── launch.py ├── scripts ├── postprocessing_gfpgan.py ├── postprocessing_codeformer.py ├── custom_code.py └── sd_upscale.py ├── webui-user.sh ├── configs ├── v1-inference.yaml ├── alt-diffusion-inference.yaml ├── v1-inpainting-inference.yaml └── instruct-pix2pix.yaml ├── webui.bat ├── .eslintrc.js └── script.js /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/VAE/Put VAE here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /extensions/put extensions here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /localizations/Put localization files here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /textual_inversion_templates/none.txt: -------------------------------------------------------------------------------- 1 | picture 2 | -------------------------------------------------------------------------------- /embeddings/Place Textual Inversion embeddings here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/Stable-diffusion/Put Stable Diffusion checkpoints here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/deepbooru/Put your deepbooru release project folder here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.eslintignore: -------------------------------------------------------------------------------- 1 | extensions 2 | extensions-disabled 3 | repositories 4 | venv -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # Apply ESlint 2 | 9c54b78d9dde5601e916f308d9a9d6953ec39430 -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest-base-url~=2.0 2 | pytest-cov~=4.0 3 | pytest~=7.3 4 | -------------------------------------------------------------------------------- /modules/models/diffusion/uni_pc/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import UniPCSampler # noqa: F401 2 | -------------------------------------------------------------------------------- /screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/screenshot.png -------------------------------------------------------------------------------- /html/card-no-preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/html/card-no-preview.png -------------------------------------------------------------------------------- /test/test_files/empty.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/test/test_files/empty.pt -------------------------------------------------------------------------------- /models/VAE-approx/model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/models/VAE-approx/model.pt -------------------------------------------------------------------------------- /modules/Roboto-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/modules/Roboto-Regular.ttf -------------------------------------------------------------------------------- /models/karlo/ViT-L-14_stats.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/models/karlo/ViT-L-14_stats.th -------------------------------------------------------------------------------- /test/test_files/mask_basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/test/test_files/mask_basic.png -------------------------------------------------------------------------------- /webui-user.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | set PYTHON= 4 | set GIT= 5 | set VENV_DIR= 6 | set COMMANDLINE_ARGS= 7 | 8 | call webui.bat 9 | -------------------------------------------------------------------------------- /test/test_files/img2img_basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/test/test_files/img2img_basic.png -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html 2 | [MESSAGES CONTROL] 3 | disable=C,R,W,E,I 4 | -------------------------------------------------------------------------------- /modules/textual_inversion/test_embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/modules/textual_inversion/test_embedding.png -------------------------------------------------------------------------------- /html/extra-networks-no-cards.html: -------------------------------------------------------------------------------- 1 |
2 |

Nothing here. Add some content to the following directories:

3 | 4 | 7 |
8 | 9 | -------------------------------------------------------------------------------- /environment-wsl2.yaml: -------------------------------------------------------------------------------- 1 | name: automatic 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.10 7 | - pip=23.0 8 | - cudatoolkit=11.8 9 | - pytorch=2.0 10 | - torchvision=0.15 11 | - numpy=1.23 12 | -------------------------------------------------------------------------------- /modules/import_hook.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | # this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it 4 | if "--xformers" not in "".join(sys.argv): 5 | sys.modules["xformers"] = None 6 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "stable-diffusion-webui", 3 | "version": "0.0.0", 4 | "devDependencies": { 5 | "eslint": "^8.40.0" 6 | }, 7 | "scripts": { 8 | "lint": "eslint .", 9 | "fix": "eslint --fix ." 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: WebUI Community Support 4 | url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions 5 | about: Please ask and answer questions here. 6 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/preload.py: -------------------------------------------------------------------------------- 1 | import os 2 | from modules import paths 3 | 4 | 5 | def preload(parser): 6 | parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora')) 7 | -------------------------------------------------------------------------------- /extensions-builtin/LDSR/preload.py: -------------------------------------------------------------------------------- 1 | import os 2 | from modules import paths 3 | 4 | 5 | def preload(parser): 6 | parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR')) 7 | -------------------------------------------------------------------------------- /extensions-builtin/ScuNET/preload.py: -------------------------------------------------------------------------------- 1 | import os 2 | from modules import paths 3 | 4 | 5 | def preload(parser): 6 | parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET')) 7 | -------------------------------------------------------------------------------- /extensions-builtin/SwinIR/preload.py: -------------------------------------------------------------------------------- 1 | import os 2 | from modules import paths 3 | 4 | 5 | def preload(parser): 6 | parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR')) 7 | -------------------------------------------------------------------------------- /modules/sd_hijack_ip2p.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | 4 | def should_hijack_ip2p(checkpoint_info): 5 | from modules import sd_models_config 6 | 7 | ckpt_basename = os.path.basename(checkpoint_info.filename).lower() 8 | cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower() 9 | 10 | return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename 11 | -------------------------------------------------------------------------------- /html/footer.html: -------------------------------------------------------------------------------- 1 |
2 | API 3 |  •  4 | Github 5 |  •  6 | Gradio 7 |  •  8 | Reload UI 9 |
10 |
11 |
12 | {versions} 13 |
14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | astunparse 2 | blendmodes 3 | accelerate 4 | basicsr 5 | gfpgan 6 | gradio==3.31.0 7 | numpy 8 | omegaconf 9 | opencv-contrib-python 10 | requests 11 | piexif 12 | Pillow 13 | pytorch_lightning==1.7.7 14 | realesrgan 15 | scikit-image>=0.19 16 | timm==0.4.12 17 | transformers==4.25.1 18 | torch 19 | einops 20 | jsonmerge 21 | clean-fid 22 | resize-right 23 | torchdiffeq 24 | kornia 25 | lark 26 | inflection 27 | GitPython 28 | torchsde 29 | safetensors 30 | psutil 31 | rich 32 | tomesd 33 | -------------------------------------------------------------------------------- /javascript/textualInversion.js: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | function start_training_textual_inversion() { 5 | gradioApp().querySelector('#ti_error').innerHTML = ''; 6 | 7 | var id = randomId(); 8 | requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function() {}, function(progress) { 9 | gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo; 10 | }); 11 | 12 | var res = Array.from(arguments); 13 | 14 | res[0] = id; 15 | 16 | return res; 17 | } 18 | -------------------------------------------------------------------------------- /html/extra-networks-card.html: -------------------------------------------------------------------------------- 1 |
2 | {background_image} 3 | {metadata_button} 4 |
5 |
6 | 9 | 10 |
11 | {name} 12 | {description} 13 |
14 |
15 | -------------------------------------------------------------------------------- /modules/face_restoration.py: -------------------------------------------------------------------------------- 1 | from modules import shared 2 | 3 | 4 | class FaceRestoration: 5 | def name(self): 6 | return "None" 7 | 8 | def restore(self, np_image): 9 | return np_image 10 | 11 | 12 | def restore_faces(np_image): 13 | face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None] 14 | if len(face_restorers) == 0: 15 | return np_image 16 | 17 | face_restorer = face_restorers[0] 18 | 19 | return face_restorer.restore(np_image) 20 | -------------------------------------------------------------------------------- /requirements_versions.txt: -------------------------------------------------------------------------------- 1 | blendmodes==2022 2 | transformers==4.25.1 3 | accelerate==0.18.0 4 | basicsr==1.4.2 5 | gfpgan==1.3.8 6 | gradio==3.31.0 7 | numpy==1.23.5 8 | Pillow==9.5.0 9 | realesrgan==0.3.0 10 | torch 11 | omegaconf==2.2.3 12 | pytorch_lightning==1.9.4 13 | scikit-image==0.20.0 14 | timm==0.6.7 15 | piexif==1.1.3 16 | einops==0.4.1 17 | jsonmerge==1.8.0 18 | clean-fid==0.1.35 19 | resize-right==0.0.2 20 | torchdiffeq==0.2.3 21 | kornia==0.6.7 22 | lark==1.1.2 23 | inflection==0.5.1 24 | GitPython==3.1.30 25 | torchsde==0.2.5 26 | safetensors==0.3.1 27 | httpcore<=0.15 28 | fastapi==0.94.0 29 | tomesd==0.1.2 30 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | * a simple description of what you're trying to accomplish 4 | * a summary of changes in code 5 | * which issues it fixes, if any 6 | 7 | ## Screenshots/videos: 8 | 9 | 10 | ## Checklist: 11 | 12 | - [ ] I have read [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) 13 | - [ ] I have performed a self-review of my own code 14 | - [ ] My code follows the [style guidelines](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing#code-style) 15 | - [ ] My code passes [tests](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Tests) 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.ckpt 3 | *.safetensors 4 | *.pth 5 | /ESRGAN/* 6 | /SwinIR/* 7 | /repositories 8 | /venv 9 | /tmp 10 | /model.ckpt 11 | /models/**/* 12 | /GFPGANv1.3.pth 13 | /gfpgan/weights/*.pth 14 | /ui-config.json 15 | /outputs 16 | /config.json 17 | /log 18 | /webui.settings.bat 19 | /embeddings 20 | /styles.csv 21 | /params.txt 22 | /styles.csv.bak 23 | /webui-user.bat 24 | /webui-user.sh 25 | /interrogate 26 | /user.css 27 | /.idea 28 | notification.mp3 29 | /SwinIR 30 | /textual_inversion 31 | .vscode 32 | /extensions 33 | /test/stdout.txt 34 | /test/stderr.txt 35 | /cache.json* 36 | /config_states/ 37 | /node_modules 38 | /package-lock.json 39 | /.coverage* 40 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from PIL import Image 5 | from gradio.processing_utils import encode_pil_to_base64 6 | 7 | test_files_path = os.path.dirname(__file__) + "/test_files" 8 | 9 | 10 | @pytest.fixture(scope="session") # session so we don't read this over and over 11 | def img2img_basic_image_base64() -> str: 12 | return encode_pil_to_base64(Image.open(os.path.join(test_files_path, "img2img_basic.png"))) 13 | 14 | 15 | @pytest.fixture(scope="session") # session so we don't read this over and over 16 | def mask_basic_image_base64() -> str: 17 | return encode_pil_to_base64(Image.open(os.path.join(test_files_path, "mask_basic.png"))) 18 | -------------------------------------------------------------------------------- /textual_inversion_templates/style.txt: -------------------------------------------------------------------------------- 1 | a painting, art by [name] 2 | a rendering, art by [name] 3 | a cropped painting, art by [name] 4 | the painting, art by [name] 5 | a clean painting, art by [name] 6 | a dirty painting, art by [name] 7 | a dark painting, art by [name] 8 | a picture, art by [name] 9 | a cool painting, art by [name] 10 | a close-up painting, art by [name] 11 | a bright painting, art by [name] 12 | a cropped painting, art by [name] 13 | a good painting, art by [name] 14 | a close-up painting, art by [name] 15 | a rendition, art by [name] 16 | a nice painting, art by [name] 17 | a small painting, art by [name] 18 | a weird painting, art by [name] 19 | a large painting, art by [name] 20 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @AUTOMATIC1111 2 | 3 | # if you were managing a localization and were removed from this file, this is because 4 | # the intended way to do localizations now is via extensions. See: 5 | # https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions 6 | # Make a repo with your localization and since you are still listed as a collaborator 7 | # you can add it to the wiki page yourself. This change is because some people complained 8 | # the git commit log is cluttered with things unrelated to almost everyone and 9 | # because I believe this is the best overall for the project to handle localizations almost 10 | # entirely without my oversight. 11 | 12 | 13 | -------------------------------------------------------------------------------- /javascript/imageParams.js: -------------------------------------------------------------------------------- 1 | window.onload = (function() { 2 | window.addEventListener('drop', e => { 3 | const target = e.composedPath()[0]; 4 | if (target.placeholder.indexOf("Prompt") == -1) return; 5 | 6 | let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; 7 | 8 | e.stopPropagation(); 9 | e.preventDefault(); 10 | const imgParent = gradioApp().getElementById(prompt_target); 11 | const files = e.dataTransfer.files; 12 | const fileInput = imgParent.querySelector('input[type="file"]'); 13 | if (fileInput) { 14 | fileInput.files = files; 15 | fileInput.dispatchEvent(new Event('change')); 16 | } 17 | }); 18 | }); 19 | -------------------------------------------------------------------------------- /modules/shared_items.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def realesrgan_models_names(): 4 | import modules.realesrgan_model 5 | return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] 6 | 7 | 8 | def postprocessing_scripts(): 9 | import modules.scripts 10 | 11 | return modules.scripts.scripts_postproc.scripts 12 | 13 | 14 | def sd_vae_items(): 15 | import modules.sd_vae 16 | 17 | return ["Automatic", "None"] + list(modules.sd_vae.vae_dict) 18 | 19 | 20 | def refresh_vae_list(): 21 | import modules.sd_vae 22 | 23 | modules.sd_vae.refresh_vae_list() 24 | 25 | 26 | def cross_attention_optimizations(): 27 | import modules.sd_hijack 28 | 29 | return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"] 30 | 31 | 32 | -------------------------------------------------------------------------------- /textual_inversion_templates/subject.txt: -------------------------------------------------------------------------------- 1 | a photo of a [name] 2 | a rendering of a [name] 3 | a cropped photo of the [name] 4 | the photo of a [name] 5 | a photo of a clean [name] 6 | a photo of a dirty [name] 7 | a dark photo of the [name] 8 | a photo of my [name] 9 | a photo of the cool [name] 10 | a close-up photo of a [name] 11 | a bright photo of the [name] 12 | a cropped photo of a [name] 13 | a photo of the [name] 14 | a good photo of the [name] 15 | a photo of one [name] 16 | a close-up photo of the [name] 17 | a rendition of the [name] 18 | a photo of the clean [name] 19 | a rendition of a [name] 20 | a photo of a nice [name] 21 | a good photo of a [name] 22 | a photo of the nice [name] 23 | a photo of the small [name] 24 | a photo of the weird [name] 25 | a photo of the large [name] 26 | a photo of a cool [name] 27 | a photo of a small [name] 28 | -------------------------------------------------------------------------------- /webui-macos-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #################################################################### 3 | # macOS defaults # 4 | # Please modify webui-user.sh to change these instead of this file # 5 | #################################################################### 6 | 7 | if [[ -x "$(command -v python3.10)" ]] 8 | then 9 | python_cmd="python3.10" 10 | fi 11 | 12 | export install_dir="$HOME" 13 | export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" 14 | export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2" 15 | export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" 16 | export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" 17 | export PYTORCH_ENABLE_MPS_FALLBACK=1 18 | 19 | #################################################################### 20 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | 3 | target-version = "py39" 4 | 5 | extend-select = [ 6 | "B", 7 | "C", 8 | "I", 9 | "W", 10 | ] 11 | 12 | exclude = [ 13 | "extensions", 14 | "extensions-disabled", 15 | ] 16 | 17 | ignore = [ 18 | "E501", # Line too long 19 | "E731", # Do not assign a `lambda` expression, use a `def` 20 | 21 | "I001", # Import block is un-sorted or un-formatted 22 | "C901", # Function is too complex 23 | "C408", # Rewrite as a literal 24 | "W605", # invalid escape sequence, messes with some docstrings 25 | ] 26 | 27 | [tool.ruff.per-file-ignores] 28 | "webui.py" = ["E402"] # Module level import not at top of file 29 | 30 | [tool.ruff.flake8-bugbear] 31 | # Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. 32 | extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"] 33 | 34 | [tool.pytest.ini_options] 35 | base_url = "http://127.0.0.1:7860" 36 | -------------------------------------------------------------------------------- /javascript/hires_fix.js: -------------------------------------------------------------------------------- 1 | 2 | function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) { 3 | function setInactive(elem, inactive) { 4 | elem.classList.toggle('inactive', !!inactive); 5 | } 6 | 7 | var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale'); 8 | var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x'); 9 | var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y'); 10 | 11 | gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""; 12 | 13 | setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0); 14 | setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0); 15 | setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0); 16 | 17 | return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]; 18 | } 19 | -------------------------------------------------------------------------------- /textual_inversion_templates/hypernetwork.txt: -------------------------------------------------------------------------------- 1 | a photo of a [filewords] 2 | a rendering of a [filewords] 3 | a cropped photo of the [filewords] 4 | the photo of a [filewords] 5 | a photo of a clean [filewords] 6 | a photo of a dirty [filewords] 7 | a dark photo of the [filewords] 8 | a photo of my [filewords] 9 | a photo of the cool [filewords] 10 | a close-up photo of a [filewords] 11 | a bright photo of the [filewords] 12 | a cropped photo of a [filewords] 13 | a photo of the [filewords] 14 | a good photo of the [filewords] 15 | a photo of one [filewords] 16 | a close-up photo of the [filewords] 17 | a rendition of the [filewords] 18 | a photo of the clean [filewords] 19 | a rendition of a [filewords] 20 | a photo of a nice [filewords] 21 | a good photo of a [filewords] 22 | a photo of the nice [filewords] 23 | a photo of the small [filewords] 24 | a photo of the weird [filewords] 25 | a photo of the large [filewords] 26 | a photo of a cool [filewords] 27 | a photo of a small [filewords] 28 | -------------------------------------------------------------------------------- /textual_inversion_templates/style_filewords.txt: -------------------------------------------------------------------------------- 1 | a painting of [filewords], art by [name] 2 | a rendering of [filewords], art by [name] 3 | a cropped painting of [filewords], art by [name] 4 | the painting of [filewords], art by [name] 5 | a clean painting of [filewords], art by [name] 6 | a dirty painting of [filewords], art by [name] 7 | a dark painting of [filewords], art by [name] 8 | a picture of [filewords], art by [name] 9 | a cool painting of [filewords], art by [name] 10 | a close-up painting of [filewords], art by [name] 11 | a bright painting of [filewords], art by [name] 12 | a cropped painting of [filewords], art by [name] 13 | a good painting of [filewords], art by [name] 14 | a close-up painting of [filewords], art by [name] 15 | a rendition of [filewords], art by [name] 16 | a nice painting of [filewords], art by [name] 17 | a small painting of [filewords], art by [name] 18 | a weird painting of [filewords], art by [name] 19 | a large painting of [filewords], art by [name] 20 | -------------------------------------------------------------------------------- /html/image-update.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import requests 3 | 4 | 5 | def test_options_write(base_url): 6 | url_options = f"{base_url}/sdapi/v1/options" 7 | response = requests.get(url_options) 8 | assert response.status_code == 200 9 | 10 | pre_value = response.json()["send_seed"] 11 | 12 | assert requests.post(url_options, json={'send_seed': (not pre_value)}).status_code == 200 13 | 14 | response = requests.get(url_options) 15 | assert response.status_code == 200 16 | assert response.json()['send_seed'] == (not pre_value) 17 | 18 | requests.post(url_options, json={"send_seed": pre_value}) 19 | 20 | 21 | @pytest.mark.parametrize("url", [ 22 | "sdapi/v1/cmd-flags", 23 | "sdapi/v1/samplers", 24 | "sdapi/v1/upscalers", 25 | "sdapi/v1/sd-models", 26 | "sdapi/v1/hypernetworks", 27 | "sdapi/v1/face-restorers", 28 | "sdapi/v1/realesrgan-models", 29 | "sdapi/v1/prompt-styles", 30 | "sdapi/v1/embeddings", 31 | ]) 32 | def test_get_api_url(base_url, url): 33 | assert requests.get(f"{base_url}/{url}").status_code == 200 34 | -------------------------------------------------------------------------------- /modules/script_loading.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | import importlib.util 5 | 6 | 7 | def load_module(path): 8 | module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path) 9 | module = importlib.util.module_from_spec(module_spec) 10 | module_spec.loader.exec_module(module) 11 | 12 | return module 13 | 14 | 15 | def preload_extensions(extensions_dir, parser): 16 | if not os.path.isdir(extensions_dir): 17 | return 18 | 19 | for dirname in sorted(os.listdir(extensions_dir)): 20 | preload_script = os.path.join(extensions_dir, dirname, "preload.py") 21 | if not os.path.isfile(preload_script): 22 | continue 23 | 24 | try: 25 | module = load_module(preload_script) 26 | if hasattr(module, 'preload'): 27 | module.preload(parser) 28 | 29 | except Exception: 30 | print(f"Error running preload() for {preload_script}", file=sys.stderr) 31 | print(traceback.format_exc(), file=sys.stderr) 32 | -------------------------------------------------------------------------------- /modules/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class Timer: 5 | def __init__(self): 6 | self.start = time.time() 7 | self.records = {} 8 | self.total = 0 9 | 10 | def elapsed(self): 11 | end = time.time() 12 | res = end - self.start 13 | self.start = end 14 | return res 15 | 16 | def record(self, category, extra_time=0): 17 | e = self.elapsed() 18 | if category not in self.records: 19 | self.records[category] = 0 20 | 21 | self.records[category] += e + extra_time 22 | self.total += e + extra_time 23 | 24 | def summary(self): 25 | res = f"{self.total:.1f}s" 26 | 27 | additions = [x for x in self.records.items() if x[1] >= 0.1] 28 | if not additions: 29 | return res 30 | 31 | res += " (" 32 | res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions]) 33 | res += ")" 34 | 35 | return res 36 | 37 | def reset(self): 38 | self.__init__() 39 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | from modules import launch_utils 2 | 3 | 4 | args = launch_utils.args 5 | python = launch_utils.python 6 | git = launch_utils.git 7 | index_url = launch_utils.index_url 8 | dir_repos = launch_utils.dir_repos 9 | 10 | commit_hash = launch_utils.commit_hash 11 | git_tag = launch_utils.git_tag 12 | 13 | run = launch_utils.run 14 | is_installed = launch_utils.is_installed 15 | repo_dir = launch_utils.repo_dir 16 | 17 | run_pip = launch_utils.run_pip 18 | check_run_python = launch_utils.check_run_python 19 | git_clone = launch_utils.git_clone 20 | git_pull_recursive = launch_utils.git_pull_recursive 21 | run_extension_installer = launch_utils.run_extension_installer 22 | prepare_environment = launch_utils.prepare_environment 23 | configure_for_tests = launch_utils.configure_for_tests 24 | start = launch_utils.start 25 | 26 | 27 | def main(): 28 | if not args.skip_prepare_environment: 29 | prepare_environment() 30 | 31 | if args.test_server: 32 | configure_for_tests() 33 | 34 | start() 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /.github/workflows/on_pull_request.yaml: -------------------------------------------------------------------------------- 1 | name: Run Linting/Formatting on Pull Requests 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | lint-python: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Checkout Code 12 | uses: actions/checkout@v3 13 | - uses: actions/setup-python@v4 14 | with: 15 | python-version: 3.11 16 | # NB: there's no cache: pip here since we're not installing anything 17 | # from the requirements.txt file(s) in the repository; it's faster 18 | # not to have GHA download an (at the time of writing) 4 GB cache 19 | # of PyTorch and other dependencies. 20 | - name: Install Ruff 21 | run: pip install ruff==0.0.265 22 | - name: Run Ruff 23 | run: ruff . 24 | lint-js: 25 | runs-on: ubuntu-latest 26 | steps: 27 | - name: Checkout Code 28 | uses: actions/checkout@v3 29 | - name: Install Node.js 30 | uses: actions/setup-node@v3 31 | with: 32 | node-version: 18 33 | - run: npm i --ci 34 | - run: npm run lint 35 | -------------------------------------------------------------------------------- /modules/ui_extra_networks_hypernets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from modules import shared, ui_extra_networks 5 | 6 | 7 | class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): 8 | def __init__(self): 9 | super().__init__('Hypernetworks') 10 | 11 | def refresh(self): 12 | shared.reload_hypernetworks() 13 | 14 | def list_items(self): 15 | for name, path in shared.hypernetworks.items(): 16 | path, ext = os.path.splitext(path) 17 | 18 | yield { 19 | "name": name, 20 | "filename": path, 21 | "preview": self.find_preview(path), 22 | "description": self.find_description(path), 23 | "search_term": self.search_terms_from_path(path), 24 | "prompt": json.dumps(f""), 25 | "local_preview": f"{path}.preview.{shared.opts.samples_format}", 26 | } 27 | 28 | def allowed_directories_for_previews(self): 29 | return [shared.cmd_opts.hypernetwork_dir] 30 | 31 | -------------------------------------------------------------------------------- /textual_inversion_templates/subject_filewords.txt: -------------------------------------------------------------------------------- 1 | a photo of a [name], [filewords] 2 | a rendering of a [name], [filewords] 3 | a cropped photo of the [name], [filewords] 4 | the photo of a [name], [filewords] 5 | a photo of a clean [name], [filewords] 6 | a photo of a dirty [name], [filewords] 7 | a dark photo of the [name], [filewords] 8 | a photo of my [name], [filewords] 9 | a photo of the cool [name], [filewords] 10 | a close-up photo of a [name], [filewords] 11 | a bright photo of the [name], [filewords] 12 | a cropped photo of a [name], [filewords] 13 | a photo of the [name], [filewords] 14 | a good photo of the [name], [filewords] 15 | a photo of one [name], [filewords] 16 | a close-up photo of the [name], [filewords] 17 | a rendition of the [name], [filewords] 18 | a photo of the clean [name], [filewords] 19 | a rendition of a [name], [filewords] 20 | a photo of a nice [name], [filewords] 21 | a good photo of a [name], [filewords] 22 | a photo of the nice [name], [filewords] 23 | a photo of the small [name], [filewords] 24 | a photo of the weird [name], [filewords] 25 | a photo of the large [name], [filewords] 26 | a photo of a cool [name], [filewords] 27 | a photo of a small [name], [filewords] 28 | -------------------------------------------------------------------------------- /modules/localization.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import traceback 5 | 6 | 7 | localizations = {} 8 | 9 | 10 | def list_localizations(dirname): 11 | localizations.clear() 12 | 13 | for file in os.listdir(dirname): 14 | fn, ext = os.path.splitext(file) 15 | if ext.lower() != ".json": 16 | continue 17 | 18 | localizations[fn] = os.path.join(dirname, file) 19 | 20 | from modules import scripts 21 | for file in scripts.list_scripts("localizations", ".json"): 22 | fn, ext = os.path.splitext(file.filename) 23 | localizations[fn] = file.path 24 | 25 | 26 | def localization_js(current_localization_name: str) -> str: 27 | fn = localizations.get(current_localization_name, None) 28 | data = {} 29 | if fn is not None: 30 | try: 31 | with open(fn, "r", encoding="utf8") as file: 32 | data = json.load(file) 33 | except Exception: 34 | print(f"Error loading localization from {fn}:", file=sys.stderr) 35 | print(traceback.format_exc(), file=sys.stderr) 36 | 37 | return f"window.localization = {json.dumps(data)}" 38 | -------------------------------------------------------------------------------- /scripts/postprocessing_gfpgan.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | from modules import scripts_postprocessing, gfpgan_model 5 | import gradio as gr 6 | 7 | from modules.ui_components import FormRow 8 | 9 | 10 | class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): 11 | name = "GFPGAN" 12 | order = 2000 13 | 14 | def ui(self): 15 | with FormRow(): 16 | gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") 17 | 18 | return { 19 | "gfpgan_visibility": gfpgan_visibility, 20 | } 21 | 22 | def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): 23 | if gfpgan_visibility == 0: 24 | return 25 | 26 | restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) 27 | res = Image.fromarray(restored_img) 28 | 29 | if gfpgan_visibility < 1.0: 30 | res = Image.blend(pp.image, res, gfpgan_visibility) 31 | 32 | pp.image = res 33 | pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3) 34 | -------------------------------------------------------------------------------- /modules/ngrok.py: -------------------------------------------------------------------------------- 1 | import ngrok 2 | 3 | # Connect to ngrok for ingress 4 | def connect(token, port, options): 5 | account = None 6 | if token is None: 7 | token = 'None' 8 | else: 9 | if ':' in token: 10 | # token = authtoken:username:password 11 | token, username, password = token.split(':', 2) 12 | account = f"{username}:{password}" 13 | 14 | # For all options see: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py 15 | if not options.get('authtoken_from_env'): 16 | options['authtoken'] = token 17 | if account: 18 | options['basic_auth'] = account 19 | if not options.get('session_metadata'): 20 | options['session_metadata'] = 'stable-diffusion-webui' 21 | 22 | 23 | try: 24 | public_url = ngrok.connect(f"127.0.0.1:{port}", **options).url() 25 | except Exception as e: 26 | print(f'Invalid ngrok authtoken? ngrok connection aborted due to: {e}\n' 27 | f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') 28 | else: 29 | print(f'ngrok connected to localhost:{port}! URL: {public_url}\n' 30 | 'You can use this link after the launch is complete.') 31 | -------------------------------------------------------------------------------- /modules/extra_networks_hypernet.py: -------------------------------------------------------------------------------- 1 | from modules import extra_networks, shared 2 | from modules.hypernetworks import hypernetwork 3 | 4 | 5 | class ExtraNetworkHypernet(extra_networks.ExtraNetwork): 6 | def __init__(self): 7 | super().__init__('hypernet') 8 | 9 | def activate(self, p, params_list): 10 | additional = shared.opts.sd_hypernetwork 11 | 12 | if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: 13 | hypernet_prompt_text = f"" 14 | p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts] 15 | params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) 16 | 17 | names = [] 18 | multipliers = [] 19 | for params in params_list: 20 | assert len(params.items) > 0 21 | 22 | names.append(params.items[0]) 23 | multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) 24 | 25 | hypernetwork.load_hypernetworks(names, multipliers) 26 | 27 | def deactivate(self, p): 28 | pass 29 | -------------------------------------------------------------------------------- /modules/ui_extra_networks_textual_inversion.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from modules import ui_extra_networks, sd_hijack, shared 5 | 6 | 7 | class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): 8 | def __init__(self): 9 | super().__init__('Textual Inversion') 10 | self.allow_negative_prompt = True 11 | 12 | def refresh(self): 13 | sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) 14 | 15 | def list_items(self): 16 | for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values(): 17 | path, ext = os.path.splitext(embedding.filename) 18 | yield { 19 | "name": embedding.name, 20 | "filename": embedding.filename, 21 | "preview": self.find_preview(path), 22 | "description": self.find_description(path), 23 | "search_term": self.search_terms_from_path(embedding.filename), 24 | "prompt": json.dumps(embedding.name), 25 | "local_preview": f"{path}.preview.{shared.opts.samples_format}", 26 | } 27 | 28 | def allowed_directories_for_previews(self): 29 | return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) 30 | -------------------------------------------------------------------------------- /test/test_extras.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | 4 | def test_simple_upscaling_performed(base_url, img2img_basic_image_base64): 5 | payload = { 6 | "resize_mode": 0, 7 | "show_extras_results": True, 8 | "gfpgan_visibility": 0, 9 | "codeformer_visibility": 0, 10 | "codeformer_weight": 0, 11 | "upscaling_resize": 2, 12 | "upscaling_resize_w": 128, 13 | "upscaling_resize_h": 128, 14 | "upscaling_crop": True, 15 | "upscaler_1": "Lanczos", 16 | "upscaler_2": "None", 17 | "extras_upscaler_2_visibility": 0, 18 | "image": img2img_basic_image_base64, 19 | } 20 | assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200 21 | 22 | 23 | def test_png_info_performed(base_url, img2img_basic_image_base64): 24 | payload = { 25 | "image": img2img_basic_image_base64, 26 | } 27 | assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200 28 | 29 | 30 | def test_interrogate_performed(base_url, img2img_basic_image_base64): 31 | payload = { 32 | "image": img2img_basic_image_base64, 33 | "model": "clip", 34 | } 35 | assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200 36 | -------------------------------------------------------------------------------- /modules/ui_extra_networks_checkpoints.py: -------------------------------------------------------------------------------- 1 | import html 2 | import json 3 | import os 4 | 5 | from modules import shared, ui_extra_networks, sd_models 6 | 7 | 8 | class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): 9 | def __init__(self): 10 | super().__init__('Checkpoints') 11 | 12 | def refresh(self): 13 | shared.refresh_checkpoints() 14 | 15 | def list_items(self): 16 | checkpoint: sd_models.CheckpointInfo 17 | for name, checkpoint in sd_models.checkpoints_list.items(): 18 | path, ext = os.path.splitext(checkpoint.filename) 19 | yield { 20 | "name": checkpoint.name_for_extra, 21 | "filename": path, 22 | "preview": self.find_preview(path), 23 | "description": self.find_description(path), 24 | "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), 25 | "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', 26 | "local_preview": f"{path}.{shared.opts.samples_format}", 27 | } 28 | 29 | def allowed_directories_for_previews(self): 30 | return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] 31 | 32 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/ui_extra_networks_lora.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import lora 4 | 5 | from modules import shared, ui_extra_networks 6 | 7 | 8 | class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): 9 | def __init__(self): 10 | super().__init__('Lora') 11 | 12 | def refresh(self): 13 | lora.list_available_loras() 14 | 15 | def list_items(self): 16 | for name, lora_on_disk in lora.available_loras.items(): 17 | path, ext = os.path.splitext(lora_on_disk.filename) 18 | 19 | alias = lora_on_disk.get_alias() 20 | 21 | yield { 22 | "name": name, 23 | "filename": path, 24 | "preview": self.find_preview(path), 25 | "description": self.find_description(path), 26 | "search_term": self.search_terms_from_path(lora_on_disk.filename), 27 | "prompt": json.dumps(f""), 28 | "local_preview": f"{path}.{shared.opts.samples_format}", 29 | "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, 30 | } 31 | 32 | def allowed_directories_for_previews(self): 33 | return [shared.cmd_opts.lora_dir] 34 | 35 | -------------------------------------------------------------------------------- /modules/errors.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | 4 | 5 | def print_error_explanation(message): 6 | lines = message.strip().split("\n") 7 | max_len = max([len(x) for x in lines]) 8 | 9 | print('=' * max_len, file=sys.stderr) 10 | for line in lines: 11 | print(line, file=sys.stderr) 12 | print('=' * max_len, file=sys.stderr) 13 | 14 | 15 | def display(e: Exception, task): 16 | print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr) 17 | print(traceback.format_exc(), file=sys.stderr) 18 | 19 | message = str(e) 20 | if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message: 21 | print_error_explanation(""" 22 | The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file. 23 | See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this. 24 | """) 25 | 26 | 27 | already_displayed = {} 28 | 29 | 30 | def display_once(e: Exception, task): 31 | if task in already_displayed: 32 | return 33 | 34 | display(e, task) 35 | 36 | already_displayed[task] = 1 37 | 38 | 39 | def run(code, task): 40 | try: 41 | code() 42 | except Exception as e: 43 | display(task, e) 44 | -------------------------------------------------------------------------------- /modules/sd_hijack_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | class CondFunc: 4 | def __new__(cls, orig_func, sub_func, cond_func): 5 | self = super(CondFunc, cls).__new__(cls) 6 | if isinstance(orig_func, str): 7 | func_path = orig_func.split('.') 8 | for i in range(len(func_path)-1, -1, -1): 9 | try: 10 | resolved_obj = importlib.import_module('.'.join(func_path[:i])) 11 | break 12 | except ImportError: 13 | pass 14 | for attr_name in func_path[i:-1]: 15 | resolved_obj = getattr(resolved_obj, attr_name) 16 | orig_func = getattr(resolved_obj, func_path[-1]) 17 | setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) 18 | self.__init__(orig_func, sub_func, cond_func) 19 | return lambda *args, **kwargs: self(*args, **kwargs) 20 | def __init__(self, orig_func, sub_func, cond_func): 21 | self.__orig_func = orig_func 22 | self.__sub_func = sub_func 23 | self.__cond_func = cond_func 24 | def __call__(self, *args, **kwargs): 25 | if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): 26 | return self.__sub_func(self.__orig_func, *args, **kwargs) 27 | else: 28 | return self.__orig_func(*args, **kwargs) 29 | -------------------------------------------------------------------------------- /modules/paths_internal.py: -------------------------------------------------------------------------------- 1 | """this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py""" 2 | 3 | import argparse 4 | import os 5 | import sys 6 | import shlex 7 | 8 | commandline_args = os.environ.get('COMMANDLINE_ARGS', "") 9 | sys.argv += shlex.split(commandline_args) 10 | 11 | modules_path = os.path.dirname(os.path.realpath(__file__)) 12 | script_path = os.path.dirname(modules_path) 13 | 14 | sd_configs_path = os.path.join(script_path, "configs") 15 | sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml") 16 | sd_model_file = os.path.join(script_path, 'model.ckpt') 17 | default_sd_model_file = sd_model_file 18 | 19 | # Parse the --data-dir flag first so we can use it as a base for our other argument default values 20 | parser_pre = argparse.ArgumentParser(add_help=False) 21 | parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", ) 22 | cmd_opts_pre = parser_pre.parse_known_args()[0] 23 | 24 | data_path = cmd_opts_pre.data_dir 25 | 26 | models_path = os.path.join(data_path, "models") 27 | extensions_dir = os.path.join(data_path, "extensions") 28 | extensions_builtin_dir = os.path.join(script_path, "extensions-builtin") 29 | config_states_dir = os.path.join(script_path, "config_states") 30 | 31 | roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf') 32 | -------------------------------------------------------------------------------- /modules/sd_hijack_checkpoint.py: -------------------------------------------------------------------------------- 1 | from torch.utils.checkpoint import checkpoint 2 | 3 | import ldm.modules.attention 4 | import ldm.modules.diffusionmodules.openaimodel 5 | 6 | 7 | def BasicTransformerBlock_forward(self, x, context=None): 8 | return checkpoint(self._forward, x, context) 9 | 10 | 11 | def AttentionBlock_forward(self, x): 12 | return checkpoint(self._forward, x) 13 | 14 | 15 | def ResBlock_forward(self, x, emb): 16 | return checkpoint(self._forward, x, emb) 17 | 18 | 19 | stored = [] 20 | 21 | 22 | def add(): 23 | if len(stored) != 0: 24 | return 25 | 26 | stored.extend([ 27 | ldm.modules.attention.BasicTransformerBlock.forward, 28 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward, 29 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward 30 | ]) 31 | 32 | ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward 33 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward 34 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward 35 | 36 | 37 | def remove(): 38 | if len(stored) == 0: 39 | return 40 | 41 | ldm.modules.attention.BasicTransformerBlock.forward = stored[0] 42 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1] 43 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2] 44 | 45 | stored.clear() 46 | 47 | -------------------------------------------------------------------------------- /modules/sd_hijack_open_clip.py: -------------------------------------------------------------------------------- 1 | import open_clip.tokenizer 2 | import torch 3 | 4 | from modules import sd_hijack_clip, devices 5 | from modules.shared import opts 6 | 7 | tokenizer = open_clip.tokenizer._tokenizer 8 | 9 | 10 | class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): 11 | def __init__(self, wrapped, hijack): 12 | super().__init__(wrapped, hijack) 13 | 14 | self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] 15 | self.id_start = tokenizer.encoder[""] 16 | self.id_end = tokenizer.encoder[""] 17 | self.id_pad = 0 18 | 19 | def tokenize(self, texts): 20 | assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' 21 | 22 | tokenized = [tokenizer.encode(text) for text in texts] 23 | 24 | return tokenized 25 | 26 | def encode_with_transformers(self, tokens): 27 | # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers 28 | z = self.wrapped.encode_with_transformer(tokens) 29 | 30 | return z 31 | 32 | def encode_embedding_init_text(self, init_text, nvpt): 33 | ids = tokenizer.encode(init_text) 34 | ids = torch.asarray([ids], device=devices.device, dtype=torch.int) 35 | embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) 36 | 37 | return embedded 38 | -------------------------------------------------------------------------------- /javascript/imageMaskFix.js: -------------------------------------------------------------------------------- 1 | /** 2 | * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668 3 | * @see https://github.com/gradio-app/gradio/issues/1721 4 | */ 5 | function imageMaskResize() { 6 | const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas'); 7 | if (!canvases.length) { 8 | window.removeEventListener('resize', imageMaskResize); 9 | return; 10 | } 11 | 12 | const wrapper = canvases[0].closest('.touch-none'); 13 | const previewImage = wrapper.previousElementSibling; 14 | 15 | if (!previewImage.complete) { 16 | previewImage.addEventListener('load', imageMaskResize); 17 | return; 18 | } 19 | 20 | const w = previewImage.width; 21 | const h = previewImage.height; 22 | const nw = previewImage.naturalWidth; 23 | const nh = previewImage.naturalHeight; 24 | const portrait = nh > nw; 25 | 26 | const wW = Math.min(w, portrait ? h / nh * nw : w / nw * nw); 27 | const wH = Math.min(h, portrait ? h / nh * nh : w / nw * nh); 28 | 29 | wrapper.style.width = `${wW}px`; 30 | wrapper.style.height = `${wH}px`; 31 | wrapper.style.left = `0px`; 32 | wrapper.style.top = `0px`; 33 | 34 | canvases.forEach(c => { 35 | c.style.width = c.style.height = ''; 36 | c.style.maxWidth = '100%'; 37 | c.style.maxHeight = '100%'; 38 | c.style.objectFit = 'contain'; 39 | }); 40 | } 41 | 42 | onUiUpdate(imageMaskResize); 43 | window.addEventListener('resize', imageMaskResize); 44 | -------------------------------------------------------------------------------- /modules/textual_inversion/logging.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import os 4 | 5 | saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"} 6 | saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"} 7 | saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"} 8 | saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet 9 | saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"} 10 | 11 | 12 | def save_settings_to_file(log_directory, all_params): 13 | now = datetime.datetime.now() 14 | params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")} 15 | 16 | keys = saved_params_all 17 | if all_params.get('preview_from_txt2img'): 18 | keys = keys | saved_params_previews 19 | 20 | params.update({k: v for k, v in all_params.items() if k in keys}) 21 | 22 | filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json' 23 | with open(os.path.join(log_directory, filename), "w") as file: 24 | json.dump(params, file, indent=4) 25 | -------------------------------------------------------------------------------- /modules/sd_hijack_xlmr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from modules import sd_hijack_clip, devices 4 | 5 | 6 | class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords): 7 | def __init__(self, wrapped, hijack): 8 | super().__init__(wrapped, hijack) 9 | 10 | self.id_start = wrapped.config.bos_token_id 11 | self.id_end = wrapped.config.eos_token_id 12 | self.id_pad = wrapped.config.pad_token_id 13 | 14 | self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have bits for comma 15 | 16 | def encode_with_transformers(self, tokens): 17 | # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a 18 | # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer 19 | # layer to work with - you have to use the last 20 | 21 | attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64) 22 | features = self.wrapped(input_ids=tokens, attention_mask=attention_mask) 23 | z = features['projection_state'] 24 | 25 | return z 26 | 27 | def encode_embedding_init_text(self, init_text, nvpt): 28 | embedding_layer = self.wrapped.roberta.embeddings 29 | ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] 30 | embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) 31 | 32 | return embedded 33 | -------------------------------------------------------------------------------- /webui-user.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ######################################################### 3 | # Uncomment and change the variables below to your need:# 4 | ######################################################### 5 | 6 | # Install directory without trailing slash 7 | #install_dir="/home/$(whoami)" 8 | 9 | # Name of the subdirectory 10 | #clone_dir="stable-diffusion-webui" 11 | 12 | # Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention" 13 | #export COMMANDLINE_ARGS="" 14 | 15 | # python3 executable 16 | #python_cmd="python3" 17 | 18 | # git executable 19 | #export GIT="git" 20 | 21 | # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) 22 | #venv_dir="venv" 23 | 24 | # script to launch to start the app 25 | #export LAUNCH_SCRIPT="launch.py" 26 | 27 | # install command for torch 28 | #export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113" 29 | 30 | # Requirements file to use for stable-diffusion-webui 31 | #export REQS_FILE="requirements_versions.txt" 32 | 33 | # Fixed git repos 34 | #export K_DIFFUSION_PACKAGE="" 35 | #export GFPGAN_PACKAGE="" 36 | 37 | # Fixed git commits 38 | #export STABLE_DIFFUSION_COMMIT_HASH="" 39 | #export TAMING_TRANSFORMERS_COMMIT_HASH="" 40 | #export CODEFORMER_COMMIT_HASH="" 41 | #export BLIP_COMMIT_HASH="" 42 | 43 | # Uncomment to enable accelerated launch 44 | #export ACCELERATE="True" 45 | 46 | # Uncomment to disable TCMalloc 47 | #export NO_TCMALLOC="True" 48 | 49 | ########################################### 50 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Suggest an idea for this project 3 | title: "[Feature Request]: " 4 | labels: ["enhancement"] 5 | 6 | body: 7 | - type: checkboxes 8 | attributes: 9 | label: Is there an existing issue for this? 10 | description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit. 11 | options: 12 | - label: I have searched the existing issues and checked the recent builds/commits 13 | required: true 14 | - type: markdown 15 | attributes: 16 | value: | 17 | *Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible* 18 | - type: textarea 19 | id: feature 20 | attributes: 21 | label: What would your feature do ? 22 | description: Tell us about your feature in a very clear and simple way, and what problem it would solve 23 | validations: 24 | required: true 25 | - type: textarea 26 | id: workflow 27 | attributes: 28 | label: Proposed workflow 29 | description: Please provide us with step by step information on how you'd like the feature to be accessed and used 30 | value: | 31 | 1. Go to .... 32 | 2. Press .... 33 | 3. ... 34 | validations: 35 | required: true 36 | - type: textarea 37 | id: misc 38 | attributes: 39 | label: Additional information 40 | description: Add any other context or screenshots about the feature request here. 41 | -------------------------------------------------------------------------------- /scripts/postprocessing_codeformer.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | from modules import scripts_postprocessing, codeformer_model 5 | import gradio as gr 6 | 7 | from modules.ui_components import FormRow 8 | 9 | 10 | class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): 11 | name = "CodeFormer" 12 | order = 3000 13 | 14 | def ui(self): 15 | with FormRow(): 16 | codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility") 17 | codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") 18 | 19 | return { 20 | "codeformer_visibility": codeformer_visibility, 21 | "codeformer_weight": codeformer_weight, 22 | } 23 | 24 | def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): 25 | if codeformer_visibility == 0: 26 | return 27 | 28 | restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) 29 | res = Image.fromarray(restored_img) 30 | 31 | if codeformer_visibility < 1.0: 32 | res = Image.blend(pp.image, res, codeformer_visibility) 33 | 34 | pp.image = res 35 | pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3) 36 | pp.info["CodeFormer weight"] = round(codeformer_weight, 3) 37 | -------------------------------------------------------------------------------- /javascript/generationParams.js: -------------------------------------------------------------------------------- 1 | // attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes 2 | 3 | let txt2img_gallery, img2img_gallery, modal = undefined; 4 | onUiUpdate(function() { 5 | if (!txt2img_gallery) { 6 | txt2img_gallery = attachGalleryListeners("txt2img"); 7 | } 8 | if (!img2img_gallery) { 9 | img2img_gallery = attachGalleryListeners("img2img"); 10 | } 11 | if (!modal) { 12 | modal = gradioApp().getElementById('lightboxModal'); 13 | modalObserver.observe(modal, {attributes: true, attributeFilter: ['style']}); 14 | } 15 | }); 16 | 17 | let modalObserver = new MutationObserver(function(mutations) { 18 | mutations.forEach(function(mutationRecord) { 19 | let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText; 20 | if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img')) { 21 | gradioApp().getElementById(selectedTab + "_generation_info_button")?.click(); 22 | } 23 | }); 24 | }); 25 | 26 | function attachGalleryListeners(tab_name) { 27 | var gallery = gradioApp().querySelector('#' + tab_name + '_gallery'); 28 | gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name + "_generation_info_button").click()); 29 | gallery?.addEventListener('keydown', (e) => { 30 | if (e.keyCode == 37 || e.keyCode == 39) { // left or right arrow 31 | gradioApp().getElementById(tab_name + "_generation_info_button").click(); 32 | } 33 | }); 34 | return gallery; 35 | } 36 | -------------------------------------------------------------------------------- /modules/hypernetworks/ui.py: -------------------------------------------------------------------------------- 1 | import html 2 | 3 | import gradio as gr 4 | import modules.hypernetworks.hypernetwork 5 | from modules import devices, sd_hijack, shared 6 | 7 | not_available = ["hardswish", "multiheadattention"] 8 | keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available] 9 | 10 | 11 | def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): 12 | filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure) 13 | 14 | return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", "" 15 | 16 | 17 | def train_hypernetwork(*args): 18 | shared.loaded_hypernetworks = [] 19 | 20 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' 21 | 22 | try: 23 | sd_hijack.undo_optimizations() 24 | 25 | hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args) 26 | 27 | res = f""" 28 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. 29 | Hypernetwork saved to {html.escape(filename)} 30 | """ 31 | return res, "" 32 | except Exception: 33 | raise 34 | finally: 35 | shared.sd_model.cond_stage_model.to(devices.device) 36 | shared.sd_model.first_stage_model.to(devices.device) 37 | sd_hijack.apply_optimizations() 38 | 39 | -------------------------------------------------------------------------------- /modules/scripts_auto_postprocessing.py: -------------------------------------------------------------------------------- 1 | from modules import scripts, scripts_postprocessing, shared 2 | 3 | 4 | class ScriptPostprocessingForMainUI(scripts.Script): 5 | def __init__(self, script_postproc): 6 | self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc 7 | self.postprocessing_controls = None 8 | 9 | def title(self): 10 | return self.script.name 11 | 12 | def show(self, is_img2img): 13 | return scripts.AlwaysVisible 14 | 15 | def ui(self, is_img2img): 16 | self.postprocessing_controls = self.script.ui() 17 | return self.postprocessing_controls.values() 18 | 19 | def postprocess_image(self, p, script_pp, *args): 20 | args_dict = dict(zip(self.postprocessing_controls, args)) 21 | 22 | pp = scripts_postprocessing.PostprocessedImage(script_pp.image) 23 | pp.info = {} 24 | self.script.process(pp, **args_dict) 25 | p.extra_generation_params.update(pp.info) 26 | script_pp.image = pp.image 27 | 28 | 29 | def create_auto_preprocessing_script_data(): 30 | from modules import scripts 31 | 32 | res = [] 33 | 34 | for name in shared.opts.postprocessing_enable_in_main_ui: 35 | script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None) 36 | if script is None: 37 | continue 38 | 39 | constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class()) 40 | res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module)) 41 | 42 | return res 43 | -------------------------------------------------------------------------------- /modules/textual_inversion/ui.py: -------------------------------------------------------------------------------- 1 | import html 2 | 3 | import gradio as gr 4 | 5 | import modules.textual_inversion.textual_inversion 6 | import modules.textual_inversion.preprocess 7 | from modules import sd_hijack, shared 8 | 9 | 10 | def create_embedding(name, initialization_text, nvpt, overwrite_old): 11 | filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text) 12 | 13 | sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() 14 | 15 | return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" 16 | 17 | 18 | def preprocess(*args): 19 | modules.textual_inversion.preprocess.preprocess(*args) 20 | 21 | return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", "" 22 | 23 | 24 | def train_embedding(*args): 25 | 26 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' 27 | 28 | apply_optimizations = shared.opts.training_xattention_optimizations 29 | try: 30 | if not apply_optimizations: 31 | sd_hijack.undo_optimizations() 32 | 33 | embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) 34 | 35 | res = f""" 36 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps. 37 | Embedding saved to {html.escape(filename)} 38 | """ 39 | return res, "" 40 | except Exception: 41 | raise 42 | finally: 43 | if not apply_optimizations: 44 | sd_hijack.apply_optimizations() 45 | 46 | -------------------------------------------------------------------------------- /modules/sd_samplers.py: -------------------------------------------------------------------------------- 1 | from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared 2 | 3 | # imports for functions that previously were here and are used by other modules 4 | from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 5 | 6 | all_samplers = [ 7 | *sd_samplers_kdiffusion.samplers_data_k_diffusion, 8 | *sd_samplers_compvis.samplers_data_compvis, 9 | ] 10 | all_samplers_map = {x.name: x for x in all_samplers} 11 | 12 | samplers = [] 13 | samplers_for_img2img = [] 14 | samplers_map = {} 15 | 16 | 17 | def find_sampler_config(name): 18 | if name is not None: 19 | config = all_samplers_map.get(name, None) 20 | else: 21 | config = all_samplers[0] 22 | 23 | return config 24 | 25 | 26 | def create_sampler(name, model): 27 | config = find_sampler_config(name) 28 | 29 | assert config is not None, f'bad sampler name: {name}' 30 | 31 | sampler = config.constructor(model) 32 | sampler.config = config 33 | 34 | return sampler 35 | 36 | 37 | def set_samplers(): 38 | global samplers, samplers_for_img2img 39 | 40 | hidden = set(shared.opts.hide_samplers) 41 | hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC']) 42 | 43 | samplers = [x for x in all_samplers if x.name not in hidden] 44 | samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] 45 | 46 | samplers_map.clear() 47 | for sampler in all_samplers: 48 | samplers_map[sampler.name.lower()] = sampler.name 49 | for alias in sampler.aliases: 50 | samplers_map[alias.lower()] = sampler.name 51 | 52 | 53 | set_samplers() 54 | -------------------------------------------------------------------------------- /javascript/notification.js: -------------------------------------------------------------------------------- 1 | // Monitors the gallery and sends a browser notification when the leading image is new. 2 | 3 | let lastHeadImg = null; 4 | 5 | let notificationButton = null; 6 | 7 | onUiUpdate(function() { 8 | if (notificationButton == null) { 9 | notificationButton = gradioApp().getElementById('request_notifications'); 10 | 11 | if (notificationButton != null) { 12 | notificationButton.addEventListener('click', () => { 13 | void Notification.requestPermission(); 14 | }, true); 15 | } 16 | } 17 | 18 | const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] .thumbnail-item > img'); 19 | 20 | if (galleryPreviews == null) return; 21 | 22 | const headImg = galleryPreviews[0]?.src; 23 | 24 | if (headImg == null || headImg == lastHeadImg) return; 25 | 26 | lastHeadImg = headImg; 27 | 28 | // play notification sound if available 29 | gradioApp().querySelector('#audio_notification audio')?.play(); 30 | 31 | if (document.hasFocus()) return; 32 | 33 | // Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated. 34 | const imgs = new Set(Array.from(galleryPreviews).map(img => img.src)); 35 | 36 | const notification = new Notification( 37 | 'Stable Diffusion', 38 | { 39 | body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`, 40 | icon: headImg, 41 | image: headImg, 42 | } 43 | ); 44 | 45 | notification.onclick = function(_) { 46 | parent.focus(); 47 | this.close(); 48 | }; 49 | }); 50 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/extra_networks_lora.py: -------------------------------------------------------------------------------- 1 | from modules import extra_networks, shared 2 | import lora 3 | 4 | 5 | class ExtraNetworkLora(extra_networks.ExtraNetwork): 6 | def __init__(self): 7 | super().__init__('lora') 8 | 9 | def activate(self, p, params_list): 10 | additional = shared.opts.sd_lora 11 | 12 | if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: 13 | p.all_prompts = [x + f"" for x in p.all_prompts] 14 | params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) 15 | 16 | names = [] 17 | multipliers = [] 18 | for params in params_list: 19 | assert len(params.items) > 0 20 | 21 | names.append(params.items[0]) 22 | multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) 23 | 24 | lora.load_loras(names, multipliers) 25 | 26 | if shared.opts.lora_add_hashes_to_infotext: 27 | lora_hashes = [] 28 | for item in lora.loaded_loras: 29 | shorthash = item.lora_on_disk.shorthash 30 | if not shorthash: 31 | continue 32 | 33 | alias = item.mentioned_name 34 | if not alias: 35 | continue 36 | 37 | alias = alias.replace(":", "").replace(",", "") 38 | 39 | lora_hashes.append(f"{alias}: {shorthash}") 40 | 41 | if lora_hashes: 42 | p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes) 43 | 44 | def deactivate(self, p): 45 | pass 46 | -------------------------------------------------------------------------------- /extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js: -------------------------------------------------------------------------------- 1 | // Stable Diffusion WebUI - Bracket checker 2 | // By Hingashi no Florin/Bwin4L & @akx 3 | // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs. 4 | // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong. 5 | 6 | function checkBrackets(textArea, counterElt) { 7 | var counts = {}; 8 | (textArea.value.match(/[(){}[\]]/g) || []).forEach(bracket => { 9 | counts[bracket] = (counts[bracket] || 0) + 1; 10 | }); 11 | var errors = []; 12 | 13 | function checkPair(open, close, kind) { 14 | if (counts[open] !== counts[close]) { 15 | errors.push( 16 | `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.` 17 | ); 18 | } 19 | } 20 | 21 | checkPair('(', ')', 'round brackets'); 22 | checkPair('[', ']', 'square brackets'); 23 | checkPair('{', '}', 'curly brackets'); 24 | counterElt.title = errors.join('\n'); 25 | counterElt.classList.toggle('error', errors.length !== 0); 26 | } 27 | 28 | function setupBracketChecking(id_prompt, id_counter) { 29 | var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); 30 | var counter = gradioApp().getElementById(id_counter); 31 | 32 | if (textarea && counter) { 33 | textarea.addEventListener("input", () => checkBrackets(textarea, counter)); 34 | } 35 | } 36 | 37 | onUiLoaded(function() { 38 | setupBracketChecking('txt2img_prompt', 'txt2img_token_counter'); 39 | setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter'); 40 | setupBracketChecking('img2img_prompt', 'img2img_token_counter'); 41 | setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter'); 42 | }); 43 | -------------------------------------------------------------------------------- /javascript/imageviewerGamepad.js: -------------------------------------------------------------------------------- 1 | window.addEventListener('gamepadconnected', (e) => { 2 | const index = e.gamepad.index; 3 | let isWaiting = false; 4 | setInterval(async() => { 5 | if (!opts.js_modal_lightbox_gamepad || isWaiting) return; 6 | const gamepad = navigator.getGamepads()[index]; 7 | const xValue = gamepad.axes[0]; 8 | if (xValue <= -0.3) { 9 | modalPrevImage(e); 10 | isWaiting = true; 11 | } else if (xValue >= 0.3) { 12 | modalNextImage(e); 13 | isWaiting = true; 14 | } 15 | if (isWaiting) { 16 | await sleepUntil(() => { 17 | const xValue = navigator.getGamepads()[index].axes[0]; 18 | if (xValue < 0.3 && xValue > -0.3) { 19 | return true; 20 | } 21 | }, opts.js_modal_lightbox_gamepad_repeat); 22 | isWaiting = false; 23 | } 24 | }, 10); 25 | }); 26 | 27 | /* 28 | Primarily for vr controller type pointer devices. 29 | I use the wheel event because there's currently no way to do it properly with web xr. 30 | */ 31 | let isScrolling = false; 32 | window.addEventListener('wheel', (e) => { 33 | if (!opts.js_modal_lightbox_gamepad || isScrolling) return; 34 | isScrolling = true; 35 | 36 | if (e.deltaX <= -0.6) { 37 | modalPrevImage(e); 38 | } else if (e.deltaX >= 0.6) { 39 | modalNextImage(e); 40 | } 41 | 42 | setTimeout(() => { 43 | isScrolling = false; 44 | }, opts.js_modal_lightbox_gamepad_repeat); 45 | }); 46 | 47 | function sleepUntil(f, timeout) { 48 | return new Promise((resolve) => { 49 | const timeStart = new Date(); 50 | const wait = setInterval(function() { 51 | if (f() || new Date() - timeStart > timeout) { 52 | clearInterval(wait); 53 | resolve(); 54 | } 55 | }, 20); 56 | }); 57 | } 58 | -------------------------------------------------------------------------------- /configs/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /modules/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401 4 | 5 | import modules.safe # noqa: F401 6 | 7 | 8 | # data_path = cmd_opts_pre.data 9 | sys.path.insert(0, script_path) 10 | 11 | # search for directory of stable diffusion in following places 12 | sd_path = None 13 | possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)] 14 | for possible_sd_path in possible_sd_paths: 15 | if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): 16 | sd_path = os.path.abspath(possible_sd_path) 17 | break 18 | 19 | assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}" 20 | 21 | path_dirs = [ 22 | (sd_path, 'ldm', 'Stable Diffusion', []), 23 | (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []), 24 | (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), 25 | (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), 26 | (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), 27 | ] 28 | 29 | paths = {} 30 | 31 | for d, must_exist, what, options in path_dirs: 32 | must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist)) 33 | if not os.path.exists(must_exist_path): 34 | print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr) 35 | else: 36 | d = os.path.abspath(d) 37 | if "atstart" in options: 38 | sys.path.insert(0, d) 39 | else: 40 | sys.path.append(d) 41 | paths[what] = d 42 | 43 | 44 | class Prioritize: 45 | def __init__(self, name): 46 | self.name = name 47 | self.path = None 48 | 49 | def __enter__(self): 50 | self.path = sys.path.copy() 51 | sys.path = [paths[self.name]] + sys.path 52 | 53 | def __exit__(self, exc_type, exc_val, exc_tb): 54 | sys.path = self.path 55 | self.path = None 56 | -------------------------------------------------------------------------------- /modules/ui_components.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | 4 | class FormComponent: 5 | def get_expected_parent(self): 6 | return gr.components.Form 7 | 8 | 9 | gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent 10 | 11 | 12 | class ToolButton(FormComponent, gr.Button): 13 | """Small button with single emoji as text, fits inside gradio forms""" 14 | 15 | def __init__(self, *args, **kwargs): 16 | classes = kwargs.pop("elem_classes", []) 17 | super().__init__(*args, elem_classes=["tool", *classes], **kwargs) 18 | 19 | def get_block_name(self): 20 | return "button" 21 | 22 | 23 | class FormRow(FormComponent, gr.Row): 24 | """Same as gr.Row but fits inside gradio forms""" 25 | 26 | def get_block_name(self): 27 | return "row" 28 | 29 | 30 | class FormColumn(FormComponent, gr.Column): 31 | """Same as gr.Column but fits inside gradio forms""" 32 | 33 | def get_block_name(self): 34 | return "column" 35 | 36 | 37 | class FormGroup(FormComponent, gr.Group): 38 | """Same as gr.Row but fits inside gradio forms""" 39 | 40 | def get_block_name(self): 41 | return "group" 42 | 43 | 44 | class FormHTML(FormComponent, gr.HTML): 45 | """Same as gr.HTML but fits inside gradio forms""" 46 | 47 | def get_block_name(self): 48 | return "html" 49 | 50 | 51 | class FormColorPicker(FormComponent, gr.ColorPicker): 52 | """Same as gr.ColorPicker but fits inside gradio forms""" 53 | 54 | def get_block_name(self): 55 | return "colorpicker" 56 | 57 | 58 | class DropdownMulti(FormComponent, gr.Dropdown): 59 | """Same as gr.Dropdown but always multiselect""" 60 | def __init__(self, **kwargs): 61 | super().__init__(multiselect=True, **kwargs) 62 | 63 | def get_block_name(self): 64 | return "dropdown" 65 | 66 | 67 | class DropdownEditable(FormComponent, gr.Dropdown): 68 | """Same as gr.Dropdown but allows editing value""" 69 | def __init__(self, **kwargs): 70 | super().__init__(allow_custom_value=True, **kwargs) 71 | 72 | def get_block_name(self): 73 | return "dropdown" 74 | 75 | -------------------------------------------------------------------------------- /configs/alt-diffusion-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: modules.xlmr.BertSeriesModelWithTransformation 71 | params: 72 | name: "XLMR-Large" -------------------------------------------------------------------------------- /modules/sd_vae_approx.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from modules import devices, paths 6 | 7 | sd_vae_approx_model = None 8 | 9 | 10 | class VAEApprox(nn.Module): 11 | def __init__(self): 12 | super(VAEApprox, self).__init__() 13 | self.conv1 = nn.Conv2d(4, 8, (7, 7)) 14 | self.conv2 = nn.Conv2d(8, 16, (5, 5)) 15 | self.conv3 = nn.Conv2d(16, 32, (3, 3)) 16 | self.conv4 = nn.Conv2d(32, 64, (3, 3)) 17 | self.conv5 = nn.Conv2d(64, 32, (3, 3)) 18 | self.conv6 = nn.Conv2d(32, 16, (3, 3)) 19 | self.conv7 = nn.Conv2d(16, 8, (3, 3)) 20 | self.conv8 = nn.Conv2d(8, 3, (3, 3)) 21 | 22 | def forward(self, x): 23 | extra = 11 24 | x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2)) 25 | x = nn.functional.pad(x, (extra, extra, extra, extra)) 26 | 27 | for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]: 28 | x = layer(x) 29 | x = nn.functional.leaky_relu(x, 0.1) 30 | 31 | return x 32 | 33 | 34 | def model(): 35 | global sd_vae_approx_model 36 | 37 | if sd_vae_approx_model is None: 38 | model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt") 39 | sd_vae_approx_model = VAEApprox() 40 | if not os.path.exists(model_path): 41 | model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt") 42 | sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) 43 | sd_vae_approx_model.eval() 44 | sd_vae_approx_model.to(devices.device, devices.dtype) 45 | 46 | return sd_vae_approx_model 47 | 48 | 49 | def cheap_approximation(sample): 50 | # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 51 | 52 | coefs = torch.tensor([ 53 | [0.298, 0.207, 0.208], 54 | [0.187, 0.286, 0.173], 55 | [-0.158, 0.189, 0.264], 56 | [-0.184, -0.271, -0.473], 57 | ]).to(sample.device) 58 | 59 | x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs) 60 | 61 | return x_sample 62 | -------------------------------------------------------------------------------- /configs/v1-inpainting-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 7.5e-05 3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid # important 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | finetune_keys: null 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yaml: -------------------------------------------------------------------------------- 1 | name: Run basic features tests on CPU with empty SD model 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Checkout Code 12 | uses: actions/checkout@v3 13 | - name: Set up Python 3.10 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: 3.10.6 17 | cache: pip 18 | cache-dependency-path: | 19 | **/requirements*txt 20 | launch.py 21 | - name: Install test dependencies 22 | run: pip install wait-for-it -r requirements-test.txt 23 | env: 24 | PIP_DISABLE_PIP_VERSION_CHECK: "1" 25 | PIP_PROGRESS_BAR: "off" 26 | - name: Setup environment 27 | run: python launch.py --skip-torch-cuda-test --exit 28 | env: 29 | PIP_DISABLE_PIP_VERSION_CHECK: "1" 30 | PIP_PROGRESS_BAR: "off" 31 | TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu 32 | WEBUI_LAUNCH_LIVE_OUTPUT: "1" 33 | PYTHONUNBUFFERED: "1" 34 | - name: Start test server 35 | run: > 36 | python -m coverage run 37 | --data-file=.coverage.server 38 | launch.py 39 | --skip-prepare-environment 40 | --skip-torch-cuda-test 41 | --test-server 42 | --no-half 43 | --disable-opt-split-attention 44 | --use-cpu all 45 | --add-stop-route 46 | 2>&1 | tee output.txt & 47 | - name: Run tests 48 | run: | 49 | wait-for-it --service 127.0.0.1:7860 -t 600 50 | python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test 51 | - name: Kill test server 52 | if: always() 53 | run: curl -vv -XPOST http://127.0.0.1:7860/_stop && sleep 10 54 | - name: Show coverage 55 | run: | 56 | python -m coverage combine .coverage* 57 | python -m coverage report -i 58 | python -m coverage html -i 59 | - name: Upload main app output 60 | uses: actions/upload-artifact@v3 61 | if: always() 62 | with: 63 | name: output 64 | path: output.txt 65 | - name: Upload coverage HTML 66 | uses: actions/upload-artifact@v3 67 | if: always() 68 | with: 69 | name: htmlcov 70 | path: htmlcov 71 | -------------------------------------------------------------------------------- /test/test_img2img.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | import requests 4 | 5 | 6 | @pytest.fixture() 7 | def url_img2img(base_url): 8 | return f"{base_url}/sdapi/v1/img2img" 9 | 10 | 11 | @pytest.fixture() 12 | def simple_img2img_request(img2img_basic_image_base64): 13 | return { 14 | "batch_size": 1, 15 | "cfg_scale": 7, 16 | "denoising_strength": 0.75, 17 | "eta": 0, 18 | "height": 64, 19 | "include_init_images": False, 20 | "init_images": [img2img_basic_image_base64], 21 | "inpaint_full_res": False, 22 | "inpaint_full_res_padding": 0, 23 | "inpainting_fill": 0, 24 | "inpainting_mask_invert": False, 25 | "mask": None, 26 | "mask_blur": 4, 27 | "n_iter": 1, 28 | "negative_prompt": "", 29 | "override_settings": {}, 30 | "prompt": "example prompt", 31 | "resize_mode": 0, 32 | "restore_faces": False, 33 | "s_churn": 0, 34 | "s_noise": 1, 35 | "s_tmax": 0, 36 | "s_tmin": 0, 37 | "sampler_index": "Euler a", 38 | "seed": -1, 39 | "seed_resize_from_h": -1, 40 | "seed_resize_from_w": -1, 41 | "steps": 3, 42 | "styles": [], 43 | "subseed": -1, 44 | "subseed_strength": 0, 45 | "tiling": False, 46 | "width": 64, 47 | } 48 | 49 | 50 | def test_img2img_simple_performed(url_img2img, simple_img2img_request): 51 | assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200 52 | 53 | 54 | def test_inpainting_masked_performed(url_img2img, simple_img2img_request, mask_basic_image_base64): 55 | simple_img2img_request["mask"] = mask_basic_image_base64 56 | assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200 57 | 58 | 59 | def test_inpainting_with_inverted_masked_performed(url_img2img, simple_img2img_request, mask_basic_image_base64): 60 | simple_img2img_request["mask"] = mask_basic_image_base64 61 | simple_img2img_request["inpainting_mask_invert"] = True 62 | assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200 63 | 64 | 65 | def test_img2img_sd_upscale_performed(url_img2img, simple_img2img_request): 66 | simple_img2img_request["script_name"] = "sd upscale" 67 | simple_img2img_request["script_args"] = ["", 8, "Lanczos", 2.0] 68 | assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200 69 | -------------------------------------------------------------------------------- /webui.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | if not defined PYTHON (set PYTHON=python) 4 | if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv") 5 | 6 | 7 | set ERROR_REPORTING=FALSE 8 | 9 | mkdir tmp 2>NUL 10 | 11 | %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt 12 | if %ERRORLEVEL% == 0 goto :check_pip 13 | echo Couldn't launch python 14 | goto :show_stdout_stderr 15 | 16 | :check_pip 17 | %PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt 18 | if %ERRORLEVEL% == 0 goto :start_venv 19 | if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr 20 | %PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt 21 | if %ERRORLEVEL% == 0 goto :start_venv 22 | echo Couldn't install pip 23 | goto :show_stdout_stderr 24 | 25 | :start_venv 26 | if ["%VENV_DIR%"] == ["-"] goto :skip_venv 27 | if ["%SKIP_VENV%"] == ["1"] goto :skip_venv 28 | 29 | dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt 30 | if %ERRORLEVEL% == 0 goto :activate_venv 31 | 32 | for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i" 33 | echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME% 34 | %PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt 35 | if %ERRORLEVEL% == 0 goto :activate_venv 36 | echo Unable to create venv in directory "%VENV_DIR%" 37 | goto :show_stdout_stderr 38 | 39 | :activate_venv 40 | set PYTHON="%VENV_DIR%\Scripts\Python.exe" 41 | echo venv %PYTHON% 42 | 43 | :skip_venv 44 | if [%ACCELERATE%] == ["True"] goto :accelerate 45 | goto :launch 46 | 47 | :accelerate 48 | echo Checking for accelerate 49 | set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe" 50 | if EXIST %ACCELERATE% goto :accelerate_launch 51 | 52 | :launch 53 | %PYTHON% launch.py %* 54 | pause 55 | exit /b 56 | 57 | :accelerate_launch 58 | echo Accelerating 59 | %ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py 60 | pause 61 | exit /b 62 | 63 | :show_stdout_stderr 64 | 65 | echo. 66 | echo exit code: %errorlevel% 67 | 68 | for /f %%i in ("tmp\stdout.txt") do set size=%%~zi 69 | if %size% equ 0 goto :show_stderr 70 | echo. 71 | echo stdout: 72 | type tmp\stdout.txt 73 | 74 | :show_stderr 75 | for /f %%i in ("tmp\stderr.txt") do set size=%%~zi 76 | if %size% equ 0 goto :show_stderr 77 | echo. 78 | echo stderr: 79 | type tmp\stderr.txt 80 | 81 | :endofscript 82 | 83 | echo. 84 | echo Launch unsuccessful. Exiting. 85 | pause 86 | -------------------------------------------------------------------------------- /javascript/ui_settings_hints.js: -------------------------------------------------------------------------------- 1 | // various hints and extra info for the settings tab 2 | 3 | var settingsHintsSetup = false; 4 | 5 | onOptionsChanged(function() { 6 | if (settingsHintsSetup) return; 7 | settingsHintsSetup = true; 8 | 9 | gradioApp().querySelectorAll('#settings [id^=setting_]').forEach(function(div) { 10 | var name = div.id.substr(8); 11 | var commentBefore = opts._comments_before[name]; 12 | var commentAfter = opts._comments_after[name]; 13 | 14 | if (!commentBefore && !commentAfter) return; 15 | 16 | var span = null; 17 | if (div.classList.contains('gradio-checkbox')) span = div.querySelector('label span'); 18 | else if (div.classList.contains('gradio-checkboxgroup')) span = div.querySelector('span').firstChild; 19 | else if (div.classList.contains('gradio-radio')) span = div.querySelector('span').firstChild; 20 | else span = div.querySelector('label span').firstChild; 21 | 22 | if (!span) return; 23 | 24 | if (commentBefore) { 25 | var comment = document.createElement('DIV'); 26 | comment.className = 'settings-comment'; 27 | comment.innerHTML = commentBefore; 28 | span.parentElement.insertBefore(document.createTextNode('\xa0'), span); 29 | span.parentElement.insertBefore(comment, span); 30 | span.parentElement.insertBefore(document.createTextNode('\xa0'), span); 31 | } 32 | if (commentAfter) { 33 | comment = document.createElement('DIV'); 34 | comment.className = 'settings-comment'; 35 | comment.innerHTML = commentAfter; 36 | span.parentElement.insertBefore(comment, span.nextSibling); 37 | span.parentElement.insertBefore(document.createTextNode('\xa0'), span.nextSibling); 38 | } 39 | }); 40 | }); 41 | 42 | function settingsHintsShowQuicksettings() { 43 | requestGet("./internal/quicksettings-hint", {}, function(data) { 44 | var table = document.createElement('table'); 45 | table.className = 'settings-value-table'; 46 | 47 | data.forEach(function(obj) { 48 | var tr = document.createElement('tr'); 49 | var td = document.createElement('td'); 50 | td.textContent = obj.name; 51 | tr.appendChild(td); 52 | 53 | td = document.createElement('td'); 54 | td.textContent = obj.label; 55 | tr.appendChild(td); 56 | 57 | table.appendChild(tr); 58 | }); 59 | 60 | popup(table); 61 | }); 62 | } 63 | -------------------------------------------------------------------------------- /javascript/extensions.js: -------------------------------------------------------------------------------- 1 | 2 | function extensions_apply(_disabled_list, _update_list, disable_all) { 3 | var disable = []; 4 | var update = []; 5 | 6 | gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) { 7 | if (x.name.startsWith("enable_") && !x.checked) { 8 | disable.push(x.name.substring(7)); 9 | } 10 | 11 | if (x.name.startsWith("update_") && x.checked) { 12 | update.push(x.name.substring(7)); 13 | } 14 | }); 15 | 16 | restart_reload(); 17 | 18 | return [JSON.stringify(disable), JSON.stringify(update), disable_all]; 19 | } 20 | 21 | function extensions_check() { 22 | var disable = []; 23 | 24 | gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) { 25 | if (x.name.startsWith("enable_") && !x.checked) { 26 | disable.push(x.name.substring(7)); 27 | } 28 | }); 29 | 30 | gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) { 31 | x.innerHTML = "Loading..."; 32 | }); 33 | 34 | 35 | var id = randomId(); 36 | requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function() { 37 | 38 | }); 39 | 40 | return [id, JSON.stringify(disable)]; 41 | } 42 | 43 | function install_extension_from_index(button, url) { 44 | button.disabled = "disabled"; 45 | button.value = "Installing..."; 46 | 47 | var textarea = gradioApp().querySelector('#extension_to_install textarea'); 48 | textarea.value = url; 49 | updateInput(textarea); 50 | 51 | gradioApp().querySelector('#install_extension_button').click(); 52 | } 53 | 54 | function config_state_confirm_restore(_, config_state_name, config_restore_type) { 55 | if (config_state_name == "Current") { 56 | return [false, config_state_name, config_restore_type]; 57 | } 58 | let restored = ""; 59 | if (config_restore_type == "extensions") { 60 | restored = "all saved extension versions"; 61 | } else if (config_restore_type == "webui") { 62 | restored = "the webui version"; 63 | } else { 64 | restored = "the webui version and all saved extension versions"; 65 | } 66 | let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + "."); 67 | if (confirmed) { 68 | restart_reload(); 69 | gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) { 70 | x.innerHTML = "Loading..."; 71 | }); 72 | } 73 | return [confirmed, config_state_name, config_restore_type]; 74 | } 75 | -------------------------------------------------------------------------------- /scripts/custom_code.py: -------------------------------------------------------------------------------- 1 | import modules.scripts as scripts 2 | import gradio as gr 3 | import ast 4 | import copy 5 | 6 | from modules.processing import Processed 7 | from modules.shared import cmd_opts 8 | 9 | 10 | def convertExpr2Expression(expr): 11 | expr.lineno = 0 12 | expr.col_offset = 0 13 | result = ast.Expression(expr.value, lineno=0, col_offset = 0) 14 | 15 | return result 16 | 17 | 18 | def exec_with_return(code, module): 19 | """ 20 | like exec() but can return values 21 | https://stackoverflow.com/a/52361938/5862977 22 | """ 23 | code_ast = ast.parse(code) 24 | 25 | init_ast = copy.deepcopy(code_ast) 26 | init_ast.body = code_ast.body[:-1] 27 | 28 | last_ast = copy.deepcopy(code_ast) 29 | last_ast.body = code_ast.body[-1:] 30 | 31 | exec(compile(init_ast, "", "exec"), module.__dict__) 32 | if type(last_ast.body[0]) == ast.Expr: 33 | return eval(compile(convertExpr2Expression(last_ast.body[0]), "", "eval"), module.__dict__) 34 | else: 35 | exec(compile(last_ast, "", "exec"), module.__dict__) 36 | 37 | 38 | class Script(scripts.Script): 39 | 40 | def title(self): 41 | return "Custom code" 42 | 43 | def show(self, is_img2img): 44 | return cmd_opts.allow_code 45 | 46 | def ui(self, is_img2img): 47 | example = """from modules.processing import process_images 48 | 49 | p.width = 768 50 | p.height = 768 51 | p.batch_size = 2 52 | p.steps = 10 53 | 54 | return process_images(p) 55 | """ 56 | 57 | 58 | code = gr.Code(value=example, language="python", label="Python code", elem_id=self.elem_id("code")) 59 | indent_level = gr.Number(label='Indent level', value=2, precision=0, elem_id=self.elem_id("indent_level")) 60 | 61 | return [code, indent_level] 62 | 63 | def run(self, p, code, indent_level): 64 | assert cmd_opts.allow_code, '--allow-code option must be enabled' 65 | 66 | display_result_data = [[], -1, ""] 67 | 68 | def display(imgs, s=display_result_data[1], i=display_result_data[2]): 69 | display_result_data[0] = imgs 70 | display_result_data[1] = s 71 | display_result_data[2] = i 72 | 73 | from types import ModuleType 74 | module = ModuleType("testmodule") 75 | module.__dict__.update(globals()) 76 | module.p = p 77 | module.display = display 78 | 79 | indent = " " * indent_level 80 | indented = code.replace('\n', f"\n{indent}") 81 | body = f"""def __webuitemp__(): 82 | {indent}{indented} 83 | __webuitemp__()""" 84 | 85 | result = exec_with_return(body, module) 86 | 87 | if isinstance(result, Processed): 88 | return result 89 | 90 | return Processed(p, *display_result_data) 91 | -------------------------------------------------------------------------------- /modules/ui_postprocessing.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from modules import scripts, shared, ui_common, postprocessing, call_queue 3 | import modules.generation_parameters_copypaste as parameters_copypaste 4 | 5 | 6 | def create_ui(): 7 | tab_index = gr.State(value=0) 8 | 9 | with gr.Row().style(equal_height=False, variant='compact'): 10 | with gr.Column(variant='compact'): 11 | with gr.Tabs(elem_id="mode_extras"): 12 | with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single: 13 | extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") 14 | 15 | with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch: 16 | image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch") 17 | 18 | with gr.TabItem('Batch from Directory', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir: 19 | extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") 20 | extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") 21 | show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") 22 | 23 | submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') 24 | 25 | script_inputs = scripts.scripts_postproc.setup_ui() 26 | 27 | with gr.Column(): 28 | result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) 29 | 30 | tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) 31 | tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index]) 32 | tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) 33 | 34 | submit.click( 35 | fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), 36 | inputs=[ 37 | tab_index, 38 | extras_image, 39 | image_batch, 40 | extras_batch_input_dir, 41 | extras_batch_output_dir, 42 | show_extras_results, 43 | *script_inputs 44 | ], 45 | outputs=[ 46 | result_images, 47 | html_info_x, 48 | html_info, 49 | ] 50 | ) 51 | 52 | parameters_copypaste.add_paste_fields("extras", extras_image, None) 53 | 54 | extras_image.change( 55 | fn=scripts.scripts_postproc.image_changed, 56 | inputs=[], outputs=[] 57 | ) 58 | -------------------------------------------------------------------------------- /modules/ui_tempdir.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from collections import namedtuple 4 | from pathlib import Path 5 | 6 | import gradio.components 7 | 8 | from PIL import PngImagePlugin 9 | 10 | from modules import shared 11 | 12 | 13 | Savedfile = namedtuple("Savedfile", ["name"]) 14 | 15 | 16 | def register_tmp_file(gradio, filename): 17 | if hasattr(gradio, 'temp_file_sets'): # gradio 3.15 18 | gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)} 19 | 20 | if hasattr(gradio, 'temp_dirs'): # gradio 3.9 21 | gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))} 22 | 23 | 24 | def check_tmp_file(gradio, filename): 25 | if hasattr(gradio, 'temp_file_sets'): 26 | return any(filename in fileset for fileset in gradio.temp_file_sets) 27 | 28 | if hasattr(gradio, 'temp_dirs'): 29 | return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs) 30 | 31 | return False 32 | 33 | 34 | def save_pil_to_file(self, pil_image, dir=None): 35 | already_saved_as = getattr(pil_image, 'already_saved_as', None) 36 | if already_saved_as and os.path.isfile(already_saved_as): 37 | register_tmp_file(shared.demo, already_saved_as) 38 | filename = already_saved_as 39 | 40 | if not shared.opts.save_images_add_number: 41 | filename += f'?{os.path.getmtime(already_saved_as)}' 42 | 43 | return filename 44 | 45 | if shared.opts.temp_dir != "": 46 | dir = shared.opts.temp_dir 47 | 48 | use_metadata = False 49 | metadata = PngImagePlugin.PngInfo() 50 | for key, value in pil_image.info.items(): 51 | if isinstance(key, str) and isinstance(value, str): 52 | metadata.add_text(key, value) 53 | use_metadata = True 54 | 55 | file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) 56 | pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) 57 | return file_obj.name 58 | 59 | 60 | # override save to file function so that it also writes PNG info 61 | gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file 62 | 63 | 64 | def on_tmpdir_changed(): 65 | if shared.opts.temp_dir == "" or shared.demo is None: 66 | return 67 | 68 | os.makedirs(shared.opts.temp_dir, exist_ok=True) 69 | 70 | register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x")) 71 | 72 | 73 | def cleanup_tmpdr(): 74 | temp_dir = shared.opts.temp_dir 75 | if temp_dir == "" or not os.path.isdir(temp_dir): 76 | return 77 | 78 | for root, _, files in os.walk(temp_dir, topdown=False): 79 | for name in files: 80 | _, extension = os.path.splitext(name) 81 | if extension != ".png": 82 | continue 83 | 84 | filename = os.path.join(root, name) 85 | os.remove(filename) 86 | -------------------------------------------------------------------------------- /configs/instruct-pix2pix.yaml: -------------------------------------------------------------------------------- 1 | # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). 2 | # See more details in LICENSE. 3 | 4 | model: 5 | base_learning_rate: 1.0e-04 6 | target: modules.models.diffusion.ddpm_edit.LatentDiffusion 7 | params: 8 | linear_start: 0.00085 9 | linear_end: 0.0120 10 | num_timesteps_cond: 1 11 | log_every_t: 200 12 | timesteps: 1000 13 | first_stage_key: edited 14 | cond_stage_key: edit 15 | # image_size: 64 16 | # image_size: 32 17 | image_size: 16 18 | channels: 4 19 | cond_stage_trainable: false # Note: different from the one we trained before 20 | conditioning_key: hybrid 21 | monitor: val/loss_simple_ema 22 | scale_factor: 0.18215 23 | use_ema: false 24 | 25 | scheduler_config: # 10000 warmup steps 26 | target: ldm.lr_scheduler.LambdaLinearScheduler 27 | params: 28 | warm_up_steps: [ 0 ] 29 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 30 | f_start: [ 1.e-6 ] 31 | f_max: [ 1. ] 32 | f_min: [ 1. ] 33 | 34 | unet_config: 35 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 36 | params: 37 | image_size: 32 # unused 38 | in_channels: 8 39 | out_channels: 4 40 | model_channels: 320 41 | attention_resolutions: [ 4, 2, 1 ] 42 | num_res_blocks: 2 43 | channel_mult: [ 1, 2, 4, 4 ] 44 | num_heads: 8 45 | use_spatial_transformer: True 46 | transformer_depth: 1 47 | context_dim: 768 48 | use_checkpoint: True 49 | legacy: False 50 | 51 | first_stage_config: 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | embed_dim: 4 55 | monitor: val/rec_loss 56 | ddconfig: 57 | double_z: true 58 | z_channels: 4 59 | resolution: 256 60 | in_channels: 3 61 | out_ch: 3 62 | ch: 128 63 | ch_mult: 64 | - 1 65 | - 2 66 | - 4 67 | - 4 68 | num_res_blocks: 2 69 | attn_resolutions: [] 70 | dropout: 0.0 71 | lossconfig: 72 | target: torch.nn.Identity 73 | 74 | cond_stage_config: 75 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 76 | 77 | data: 78 | target: main.DataModuleFromConfig 79 | params: 80 | batch_size: 128 81 | num_workers: 1 82 | wrap: false 83 | validation: 84 | target: edit_dataset.EditDataset 85 | params: 86 | path: data/clip-filtered-dataset 87 | cache_dir: data/ 88 | cache_name: data_10k 89 | split: val 90 | min_text_sim: 0.2 91 | min_image_sim: 0.75 92 | min_direction_sim: 0.2 93 | max_samples_per_prompt: 1 94 | min_resize_res: 512 95 | max_resize_res: 512 96 | crop_res: 512 97 | output_as_edit: False 98 | real_input: True 99 | -------------------------------------------------------------------------------- /modules/styles.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import os.path 4 | import typing 5 | import shutil 6 | 7 | 8 | class PromptStyle(typing.NamedTuple): 9 | name: str 10 | prompt: str 11 | negative_prompt: str 12 | 13 | 14 | def merge_prompts(style_prompt: str, prompt: str) -> str: 15 | if "{prompt}" in style_prompt: 16 | res = style_prompt.replace("{prompt}", prompt) 17 | else: 18 | parts = filter(None, (prompt.strip(), style_prompt.strip())) 19 | res = ", ".join(parts) 20 | 21 | return res 22 | 23 | 24 | def apply_styles_to_prompt(prompt, styles): 25 | for style in styles: 26 | prompt = merge_prompts(style, prompt) 27 | 28 | return prompt 29 | 30 | 31 | class StyleDatabase: 32 | def __init__(self, path: str): 33 | self.no_style = PromptStyle("None", "", "") 34 | self.styles = {} 35 | self.path = path 36 | 37 | self.reload() 38 | 39 | def reload(self): 40 | self.styles.clear() 41 | 42 | if not os.path.exists(self.path): 43 | return 44 | 45 | with open(self.path, "r", encoding="utf-8-sig", newline='') as file: 46 | reader = csv.DictReader(file, skipinitialspace=True) 47 | for row in reader: 48 | # Support loading old CSV format with "name, text"-columns 49 | prompt = row["prompt"] if "prompt" in row else row["text"] 50 | negative_prompt = row.get("negative_prompt", "") 51 | self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt) 52 | 53 | def get_style_prompts(self, styles): 54 | return [self.styles.get(x, self.no_style).prompt for x in styles] 55 | 56 | def get_negative_style_prompts(self, styles): 57 | return [self.styles.get(x, self.no_style).negative_prompt for x in styles] 58 | 59 | def apply_styles_to_prompt(self, prompt, styles): 60 | return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles]) 61 | 62 | def apply_negative_styles_to_prompt(self, prompt, styles): 63 | return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]) 64 | 65 | def save_styles(self, path: str) -> None: 66 | # Always keep a backup file around 67 | if os.path.exists(path): 68 | shutil.copy(path, f"{path}.bak") 69 | 70 | fd = os.open(path, os.O_RDWR|os.O_CREAT) 71 | with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: 72 | # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, 73 | # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() 74 | writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) 75 | writer.writeheader() 76 | writer.writerows(style._asdict() for k, style in self.styles.items()) 77 | -------------------------------------------------------------------------------- /modules/sd_vae_taesd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tiny AutoEncoder for Stable Diffusion 3 | (DNN for encoding / decoding SD's latent space) 4 | 5 | https://github.com/madebyollin/taesd 6 | """ 7 | import os 8 | import torch 9 | import torch.nn as nn 10 | 11 | from modules import devices, paths_internal 12 | 13 | sd_vae_taesd = None 14 | 15 | 16 | def conv(n_in, n_out, **kwargs): 17 | return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) 18 | 19 | 20 | class Clamp(nn.Module): 21 | @staticmethod 22 | def forward(x): 23 | return torch.tanh(x / 3) * 3 24 | 25 | 26 | class Block(nn.Module): 27 | def __init__(self, n_in, n_out): 28 | super().__init__() 29 | self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) 30 | self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() 31 | self.fuse = nn.ReLU() 32 | 33 | def forward(self, x): 34 | return self.fuse(self.conv(x) + self.skip(x)) 35 | 36 | 37 | def decoder(): 38 | return nn.Sequential( 39 | Clamp(), conv(4, 64), nn.ReLU(), 40 | Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), 41 | Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), 42 | Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), 43 | Block(64, 64), conv(64, 3), 44 | ) 45 | 46 | 47 | class TAESD(nn.Module): 48 | latent_magnitude = 3 49 | latent_shift = 0.5 50 | 51 | def __init__(self, decoder_path="taesd_decoder.pth"): 52 | """Initialize pretrained TAESD on the given device from the given checkpoints.""" 53 | super().__init__() 54 | self.decoder = decoder() 55 | self.decoder.load_state_dict( 56 | torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) 57 | 58 | @staticmethod 59 | def unscale_latents(x): 60 | """[0, 1] -> raw latents""" 61 | return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) 62 | 63 | 64 | def download_model(model_path): 65 | model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' 66 | 67 | if not os.path.exists(model_path): 68 | os.makedirs(os.path.dirname(model_path), exist_ok=True) 69 | 70 | print(f'Downloading TAESD decoder to: {model_path}') 71 | torch.hub.download_url_to_file(model_url, model_path) 72 | 73 | 74 | def model(): 75 | global sd_vae_taesd 76 | 77 | if sd_vae_taesd is None: 78 | model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth") 79 | download_model(model_path) 80 | 81 | if os.path.exists(model_path): 82 | sd_vae_taesd = TAESD(model_path) 83 | sd_vae_taesd.eval() 84 | sd_vae_taesd.to(devices.device, devices.dtype) 85 | else: 86 | raise FileNotFoundError('TAESD model not found') 87 | 88 | return sd_vae_taesd.decoder 89 | -------------------------------------------------------------------------------- /modules/memmon.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | from collections import defaultdict 4 | 5 | import torch 6 | 7 | 8 | class MemUsageMonitor(threading.Thread): 9 | run_flag = None 10 | device = None 11 | disabled = False 12 | opts = None 13 | data = None 14 | 15 | def __init__(self, name, device, opts): 16 | threading.Thread.__init__(self) 17 | self.name = name 18 | self.device = device 19 | self.opts = opts 20 | 21 | self.daemon = True 22 | self.run_flag = threading.Event() 23 | self.data = defaultdict(int) 24 | 25 | try: 26 | self.cuda_mem_get_info() 27 | torch.cuda.memory_stats(self.device) 28 | except Exception as e: # AMD or whatever 29 | print(f"Warning: caught exception '{e}', memory monitor disabled") 30 | self.disabled = True 31 | 32 | def cuda_mem_get_info(self): 33 | index = self.device.index if self.device.index is not None else torch.cuda.current_device() 34 | return torch.cuda.mem_get_info(index) 35 | 36 | def run(self): 37 | if self.disabled: 38 | return 39 | 40 | while True: 41 | self.run_flag.wait() 42 | 43 | torch.cuda.reset_peak_memory_stats() 44 | self.data.clear() 45 | 46 | if self.opts.memmon_poll_rate <= 0: 47 | self.run_flag.clear() 48 | continue 49 | 50 | self.data["min_free"] = self.cuda_mem_get_info()[0] 51 | 52 | while self.run_flag.is_set(): 53 | free, total = self.cuda_mem_get_info() 54 | self.data["min_free"] = min(self.data["min_free"], free) 55 | 56 | time.sleep(1 / self.opts.memmon_poll_rate) 57 | 58 | def dump_debug(self): 59 | print(self, 'recorded data:') 60 | for k, v in self.read().items(): 61 | print(k, -(v // -(1024 ** 2))) 62 | 63 | print(self, 'raw torch memory stats:') 64 | tm = torch.cuda.memory_stats(self.device) 65 | for k, v in tm.items(): 66 | if 'bytes' not in k: 67 | continue 68 | print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2))) 69 | 70 | print(torch.cuda.memory_summary()) 71 | 72 | def monitor(self): 73 | self.run_flag.set() 74 | 75 | def read(self): 76 | if not self.disabled: 77 | free, total = self.cuda_mem_get_info() 78 | self.data["free"] = free 79 | self.data["total"] = total 80 | 81 | torch_stats = torch.cuda.memory_stats(self.device) 82 | self.data["active"] = torch_stats["active.all.current"] 83 | self.data["active_peak"] = torch_stats["active_bytes.all.peak"] 84 | self.data["reserved"] = torch_stats["reserved_bytes.all.current"] 85 | self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] 86 | self.data["system_peak"] = total - self.data["min_free"] 87 | 88 | return self.data 89 | 90 | def stop(self): 91 | self.run_flag.clear() 92 | return self.read() 93 | -------------------------------------------------------------------------------- /modules/textual_inversion/learn_schedule.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | 3 | 4 | class LearnScheduleIterator: 5 | def __init__(self, learn_rate, max_steps, cur_step=0): 6 | """ 7 | specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000 8 | """ 9 | 10 | pairs = learn_rate.split(',') 11 | self.rates = [] 12 | self.it = 0 13 | self.maxit = 0 14 | try: 15 | for pair in pairs: 16 | if not pair.strip(): 17 | continue 18 | tmp = pair.split(':') 19 | if len(tmp) == 2: 20 | step = int(tmp[1]) 21 | if step > cur_step: 22 | self.rates.append((float(tmp[0]), min(step, max_steps))) 23 | self.maxit += 1 24 | if step > max_steps: 25 | return 26 | elif step == -1: 27 | self.rates.append((float(tmp[0]), max_steps)) 28 | self.maxit += 1 29 | return 30 | else: 31 | self.rates.append((float(tmp[0]), max_steps)) 32 | self.maxit += 1 33 | return 34 | assert self.rates 35 | except (ValueError, AssertionError) as e: 36 | raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e 37 | 38 | 39 | def __iter__(self): 40 | return self 41 | 42 | def __next__(self): 43 | if self.it < self.maxit: 44 | self.it += 1 45 | return self.rates[self.it - 1] 46 | else: 47 | raise StopIteration 48 | 49 | 50 | class LearnRateScheduler: 51 | def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True): 52 | self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step) 53 | (self.learn_rate, self.end_step) = next(self.schedules) 54 | self.verbose = verbose 55 | 56 | if self.verbose: 57 | print(f'Training at rate of {self.learn_rate} until step {self.end_step}') 58 | 59 | self.finished = False 60 | 61 | def step(self, step_number): 62 | if step_number < self.end_step: 63 | return False 64 | 65 | try: 66 | (self.learn_rate, self.end_step) = next(self.schedules) 67 | except StopIteration: 68 | self.finished = True 69 | return False 70 | return True 71 | 72 | def apply(self, optimizer, step_number): 73 | if not self.step(step_number): 74 | return 75 | 76 | if self.verbose: 77 | tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}') 78 | 79 | for pg in optimizer.param_groups: 80 | pg['lr'] = self.learn_rate 81 | 82 | -------------------------------------------------------------------------------- /modules/txt2img.py: -------------------------------------------------------------------------------- 1 | import modules.scripts 2 | from modules import sd_samplers, processing 3 | from modules.generation_parameters_copypaste import create_override_settings_dict 4 | from modules.shared import opts, cmd_opts 5 | import modules.shared as shared 6 | from modules.ui import plaintext_to_html 7 | 8 | 9 | 10 | def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args): 11 | override_settings = create_override_settings_dict(override_settings_texts) 12 | 13 | p = processing.StableDiffusionProcessingTxt2Img( 14 | sd_model=shared.sd_model, 15 | outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, 16 | outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids, 17 | prompt=prompt, 18 | styles=prompt_styles, 19 | negative_prompt=negative_prompt, 20 | seed=seed, 21 | subseed=subseed, 22 | subseed_strength=subseed_strength, 23 | seed_resize_from_h=seed_resize_from_h, 24 | seed_resize_from_w=seed_resize_from_w, 25 | seed_enable_extras=seed_enable_extras, 26 | sampler_name=sd_samplers.samplers[sampler_index].name, 27 | batch_size=batch_size, 28 | n_iter=n_iter, 29 | steps=steps, 30 | cfg_scale=cfg_scale, 31 | width=width, 32 | height=height, 33 | restore_faces=restore_faces, 34 | tiling=tiling, 35 | enable_hr=enable_hr, 36 | denoising_strength=denoising_strength if enable_hr else None, 37 | hr_scale=hr_scale, 38 | hr_upscaler=hr_upscaler, 39 | hr_second_pass_steps=hr_second_pass_steps, 40 | hr_resize_x=hr_resize_x, 41 | hr_resize_y=hr_resize_y, 42 | hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None, 43 | hr_prompt=hr_prompt, 44 | hr_negative_prompt=hr_negative_prompt, 45 | override_settings=override_settings, 46 | ) 47 | 48 | p.scripts = modules.scripts.scripts_txt2img 49 | p.script_args = args 50 | 51 | if cmd_opts.enable_console_prompts: 52 | print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) 53 | 54 | processed = modules.scripts.scripts_txt2img.run(p, *args) 55 | 56 | if processed is None: 57 | processed = processing.process_images(p) 58 | 59 | p.close() 60 | 61 | shared.total_tqdm.clear() 62 | 63 | generation_info_js = processed.js() 64 | if opts.samples_log_stdout: 65 | print(generation_info_js) 66 | 67 | if opts.do_not_show_images: 68 | processed.images = [] 69 | 70 | return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments) 71 | -------------------------------------------------------------------------------- /modules/hashes.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import os.path 4 | 5 | import filelock 6 | 7 | from modules import shared 8 | from modules.paths import data_path 9 | 10 | 11 | cache_filename = os.path.join(data_path, "cache.json") 12 | cache_data = None 13 | 14 | 15 | def dump_cache(): 16 | with filelock.FileLock(f"{cache_filename}.lock"): 17 | with open(cache_filename, "w", encoding="utf8") as file: 18 | json.dump(cache_data, file, indent=4) 19 | 20 | 21 | def cache(subsection): 22 | global cache_data 23 | 24 | if cache_data is None: 25 | with filelock.FileLock(f"{cache_filename}.lock"): 26 | if not os.path.isfile(cache_filename): 27 | cache_data = {} 28 | else: 29 | with open(cache_filename, "r", encoding="utf8") as file: 30 | cache_data = json.load(file) 31 | 32 | s = cache_data.get(subsection, {}) 33 | cache_data[subsection] = s 34 | 35 | return s 36 | 37 | 38 | def calculate_sha256(filename): 39 | hash_sha256 = hashlib.sha256() 40 | blksize = 1024 * 1024 41 | 42 | with open(filename, "rb") as f: 43 | for chunk in iter(lambda: f.read(blksize), b""): 44 | hash_sha256.update(chunk) 45 | 46 | return hash_sha256.hexdigest() 47 | 48 | 49 | def sha256_from_cache(filename, title, use_addnet_hash=False): 50 | hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes") 51 | ondisk_mtime = os.path.getmtime(filename) 52 | 53 | if title not in hashes: 54 | return None 55 | 56 | cached_sha256 = hashes[title].get("sha256", None) 57 | cached_mtime = hashes[title].get("mtime", 0) 58 | 59 | if ondisk_mtime > cached_mtime or cached_sha256 is None: 60 | return None 61 | 62 | return cached_sha256 63 | 64 | 65 | def sha256(filename, title, use_addnet_hash=False): 66 | hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes") 67 | 68 | sha256_value = sha256_from_cache(filename, title, use_addnet_hash) 69 | if sha256_value is not None: 70 | return sha256_value 71 | 72 | if shared.cmd_opts.no_hashing: 73 | return None 74 | 75 | print(f"Calculating sha256 for {filename}: ", end='') 76 | if use_addnet_hash: 77 | with open(filename, "rb") as file: 78 | sha256_value = addnet_hash_safetensors(file) 79 | else: 80 | sha256_value = calculate_sha256(filename) 81 | print(f"{sha256_value}") 82 | 83 | hashes[title] = { 84 | "mtime": os.path.getmtime(filename), 85 | "sha256": sha256_value, 86 | } 87 | 88 | dump_cache() 89 | 90 | return sha256_value 91 | 92 | 93 | def addnet_hash_safetensors(b): 94 | """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py""" 95 | hash_sha256 = hashlib.sha256() 96 | blksize = 1024 * 1024 97 | 98 | b.seek(0) 99 | header = b.read(8) 100 | n = int.from_bytes(header, "little") 101 | 102 | offset = n + 8 103 | b.seek(offset) 104 | for chunk in iter(lambda: b.read(blksize), b""): 105 | hash_sha256.update(chunk) 106 | 107 | return hash_sha256.hexdigest() 108 | 109 | -------------------------------------------------------------------------------- /modules/deepbooru.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from modules import modelloader, paths, deepbooru_model, devices, images, shared 8 | 9 | re_special = re.compile(r'([\\()])') 10 | 11 | 12 | class DeepDanbooru: 13 | def __init__(self): 14 | self.model = None 15 | 16 | def load(self): 17 | if self.model is not None: 18 | return 19 | 20 | files = modelloader.load_models( 21 | model_path=os.path.join(paths.models_path, "torch_deepdanbooru"), 22 | model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt', 23 | ext_filter=[".pt"], 24 | download_name='model-resnet_custom_v3.pt', 25 | ) 26 | 27 | self.model = deepbooru_model.DeepDanbooruModel() 28 | self.model.load_state_dict(torch.load(files[0], map_location="cpu")) 29 | 30 | self.model.eval() 31 | self.model.to(devices.cpu, devices.dtype) 32 | 33 | def start(self): 34 | self.load() 35 | self.model.to(devices.device) 36 | 37 | def stop(self): 38 | if not shared.opts.interrogate_keep_models_in_memory: 39 | self.model.to(devices.cpu) 40 | devices.torch_gc() 41 | 42 | def tag(self, pil_image): 43 | self.start() 44 | res = self.tag_multi(pil_image) 45 | self.stop() 46 | 47 | return res 48 | 49 | def tag_multi(self, pil_image, force_disable_ranks=False): 50 | threshold = shared.opts.interrogate_deepbooru_score_threshold 51 | use_spaces = shared.opts.deepbooru_use_spaces 52 | use_escape = shared.opts.deepbooru_escape 53 | alpha_sort = shared.opts.deepbooru_sort_alpha 54 | include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks 55 | 56 | pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512) 57 | a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255 58 | 59 | with torch.no_grad(), devices.autocast(): 60 | x = torch.from_numpy(a).to(devices.device) 61 | y = self.model(x)[0].detach().cpu().numpy() 62 | 63 | probability_dict = {} 64 | 65 | for tag, probability in zip(self.model.tags, y): 66 | if probability < threshold: 67 | continue 68 | 69 | if tag.startswith("rating:"): 70 | continue 71 | 72 | probability_dict[tag] = probability 73 | 74 | if alpha_sort: 75 | tags = sorted(probability_dict) 76 | else: 77 | tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])] 78 | 79 | res = [] 80 | 81 | filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")} 82 | 83 | for tag in [x for x in tags if x not in filtertags]: 84 | probability = probability_dict[tag] 85 | tag_outformat = tag 86 | if use_spaces: 87 | tag_outformat = tag_outformat.replace('_', ' ') 88 | if use_escape: 89 | tag_outformat = re.sub(re_special, r'\\\1', tag_outformat) 90 | if include_ranks: 91 | tag_outformat = f"({tag_outformat}:{probability:.3f})" 92 | 93 | res.append(tag_outformat) 94 | 95 | return ", ".join(res) 96 | 97 | 98 | model = DeepDanbooru() 99 | -------------------------------------------------------------------------------- /.eslintrc.js: -------------------------------------------------------------------------------- 1 | /* global module */ 2 | module.exports = { 3 | env: { 4 | browser: true, 5 | es2021: true, 6 | }, 7 | extends: "eslint:recommended", 8 | parserOptions: { 9 | ecmaVersion: "latest", 10 | }, 11 | rules: { 12 | "arrow-spacing": "error", 13 | "block-spacing": "error", 14 | "brace-style": "error", 15 | "comma-dangle": ["error", "only-multiline"], 16 | "comma-spacing": "error", 17 | "comma-style": ["error", "last"], 18 | "curly": ["error", "multi-line", "consistent"], 19 | "eol-last": "error", 20 | "func-call-spacing": "error", 21 | "function-call-argument-newline": ["error", "consistent"], 22 | "function-paren-newline": ["error", "consistent"], 23 | "indent": ["error", 4], 24 | "key-spacing": "error", 25 | "keyword-spacing": "error", 26 | "linebreak-style": ["error", "unix"], 27 | "no-extra-semi": "error", 28 | "no-mixed-spaces-and-tabs": "error", 29 | "no-multi-spaces": "error", 30 | "no-redeclare": ["error", {builtinGlobals: false}], 31 | "no-trailing-spaces": "error", 32 | "no-unused-vars": "off", 33 | "no-whitespace-before-property": "error", 34 | "object-curly-newline": ["error", {consistent: true, multiline: true}], 35 | "object-curly-spacing": ["error", "never"], 36 | "operator-linebreak": ["error", "after"], 37 | "quote-props": ["error", "consistent-as-needed"], 38 | "semi": ["error", "always"], 39 | "semi-spacing": "error", 40 | "semi-style": ["error", "last"], 41 | "space-before-blocks": "error", 42 | "space-before-function-paren": ["error", "never"], 43 | "space-in-parens": ["error", "never"], 44 | "space-infix-ops": "error", 45 | "space-unary-ops": "error", 46 | "switch-colon-spacing": "error", 47 | "template-curly-spacing": ["error", "never"], 48 | "unicode-bom": "error", 49 | }, 50 | globals: { 51 | //script.js 52 | gradioApp: "readonly", 53 | onUiLoaded: "readonly", 54 | onUiUpdate: "readonly", 55 | onOptionsChanged: "readonly", 56 | uiCurrentTab: "writable", 57 | uiElementIsVisible: "readonly", 58 | uiElementInSight: "readonly", 59 | executeCallbacks: "readonly", 60 | //ui.js 61 | opts: "writable", 62 | all_gallery_buttons: "readonly", 63 | selected_gallery_button: "readonly", 64 | selected_gallery_index: "readonly", 65 | switch_to_txt2img: "readonly", 66 | switch_to_img2img_tab: "readonly", 67 | switch_to_img2img: "readonly", 68 | switch_to_sketch: "readonly", 69 | switch_to_inpaint: "readonly", 70 | switch_to_inpaint_sketch: "readonly", 71 | switch_to_extras: "readonly", 72 | get_tab_index: "readonly", 73 | create_submit_args: "readonly", 74 | restart_reload: "readonly", 75 | updateInput: "readonly", 76 | //extraNetworks.js 77 | requestGet: "readonly", 78 | popup: "readonly", 79 | // from python 80 | localization: "readonly", 81 | // progrssbar.js 82 | randomId: "readonly", 83 | requestProgress: "readonly", 84 | // imageviewer.js 85 | modalPrevImage: "readonly", 86 | modalNextImage: "readonly", 87 | } 88 | }; 89 | -------------------------------------------------------------------------------- /test/test_txt2img.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | import requests 4 | 5 | 6 | @pytest.fixture() 7 | def url_txt2img(base_url): 8 | return f"{base_url}/sdapi/v1/txt2img" 9 | 10 | 11 | @pytest.fixture() 12 | def simple_txt2img_request(): 13 | return { 14 | "batch_size": 1, 15 | "cfg_scale": 7, 16 | "denoising_strength": 0, 17 | "enable_hr": False, 18 | "eta": 0, 19 | "firstphase_height": 0, 20 | "firstphase_width": 0, 21 | "height": 64, 22 | "n_iter": 1, 23 | "negative_prompt": "", 24 | "prompt": "example prompt", 25 | "restore_faces": False, 26 | "s_churn": 0, 27 | "s_noise": 1, 28 | "s_tmax": 0, 29 | "s_tmin": 0, 30 | "sampler_index": "Euler a", 31 | "seed": -1, 32 | "seed_resize_from_h": -1, 33 | "seed_resize_from_w": -1, 34 | "steps": 3, 35 | "styles": [], 36 | "subseed": -1, 37 | "subseed_strength": 0, 38 | "tiling": False, 39 | "width": 64, 40 | } 41 | 42 | 43 | def test_txt2img_simple_performed(url_txt2img, simple_txt2img_request): 44 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 45 | 46 | 47 | def test_txt2img_with_negative_prompt_performed(url_txt2img, simple_txt2img_request): 48 | simple_txt2img_request["negative_prompt"] = "example negative prompt" 49 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 50 | 51 | 52 | def test_txt2img_with_complex_prompt_performed(url_txt2img, simple_txt2img_request): 53 | simple_txt2img_request["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]" 54 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 55 | 56 | 57 | def test_txt2img_not_square_image_performed(url_txt2img, simple_txt2img_request): 58 | simple_txt2img_request["height"] = 128 59 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 60 | 61 | 62 | def test_txt2img_with_hrfix_performed(url_txt2img, simple_txt2img_request): 63 | simple_txt2img_request["enable_hr"] = True 64 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 65 | 66 | 67 | def test_txt2img_with_tiling_performed(url_txt2img, simple_txt2img_request): 68 | simple_txt2img_request["tiling"] = True 69 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 70 | 71 | 72 | def test_txt2img_with_restore_faces_performed(url_txt2img, simple_txt2img_request): 73 | simple_txt2img_request["restore_faces"] = True 74 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 75 | 76 | 77 | @pytest.mark.parametrize("sampler", ["PLMS", "DDIM", "UniPC"]) 78 | def test_txt2img_with_vanilla_sampler_performed(url_txt2img, simple_txt2img_request, sampler): 79 | simple_txt2img_request["sampler_index"] = sampler 80 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 81 | 82 | 83 | def test_txt2img_multiple_batches_performed(url_txt2img, simple_txt2img_request): 84 | simple_txt2img_request["n_iter"] = 2 85 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 86 | 87 | 88 | def test_txt2img_batch_performed(url_txt2img, simple_txt2img_request): 89 | simple_txt2img_request["batch_size"] = 2 90 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 91 | -------------------------------------------------------------------------------- /script.js: -------------------------------------------------------------------------------- 1 | function gradioApp() { 2 | const elems = document.getElementsByTagName('gradio-app'); 3 | const elem = elems.length == 0 ? document : elems[0]; 4 | 5 | if (elem !== document) { 6 | elem.getElementById = function(id) { 7 | return document.getElementById(id); 8 | }; 9 | } 10 | return elem.shadowRoot ? elem.shadowRoot : elem; 11 | } 12 | 13 | function get_uiCurrentTab() { 14 | return gradioApp().querySelector('#tabs button.selected'); 15 | } 16 | 17 | function get_uiCurrentTabContent() { 18 | return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])'); 19 | } 20 | 21 | var uiUpdateCallbacks = []; 22 | var uiLoadedCallbacks = []; 23 | var uiTabChangeCallbacks = []; 24 | var optionsChangedCallbacks = []; 25 | var uiCurrentTab = null; 26 | 27 | function onUiUpdate(callback) { 28 | uiUpdateCallbacks.push(callback); 29 | } 30 | function onUiLoaded(callback) { 31 | uiLoadedCallbacks.push(callback); 32 | } 33 | function onUiTabChange(callback) { 34 | uiTabChangeCallbacks.push(callback); 35 | } 36 | function onOptionsChanged(callback) { 37 | optionsChangedCallbacks.push(callback); 38 | } 39 | 40 | function runCallback(x, m) { 41 | try { 42 | x(m); 43 | } catch (e) { 44 | (console.error || console.log).call(console, e.message, e); 45 | } 46 | } 47 | function executeCallbacks(queue, m) { 48 | queue.forEach(function(x) { 49 | runCallback(x, m); 50 | }); 51 | } 52 | 53 | var executedOnLoaded = false; 54 | 55 | document.addEventListener("DOMContentLoaded", function() { 56 | var mutationObserver = new MutationObserver(function(m) { 57 | if (!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')) { 58 | executedOnLoaded = true; 59 | executeCallbacks(uiLoadedCallbacks); 60 | } 61 | 62 | executeCallbacks(uiUpdateCallbacks, m); 63 | const newTab = get_uiCurrentTab(); 64 | if (newTab && (newTab !== uiCurrentTab)) { 65 | uiCurrentTab = newTab; 66 | executeCallbacks(uiTabChangeCallbacks); 67 | } 68 | }); 69 | mutationObserver.observe(gradioApp(), {childList: true, subtree: true}); 70 | }); 71 | 72 | /** 73 | * Add a ctrl+enter as a shortcut to start a generation 74 | */ 75 | document.addEventListener('keydown', function(e) { 76 | var handled = false; 77 | if (e.key !== undefined) { 78 | if ((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; 79 | } else if (e.keyCode !== undefined) { 80 | if ((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; 81 | } 82 | if (handled) { 83 | var button = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); 84 | if (button) { 85 | button.click(); 86 | } 87 | e.preventDefault(); 88 | } 89 | }); 90 | 91 | /** 92 | * checks that a UI element is not in another hidden element or tab content 93 | */ 94 | function uiElementIsVisible(el) { 95 | if (el === document) { 96 | return true; 97 | } 98 | 99 | const computedStyle = getComputedStyle(el); 100 | const isVisible = computedStyle.display !== 'none'; 101 | 102 | if (!isVisible) return false; 103 | return uiElementIsVisible(el.parentNode); 104 | } 105 | 106 | function uiElementInSight(el) { 107 | const clRect = el.getBoundingClientRect(); 108 | const windowHeight = window.innerHeight; 109 | const isOnScreen = clRect.bottom > 0 && clRect.top < windowHeight; 110 | 111 | return isOnScreen; 112 | } 113 | -------------------------------------------------------------------------------- /extensions-builtin/LDSR/scripts/ldsr_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | from basicsr.utils.download_util import load_file_from_url 6 | 7 | from modules.upscaler import Upscaler, UpscalerData 8 | from ldsr_model_arch import LDSR 9 | from modules import shared, script_callbacks 10 | import sd_hijack_autoencoder # noqa: F401 11 | import sd_hijack_ddpm_v1 # noqa: F401 12 | 13 | 14 | class UpscalerLDSR(Upscaler): 15 | def __init__(self, user_path): 16 | self.name = "LDSR" 17 | self.user_path = user_path 18 | self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" 19 | self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" 20 | super().__init__() 21 | scaler_data = UpscalerData("LDSR", None, self) 22 | self.scalers = [scaler_data] 23 | 24 | def load_model(self, path: str): 25 | # Remove incorrect project.yaml file if too big 26 | yaml_path = os.path.join(self.model_path, "project.yaml") 27 | old_model_path = os.path.join(self.model_path, "model.pth") 28 | new_model_path = os.path.join(self.model_path, "model.ckpt") 29 | 30 | local_model_paths = self.find_models(ext_filter=[".ckpt", ".safetensors"]) 31 | local_ckpt_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.ckpt")]), None) 32 | local_safetensors_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.safetensors")]), None) 33 | local_yaml_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("project.yaml")]), None) 34 | 35 | if os.path.exists(yaml_path): 36 | statinfo = os.stat(yaml_path) 37 | if statinfo.st_size >= 10485760: 38 | print("Removing invalid LDSR YAML file.") 39 | os.remove(yaml_path) 40 | 41 | if os.path.exists(old_model_path): 42 | print("Renaming model from model.pth to model.ckpt") 43 | os.rename(old_model_path, new_model_path) 44 | 45 | if local_safetensors_path is not None and os.path.exists(local_safetensors_path): 46 | model = local_safetensors_path 47 | else: 48 | model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True) 49 | 50 | yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True) 51 | 52 | try: 53 | return LDSR(model, yaml) 54 | 55 | except Exception: 56 | print("Error importing LDSR:", file=sys.stderr) 57 | print(traceback.format_exc(), file=sys.stderr) 58 | return None 59 | 60 | def do_upscale(self, img, path): 61 | ldsr = self.load_model(path) 62 | if ldsr is None: 63 | print("NO LDSR!") 64 | return img 65 | ddim_steps = shared.opts.ldsr_steps 66 | return ldsr.super_resolution(img, ddim_steps, self.scale) 67 | 68 | 69 | def on_ui_settings(): 70 | import gradio as gr 71 | 72 | shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling"))) 73 | shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling"))) 74 | 75 | 76 | script_callbacks.on_ui_settings(on_ui_settings) 77 | -------------------------------------------------------------------------------- /javascript/dragdrop.js: -------------------------------------------------------------------------------- 1 | // allows drag-dropping files into gradio image elements, and also pasting images from clipboard 2 | 3 | function isValidImageList(files) { 4 | return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type); 5 | } 6 | 7 | function dropReplaceImage(imgWrap, files) { 8 | if (!isValidImageList(files)) { 9 | return; 10 | } 11 | 12 | const tmpFile = files[0]; 13 | 14 | imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click(); 15 | const callback = () => { 16 | const fileInput = imgWrap.querySelector('input[type="file"]'); 17 | if (fileInput) { 18 | if (files.length === 0) { 19 | files = new DataTransfer(); 20 | files.items.add(tmpFile); 21 | fileInput.files = files.files; 22 | } else { 23 | fileInput.files = files; 24 | } 25 | fileInput.dispatchEvent(new Event('change')); 26 | } 27 | }; 28 | 29 | if (imgWrap.closest('#pnginfo_image')) { 30 | // special treatment for PNG Info tab, wait for fetch request to finish 31 | const oldFetch = window.fetch; 32 | window.fetch = async(input, options) => { 33 | const response = await oldFetch(input, options); 34 | if ('api/predict/' === input) { 35 | const content = await response.text(); 36 | window.fetch = oldFetch; 37 | window.requestAnimationFrame(() => callback()); 38 | return new Response(content, { 39 | status: response.status, 40 | statusText: response.statusText, 41 | headers: response.headers 42 | }); 43 | } 44 | return response; 45 | }; 46 | } else { 47 | window.requestAnimationFrame(() => callback()); 48 | } 49 | } 50 | 51 | window.document.addEventListener('dragover', e => { 52 | const target = e.composedPath()[0]; 53 | const imgWrap = target.closest('[data-testid="image"]'); 54 | if (!imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) { 55 | return; 56 | } 57 | e.stopPropagation(); 58 | e.preventDefault(); 59 | e.dataTransfer.dropEffect = 'copy'; 60 | }); 61 | 62 | window.document.addEventListener('drop', e => { 63 | const target = e.composedPath()[0]; 64 | if (target.placeholder.indexOf("Prompt") == -1) { 65 | return; 66 | } 67 | const imgWrap = target.closest('[data-testid="image"]'); 68 | if (!imgWrap) { 69 | return; 70 | } 71 | e.stopPropagation(); 72 | e.preventDefault(); 73 | const files = e.dataTransfer.files; 74 | dropReplaceImage(imgWrap, files); 75 | }); 76 | 77 | window.addEventListener('paste', e => { 78 | const files = e.clipboardData.files; 79 | if (!isValidImageList(files)) { 80 | return; 81 | } 82 | 83 | const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')] 84 | .filter(el => uiElementIsVisible(el)) 85 | .sort((a, b) => uiElementInSight(b) - uiElementInSight(a)); 86 | 87 | 88 | if (!visibleImageFields.length) { 89 | return; 90 | } 91 | 92 | const firstFreeImageField = visibleImageFields 93 | .filter(el => el.querySelector('input[type=file]'))?.[0]; 94 | 95 | dropReplaceImage( 96 | firstFreeImageField ? 97 | firstFreeImageField : 98 | visibleImageFields[visibleImageFields.length - 1] 99 | , files 100 | ); 101 | }); 102 | -------------------------------------------------------------------------------- /modules/sd_samplers_common.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd 6 | 7 | from modules.shared import opts, state 8 | import modules.shared as shared 9 | 10 | SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) 11 | 12 | 13 | def setup_img2img_steps(p, steps=None): 14 | if opts.img2img_fix_steps or steps is not None: 15 | requested_steps = (steps or p.steps) 16 | steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 17 | t_enc = requested_steps - 1 18 | else: 19 | steps = p.steps 20 | t_enc = int(min(p.denoising_strength, 0.999) * steps) 21 | 22 | return steps, t_enc 23 | 24 | 25 | approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3} 26 | 27 | 28 | def single_sample_to_image(sample, approximation=None): 29 | if approximation is None: 30 | approximation = approximation_indexes.get(opts.show_progress_type, 0) 31 | 32 | if approximation == 2: 33 | x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5 34 | elif approximation == 1: 35 | x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5 36 | elif approximation == 3: 37 | x_sample = sample * 1.5 38 | x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() 39 | else: 40 | x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5 41 | 42 | x_sample = torch.clamp(x_sample, min=0.0, max=1.0) 43 | x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) 44 | x_sample = x_sample.astype(np.uint8) 45 | 46 | return Image.fromarray(x_sample) 47 | 48 | 49 | def sample_to_image(samples, index=0, approximation=None): 50 | return single_sample_to_image(samples[index], approximation) 51 | 52 | 53 | def samples_to_image_grid(samples, approximation=None): 54 | return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples]) 55 | 56 | 57 | def store_latent(decoded): 58 | state.current_latent = decoded 59 | 60 | if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: 61 | if not shared.parallel_processing_allowed: 62 | shared.state.assign_current_image(sample_to_image(decoded)) 63 | 64 | 65 | def is_sampler_using_eta_noise_seed_delta(p): 66 | """returns whether sampler from config will use eta noise seed delta for image creation""" 67 | 68 | sampler_config = sd_samplers.find_sampler_config(p.sampler_name) 69 | 70 | eta = p.eta 71 | 72 | if eta is None and p.sampler is not None: 73 | eta = p.sampler.eta 74 | 75 | if eta is None and sampler_config is not None: 76 | eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0 77 | 78 | if eta == 0: 79 | return False 80 | 81 | return sampler_config.options.get("uses_ensd", False) 82 | 83 | 84 | class InterruptedException(BaseException): 85 | pass 86 | 87 | 88 | if opts.randn_source == "CPU": 89 | import torchsde._brownian.brownian_interval 90 | 91 | def torchsde_randn(size, dtype, device, seed): 92 | generator = torch.Generator(devices.cpu).manual_seed(int(seed)) 93 | return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device) 94 | 95 | torchsde._brownian.brownian_interval._randn = torchsde_randn 96 | -------------------------------------------------------------------------------- /modules/masking.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFilter, ImageOps 2 | 3 | 4 | def get_crop_region(mask, pad=0): 5 | """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. 6 | For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" 7 | 8 | h, w = mask.shape 9 | 10 | crop_left = 0 11 | for i in range(w): 12 | if not (mask[:, i] == 0).all(): 13 | break 14 | crop_left += 1 15 | 16 | crop_right = 0 17 | for i in reversed(range(w)): 18 | if not (mask[:, i] == 0).all(): 19 | break 20 | crop_right += 1 21 | 22 | crop_top = 0 23 | for i in range(h): 24 | if not (mask[i] == 0).all(): 25 | break 26 | crop_top += 1 27 | 28 | crop_bottom = 0 29 | for i in reversed(range(h)): 30 | if not (mask[i] == 0).all(): 31 | break 32 | crop_bottom += 1 33 | 34 | return ( 35 | int(max(crop_left-pad, 0)), 36 | int(max(crop_top-pad, 0)), 37 | int(min(w - crop_right + pad, w)), 38 | int(min(h - crop_bottom + pad, h)) 39 | ) 40 | 41 | 42 | def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height): 43 | """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region 44 | for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.""" 45 | 46 | x1, y1, x2, y2 = crop_region 47 | 48 | ratio_crop_region = (x2 - x1) / (y2 - y1) 49 | ratio_processing = processing_width / processing_height 50 | 51 | if ratio_crop_region > ratio_processing: 52 | desired_height = (x2 - x1) / ratio_processing 53 | desired_height_diff = int(desired_height - (y2-y1)) 54 | y1 -= desired_height_diff//2 55 | y2 += desired_height_diff - desired_height_diff//2 56 | if y2 >= image_height: 57 | diff = y2 - image_height 58 | y2 -= diff 59 | y1 -= diff 60 | if y1 < 0: 61 | y2 -= y1 62 | y1 -= y1 63 | if y2 >= image_height: 64 | y2 = image_height 65 | else: 66 | desired_width = (y2 - y1) * ratio_processing 67 | desired_width_diff = int(desired_width - (x2-x1)) 68 | x1 -= desired_width_diff//2 69 | x2 += desired_width_diff - desired_width_diff//2 70 | if x2 >= image_width: 71 | diff = x2 - image_width 72 | x2 -= diff 73 | x1 -= diff 74 | if x1 < 0: 75 | x2 -= x1 76 | x1 -= x1 77 | if x2 >= image_width: 78 | x2 = image_width 79 | 80 | return x1, y1, x2, y2 81 | 82 | 83 | def fill(image, mask): 84 | """fills masked regions with colors from image using blur. Not extremely effective.""" 85 | 86 | image_mod = Image.new('RGBA', (image.width, image.height)) 87 | 88 | image_masked = Image.new('RGBa', (image.width, image.height)) 89 | image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L'))) 90 | 91 | image_masked = image_masked.convert('RGBa') 92 | 93 | for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]: 94 | blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA') 95 | for _ in range(repeats): 96 | image_mod.alpha_composite(blurred) 97 | 98 | return image_mod.convert("RGB") 99 | 100 | -------------------------------------------------------------------------------- /modules/sd_hijack_clip_old.py: -------------------------------------------------------------------------------- 1 | from modules import sd_hijack_clip 2 | from modules import shared 3 | 4 | 5 | def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts): 6 | id_start = self.id_start 7 | id_end = self.id_end 8 | maxlen = self.wrapped.max_length # you get to stay at 77 9 | used_custom_terms = [] 10 | remade_batch_tokens = [] 11 | hijack_comments = [] 12 | hijack_fixes = [] 13 | token_count = 0 14 | 15 | cache = {} 16 | batch_tokens = self.tokenize(texts) 17 | batch_multipliers = [] 18 | for tokens in batch_tokens: 19 | tuple_tokens = tuple(tokens) 20 | 21 | if tuple_tokens in cache: 22 | remade_tokens, fixes, multipliers = cache[tuple_tokens] 23 | else: 24 | fixes = [] 25 | remade_tokens = [] 26 | multipliers = [] 27 | mult = 1.0 28 | 29 | i = 0 30 | while i < len(tokens): 31 | token = tokens[i] 32 | 33 | embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) 34 | 35 | mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None 36 | if mult_change is not None: 37 | mult *= mult_change 38 | i += 1 39 | elif embedding is None: 40 | remade_tokens.append(token) 41 | multipliers.append(mult) 42 | i += 1 43 | else: 44 | emb_len = int(embedding.vec.shape[0]) 45 | fixes.append((len(remade_tokens), embedding)) 46 | remade_tokens += [0] * emb_len 47 | multipliers += [mult] * emb_len 48 | used_custom_terms.append((embedding.name, embedding.checksum())) 49 | i += embedding_length_in_tokens 50 | 51 | if len(remade_tokens) > maxlen - 2: 52 | vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} 53 | ovf = remade_tokens[maxlen - 2:] 54 | overflowing_words = [vocab.get(int(x), "") for x in ovf] 55 | overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) 56 | hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") 57 | 58 | token_count = len(remade_tokens) 59 | remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) 60 | remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] 61 | cache[tuple_tokens] = (remade_tokens, fixes, multipliers) 62 | 63 | multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) 64 | multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] 65 | 66 | remade_batch_tokens.append(remade_tokens) 67 | hijack_fixes.append(fixes) 68 | batch_multipliers.append(multipliers) 69 | return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count 70 | 71 | 72 | def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts): 73 | batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts) 74 | 75 | self.hijack.comments += hijack_comments 76 | 77 | if len(used_custom_terms) > 0: 78 | embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms) 79 | self.hijack.comments.append(f"Used embeddings: {embedding_names}") 80 | 81 | self.hijack.fixes = hijack_fixes 82 | return self.process_tokens(remade_batch_tokens, batch_multipliers) 83 | -------------------------------------------------------------------------------- /modules/call_queue.py: -------------------------------------------------------------------------------- 1 | import html 2 | import sys 3 | import threading 4 | import traceback 5 | import time 6 | 7 | from modules import shared, progress 8 | 9 | queue_lock = threading.Lock() 10 | 11 | 12 | def wrap_queued_call(func): 13 | def f(*args, **kwargs): 14 | with queue_lock: 15 | res = func(*args, **kwargs) 16 | 17 | return res 18 | 19 | return f 20 | 21 | 22 | def wrap_gradio_gpu_call(func, extra_outputs=None): 23 | def f(*args, **kwargs): 24 | 25 | # if the first argument is a string that says "task(...)", it is treated as a job id 26 | if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": 27 | id_task = args[0] 28 | progress.add_task_to_queue(id_task) 29 | else: 30 | id_task = None 31 | 32 | with queue_lock: 33 | shared.state.begin() 34 | progress.start_task(id_task) 35 | 36 | try: 37 | res = func(*args, **kwargs) 38 | progress.record_results(id_task, res) 39 | finally: 40 | progress.finish_task(id_task) 41 | 42 | shared.state.end() 43 | 44 | return res 45 | 46 | return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) 47 | 48 | 49 | def wrap_gradio_call(func, extra_outputs=None, add_stats=False): 50 | def f(*args, extra_outputs_array=extra_outputs, **kwargs): 51 | run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats 52 | if run_memmon: 53 | shared.mem_mon.monitor() 54 | t = time.perf_counter() 55 | 56 | try: 57 | res = list(func(*args, **kwargs)) 58 | except Exception as e: 59 | # When printing out our debug argument list, do not print out more than a MB of text 60 | max_debug_str_len = 131072 # (1024*1024)/8 61 | 62 | print("Error completing request", file=sys.stderr) 63 | argStr = f"Arguments: {args} {kwargs}" 64 | print(argStr[:max_debug_str_len], file=sys.stderr) 65 | if len(argStr) > max_debug_str_len: 66 | print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) 67 | 68 | print(traceback.format_exc(), file=sys.stderr) 69 | 70 | shared.state.job = "" 71 | shared.state.job_count = 0 72 | 73 | if extra_outputs_array is None: 74 | extra_outputs_array = [None, ''] 75 | 76 | error_message = f'{type(e).__name__}: {e}' 77 | res = extra_outputs_array + [f"
{html.escape(error_message)}
"] 78 | 79 | shared.state.skipped = False 80 | shared.state.interrupted = False 81 | shared.state.job_count = 0 82 | 83 | if not add_stats: 84 | return tuple(res) 85 | 86 | elapsed = time.perf_counter() - t 87 | elapsed_m = int(elapsed // 60) 88 | elapsed_s = elapsed % 60 89 | elapsed_text = f"{elapsed_s:.2f}s" 90 | if elapsed_m > 0: 91 | elapsed_text = f"{elapsed_m}m "+elapsed_text 92 | 93 | if run_memmon: 94 | mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} 95 | active_peak = mem_stats['active_peak'] 96 | reserved_peak = mem_stats['reserved_peak'] 97 | sys_peak = mem_stats['system_peak'] 98 | sys_total = mem_stats['total'] 99 | sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) 100 | 101 | vram_html = f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" 102 | else: 103 | vram_html = '' 104 | 105 | # last item is always HTML 106 | res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" 107 | 108 | return tuple(res) 109 | 110 | return f 111 | 112 | -------------------------------------------------------------------------------- /modules/sd_hijack_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from packaging import version 3 | 4 | from modules import devices 5 | from modules.sd_hijack_utils import CondFunc 6 | 7 | 8 | class TorchHijackForUnet: 9 | """ 10 | This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; 11 | this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 12 | """ 13 | 14 | def __getattr__(self, item): 15 | if item == 'cat': 16 | return self.cat 17 | 18 | if hasattr(torch, item): 19 | return getattr(torch, item) 20 | 21 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") 22 | 23 | def cat(self, tensors, *args, **kwargs): 24 | if len(tensors) == 2: 25 | a, b = tensors 26 | if a.shape[-2:] != b.shape[-2:]: 27 | a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") 28 | 29 | tensors = (a, b) 30 | 31 | return torch.cat(tensors, *args, **kwargs) 32 | 33 | 34 | th = TorchHijackForUnet() 35 | 36 | 37 | # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling 38 | def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): 39 | 40 | if isinstance(cond, dict): 41 | for y in cond.keys(): 42 | cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] 43 | 44 | with devices.autocast(): 45 | return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() 46 | 47 | 48 | class GELUHijack(torch.nn.GELU, torch.nn.Module): 49 | def __init__(self, *args, **kwargs): 50 | torch.nn.GELU.__init__(self, *args, **kwargs) 51 | def forward(self, x): 52 | if devices.unet_needs_upcast: 53 | return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) 54 | else: 55 | return torch.nn.GELU.forward(self, x) 56 | 57 | 58 | ddpm_edit_hijack = None 59 | def hijack_ddpm_edit(): 60 | global ddpm_edit_hijack 61 | if not ddpm_edit_hijack: 62 | CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) 63 | CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) 64 | ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) 65 | 66 | 67 | unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast 68 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) 69 | CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) 70 | if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): 71 | CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) 72 | CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) 73 | CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) 74 | 75 | first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 76 | first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) 77 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) 78 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) 79 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) 80 | -------------------------------------------------------------------------------- /modules/models/diffusion/uni_pc/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | 5 | from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC 6 | from modules import shared, devices 7 | 8 | 9 | class UniPCSampler(object): 10 | def __init__(self, model, **kwargs): 11 | super().__init__() 12 | self.model = model 13 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 14 | self.before_sample = None 15 | self.after_sample = None 16 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != devices.device: 21 | attr = attr.to(devices.device) 22 | setattr(self, name, attr) 23 | 24 | def set_hooks(self, before_sample, after_sample, after_update): 25 | self.before_sample = before_sample 26 | self.after_sample = after_sample 27 | self.after_update = after_update 28 | 29 | @torch.no_grad() 30 | def sample(self, 31 | S, 32 | batch_size, 33 | shape, 34 | conditioning=None, 35 | callback=None, 36 | normals_sequence=None, 37 | img_callback=None, 38 | quantize_x0=False, 39 | eta=0., 40 | mask=None, 41 | x0=None, 42 | temperature=1., 43 | noise_dropout=0., 44 | score_corrector=None, 45 | corrector_kwargs=None, 46 | verbose=True, 47 | x_T=None, 48 | log_every_t=100, 49 | unconditional_guidance_scale=1., 50 | unconditional_conditioning=None, 51 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 52 | **kwargs 53 | ): 54 | if conditioning is not None: 55 | if isinstance(conditioning, dict): 56 | ctmp = conditioning[list(conditioning.keys())[0]] 57 | while isinstance(ctmp, list): 58 | ctmp = ctmp[0] 59 | cbs = ctmp.shape[0] 60 | if cbs != batch_size: 61 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 62 | 63 | elif isinstance(conditioning, list): 64 | for ctmp in conditioning: 65 | if ctmp.shape[0] != batch_size: 66 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 67 | 68 | else: 69 | if conditioning.shape[0] != batch_size: 70 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 71 | 72 | # sampling 73 | C, H, W = shape 74 | size = (batch_size, C, H, W) 75 | # print(f'Data shape for UniPC sampling is {size}') 76 | 77 | device = self.model.betas.device 78 | if x_T is None: 79 | img = torch.randn(size, device=device) 80 | else: 81 | img = x_T 82 | 83 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 84 | 85 | # SD 1.X is "noise", SD 2.X is "v" 86 | model_type = "v" if self.model.parameterization == "v" else "noise" 87 | 88 | model_fn = model_wrapper( 89 | lambda x, t, c: self.model.apply_model(x, t, c), 90 | ns, 91 | model_type=model_type, 92 | guidance_type="classifier-free", 93 | #condition=conditioning, 94 | #unconditional_condition=unconditional_conditioning, 95 | guidance_scale=unconditional_guidance_scale, 96 | ) 97 | 98 | uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) 99 | x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final) 100 | 101 | return x.to(device), None 102 | -------------------------------------------------------------------------------- /modules/gfpgan_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | import facexlib 6 | import gfpgan 7 | 8 | import modules.face_restoration 9 | from modules import paths, shared, devices, modelloader 10 | 11 | model_dir = "GFPGAN" 12 | user_path = None 13 | model_path = os.path.join(paths.models_path, model_dir) 14 | model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" 15 | have_gfpgan = False 16 | loaded_gfpgan_model = None 17 | 18 | 19 | def gfpgann(): 20 | global loaded_gfpgan_model 21 | global model_path 22 | if loaded_gfpgan_model is not None: 23 | loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan) 24 | return loaded_gfpgan_model 25 | 26 | if gfpgan_constructor is None: 27 | return None 28 | 29 | models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") 30 | if len(models) == 1 and "http" in models[0]: 31 | model_file = models[0] 32 | elif len(models) != 0: 33 | latest_file = max(models, key=os.path.getctime) 34 | model_file = latest_file 35 | else: 36 | print("Unable to load gfpgan model!") 37 | return None 38 | if hasattr(facexlib.detection.retinaface, 'device'): 39 | facexlib.detection.retinaface.device = devices.device_gfpgan 40 | model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) 41 | loaded_gfpgan_model = model 42 | 43 | return model 44 | 45 | 46 | def send_model_to(model, device): 47 | model.gfpgan.to(device) 48 | model.face_helper.face_det.to(device) 49 | model.face_helper.face_parse.to(device) 50 | 51 | 52 | def gfpgan_fix_faces(np_image): 53 | model = gfpgann() 54 | if model is None: 55 | return np_image 56 | 57 | send_model_to(model, devices.device_gfpgan) 58 | 59 | np_image_bgr = np_image[:, :, ::-1] 60 | cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) 61 | np_image = gfpgan_output_bgr[:, :, ::-1] 62 | 63 | model.face_helper.clean_all() 64 | 65 | if shared.opts.face_restoration_unload: 66 | send_model_to(model, devices.cpu) 67 | 68 | return np_image 69 | 70 | 71 | gfpgan_constructor = None 72 | 73 | 74 | def setup_model(dirname): 75 | global model_path 76 | if not os.path.exists(model_path): 77 | os.makedirs(model_path) 78 | 79 | try: 80 | from gfpgan import GFPGANer 81 | from facexlib import detection, parsing # noqa: F401 82 | global user_path 83 | global have_gfpgan 84 | global gfpgan_constructor 85 | 86 | load_file_from_url_orig = gfpgan.utils.load_file_from_url 87 | facex_load_file_from_url_orig = facexlib.detection.load_file_from_url 88 | facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url 89 | 90 | def my_load_file_from_url(**kwargs): 91 | return load_file_from_url_orig(**dict(kwargs, model_dir=model_path)) 92 | 93 | def facex_load_file_from_url(**kwargs): 94 | return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None)) 95 | 96 | def facex_load_file_from_url2(**kwargs): 97 | return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None)) 98 | 99 | gfpgan.utils.load_file_from_url = my_load_file_from_url 100 | facexlib.detection.load_file_from_url = facex_load_file_from_url 101 | facexlib.parsing.load_file_from_url = facex_load_file_from_url2 102 | user_path = dirname 103 | have_gfpgan = True 104 | gfpgan_constructor = GFPGANer 105 | 106 | class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): 107 | def name(self): 108 | return "GFPGAN" 109 | 110 | def restore(self, np_image): 111 | return gfpgan_fix_faces(np_image) 112 | 113 | shared.face_restorers.append(FaceRestorerGFPGAN()) 114 | except Exception: 115 | print("Error setting up GFPGAN:", file=sys.stderr) 116 | print(traceback.format_exc(), file=sys.stderr) 117 | -------------------------------------------------------------------------------- /scripts/sd_upscale.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import modules.scripts as scripts 4 | import gradio as gr 5 | from PIL import Image 6 | 7 | from modules import processing, shared, images, devices 8 | from modules.processing import Processed 9 | from modules.shared import opts, state 10 | 11 | 12 | class Script(scripts.Script): 13 | def title(self): 14 | return "SD upscale" 15 | 16 | def show(self, is_img2img): 17 | return is_img2img 18 | 19 | def ui(self, is_img2img): 20 | info = gr.HTML("

Will upscale the image by the selected scale factor; use width and height sliders to set tile size

") 21 | overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap")) 22 | scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor")) 23 | upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=self.elem_id("upscaler_index")) 24 | 25 | return [info, overlap, upscaler_index, scale_factor] 26 | 27 | def run(self, p, _, overlap, upscaler_index, scale_factor): 28 | if isinstance(upscaler_index, str): 29 | upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower()) 30 | processing.fix_seed(p) 31 | upscaler = shared.sd_upscalers[upscaler_index] 32 | 33 | p.extra_generation_params["SD upscale overlap"] = overlap 34 | p.extra_generation_params["SD upscale upscaler"] = upscaler.name 35 | 36 | initial_info = None 37 | seed = p.seed 38 | 39 | init_img = p.init_images[0] 40 | init_img = images.flatten(init_img, opts.img2img_background_color) 41 | 42 | if upscaler.name != "None": 43 | img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path) 44 | else: 45 | img = init_img 46 | 47 | devices.torch_gc() 48 | 49 | grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap) 50 | 51 | batch_size = p.batch_size 52 | upscale_count = p.n_iter 53 | p.n_iter = 1 54 | p.do_not_save_grid = True 55 | p.do_not_save_samples = True 56 | 57 | work = [] 58 | 59 | for _y, _h, row in grid.tiles: 60 | for tiledata in row: 61 | work.append(tiledata[2]) 62 | 63 | batch_count = math.ceil(len(work) / batch_size) 64 | state.job_count = batch_count * upscale_count 65 | 66 | print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.") 67 | 68 | result_images = [] 69 | for n in range(upscale_count): 70 | start_seed = seed + n 71 | p.seed = start_seed 72 | 73 | work_results = [] 74 | for i in range(batch_count): 75 | p.batch_size = batch_size 76 | p.init_images = work[i * batch_size:(i + 1) * batch_size] 77 | 78 | state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}" 79 | processed = processing.process_images(p) 80 | 81 | if initial_info is None: 82 | initial_info = processed.info 83 | 84 | p.seed = processed.seed + 1 85 | work_results += processed.images 86 | 87 | image_index = 0 88 | for _y, _h, row in grid.tiles: 89 | for tiledata in row: 90 | tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) 91 | image_index += 1 92 | 93 | combined_image = images.combine_grid(grid) 94 | result_images.append(combined_image) 95 | 96 | if opts.samples_save: 97 | images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p) 98 | 99 | processed = Processed(p, result_images, seed, initial_info) 100 | 101 | return processed 102 | -------------------------------------------------------------------------------- /javascript/aspectRatioOverlay.js: -------------------------------------------------------------------------------- 1 | 2 | let currentWidth = null; 3 | let currentHeight = null; 4 | let arFrameTimeout = setTimeout(function() {}, 0); 5 | 6 | function dimensionChange(e, is_width, is_height) { 7 | 8 | if (is_width) { 9 | currentWidth = e.target.value * 1.0; 10 | } 11 | if (is_height) { 12 | currentHeight = e.target.value * 1.0; 13 | } 14 | 15 | var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block"; 16 | 17 | if (!inImg2img) { 18 | return; 19 | } 20 | 21 | var targetElement = null; 22 | 23 | var tabIndex = get_tab_index('mode_img2img'); 24 | if (tabIndex == 0) { // img2img 25 | targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img'); 26 | } else if (tabIndex == 1) { //Sketch 27 | targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img'); 28 | } else if (tabIndex == 2) { // Inpaint 29 | targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); 30 | } else if (tabIndex == 3) { // Inpaint sketch 31 | targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img'); 32 | } 33 | 34 | 35 | if (targetElement) { 36 | 37 | var arPreviewRect = gradioApp().querySelector('#imageARPreview'); 38 | if (!arPreviewRect) { 39 | arPreviewRect = document.createElement('div'); 40 | arPreviewRect.id = "imageARPreview"; 41 | gradioApp().appendChild(arPreviewRect); 42 | } 43 | 44 | 45 | 46 | var viewportOffset = targetElement.getBoundingClientRect(); 47 | 48 | var viewportscale = Math.min(targetElement.clientWidth / targetElement.naturalWidth, targetElement.clientHeight / targetElement.naturalHeight); 49 | 50 | var scaledx = targetElement.naturalWidth * viewportscale; 51 | var scaledy = targetElement.naturalHeight * viewportscale; 52 | 53 | var cleintRectTop = (viewportOffset.top + window.scrollY); 54 | var cleintRectLeft = (viewportOffset.left + window.scrollX); 55 | var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2); 56 | var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2); 57 | 58 | var arscale = Math.min(scaledx / currentWidth, scaledy / currentHeight); 59 | var arscaledx = currentWidth * arscale; 60 | var arscaledy = currentHeight * arscale; 61 | 62 | var arRectTop = cleintRectCentreY - (arscaledy / 2); 63 | var arRectLeft = cleintRectCentreX - (arscaledx / 2); 64 | var arRectWidth = arscaledx; 65 | var arRectHeight = arscaledy; 66 | 67 | arPreviewRect.style.top = arRectTop + 'px'; 68 | arPreviewRect.style.left = arRectLeft + 'px'; 69 | arPreviewRect.style.width = arRectWidth + 'px'; 70 | arPreviewRect.style.height = arRectHeight + 'px'; 71 | 72 | clearTimeout(arFrameTimeout); 73 | arFrameTimeout = setTimeout(function() { 74 | arPreviewRect.style.display = 'none'; 75 | }, 2000); 76 | 77 | arPreviewRect.style.display = 'block'; 78 | 79 | } 80 | 81 | } 82 | 83 | 84 | onUiUpdate(function() { 85 | var arPreviewRect = gradioApp().querySelector('#imageARPreview'); 86 | if (arPreviewRect) { 87 | arPreviewRect.style.display = 'none'; 88 | } 89 | var tabImg2img = gradioApp().querySelector("#tab_img2img"); 90 | if (tabImg2img) { 91 | var inImg2img = tabImg2img.style.display == "block"; 92 | if (inImg2img) { 93 | let inputs = gradioApp().querySelectorAll('input'); 94 | inputs.forEach(function(e) { 95 | var is_width = e.parentElement.id == "img2img_width"; 96 | var is_height = e.parentElement.id == "img2img_height"; 97 | 98 | if ((is_width || is_height) && !e.classList.contains('scrollwatch')) { 99 | e.addEventListener('input', function(e) { 100 | dimensionChange(e, is_width, is_height); 101 | }); 102 | e.classList.add('scrollwatch'); 103 | } 104 | if (is_width) { 105 | currentWidth = e.value * 1.0; 106 | } 107 | if (is_height) { 108 | currentHeight = e.value * 1.0; 109 | } 110 | }); 111 | } 112 | } 113 | }); 114 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: You think somethings is broken in the UI 3 | title: "[Bug]: " 4 | labels: ["bug-report"] 5 | 6 | body: 7 | - type: checkboxes 8 | attributes: 9 | label: Is there an existing issue for this? 10 | description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit. 11 | options: 12 | - label: I have searched the existing issues and checked the recent builds/commits 13 | required: true 14 | - type: markdown 15 | attributes: 16 | value: | 17 | *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible** 18 | - type: textarea 19 | id: what-did 20 | attributes: 21 | label: What happened? 22 | description: Tell us what happened in a very clear and simple way 23 | validations: 24 | required: true 25 | - type: textarea 26 | id: steps 27 | attributes: 28 | label: Steps to reproduce the problem 29 | description: Please provide us with precise step by step information on how to reproduce the bug 30 | value: | 31 | 1. Go to .... 32 | 2. Press .... 33 | 3. ... 34 | validations: 35 | required: true 36 | - type: textarea 37 | id: what-should 38 | attributes: 39 | label: What should have happened? 40 | description: Tell what you think the normal behavior should be 41 | validations: 42 | required: true 43 | - type: input 44 | id: commit 45 | attributes: 46 | label: Commit where the problem happens 47 | description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.) 48 | validations: 49 | required: true 50 | - type: dropdown 51 | id: py-version 52 | attributes: 53 | label: What Python version are you running on ? 54 | multiple: false 55 | options: 56 | - Python 3.10.x 57 | - Python 3.11.x (above, no supported yet) 58 | - Python 3.9.x (below, no recommended) 59 | - type: dropdown 60 | id: platforms 61 | attributes: 62 | label: What platforms do you use to access the UI ? 63 | multiple: true 64 | options: 65 | - Windows 66 | - Linux 67 | - MacOS 68 | - iOS 69 | - Android 70 | - Other/Cloud 71 | - type: dropdown 72 | id: device 73 | attributes: 74 | label: What device are you running WebUI on? 75 | multiple: true 76 | options: 77 | - Nvidia GPUs (RTX 20 above) 78 | - Nvidia GPUs (GTX 16 below) 79 | - AMD GPUs (RX 6000 above) 80 | - AMD GPUs (RX 5000 below) 81 | - CPU 82 | - Other GPUs 83 | - type: dropdown 84 | id: browsers 85 | attributes: 86 | label: What browsers do you use to access the UI ? 87 | multiple: true 88 | options: 89 | - Mozilla Firefox 90 | - Google Chrome 91 | - Brave 92 | - Apple Safari 93 | - Microsoft Edge 94 | - type: textarea 95 | id: cmdargs 96 | attributes: 97 | label: Command Line Arguments 98 | description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise. 99 | render: Shell 100 | validations: 101 | required: true 102 | - type: textarea 103 | id: extensions 104 | attributes: 105 | label: List of extensions 106 | description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise. 107 | validations: 108 | required: true 109 | - type: textarea 110 | id: logs 111 | attributes: 112 | label: Console logs 113 | description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service. 114 | render: Shell 115 | validations: 116 | required: true 117 | - type: textarea 118 | id: misc 119 | attributes: 120 | label: Additional information 121 | description: Please provide us with any relevant additional info or context. 122 | -------------------------------------------------------------------------------- /modules/postprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | 5 | from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste 6 | from modules.shared import opts 7 | 8 | 9 | def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): 10 | devices.torch_gc() 11 | 12 | shared.state.begin() 13 | shared.state.job = 'extras' 14 | 15 | image_data = [] 16 | image_names = [] 17 | outputs = [] 18 | 19 | if extras_mode == 1: 20 | for img in image_folder: 21 | if isinstance(img, Image.Image): 22 | image = img 23 | fn = '' 24 | else: 25 | image = Image.open(os.path.abspath(img.name)) 26 | fn = os.path.splitext(img.orig_name)[0] 27 | image_data.append(image) 28 | image_names.append(fn) 29 | elif extras_mode == 2: 30 | assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' 31 | assert input_dir, 'input directory not selected' 32 | 33 | image_list = shared.listfiles(input_dir) 34 | for filename in image_list: 35 | try: 36 | image = Image.open(filename) 37 | except Exception: 38 | continue 39 | image_data.append(image) 40 | image_names.append(filename) 41 | else: 42 | assert image, 'image not selected' 43 | 44 | image_data.append(image) 45 | image_names.append(None) 46 | 47 | if extras_mode == 2 and output_dir != '': 48 | outpath = output_dir 49 | else: 50 | outpath = opts.outdir_samples or opts.outdir_extras_samples 51 | 52 | infotext = '' 53 | 54 | for image, name in zip(image_data, image_names): 55 | shared.state.textinfo = name 56 | 57 | existing_pnginfo = image.info or {} 58 | 59 | pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB")) 60 | 61 | scripts.scripts_postproc.run(pp, args) 62 | 63 | if opts.use_original_name_batch and name is not None: 64 | basename = os.path.splitext(os.path.basename(name))[0] 65 | else: 66 | basename = '' 67 | 68 | infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) 69 | 70 | if opts.enable_pnginfo: 71 | pp.image.info = existing_pnginfo 72 | pp.image.info["postprocessing"] = infotext 73 | 74 | if save_output: 75 | images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) 76 | 77 | if extras_mode != 2 or show_extras_results: 78 | outputs.append(pp.image) 79 | 80 | devices.torch_gc() 81 | 82 | return outputs, ui_common.plaintext_to_html(infotext), '' 83 | 84 | 85 | def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): 86 | """old handler for API""" 87 | 88 | args = scripts.scripts_postproc.create_args_for_run({ 89 | "Upscale": { 90 | "upscale_mode": resize_mode, 91 | "upscale_by": upscaling_resize, 92 | "upscale_to_width": upscaling_resize_w, 93 | "upscale_to_height": upscaling_resize_h, 94 | "upscale_crop": upscaling_crop, 95 | "upscaler_1_name": extras_upscaler_1, 96 | "upscaler_2_name": extras_upscaler_2, 97 | "upscaler_2_visibility": extras_upscaler_2_visibility, 98 | }, 99 | "GFPGAN": { 100 | "gfpgan_visibility": gfpgan_visibility, 101 | }, 102 | "CodeFormer": { 103 | "codeformer_visibility": codeformer_visibility, 104 | "codeformer_weight": codeformer_weight, 105 | }, 106 | }) 107 | 108 | return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) 109 | -------------------------------------------------------------------------------- /modules/upscaler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import abstractmethod 3 | 4 | import PIL 5 | from PIL import Image 6 | 7 | import modules.shared 8 | from modules import modelloader, shared 9 | 10 | LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) 11 | NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) 12 | 13 | 14 | class Upscaler: 15 | name = None 16 | model_path = None 17 | model_name = None 18 | model_url = None 19 | enable = True 20 | filter = None 21 | model = None 22 | user_path = None 23 | scalers: [] 24 | tile = True 25 | 26 | def __init__(self, create_dirs=False): 27 | self.mod_pad_h = None 28 | self.tile_size = modules.shared.opts.ESRGAN_tile 29 | self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap 30 | self.device = modules.shared.device 31 | self.img = None 32 | self.output = None 33 | self.scale = 1 34 | self.half = not modules.shared.cmd_opts.no_half 35 | self.pre_pad = 0 36 | self.mod_scale = None 37 | self.model_download_path = None 38 | 39 | if self.model_path is None and self.name: 40 | self.model_path = os.path.join(shared.models_path, self.name) 41 | if self.model_path and create_dirs: 42 | os.makedirs(self.model_path, exist_ok=True) 43 | 44 | try: 45 | import cv2 # noqa: F401 46 | self.can_tile = True 47 | except Exception: 48 | pass 49 | 50 | @abstractmethod 51 | def do_upscale(self, img: PIL.Image, selected_model: str): 52 | return img 53 | 54 | def upscale(self, img: PIL.Image, scale, selected_model: str = None): 55 | self.scale = scale 56 | dest_w = int(img.width * scale) 57 | dest_h = int(img.height * scale) 58 | 59 | for _ in range(3): 60 | shape = (img.width, img.height) 61 | 62 | img = self.do_upscale(img, selected_model) 63 | 64 | if shape == (img.width, img.height): 65 | break 66 | 67 | if img.width >= dest_w and img.height >= dest_h: 68 | break 69 | 70 | if img.width != dest_w or img.height != dest_h: 71 | img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) 72 | 73 | return img 74 | 75 | @abstractmethod 76 | def load_model(self, path: str): 77 | pass 78 | 79 | def find_models(self, ext_filter=None) -> list: 80 | return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path) 81 | 82 | def update_status(self, prompt): 83 | print(f"\nextras: {prompt}", file=shared.progress_print_out) 84 | 85 | 86 | class UpscalerData: 87 | name = None 88 | data_path = None 89 | scale: int = 4 90 | scaler: Upscaler = None 91 | model: None 92 | 93 | def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None): 94 | self.name = name 95 | self.data_path = path 96 | self.local_data_path = path 97 | self.scaler = upscaler 98 | self.scale = scale 99 | self.model = model 100 | 101 | 102 | class UpscalerNone(Upscaler): 103 | name = "None" 104 | scalers = [] 105 | 106 | def load_model(self, path): 107 | pass 108 | 109 | def do_upscale(self, img, selected_model=None): 110 | return img 111 | 112 | def __init__(self, dirname=None): 113 | super().__init__(False) 114 | self.scalers = [UpscalerData("None", None, self)] 115 | 116 | 117 | class UpscalerLanczos(Upscaler): 118 | scalers = [] 119 | 120 | def do_upscale(self, img, selected_model=None): 121 | return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS) 122 | 123 | def load_model(self, _): 124 | pass 125 | 126 | def __init__(self, dirname=None): 127 | super().__init__(False) 128 | self.name = "Lanczos" 129 | self.scalers = [UpscalerData("Lanczos", None, self)] 130 | 131 | 132 | class UpscalerNearest(Upscaler): 133 | scalers = [] 134 | 135 | def do_upscale(self, img, selected_model=None): 136 | return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST) 137 | 138 | def load_model(self, _): 139 | pass 140 | 141 | def __init__(self, dirname=None): 142 | super().__init__(False) 143 | self.name = "Nearest" 144 | self.scalers = [UpscalerData("Nearest", None, self)] 145 | -------------------------------------------------------------------------------- /modules/sd_hijack_inpainting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import ldm.models.diffusion.ddpm 4 | import ldm.models.diffusion.ddim 5 | import ldm.models.diffusion.plms 6 | 7 | from ldm.models.diffusion.ddim import noise_like 8 | from ldm.models.diffusion.sampling_util import norm_thresholding 9 | 10 | 11 | @torch.no_grad() 12 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 13 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 14 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None): 15 | b, *_, device = *x.shape, x.device 16 | 17 | def get_model_output(x, t): 18 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 19 | e_t = self.model.apply_model(x, t, c) 20 | else: 21 | x_in = torch.cat([x] * 2) 22 | t_in = torch.cat([t] * 2) 23 | 24 | if isinstance(c, dict): 25 | assert isinstance(unconditional_conditioning, dict) 26 | c_in = {} 27 | for k in c: 28 | if isinstance(c[k], list): 29 | c_in[k] = [ 30 | torch.cat([unconditional_conditioning[k][i], c[k][i]]) 31 | for i in range(len(c[k])) 32 | ] 33 | else: 34 | c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) 35 | else: 36 | c_in = torch.cat([unconditional_conditioning, c]) 37 | 38 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 39 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 40 | 41 | if score_corrector is not None: 42 | assert self.model.parameterization == "eps" 43 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 44 | 45 | return e_t 46 | 47 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 48 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 49 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 50 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 51 | 52 | def get_x_prev_and_pred_x0(e_t, index): 53 | # select parameters corresponding to the currently considered timestep 54 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 55 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 56 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 57 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 58 | 59 | # current prediction for x_0 60 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 61 | if quantize_denoised: 62 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 63 | if dynamic_threshold is not None: 64 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) 65 | # direction pointing to x_t 66 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 67 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 68 | if noise_dropout > 0.: 69 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 70 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 71 | return x_prev, pred_x0 72 | 73 | e_t = get_model_output(x, t) 74 | if len(old_eps) == 0: 75 | # Pseudo Improved Euler (2nd order) 76 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 77 | e_t_next = get_model_output(x_prev, t_next) 78 | e_t_prime = (e_t + e_t_next) / 2 79 | elif len(old_eps) == 1: 80 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 81 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 82 | elif len(old_eps) == 2: 83 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 84 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 85 | elif len(old_eps) >= 3: 86 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 87 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 88 | 89 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 90 | 91 | return x_prev, pred_x0, e_t 92 | 93 | 94 | def do_inpainting_hijack(): 95 | # p_sample_plms is needed because PLMS can't work with dicts as conditionings 96 | 97 | ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms 98 | -------------------------------------------------------------------------------- /modules/mac_specific.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import platform 3 | from modules.sd_hijack_utils import CondFunc 4 | from packaging import version 5 | 6 | 7 | # has_mps is only available in nightly pytorch (for now) and macOS 12.3+. 8 | # check `getattr` and try it for compatibility 9 | def check_for_mps() -> bool: 10 | if not getattr(torch, 'has_mps', False): 11 | return False 12 | try: 13 | torch.zeros(1).to(torch.device("mps")) 14 | return True 15 | except Exception: 16 | return False 17 | has_mps = check_for_mps() 18 | 19 | 20 | # MPS workaround for https://github.com/pytorch/pytorch/issues/89784 21 | def cumsum_fix(input, cumsum_func, *args, **kwargs): 22 | if input.device.type == 'mps': 23 | output_dtype = kwargs.get('dtype', input.dtype) 24 | if output_dtype == torch.int64: 25 | return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) 26 | elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): 27 | return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) 28 | return cumsum_func(input, *args, **kwargs) 29 | 30 | 31 | if has_mps: 32 | # MPS fix for randn in torchsde 33 | CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps') 34 | 35 | if platform.mac_ver()[0].startswith("13.2."): 36 | # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) 37 | CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760) 38 | 39 | if version.parse(torch.__version__) < version.parse("1.13"): 40 | # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working 41 | 42 | # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 43 | CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), 44 | lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) 45 | # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 46 | CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), 47 | lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') 48 | # MPS workaround for https://github.com/pytorch/pytorch/issues/90532 49 | CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) 50 | elif version.parse(torch.__version__) > version.parse("1.13.1"): 51 | cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) 52 | cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) 53 | CondFunc('torch.cumsum', cumsum_fix_func, None) 54 | CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) 55 | CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) 56 | 57 | # MPS workaround for https://github.com/pytorch/pytorch/issues/96113 58 | CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps') 59 | 60 | # MPS workaround for https://github.com/pytorch/pytorch/issues/92311 61 | if platform.processor() == 'i386': 62 | for funcName in ['torch.argmax', 'torch.Tensor.argmax']: 63 | CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps') 64 | -------------------------------------------------------------------------------- /modules/lowvram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules import devices 3 | 4 | module_in_gpu = None 5 | cpu = torch.device("cpu") 6 | 7 | 8 | def send_everything_to_cpu(): 9 | global module_in_gpu 10 | 11 | if module_in_gpu is not None: 12 | module_in_gpu.to(cpu) 13 | 14 | module_in_gpu = None 15 | 16 | 17 | def setup_for_low_vram(sd_model, use_medvram): 18 | parents = {} 19 | 20 | def send_me_to_gpu(module, _): 21 | """send this module to GPU; send whatever tracked module was previous in GPU to CPU; 22 | we add this as forward_pre_hook to a lot of modules and this way all but one of them will 23 | be in CPU 24 | """ 25 | global module_in_gpu 26 | 27 | module = parents.get(module, module) 28 | 29 | if module_in_gpu == module: 30 | return 31 | 32 | if module_in_gpu is not None: 33 | module_in_gpu.to(cpu) 34 | 35 | module.to(devices.device) 36 | module_in_gpu = module 37 | 38 | # see below for register_forward_pre_hook; 39 | # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is 40 | # useless here, and we just replace those methods 41 | 42 | first_stage_model = sd_model.first_stage_model 43 | first_stage_model_encode = sd_model.first_stage_model.encode 44 | first_stage_model_decode = sd_model.first_stage_model.decode 45 | 46 | def first_stage_model_encode_wrap(x): 47 | send_me_to_gpu(first_stage_model, None) 48 | return first_stage_model_encode(x) 49 | 50 | def first_stage_model_decode_wrap(z): 51 | send_me_to_gpu(first_stage_model, None) 52 | return first_stage_model_decode(z) 53 | 54 | # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field 55 | if hasattr(sd_model.cond_stage_model, 'model'): 56 | sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model 57 | 58 | # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then 59 | # send the model to GPU. Then put modules back. the modules will be in CPU. 60 | stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model 61 | sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None 62 | sd_model.to(devices.device) 63 | sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored 64 | 65 | # register hooks for those the first three models 66 | sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) 67 | sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) 68 | sd_model.first_stage_model.encode = first_stage_model_encode_wrap 69 | sd_model.first_stage_model.decode = first_stage_model_decode_wrap 70 | if sd_model.depth_model: 71 | sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu) 72 | if sd_model.embedder: 73 | sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) 74 | parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model 75 | 76 | if hasattr(sd_model.cond_stage_model, 'model'): 77 | sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer 78 | del sd_model.cond_stage_model.transformer 79 | 80 | if use_medvram: 81 | sd_model.model.register_forward_pre_hook(send_me_to_gpu) 82 | else: 83 | diff_model = sd_model.model.diffusion_model 84 | 85 | # the third remaining model is still too big for 4 GB, so we also do the same for its submodules 86 | # so that only one of them is in GPU at a time 87 | stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed 88 | diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None 89 | sd_model.model.to(devices.device) 90 | diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored 91 | 92 | # install hooks for bits of third model 93 | diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu) 94 | for block in diff_model.input_blocks: 95 | block.register_forward_pre_hook(send_me_to_gpu) 96 | diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu) 97 | for block in diff_model.output_blocks: 98 | block.register_forward_pre_hook(send_me_to_gpu) 99 | --------------------------------------------------------------------------------