├── test
├── __init__.py
├── test_files
│ ├── empty.pt
│ ├── mask_basic.png
│ └── img2img_basic.png
├── conftest.py
├── test_utils.py
├── test_extras.py
├── test_img2img.py
└── test_txt2img.py
├── models
├── VAE
│ └── Put VAE here.txt
├── Stable-diffusion
│ └── Put Stable Diffusion checkpoints here.txt
├── deepbooru
│ └── Put your deepbooru release project folder here.txt
├── VAE-approx
│ └── model.pt
└── karlo
│ └── ViT-L-14_stats.th
├── extensions
└── put extensions here.txt
├── localizations
└── Put localization files here.txt
├── textual_inversion_templates
├── none.txt
├── style.txt
├── subject.txt
├── hypernetwork.txt
├── style_filewords.txt
└── subject_filewords.txt
├── embeddings
└── Place Textual Inversion embeddings here.txt
├── .eslintignore
├── .git-blame-ignore-revs
├── requirements-test.txt
├── modules
├── models
│ └── diffusion
│ │ └── uni_pc
│ │ ├── __init__.py
│ │ └── sampler.py
├── Roboto-Regular.ttf
├── textual_inversion
│ ├── test_embedding.png
│ ├── logging.py
│ ├── ui.py
│ └── learn_schedule.py
├── import_hook.py
├── sd_hijack_ip2p.py
├── face_restoration.py
├── shared_items.py
├── script_loading.py
├── timer.py
├── ui_extra_networks_hypernets.py
├── localization.py
├── ngrok.py
├── extra_networks_hypernet.py
├── ui_extra_networks_textual_inversion.py
├── ui_extra_networks_checkpoints.py
├── errors.py
├── sd_hijack_utils.py
├── paths_internal.py
├── sd_hijack_checkpoint.py
├── sd_hijack_open_clip.py
├── sd_hijack_xlmr.py
├── hypernetworks
│ └── ui.py
├── scripts_auto_postprocessing.py
├── sd_samplers.py
├── paths.py
├── ui_components.py
├── sd_vae_approx.py
├── ui_postprocessing.py
├── ui_tempdir.py
├── styles.py
├── sd_vae_taesd.py
├── memmon.py
├── txt2img.py
├── hashes.py
├── deepbooru.py
├── sd_samplers_common.py
├── masking.py
├── sd_hijack_clip_old.py
├── call_queue.py
├── sd_hijack_unet.py
├── gfpgan_model.py
├── postprocessing.py
├── upscaler.py
├── sd_hijack_inpainting.py
├── mac_specific.py
└── lowvram.py
├── screenshot.png
├── html
├── card-no-preview.png
├── extra-networks-no-cards.html
├── footer.html
├── extra-networks-card.html
└── image-update.svg
├── webui-user.bat
├── .pylintrc
├── environment-wsl2.yaml
├── package.json
├── .github
├── ISSUE_TEMPLATE
│ ├── config.yml
│ ├── feature_request.yml
│ └── bug_report.yml
├── pull_request_template.md
└── workflows
│ ├── on_pull_request.yaml
│ └── run_tests.yaml
├── extensions-builtin
├── Lora
│ ├── preload.py
│ ├── ui_extra_networks_lora.py
│ └── extra_networks_lora.py
├── LDSR
│ ├── preload.py
│ └── scripts
│ │ └── ldsr_model.py
├── ScuNET
│ └── preload.py
├── SwinIR
│ └── preload.py
└── prompt-bracket-checker
│ └── javascript
│ └── prompt-bracket-checker.js
├── requirements.txt
├── javascript
├── textualInversion.js
├── imageParams.js
├── hires_fix.js
├── imageMaskFix.js
├── generationParams.js
├── notification.js
├── imageviewerGamepad.js
├── ui_settings_hints.js
├── extensions.js
├── dragdrop.js
└── aspectRatioOverlay.js
├── requirements_versions.txt
├── .gitignore
├── CODEOWNERS
├── webui-macos-env.sh
├── pyproject.toml
├── launch.py
├── scripts
├── postprocessing_gfpgan.py
├── postprocessing_codeformer.py
├── custom_code.py
└── sd_upscale.py
├── webui-user.sh
├── configs
├── v1-inference.yaml
├── alt-diffusion-inference.yaml
├── v1-inpainting-inference.yaml
└── instruct-pix2pix.yaml
├── webui.bat
├── .eslintrc.js
└── script.js
/test/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/VAE/Put VAE here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/extensions/put extensions here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/localizations/Put localization files here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/textual_inversion_templates/none.txt:
--------------------------------------------------------------------------------
1 | picture
2 |
--------------------------------------------------------------------------------
/embeddings/Place Textual Inversion embeddings here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/Stable-diffusion/Put Stable Diffusion checkpoints here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/deepbooru/Put your deepbooru release project folder here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.eslintignore:
--------------------------------------------------------------------------------
1 | extensions
2 | extensions-disabled
3 | repositories
4 | venv
--------------------------------------------------------------------------------
/.git-blame-ignore-revs:
--------------------------------------------------------------------------------
1 | # Apply ESlint
2 | 9c54b78d9dde5601e916f308d9a9d6953ec39430
--------------------------------------------------------------------------------
/requirements-test.txt:
--------------------------------------------------------------------------------
1 | pytest-base-url~=2.0
2 | pytest-cov~=4.0
3 | pytest~=7.3
4 |
--------------------------------------------------------------------------------
/modules/models/diffusion/uni_pc/__init__.py:
--------------------------------------------------------------------------------
1 | from .sampler import UniPCSampler # noqa: F401
2 |
--------------------------------------------------------------------------------
/screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/screenshot.png
--------------------------------------------------------------------------------
/html/card-no-preview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/html/card-no-preview.png
--------------------------------------------------------------------------------
/test/test_files/empty.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/test/test_files/empty.pt
--------------------------------------------------------------------------------
/models/VAE-approx/model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/models/VAE-approx/model.pt
--------------------------------------------------------------------------------
/modules/Roboto-Regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/modules/Roboto-Regular.ttf
--------------------------------------------------------------------------------
/models/karlo/ViT-L-14_stats.th:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/models/karlo/ViT-L-14_stats.th
--------------------------------------------------------------------------------
/test/test_files/mask_basic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/test/test_files/mask_basic.png
--------------------------------------------------------------------------------
/webui-user.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | set PYTHON=
4 | set GIT=
5 | set VENV_DIR=
6 | set COMMANDLINE_ARGS=
7 |
8 | call webui.bat
9 |
--------------------------------------------------------------------------------
/test/test_files/img2img_basic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/test/test_files/img2img_basic.png
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | # See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html
2 | [MESSAGES CONTROL]
3 | disable=C,R,W,E,I
4 |
--------------------------------------------------------------------------------
/modules/textual_inversion/test_embedding.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tamanna18/stable-diffusion-webui/master/modules/textual_inversion/test_embedding.png
--------------------------------------------------------------------------------
/html/extra-networks-no-cards.html:
--------------------------------------------------------------------------------
1 |
2 |
Nothing here. Add some content to the following directories:
3 |
4 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/environment-wsl2.yaml:
--------------------------------------------------------------------------------
1 | name: automatic
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.10
7 | - pip=23.0
8 | - cudatoolkit=11.8
9 | - pytorch=2.0
10 | - torchvision=0.15
11 | - numpy=1.23
12 |
--------------------------------------------------------------------------------
/modules/import_hook.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | # this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
4 | if "--xformers" not in "".join(sys.argv):
5 | sys.modules["xformers"] = None
6 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "stable-diffusion-webui",
3 | "version": "0.0.0",
4 | "devDependencies": {
5 | "eslint": "^8.40.0"
6 | },
7 | "scripts": {
8 | "lint": "eslint .",
9 | "fix": "eslint --fix ."
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: WebUI Community Support
4 | url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
5 | about: Please ask and answer questions here.
6 |
--------------------------------------------------------------------------------
/extensions-builtin/Lora/preload.py:
--------------------------------------------------------------------------------
1 | import os
2 | from modules import paths
3 |
4 |
5 | def preload(parser):
6 | parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
7 |
--------------------------------------------------------------------------------
/extensions-builtin/LDSR/preload.py:
--------------------------------------------------------------------------------
1 | import os
2 | from modules import paths
3 |
4 |
5 | def preload(parser):
6 | parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
7 |
--------------------------------------------------------------------------------
/extensions-builtin/ScuNET/preload.py:
--------------------------------------------------------------------------------
1 | import os
2 | from modules import paths
3 |
4 |
5 | def preload(parser):
6 | parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
7 |
--------------------------------------------------------------------------------
/extensions-builtin/SwinIR/preload.py:
--------------------------------------------------------------------------------
1 | import os
2 | from modules import paths
3 |
4 |
5 | def preload(parser):
6 | parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
7 |
--------------------------------------------------------------------------------
/modules/sd_hijack_ip2p.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 |
4 | def should_hijack_ip2p(checkpoint_info):
5 | from modules import sd_models_config
6 |
7 | ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
8 | cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
9 |
10 | return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename
11 |
--------------------------------------------------------------------------------
/html/footer.html:
--------------------------------------------------------------------------------
1 |
10 |
11 |
12 | {versions}
13 |
14 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | astunparse
2 | blendmodes
3 | accelerate
4 | basicsr
5 | gfpgan
6 | gradio==3.31.0
7 | numpy
8 | omegaconf
9 | opencv-contrib-python
10 | requests
11 | piexif
12 | Pillow
13 | pytorch_lightning==1.7.7
14 | realesrgan
15 | scikit-image>=0.19
16 | timm==0.4.12
17 | transformers==4.25.1
18 | torch
19 | einops
20 | jsonmerge
21 | clean-fid
22 | resize-right
23 | torchdiffeq
24 | kornia
25 | lark
26 | inflection
27 | GitPython
28 | torchsde
29 | safetensors
30 | psutil
31 | rich
32 | tomesd
33 |
--------------------------------------------------------------------------------
/javascript/textualInversion.js:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | function start_training_textual_inversion() {
5 | gradioApp().querySelector('#ti_error').innerHTML = '';
6 |
7 | var id = randomId();
8 | requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function() {}, function(progress) {
9 | gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo;
10 | });
11 |
12 | var res = Array.from(arguments);
13 |
14 | res[0] = id;
15 |
16 | return res;
17 | }
18 |
--------------------------------------------------------------------------------
/html/extra-networks-card.html:
--------------------------------------------------------------------------------
1 |
2 | {background_image}
3 | {metadata_button}
4 |
5 |
6 |
9 |
{search_term}
10 |
11 |
{name}
12 |
{description}
13 |
14 |
15 |
--------------------------------------------------------------------------------
/modules/face_restoration.py:
--------------------------------------------------------------------------------
1 | from modules import shared
2 |
3 |
4 | class FaceRestoration:
5 | def name(self):
6 | return "None"
7 |
8 | def restore(self, np_image):
9 | return np_image
10 |
11 |
12 | def restore_faces(np_image):
13 | face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
14 | if len(face_restorers) == 0:
15 | return np_image
16 |
17 | face_restorer = face_restorers[0]
18 |
19 | return face_restorer.restore(np_image)
20 |
--------------------------------------------------------------------------------
/requirements_versions.txt:
--------------------------------------------------------------------------------
1 | blendmodes==2022
2 | transformers==4.25.1
3 | accelerate==0.18.0
4 | basicsr==1.4.2
5 | gfpgan==1.3.8
6 | gradio==3.31.0
7 | numpy==1.23.5
8 | Pillow==9.5.0
9 | realesrgan==0.3.0
10 | torch
11 | omegaconf==2.2.3
12 | pytorch_lightning==1.9.4
13 | scikit-image==0.20.0
14 | timm==0.6.7
15 | piexif==1.1.3
16 | einops==0.4.1
17 | jsonmerge==1.8.0
18 | clean-fid==0.1.35
19 | resize-right==0.0.2
20 | torchdiffeq==0.2.3
21 | kornia==0.6.7
22 | lark==1.1.2
23 | inflection==0.5.1
24 | GitPython==3.1.30
25 | torchsde==0.2.5
26 | safetensors==0.3.1
27 | httpcore<=0.15
28 | fastapi==0.94.0
29 | tomesd==0.1.2
30 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 | ## Description
2 |
3 | * a simple description of what you're trying to accomplish
4 | * a summary of changes in code
5 | * which issues it fixes, if any
6 |
7 | ## Screenshots/videos:
8 |
9 |
10 | ## Checklist:
11 |
12 | - [ ] I have read [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
13 | - [ ] I have performed a self-review of my own code
14 | - [ ] My code follows the [style guidelines](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing#code-style)
15 | - [ ] My code passes [tests](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Tests)
16 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.ckpt
3 | *.safetensors
4 | *.pth
5 | /ESRGAN/*
6 | /SwinIR/*
7 | /repositories
8 | /venv
9 | /tmp
10 | /model.ckpt
11 | /models/**/*
12 | /GFPGANv1.3.pth
13 | /gfpgan/weights/*.pth
14 | /ui-config.json
15 | /outputs
16 | /config.json
17 | /log
18 | /webui.settings.bat
19 | /embeddings
20 | /styles.csv
21 | /params.txt
22 | /styles.csv.bak
23 | /webui-user.bat
24 | /webui-user.sh
25 | /interrogate
26 | /user.css
27 | /.idea
28 | notification.mp3
29 | /SwinIR
30 | /textual_inversion
31 | .vscode
32 | /extensions
33 | /test/stdout.txt
34 | /test/stderr.txt
35 | /cache.json*
36 | /config_states/
37 | /node_modules
38 | /package-lock.json
39 | /.coverage*
40 |
--------------------------------------------------------------------------------
/test/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | from PIL import Image
5 | from gradio.processing_utils import encode_pil_to_base64
6 |
7 | test_files_path = os.path.dirname(__file__) + "/test_files"
8 |
9 |
10 | @pytest.fixture(scope="session") # session so we don't read this over and over
11 | def img2img_basic_image_base64() -> str:
12 | return encode_pil_to_base64(Image.open(os.path.join(test_files_path, "img2img_basic.png")))
13 |
14 |
15 | @pytest.fixture(scope="session") # session so we don't read this over and over
16 | def mask_basic_image_base64() -> str:
17 | return encode_pil_to_base64(Image.open(os.path.join(test_files_path, "mask_basic.png")))
18 |
--------------------------------------------------------------------------------
/textual_inversion_templates/style.txt:
--------------------------------------------------------------------------------
1 | a painting, art by [name]
2 | a rendering, art by [name]
3 | a cropped painting, art by [name]
4 | the painting, art by [name]
5 | a clean painting, art by [name]
6 | a dirty painting, art by [name]
7 | a dark painting, art by [name]
8 | a picture, art by [name]
9 | a cool painting, art by [name]
10 | a close-up painting, art by [name]
11 | a bright painting, art by [name]
12 | a cropped painting, art by [name]
13 | a good painting, art by [name]
14 | a close-up painting, art by [name]
15 | a rendition, art by [name]
16 | a nice painting, art by [name]
17 | a small painting, art by [name]
18 | a weird painting, art by [name]
19 | a large painting, art by [name]
20 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @AUTOMATIC1111
2 |
3 | # if you were managing a localization and were removed from this file, this is because
4 | # the intended way to do localizations now is via extensions. See:
5 | # https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions
6 | # Make a repo with your localization and since you are still listed as a collaborator
7 | # you can add it to the wiki page yourself. This change is because some people complained
8 | # the git commit log is cluttered with things unrelated to almost everyone and
9 | # because I believe this is the best overall for the project to handle localizations almost
10 | # entirely without my oversight.
11 |
12 |
13 |
--------------------------------------------------------------------------------
/javascript/imageParams.js:
--------------------------------------------------------------------------------
1 | window.onload = (function() {
2 | window.addEventListener('drop', e => {
3 | const target = e.composedPath()[0];
4 | if (target.placeholder.indexOf("Prompt") == -1) return;
5 |
6 | let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
7 |
8 | e.stopPropagation();
9 | e.preventDefault();
10 | const imgParent = gradioApp().getElementById(prompt_target);
11 | const files = e.dataTransfer.files;
12 | const fileInput = imgParent.querySelector('input[type="file"]');
13 | if (fileInput) {
14 | fileInput.files = files;
15 | fileInput.dispatchEvent(new Event('change'));
16 | }
17 | });
18 | });
19 |
--------------------------------------------------------------------------------
/modules/shared_items.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def realesrgan_models_names():
4 | import modules.realesrgan_model
5 | return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
6 |
7 |
8 | def postprocessing_scripts():
9 | import modules.scripts
10 |
11 | return modules.scripts.scripts_postproc.scripts
12 |
13 |
14 | def sd_vae_items():
15 | import modules.sd_vae
16 |
17 | return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
18 |
19 |
20 | def refresh_vae_list():
21 | import modules.sd_vae
22 |
23 | modules.sd_vae.refresh_vae_list()
24 |
25 |
26 | def cross_attention_optimizations():
27 | import modules.sd_hijack
28 |
29 | return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
30 |
31 |
32 |
--------------------------------------------------------------------------------
/textual_inversion_templates/subject.txt:
--------------------------------------------------------------------------------
1 | a photo of a [name]
2 | a rendering of a [name]
3 | a cropped photo of the [name]
4 | the photo of a [name]
5 | a photo of a clean [name]
6 | a photo of a dirty [name]
7 | a dark photo of the [name]
8 | a photo of my [name]
9 | a photo of the cool [name]
10 | a close-up photo of a [name]
11 | a bright photo of the [name]
12 | a cropped photo of a [name]
13 | a photo of the [name]
14 | a good photo of the [name]
15 | a photo of one [name]
16 | a close-up photo of the [name]
17 | a rendition of the [name]
18 | a photo of the clean [name]
19 | a rendition of a [name]
20 | a photo of a nice [name]
21 | a good photo of a [name]
22 | a photo of the nice [name]
23 | a photo of the small [name]
24 | a photo of the weird [name]
25 | a photo of the large [name]
26 | a photo of a cool [name]
27 | a photo of a small [name]
28 |
--------------------------------------------------------------------------------
/webui-macos-env.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ####################################################################
3 | # macOS defaults #
4 | # Please modify webui-user.sh to change these instead of this file #
5 | ####################################################################
6 |
7 | if [[ -x "$(command -v python3.10)" ]]
8 | then
9 | python_cmd="python3.10"
10 | fi
11 |
12 | export install_dir="$HOME"
13 | export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
14 | export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2"
15 | export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
16 | export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
17 | export PYTORCH_ENABLE_MPS_FALLBACK=1
18 |
19 | ####################################################################
20 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.ruff]
2 |
3 | target-version = "py39"
4 |
5 | extend-select = [
6 | "B",
7 | "C",
8 | "I",
9 | "W",
10 | ]
11 |
12 | exclude = [
13 | "extensions",
14 | "extensions-disabled",
15 | ]
16 |
17 | ignore = [
18 | "E501", # Line too long
19 | "E731", # Do not assign a `lambda` expression, use a `def`
20 |
21 | "I001", # Import block is un-sorted or un-formatted
22 | "C901", # Function is too complex
23 | "C408", # Rewrite as a literal
24 | "W605", # invalid escape sequence, messes with some docstrings
25 | ]
26 |
27 | [tool.ruff.per-file-ignores]
28 | "webui.py" = ["E402"] # Module level import not at top of file
29 |
30 | [tool.ruff.flake8-bugbear]
31 | # Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
32 | extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"]
33 |
34 | [tool.pytest.ini_options]
35 | base_url = "http://127.0.0.1:7860"
36 |
--------------------------------------------------------------------------------
/javascript/hires_fix.js:
--------------------------------------------------------------------------------
1 |
2 | function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) {
3 | function setInactive(elem, inactive) {
4 | elem.classList.toggle('inactive', !!inactive);
5 | }
6 |
7 | var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale');
8 | var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x');
9 | var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y');
10 |
11 | gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "";
12 |
13 | setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0);
14 | setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0);
15 | setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0);
16 |
17 | return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y];
18 | }
19 |
--------------------------------------------------------------------------------
/textual_inversion_templates/hypernetwork.txt:
--------------------------------------------------------------------------------
1 | a photo of a [filewords]
2 | a rendering of a [filewords]
3 | a cropped photo of the [filewords]
4 | the photo of a [filewords]
5 | a photo of a clean [filewords]
6 | a photo of a dirty [filewords]
7 | a dark photo of the [filewords]
8 | a photo of my [filewords]
9 | a photo of the cool [filewords]
10 | a close-up photo of a [filewords]
11 | a bright photo of the [filewords]
12 | a cropped photo of a [filewords]
13 | a photo of the [filewords]
14 | a good photo of the [filewords]
15 | a photo of one [filewords]
16 | a close-up photo of the [filewords]
17 | a rendition of the [filewords]
18 | a photo of the clean [filewords]
19 | a rendition of a [filewords]
20 | a photo of a nice [filewords]
21 | a good photo of a [filewords]
22 | a photo of the nice [filewords]
23 | a photo of the small [filewords]
24 | a photo of the weird [filewords]
25 | a photo of the large [filewords]
26 | a photo of a cool [filewords]
27 | a photo of a small [filewords]
28 |
--------------------------------------------------------------------------------
/textual_inversion_templates/style_filewords.txt:
--------------------------------------------------------------------------------
1 | a painting of [filewords], art by [name]
2 | a rendering of [filewords], art by [name]
3 | a cropped painting of [filewords], art by [name]
4 | the painting of [filewords], art by [name]
5 | a clean painting of [filewords], art by [name]
6 | a dirty painting of [filewords], art by [name]
7 | a dark painting of [filewords], art by [name]
8 | a picture of [filewords], art by [name]
9 | a cool painting of [filewords], art by [name]
10 | a close-up painting of [filewords], art by [name]
11 | a bright painting of [filewords], art by [name]
12 | a cropped painting of [filewords], art by [name]
13 | a good painting of [filewords], art by [name]
14 | a close-up painting of [filewords], art by [name]
15 | a rendition of [filewords], art by [name]
16 | a nice painting of [filewords], art by [name]
17 | a small painting of [filewords], art by [name]
18 | a weird painting of [filewords], art by [name]
19 | a large painting of [filewords], art by [name]
20 |
--------------------------------------------------------------------------------
/html/image-update.svg:
--------------------------------------------------------------------------------
1 |
8 |
--------------------------------------------------------------------------------
/test/test_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import requests
3 |
4 |
5 | def test_options_write(base_url):
6 | url_options = f"{base_url}/sdapi/v1/options"
7 | response = requests.get(url_options)
8 | assert response.status_code == 200
9 |
10 | pre_value = response.json()["send_seed"]
11 |
12 | assert requests.post(url_options, json={'send_seed': (not pre_value)}).status_code == 200
13 |
14 | response = requests.get(url_options)
15 | assert response.status_code == 200
16 | assert response.json()['send_seed'] == (not pre_value)
17 |
18 | requests.post(url_options, json={"send_seed": pre_value})
19 |
20 |
21 | @pytest.mark.parametrize("url", [
22 | "sdapi/v1/cmd-flags",
23 | "sdapi/v1/samplers",
24 | "sdapi/v1/upscalers",
25 | "sdapi/v1/sd-models",
26 | "sdapi/v1/hypernetworks",
27 | "sdapi/v1/face-restorers",
28 | "sdapi/v1/realesrgan-models",
29 | "sdapi/v1/prompt-styles",
30 | "sdapi/v1/embeddings",
31 | ])
32 | def test_get_api_url(base_url, url):
33 | assert requests.get(f"{base_url}/{url}").status_code == 200
34 |
--------------------------------------------------------------------------------
/modules/script_loading.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import traceback
4 | import importlib.util
5 |
6 |
7 | def load_module(path):
8 | module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
9 | module = importlib.util.module_from_spec(module_spec)
10 | module_spec.loader.exec_module(module)
11 |
12 | return module
13 |
14 |
15 | def preload_extensions(extensions_dir, parser):
16 | if not os.path.isdir(extensions_dir):
17 | return
18 |
19 | for dirname in sorted(os.listdir(extensions_dir)):
20 | preload_script = os.path.join(extensions_dir, dirname, "preload.py")
21 | if not os.path.isfile(preload_script):
22 | continue
23 |
24 | try:
25 | module = load_module(preload_script)
26 | if hasattr(module, 'preload'):
27 | module.preload(parser)
28 |
29 | except Exception:
30 | print(f"Error running preload() for {preload_script}", file=sys.stderr)
31 | print(traceback.format_exc(), file=sys.stderr)
32 |
--------------------------------------------------------------------------------
/modules/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 |
4 | class Timer:
5 | def __init__(self):
6 | self.start = time.time()
7 | self.records = {}
8 | self.total = 0
9 |
10 | def elapsed(self):
11 | end = time.time()
12 | res = end - self.start
13 | self.start = end
14 | return res
15 |
16 | def record(self, category, extra_time=0):
17 | e = self.elapsed()
18 | if category not in self.records:
19 | self.records[category] = 0
20 |
21 | self.records[category] += e + extra_time
22 | self.total += e + extra_time
23 |
24 | def summary(self):
25 | res = f"{self.total:.1f}s"
26 |
27 | additions = [x for x in self.records.items() if x[1] >= 0.1]
28 | if not additions:
29 | return res
30 |
31 | res += " ("
32 | res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
33 | res += ")"
34 |
35 | return res
36 |
37 | def reset(self):
38 | self.__init__()
39 |
--------------------------------------------------------------------------------
/launch.py:
--------------------------------------------------------------------------------
1 | from modules import launch_utils
2 |
3 |
4 | args = launch_utils.args
5 | python = launch_utils.python
6 | git = launch_utils.git
7 | index_url = launch_utils.index_url
8 | dir_repos = launch_utils.dir_repos
9 |
10 | commit_hash = launch_utils.commit_hash
11 | git_tag = launch_utils.git_tag
12 |
13 | run = launch_utils.run
14 | is_installed = launch_utils.is_installed
15 | repo_dir = launch_utils.repo_dir
16 |
17 | run_pip = launch_utils.run_pip
18 | check_run_python = launch_utils.check_run_python
19 | git_clone = launch_utils.git_clone
20 | git_pull_recursive = launch_utils.git_pull_recursive
21 | run_extension_installer = launch_utils.run_extension_installer
22 | prepare_environment = launch_utils.prepare_environment
23 | configure_for_tests = launch_utils.configure_for_tests
24 | start = launch_utils.start
25 |
26 |
27 | def main():
28 | if not args.skip_prepare_environment:
29 | prepare_environment()
30 |
31 | if args.test_server:
32 | configure_for_tests()
33 |
34 | start()
35 |
36 |
37 | if __name__ == "__main__":
38 | main()
39 |
--------------------------------------------------------------------------------
/.github/workflows/on_pull_request.yaml:
--------------------------------------------------------------------------------
1 | name: Run Linting/Formatting on Pull Requests
2 |
3 | on:
4 | - push
5 | - pull_request
6 |
7 | jobs:
8 | lint-python:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - name: Checkout Code
12 | uses: actions/checkout@v3
13 | - uses: actions/setup-python@v4
14 | with:
15 | python-version: 3.11
16 | # NB: there's no cache: pip here since we're not installing anything
17 | # from the requirements.txt file(s) in the repository; it's faster
18 | # not to have GHA download an (at the time of writing) 4 GB cache
19 | # of PyTorch and other dependencies.
20 | - name: Install Ruff
21 | run: pip install ruff==0.0.265
22 | - name: Run Ruff
23 | run: ruff .
24 | lint-js:
25 | runs-on: ubuntu-latest
26 | steps:
27 | - name: Checkout Code
28 | uses: actions/checkout@v3
29 | - name: Install Node.js
30 | uses: actions/setup-node@v3
31 | with:
32 | node-version: 18
33 | - run: npm i --ci
34 | - run: npm run lint
35 |
--------------------------------------------------------------------------------
/modules/ui_extra_networks_hypernets.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | from modules import shared, ui_extra_networks
5 |
6 |
7 | class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
8 | def __init__(self):
9 | super().__init__('Hypernetworks')
10 |
11 | def refresh(self):
12 | shared.reload_hypernetworks()
13 |
14 | def list_items(self):
15 | for name, path in shared.hypernetworks.items():
16 | path, ext = os.path.splitext(path)
17 |
18 | yield {
19 | "name": name,
20 | "filename": path,
21 | "preview": self.find_preview(path),
22 | "description": self.find_description(path),
23 | "search_term": self.search_terms_from_path(path),
24 | "prompt": json.dumps(f""),
25 | "local_preview": f"{path}.preview.{shared.opts.samples_format}",
26 | }
27 |
28 | def allowed_directories_for_previews(self):
29 | return [shared.cmd_opts.hypernetwork_dir]
30 |
31 |
--------------------------------------------------------------------------------
/textual_inversion_templates/subject_filewords.txt:
--------------------------------------------------------------------------------
1 | a photo of a [name], [filewords]
2 | a rendering of a [name], [filewords]
3 | a cropped photo of the [name], [filewords]
4 | the photo of a [name], [filewords]
5 | a photo of a clean [name], [filewords]
6 | a photo of a dirty [name], [filewords]
7 | a dark photo of the [name], [filewords]
8 | a photo of my [name], [filewords]
9 | a photo of the cool [name], [filewords]
10 | a close-up photo of a [name], [filewords]
11 | a bright photo of the [name], [filewords]
12 | a cropped photo of a [name], [filewords]
13 | a photo of the [name], [filewords]
14 | a good photo of the [name], [filewords]
15 | a photo of one [name], [filewords]
16 | a close-up photo of the [name], [filewords]
17 | a rendition of the [name], [filewords]
18 | a photo of the clean [name], [filewords]
19 | a rendition of a [name], [filewords]
20 | a photo of a nice [name], [filewords]
21 | a good photo of a [name], [filewords]
22 | a photo of the nice [name], [filewords]
23 | a photo of the small [name], [filewords]
24 | a photo of the weird [name], [filewords]
25 | a photo of the large [name], [filewords]
26 | a photo of a cool [name], [filewords]
27 | a photo of a small [name], [filewords]
28 |
--------------------------------------------------------------------------------
/modules/localization.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sys
4 | import traceback
5 |
6 |
7 | localizations = {}
8 |
9 |
10 | def list_localizations(dirname):
11 | localizations.clear()
12 |
13 | for file in os.listdir(dirname):
14 | fn, ext = os.path.splitext(file)
15 | if ext.lower() != ".json":
16 | continue
17 |
18 | localizations[fn] = os.path.join(dirname, file)
19 |
20 | from modules import scripts
21 | for file in scripts.list_scripts("localizations", ".json"):
22 | fn, ext = os.path.splitext(file.filename)
23 | localizations[fn] = file.path
24 |
25 |
26 | def localization_js(current_localization_name: str) -> str:
27 | fn = localizations.get(current_localization_name, None)
28 | data = {}
29 | if fn is not None:
30 | try:
31 | with open(fn, "r", encoding="utf8") as file:
32 | data = json.load(file)
33 | except Exception:
34 | print(f"Error loading localization from {fn}:", file=sys.stderr)
35 | print(traceback.format_exc(), file=sys.stderr)
36 |
37 | return f"window.localization = {json.dumps(data)}"
38 |
--------------------------------------------------------------------------------
/scripts/postprocessing_gfpgan.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 |
4 | from modules import scripts_postprocessing, gfpgan_model
5 | import gradio as gr
6 |
7 | from modules.ui_components import FormRow
8 |
9 |
10 | class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
11 | name = "GFPGAN"
12 | order = 2000
13 |
14 | def ui(self):
15 | with FormRow():
16 | gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility")
17 |
18 | return {
19 | "gfpgan_visibility": gfpgan_visibility,
20 | }
21 |
22 | def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility):
23 | if gfpgan_visibility == 0:
24 | return
25 |
26 | restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8))
27 | res = Image.fromarray(restored_img)
28 |
29 | if gfpgan_visibility < 1.0:
30 | res = Image.blend(pp.image, res, gfpgan_visibility)
31 |
32 | pp.image = res
33 | pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3)
34 |
--------------------------------------------------------------------------------
/modules/ngrok.py:
--------------------------------------------------------------------------------
1 | import ngrok
2 |
3 | # Connect to ngrok for ingress
4 | def connect(token, port, options):
5 | account = None
6 | if token is None:
7 | token = 'None'
8 | else:
9 | if ':' in token:
10 | # token = authtoken:username:password
11 | token, username, password = token.split(':', 2)
12 | account = f"{username}:{password}"
13 |
14 | # For all options see: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py
15 | if not options.get('authtoken_from_env'):
16 | options['authtoken'] = token
17 | if account:
18 | options['basic_auth'] = account
19 | if not options.get('session_metadata'):
20 | options['session_metadata'] = 'stable-diffusion-webui'
21 |
22 |
23 | try:
24 | public_url = ngrok.connect(f"127.0.0.1:{port}", **options).url()
25 | except Exception as e:
26 | print(f'Invalid ngrok authtoken? ngrok connection aborted due to: {e}\n'
27 | f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
28 | else:
29 | print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
30 | 'You can use this link after the launch is complete.')
31 |
--------------------------------------------------------------------------------
/modules/extra_networks_hypernet.py:
--------------------------------------------------------------------------------
1 | from modules import extra_networks, shared
2 | from modules.hypernetworks import hypernetwork
3 |
4 |
5 | class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
6 | def __init__(self):
7 | super().__init__('hypernet')
8 |
9 | def activate(self, p, params_list):
10 | additional = shared.opts.sd_hypernetwork
11 |
12 | if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
13 | hypernet_prompt_text = f""
14 | p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
15 | params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
16 |
17 | names = []
18 | multipliers = []
19 | for params in params_list:
20 | assert len(params.items) > 0
21 |
22 | names.append(params.items[0])
23 | multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
24 |
25 | hypernetwork.load_hypernetworks(names, multipliers)
26 |
27 | def deactivate(self, p):
28 | pass
29 |
--------------------------------------------------------------------------------
/modules/ui_extra_networks_textual_inversion.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | from modules import ui_extra_networks, sd_hijack, shared
5 |
6 |
7 | class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
8 | def __init__(self):
9 | super().__init__('Textual Inversion')
10 | self.allow_negative_prompt = True
11 |
12 | def refresh(self):
13 | sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
14 |
15 | def list_items(self):
16 | for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
17 | path, ext = os.path.splitext(embedding.filename)
18 | yield {
19 | "name": embedding.name,
20 | "filename": embedding.filename,
21 | "preview": self.find_preview(path),
22 | "description": self.find_description(path),
23 | "search_term": self.search_terms_from_path(embedding.filename),
24 | "prompt": json.dumps(embedding.name),
25 | "local_preview": f"{path}.preview.{shared.opts.samples_format}",
26 | }
27 |
28 | def allowed_directories_for_previews(self):
29 | return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
30 |
--------------------------------------------------------------------------------
/test/test_extras.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 |
4 | def test_simple_upscaling_performed(base_url, img2img_basic_image_base64):
5 | payload = {
6 | "resize_mode": 0,
7 | "show_extras_results": True,
8 | "gfpgan_visibility": 0,
9 | "codeformer_visibility": 0,
10 | "codeformer_weight": 0,
11 | "upscaling_resize": 2,
12 | "upscaling_resize_w": 128,
13 | "upscaling_resize_h": 128,
14 | "upscaling_crop": True,
15 | "upscaler_1": "Lanczos",
16 | "upscaler_2": "None",
17 | "extras_upscaler_2_visibility": 0,
18 | "image": img2img_basic_image_base64,
19 | }
20 | assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200
21 |
22 |
23 | def test_png_info_performed(base_url, img2img_basic_image_base64):
24 | payload = {
25 | "image": img2img_basic_image_base64,
26 | }
27 | assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200
28 |
29 |
30 | def test_interrogate_performed(base_url, img2img_basic_image_base64):
31 | payload = {
32 | "image": img2img_basic_image_base64,
33 | "model": "clip",
34 | }
35 | assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200
36 |
--------------------------------------------------------------------------------
/modules/ui_extra_networks_checkpoints.py:
--------------------------------------------------------------------------------
1 | import html
2 | import json
3 | import os
4 |
5 | from modules import shared, ui_extra_networks, sd_models
6 |
7 |
8 | class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
9 | def __init__(self):
10 | super().__init__('Checkpoints')
11 |
12 | def refresh(self):
13 | shared.refresh_checkpoints()
14 |
15 | def list_items(self):
16 | checkpoint: sd_models.CheckpointInfo
17 | for name, checkpoint in sd_models.checkpoints_list.items():
18 | path, ext = os.path.splitext(checkpoint.filename)
19 | yield {
20 | "name": checkpoint.name_for_extra,
21 | "filename": path,
22 | "preview": self.find_preview(path),
23 | "description": self.find_description(path),
24 | "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
25 | "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
26 | "local_preview": f"{path}.{shared.opts.samples_format}",
27 | }
28 |
29 | def allowed_directories_for_previews(self):
30 | return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
31 |
32 |
--------------------------------------------------------------------------------
/extensions-builtin/Lora/ui_extra_networks_lora.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import lora
4 |
5 | from modules import shared, ui_extra_networks
6 |
7 |
8 | class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
9 | def __init__(self):
10 | super().__init__('Lora')
11 |
12 | def refresh(self):
13 | lora.list_available_loras()
14 |
15 | def list_items(self):
16 | for name, lora_on_disk in lora.available_loras.items():
17 | path, ext = os.path.splitext(lora_on_disk.filename)
18 |
19 | alias = lora_on_disk.get_alias()
20 |
21 | yield {
22 | "name": name,
23 | "filename": path,
24 | "preview": self.find_preview(path),
25 | "description": self.find_description(path),
26 | "search_term": self.search_terms_from_path(lora_on_disk.filename),
27 | "prompt": json.dumps(f""),
28 | "local_preview": f"{path}.{shared.opts.samples_format}",
29 | "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
30 | }
31 |
32 | def allowed_directories_for_previews(self):
33 | return [shared.cmd_opts.lora_dir]
34 |
35 |
--------------------------------------------------------------------------------
/modules/errors.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import traceback
3 |
4 |
5 | def print_error_explanation(message):
6 | lines = message.strip().split("\n")
7 | max_len = max([len(x) for x in lines])
8 |
9 | print('=' * max_len, file=sys.stderr)
10 | for line in lines:
11 | print(line, file=sys.stderr)
12 | print('=' * max_len, file=sys.stderr)
13 |
14 |
15 | def display(e: Exception, task):
16 | print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
17 | print(traceback.format_exc(), file=sys.stderr)
18 |
19 | message = str(e)
20 | if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
21 | print_error_explanation("""
22 | The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
23 | See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
24 | """)
25 |
26 |
27 | already_displayed = {}
28 |
29 |
30 | def display_once(e: Exception, task):
31 | if task in already_displayed:
32 | return
33 |
34 | display(e, task)
35 |
36 | already_displayed[task] = 1
37 |
38 |
39 | def run(code, task):
40 | try:
41 | code()
42 | except Exception as e:
43 | display(task, e)
44 |
--------------------------------------------------------------------------------
/modules/sd_hijack_utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | class CondFunc:
4 | def __new__(cls, orig_func, sub_func, cond_func):
5 | self = super(CondFunc, cls).__new__(cls)
6 | if isinstance(orig_func, str):
7 | func_path = orig_func.split('.')
8 | for i in range(len(func_path)-1, -1, -1):
9 | try:
10 | resolved_obj = importlib.import_module('.'.join(func_path[:i]))
11 | break
12 | except ImportError:
13 | pass
14 | for attr_name in func_path[i:-1]:
15 | resolved_obj = getattr(resolved_obj, attr_name)
16 | orig_func = getattr(resolved_obj, func_path[-1])
17 | setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
18 | self.__init__(orig_func, sub_func, cond_func)
19 | return lambda *args, **kwargs: self(*args, **kwargs)
20 | def __init__(self, orig_func, sub_func, cond_func):
21 | self.__orig_func = orig_func
22 | self.__sub_func = sub_func
23 | self.__cond_func = cond_func
24 | def __call__(self, *args, **kwargs):
25 | if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
26 | return self.__sub_func(self.__orig_func, *args, **kwargs)
27 | else:
28 | return self.__orig_func(*args, **kwargs)
29 |
--------------------------------------------------------------------------------
/modules/paths_internal.py:
--------------------------------------------------------------------------------
1 | """this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py"""
2 |
3 | import argparse
4 | import os
5 | import sys
6 | import shlex
7 |
8 | commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
9 | sys.argv += shlex.split(commandline_args)
10 |
11 | modules_path = os.path.dirname(os.path.realpath(__file__))
12 | script_path = os.path.dirname(modules_path)
13 |
14 | sd_configs_path = os.path.join(script_path, "configs")
15 | sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
16 | sd_model_file = os.path.join(script_path, 'model.ckpt')
17 | default_sd_model_file = sd_model_file
18 |
19 | # Parse the --data-dir flag first so we can use it as a base for our other argument default values
20 | parser_pre = argparse.ArgumentParser(add_help=False)
21 | parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", )
22 | cmd_opts_pre = parser_pre.parse_known_args()[0]
23 |
24 | data_path = cmd_opts_pre.data_dir
25 |
26 | models_path = os.path.join(data_path, "models")
27 | extensions_dir = os.path.join(data_path, "extensions")
28 | extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
29 | config_states_dir = os.path.join(script_path, "config_states")
30 |
31 | roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
32 |
--------------------------------------------------------------------------------
/modules/sd_hijack_checkpoint.py:
--------------------------------------------------------------------------------
1 | from torch.utils.checkpoint import checkpoint
2 |
3 | import ldm.modules.attention
4 | import ldm.modules.diffusionmodules.openaimodel
5 |
6 |
7 | def BasicTransformerBlock_forward(self, x, context=None):
8 | return checkpoint(self._forward, x, context)
9 |
10 |
11 | def AttentionBlock_forward(self, x):
12 | return checkpoint(self._forward, x)
13 |
14 |
15 | def ResBlock_forward(self, x, emb):
16 | return checkpoint(self._forward, x, emb)
17 |
18 |
19 | stored = []
20 |
21 |
22 | def add():
23 | if len(stored) != 0:
24 | return
25 |
26 | stored.extend([
27 | ldm.modules.attention.BasicTransformerBlock.forward,
28 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
29 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
30 | ])
31 |
32 | ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
33 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
34 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
35 |
36 |
37 | def remove():
38 | if len(stored) == 0:
39 | return
40 |
41 | ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
42 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
43 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
44 |
45 | stored.clear()
46 |
47 |
--------------------------------------------------------------------------------
/modules/sd_hijack_open_clip.py:
--------------------------------------------------------------------------------
1 | import open_clip.tokenizer
2 | import torch
3 |
4 | from modules import sd_hijack_clip, devices
5 | from modules.shared import opts
6 |
7 | tokenizer = open_clip.tokenizer._tokenizer
8 |
9 |
10 | class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
11 | def __init__(self, wrapped, hijack):
12 | super().__init__(wrapped, hijack)
13 |
14 | self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0]
15 | self.id_start = tokenizer.encoder[""]
16 | self.id_end = tokenizer.encoder[""]
17 | self.id_pad = 0
18 |
19 | def tokenize(self, texts):
20 | assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
21 |
22 | tokenized = [tokenizer.encode(text) for text in texts]
23 |
24 | return tokenized
25 |
26 | def encode_with_transformers(self, tokens):
27 | # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
28 | z = self.wrapped.encode_with_transformer(tokens)
29 |
30 | return z
31 |
32 | def encode_embedding_init_text(self, init_text, nvpt):
33 | ids = tokenizer.encode(init_text)
34 | ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
35 | embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
36 |
37 | return embedded
38 |
--------------------------------------------------------------------------------
/javascript/imageMaskFix.js:
--------------------------------------------------------------------------------
1 | /**
2 | * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
3 | * @see https://github.com/gradio-app/gradio/issues/1721
4 | */
5 | function imageMaskResize() {
6 | const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
7 | if (!canvases.length) {
8 | window.removeEventListener('resize', imageMaskResize);
9 | return;
10 | }
11 |
12 | const wrapper = canvases[0].closest('.touch-none');
13 | const previewImage = wrapper.previousElementSibling;
14 |
15 | if (!previewImage.complete) {
16 | previewImage.addEventListener('load', imageMaskResize);
17 | return;
18 | }
19 |
20 | const w = previewImage.width;
21 | const h = previewImage.height;
22 | const nw = previewImage.naturalWidth;
23 | const nh = previewImage.naturalHeight;
24 | const portrait = nh > nw;
25 |
26 | const wW = Math.min(w, portrait ? h / nh * nw : w / nw * nw);
27 | const wH = Math.min(h, portrait ? h / nh * nh : w / nw * nh);
28 |
29 | wrapper.style.width = `${wW}px`;
30 | wrapper.style.height = `${wH}px`;
31 | wrapper.style.left = `0px`;
32 | wrapper.style.top = `0px`;
33 |
34 | canvases.forEach(c => {
35 | c.style.width = c.style.height = '';
36 | c.style.maxWidth = '100%';
37 | c.style.maxHeight = '100%';
38 | c.style.objectFit = 'contain';
39 | });
40 | }
41 |
42 | onUiUpdate(imageMaskResize);
43 | window.addEventListener('resize', imageMaskResize);
44 |
--------------------------------------------------------------------------------
/modules/textual_inversion/logging.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import json
3 | import os
4 |
5 | saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
6 | saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
7 | saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
8 | saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
9 | saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
10 |
11 |
12 | def save_settings_to_file(log_directory, all_params):
13 | now = datetime.datetime.now()
14 | params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")}
15 |
16 | keys = saved_params_all
17 | if all_params.get('preview_from_txt2img'):
18 | keys = keys | saved_params_previews
19 |
20 | params.update({k: v for k, v in all_params.items() if k in keys})
21 |
22 | filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json'
23 | with open(os.path.join(log_directory, filename), "w") as file:
24 | json.dump(params, file, indent=4)
25 |
--------------------------------------------------------------------------------
/modules/sd_hijack_xlmr.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from modules import sd_hijack_clip, devices
4 |
5 |
6 | class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
7 | def __init__(self, wrapped, hijack):
8 | super().__init__(wrapped, hijack)
9 |
10 | self.id_start = wrapped.config.bos_token_id
11 | self.id_end = wrapped.config.eos_token_id
12 | self.id_pad = wrapped.config.pad_token_id
13 |
14 | self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have bits for comma
15 |
16 | def encode_with_transformers(self, tokens):
17 | # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
18 | # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
19 | # layer to work with - you have to use the last
20 |
21 | attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
22 | features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
23 | z = features['projection_state']
24 |
25 | return z
26 |
27 | def encode_embedding_init_text(self, init_text, nvpt):
28 | embedding_layer = self.wrapped.roberta.embeddings
29 | ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
30 | embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
31 |
32 | return embedded
33 |
--------------------------------------------------------------------------------
/webui-user.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #########################################################
3 | # Uncomment and change the variables below to your need:#
4 | #########################################################
5 |
6 | # Install directory without trailing slash
7 | #install_dir="/home/$(whoami)"
8 |
9 | # Name of the subdirectory
10 | #clone_dir="stable-diffusion-webui"
11 |
12 | # Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention"
13 | #export COMMANDLINE_ARGS=""
14 |
15 | # python3 executable
16 | #python_cmd="python3"
17 |
18 | # git executable
19 | #export GIT="git"
20 |
21 | # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
22 | #venv_dir="venv"
23 |
24 | # script to launch to start the app
25 | #export LAUNCH_SCRIPT="launch.py"
26 |
27 | # install command for torch
28 | #export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113"
29 |
30 | # Requirements file to use for stable-diffusion-webui
31 | #export REQS_FILE="requirements_versions.txt"
32 |
33 | # Fixed git repos
34 | #export K_DIFFUSION_PACKAGE=""
35 | #export GFPGAN_PACKAGE=""
36 |
37 | # Fixed git commits
38 | #export STABLE_DIFFUSION_COMMIT_HASH=""
39 | #export TAMING_TRANSFORMERS_COMMIT_HASH=""
40 | #export CODEFORMER_COMMIT_HASH=""
41 | #export BLIP_COMMIT_HASH=""
42 |
43 | # Uncomment to enable accelerated launch
44 | #export ACCELERATE="True"
45 |
46 | # Uncomment to disable TCMalloc
47 | #export NO_TCMALLOC="True"
48 |
49 | ###########################################
50 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: Feature request
2 | description: Suggest an idea for this project
3 | title: "[Feature Request]: "
4 | labels: ["enhancement"]
5 |
6 | body:
7 | - type: checkboxes
8 | attributes:
9 | label: Is there an existing issue for this?
10 | description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit.
11 | options:
12 | - label: I have searched the existing issues and checked the recent builds/commits
13 | required: true
14 | - type: markdown
15 | attributes:
16 | value: |
17 | *Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible*
18 | - type: textarea
19 | id: feature
20 | attributes:
21 | label: What would your feature do ?
22 | description: Tell us about your feature in a very clear and simple way, and what problem it would solve
23 | validations:
24 | required: true
25 | - type: textarea
26 | id: workflow
27 | attributes:
28 | label: Proposed workflow
29 | description: Please provide us with step by step information on how you'd like the feature to be accessed and used
30 | value: |
31 | 1. Go to ....
32 | 2. Press ....
33 | 3. ...
34 | validations:
35 | required: true
36 | - type: textarea
37 | id: misc
38 | attributes:
39 | label: Additional information
40 | description: Add any other context or screenshots about the feature request here.
41 |
--------------------------------------------------------------------------------
/scripts/postprocessing_codeformer.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 |
4 | from modules import scripts_postprocessing, codeformer_model
5 | import gradio as gr
6 |
7 | from modules.ui_components import FormRow
8 |
9 |
10 | class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing):
11 | name = "CodeFormer"
12 | order = 3000
13 |
14 | def ui(self):
15 | with FormRow():
16 | codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility")
17 | codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
18 |
19 | return {
20 | "codeformer_visibility": codeformer_visibility,
21 | "codeformer_weight": codeformer_weight,
22 | }
23 |
24 | def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight):
25 | if codeformer_visibility == 0:
26 | return
27 |
28 | restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight)
29 | res = Image.fromarray(restored_img)
30 |
31 | if codeformer_visibility < 1.0:
32 | res = Image.blend(pp.image, res, codeformer_visibility)
33 |
34 | pp.image = res
35 | pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3)
36 | pp.info["CodeFormer weight"] = round(codeformer_weight, 3)
37 |
--------------------------------------------------------------------------------
/javascript/generationParams.js:
--------------------------------------------------------------------------------
1 | // attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
2 |
3 | let txt2img_gallery, img2img_gallery, modal = undefined;
4 | onUiUpdate(function() {
5 | if (!txt2img_gallery) {
6 | txt2img_gallery = attachGalleryListeners("txt2img");
7 | }
8 | if (!img2img_gallery) {
9 | img2img_gallery = attachGalleryListeners("img2img");
10 | }
11 | if (!modal) {
12 | modal = gradioApp().getElementById('lightboxModal');
13 | modalObserver.observe(modal, {attributes: true, attributeFilter: ['style']});
14 | }
15 | });
16 |
17 | let modalObserver = new MutationObserver(function(mutations) {
18 | mutations.forEach(function(mutationRecord) {
19 | let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText;
20 | if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img')) {
21 | gradioApp().getElementById(selectedTab + "_generation_info_button")?.click();
22 | }
23 | });
24 | });
25 |
26 | function attachGalleryListeners(tab_name) {
27 | var gallery = gradioApp().querySelector('#' + tab_name + '_gallery');
28 | gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name + "_generation_info_button").click());
29 | gallery?.addEventListener('keydown', (e) => {
30 | if (e.keyCode == 37 || e.keyCode == 39) { // left or right arrow
31 | gradioApp().getElementById(tab_name + "_generation_info_button").click();
32 | }
33 | });
34 | return gallery;
35 | }
36 |
--------------------------------------------------------------------------------
/modules/hypernetworks/ui.py:
--------------------------------------------------------------------------------
1 | import html
2 |
3 | import gradio as gr
4 | import modules.hypernetworks.hypernetwork
5 | from modules import devices, sd_hijack, shared
6 |
7 | not_available = ["hardswish", "multiheadattention"]
8 | keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
9 |
10 |
11 | def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
12 | filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
13 |
14 | return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
15 |
16 |
17 | def train_hypernetwork(*args):
18 | shared.loaded_hypernetworks = []
19 |
20 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
21 |
22 | try:
23 | sd_hijack.undo_optimizations()
24 |
25 | hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
26 |
27 | res = f"""
28 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
29 | Hypernetwork saved to {html.escape(filename)}
30 | """
31 | return res, ""
32 | except Exception:
33 | raise
34 | finally:
35 | shared.sd_model.cond_stage_model.to(devices.device)
36 | shared.sd_model.first_stage_model.to(devices.device)
37 | sd_hijack.apply_optimizations()
38 |
39 |
--------------------------------------------------------------------------------
/modules/scripts_auto_postprocessing.py:
--------------------------------------------------------------------------------
1 | from modules import scripts, scripts_postprocessing, shared
2 |
3 |
4 | class ScriptPostprocessingForMainUI(scripts.Script):
5 | def __init__(self, script_postproc):
6 | self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
7 | self.postprocessing_controls = None
8 |
9 | def title(self):
10 | return self.script.name
11 |
12 | def show(self, is_img2img):
13 | return scripts.AlwaysVisible
14 |
15 | def ui(self, is_img2img):
16 | self.postprocessing_controls = self.script.ui()
17 | return self.postprocessing_controls.values()
18 |
19 | def postprocess_image(self, p, script_pp, *args):
20 | args_dict = dict(zip(self.postprocessing_controls, args))
21 |
22 | pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
23 | pp.info = {}
24 | self.script.process(pp, **args_dict)
25 | p.extra_generation_params.update(pp.info)
26 | script_pp.image = pp.image
27 |
28 |
29 | def create_auto_preprocessing_script_data():
30 | from modules import scripts
31 |
32 | res = []
33 |
34 | for name in shared.opts.postprocessing_enable_in_main_ui:
35 | script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
36 | if script is None:
37 | continue
38 |
39 | constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
40 | res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
41 |
42 | return res
43 |
--------------------------------------------------------------------------------
/modules/textual_inversion/ui.py:
--------------------------------------------------------------------------------
1 | import html
2 |
3 | import gradio as gr
4 |
5 | import modules.textual_inversion.textual_inversion
6 | import modules.textual_inversion.preprocess
7 | from modules import sd_hijack, shared
8 |
9 |
10 | def create_embedding(name, initialization_text, nvpt, overwrite_old):
11 | filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
12 |
13 | sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
14 |
15 | return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
16 |
17 |
18 | def preprocess(*args):
19 | modules.textual_inversion.preprocess.preprocess(*args)
20 |
21 | return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
22 |
23 |
24 | def train_embedding(*args):
25 |
26 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
27 |
28 | apply_optimizations = shared.opts.training_xattention_optimizations
29 | try:
30 | if not apply_optimizations:
31 | sd_hijack.undo_optimizations()
32 |
33 | embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
34 |
35 | res = f"""
36 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
37 | Embedding saved to {html.escape(filename)}
38 | """
39 | return res, ""
40 | except Exception:
41 | raise
42 | finally:
43 | if not apply_optimizations:
44 | sd_hijack.apply_optimizations()
45 |
46 |
--------------------------------------------------------------------------------
/modules/sd_samplers.py:
--------------------------------------------------------------------------------
1 | from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
2 |
3 | # imports for functions that previously were here and are used by other modules
4 | from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
5 |
6 | all_samplers = [
7 | *sd_samplers_kdiffusion.samplers_data_k_diffusion,
8 | *sd_samplers_compvis.samplers_data_compvis,
9 | ]
10 | all_samplers_map = {x.name: x for x in all_samplers}
11 |
12 | samplers = []
13 | samplers_for_img2img = []
14 | samplers_map = {}
15 |
16 |
17 | def find_sampler_config(name):
18 | if name is not None:
19 | config = all_samplers_map.get(name, None)
20 | else:
21 | config = all_samplers[0]
22 |
23 | return config
24 |
25 |
26 | def create_sampler(name, model):
27 | config = find_sampler_config(name)
28 |
29 | assert config is not None, f'bad sampler name: {name}'
30 |
31 | sampler = config.constructor(model)
32 | sampler.config = config
33 |
34 | return sampler
35 |
36 |
37 | def set_samplers():
38 | global samplers, samplers_for_img2img
39 |
40 | hidden = set(shared.opts.hide_samplers)
41 | hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
42 |
43 | samplers = [x for x in all_samplers if x.name not in hidden]
44 | samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
45 |
46 | samplers_map.clear()
47 | for sampler in all_samplers:
48 | samplers_map[sampler.name.lower()] = sampler.name
49 | for alias in sampler.aliases:
50 | samplers_map[alias.lower()] = sampler.name
51 |
52 |
53 | set_samplers()
54 |
--------------------------------------------------------------------------------
/javascript/notification.js:
--------------------------------------------------------------------------------
1 | // Monitors the gallery and sends a browser notification when the leading image is new.
2 |
3 | let lastHeadImg = null;
4 |
5 | let notificationButton = null;
6 |
7 | onUiUpdate(function() {
8 | if (notificationButton == null) {
9 | notificationButton = gradioApp().getElementById('request_notifications');
10 |
11 | if (notificationButton != null) {
12 | notificationButton.addEventListener('click', () => {
13 | void Notification.requestPermission();
14 | }, true);
15 | }
16 | }
17 |
18 | const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] .thumbnail-item > img');
19 |
20 | if (galleryPreviews == null) return;
21 |
22 | const headImg = galleryPreviews[0]?.src;
23 |
24 | if (headImg == null || headImg == lastHeadImg) return;
25 |
26 | lastHeadImg = headImg;
27 |
28 | // play notification sound if available
29 | gradioApp().querySelector('#audio_notification audio')?.play();
30 |
31 | if (document.hasFocus()) return;
32 |
33 | // Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated.
34 | const imgs = new Set(Array.from(galleryPreviews).map(img => img.src));
35 |
36 | const notification = new Notification(
37 | 'Stable Diffusion',
38 | {
39 | body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`,
40 | icon: headImg,
41 | image: headImg,
42 | }
43 | );
44 |
45 | notification.onclick = function(_) {
46 | parent.focus();
47 | this.close();
48 | };
49 | });
50 |
--------------------------------------------------------------------------------
/extensions-builtin/Lora/extra_networks_lora.py:
--------------------------------------------------------------------------------
1 | from modules import extra_networks, shared
2 | import lora
3 |
4 |
5 | class ExtraNetworkLora(extra_networks.ExtraNetwork):
6 | def __init__(self):
7 | super().__init__('lora')
8 |
9 | def activate(self, p, params_list):
10 | additional = shared.opts.sd_lora
11 |
12 | if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
13 | p.all_prompts = [x + f"" for x in p.all_prompts]
14 | params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
15 |
16 | names = []
17 | multipliers = []
18 | for params in params_list:
19 | assert len(params.items) > 0
20 |
21 | names.append(params.items[0])
22 | multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
23 |
24 | lora.load_loras(names, multipliers)
25 |
26 | if shared.opts.lora_add_hashes_to_infotext:
27 | lora_hashes = []
28 | for item in lora.loaded_loras:
29 | shorthash = item.lora_on_disk.shorthash
30 | if not shorthash:
31 | continue
32 |
33 | alias = item.mentioned_name
34 | if not alias:
35 | continue
36 |
37 | alias = alias.replace(":", "").replace(",", "")
38 |
39 | lora_hashes.append(f"{alias}: {shorthash}")
40 |
41 | if lora_hashes:
42 | p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes)
43 |
44 | def deactivate(self, p):
45 | pass
46 |
--------------------------------------------------------------------------------
/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js:
--------------------------------------------------------------------------------
1 | // Stable Diffusion WebUI - Bracket checker
2 | // By Hingashi no Florin/Bwin4L & @akx
3 | // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
4 | // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
5 |
6 | function checkBrackets(textArea, counterElt) {
7 | var counts = {};
8 | (textArea.value.match(/[(){}[\]]/g) || []).forEach(bracket => {
9 | counts[bracket] = (counts[bracket] || 0) + 1;
10 | });
11 | var errors = [];
12 |
13 | function checkPair(open, close, kind) {
14 | if (counts[open] !== counts[close]) {
15 | errors.push(
16 | `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
17 | );
18 | }
19 | }
20 |
21 | checkPair('(', ')', 'round brackets');
22 | checkPair('[', ']', 'square brackets');
23 | checkPair('{', '}', 'curly brackets');
24 | counterElt.title = errors.join('\n');
25 | counterElt.classList.toggle('error', errors.length !== 0);
26 | }
27 |
28 | function setupBracketChecking(id_prompt, id_counter) {
29 | var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
30 | var counter = gradioApp().getElementById(id_counter);
31 |
32 | if (textarea && counter) {
33 | textarea.addEventListener("input", () => checkBrackets(textarea, counter));
34 | }
35 | }
36 |
37 | onUiLoaded(function() {
38 | setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
39 | setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
40 | setupBracketChecking('img2img_prompt', 'img2img_token_counter');
41 | setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
42 | });
43 |
--------------------------------------------------------------------------------
/javascript/imageviewerGamepad.js:
--------------------------------------------------------------------------------
1 | window.addEventListener('gamepadconnected', (e) => {
2 | const index = e.gamepad.index;
3 | let isWaiting = false;
4 | setInterval(async() => {
5 | if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
6 | const gamepad = navigator.getGamepads()[index];
7 | const xValue = gamepad.axes[0];
8 | if (xValue <= -0.3) {
9 | modalPrevImage(e);
10 | isWaiting = true;
11 | } else if (xValue >= 0.3) {
12 | modalNextImage(e);
13 | isWaiting = true;
14 | }
15 | if (isWaiting) {
16 | await sleepUntil(() => {
17 | const xValue = navigator.getGamepads()[index].axes[0];
18 | if (xValue < 0.3 && xValue > -0.3) {
19 | return true;
20 | }
21 | }, opts.js_modal_lightbox_gamepad_repeat);
22 | isWaiting = false;
23 | }
24 | }, 10);
25 | });
26 |
27 | /*
28 | Primarily for vr controller type pointer devices.
29 | I use the wheel event because there's currently no way to do it properly with web xr.
30 | */
31 | let isScrolling = false;
32 | window.addEventListener('wheel', (e) => {
33 | if (!opts.js_modal_lightbox_gamepad || isScrolling) return;
34 | isScrolling = true;
35 |
36 | if (e.deltaX <= -0.6) {
37 | modalPrevImage(e);
38 | } else if (e.deltaX >= 0.6) {
39 | modalNextImage(e);
40 | }
41 |
42 | setTimeout(() => {
43 | isScrolling = false;
44 | }, opts.js_modal_lightbox_gamepad_repeat);
45 | });
46 |
47 | function sleepUntil(f, timeout) {
48 | return new Promise((resolve) => {
49 | const timeStart = new Date();
50 | const wait = setInterval(function() {
51 | if (f() || new Date() - timeStart > timeout) {
52 | clearInterval(wait);
53 | resolve();
54 | }
55 | }, 20);
56 | });
57 | }
58 |
--------------------------------------------------------------------------------
/configs/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
--------------------------------------------------------------------------------
/modules/paths.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
4 |
5 | import modules.safe # noqa: F401
6 |
7 |
8 | # data_path = cmd_opts_pre.data
9 | sys.path.insert(0, script_path)
10 |
11 | # search for directory of stable diffusion in following places
12 | sd_path = None
13 | possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
14 | for possible_sd_path in possible_sd_paths:
15 | if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
16 | sd_path = os.path.abspath(possible_sd_path)
17 | break
18 |
19 | assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
20 |
21 | path_dirs = [
22 | (sd_path, 'ldm', 'Stable Diffusion', []),
23 | (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
24 | (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
25 | (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
26 | (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
27 | ]
28 |
29 | paths = {}
30 |
31 | for d, must_exist, what, options in path_dirs:
32 | must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
33 | if not os.path.exists(must_exist_path):
34 | print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
35 | else:
36 | d = os.path.abspath(d)
37 | if "atstart" in options:
38 | sys.path.insert(0, d)
39 | else:
40 | sys.path.append(d)
41 | paths[what] = d
42 |
43 |
44 | class Prioritize:
45 | def __init__(self, name):
46 | self.name = name
47 | self.path = None
48 |
49 | def __enter__(self):
50 | self.path = sys.path.copy()
51 | sys.path = [paths[self.name]] + sys.path
52 |
53 | def __exit__(self, exc_type, exc_val, exc_tb):
54 | sys.path = self.path
55 | self.path = None
56 |
--------------------------------------------------------------------------------
/modules/ui_components.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 |
4 | class FormComponent:
5 | def get_expected_parent(self):
6 | return gr.components.Form
7 |
8 |
9 | gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
10 |
11 |
12 | class ToolButton(FormComponent, gr.Button):
13 | """Small button with single emoji as text, fits inside gradio forms"""
14 |
15 | def __init__(self, *args, **kwargs):
16 | classes = kwargs.pop("elem_classes", [])
17 | super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
18 |
19 | def get_block_name(self):
20 | return "button"
21 |
22 |
23 | class FormRow(FormComponent, gr.Row):
24 | """Same as gr.Row but fits inside gradio forms"""
25 |
26 | def get_block_name(self):
27 | return "row"
28 |
29 |
30 | class FormColumn(FormComponent, gr.Column):
31 | """Same as gr.Column but fits inside gradio forms"""
32 |
33 | def get_block_name(self):
34 | return "column"
35 |
36 |
37 | class FormGroup(FormComponent, gr.Group):
38 | """Same as gr.Row but fits inside gradio forms"""
39 |
40 | def get_block_name(self):
41 | return "group"
42 |
43 |
44 | class FormHTML(FormComponent, gr.HTML):
45 | """Same as gr.HTML but fits inside gradio forms"""
46 |
47 | def get_block_name(self):
48 | return "html"
49 |
50 |
51 | class FormColorPicker(FormComponent, gr.ColorPicker):
52 | """Same as gr.ColorPicker but fits inside gradio forms"""
53 |
54 | def get_block_name(self):
55 | return "colorpicker"
56 |
57 |
58 | class DropdownMulti(FormComponent, gr.Dropdown):
59 | """Same as gr.Dropdown but always multiselect"""
60 | def __init__(self, **kwargs):
61 | super().__init__(multiselect=True, **kwargs)
62 |
63 | def get_block_name(self):
64 | return "dropdown"
65 |
66 |
67 | class DropdownEditable(FormComponent, gr.Dropdown):
68 | """Same as gr.Dropdown but allows editing value"""
69 | def __init__(self, **kwargs):
70 | super().__init__(allow_custom_value=True, **kwargs)
71 |
72 | def get_block_name(self):
73 | return "dropdown"
74 |
75 |
--------------------------------------------------------------------------------
/configs/alt-diffusion-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: modules.xlmr.BertSeriesModelWithTransformation
71 | params:
72 | name: "XLMR-Large"
--------------------------------------------------------------------------------
/modules/sd_vae_approx.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from modules import devices, paths
6 |
7 | sd_vae_approx_model = None
8 |
9 |
10 | class VAEApprox(nn.Module):
11 | def __init__(self):
12 | super(VAEApprox, self).__init__()
13 | self.conv1 = nn.Conv2d(4, 8, (7, 7))
14 | self.conv2 = nn.Conv2d(8, 16, (5, 5))
15 | self.conv3 = nn.Conv2d(16, 32, (3, 3))
16 | self.conv4 = nn.Conv2d(32, 64, (3, 3))
17 | self.conv5 = nn.Conv2d(64, 32, (3, 3))
18 | self.conv6 = nn.Conv2d(32, 16, (3, 3))
19 | self.conv7 = nn.Conv2d(16, 8, (3, 3))
20 | self.conv8 = nn.Conv2d(8, 3, (3, 3))
21 |
22 | def forward(self, x):
23 | extra = 11
24 | x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
25 | x = nn.functional.pad(x, (extra, extra, extra, extra))
26 |
27 | for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
28 | x = layer(x)
29 | x = nn.functional.leaky_relu(x, 0.1)
30 |
31 | return x
32 |
33 |
34 | def model():
35 | global sd_vae_approx_model
36 |
37 | if sd_vae_approx_model is None:
38 | model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
39 | sd_vae_approx_model = VAEApprox()
40 | if not os.path.exists(model_path):
41 | model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
42 | sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
43 | sd_vae_approx_model.eval()
44 | sd_vae_approx_model.to(devices.device, devices.dtype)
45 |
46 | return sd_vae_approx_model
47 |
48 |
49 | def cheap_approximation(sample):
50 | # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
51 |
52 | coefs = torch.tensor([
53 | [0.298, 0.207, 0.208],
54 | [0.187, 0.286, 0.173],
55 | [-0.158, 0.189, 0.264],
56 | [-0.184, -0.271, -0.473],
57 | ]).to(sample.device)
58 |
59 | x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
60 |
61 | return x_sample
62 |
--------------------------------------------------------------------------------
/configs/v1-inpainting-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 7.5e-05
3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: hybrid # important
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | finetune_keys: null
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
--------------------------------------------------------------------------------
/.github/workflows/run_tests.yaml:
--------------------------------------------------------------------------------
1 | name: Run basic features tests on CPU with empty SD model
2 |
3 | on:
4 | - push
5 | - pull_request
6 |
7 | jobs:
8 | test:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - name: Checkout Code
12 | uses: actions/checkout@v3
13 | - name: Set up Python 3.10
14 | uses: actions/setup-python@v4
15 | with:
16 | python-version: 3.10.6
17 | cache: pip
18 | cache-dependency-path: |
19 | **/requirements*txt
20 | launch.py
21 | - name: Install test dependencies
22 | run: pip install wait-for-it -r requirements-test.txt
23 | env:
24 | PIP_DISABLE_PIP_VERSION_CHECK: "1"
25 | PIP_PROGRESS_BAR: "off"
26 | - name: Setup environment
27 | run: python launch.py --skip-torch-cuda-test --exit
28 | env:
29 | PIP_DISABLE_PIP_VERSION_CHECK: "1"
30 | PIP_PROGRESS_BAR: "off"
31 | TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
32 | WEBUI_LAUNCH_LIVE_OUTPUT: "1"
33 | PYTHONUNBUFFERED: "1"
34 | - name: Start test server
35 | run: >
36 | python -m coverage run
37 | --data-file=.coverage.server
38 | launch.py
39 | --skip-prepare-environment
40 | --skip-torch-cuda-test
41 | --test-server
42 | --no-half
43 | --disable-opt-split-attention
44 | --use-cpu all
45 | --add-stop-route
46 | 2>&1 | tee output.txt &
47 | - name: Run tests
48 | run: |
49 | wait-for-it --service 127.0.0.1:7860 -t 600
50 | python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
51 | - name: Kill test server
52 | if: always()
53 | run: curl -vv -XPOST http://127.0.0.1:7860/_stop && sleep 10
54 | - name: Show coverage
55 | run: |
56 | python -m coverage combine .coverage*
57 | python -m coverage report -i
58 | python -m coverage html -i
59 | - name: Upload main app output
60 | uses: actions/upload-artifact@v3
61 | if: always()
62 | with:
63 | name: output
64 | path: output.txt
65 | - name: Upload coverage HTML
66 | uses: actions/upload-artifact@v3
67 | if: always()
68 | with:
69 | name: htmlcov
70 | path: htmlcov
71 |
--------------------------------------------------------------------------------
/test/test_img2img.py:
--------------------------------------------------------------------------------
1 |
2 | import pytest
3 | import requests
4 |
5 |
6 | @pytest.fixture()
7 | def url_img2img(base_url):
8 | return f"{base_url}/sdapi/v1/img2img"
9 |
10 |
11 | @pytest.fixture()
12 | def simple_img2img_request(img2img_basic_image_base64):
13 | return {
14 | "batch_size": 1,
15 | "cfg_scale": 7,
16 | "denoising_strength": 0.75,
17 | "eta": 0,
18 | "height": 64,
19 | "include_init_images": False,
20 | "init_images": [img2img_basic_image_base64],
21 | "inpaint_full_res": False,
22 | "inpaint_full_res_padding": 0,
23 | "inpainting_fill": 0,
24 | "inpainting_mask_invert": False,
25 | "mask": None,
26 | "mask_blur": 4,
27 | "n_iter": 1,
28 | "negative_prompt": "",
29 | "override_settings": {},
30 | "prompt": "example prompt",
31 | "resize_mode": 0,
32 | "restore_faces": False,
33 | "s_churn": 0,
34 | "s_noise": 1,
35 | "s_tmax": 0,
36 | "s_tmin": 0,
37 | "sampler_index": "Euler a",
38 | "seed": -1,
39 | "seed_resize_from_h": -1,
40 | "seed_resize_from_w": -1,
41 | "steps": 3,
42 | "styles": [],
43 | "subseed": -1,
44 | "subseed_strength": 0,
45 | "tiling": False,
46 | "width": 64,
47 | }
48 |
49 |
50 | def test_img2img_simple_performed(url_img2img, simple_img2img_request):
51 | assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200
52 |
53 |
54 | def test_inpainting_masked_performed(url_img2img, simple_img2img_request, mask_basic_image_base64):
55 | simple_img2img_request["mask"] = mask_basic_image_base64
56 | assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200
57 |
58 |
59 | def test_inpainting_with_inverted_masked_performed(url_img2img, simple_img2img_request, mask_basic_image_base64):
60 | simple_img2img_request["mask"] = mask_basic_image_base64
61 | simple_img2img_request["inpainting_mask_invert"] = True
62 | assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200
63 |
64 |
65 | def test_img2img_sd_upscale_performed(url_img2img, simple_img2img_request):
66 | simple_img2img_request["script_name"] = "sd upscale"
67 | simple_img2img_request["script_args"] = ["", 8, "Lanczos", 2.0]
68 | assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200
69 |
--------------------------------------------------------------------------------
/webui.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | if not defined PYTHON (set PYTHON=python)
4 | if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
5 |
6 |
7 | set ERROR_REPORTING=FALSE
8 |
9 | mkdir tmp 2>NUL
10 |
11 | %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
12 | if %ERRORLEVEL% == 0 goto :check_pip
13 | echo Couldn't launch python
14 | goto :show_stdout_stderr
15 |
16 | :check_pip
17 | %PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
18 | if %ERRORLEVEL% == 0 goto :start_venv
19 | if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
20 | %PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
21 | if %ERRORLEVEL% == 0 goto :start_venv
22 | echo Couldn't install pip
23 | goto :show_stdout_stderr
24 |
25 | :start_venv
26 | if ["%VENV_DIR%"] == ["-"] goto :skip_venv
27 | if ["%SKIP_VENV%"] == ["1"] goto :skip_venv
28 |
29 | dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
30 | if %ERRORLEVEL% == 0 goto :activate_venv
31 |
32 | for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
33 | echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
34 | %PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
35 | if %ERRORLEVEL% == 0 goto :activate_venv
36 | echo Unable to create venv in directory "%VENV_DIR%"
37 | goto :show_stdout_stderr
38 |
39 | :activate_venv
40 | set PYTHON="%VENV_DIR%\Scripts\Python.exe"
41 | echo venv %PYTHON%
42 |
43 | :skip_venv
44 | if [%ACCELERATE%] == ["True"] goto :accelerate
45 | goto :launch
46 |
47 | :accelerate
48 | echo Checking for accelerate
49 | set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe"
50 | if EXIST %ACCELERATE% goto :accelerate_launch
51 |
52 | :launch
53 | %PYTHON% launch.py %*
54 | pause
55 | exit /b
56 |
57 | :accelerate_launch
58 | echo Accelerating
59 | %ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py
60 | pause
61 | exit /b
62 |
63 | :show_stdout_stderr
64 |
65 | echo.
66 | echo exit code: %errorlevel%
67 |
68 | for /f %%i in ("tmp\stdout.txt") do set size=%%~zi
69 | if %size% equ 0 goto :show_stderr
70 | echo.
71 | echo stdout:
72 | type tmp\stdout.txt
73 |
74 | :show_stderr
75 | for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
76 | if %size% equ 0 goto :show_stderr
77 | echo.
78 | echo stderr:
79 | type tmp\stderr.txt
80 |
81 | :endofscript
82 |
83 | echo.
84 | echo Launch unsuccessful. Exiting.
85 | pause
86 |
--------------------------------------------------------------------------------
/javascript/ui_settings_hints.js:
--------------------------------------------------------------------------------
1 | // various hints and extra info for the settings tab
2 |
3 | var settingsHintsSetup = false;
4 |
5 | onOptionsChanged(function() {
6 | if (settingsHintsSetup) return;
7 | settingsHintsSetup = true;
8 |
9 | gradioApp().querySelectorAll('#settings [id^=setting_]').forEach(function(div) {
10 | var name = div.id.substr(8);
11 | var commentBefore = opts._comments_before[name];
12 | var commentAfter = opts._comments_after[name];
13 |
14 | if (!commentBefore && !commentAfter) return;
15 |
16 | var span = null;
17 | if (div.classList.contains('gradio-checkbox')) span = div.querySelector('label span');
18 | else if (div.classList.contains('gradio-checkboxgroup')) span = div.querySelector('span').firstChild;
19 | else if (div.classList.contains('gradio-radio')) span = div.querySelector('span').firstChild;
20 | else span = div.querySelector('label span').firstChild;
21 |
22 | if (!span) return;
23 |
24 | if (commentBefore) {
25 | var comment = document.createElement('DIV');
26 | comment.className = 'settings-comment';
27 | comment.innerHTML = commentBefore;
28 | span.parentElement.insertBefore(document.createTextNode('\xa0'), span);
29 | span.parentElement.insertBefore(comment, span);
30 | span.parentElement.insertBefore(document.createTextNode('\xa0'), span);
31 | }
32 | if (commentAfter) {
33 | comment = document.createElement('DIV');
34 | comment.className = 'settings-comment';
35 | comment.innerHTML = commentAfter;
36 | span.parentElement.insertBefore(comment, span.nextSibling);
37 | span.parentElement.insertBefore(document.createTextNode('\xa0'), span.nextSibling);
38 | }
39 | });
40 | });
41 |
42 | function settingsHintsShowQuicksettings() {
43 | requestGet("./internal/quicksettings-hint", {}, function(data) {
44 | var table = document.createElement('table');
45 | table.className = 'settings-value-table';
46 |
47 | data.forEach(function(obj) {
48 | var tr = document.createElement('tr');
49 | var td = document.createElement('td');
50 | td.textContent = obj.name;
51 | tr.appendChild(td);
52 |
53 | td = document.createElement('td');
54 | td.textContent = obj.label;
55 | tr.appendChild(td);
56 |
57 | table.appendChild(tr);
58 | });
59 |
60 | popup(table);
61 | });
62 | }
63 |
--------------------------------------------------------------------------------
/javascript/extensions.js:
--------------------------------------------------------------------------------
1 |
2 | function extensions_apply(_disabled_list, _update_list, disable_all) {
3 | var disable = [];
4 | var update = [];
5 |
6 | gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
7 | if (x.name.startsWith("enable_") && !x.checked) {
8 | disable.push(x.name.substring(7));
9 | }
10 |
11 | if (x.name.startsWith("update_") && x.checked) {
12 | update.push(x.name.substring(7));
13 | }
14 | });
15 |
16 | restart_reload();
17 |
18 | return [JSON.stringify(disable), JSON.stringify(update), disable_all];
19 | }
20 |
21 | function extensions_check() {
22 | var disable = [];
23 |
24 | gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
25 | if (x.name.startsWith("enable_") && !x.checked) {
26 | disable.push(x.name.substring(7));
27 | }
28 | });
29 |
30 | gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
31 | x.innerHTML = "Loading...";
32 | });
33 |
34 |
35 | var id = randomId();
36 | requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function() {
37 |
38 | });
39 |
40 | return [id, JSON.stringify(disable)];
41 | }
42 |
43 | function install_extension_from_index(button, url) {
44 | button.disabled = "disabled";
45 | button.value = "Installing...";
46 |
47 | var textarea = gradioApp().querySelector('#extension_to_install textarea');
48 | textarea.value = url;
49 | updateInput(textarea);
50 |
51 | gradioApp().querySelector('#install_extension_button').click();
52 | }
53 |
54 | function config_state_confirm_restore(_, config_state_name, config_restore_type) {
55 | if (config_state_name == "Current") {
56 | return [false, config_state_name, config_restore_type];
57 | }
58 | let restored = "";
59 | if (config_restore_type == "extensions") {
60 | restored = "all saved extension versions";
61 | } else if (config_restore_type == "webui") {
62 | restored = "the webui version";
63 | } else {
64 | restored = "the webui version and all saved extension versions";
65 | }
66 | let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
67 | if (confirmed) {
68 | restart_reload();
69 | gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
70 | x.innerHTML = "Loading...";
71 | });
72 | }
73 | return [confirmed, config_state_name, config_restore_type];
74 | }
75 |
--------------------------------------------------------------------------------
/scripts/custom_code.py:
--------------------------------------------------------------------------------
1 | import modules.scripts as scripts
2 | import gradio as gr
3 | import ast
4 | import copy
5 |
6 | from modules.processing import Processed
7 | from modules.shared import cmd_opts
8 |
9 |
10 | def convertExpr2Expression(expr):
11 | expr.lineno = 0
12 | expr.col_offset = 0
13 | result = ast.Expression(expr.value, lineno=0, col_offset = 0)
14 |
15 | return result
16 |
17 |
18 | def exec_with_return(code, module):
19 | """
20 | like exec() but can return values
21 | https://stackoverflow.com/a/52361938/5862977
22 | """
23 | code_ast = ast.parse(code)
24 |
25 | init_ast = copy.deepcopy(code_ast)
26 | init_ast.body = code_ast.body[:-1]
27 |
28 | last_ast = copy.deepcopy(code_ast)
29 | last_ast.body = code_ast.body[-1:]
30 |
31 | exec(compile(init_ast, "", "exec"), module.__dict__)
32 | if type(last_ast.body[0]) == ast.Expr:
33 | return eval(compile(convertExpr2Expression(last_ast.body[0]), "", "eval"), module.__dict__)
34 | else:
35 | exec(compile(last_ast, "", "exec"), module.__dict__)
36 |
37 |
38 | class Script(scripts.Script):
39 |
40 | def title(self):
41 | return "Custom code"
42 |
43 | def show(self, is_img2img):
44 | return cmd_opts.allow_code
45 |
46 | def ui(self, is_img2img):
47 | example = """from modules.processing import process_images
48 |
49 | p.width = 768
50 | p.height = 768
51 | p.batch_size = 2
52 | p.steps = 10
53 |
54 | return process_images(p)
55 | """
56 |
57 |
58 | code = gr.Code(value=example, language="python", label="Python code", elem_id=self.elem_id("code"))
59 | indent_level = gr.Number(label='Indent level', value=2, precision=0, elem_id=self.elem_id("indent_level"))
60 |
61 | return [code, indent_level]
62 |
63 | def run(self, p, code, indent_level):
64 | assert cmd_opts.allow_code, '--allow-code option must be enabled'
65 |
66 | display_result_data = [[], -1, ""]
67 |
68 | def display(imgs, s=display_result_data[1], i=display_result_data[2]):
69 | display_result_data[0] = imgs
70 | display_result_data[1] = s
71 | display_result_data[2] = i
72 |
73 | from types import ModuleType
74 | module = ModuleType("testmodule")
75 | module.__dict__.update(globals())
76 | module.p = p
77 | module.display = display
78 |
79 | indent = " " * indent_level
80 | indented = code.replace('\n', f"\n{indent}")
81 | body = f"""def __webuitemp__():
82 | {indent}{indented}
83 | __webuitemp__()"""
84 |
85 | result = exec_with_return(body, module)
86 |
87 | if isinstance(result, Processed):
88 | return result
89 |
90 | return Processed(p, *display_result_data)
91 |
--------------------------------------------------------------------------------
/modules/ui_postprocessing.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from modules import scripts, shared, ui_common, postprocessing, call_queue
3 | import modules.generation_parameters_copypaste as parameters_copypaste
4 |
5 |
6 | def create_ui():
7 | tab_index = gr.State(value=0)
8 |
9 | with gr.Row().style(equal_height=False, variant='compact'):
10 | with gr.Column(variant='compact'):
11 | with gr.Tabs(elem_id="mode_extras"):
12 | with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
13 | extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
14 |
15 | with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
16 | image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
17 |
18 | with gr.TabItem('Batch from Directory', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir:
19 | extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
20 | extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
21 | show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
22 |
23 | submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
24 |
25 | script_inputs = scripts.scripts_postproc.setup_ui()
26 |
27 | with gr.Column():
28 | result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
29 |
30 | tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
31 | tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
32 | tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
33 |
34 | submit.click(
35 | fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
36 | inputs=[
37 | tab_index,
38 | extras_image,
39 | image_batch,
40 | extras_batch_input_dir,
41 | extras_batch_output_dir,
42 | show_extras_results,
43 | *script_inputs
44 | ],
45 | outputs=[
46 | result_images,
47 | html_info_x,
48 | html_info,
49 | ]
50 | )
51 |
52 | parameters_copypaste.add_paste_fields("extras", extras_image, None)
53 |
54 | extras_image.change(
55 | fn=scripts.scripts_postproc.image_changed,
56 | inputs=[], outputs=[]
57 | )
58 |
--------------------------------------------------------------------------------
/modules/ui_tempdir.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | from collections import namedtuple
4 | from pathlib import Path
5 |
6 | import gradio.components
7 |
8 | from PIL import PngImagePlugin
9 |
10 | from modules import shared
11 |
12 |
13 | Savedfile = namedtuple("Savedfile", ["name"])
14 |
15 |
16 | def register_tmp_file(gradio, filename):
17 | if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
18 | gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
19 |
20 | if hasattr(gradio, 'temp_dirs'): # gradio 3.9
21 | gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
22 |
23 |
24 | def check_tmp_file(gradio, filename):
25 | if hasattr(gradio, 'temp_file_sets'):
26 | return any(filename in fileset for fileset in gradio.temp_file_sets)
27 |
28 | if hasattr(gradio, 'temp_dirs'):
29 | return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
30 |
31 | return False
32 |
33 |
34 | def save_pil_to_file(self, pil_image, dir=None):
35 | already_saved_as = getattr(pil_image, 'already_saved_as', None)
36 | if already_saved_as and os.path.isfile(already_saved_as):
37 | register_tmp_file(shared.demo, already_saved_as)
38 | filename = already_saved_as
39 |
40 | if not shared.opts.save_images_add_number:
41 | filename += f'?{os.path.getmtime(already_saved_as)}'
42 |
43 | return filename
44 |
45 | if shared.opts.temp_dir != "":
46 | dir = shared.opts.temp_dir
47 |
48 | use_metadata = False
49 | metadata = PngImagePlugin.PngInfo()
50 | for key, value in pil_image.info.items():
51 | if isinstance(key, str) and isinstance(value, str):
52 | metadata.add_text(key, value)
53 | use_metadata = True
54 |
55 | file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
56 | pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
57 | return file_obj.name
58 |
59 |
60 | # override save to file function so that it also writes PNG info
61 | gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
62 |
63 |
64 | def on_tmpdir_changed():
65 | if shared.opts.temp_dir == "" or shared.demo is None:
66 | return
67 |
68 | os.makedirs(shared.opts.temp_dir, exist_ok=True)
69 |
70 | register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
71 |
72 |
73 | def cleanup_tmpdr():
74 | temp_dir = shared.opts.temp_dir
75 | if temp_dir == "" or not os.path.isdir(temp_dir):
76 | return
77 |
78 | for root, _, files in os.walk(temp_dir, topdown=False):
79 | for name in files:
80 | _, extension = os.path.splitext(name)
81 | if extension != ".png":
82 | continue
83 |
84 | filename = os.path.join(root, name)
85 | os.remove(filename)
86 |
--------------------------------------------------------------------------------
/configs/instruct-pix2pix.yaml:
--------------------------------------------------------------------------------
1 | # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2 | # See more details in LICENSE.
3 |
4 | model:
5 | base_learning_rate: 1.0e-04
6 | target: modules.models.diffusion.ddpm_edit.LatentDiffusion
7 | params:
8 | linear_start: 0.00085
9 | linear_end: 0.0120
10 | num_timesteps_cond: 1
11 | log_every_t: 200
12 | timesteps: 1000
13 | first_stage_key: edited
14 | cond_stage_key: edit
15 | # image_size: 64
16 | # image_size: 32
17 | image_size: 16
18 | channels: 4
19 | cond_stage_trainable: false # Note: different from the one we trained before
20 | conditioning_key: hybrid
21 | monitor: val/loss_simple_ema
22 | scale_factor: 0.18215
23 | use_ema: false
24 |
25 | scheduler_config: # 10000 warmup steps
26 | target: ldm.lr_scheduler.LambdaLinearScheduler
27 | params:
28 | warm_up_steps: [ 0 ]
29 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
30 | f_start: [ 1.e-6 ]
31 | f_max: [ 1. ]
32 | f_min: [ 1. ]
33 |
34 | unet_config:
35 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
36 | params:
37 | image_size: 32 # unused
38 | in_channels: 8
39 | out_channels: 4
40 | model_channels: 320
41 | attention_resolutions: [ 4, 2, 1 ]
42 | num_res_blocks: 2
43 | channel_mult: [ 1, 2, 4, 4 ]
44 | num_heads: 8
45 | use_spatial_transformer: True
46 | transformer_depth: 1
47 | context_dim: 768
48 | use_checkpoint: True
49 | legacy: False
50 |
51 | first_stage_config:
52 | target: ldm.models.autoencoder.AutoencoderKL
53 | params:
54 | embed_dim: 4
55 | monitor: val/rec_loss
56 | ddconfig:
57 | double_z: true
58 | z_channels: 4
59 | resolution: 256
60 | in_channels: 3
61 | out_ch: 3
62 | ch: 128
63 | ch_mult:
64 | - 1
65 | - 2
66 | - 4
67 | - 4
68 | num_res_blocks: 2
69 | attn_resolutions: []
70 | dropout: 0.0
71 | lossconfig:
72 | target: torch.nn.Identity
73 |
74 | cond_stage_config:
75 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
76 |
77 | data:
78 | target: main.DataModuleFromConfig
79 | params:
80 | batch_size: 128
81 | num_workers: 1
82 | wrap: false
83 | validation:
84 | target: edit_dataset.EditDataset
85 | params:
86 | path: data/clip-filtered-dataset
87 | cache_dir: data/
88 | cache_name: data_10k
89 | split: val
90 | min_text_sim: 0.2
91 | min_image_sim: 0.75
92 | min_direction_sim: 0.2
93 | max_samples_per_prompt: 1
94 | min_resize_res: 512
95 | max_resize_res: 512
96 | crop_res: 512
97 | output_as_edit: False
98 | real_input: True
99 |
--------------------------------------------------------------------------------
/modules/styles.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | import os.path
4 | import typing
5 | import shutil
6 |
7 |
8 | class PromptStyle(typing.NamedTuple):
9 | name: str
10 | prompt: str
11 | negative_prompt: str
12 |
13 |
14 | def merge_prompts(style_prompt: str, prompt: str) -> str:
15 | if "{prompt}" in style_prompt:
16 | res = style_prompt.replace("{prompt}", prompt)
17 | else:
18 | parts = filter(None, (prompt.strip(), style_prompt.strip()))
19 | res = ", ".join(parts)
20 |
21 | return res
22 |
23 |
24 | def apply_styles_to_prompt(prompt, styles):
25 | for style in styles:
26 | prompt = merge_prompts(style, prompt)
27 |
28 | return prompt
29 |
30 |
31 | class StyleDatabase:
32 | def __init__(self, path: str):
33 | self.no_style = PromptStyle("None", "", "")
34 | self.styles = {}
35 | self.path = path
36 |
37 | self.reload()
38 |
39 | def reload(self):
40 | self.styles.clear()
41 |
42 | if not os.path.exists(self.path):
43 | return
44 |
45 | with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
46 | reader = csv.DictReader(file, skipinitialspace=True)
47 | for row in reader:
48 | # Support loading old CSV format with "name, text"-columns
49 | prompt = row["prompt"] if "prompt" in row else row["text"]
50 | negative_prompt = row.get("negative_prompt", "")
51 | self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
52 |
53 | def get_style_prompts(self, styles):
54 | return [self.styles.get(x, self.no_style).prompt for x in styles]
55 |
56 | def get_negative_style_prompts(self, styles):
57 | return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
58 |
59 | def apply_styles_to_prompt(self, prompt, styles):
60 | return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
61 |
62 | def apply_negative_styles_to_prompt(self, prompt, styles):
63 | return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
64 |
65 | def save_styles(self, path: str) -> None:
66 | # Always keep a backup file around
67 | if os.path.exists(path):
68 | shutil.copy(path, f"{path}.bak")
69 |
70 | fd = os.open(path, os.O_RDWR|os.O_CREAT)
71 | with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
72 | # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
73 | # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
74 | writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
75 | writer.writeheader()
76 | writer.writerows(style._asdict() for k, style in self.styles.items())
77 |
--------------------------------------------------------------------------------
/modules/sd_vae_taesd.py:
--------------------------------------------------------------------------------
1 | """
2 | Tiny AutoEncoder for Stable Diffusion
3 | (DNN for encoding / decoding SD's latent space)
4 |
5 | https://github.com/madebyollin/taesd
6 | """
7 | import os
8 | import torch
9 | import torch.nn as nn
10 |
11 | from modules import devices, paths_internal
12 |
13 | sd_vae_taesd = None
14 |
15 |
16 | def conv(n_in, n_out, **kwargs):
17 | return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
18 |
19 |
20 | class Clamp(nn.Module):
21 | @staticmethod
22 | def forward(x):
23 | return torch.tanh(x / 3) * 3
24 |
25 |
26 | class Block(nn.Module):
27 | def __init__(self, n_in, n_out):
28 | super().__init__()
29 | self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
30 | self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
31 | self.fuse = nn.ReLU()
32 |
33 | def forward(self, x):
34 | return self.fuse(self.conv(x) + self.skip(x))
35 |
36 |
37 | def decoder():
38 | return nn.Sequential(
39 | Clamp(), conv(4, 64), nn.ReLU(),
40 | Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
41 | Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
42 | Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
43 | Block(64, 64), conv(64, 3),
44 | )
45 |
46 |
47 | class TAESD(nn.Module):
48 | latent_magnitude = 3
49 | latent_shift = 0.5
50 |
51 | def __init__(self, decoder_path="taesd_decoder.pth"):
52 | """Initialize pretrained TAESD on the given device from the given checkpoints."""
53 | super().__init__()
54 | self.decoder = decoder()
55 | self.decoder.load_state_dict(
56 | torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
57 |
58 | @staticmethod
59 | def unscale_latents(x):
60 | """[0, 1] -> raw latents"""
61 | return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
62 |
63 |
64 | def download_model(model_path):
65 | model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
66 |
67 | if not os.path.exists(model_path):
68 | os.makedirs(os.path.dirname(model_path), exist_ok=True)
69 |
70 | print(f'Downloading TAESD decoder to: {model_path}')
71 | torch.hub.download_url_to_file(model_url, model_path)
72 |
73 |
74 | def model():
75 | global sd_vae_taesd
76 |
77 | if sd_vae_taesd is None:
78 | model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
79 | download_model(model_path)
80 |
81 | if os.path.exists(model_path):
82 | sd_vae_taesd = TAESD(model_path)
83 | sd_vae_taesd.eval()
84 | sd_vae_taesd.to(devices.device, devices.dtype)
85 | else:
86 | raise FileNotFoundError('TAESD model not found')
87 |
88 | return sd_vae_taesd.decoder
89 |
--------------------------------------------------------------------------------
/modules/memmon.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import time
3 | from collections import defaultdict
4 |
5 | import torch
6 |
7 |
8 | class MemUsageMonitor(threading.Thread):
9 | run_flag = None
10 | device = None
11 | disabled = False
12 | opts = None
13 | data = None
14 |
15 | def __init__(self, name, device, opts):
16 | threading.Thread.__init__(self)
17 | self.name = name
18 | self.device = device
19 | self.opts = opts
20 |
21 | self.daemon = True
22 | self.run_flag = threading.Event()
23 | self.data = defaultdict(int)
24 |
25 | try:
26 | self.cuda_mem_get_info()
27 | torch.cuda.memory_stats(self.device)
28 | except Exception as e: # AMD or whatever
29 | print(f"Warning: caught exception '{e}', memory monitor disabled")
30 | self.disabled = True
31 |
32 | def cuda_mem_get_info(self):
33 | index = self.device.index if self.device.index is not None else torch.cuda.current_device()
34 | return torch.cuda.mem_get_info(index)
35 |
36 | def run(self):
37 | if self.disabled:
38 | return
39 |
40 | while True:
41 | self.run_flag.wait()
42 |
43 | torch.cuda.reset_peak_memory_stats()
44 | self.data.clear()
45 |
46 | if self.opts.memmon_poll_rate <= 0:
47 | self.run_flag.clear()
48 | continue
49 |
50 | self.data["min_free"] = self.cuda_mem_get_info()[0]
51 |
52 | while self.run_flag.is_set():
53 | free, total = self.cuda_mem_get_info()
54 | self.data["min_free"] = min(self.data["min_free"], free)
55 |
56 | time.sleep(1 / self.opts.memmon_poll_rate)
57 |
58 | def dump_debug(self):
59 | print(self, 'recorded data:')
60 | for k, v in self.read().items():
61 | print(k, -(v // -(1024 ** 2)))
62 |
63 | print(self, 'raw torch memory stats:')
64 | tm = torch.cuda.memory_stats(self.device)
65 | for k, v in tm.items():
66 | if 'bytes' not in k:
67 | continue
68 | print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
69 |
70 | print(torch.cuda.memory_summary())
71 |
72 | def monitor(self):
73 | self.run_flag.set()
74 |
75 | def read(self):
76 | if not self.disabled:
77 | free, total = self.cuda_mem_get_info()
78 | self.data["free"] = free
79 | self.data["total"] = total
80 |
81 | torch_stats = torch.cuda.memory_stats(self.device)
82 | self.data["active"] = torch_stats["active.all.current"]
83 | self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
84 | self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
85 | self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
86 | self.data["system_peak"] = total - self.data["min_free"]
87 |
88 | return self.data
89 |
90 | def stop(self):
91 | self.run_flag.clear()
92 | return self.read()
93 |
--------------------------------------------------------------------------------
/modules/textual_inversion/learn_schedule.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 |
3 |
4 | class LearnScheduleIterator:
5 | def __init__(self, learn_rate, max_steps, cur_step=0):
6 | """
7 | specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
8 | """
9 |
10 | pairs = learn_rate.split(',')
11 | self.rates = []
12 | self.it = 0
13 | self.maxit = 0
14 | try:
15 | for pair in pairs:
16 | if not pair.strip():
17 | continue
18 | tmp = pair.split(':')
19 | if len(tmp) == 2:
20 | step = int(tmp[1])
21 | if step > cur_step:
22 | self.rates.append((float(tmp[0]), min(step, max_steps)))
23 | self.maxit += 1
24 | if step > max_steps:
25 | return
26 | elif step == -1:
27 | self.rates.append((float(tmp[0]), max_steps))
28 | self.maxit += 1
29 | return
30 | else:
31 | self.rates.append((float(tmp[0]), max_steps))
32 | self.maxit += 1
33 | return
34 | assert self.rates
35 | except (ValueError, AssertionError) as e:
36 | raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
37 |
38 |
39 | def __iter__(self):
40 | return self
41 |
42 | def __next__(self):
43 | if self.it < self.maxit:
44 | self.it += 1
45 | return self.rates[self.it - 1]
46 | else:
47 | raise StopIteration
48 |
49 |
50 | class LearnRateScheduler:
51 | def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
52 | self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
53 | (self.learn_rate, self.end_step) = next(self.schedules)
54 | self.verbose = verbose
55 |
56 | if self.verbose:
57 | print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
58 |
59 | self.finished = False
60 |
61 | def step(self, step_number):
62 | if step_number < self.end_step:
63 | return False
64 |
65 | try:
66 | (self.learn_rate, self.end_step) = next(self.schedules)
67 | except StopIteration:
68 | self.finished = True
69 | return False
70 | return True
71 |
72 | def apply(self, optimizer, step_number):
73 | if not self.step(step_number):
74 | return
75 |
76 | if self.verbose:
77 | tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
78 |
79 | for pg in optimizer.param_groups:
80 | pg['lr'] = self.learn_rate
81 |
82 |
--------------------------------------------------------------------------------
/modules/txt2img.py:
--------------------------------------------------------------------------------
1 | import modules.scripts
2 | from modules import sd_samplers, processing
3 | from modules.generation_parameters_copypaste import create_override_settings_dict
4 | from modules.shared import opts, cmd_opts
5 | import modules.shared as shared
6 | from modules.ui import plaintext_to_html
7 |
8 |
9 |
10 | def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
11 | override_settings = create_override_settings_dict(override_settings_texts)
12 |
13 | p = processing.StableDiffusionProcessingTxt2Img(
14 | sd_model=shared.sd_model,
15 | outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
16 | outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
17 | prompt=prompt,
18 | styles=prompt_styles,
19 | negative_prompt=negative_prompt,
20 | seed=seed,
21 | subseed=subseed,
22 | subseed_strength=subseed_strength,
23 | seed_resize_from_h=seed_resize_from_h,
24 | seed_resize_from_w=seed_resize_from_w,
25 | seed_enable_extras=seed_enable_extras,
26 | sampler_name=sd_samplers.samplers[sampler_index].name,
27 | batch_size=batch_size,
28 | n_iter=n_iter,
29 | steps=steps,
30 | cfg_scale=cfg_scale,
31 | width=width,
32 | height=height,
33 | restore_faces=restore_faces,
34 | tiling=tiling,
35 | enable_hr=enable_hr,
36 | denoising_strength=denoising_strength if enable_hr else None,
37 | hr_scale=hr_scale,
38 | hr_upscaler=hr_upscaler,
39 | hr_second_pass_steps=hr_second_pass_steps,
40 | hr_resize_x=hr_resize_x,
41 | hr_resize_y=hr_resize_y,
42 | hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
43 | hr_prompt=hr_prompt,
44 | hr_negative_prompt=hr_negative_prompt,
45 | override_settings=override_settings,
46 | )
47 |
48 | p.scripts = modules.scripts.scripts_txt2img
49 | p.script_args = args
50 |
51 | if cmd_opts.enable_console_prompts:
52 | print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
53 |
54 | processed = modules.scripts.scripts_txt2img.run(p, *args)
55 |
56 | if processed is None:
57 | processed = processing.process_images(p)
58 |
59 | p.close()
60 |
61 | shared.total_tqdm.clear()
62 |
63 | generation_info_js = processed.js()
64 | if opts.samples_log_stdout:
65 | print(generation_info_js)
66 |
67 | if opts.do_not_show_images:
68 | processed.images = []
69 |
70 | return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
71 |
--------------------------------------------------------------------------------
/modules/hashes.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import json
3 | import os.path
4 |
5 | import filelock
6 |
7 | from modules import shared
8 | from modules.paths import data_path
9 |
10 |
11 | cache_filename = os.path.join(data_path, "cache.json")
12 | cache_data = None
13 |
14 |
15 | def dump_cache():
16 | with filelock.FileLock(f"{cache_filename}.lock"):
17 | with open(cache_filename, "w", encoding="utf8") as file:
18 | json.dump(cache_data, file, indent=4)
19 |
20 |
21 | def cache(subsection):
22 | global cache_data
23 |
24 | if cache_data is None:
25 | with filelock.FileLock(f"{cache_filename}.lock"):
26 | if not os.path.isfile(cache_filename):
27 | cache_data = {}
28 | else:
29 | with open(cache_filename, "r", encoding="utf8") as file:
30 | cache_data = json.load(file)
31 |
32 | s = cache_data.get(subsection, {})
33 | cache_data[subsection] = s
34 |
35 | return s
36 |
37 |
38 | def calculate_sha256(filename):
39 | hash_sha256 = hashlib.sha256()
40 | blksize = 1024 * 1024
41 |
42 | with open(filename, "rb") as f:
43 | for chunk in iter(lambda: f.read(blksize), b""):
44 | hash_sha256.update(chunk)
45 |
46 | return hash_sha256.hexdigest()
47 |
48 |
49 | def sha256_from_cache(filename, title, use_addnet_hash=False):
50 | hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
51 | ondisk_mtime = os.path.getmtime(filename)
52 |
53 | if title not in hashes:
54 | return None
55 |
56 | cached_sha256 = hashes[title].get("sha256", None)
57 | cached_mtime = hashes[title].get("mtime", 0)
58 |
59 | if ondisk_mtime > cached_mtime or cached_sha256 is None:
60 | return None
61 |
62 | return cached_sha256
63 |
64 |
65 | def sha256(filename, title, use_addnet_hash=False):
66 | hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
67 |
68 | sha256_value = sha256_from_cache(filename, title, use_addnet_hash)
69 | if sha256_value is not None:
70 | return sha256_value
71 |
72 | if shared.cmd_opts.no_hashing:
73 | return None
74 |
75 | print(f"Calculating sha256 for {filename}: ", end='')
76 | if use_addnet_hash:
77 | with open(filename, "rb") as file:
78 | sha256_value = addnet_hash_safetensors(file)
79 | else:
80 | sha256_value = calculate_sha256(filename)
81 | print(f"{sha256_value}")
82 |
83 | hashes[title] = {
84 | "mtime": os.path.getmtime(filename),
85 | "sha256": sha256_value,
86 | }
87 |
88 | dump_cache()
89 |
90 | return sha256_value
91 |
92 |
93 | def addnet_hash_safetensors(b):
94 | """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
95 | hash_sha256 = hashlib.sha256()
96 | blksize = 1024 * 1024
97 |
98 | b.seek(0)
99 | header = b.read(8)
100 | n = int.from_bytes(header, "little")
101 |
102 | offset = n + 8
103 | b.seek(offset)
104 | for chunk in iter(lambda: b.read(blksize), b""):
105 | hash_sha256.update(chunk)
106 |
107 | return hash_sha256.hexdigest()
108 |
109 |
--------------------------------------------------------------------------------
/modules/deepbooru.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 |
4 | import torch
5 | import numpy as np
6 |
7 | from modules import modelloader, paths, deepbooru_model, devices, images, shared
8 |
9 | re_special = re.compile(r'([\\()])')
10 |
11 |
12 | class DeepDanbooru:
13 | def __init__(self):
14 | self.model = None
15 |
16 | def load(self):
17 | if self.model is not None:
18 | return
19 |
20 | files = modelloader.load_models(
21 | model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
22 | model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
23 | ext_filter=[".pt"],
24 | download_name='model-resnet_custom_v3.pt',
25 | )
26 |
27 | self.model = deepbooru_model.DeepDanbooruModel()
28 | self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
29 |
30 | self.model.eval()
31 | self.model.to(devices.cpu, devices.dtype)
32 |
33 | def start(self):
34 | self.load()
35 | self.model.to(devices.device)
36 |
37 | def stop(self):
38 | if not shared.opts.interrogate_keep_models_in_memory:
39 | self.model.to(devices.cpu)
40 | devices.torch_gc()
41 |
42 | def tag(self, pil_image):
43 | self.start()
44 | res = self.tag_multi(pil_image)
45 | self.stop()
46 |
47 | return res
48 |
49 | def tag_multi(self, pil_image, force_disable_ranks=False):
50 | threshold = shared.opts.interrogate_deepbooru_score_threshold
51 | use_spaces = shared.opts.deepbooru_use_spaces
52 | use_escape = shared.opts.deepbooru_escape
53 | alpha_sort = shared.opts.deepbooru_sort_alpha
54 | include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
55 |
56 | pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
57 | a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
58 |
59 | with torch.no_grad(), devices.autocast():
60 | x = torch.from_numpy(a).to(devices.device)
61 | y = self.model(x)[0].detach().cpu().numpy()
62 |
63 | probability_dict = {}
64 |
65 | for tag, probability in zip(self.model.tags, y):
66 | if probability < threshold:
67 | continue
68 |
69 | if tag.startswith("rating:"):
70 | continue
71 |
72 | probability_dict[tag] = probability
73 |
74 | if alpha_sort:
75 | tags = sorted(probability_dict)
76 | else:
77 | tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
78 |
79 | res = []
80 |
81 | filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
82 |
83 | for tag in [x for x in tags if x not in filtertags]:
84 | probability = probability_dict[tag]
85 | tag_outformat = tag
86 | if use_spaces:
87 | tag_outformat = tag_outformat.replace('_', ' ')
88 | if use_escape:
89 | tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
90 | if include_ranks:
91 | tag_outformat = f"({tag_outformat}:{probability:.3f})"
92 |
93 | res.append(tag_outformat)
94 |
95 | return ", ".join(res)
96 |
97 |
98 | model = DeepDanbooru()
99 |
--------------------------------------------------------------------------------
/.eslintrc.js:
--------------------------------------------------------------------------------
1 | /* global module */
2 | module.exports = {
3 | env: {
4 | browser: true,
5 | es2021: true,
6 | },
7 | extends: "eslint:recommended",
8 | parserOptions: {
9 | ecmaVersion: "latest",
10 | },
11 | rules: {
12 | "arrow-spacing": "error",
13 | "block-spacing": "error",
14 | "brace-style": "error",
15 | "comma-dangle": ["error", "only-multiline"],
16 | "comma-spacing": "error",
17 | "comma-style": ["error", "last"],
18 | "curly": ["error", "multi-line", "consistent"],
19 | "eol-last": "error",
20 | "func-call-spacing": "error",
21 | "function-call-argument-newline": ["error", "consistent"],
22 | "function-paren-newline": ["error", "consistent"],
23 | "indent": ["error", 4],
24 | "key-spacing": "error",
25 | "keyword-spacing": "error",
26 | "linebreak-style": ["error", "unix"],
27 | "no-extra-semi": "error",
28 | "no-mixed-spaces-and-tabs": "error",
29 | "no-multi-spaces": "error",
30 | "no-redeclare": ["error", {builtinGlobals: false}],
31 | "no-trailing-spaces": "error",
32 | "no-unused-vars": "off",
33 | "no-whitespace-before-property": "error",
34 | "object-curly-newline": ["error", {consistent: true, multiline: true}],
35 | "object-curly-spacing": ["error", "never"],
36 | "operator-linebreak": ["error", "after"],
37 | "quote-props": ["error", "consistent-as-needed"],
38 | "semi": ["error", "always"],
39 | "semi-spacing": "error",
40 | "semi-style": ["error", "last"],
41 | "space-before-blocks": "error",
42 | "space-before-function-paren": ["error", "never"],
43 | "space-in-parens": ["error", "never"],
44 | "space-infix-ops": "error",
45 | "space-unary-ops": "error",
46 | "switch-colon-spacing": "error",
47 | "template-curly-spacing": ["error", "never"],
48 | "unicode-bom": "error",
49 | },
50 | globals: {
51 | //script.js
52 | gradioApp: "readonly",
53 | onUiLoaded: "readonly",
54 | onUiUpdate: "readonly",
55 | onOptionsChanged: "readonly",
56 | uiCurrentTab: "writable",
57 | uiElementIsVisible: "readonly",
58 | uiElementInSight: "readonly",
59 | executeCallbacks: "readonly",
60 | //ui.js
61 | opts: "writable",
62 | all_gallery_buttons: "readonly",
63 | selected_gallery_button: "readonly",
64 | selected_gallery_index: "readonly",
65 | switch_to_txt2img: "readonly",
66 | switch_to_img2img_tab: "readonly",
67 | switch_to_img2img: "readonly",
68 | switch_to_sketch: "readonly",
69 | switch_to_inpaint: "readonly",
70 | switch_to_inpaint_sketch: "readonly",
71 | switch_to_extras: "readonly",
72 | get_tab_index: "readonly",
73 | create_submit_args: "readonly",
74 | restart_reload: "readonly",
75 | updateInput: "readonly",
76 | //extraNetworks.js
77 | requestGet: "readonly",
78 | popup: "readonly",
79 | // from python
80 | localization: "readonly",
81 | // progrssbar.js
82 | randomId: "readonly",
83 | requestProgress: "readonly",
84 | // imageviewer.js
85 | modalPrevImage: "readonly",
86 | modalNextImage: "readonly",
87 | }
88 | };
89 |
--------------------------------------------------------------------------------
/test/test_txt2img.py:
--------------------------------------------------------------------------------
1 |
2 | import pytest
3 | import requests
4 |
5 |
6 | @pytest.fixture()
7 | def url_txt2img(base_url):
8 | return f"{base_url}/sdapi/v1/txt2img"
9 |
10 |
11 | @pytest.fixture()
12 | def simple_txt2img_request():
13 | return {
14 | "batch_size": 1,
15 | "cfg_scale": 7,
16 | "denoising_strength": 0,
17 | "enable_hr": False,
18 | "eta": 0,
19 | "firstphase_height": 0,
20 | "firstphase_width": 0,
21 | "height": 64,
22 | "n_iter": 1,
23 | "negative_prompt": "",
24 | "prompt": "example prompt",
25 | "restore_faces": False,
26 | "s_churn": 0,
27 | "s_noise": 1,
28 | "s_tmax": 0,
29 | "s_tmin": 0,
30 | "sampler_index": "Euler a",
31 | "seed": -1,
32 | "seed_resize_from_h": -1,
33 | "seed_resize_from_w": -1,
34 | "steps": 3,
35 | "styles": [],
36 | "subseed": -1,
37 | "subseed_strength": 0,
38 | "tiling": False,
39 | "width": 64,
40 | }
41 |
42 |
43 | def test_txt2img_simple_performed(url_txt2img, simple_txt2img_request):
44 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
45 |
46 |
47 | def test_txt2img_with_negative_prompt_performed(url_txt2img, simple_txt2img_request):
48 | simple_txt2img_request["negative_prompt"] = "example negative prompt"
49 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
50 |
51 |
52 | def test_txt2img_with_complex_prompt_performed(url_txt2img, simple_txt2img_request):
53 | simple_txt2img_request["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]"
54 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
55 |
56 |
57 | def test_txt2img_not_square_image_performed(url_txt2img, simple_txt2img_request):
58 | simple_txt2img_request["height"] = 128
59 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
60 |
61 |
62 | def test_txt2img_with_hrfix_performed(url_txt2img, simple_txt2img_request):
63 | simple_txt2img_request["enable_hr"] = True
64 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
65 |
66 |
67 | def test_txt2img_with_tiling_performed(url_txt2img, simple_txt2img_request):
68 | simple_txt2img_request["tiling"] = True
69 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
70 |
71 |
72 | def test_txt2img_with_restore_faces_performed(url_txt2img, simple_txt2img_request):
73 | simple_txt2img_request["restore_faces"] = True
74 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
75 |
76 |
77 | @pytest.mark.parametrize("sampler", ["PLMS", "DDIM", "UniPC"])
78 | def test_txt2img_with_vanilla_sampler_performed(url_txt2img, simple_txt2img_request, sampler):
79 | simple_txt2img_request["sampler_index"] = sampler
80 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
81 |
82 |
83 | def test_txt2img_multiple_batches_performed(url_txt2img, simple_txt2img_request):
84 | simple_txt2img_request["n_iter"] = 2
85 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
86 |
87 |
88 | def test_txt2img_batch_performed(url_txt2img, simple_txt2img_request):
89 | simple_txt2img_request["batch_size"] = 2
90 | assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
91 |
--------------------------------------------------------------------------------
/script.js:
--------------------------------------------------------------------------------
1 | function gradioApp() {
2 | const elems = document.getElementsByTagName('gradio-app');
3 | const elem = elems.length == 0 ? document : elems[0];
4 |
5 | if (elem !== document) {
6 | elem.getElementById = function(id) {
7 | return document.getElementById(id);
8 | };
9 | }
10 | return elem.shadowRoot ? elem.shadowRoot : elem;
11 | }
12 |
13 | function get_uiCurrentTab() {
14 | return gradioApp().querySelector('#tabs button.selected');
15 | }
16 |
17 | function get_uiCurrentTabContent() {
18 | return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])');
19 | }
20 |
21 | var uiUpdateCallbacks = [];
22 | var uiLoadedCallbacks = [];
23 | var uiTabChangeCallbacks = [];
24 | var optionsChangedCallbacks = [];
25 | var uiCurrentTab = null;
26 |
27 | function onUiUpdate(callback) {
28 | uiUpdateCallbacks.push(callback);
29 | }
30 | function onUiLoaded(callback) {
31 | uiLoadedCallbacks.push(callback);
32 | }
33 | function onUiTabChange(callback) {
34 | uiTabChangeCallbacks.push(callback);
35 | }
36 | function onOptionsChanged(callback) {
37 | optionsChangedCallbacks.push(callback);
38 | }
39 |
40 | function runCallback(x, m) {
41 | try {
42 | x(m);
43 | } catch (e) {
44 | (console.error || console.log).call(console, e.message, e);
45 | }
46 | }
47 | function executeCallbacks(queue, m) {
48 | queue.forEach(function(x) {
49 | runCallback(x, m);
50 | });
51 | }
52 |
53 | var executedOnLoaded = false;
54 |
55 | document.addEventListener("DOMContentLoaded", function() {
56 | var mutationObserver = new MutationObserver(function(m) {
57 | if (!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')) {
58 | executedOnLoaded = true;
59 | executeCallbacks(uiLoadedCallbacks);
60 | }
61 |
62 | executeCallbacks(uiUpdateCallbacks, m);
63 | const newTab = get_uiCurrentTab();
64 | if (newTab && (newTab !== uiCurrentTab)) {
65 | uiCurrentTab = newTab;
66 | executeCallbacks(uiTabChangeCallbacks);
67 | }
68 | });
69 | mutationObserver.observe(gradioApp(), {childList: true, subtree: true});
70 | });
71 |
72 | /**
73 | * Add a ctrl+enter as a shortcut to start a generation
74 | */
75 | document.addEventListener('keydown', function(e) {
76 | var handled = false;
77 | if (e.key !== undefined) {
78 | if ((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
79 | } else if (e.keyCode !== undefined) {
80 | if ((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
81 | }
82 | if (handled) {
83 | var button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
84 | if (button) {
85 | button.click();
86 | }
87 | e.preventDefault();
88 | }
89 | });
90 |
91 | /**
92 | * checks that a UI element is not in another hidden element or tab content
93 | */
94 | function uiElementIsVisible(el) {
95 | if (el === document) {
96 | return true;
97 | }
98 |
99 | const computedStyle = getComputedStyle(el);
100 | const isVisible = computedStyle.display !== 'none';
101 |
102 | if (!isVisible) return false;
103 | return uiElementIsVisible(el.parentNode);
104 | }
105 |
106 | function uiElementInSight(el) {
107 | const clRect = el.getBoundingClientRect();
108 | const windowHeight = window.innerHeight;
109 | const isOnScreen = clRect.bottom > 0 && clRect.top < windowHeight;
110 |
111 | return isOnScreen;
112 | }
113 |
--------------------------------------------------------------------------------
/extensions-builtin/LDSR/scripts/ldsr_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import traceback
4 |
5 | from basicsr.utils.download_util import load_file_from_url
6 |
7 | from modules.upscaler import Upscaler, UpscalerData
8 | from ldsr_model_arch import LDSR
9 | from modules import shared, script_callbacks
10 | import sd_hijack_autoencoder # noqa: F401
11 | import sd_hijack_ddpm_v1 # noqa: F401
12 |
13 |
14 | class UpscalerLDSR(Upscaler):
15 | def __init__(self, user_path):
16 | self.name = "LDSR"
17 | self.user_path = user_path
18 | self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
19 | self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
20 | super().__init__()
21 | scaler_data = UpscalerData("LDSR", None, self)
22 | self.scalers = [scaler_data]
23 |
24 | def load_model(self, path: str):
25 | # Remove incorrect project.yaml file if too big
26 | yaml_path = os.path.join(self.model_path, "project.yaml")
27 | old_model_path = os.path.join(self.model_path, "model.pth")
28 | new_model_path = os.path.join(self.model_path, "model.ckpt")
29 |
30 | local_model_paths = self.find_models(ext_filter=[".ckpt", ".safetensors"])
31 | local_ckpt_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.ckpt")]), None)
32 | local_safetensors_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.safetensors")]), None)
33 | local_yaml_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("project.yaml")]), None)
34 |
35 | if os.path.exists(yaml_path):
36 | statinfo = os.stat(yaml_path)
37 | if statinfo.st_size >= 10485760:
38 | print("Removing invalid LDSR YAML file.")
39 | os.remove(yaml_path)
40 |
41 | if os.path.exists(old_model_path):
42 | print("Renaming model from model.pth to model.ckpt")
43 | os.rename(old_model_path, new_model_path)
44 |
45 | if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
46 | model = local_safetensors_path
47 | else:
48 | model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True)
49 |
50 | yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True)
51 |
52 | try:
53 | return LDSR(model, yaml)
54 |
55 | except Exception:
56 | print("Error importing LDSR:", file=sys.stderr)
57 | print(traceback.format_exc(), file=sys.stderr)
58 | return None
59 |
60 | def do_upscale(self, img, path):
61 | ldsr = self.load_model(path)
62 | if ldsr is None:
63 | print("NO LDSR!")
64 | return img
65 | ddim_steps = shared.opts.ldsr_steps
66 | return ldsr.super_resolution(img, ddim_steps, self.scale)
67 |
68 |
69 | def on_ui_settings():
70 | import gradio as gr
71 |
72 | shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
73 | shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
74 |
75 |
76 | script_callbacks.on_ui_settings(on_ui_settings)
77 |
--------------------------------------------------------------------------------
/javascript/dragdrop.js:
--------------------------------------------------------------------------------
1 | // allows drag-dropping files into gradio image elements, and also pasting images from clipboard
2 |
3 | function isValidImageList(files) {
4 | return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
5 | }
6 |
7 | function dropReplaceImage(imgWrap, files) {
8 | if (!isValidImageList(files)) {
9 | return;
10 | }
11 |
12 | const tmpFile = files[0];
13 |
14 | imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
15 | const callback = () => {
16 | const fileInput = imgWrap.querySelector('input[type="file"]');
17 | if (fileInput) {
18 | if (files.length === 0) {
19 | files = new DataTransfer();
20 | files.items.add(tmpFile);
21 | fileInput.files = files.files;
22 | } else {
23 | fileInput.files = files;
24 | }
25 | fileInput.dispatchEvent(new Event('change'));
26 | }
27 | };
28 |
29 | if (imgWrap.closest('#pnginfo_image')) {
30 | // special treatment for PNG Info tab, wait for fetch request to finish
31 | const oldFetch = window.fetch;
32 | window.fetch = async(input, options) => {
33 | const response = await oldFetch(input, options);
34 | if ('api/predict/' === input) {
35 | const content = await response.text();
36 | window.fetch = oldFetch;
37 | window.requestAnimationFrame(() => callback());
38 | return new Response(content, {
39 | status: response.status,
40 | statusText: response.statusText,
41 | headers: response.headers
42 | });
43 | }
44 | return response;
45 | };
46 | } else {
47 | window.requestAnimationFrame(() => callback());
48 | }
49 | }
50 |
51 | window.document.addEventListener('dragover', e => {
52 | const target = e.composedPath()[0];
53 | const imgWrap = target.closest('[data-testid="image"]');
54 | if (!imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
55 | return;
56 | }
57 | e.stopPropagation();
58 | e.preventDefault();
59 | e.dataTransfer.dropEffect = 'copy';
60 | });
61 |
62 | window.document.addEventListener('drop', e => {
63 | const target = e.composedPath()[0];
64 | if (target.placeholder.indexOf("Prompt") == -1) {
65 | return;
66 | }
67 | const imgWrap = target.closest('[data-testid="image"]');
68 | if (!imgWrap) {
69 | return;
70 | }
71 | e.stopPropagation();
72 | e.preventDefault();
73 | const files = e.dataTransfer.files;
74 | dropReplaceImage(imgWrap, files);
75 | });
76 |
77 | window.addEventListener('paste', e => {
78 | const files = e.clipboardData.files;
79 | if (!isValidImageList(files)) {
80 | return;
81 | }
82 |
83 | const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
84 | .filter(el => uiElementIsVisible(el))
85 | .sort((a, b) => uiElementInSight(b) - uiElementInSight(a));
86 |
87 |
88 | if (!visibleImageFields.length) {
89 | return;
90 | }
91 |
92 | const firstFreeImageField = visibleImageFields
93 | .filter(el => el.querySelector('input[type=file]'))?.[0];
94 |
95 | dropReplaceImage(
96 | firstFreeImageField ?
97 | firstFreeImageField :
98 | visibleImageFields[visibleImageFields.length - 1]
99 | , files
100 | );
101 | });
102 |
--------------------------------------------------------------------------------
/modules/sd_samplers_common.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 | from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
6 |
7 | from modules.shared import opts, state
8 | import modules.shared as shared
9 |
10 | SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
11 |
12 |
13 | def setup_img2img_steps(p, steps=None):
14 | if opts.img2img_fix_steps or steps is not None:
15 | requested_steps = (steps or p.steps)
16 | steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
17 | t_enc = requested_steps - 1
18 | else:
19 | steps = p.steps
20 | t_enc = int(min(p.denoising_strength, 0.999) * steps)
21 |
22 | return steps, t_enc
23 |
24 |
25 | approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
26 |
27 |
28 | def single_sample_to_image(sample, approximation=None):
29 | if approximation is None:
30 | approximation = approximation_indexes.get(opts.show_progress_type, 0)
31 |
32 | if approximation == 2:
33 | x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
34 | elif approximation == 1:
35 | x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
36 | elif approximation == 3:
37 | x_sample = sample * 1.5
38 | x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
39 | else:
40 | x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
41 |
42 | x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
43 | x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
44 | x_sample = x_sample.astype(np.uint8)
45 |
46 | return Image.fromarray(x_sample)
47 |
48 |
49 | def sample_to_image(samples, index=0, approximation=None):
50 | return single_sample_to_image(samples[index], approximation)
51 |
52 |
53 | def samples_to_image_grid(samples, approximation=None):
54 | return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
55 |
56 |
57 | def store_latent(decoded):
58 | state.current_latent = decoded
59 |
60 | if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
61 | if not shared.parallel_processing_allowed:
62 | shared.state.assign_current_image(sample_to_image(decoded))
63 |
64 |
65 | def is_sampler_using_eta_noise_seed_delta(p):
66 | """returns whether sampler from config will use eta noise seed delta for image creation"""
67 |
68 | sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
69 |
70 | eta = p.eta
71 |
72 | if eta is None and p.sampler is not None:
73 | eta = p.sampler.eta
74 |
75 | if eta is None and sampler_config is not None:
76 | eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
77 |
78 | if eta == 0:
79 | return False
80 |
81 | return sampler_config.options.get("uses_ensd", False)
82 |
83 |
84 | class InterruptedException(BaseException):
85 | pass
86 |
87 |
88 | if opts.randn_source == "CPU":
89 | import torchsde._brownian.brownian_interval
90 |
91 | def torchsde_randn(size, dtype, device, seed):
92 | generator = torch.Generator(devices.cpu).manual_seed(int(seed))
93 | return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
94 |
95 | torchsde._brownian.brownian_interval._randn = torchsde_randn
96 |
--------------------------------------------------------------------------------
/modules/masking.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageFilter, ImageOps
2 |
3 |
4 | def get_crop_region(mask, pad=0):
5 | """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
6 | For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
7 |
8 | h, w = mask.shape
9 |
10 | crop_left = 0
11 | for i in range(w):
12 | if not (mask[:, i] == 0).all():
13 | break
14 | crop_left += 1
15 |
16 | crop_right = 0
17 | for i in reversed(range(w)):
18 | if not (mask[:, i] == 0).all():
19 | break
20 | crop_right += 1
21 |
22 | crop_top = 0
23 | for i in range(h):
24 | if not (mask[i] == 0).all():
25 | break
26 | crop_top += 1
27 |
28 | crop_bottom = 0
29 | for i in reversed(range(h)):
30 | if not (mask[i] == 0).all():
31 | break
32 | crop_bottom += 1
33 |
34 | return (
35 | int(max(crop_left-pad, 0)),
36 | int(max(crop_top-pad, 0)),
37 | int(min(w - crop_right + pad, w)),
38 | int(min(h - crop_bottom + pad, h))
39 | )
40 |
41 |
42 | def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
43 | """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
44 | for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
45 |
46 | x1, y1, x2, y2 = crop_region
47 |
48 | ratio_crop_region = (x2 - x1) / (y2 - y1)
49 | ratio_processing = processing_width / processing_height
50 |
51 | if ratio_crop_region > ratio_processing:
52 | desired_height = (x2 - x1) / ratio_processing
53 | desired_height_diff = int(desired_height - (y2-y1))
54 | y1 -= desired_height_diff//2
55 | y2 += desired_height_diff - desired_height_diff//2
56 | if y2 >= image_height:
57 | diff = y2 - image_height
58 | y2 -= diff
59 | y1 -= diff
60 | if y1 < 0:
61 | y2 -= y1
62 | y1 -= y1
63 | if y2 >= image_height:
64 | y2 = image_height
65 | else:
66 | desired_width = (y2 - y1) * ratio_processing
67 | desired_width_diff = int(desired_width - (x2-x1))
68 | x1 -= desired_width_diff//2
69 | x2 += desired_width_diff - desired_width_diff//2
70 | if x2 >= image_width:
71 | diff = x2 - image_width
72 | x2 -= diff
73 | x1 -= diff
74 | if x1 < 0:
75 | x2 -= x1
76 | x1 -= x1
77 | if x2 >= image_width:
78 | x2 = image_width
79 |
80 | return x1, y1, x2, y2
81 |
82 |
83 | def fill(image, mask):
84 | """fills masked regions with colors from image using blur. Not extremely effective."""
85 |
86 | image_mod = Image.new('RGBA', (image.width, image.height))
87 |
88 | image_masked = Image.new('RGBa', (image.width, image.height))
89 | image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
90 |
91 | image_masked = image_masked.convert('RGBa')
92 |
93 | for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
94 | blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
95 | for _ in range(repeats):
96 | image_mod.alpha_composite(blurred)
97 |
98 | return image_mod.convert("RGB")
99 |
100 |
--------------------------------------------------------------------------------
/modules/sd_hijack_clip_old.py:
--------------------------------------------------------------------------------
1 | from modules import sd_hijack_clip
2 | from modules import shared
3 |
4 |
5 | def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
6 | id_start = self.id_start
7 | id_end = self.id_end
8 | maxlen = self.wrapped.max_length # you get to stay at 77
9 | used_custom_terms = []
10 | remade_batch_tokens = []
11 | hijack_comments = []
12 | hijack_fixes = []
13 | token_count = 0
14 |
15 | cache = {}
16 | batch_tokens = self.tokenize(texts)
17 | batch_multipliers = []
18 | for tokens in batch_tokens:
19 | tuple_tokens = tuple(tokens)
20 |
21 | if tuple_tokens in cache:
22 | remade_tokens, fixes, multipliers = cache[tuple_tokens]
23 | else:
24 | fixes = []
25 | remade_tokens = []
26 | multipliers = []
27 | mult = 1.0
28 |
29 | i = 0
30 | while i < len(tokens):
31 | token = tokens[i]
32 |
33 | embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
34 |
35 | mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
36 | if mult_change is not None:
37 | mult *= mult_change
38 | i += 1
39 | elif embedding is None:
40 | remade_tokens.append(token)
41 | multipliers.append(mult)
42 | i += 1
43 | else:
44 | emb_len = int(embedding.vec.shape[0])
45 | fixes.append((len(remade_tokens), embedding))
46 | remade_tokens += [0] * emb_len
47 | multipliers += [mult] * emb_len
48 | used_custom_terms.append((embedding.name, embedding.checksum()))
49 | i += embedding_length_in_tokens
50 |
51 | if len(remade_tokens) > maxlen - 2:
52 | vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
53 | ovf = remade_tokens[maxlen - 2:]
54 | overflowing_words = [vocab.get(int(x), "") for x in ovf]
55 | overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
56 | hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
57 |
58 | token_count = len(remade_tokens)
59 | remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
60 | remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
61 | cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
62 |
63 | multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
64 | multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
65 |
66 | remade_batch_tokens.append(remade_tokens)
67 | hijack_fixes.append(fixes)
68 | batch_multipliers.append(multipliers)
69 | return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
70 |
71 |
72 | def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
73 | batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
74 |
75 | self.hijack.comments += hijack_comments
76 |
77 | if len(used_custom_terms) > 0:
78 | embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
79 | self.hijack.comments.append(f"Used embeddings: {embedding_names}")
80 |
81 | self.hijack.fixes = hijack_fixes
82 | return self.process_tokens(remade_batch_tokens, batch_multipliers)
83 |
--------------------------------------------------------------------------------
/modules/call_queue.py:
--------------------------------------------------------------------------------
1 | import html
2 | import sys
3 | import threading
4 | import traceback
5 | import time
6 |
7 | from modules import shared, progress
8 |
9 | queue_lock = threading.Lock()
10 |
11 |
12 | def wrap_queued_call(func):
13 | def f(*args, **kwargs):
14 | with queue_lock:
15 | res = func(*args, **kwargs)
16 |
17 | return res
18 |
19 | return f
20 |
21 |
22 | def wrap_gradio_gpu_call(func, extra_outputs=None):
23 | def f(*args, **kwargs):
24 |
25 | # if the first argument is a string that says "task(...)", it is treated as a job id
26 | if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
27 | id_task = args[0]
28 | progress.add_task_to_queue(id_task)
29 | else:
30 | id_task = None
31 |
32 | with queue_lock:
33 | shared.state.begin()
34 | progress.start_task(id_task)
35 |
36 | try:
37 | res = func(*args, **kwargs)
38 | progress.record_results(id_task, res)
39 | finally:
40 | progress.finish_task(id_task)
41 |
42 | shared.state.end()
43 |
44 | return res
45 |
46 | return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
47 |
48 |
49 | def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
50 | def f(*args, extra_outputs_array=extra_outputs, **kwargs):
51 | run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
52 | if run_memmon:
53 | shared.mem_mon.monitor()
54 | t = time.perf_counter()
55 |
56 | try:
57 | res = list(func(*args, **kwargs))
58 | except Exception as e:
59 | # When printing out our debug argument list, do not print out more than a MB of text
60 | max_debug_str_len = 131072 # (1024*1024)/8
61 |
62 | print("Error completing request", file=sys.stderr)
63 | argStr = f"Arguments: {args} {kwargs}"
64 | print(argStr[:max_debug_str_len], file=sys.stderr)
65 | if len(argStr) > max_debug_str_len:
66 | print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
67 |
68 | print(traceback.format_exc(), file=sys.stderr)
69 |
70 | shared.state.job = ""
71 | shared.state.job_count = 0
72 |
73 | if extra_outputs_array is None:
74 | extra_outputs_array = [None, '']
75 |
76 | error_message = f'{type(e).__name__}: {e}'
77 | res = extra_outputs_array + [f"{html.escape(error_message)}
"]
78 |
79 | shared.state.skipped = False
80 | shared.state.interrupted = False
81 | shared.state.job_count = 0
82 |
83 | if not add_stats:
84 | return tuple(res)
85 |
86 | elapsed = time.perf_counter() - t
87 | elapsed_m = int(elapsed // 60)
88 | elapsed_s = elapsed % 60
89 | elapsed_text = f"{elapsed_s:.2f}s"
90 | if elapsed_m > 0:
91 | elapsed_text = f"{elapsed_m}m "+elapsed_text
92 |
93 | if run_memmon:
94 | mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
95 | active_peak = mem_stats['active_peak']
96 | reserved_peak = mem_stats['reserved_peak']
97 | sys_peak = mem_stats['system_peak']
98 | sys_total = mem_stats['total']
99 | sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
100 |
101 | vram_html = f"Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)
"
102 | else:
103 | vram_html = ''
104 |
105 | # last item is always HTML
106 | res[-1] += f""
107 |
108 | return tuple(res)
109 |
110 | return f
111 |
112 |
--------------------------------------------------------------------------------
/modules/sd_hijack_unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from packaging import version
3 |
4 | from modules import devices
5 | from modules.sd_hijack_utils import CondFunc
6 |
7 |
8 | class TorchHijackForUnet:
9 | """
10 | This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
11 | this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
12 | """
13 |
14 | def __getattr__(self, item):
15 | if item == 'cat':
16 | return self.cat
17 |
18 | if hasattr(torch, item):
19 | return getattr(torch, item)
20 |
21 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
22 |
23 | def cat(self, tensors, *args, **kwargs):
24 | if len(tensors) == 2:
25 | a, b = tensors
26 | if a.shape[-2:] != b.shape[-2:]:
27 | a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
28 |
29 | tensors = (a, b)
30 |
31 | return torch.cat(tensors, *args, **kwargs)
32 |
33 |
34 | th = TorchHijackForUnet()
35 |
36 |
37 | # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
38 | def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
39 |
40 | if isinstance(cond, dict):
41 | for y in cond.keys():
42 | cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
43 |
44 | with devices.autocast():
45 | return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
46 |
47 |
48 | class GELUHijack(torch.nn.GELU, torch.nn.Module):
49 | def __init__(self, *args, **kwargs):
50 | torch.nn.GELU.__init__(self, *args, **kwargs)
51 | def forward(self, x):
52 | if devices.unet_needs_upcast:
53 | return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
54 | else:
55 | return torch.nn.GELU.forward(self, x)
56 |
57 |
58 | ddpm_edit_hijack = None
59 | def hijack_ddpm_edit():
60 | global ddpm_edit_hijack
61 | if not ddpm_edit_hijack:
62 | CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
63 | CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
64 | ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
65 |
66 |
67 | unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
68 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
69 | CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
70 | if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
71 | CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
72 | CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
73 | CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
74 |
75 | first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
76 | first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
77 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
78 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
79 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
80 |
--------------------------------------------------------------------------------
/modules/models/diffusion/uni_pc/sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 |
5 | from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
6 | from modules import shared, devices
7 |
8 |
9 | class UniPCSampler(object):
10 | def __init__(self, model, **kwargs):
11 | super().__init__()
12 | self.model = model
13 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
14 | self.before_sample = None
15 | self.after_sample = None
16 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
17 |
18 | def register_buffer(self, name, attr):
19 | if type(attr) == torch.Tensor:
20 | if attr.device != devices.device:
21 | attr = attr.to(devices.device)
22 | setattr(self, name, attr)
23 |
24 | def set_hooks(self, before_sample, after_sample, after_update):
25 | self.before_sample = before_sample
26 | self.after_sample = after_sample
27 | self.after_update = after_update
28 |
29 | @torch.no_grad()
30 | def sample(self,
31 | S,
32 | batch_size,
33 | shape,
34 | conditioning=None,
35 | callback=None,
36 | normals_sequence=None,
37 | img_callback=None,
38 | quantize_x0=False,
39 | eta=0.,
40 | mask=None,
41 | x0=None,
42 | temperature=1.,
43 | noise_dropout=0.,
44 | score_corrector=None,
45 | corrector_kwargs=None,
46 | verbose=True,
47 | x_T=None,
48 | log_every_t=100,
49 | unconditional_guidance_scale=1.,
50 | unconditional_conditioning=None,
51 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
52 | **kwargs
53 | ):
54 | if conditioning is not None:
55 | if isinstance(conditioning, dict):
56 | ctmp = conditioning[list(conditioning.keys())[0]]
57 | while isinstance(ctmp, list):
58 | ctmp = ctmp[0]
59 | cbs = ctmp.shape[0]
60 | if cbs != batch_size:
61 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
62 |
63 | elif isinstance(conditioning, list):
64 | for ctmp in conditioning:
65 | if ctmp.shape[0] != batch_size:
66 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
67 |
68 | else:
69 | if conditioning.shape[0] != batch_size:
70 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
71 |
72 | # sampling
73 | C, H, W = shape
74 | size = (batch_size, C, H, W)
75 | # print(f'Data shape for UniPC sampling is {size}')
76 |
77 | device = self.model.betas.device
78 | if x_T is None:
79 | img = torch.randn(size, device=device)
80 | else:
81 | img = x_T
82 |
83 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
84 |
85 | # SD 1.X is "noise", SD 2.X is "v"
86 | model_type = "v" if self.model.parameterization == "v" else "noise"
87 |
88 | model_fn = model_wrapper(
89 | lambda x, t, c: self.model.apply_model(x, t, c),
90 | ns,
91 | model_type=model_type,
92 | guidance_type="classifier-free",
93 | #condition=conditioning,
94 | #unconditional_condition=unconditional_conditioning,
95 | guidance_scale=unconditional_guidance_scale,
96 | )
97 |
98 | uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
99 | x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
100 |
101 | return x.to(device), None
102 |
--------------------------------------------------------------------------------
/modules/gfpgan_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import traceback
4 |
5 | import facexlib
6 | import gfpgan
7 |
8 | import modules.face_restoration
9 | from modules import paths, shared, devices, modelloader
10 |
11 | model_dir = "GFPGAN"
12 | user_path = None
13 | model_path = os.path.join(paths.models_path, model_dir)
14 | model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
15 | have_gfpgan = False
16 | loaded_gfpgan_model = None
17 |
18 |
19 | def gfpgann():
20 | global loaded_gfpgan_model
21 | global model_path
22 | if loaded_gfpgan_model is not None:
23 | loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
24 | return loaded_gfpgan_model
25 |
26 | if gfpgan_constructor is None:
27 | return None
28 |
29 | models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
30 | if len(models) == 1 and "http" in models[0]:
31 | model_file = models[0]
32 | elif len(models) != 0:
33 | latest_file = max(models, key=os.path.getctime)
34 | model_file = latest_file
35 | else:
36 | print("Unable to load gfpgan model!")
37 | return None
38 | if hasattr(facexlib.detection.retinaface, 'device'):
39 | facexlib.detection.retinaface.device = devices.device_gfpgan
40 | model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
41 | loaded_gfpgan_model = model
42 |
43 | return model
44 |
45 |
46 | def send_model_to(model, device):
47 | model.gfpgan.to(device)
48 | model.face_helper.face_det.to(device)
49 | model.face_helper.face_parse.to(device)
50 |
51 |
52 | def gfpgan_fix_faces(np_image):
53 | model = gfpgann()
54 | if model is None:
55 | return np_image
56 |
57 | send_model_to(model, devices.device_gfpgan)
58 |
59 | np_image_bgr = np_image[:, :, ::-1]
60 | cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
61 | np_image = gfpgan_output_bgr[:, :, ::-1]
62 |
63 | model.face_helper.clean_all()
64 |
65 | if shared.opts.face_restoration_unload:
66 | send_model_to(model, devices.cpu)
67 |
68 | return np_image
69 |
70 |
71 | gfpgan_constructor = None
72 |
73 |
74 | def setup_model(dirname):
75 | global model_path
76 | if not os.path.exists(model_path):
77 | os.makedirs(model_path)
78 |
79 | try:
80 | from gfpgan import GFPGANer
81 | from facexlib import detection, parsing # noqa: F401
82 | global user_path
83 | global have_gfpgan
84 | global gfpgan_constructor
85 |
86 | load_file_from_url_orig = gfpgan.utils.load_file_from_url
87 | facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
88 | facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
89 |
90 | def my_load_file_from_url(**kwargs):
91 | return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
92 |
93 | def facex_load_file_from_url(**kwargs):
94 | return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
95 |
96 | def facex_load_file_from_url2(**kwargs):
97 | return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
98 |
99 | gfpgan.utils.load_file_from_url = my_load_file_from_url
100 | facexlib.detection.load_file_from_url = facex_load_file_from_url
101 | facexlib.parsing.load_file_from_url = facex_load_file_from_url2
102 | user_path = dirname
103 | have_gfpgan = True
104 | gfpgan_constructor = GFPGANer
105 |
106 | class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
107 | def name(self):
108 | return "GFPGAN"
109 |
110 | def restore(self, np_image):
111 | return gfpgan_fix_faces(np_image)
112 |
113 | shared.face_restorers.append(FaceRestorerGFPGAN())
114 | except Exception:
115 | print("Error setting up GFPGAN:", file=sys.stderr)
116 | print(traceback.format_exc(), file=sys.stderr)
117 |
--------------------------------------------------------------------------------
/scripts/sd_upscale.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import modules.scripts as scripts
4 | import gradio as gr
5 | from PIL import Image
6 |
7 | from modules import processing, shared, images, devices
8 | from modules.processing import Processed
9 | from modules.shared import opts, state
10 |
11 |
12 | class Script(scripts.Script):
13 | def title(self):
14 | return "SD upscale"
15 |
16 | def show(self, is_img2img):
17 | return is_img2img
18 |
19 | def ui(self, is_img2img):
20 | info = gr.HTML("Will upscale the image by the selected scale factor; use width and height sliders to set tile size
")
21 | overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap"))
22 | scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor"))
23 | upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=self.elem_id("upscaler_index"))
24 |
25 | return [info, overlap, upscaler_index, scale_factor]
26 |
27 | def run(self, p, _, overlap, upscaler_index, scale_factor):
28 | if isinstance(upscaler_index, str):
29 | upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower())
30 | processing.fix_seed(p)
31 | upscaler = shared.sd_upscalers[upscaler_index]
32 |
33 | p.extra_generation_params["SD upscale overlap"] = overlap
34 | p.extra_generation_params["SD upscale upscaler"] = upscaler.name
35 |
36 | initial_info = None
37 | seed = p.seed
38 |
39 | init_img = p.init_images[0]
40 | init_img = images.flatten(init_img, opts.img2img_background_color)
41 |
42 | if upscaler.name != "None":
43 | img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path)
44 | else:
45 | img = init_img
46 |
47 | devices.torch_gc()
48 |
49 | grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap)
50 |
51 | batch_size = p.batch_size
52 | upscale_count = p.n_iter
53 | p.n_iter = 1
54 | p.do_not_save_grid = True
55 | p.do_not_save_samples = True
56 |
57 | work = []
58 |
59 | for _y, _h, row in grid.tiles:
60 | for tiledata in row:
61 | work.append(tiledata[2])
62 |
63 | batch_count = math.ceil(len(work) / batch_size)
64 | state.job_count = batch_count * upscale_count
65 |
66 | print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.")
67 |
68 | result_images = []
69 | for n in range(upscale_count):
70 | start_seed = seed + n
71 | p.seed = start_seed
72 |
73 | work_results = []
74 | for i in range(batch_count):
75 | p.batch_size = batch_size
76 | p.init_images = work[i * batch_size:(i + 1) * batch_size]
77 |
78 | state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}"
79 | processed = processing.process_images(p)
80 |
81 | if initial_info is None:
82 | initial_info = processed.info
83 |
84 | p.seed = processed.seed + 1
85 | work_results += processed.images
86 |
87 | image_index = 0
88 | for _y, _h, row in grid.tiles:
89 | for tiledata in row:
90 | tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
91 | image_index += 1
92 |
93 | combined_image = images.combine_grid(grid)
94 | result_images.append(combined_image)
95 |
96 | if opts.samples_save:
97 | images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p)
98 |
99 | processed = Processed(p, result_images, seed, initial_info)
100 |
101 | return processed
102 |
--------------------------------------------------------------------------------
/javascript/aspectRatioOverlay.js:
--------------------------------------------------------------------------------
1 |
2 | let currentWidth = null;
3 | let currentHeight = null;
4 | let arFrameTimeout = setTimeout(function() {}, 0);
5 |
6 | function dimensionChange(e, is_width, is_height) {
7 |
8 | if (is_width) {
9 | currentWidth = e.target.value * 1.0;
10 | }
11 | if (is_height) {
12 | currentHeight = e.target.value * 1.0;
13 | }
14 |
15 | var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block";
16 |
17 | if (!inImg2img) {
18 | return;
19 | }
20 |
21 | var targetElement = null;
22 |
23 | var tabIndex = get_tab_index('mode_img2img');
24 | if (tabIndex == 0) { // img2img
25 | targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');
26 | } else if (tabIndex == 1) { //Sketch
27 | targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
28 | } else if (tabIndex == 2) { // Inpaint
29 | targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
30 | } else if (tabIndex == 3) { // Inpaint sketch
31 | targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
32 | }
33 |
34 |
35 | if (targetElement) {
36 |
37 | var arPreviewRect = gradioApp().querySelector('#imageARPreview');
38 | if (!arPreviewRect) {
39 | arPreviewRect = document.createElement('div');
40 | arPreviewRect.id = "imageARPreview";
41 | gradioApp().appendChild(arPreviewRect);
42 | }
43 |
44 |
45 |
46 | var viewportOffset = targetElement.getBoundingClientRect();
47 |
48 | var viewportscale = Math.min(targetElement.clientWidth / targetElement.naturalWidth, targetElement.clientHeight / targetElement.naturalHeight);
49 |
50 | var scaledx = targetElement.naturalWidth * viewportscale;
51 | var scaledy = targetElement.naturalHeight * viewportscale;
52 |
53 | var cleintRectTop = (viewportOffset.top + window.scrollY);
54 | var cleintRectLeft = (viewportOffset.left + window.scrollX);
55 | var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2);
56 | var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2);
57 |
58 | var arscale = Math.min(scaledx / currentWidth, scaledy / currentHeight);
59 | var arscaledx = currentWidth * arscale;
60 | var arscaledy = currentHeight * arscale;
61 |
62 | var arRectTop = cleintRectCentreY - (arscaledy / 2);
63 | var arRectLeft = cleintRectCentreX - (arscaledx / 2);
64 | var arRectWidth = arscaledx;
65 | var arRectHeight = arscaledy;
66 |
67 | arPreviewRect.style.top = arRectTop + 'px';
68 | arPreviewRect.style.left = arRectLeft + 'px';
69 | arPreviewRect.style.width = arRectWidth + 'px';
70 | arPreviewRect.style.height = arRectHeight + 'px';
71 |
72 | clearTimeout(arFrameTimeout);
73 | arFrameTimeout = setTimeout(function() {
74 | arPreviewRect.style.display = 'none';
75 | }, 2000);
76 |
77 | arPreviewRect.style.display = 'block';
78 |
79 | }
80 |
81 | }
82 |
83 |
84 | onUiUpdate(function() {
85 | var arPreviewRect = gradioApp().querySelector('#imageARPreview');
86 | if (arPreviewRect) {
87 | arPreviewRect.style.display = 'none';
88 | }
89 | var tabImg2img = gradioApp().querySelector("#tab_img2img");
90 | if (tabImg2img) {
91 | var inImg2img = tabImg2img.style.display == "block";
92 | if (inImg2img) {
93 | let inputs = gradioApp().querySelectorAll('input');
94 | inputs.forEach(function(e) {
95 | var is_width = e.parentElement.id == "img2img_width";
96 | var is_height = e.parentElement.id == "img2img_height";
97 |
98 | if ((is_width || is_height) && !e.classList.contains('scrollwatch')) {
99 | e.addEventListener('input', function(e) {
100 | dimensionChange(e, is_width, is_height);
101 | });
102 | e.classList.add('scrollwatch');
103 | }
104 | if (is_width) {
105 | currentWidth = e.value * 1.0;
106 | }
107 | if (is_height) {
108 | currentHeight = e.value * 1.0;
109 | }
110 | });
111 | }
112 | }
113 | });
114 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: Bug Report
2 | description: You think somethings is broken in the UI
3 | title: "[Bug]: "
4 | labels: ["bug-report"]
5 |
6 | body:
7 | - type: checkboxes
8 | attributes:
9 | label: Is there an existing issue for this?
10 | description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
11 | options:
12 | - label: I have searched the existing issues and checked the recent builds/commits
13 | required: true
14 | - type: markdown
15 | attributes:
16 | value: |
17 | *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
18 | - type: textarea
19 | id: what-did
20 | attributes:
21 | label: What happened?
22 | description: Tell us what happened in a very clear and simple way
23 | validations:
24 | required: true
25 | - type: textarea
26 | id: steps
27 | attributes:
28 | label: Steps to reproduce the problem
29 | description: Please provide us with precise step by step information on how to reproduce the bug
30 | value: |
31 | 1. Go to ....
32 | 2. Press ....
33 | 3. ...
34 | validations:
35 | required: true
36 | - type: textarea
37 | id: what-should
38 | attributes:
39 | label: What should have happened?
40 | description: Tell what you think the normal behavior should be
41 | validations:
42 | required: true
43 | - type: input
44 | id: commit
45 | attributes:
46 | label: Commit where the problem happens
47 | description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
48 | validations:
49 | required: true
50 | - type: dropdown
51 | id: py-version
52 | attributes:
53 | label: What Python version are you running on ?
54 | multiple: false
55 | options:
56 | - Python 3.10.x
57 | - Python 3.11.x (above, no supported yet)
58 | - Python 3.9.x (below, no recommended)
59 | - type: dropdown
60 | id: platforms
61 | attributes:
62 | label: What platforms do you use to access the UI ?
63 | multiple: true
64 | options:
65 | - Windows
66 | - Linux
67 | - MacOS
68 | - iOS
69 | - Android
70 | - Other/Cloud
71 | - type: dropdown
72 | id: device
73 | attributes:
74 | label: What device are you running WebUI on?
75 | multiple: true
76 | options:
77 | - Nvidia GPUs (RTX 20 above)
78 | - Nvidia GPUs (GTX 16 below)
79 | - AMD GPUs (RX 6000 above)
80 | - AMD GPUs (RX 5000 below)
81 | - CPU
82 | - Other GPUs
83 | - type: dropdown
84 | id: browsers
85 | attributes:
86 | label: What browsers do you use to access the UI ?
87 | multiple: true
88 | options:
89 | - Mozilla Firefox
90 | - Google Chrome
91 | - Brave
92 | - Apple Safari
93 | - Microsoft Edge
94 | - type: textarea
95 | id: cmdargs
96 | attributes:
97 | label: Command Line Arguments
98 | description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
99 | render: Shell
100 | validations:
101 | required: true
102 | - type: textarea
103 | id: extensions
104 | attributes:
105 | label: List of extensions
106 | description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise.
107 | validations:
108 | required: true
109 | - type: textarea
110 | id: logs
111 | attributes:
112 | label: Console logs
113 | description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
114 | render: Shell
115 | validations:
116 | required: true
117 | - type: textarea
118 | id: misc
119 | attributes:
120 | label: Additional information
121 | description: Please provide us with any relevant additional info or context.
122 |
--------------------------------------------------------------------------------
/modules/postprocessing.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from PIL import Image
4 |
5 | from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
6 | from modules.shared import opts
7 |
8 |
9 | def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
10 | devices.torch_gc()
11 |
12 | shared.state.begin()
13 | shared.state.job = 'extras'
14 |
15 | image_data = []
16 | image_names = []
17 | outputs = []
18 |
19 | if extras_mode == 1:
20 | for img in image_folder:
21 | if isinstance(img, Image.Image):
22 | image = img
23 | fn = ''
24 | else:
25 | image = Image.open(os.path.abspath(img.name))
26 | fn = os.path.splitext(img.orig_name)[0]
27 | image_data.append(image)
28 | image_names.append(fn)
29 | elif extras_mode == 2:
30 | assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
31 | assert input_dir, 'input directory not selected'
32 |
33 | image_list = shared.listfiles(input_dir)
34 | for filename in image_list:
35 | try:
36 | image = Image.open(filename)
37 | except Exception:
38 | continue
39 | image_data.append(image)
40 | image_names.append(filename)
41 | else:
42 | assert image, 'image not selected'
43 |
44 | image_data.append(image)
45 | image_names.append(None)
46 |
47 | if extras_mode == 2 and output_dir != '':
48 | outpath = output_dir
49 | else:
50 | outpath = opts.outdir_samples or opts.outdir_extras_samples
51 |
52 | infotext = ''
53 |
54 | for image, name in zip(image_data, image_names):
55 | shared.state.textinfo = name
56 |
57 | existing_pnginfo = image.info or {}
58 |
59 | pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
60 |
61 | scripts.scripts_postproc.run(pp, args)
62 |
63 | if opts.use_original_name_batch and name is not None:
64 | basename = os.path.splitext(os.path.basename(name))[0]
65 | else:
66 | basename = ''
67 |
68 | infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
69 |
70 | if opts.enable_pnginfo:
71 | pp.image.info = existing_pnginfo
72 | pp.image.info["postprocessing"] = infotext
73 |
74 | if save_output:
75 | images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
76 |
77 | if extras_mode != 2 or show_extras_results:
78 | outputs.append(pp.image)
79 |
80 | devices.torch_gc()
81 |
82 | return outputs, ui_common.plaintext_to_html(infotext), ''
83 |
84 |
85 | def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
86 | """old handler for API"""
87 |
88 | args = scripts.scripts_postproc.create_args_for_run({
89 | "Upscale": {
90 | "upscale_mode": resize_mode,
91 | "upscale_by": upscaling_resize,
92 | "upscale_to_width": upscaling_resize_w,
93 | "upscale_to_height": upscaling_resize_h,
94 | "upscale_crop": upscaling_crop,
95 | "upscaler_1_name": extras_upscaler_1,
96 | "upscaler_2_name": extras_upscaler_2,
97 | "upscaler_2_visibility": extras_upscaler_2_visibility,
98 | },
99 | "GFPGAN": {
100 | "gfpgan_visibility": gfpgan_visibility,
101 | },
102 | "CodeFormer": {
103 | "codeformer_visibility": codeformer_visibility,
104 | "codeformer_weight": codeformer_weight,
105 | },
106 | })
107 |
108 | return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
109 |
--------------------------------------------------------------------------------
/modules/upscaler.py:
--------------------------------------------------------------------------------
1 | import os
2 | from abc import abstractmethod
3 |
4 | import PIL
5 | from PIL import Image
6 |
7 | import modules.shared
8 | from modules import modelloader, shared
9 |
10 | LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
11 | NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
12 |
13 |
14 | class Upscaler:
15 | name = None
16 | model_path = None
17 | model_name = None
18 | model_url = None
19 | enable = True
20 | filter = None
21 | model = None
22 | user_path = None
23 | scalers: []
24 | tile = True
25 |
26 | def __init__(self, create_dirs=False):
27 | self.mod_pad_h = None
28 | self.tile_size = modules.shared.opts.ESRGAN_tile
29 | self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
30 | self.device = modules.shared.device
31 | self.img = None
32 | self.output = None
33 | self.scale = 1
34 | self.half = not modules.shared.cmd_opts.no_half
35 | self.pre_pad = 0
36 | self.mod_scale = None
37 | self.model_download_path = None
38 |
39 | if self.model_path is None and self.name:
40 | self.model_path = os.path.join(shared.models_path, self.name)
41 | if self.model_path and create_dirs:
42 | os.makedirs(self.model_path, exist_ok=True)
43 |
44 | try:
45 | import cv2 # noqa: F401
46 | self.can_tile = True
47 | except Exception:
48 | pass
49 |
50 | @abstractmethod
51 | def do_upscale(self, img: PIL.Image, selected_model: str):
52 | return img
53 |
54 | def upscale(self, img: PIL.Image, scale, selected_model: str = None):
55 | self.scale = scale
56 | dest_w = int(img.width * scale)
57 | dest_h = int(img.height * scale)
58 |
59 | for _ in range(3):
60 | shape = (img.width, img.height)
61 |
62 | img = self.do_upscale(img, selected_model)
63 |
64 | if shape == (img.width, img.height):
65 | break
66 |
67 | if img.width >= dest_w and img.height >= dest_h:
68 | break
69 |
70 | if img.width != dest_w or img.height != dest_h:
71 | img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
72 |
73 | return img
74 |
75 | @abstractmethod
76 | def load_model(self, path: str):
77 | pass
78 |
79 | def find_models(self, ext_filter=None) -> list:
80 | return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
81 |
82 | def update_status(self, prompt):
83 | print(f"\nextras: {prompt}", file=shared.progress_print_out)
84 |
85 |
86 | class UpscalerData:
87 | name = None
88 | data_path = None
89 | scale: int = 4
90 | scaler: Upscaler = None
91 | model: None
92 |
93 | def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
94 | self.name = name
95 | self.data_path = path
96 | self.local_data_path = path
97 | self.scaler = upscaler
98 | self.scale = scale
99 | self.model = model
100 |
101 |
102 | class UpscalerNone(Upscaler):
103 | name = "None"
104 | scalers = []
105 |
106 | def load_model(self, path):
107 | pass
108 |
109 | def do_upscale(self, img, selected_model=None):
110 | return img
111 |
112 | def __init__(self, dirname=None):
113 | super().__init__(False)
114 | self.scalers = [UpscalerData("None", None, self)]
115 |
116 |
117 | class UpscalerLanczos(Upscaler):
118 | scalers = []
119 |
120 | def do_upscale(self, img, selected_model=None):
121 | return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
122 |
123 | def load_model(self, _):
124 | pass
125 |
126 | def __init__(self, dirname=None):
127 | super().__init__(False)
128 | self.name = "Lanczos"
129 | self.scalers = [UpscalerData("Lanczos", None, self)]
130 |
131 |
132 | class UpscalerNearest(Upscaler):
133 | scalers = []
134 |
135 | def do_upscale(self, img, selected_model=None):
136 | return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
137 |
138 | def load_model(self, _):
139 | pass
140 |
141 | def __init__(self, dirname=None):
142 | super().__init__(False)
143 | self.name = "Nearest"
144 | self.scalers = [UpscalerData("Nearest", None, self)]
145 |
--------------------------------------------------------------------------------
/modules/sd_hijack_inpainting.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import ldm.models.diffusion.ddpm
4 | import ldm.models.diffusion.ddim
5 | import ldm.models.diffusion.plms
6 |
7 | from ldm.models.diffusion.ddim import noise_like
8 | from ldm.models.diffusion.sampling_util import norm_thresholding
9 |
10 |
11 | @torch.no_grad()
12 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
13 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
14 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
15 | b, *_, device = *x.shape, x.device
16 |
17 | def get_model_output(x, t):
18 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
19 | e_t = self.model.apply_model(x, t, c)
20 | else:
21 | x_in = torch.cat([x] * 2)
22 | t_in = torch.cat([t] * 2)
23 |
24 | if isinstance(c, dict):
25 | assert isinstance(unconditional_conditioning, dict)
26 | c_in = {}
27 | for k in c:
28 | if isinstance(c[k], list):
29 | c_in[k] = [
30 | torch.cat([unconditional_conditioning[k][i], c[k][i]])
31 | for i in range(len(c[k]))
32 | ]
33 | else:
34 | c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
35 | else:
36 | c_in = torch.cat([unconditional_conditioning, c])
37 |
38 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
39 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
40 |
41 | if score_corrector is not None:
42 | assert self.model.parameterization == "eps"
43 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
44 |
45 | return e_t
46 |
47 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
48 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
49 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
50 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
51 |
52 | def get_x_prev_and_pred_x0(e_t, index):
53 | # select parameters corresponding to the currently considered timestep
54 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
55 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
56 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
57 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
58 |
59 | # current prediction for x_0
60 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
61 | if quantize_denoised:
62 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
63 | if dynamic_threshold is not None:
64 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
65 | # direction pointing to x_t
66 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
67 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
68 | if noise_dropout > 0.:
69 | noise = torch.nn.functional.dropout(noise, p=noise_dropout)
70 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
71 | return x_prev, pred_x0
72 |
73 | e_t = get_model_output(x, t)
74 | if len(old_eps) == 0:
75 | # Pseudo Improved Euler (2nd order)
76 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
77 | e_t_next = get_model_output(x_prev, t_next)
78 | e_t_prime = (e_t + e_t_next) / 2
79 | elif len(old_eps) == 1:
80 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
81 | e_t_prime = (3 * e_t - old_eps[-1]) / 2
82 | elif len(old_eps) == 2:
83 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
84 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
85 | elif len(old_eps) >= 3:
86 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
87 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
88 |
89 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
90 |
91 | return x_prev, pred_x0, e_t
92 |
93 |
94 | def do_inpainting_hijack():
95 | # p_sample_plms is needed because PLMS can't work with dicts as conditionings
96 |
97 | ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
98 |
--------------------------------------------------------------------------------
/modules/mac_specific.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import platform
3 | from modules.sd_hijack_utils import CondFunc
4 | from packaging import version
5 |
6 |
7 | # has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
8 | # check `getattr` and try it for compatibility
9 | def check_for_mps() -> bool:
10 | if not getattr(torch, 'has_mps', False):
11 | return False
12 | try:
13 | torch.zeros(1).to(torch.device("mps"))
14 | return True
15 | except Exception:
16 | return False
17 | has_mps = check_for_mps()
18 |
19 |
20 | # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
21 | def cumsum_fix(input, cumsum_func, *args, **kwargs):
22 | if input.device.type == 'mps':
23 | output_dtype = kwargs.get('dtype', input.dtype)
24 | if output_dtype == torch.int64:
25 | return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
26 | elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
27 | return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
28 | return cumsum_func(input, *args, **kwargs)
29 |
30 |
31 | if has_mps:
32 | # MPS fix for randn in torchsde
33 | CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
34 |
35 | if platform.mac_ver()[0].startswith("13.2."):
36 | # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
37 | CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
38 |
39 | if version.parse(torch.__version__) < version.parse("1.13"):
40 | # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
41 |
42 | # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
43 | CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
44 | lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
45 | # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
46 | CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
47 | lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
48 | # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
49 | CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
50 | elif version.parse(torch.__version__) > version.parse("1.13.1"):
51 | cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
52 | cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
53 | CondFunc('torch.cumsum', cumsum_fix_func, None)
54 | CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
55 | CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
56 |
57 | # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
58 | CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
59 |
60 | # MPS workaround for https://github.com/pytorch/pytorch/issues/92311
61 | if platform.processor() == 'i386':
62 | for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
63 | CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
64 |
--------------------------------------------------------------------------------
/modules/lowvram.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from modules import devices
3 |
4 | module_in_gpu = None
5 | cpu = torch.device("cpu")
6 |
7 |
8 | def send_everything_to_cpu():
9 | global module_in_gpu
10 |
11 | if module_in_gpu is not None:
12 | module_in_gpu.to(cpu)
13 |
14 | module_in_gpu = None
15 |
16 |
17 | def setup_for_low_vram(sd_model, use_medvram):
18 | parents = {}
19 |
20 | def send_me_to_gpu(module, _):
21 | """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
22 | we add this as forward_pre_hook to a lot of modules and this way all but one of them will
23 | be in CPU
24 | """
25 | global module_in_gpu
26 |
27 | module = parents.get(module, module)
28 |
29 | if module_in_gpu == module:
30 | return
31 |
32 | if module_in_gpu is not None:
33 | module_in_gpu.to(cpu)
34 |
35 | module.to(devices.device)
36 | module_in_gpu = module
37 |
38 | # see below for register_forward_pre_hook;
39 | # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
40 | # useless here, and we just replace those methods
41 |
42 | first_stage_model = sd_model.first_stage_model
43 | first_stage_model_encode = sd_model.first_stage_model.encode
44 | first_stage_model_decode = sd_model.first_stage_model.decode
45 |
46 | def first_stage_model_encode_wrap(x):
47 | send_me_to_gpu(first_stage_model, None)
48 | return first_stage_model_encode(x)
49 |
50 | def first_stage_model_decode_wrap(z):
51 | send_me_to_gpu(first_stage_model, None)
52 | return first_stage_model_decode(z)
53 |
54 | # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
55 | if hasattr(sd_model.cond_stage_model, 'model'):
56 | sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
57 |
58 | # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
59 | # send the model to GPU. Then put modules back. the modules will be in CPU.
60 | stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
61 | sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
62 | sd_model.to(devices.device)
63 | sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
64 |
65 | # register hooks for those the first three models
66 | sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
67 | sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
68 | sd_model.first_stage_model.encode = first_stage_model_encode_wrap
69 | sd_model.first_stage_model.decode = first_stage_model_decode_wrap
70 | if sd_model.depth_model:
71 | sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
72 | if sd_model.embedder:
73 | sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
74 | parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
75 |
76 | if hasattr(sd_model.cond_stage_model, 'model'):
77 | sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
78 | del sd_model.cond_stage_model.transformer
79 |
80 | if use_medvram:
81 | sd_model.model.register_forward_pre_hook(send_me_to_gpu)
82 | else:
83 | diff_model = sd_model.model.diffusion_model
84 |
85 | # the third remaining model is still too big for 4 GB, so we also do the same for its submodules
86 | # so that only one of them is in GPU at a time
87 | stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
88 | diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
89 | sd_model.model.to(devices.device)
90 | diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
91 |
92 | # install hooks for bits of third model
93 | diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
94 | for block in diff_model.input_blocks:
95 | block.register_forward_pre_hook(send_me_to_gpu)
96 | diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
97 | for block in diff_model.output_blocks:
98 | block.register_forward_pre_hook(send_me_to_gpu)
99 |
--------------------------------------------------------------------------------