├── test
├── __init__.py
├── basic_features
│ ├── __init__.py
│ ├── extras_test.py
│ ├── utils_test.py
│ ├── img2img_test.py
│ └── txt2img_test.py
├── test_files
│ ├── empty.pt
│ ├── mask_basic.png
│ └── img2img_basic.png
└── server_poll.py
├── models
├── VAE
│ └── Put VAE here.txt
├── Stable-diffusion
│ └── Put Stable Diffusion checkpoints here.txt
├── deepbooru
│ └── Put your deepbooru release project folder here.txt
└── VAE-approx
│ └── model.pt
├── extensions
└── put extensions here.txt
├── localizations
└── Put localization files here.txt
├── textual_inversion_templates
├── none.txt
├── style.txt
├── subject.txt
├── hypernetwork.txt
├── style_filewords.txt
└── subject_filewords.txt
├── embeddings
└── Place Textual Inversion embeddings here.txt
├── screenshot.png
├── html
├── card-no-preview.png
├── extra-networks-no-cards.html
├── extra-networks-card.html
├── footer.html
└── image-update.svg
├── webui-user.bat
├── .pylintrc
├── modules
├── textual_inversion
│ ├── test_embedding.png
│ ├── logging.py
│ ├── ui.py
│ └── learn_schedule.py
├── import_hook.py
├── sd_hijack_ip2p.py
├── face_restoration.py
├── shared_items.py
├── timer.py
├── script_loading.py
├── ngrok.py
├── localization.py
├── extra_networks_hypernet.py
├── ui_extra_networks_hypernets.py
├── errors.py
├── ui_extra_networks_textual_inversion.py
├── sd_hijack_utils.py
├── sd_hijack_checkpoint.py
├── sd_hijack_open_clip.py
├── ui_extra_networks_checkpoints.py
├── sd_samplers.py
├── sd_hijack_xlmr.py
├── scripts_auto_postprocessing.py
├── hypernetworks
│ └── ui.py
├── ui_components.py
├── sd_vae_approx.py
├── hashes.py
├── sd_samplers_common.py
├── paths.py
├── ui_tempdir.py
├── memmon.py
├── ui_postprocessing.py
├── txt2img.py
├── deepbooru.py
├── styles.py
├── mac_specific.py
├── masking.py
├── sd_hijack_clip_old.py
├── extensions.py
├── call_queue.py
├── progress.py
├── sd_hijack_unet.py
├── gfpgan_model.py
├── postprocessing.py
├── sd_models_config.py
├── lowvram.py
└── upscaler.py
├── environment-wsl2.yaml
├── .github
├── ISSUE_TEMPLATE
│ ├── config.yml
│ ├── feature_request.yml
│ └── bug_report.yml
├── workflows
│ ├── run_tests.yaml
│ └── on_pull_request.yaml
└── pull_request_template.md
├── extensions-builtin
├── Lora
│ ├── preload.py
│ ├── extra_networks_lora.py
│ ├── ui_extra_networks_lora.py
│ └── scripts
│ │ └── lora_script.py
├── LDSR
│ ├── preload.py
│ └── scripts
│ │ └── ldsr_model.py
├── ScuNET
│ ├── preload.py
│ └── scripts
│ │ └── scunet_model.py
├── SwinIR
│ └── preload.py
└── prompt-bracket-checker
│ └── javascript
│ └── prompt-bracket-checker.js
├── javascript
├── textualInversion.js
├── imageParams.js
├── hires_fix.js
├── generationParams.js
├── extensions.js
├── imageMaskFix.js
├── notification.js
├── dragdrop.js
├── edit-attention.js
├── extraNetworks.js
└── aspectRatioOverlay.js
├── requirements.txt
├── .gitignore
├── requirements_versions.txt
├── CODEOWNERS
├── webui-macos-env.sh
├── scripts
├── postprocessing_gfpgan.py
├── custom_code.py
├── postprocessing_codeformer.py
├── loopback.py
└── sd_upscale.py
├── webui-user.sh
├── configs
├── v1-inference.yaml
├── alt-diffusion-inference.yaml
├── v1-inpainting-inference.yaml
└── instruct-pix2pix.yaml
├── webui.bat
└── script.js
/test/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/VAE/Put VAE here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/basic_features/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/extensions/put extensions here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/localizations/Put localization files here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/textual_inversion_templates/none.txt:
--------------------------------------------------------------------------------
1 | picture
2 |
--------------------------------------------------------------------------------
/embeddings/Place Textual Inversion embeddings here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/Stable-diffusion/Put Stable Diffusion checkpoints here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/deepbooru/Put your deepbooru release project folder here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mthli/stable-diffusion-webui/master/screenshot.png
--------------------------------------------------------------------------------
/html/card-no-preview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mthli/stable-diffusion-webui/master/html/card-no-preview.png
--------------------------------------------------------------------------------
/models/VAE-approx/model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mthli/stable-diffusion-webui/master/models/VAE-approx/model.pt
--------------------------------------------------------------------------------
/test/test_files/empty.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mthli/stable-diffusion-webui/master/test/test_files/empty.pt
--------------------------------------------------------------------------------
/test/test_files/mask_basic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mthli/stable-diffusion-webui/master/test/test_files/mask_basic.png
--------------------------------------------------------------------------------
/test/test_files/img2img_basic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mthli/stable-diffusion-webui/master/test/test_files/img2img_basic.png
--------------------------------------------------------------------------------
/webui-user.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | set PYTHON=
4 | set GIT=
5 | set VENV_DIR=
6 | set COMMANDLINE_ARGS=
7 |
8 | call webui.bat
9 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | # See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html
2 | [MESSAGES CONTROL]
3 | disable=C,R,W,E,I
4 |
--------------------------------------------------------------------------------
/modules/textual_inversion/test_embedding.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mthli/stable-diffusion-webui/master/modules/textual_inversion/test_embedding.png
--------------------------------------------------------------------------------
/html/extra-networks-no-cards.html:
--------------------------------------------------------------------------------
1 |
2 |
Nothing here. Add some content to the following directories:
3 |
4 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/environment-wsl2.yaml:
--------------------------------------------------------------------------------
1 | name: automatic
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.10
7 | - pip=22.2.2
8 | - cudatoolkit=11.3
9 | - pytorch=1.12.1
10 | - torchvision=0.13.1
11 | - numpy=1.23.1
--------------------------------------------------------------------------------
/modules/import_hook.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | # this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
4 | if "--xformers" not in "".join(sys.argv):
5 | sys.modules["xformers"] = None
6 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: WebUI Community Support
4 | url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
5 | about: Please ask and answer questions here.
6 |
--------------------------------------------------------------------------------
/extensions-builtin/Lora/preload.py:
--------------------------------------------------------------------------------
1 | import os
2 | from modules import paths
3 |
4 |
5 | def preload(parser):
6 | parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
7 |
--------------------------------------------------------------------------------
/extensions-builtin/LDSR/preload.py:
--------------------------------------------------------------------------------
1 | import os
2 | from modules import paths
3 |
4 |
5 | def preload(parser):
6 | parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
7 |
--------------------------------------------------------------------------------
/extensions-builtin/ScuNET/preload.py:
--------------------------------------------------------------------------------
1 | import os
2 | from modules import paths
3 |
4 |
5 | def preload(parser):
6 | parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
7 |
--------------------------------------------------------------------------------
/extensions-builtin/SwinIR/preload.py:
--------------------------------------------------------------------------------
1 | import os
2 | from modules import paths
3 |
4 |
5 | def preload(parser):
6 | parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
7 |
--------------------------------------------------------------------------------
/html/extra-networks-card.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
7 |
{search_term}
8 |
9 |
{name}
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/html/footer.html:
--------------------------------------------------------------------------------
1 |
10 |
11 |
12 | {versions}
13 |
14 |
--------------------------------------------------------------------------------
/modules/sd_hijack_ip2p.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os.path
3 | import sys
4 | import gc
5 | import time
6 |
7 | def should_hijack_ip2p(checkpoint_info):
8 | from modules import sd_models_config
9 |
10 | ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
11 | cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
12 |
13 | return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
14 |
--------------------------------------------------------------------------------
/javascript/textualInversion.js:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | function start_training_textual_inversion(){
5 | gradioApp().querySelector('#ti_error').innerHTML=''
6 |
7 | var id = randomId()
8 | requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
9 | gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
10 | })
11 |
12 | var res = args_to_array(arguments)
13 |
14 | res[0] = id
15 |
16 | return res
17 | }
18 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | blendmodes
2 | accelerate
3 | basicsr
4 | fonts
5 | font-roboto
6 | gfpgan
7 | gradio==3.16.2
8 | invisible-watermark
9 | numpy
10 | omegaconf
11 | opencv-contrib-python
12 | requests
13 | piexif
14 | Pillow
15 | pytorch_lightning==1.7.7
16 | realesrgan
17 | scikit-image>=0.19
18 | timm==0.4.12
19 | transformers==4.25.1
20 | torch
21 | einops
22 | jsonmerge
23 | clean-fid
24 | resize-right
25 | torchdiffeq
26 | kornia
27 | lark
28 | inflection
29 | GitPython
30 | torchsde
31 | safetensors
32 | psutil
33 |
--------------------------------------------------------------------------------
/modules/face_restoration.py:
--------------------------------------------------------------------------------
1 | from modules import shared
2 |
3 |
4 | class FaceRestoration:
5 | def name(self):
6 | return "None"
7 |
8 | def restore(self, np_image):
9 | return np_image
10 |
11 |
12 | def restore_faces(np_image):
13 | face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
14 | if len(face_restorers) == 0:
15 | return np_image
16 |
17 | face_restorer = face_restorers[0]
18 |
19 | return face_restorer.restore(np_image)
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.ckpt
3 | *.safetensors
4 | *.pth
5 | /ESRGAN/*
6 | /SwinIR/*
7 | /repositories
8 | /venv
9 | /tmp
10 | /model.ckpt
11 | /models/**/*
12 | /GFPGANv1.3.pth
13 | /gfpgan/weights/*.pth
14 | /ui-config.json
15 | /outputs
16 | /config.json
17 | /log
18 | /webui.settings.bat
19 | /embeddings
20 | /styles.csv
21 | /params.txt
22 | /styles.csv.bak
23 | /webui-user.bat
24 | /webui-user.sh
25 | /interrogate
26 | /user.css
27 | /.idea
28 | notification.mp3
29 | /SwinIR
30 | /textual_inversion
31 | .vscode
32 | /extensions
33 | /test/stdout.txt
34 | /test/stderr.txt
35 | /cache.json
36 |
--------------------------------------------------------------------------------
/modules/shared_items.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def realesrgan_models_names():
4 | import modules.realesrgan_model
5 | return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
6 |
7 |
8 | def postprocessing_scripts():
9 | import modules.scripts
10 |
11 | return modules.scripts.scripts_postproc.scripts
12 |
13 |
14 | def sd_vae_items():
15 | import modules.sd_vae
16 |
17 | return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
18 |
19 |
20 | def refresh_vae_list():
21 | import modules.sd_vae
22 |
23 | modules.sd_vae.refresh_vae_list()
24 |
--------------------------------------------------------------------------------
/requirements_versions.txt:
--------------------------------------------------------------------------------
1 | blendmodes==2022
2 | transformers==4.25.1
3 | accelerate==0.12.0
4 | basicsr==1.4.2
5 | gfpgan==1.3.8
6 | gradio==3.16.2
7 | numpy==1.23.3
8 | Pillow==9.4.0
9 | realesrgan==0.3.0
10 | torch
11 | omegaconf==2.2.3
12 | pytorch_lightning==1.7.6
13 | scikit-image==0.19.2
14 | fonts
15 | font-roboto
16 | timm==0.6.7
17 | piexif==1.1.3
18 | einops==0.4.1
19 | jsonmerge==1.8.0
20 | clean-fid==0.1.29
21 | resize-right==0.0.2
22 | torchdiffeq==0.2.3
23 | kornia==0.6.7
24 | lark==1.1.2
25 | inflection==0.5.1
26 | GitPython==3.1.27
27 | torchsde==0.2.5
28 | safetensors==0.2.7
29 | httpcore<=0.15
30 | fastapi==0.90.1
31 |
--------------------------------------------------------------------------------
/textual_inversion_templates/style.txt:
--------------------------------------------------------------------------------
1 | a painting, art by [name]
2 | a rendering, art by [name]
3 | a cropped painting, art by [name]
4 | the painting, art by [name]
5 | a clean painting, art by [name]
6 | a dirty painting, art by [name]
7 | a dark painting, art by [name]
8 | a picture, art by [name]
9 | a cool painting, art by [name]
10 | a close-up painting, art by [name]
11 | a bright painting, art by [name]
12 | a cropped painting, art by [name]
13 | a good painting, art by [name]
14 | a close-up painting, art by [name]
15 | a rendition, art by [name]
16 | a nice painting, art by [name]
17 | a small painting, art by [name]
18 | a weird painting, art by [name]
19 | a large painting, art by [name]
20 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @AUTOMATIC1111
2 |
3 | # if you were managing a localization and were removed from this file, this is because
4 | # the intended way to do localizations now is via extensions. See:
5 | # https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions
6 | # Make a repo with your localization and since you are still listed as a collaborator
7 | # you can add it to the wiki page yourself. This change is because some people complained
8 | # the git commit log is cluttered with things unrelated to almost everyone and
9 | # because I believe this is the best overall for the project to handle localizations almost
10 | # entirely without my oversight.
11 |
12 |
13 |
--------------------------------------------------------------------------------
/javascript/imageParams.js:
--------------------------------------------------------------------------------
1 | window.onload = (function(){
2 | window.addEventListener('drop', e => {
3 | const target = e.composedPath()[0];
4 | const idx = selected_gallery_index();
5 | if (target.placeholder.indexOf("Prompt") == -1) return;
6 |
7 | let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
8 |
9 | e.stopPropagation();
10 | e.preventDefault();
11 | const imgParent = gradioApp().getElementById(prompt_target);
12 | const files = e.dataTransfer.files;
13 | const fileInput = imgParent.querySelector('input[type="file"]');
14 | if ( fileInput ) {
15 | fileInput.files = files;
16 | fileInput.dispatchEvent(new Event('change'));
17 | }
18 | });
19 | });
20 |
--------------------------------------------------------------------------------
/textual_inversion_templates/subject.txt:
--------------------------------------------------------------------------------
1 | a photo of a [name]
2 | a rendering of a [name]
3 | a cropped photo of the [name]
4 | the photo of a [name]
5 | a photo of a clean [name]
6 | a photo of a dirty [name]
7 | a dark photo of the [name]
8 | a photo of my [name]
9 | a photo of the cool [name]
10 | a close-up photo of a [name]
11 | a bright photo of the [name]
12 | a cropped photo of a [name]
13 | a photo of the [name]
14 | a good photo of the [name]
15 | a photo of one [name]
16 | a close-up photo of the [name]
17 | a rendition of the [name]
18 | a photo of the clean [name]
19 | a rendition of a [name]
20 | a photo of a nice [name]
21 | a good photo of a [name]
22 | a photo of the nice [name]
23 | a photo of the small [name]
24 | a photo of the weird [name]
25 | a photo of the large [name]
26 | a photo of a cool [name]
27 | a photo of a small [name]
28 |
--------------------------------------------------------------------------------
/test/server_poll.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import requests
3 | import time
4 |
5 |
6 | def run_tests(proc, test_dir):
7 | timeout_threshold = 240
8 | start_time = time.time()
9 | while time.time()-start_time < timeout_threshold:
10 | try:
11 | requests.head("http://localhost:7860/")
12 | break
13 | except requests.exceptions.ConnectionError:
14 | if proc.poll() is not None:
15 | break
16 | if proc.poll() is None:
17 | if test_dir is None:
18 | test_dir = "test"
19 | suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test")
20 | result = unittest.TextTestRunner(verbosity=2).run(suite)
21 | return len(result.failures) + len(result.errors)
22 | else:
23 | print("Launch unsuccessful")
24 | return 1
25 |
--------------------------------------------------------------------------------
/webui-macos-env.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ####################################################################
3 | # macOS defaults #
4 | # Please modify webui-user.sh to change these instead of this file #
5 | ####################################################################
6 |
7 | if [[ -x "$(command -v python3.10)" ]]
8 | then
9 | python_cmd="python3.10"
10 | fi
11 |
12 | export install_dir="$HOME"
13 | export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
14 | export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
15 | export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
16 | export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
17 | export PYTORCH_ENABLE_MPS_FALLBACK=1
18 |
19 | ####################################################################
20 |
--------------------------------------------------------------------------------
/.github/workflows/run_tests.yaml:
--------------------------------------------------------------------------------
1 | name: Run basic features tests on CPU with empty SD model
2 |
3 | on:
4 | - push
5 | - pull_request
6 |
7 | jobs:
8 | test:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - name: Checkout Code
12 | uses: actions/checkout@v3
13 | - name: Set up Python 3.10
14 | uses: actions/setup-python@v4
15 | with:
16 | python-version: 3.10.6
17 | cache: pip
18 | cache-dependency-path: |
19 | **/requirements*txt
20 | - name: Run tests
21 | run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
22 | - name: Upload main app stdout-stderr
23 | uses: actions/upload-artifact@v3
24 | if: always()
25 | with:
26 | name: stdout-stderr
27 | path: |
28 | test/stdout.txt
29 | test/stderr.txt
30 |
--------------------------------------------------------------------------------
/textual_inversion_templates/hypernetwork.txt:
--------------------------------------------------------------------------------
1 | a photo of a [filewords]
2 | a rendering of a [filewords]
3 | a cropped photo of the [filewords]
4 | the photo of a [filewords]
5 | a photo of a clean [filewords]
6 | a photo of a dirty [filewords]
7 | a dark photo of the [filewords]
8 | a photo of my [filewords]
9 | a photo of the cool [filewords]
10 | a close-up photo of a [filewords]
11 | a bright photo of the [filewords]
12 | a cropped photo of a [filewords]
13 | a photo of the [filewords]
14 | a good photo of the [filewords]
15 | a photo of one [filewords]
16 | a close-up photo of the [filewords]
17 | a rendition of the [filewords]
18 | a photo of the clean [filewords]
19 | a rendition of a [filewords]
20 | a photo of a nice [filewords]
21 | a good photo of a [filewords]
22 | a photo of the nice [filewords]
23 | a photo of the small [filewords]
24 | a photo of the weird [filewords]
25 | a photo of the large [filewords]
26 | a photo of a cool [filewords]
27 | a photo of a small [filewords]
28 |
--------------------------------------------------------------------------------
/textual_inversion_templates/style_filewords.txt:
--------------------------------------------------------------------------------
1 | a painting of [filewords], art by [name]
2 | a rendering of [filewords], art by [name]
3 | a cropped painting of [filewords], art by [name]
4 | the painting of [filewords], art by [name]
5 | a clean painting of [filewords], art by [name]
6 | a dirty painting of [filewords], art by [name]
7 | a dark painting of [filewords], art by [name]
8 | a picture of [filewords], art by [name]
9 | a cool painting of [filewords], art by [name]
10 | a close-up painting of [filewords], art by [name]
11 | a bright painting of [filewords], art by [name]
12 | a cropped painting of [filewords], art by [name]
13 | a good painting of [filewords], art by [name]
14 | a close-up painting of [filewords], art by [name]
15 | a rendition of [filewords], art by [name]
16 | a nice painting of [filewords], art by [name]
17 | a small painting of [filewords], art by [name]
18 | a weird painting of [filewords], art by [name]
19 | a large painting of [filewords], art by [name]
20 |
--------------------------------------------------------------------------------
/html/image-update.svg:
--------------------------------------------------------------------------------
1 |
8 |
--------------------------------------------------------------------------------
/modules/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 |
4 | class Timer:
5 | def __init__(self):
6 | self.start = time.time()
7 | self.records = {}
8 | self.total = 0
9 |
10 | def elapsed(self):
11 | end = time.time()
12 | res = end - self.start
13 | self.start = end
14 | return res
15 |
16 | def record(self, category, extra_time=0):
17 | e = self.elapsed()
18 | if category not in self.records:
19 | self.records[category] = 0
20 |
21 | self.records[category] += e + extra_time
22 | self.total += e + extra_time
23 |
24 | def summary(self):
25 | res = f"{self.total:.1f}s"
26 |
27 | additions = [x for x in self.records.items() if x[1] >= 0.1]
28 | if not additions:
29 | return res
30 |
31 | res += " ("
32 | res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
33 | res += ")"
34 |
35 | return res
36 |
--------------------------------------------------------------------------------
/javascript/hires_fix.js:
--------------------------------------------------------------------------------
1 |
2 | function setInactive(elem, inactive){
3 | if(inactive){
4 | elem.classList.add('inactive')
5 | } else{
6 | elem.classList.remove('inactive')
7 | }
8 | }
9 |
10 | function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
11 | hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
12 | hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
13 | hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
14 |
15 | gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
16 |
17 | setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
18 | setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
19 | setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
20 |
21 | return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
22 | }
23 |
--------------------------------------------------------------------------------
/modules/script_loading.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import traceback
4 | import importlib.util
5 | from types import ModuleType
6 |
7 |
8 | def load_module(path):
9 | module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
10 | module = importlib.util.module_from_spec(module_spec)
11 | module_spec.loader.exec_module(module)
12 |
13 | return module
14 |
15 |
16 | def preload_extensions(extensions_dir, parser):
17 | if not os.path.isdir(extensions_dir):
18 | return
19 |
20 | for dirname in sorted(os.listdir(extensions_dir)):
21 | preload_script = os.path.join(extensions_dir, dirname, "preload.py")
22 | if not os.path.isfile(preload_script):
23 | continue
24 |
25 | try:
26 | module = load_module(preload_script)
27 | if hasattr(module, 'preload'):
28 | module.preload(parser)
29 |
30 | except Exception:
31 | print(f"Error running preload() for {preload_script}", file=sys.stderr)
32 | print(traceback.format_exc(), file=sys.stderr)
33 |
--------------------------------------------------------------------------------
/extensions-builtin/Lora/extra_networks_lora.py:
--------------------------------------------------------------------------------
1 | from modules import extra_networks, shared
2 | import lora
3 |
4 | class ExtraNetworkLora(extra_networks.ExtraNetwork):
5 | def __init__(self):
6 | super().__init__('lora')
7 |
8 | def activate(self, p, params_list):
9 | additional = shared.opts.sd_lora
10 |
11 | if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
12 | p.all_prompts = [x + f"" for x in p.all_prompts]
13 | params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
14 |
15 | names = []
16 | multipliers = []
17 | for params in params_list:
18 | assert len(params.items) > 0
19 |
20 | names.append(params.items[0])
21 | multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
22 |
23 | lora.load_loras(names, multipliers)
24 |
25 | def deactivate(self, p):
26 | pass
27 |
--------------------------------------------------------------------------------
/modules/ngrok.py:
--------------------------------------------------------------------------------
1 | from pyngrok import ngrok, conf, exception
2 |
3 | def connect(token, port, region):
4 | account = None
5 | if token is None:
6 | token = 'None'
7 | else:
8 | if ':' in token:
9 | # token = authtoken:username:password
10 | account = token.split(':')[1] + ':' + token.split(':')[-1]
11 | token = token.split(':')[0]
12 |
13 | config = conf.PyngrokConfig(
14 | auth_token=token, region=region
15 | )
16 | try:
17 | if account is None:
18 | public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
19 | else:
20 | public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
21 | except exception.PyngrokNgrokError:
22 | print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
23 | f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
24 | else:
25 | print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
26 | 'You can use this link after the launch is complete.')
27 |
--------------------------------------------------------------------------------
/modules/localization.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sys
4 | import traceback
5 |
6 |
7 | localizations = {}
8 |
9 |
10 | def list_localizations(dirname):
11 | localizations.clear()
12 |
13 | for file in os.listdir(dirname):
14 | fn, ext = os.path.splitext(file)
15 | if ext.lower() != ".json":
16 | continue
17 |
18 | localizations[fn] = os.path.join(dirname, file)
19 |
20 | from modules import scripts
21 | for file in scripts.list_scripts("localizations", ".json"):
22 | fn, ext = os.path.splitext(file.filename)
23 | localizations[fn] = file.path
24 |
25 |
26 | def localization_js(current_localization_name):
27 | fn = localizations.get(current_localization_name, None)
28 | data = {}
29 | if fn is not None:
30 | try:
31 | with open(fn, "r", encoding="utf8") as file:
32 | data = json.load(file)
33 | except Exception:
34 | print(f"Error loading localization from {fn}:", file=sys.stderr)
35 | print(traceback.format_exc(), file=sys.stderr)
36 |
37 | return f"var localization = {json.dumps(data)}\n"
38 |
--------------------------------------------------------------------------------
/textual_inversion_templates/subject_filewords.txt:
--------------------------------------------------------------------------------
1 | a photo of a [name], [filewords]
2 | a rendering of a [name], [filewords]
3 | a cropped photo of the [name], [filewords]
4 | the photo of a [name], [filewords]
5 | a photo of a clean [name], [filewords]
6 | a photo of a dirty [name], [filewords]
7 | a dark photo of the [name], [filewords]
8 | a photo of my [name], [filewords]
9 | a photo of the cool [name], [filewords]
10 | a close-up photo of a [name], [filewords]
11 | a bright photo of the [name], [filewords]
12 | a cropped photo of a [name], [filewords]
13 | a photo of the [name], [filewords]
14 | a good photo of the [name], [filewords]
15 | a photo of one [name], [filewords]
16 | a close-up photo of the [name], [filewords]
17 | a rendition of the [name], [filewords]
18 | a photo of the clean [name], [filewords]
19 | a rendition of a [name], [filewords]
20 | a photo of a nice [name], [filewords]
21 | a good photo of a [name], [filewords]
22 | a photo of the nice [name], [filewords]
23 | a photo of the small [name], [filewords]
24 | a photo of the weird [name], [filewords]
25 | a photo of the large [name], [filewords]
26 | a photo of a cool [name], [filewords]
27 | a photo of a small [name], [filewords]
28 |
--------------------------------------------------------------------------------
/scripts/postprocessing_gfpgan.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 |
4 | from modules import scripts_postprocessing, gfpgan_model
5 | import gradio as gr
6 |
7 | from modules.ui_components import FormRow
8 |
9 |
10 | class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
11 | name = "GFPGAN"
12 | order = 2000
13 |
14 | def ui(self):
15 | with FormRow():
16 | gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility")
17 |
18 | return {
19 | "gfpgan_visibility": gfpgan_visibility,
20 | }
21 |
22 | def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility):
23 | if gfpgan_visibility == 0:
24 | return
25 |
26 | restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8))
27 | res = Image.fromarray(restored_img)
28 |
29 | if gfpgan_visibility < 1.0:
30 | res = Image.blend(pp.image, res, gfpgan_visibility)
31 |
32 | pp.image = res
33 | pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3)
34 |
--------------------------------------------------------------------------------
/modules/extra_networks_hypernet.py:
--------------------------------------------------------------------------------
1 | from modules import extra_networks, shared, extra_networks
2 | from modules.hypernetworks import hypernetwork
3 |
4 |
5 | class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
6 | def __init__(self):
7 | super().__init__('hypernet')
8 |
9 | def activate(self, p, params_list):
10 | additional = shared.opts.sd_hypernetwork
11 |
12 | if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
13 | p.all_prompts = [x + f"" for x in p.all_prompts]
14 | params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
15 |
16 | names = []
17 | multipliers = []
18 | for params in params_list:
19 | assert len(params.items) > 0
20 |
21 | names.append(params.items[0])
22 | multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
23 |
24 | hypernetwork.load_hypernetworks(names, multipliers)
25 |
26 | def deactivate(self, p):
27 | pass
28 |
--------------------------------------------------------------------------------
/scripts/custom_code.py:
--------------------------------------------------------------------------------
1 | import modules.scripts as scripts
2 | import gradio as gr
3 |
4 | from modules.processing import Processed
5 | from modules.shared import opts, cmd_opts, state
6 |
7 | class Script(scripts.Script):
8 |
9 | def title(self):
10 | return "Custom code"
11 |
12 | def show(self, is_img2img):
13 | return cmd_opts.allow_code
14 |
15 | def ui(self, is_img2img):
16 | code = gr.Textbox(label="Python code", lines=1, elem_id=self.elem_id("code"))
17 |
18 | return [code]
19 |
20 |
21 | def run(self, p, code):
22 | assert cmd_opts.allow_code, '--allow-code option must be enabled'
23 |
24 | display_result_data = [[], -1, ""]
25 |
26 | def display(imgs, s=display_result_data[1], i=display_result_data[2]):
27 | display_result_data[0] = imgs
28 | display_result_data[1] = s
29 | display_result_data[2] = i
30 |
31 | from types import ModuleType
32 | compiled = compile(code, '', 'exec')
33 | module = ModuleType("testmodule")
34 | module.__dict__.update(globals())
35 | module.p = p
36 | module.display = display
37 | exec(compiled, module.__dict__)
38 |
39 | return Processed(p, *display_result_data)
40 |
41 |
--------------------------------------------------------------------------------
/modules/ui_extra_networks_hypernets.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | from modules import shared, ui_extra_networks
5 |
6 |
7 | class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
8 | def __init__(self):
9 | super().__init__('Hypernetworks')
10 |
11 | def refresh(self):
12 | shared.reload_hypernetworks()
13 |
14 | def list_items(self):
15 | for name, path in shared.hypernetworks.items():
16 | path, ext = os.path.splitext(path)
17 | previews = [path + ".png", path + ".preview.png"]
18 |
19 | preview = None
20 | for file in previews:
21 | if os.path.isfile(file):
22 | preview = self.link_preview(file)
23 | break
24 |
25 | yield {
26 | "name": name,
27 | "filename": path,
28 | "preview": preview,
29 | "search_term": self.search_terms_from_path(path),
30 | "prompt": json.dumps(f""),
31 | "local_preview": path + ".png",
32 | }
33 |
34 | def allowed_directories_for_previews(self):
35 | return [shared.cmd_opts.hypernetwork_dir]
36 |
37 |
--------------------------------------------------------------------------------
/extensions-builtin/Lora/ui_extra_networks_lora.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import lora
4 |
5 | from modules import shared, ui_extra_networks
6 |
7 |
8 | class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
9 | def __init__(self):
10 | super().__init__('Lora')
11 |
12 | def refresh(self):
13 | lora.list_available_loras()
14 |
15 | def list_items(self):
16 | for name, lora_on_disk in lora.available_loras.items():
17 | path, ext = os.path.splitext(lora_on_disk.filename)
18 | previews = [path + ".png", path + ".preview.png"]
19 |
20 | preview = None
21 | for file in previews:
22 | if os.path.isfile(file):
23 | preview = self.link_preview(file)
24 | break
25 |
26 | yield {
27 | "name": name,
28 | "filename": path,
29 | "preview": preview,
30 | "search_term": self.search_terms_from_path(lora_on_disk.filename),
31 | "prompt": json.dumps(f""),
32 | "local_preview": path + ".png",
33 | }
34 |
35 | def allowed_directories_for_previews(self):
36 | return [shared.cmd_opts.lora_dir]
37 |
38 |
--------------------------------------------------------------------------------
/modules/errors.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import traceback
3 |
4 |
5 | def print_error_explanation(message):
6 | lines = message.strip().split("\n")
7 | max_len = max([len(x) for x in lines])
8 |
9 | print('=' * max_len, file=sys.stderr)
10 | for line in lines:
11 | print(line, file=sys.stderr)
12 | print('=' * max_len, file=sys.stderr)
13 |
14 |
15 | def display(e: Exception, task):
16 | print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
17 | print(traceback.format_exc(), file=sys.stderr)
18 |
19 | message = str(e)
20 | if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
21 | print_error_explanation("""
22 | The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
23 | See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
24 | """)
25 |
26 |
27 | already_displayed = {}
28 |
29 |
30 | def display_once(e: Exception, task):
31 | if task in already_displayed:
32 | return
33 |
34 | display(e, task)
35 |
36 | already_displayed[task] = 1
37 |
38 |
39 | def run(code, task):
40 | try:
41 | code()
42 | except Exception as e:
43 | display(task, e)
44 |
--------------------------------------------------------------------------------
/modules/ui_extra_networks_textual_inversion.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | from modules import ui_extra_networks, sd_hijack
5 |
6 |
7 | class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
8 | def __init__(self):
9 | super().__init__('Textual Inversion')
10 | self.allow_negative_prompt = True
11 |
12 | def refresh(self):
13 | sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
14 |
15 | def list_items(self):
16 | for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
17 | path, ext = os.path.splitext(embedding.filename)
18 | preview_file = path + ".preview.png"
19 |
20 | preview = None
21 | if os.path.isfile(preview_file):
22 | preview = self.link_preview(preview_file)
23 |
24 | yield {
25 | "name": embedding.name,
26 | "filename": embedding.filename,
27 | "preview": preview,
28 | "search_term": self.search_terms_from_path(embedding.filename),
29 | "prompt": json.dumps(embedding.name),
30 | "local_preview": path + ".preview.png",
31 | }
32 |
33 | def allowed_directories_for_previews(self):
34 | return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
35 |
--------------------------------------------------------------------------------
/modules/sd_hijack_utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | class CondFunc:
4 | def __new__(cls, orig_func, sub_func, cond_func):
5 | self = super(CondFunc, cls).__new__(cls)
6 | if isinstance(orig_func, str):
7 | func_path = orig_func.split('.')
8 | for i in range(len(func_path)-1, -1, -1):
9 | try:
10 | resolved_obj = importlib.import_module('.'.join(func_path[:i]))
11 | break
12 | except ImportError:
13 | pass
14 | for attr_name in func_path[i:-1]:
15 | resolved_obj = getattr(resolved_obj, attr_name)
16 | orig_func = getattr(resolved_obj, func_path[-1])
17 | setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
18 | self.__init__(orig_func, sub_func, cond_func)
19 | return lambda *args, **kwargs: self(*args, **kwargs)
20 | def __init__(self, orig_func, sub_func, cond_func):
21 | self.__orig_func = orig_func
22 | self.__sub_func = sub_func
23 | self.__cond_func = cond_func
24 | def __call__(self, *args, **kwargs):
25 | if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
26 | return self.__sub_func(self.__orig_func, *args, **kwargs)
27 | else:
28 | return self.__orig_func(*args, **kwargs)
29 |
--------------------------------------------------------------------------------
/modules/sd_hijack_checkpoint.py:
--------------------------------------------------------------------------------
1 | from torch.utils.checkpoint import checkpoint
2 |
3 | import ldm.modules.attention
4 | import ldm.modules.diffusionmodules.openaimodel
5 |
6 |
7 | def BasicTransformerBlock_forward(self, x, context=None):
8 | return checkpoint(self._forward, x, context)
9 |
10 |
11 | def AttentionBlock_forward(self, x):
12 | return checkpoint(self._forward, x)
13 |
14 |
15 | def ResBlock_forward(self, x, emb):
16 | return checkpoint(self._forward, x, emb)
17 |
18 |
19 | stored = []
20 |
21 |
22 | def add():
23 | if len(stored) != 0:
24 | return
25 |
26 | stored.extend([
27 | ldm.modules.attention.BasicTransformerBlock.forward,
28 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
29 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
30 | ])
31 |
32 | ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
33 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
34 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
35 |
36 |
37 | def remove():
38 | if len(stored) == 0:
39 | return
40 |
41 | ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
42 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
43 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
44 |
45 | stored.clear()
46 |
47 |
--------------------------------------------------------------------------------
/javascript/generationParams.js:
--------------------------------------------------------------------------------
1 | // attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
2 |
3 | let txt2img_gallery, img2img_gallery, modal = undefined;
4 | onUiUpdate(function(){
5 | if (!txt2img_gallery) {
6 | txt2img_gallery = attachGalleryListeners("txt2img")
7 | }
8 | if (!img2img_gallery) {
9 | img2img_gallery = attachGalleryListeners("img2img")
10 | }
11 | if (!modal) {
12 | modal = gradioApp().getElementById('lightboxModal')
13 | modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
14 | }
15 | });
16 |
17 | let modalObserver = new MutationObserver(function(mutations) {
18 | mutations.forEach(function(mutationRecord) {
19 | let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
20 | if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
21 | gradioApp().getElementById(selectedTab+"_generation_info_button").click()
22 | });
23 | });
24 |
25 | function attachGalleryListeners(tab_name) {
26 | gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
27 | gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
28 | gallery?.addEventListener('keydown', (e) => {
29 | if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
30 | gradioApp().getElementById(tab_name+"_generation_info_button").click()
31 | });
32 | return gallery;
33 | }
34 |
--------------------------------------------------------------------------------
/.github/workflows/on_pull_request.yaml:
--------------------------------------------------------------------------------
1 | # See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
2 | name: Run Linting/Formatting on Pull Requests
3 |
4 | on:
5 | - push
6 | - pull_request
7 | # See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
8 | # if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
9 | # pull_request:
10 | # branches:
11 | # - master
12 | # branches-ignore:
13 | # - development
14 |
15 | jobs:
16 | lint:
17 | runs-on: ubuntu-latest
18 | steps:
19 | - name: Checkout Code
20 | uses: actions/checkout@v3
21 | - name: Set up Python 3.10
22 | uses: actions/setup-python@v4
23 | with:
24 | python-version: 3.10.6
25 | cache: pip
26 | cache-dependency-path: |
27 | **/requirements*txt
28 | - name: Install PyLint
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install pylint
32 | # This lets PyLint check to see if it can resolve imports
33 | - name: Install dependencies
34 | run: |
35 | export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
36 | python launch.py
37 | - name: Analysing the code with pylint
38 | run: |
39 | pylint $(git ls-files '*.py')
40 |
--------------------------------------------------------------------------------
/modules/sd_hijack_open_clip.py:
--------------------------------------------------------------------------------
1 | import open_clip.tokenizer
2 | import torch
3 |
4 | from modules import sd_hijack_clip, devices
5 | from modules.shared import opts
6 |
7 | tokenizer = open_clip.tokenizer._tokenizer
8 |
9 |
10 | class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
11 | def __init__(self, wrapped, hijack):
12 | super().__init__(wrapped, hijack)
13 |
14 | self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0]
15 | self.id_start = tokenizer.encoder[""]
16 | self.id_end = tokenizer.encoder[""]
17 | self.id_pad = 0
18 |
19 | def tokenize(self, texts):
20 | assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
21 |
22 | tokenized = [tokenizer.encode(text) for text in texts]
23 |
24 | return tokenized
25 |
26 | def encode_with_transformers(self, tokens):
27 | # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
28 | z = self.wrapped.encode_with_transformer(tokens)
29 |
30 | return z
31 |
32 | def encode_embedding_init_text(self, init_text, nvpt):
33 | ids = tokenizer.encode(init_text)
34 | ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
35 | embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
36 |
37 | return embedded
38 |
--------------------------------------------------------------------------------
/webui-user.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #########################################################
3 | # Uncomment and change the variables below to your need:#
4 | #########################################################
5 |
6 | # Install directory without trailing slash
7 | #install_dir="/home/$(whoami)"
8 |
9 | # Name of the subdirectory
10 | #clone_dir="stable-diffusion-webui"
11 |
12 | # Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention"
13 | #export COMMANDLINE_ARGS=""
14 |
15 | # python3 executable
16 | #python_cmd="python3"
17 |
18 | # git executable
19 | #export GIT="git"
20 |
21 | # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
22 | #venv_dir="venv"
23 |
24 | # script to launch to start the app
25 | #export LAUNCH_SCRIPT="launch.py"
26 |
27 | # install command for torch
28 | #export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113"
29 |
30 | # Requirements file to use for stable-diffusion-webui
31 | #export REQS_FILE="requirements_versions.txt"
32 |
33 | # Fixed git repos
34 | #export K_DIFFUSION_PACKAGE=""
35 | #export GFPGAN_PACKAGE=""
36 |
37 | # Fixed git commits
38 | #export STABLE_DIFFUSION_COMMIT_HASH=""
39 | #export TAMING_TRANSFORMERS_COMMIT_HASH=""
40 | #export CODEFORMER_COMMIT_HASH=""
41 | #export BLIP_COMMIT_HASH=""
42 |
43 | # Uncomment to enable accelerated launch
44 | #export ACCELERATE="True"
45 |
46 | ###########################################
47 |
--------------------------------------------------------------------------------
/modules/ui_extra_networks_checkpoints.py:
--------------------------------------------------------------------------------
1 | import html
2 | import json
3 | import os
4 | import urllib.parse
5 |
6 | from modules import shared, ui_extra_networks, sd_models
7 |
8 |
9 | class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
10 | def __init__(self):
11 | super().__init__('Checkpoints')
12 |
13 | def refresh(self):
14 | shared.refresh_checkpoints()
15 |
16 | def list_items(self):
17 | checkpoint: sd_models.CheckpointInfo
18 | for name, checkpoint in sd_models.checkpoints_list.items():
19 | path, ext = os.path.splitext(checkpoint.filename)
20 | previews = [path + ".png", path + ".preview.png"]
21 |
22 | preview = None
23 | for file in previews:
24 | if os.path.isfile(file):
25 | preview = self.link_preview(file)
26 | break
27 |
28 | yield {
29 | "name": checkpoint.name_for_extra,
30 | "filename": path,
31 | "preview": preview,
32 | "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
33 | "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
34 | "local_preview": path + ".png",
35 | }
36 |
37 | def allowed_directories_for_previews(self):
38 | return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
39 |
40 |
--------------------------------------------------------------------------------
/modules/textual_inversion/logging.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import json
3 | import os
4 |
5 | saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
6 | saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
7 | saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
8 | saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
9 | saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
10 |
11 |
12 | def save_settings_to_file(log_directory, all_params):
13 | now = datetime.datetime.now()
14 | params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")}
15 |
16 | keys = saved_params_all
17 | if all_params.get('preview_from_txt2img'):
18 | keys = keys | saved_params_previews
19 |
20 | params.update({k: v for k, v in all_params.items() if k in keys})
21 |
22 | filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json'
23 | with open(os.path.join(log_directory, filename), "w") as file:
24 | json.dump(params, file, indent=4)
25 |
--------------------------------------------------------------------------------
/modules/sd_samplers.py:
--------------------------------------------------------------------------------
1 | from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
2 |
3 | # imports for functions that previously were here and are used by other modules
4 | from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
5 |
6 | all_samplers = [
7 | *sd_samplers_kdiffusion.samplers_data_k_diffusion,
8 | *sd_samplers_compvis.samplers_data_compvis,
9 | ]
10 | all_samplers_map = {x.name: x for x in all_samplers}
11 |
12 | samplers = []
13 | samplers_for_img2img = []
14 | samplers_map = {}
15 |
16 |
17 | def create_sampler(name, model):
18 | if name is not None:
19 | config = all_samplers_map.get(name, None)
20 | else:
21 | config = all_samplers[0]
22 |
23 | assert config is not None, f'bad sampler name: {name}'
24 |
25 | sampler = config.constructor(model)
26 | sampler.config = config
27 |
28 | return sampler
29 |
30 |
31 | def set_samplers():
32 | global samplers, samplers_for_img2img
33 |
34 | hidden = set(shared.opts.hide_samplers)
35 | hidden_img2img = set(shared.opts.hide_samplers + ['PLMS'])
36 |
37 | samplers = [x for x in all_samplers if x.name not in hidden]
38 | samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
39 |
40 | samplers_map.clear()
41 | for sampler in all_samplers:
42 | samplers_map[sampler.name.lower()] = sampler.name
43 | for alias in sampler.aliases:
44 | samplers_map[alias.lower()] = sampler.name
45 |
46 |
47 | set_samplers()
48 |
--------------------------------------------------------------------------------
/javascript/extensions.js:
--------------------------------------------------------------------------------
1 |
2 | function extensions_apply(_, _){
3 | var disable = []
4 | var update = []
5 |
6 | gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
7 | if(x.name.startsWith("enable_") && ! x.checked)
8 | disable.push(x.name.substr(7))
9 |
10 | if(x.name.startsWith("update_") && x.checked)
11 | update.push(x.name.substr(7))
12 | })
13 |
14 | restart_reload()
15 |
16 | return [JSON.stringify(disable), JSON.stringify(update)]
17 | }
18 |
19 | function extensions_check(){
20 | var disable = []
21 |
22 | gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
23 | if(x.name.startsWith("enable_") && ! x.checked)
24 | disable.push(x.name.substr(7))
25 | })
26 |
27 | gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
28 | x.innerHTML = "Loading..."
29 | })
30 |
31 |
32 | var id = randomId()
33 | requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
34 |
35 | })
36 |
37 | return [id, JSON.stringify(disable)]
38 | }
39 |
40 | function install_extension_from_index(button, url){
41 | button.disabled = "disabled"
42 | button.value = "Installing..."
43 |
44 | textarea = gradioApp().querySelector('#extension_to_install textarea')
45 | textarea.value = url
46 | updateInput(textarea)
47 |
48 | gradioApp().querySelector('#install_extension_button').click()
49 | }
50 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: Feature request
2 | description: Suggest an idea for this project
3 | title: "[Feature Request]: "
4 | labels: ["enhancement"]
5 |
6 | body:
7 | - type: checkboxes
8 | attributes:
9 | label: Is there an existing issue for this?
10 | description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit.
11 | options:
12 | - label: I have searched the existing issues and checked the recent builds/commits
13 | required: true
14 | - type: markdown
15 | attributes:
16 | value: |
17 | *Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible*
18 | - type: textarea
19 | id: feature
20 | attributes:
21 | label: What would your feature do ?
22 | description: Tell us about your feature in a very clear and simple way, and what problem it would solve
23 | validations:
24 | required: true
25 | - type: textarea
26 | id: workflow
27 | attributes:
28 | label: Proposed workflow
29 | description: Please provide us with step by step information on how you'd like the feature to be accessed and used
30 | value: |
31 | 1. Go to ....
32 | 2. Press ....
33 | 3. ...
34 | validations:
35 | required: true
36 | - type: textarea
37 | id: misc
38 | attributes:
39 | label: Additional information
40 | description: Add any other context or screenshots about the feature request here.
41 |
--------------------------------------------------------------------------------
/javascript/imageMaskFix.js:
--------------------------------------------------------------------------------
1 | /**
2 | * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
3 | * @see https://github.com/gradio-app/gradio/issues/1721
4 | */
5 | window.addEventListener( 'resize', () => imageMaskResize());
6 | function imageMaskResize() {
7 | const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
8 | if ( ! canvases.length ) {
9 | canvases_fixed = false;
10 | window.removeEventListener( 'resize', imageMaskResize );
11 | return;
12 | }
13 |
14 | const wrapper = canvases[0].closest('.touch-none');
15 | const previewImage = wrapper.previousElementSibling;
16 |
17 | if ( ! previewImage.complete ) {
18 | previewImage.addEventListener( 'load', () => imageMaskResize());
19 | return;
20 | }
21 |
22 | const w = previewImage.width;
23 | const h = previewImage.height;
24 | const nw = previewImage.naturalWidth;
25 | const nh = previewImage.naturalHeight;
26 | const portrait = nh > nw;
27 | const factor = portrait;
28 |
29 | const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
30 | const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
31 |
32 | wrapper.style.width = `${wW}px`;
33 | wrapper.style.height = `${wH}px`;
34 | wrapper.style.left = `0px`;
35 | wrapper.style.top = `0px`;
36 |
37 | canvases.forEach( c => {
38 | c.style.width = c.style.height = '';
39 | c.style.maxWidth = '100%';
40 | c.style.maxHeight = '100%';
41 | c.style.objectFit = 'contain';
42 | });
43 | }
44 |
45 | onUiUpdate(() => imageMaskResize());
46 |
--------------------------------------------------------------------------------
/scripts/postprocessing_codeformer.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 |
4 | from modules import scripts_postprocessing, codeformer_model
5 | import gradio as gr
6 |
7 | from modules.ui_components import FormRow
8 |
9 |
10 | class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing):
11 | name = "CodeFormer"
12 | order = 3000
13 |
14 | def ui(self):
15 | with FormRow():
16 | codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility")
17 | codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
18 |
19 | return {
20 | "codeformer_visibility": codeformer_visibility,
21 | "codeformer_weight": codeformer_weight,
22 | }
23 |
24 | def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight):
25 | if codeformer_visibility == 0:
26 | return
27 |
28 | restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight)
29 | res = Image.fromarray(restored_img)
30 |
31 | if codeformer_visibility < 1.0:
32 | res = Image.blend(pp.image, res, codeformer_visibility)
33 |
34 | pp.image = res
35 | pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3)
36 | pp.info["CodeFormer weight"] = round(codeformer_weight, 3)
37 |
--------------------------------------------------------------------------------
/extensions-builtin/Lora/scripts/lora_script.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import gradio as gr
3 |
4 | import lora
5 | import extra_networks_lora
6 | import ui_extra_networks_lora
7 | from modules import script_callbacks, ui_extra_networks, extra_networks, shared
8 |
9 |
10 | def unload():
11 | torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
12 | torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
13 |
14 |
15 | def before_ui():
16 | ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
17 | extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
18 |
19 |
20 | if not hasattr(torch.nn, 'Linear_forward_before_lora'):
21 | torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
22 |
23 | if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
24 | torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
25 |
26 | torch.nn.Linear.forward = lora.lora_Linear_forward
27 | torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
28 |
29 | script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
30 | script_callbacks.on_script_unloaded(unload)
31 | script_callbacks.on_before_ui(before_ui)
32 |
33 |
34 | shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
35 | "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
36 | "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
37 |
38 | }))
39 |
--------------------------------------------------------------------------------
/modules/sd_hijack_xlmr.py:
--------------------------------------------------------------------------------
1 | import open_clip.tokenizer
2 | import torch
3 |
4 | from modules import sd_hijack_clip, devices
5 | from modules.shared import opts
6 |
7 |
8 | class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
9 | def __init__(self, wrapped, hijack):
10 | super().__init__(wrapped, hijack)
11 |
12 | self.id_start = wrapped.config.bos_token_id
13 | self.id_end = wrapped.config.eos_token_id
14 | self.id_pad = wrapped.config.pad_token_id
15 |
16 | self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have bits for comma
17 |
18 | def encode_with_transformers(self, tokens):
19 | # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
20 | # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
21 | # layer to work with - you have to use the last
22 |
23 | attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
24 | features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
25 | z = features['projection_state']
26 |
27 | return z
28 |
29 | def encode_embedding_init_text(self, init_text, nvpt):
30 | embedding_layer = self.wrapped.roberta.embeddings
31 | ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
32 | embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
33 |
34 | return embedded
35 |
--------------------------------------------------------------------------------
/modules/scripts_auto_postprocessing.py:
--------------------------------------------------------------------------------
1 | from modules import scripts, scripts_postprocessing, shared
2 |
3 |
4 | class ScriptPostprocessingForMainUI(scripts.Script):
5 | def __init__(self, script_postproc):
6 | self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
7 | self.postprocessing_controls = None
8 |
9 | def title(self):
10 | return self.script.name
11 |
12 | def show(self, is_img2img):
13 | return scripts.AlwaysVisible
14 |
15 | def ui(self, is_img2img):
16 | self.postprocessing_controls = self.script.ui()
17 | return self.postprocessing_controls.values()
18 |
19 | def postprocess_image(self, p, script_pp, *args):
20 | args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
21 |
22 | pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
23 | pp.info = {}
24 | self.script.process(pp, **args_dict)
25 | p.extra_generation_params.update(pp.info)
26 | script_pp.image = pp.image
27 |
28 |
29 | def create_auto_preprocessing_script_data():
30 | from modules import scripts
31 |
32 | res = []
33 |
34 | for name in shared.opts.postprocessing_enable_in_main_ui:
35 | script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
36 | if script is None:
37 | continue
38 |
39 | constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
40 | res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
41 |
42 | return res
43 |
--------------------------------------------------------------------------------
/modules/textual_inversion/ui.py:
--------------------------------------------------------------------------------
1 | import html
2 |
3 | import gradio as gr
4 |
5 | import modules.textual_inversion.textual_inversion
6 | import modules.textual_inversion.preprocess
7 | from modules import sd_hijack, shared
8 |
9 |
10 | def create_embedding(name, initialization_text, nvpt, overwrite_old):
11 | filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
12 |
13 | sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
14 |
15 | return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
16 |
17 |
18 | def preprocess(*args):
19 | modules.textual_inversion.preprocess.preprocess(*args)
20 |
21 | return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
22 |
23 |
24 | def train_embedding(*args):
25 |
26 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
27 |
28 | apply_optimizations = shared.opts.training_xattention_optimizations
29 | try:
30 | if not apply_optimizations:
31 | sd_hijack.undo_optimizations()
32 |
33 | embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
34 |
35 | res = f"""
36 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
37 | Embedding saved to {html.escape(filename)}
38 | """
39 | return res, ""
40 | except Exception:
41 | raise
42 | finally:
43 | if not apply_optimizations:
44 | sd_hijack.apply_optimizations()
45 |
46 |
--------------------------------------------------------------------------------
/javascript/notification.js:
--------------------------------------------------------------------------------
1 | // Monitors the gallery and sends a browser notification when the leading image is new.
2 |
3 | let lastHeadImg = null;
4 |
5 | notificationButton = null
6 |
7 | onUiUpdate(function(){
8 | if(notificationButton == null){
9 | notificationButton = gradioApp().getElementById('request_notifications')
10 |
11 | if(notificationButton != null){
12 | notificationButton.addEventListener('click', function (evt) {
13 | Notification.requestPermission();
14 | },true);
15 | }
16 | }
17 |
18 | const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden');
19 |
20 | if (galleryPreviews == null) return;
21 |
22 | const headImg = galleryPreviews[0]?.src;
23 |
24 | if (headImg == null || headImg == lastHeadImg) return;
25 |
26 | lastHeadImg = headImg;
27 |
28 | // play notification sound if available
29 | gradioApp().querySelector('#audio_notification audio')?.play();
30 |
31 | if (document.hasFocus()) return;
32 |
33 | // Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated.
34 | const imgs = new Set(Array.from(galleryPreviews).map(img => img.src));
35 |
36 | const notification = new Notification(
37 | 'Stable Diffusion',
38 | {
39 | body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`,
40 | icon: headImg,
41 | image: headImg,
42 | }
43 | );
44 |
45 | notification.onclick = function(_){
46 | parent.focus();
47 | this.close();
48 | };
49 | });
50 |
--------------------------------------------------------------------------------
/modules/hypernetworks/ui.py:
--------------------------------------------------------------------------------
1 | import html
2 | import os
3 | import re
4 |
5 | import gradio as gr
6 | import modules.hypernetworks.hypernetwork
7 | from modules import devices, sd_hijack, shared
8 |
9 | not_available = ["hardswish", "multiheadattention"]
10 | keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
11 |
12 |
13 | def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
14 | filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
15 |
16 | return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
17 |
18 |
19 | def train_hypernetwork(*args):
20 | shared.loaded_hypernetworks = []
21 |
22 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
23 |
24 | try:
25 | sd_hijack.undo_optimizations()
26 |
27 | hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
28 |
29 | res = f"""
30 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
31 | Hypernetwork saved to {html.escape(filename)}
32 | """
33 | return res, ""
34 | except Exception:
35 | raise
36 | finally:
37 | shared.sd_model.cond_stage_model.to(devices.device)
38 | shared.sd_model.first_stage_model.to(devices.device)
39 | sd_hijack.apply_optimizations()
40 |
41 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 | # Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request!
2 |
3 | If you have a large change, pay special attention to this paragraph:
4 |
5 | > Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature.
6 |
7 | Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on.
8 |
9 | **Describe what this pull request is trying to achieve.**
10 |
11 | A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code.
12 |
13 | **Additional notes and description of your changes**
14 |
15 | More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of.
16 |
17 | **Environment this was tested in**
18 |
19 | List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
20 | - OS: [e.g. Windows, Linux]
21 | - Browser: [e.g. chrome, safari]
22 | - Graphics card: [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
23 |
24 | **Screenshots or videos of your changes**
25 |
26 | If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made.
27 |
28 | This is **required** for anything that touches the user interface.
--------------------------------------------------------------------------------
/modules/ui_components.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 |
4 | class ToolButton(gr.Button, gr.components.FormComponent):
5 | """Small button with single emoji as text, fits inside gradio forms"""
6 |
7 | def __init__(self, **kwargs):
8 | super().__init__(variant="tool", **kwargs)
9 |
10 | def get_block_name(self):
11 | return "button"
12 |
13 |
14 | class ToolButtonTop(gr.Button, gr.components.FormComponent):
15 | """Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
16 |
17 | def __init__(self, **kwargs):
18 | super().__init__(variant="tool-top", **kwargs)
19 |
20 | def get_block_name(self):
21 | return "button"
22 |
23 |
24 | class FormRow(gr.Row, gr.components.FormComponent):
25 | """Same as gr.Row but fits inside gradio forms"""
26 |
27 | def get_block_name(self):
28 | return "row"
29 |
30 |
31 | class FormGroup(gr.Group, gr.components.FormComponent):
32 | """Same as gr.Row but fits inside gradio forms"""
33 |
34 | def get_block_name(self):
35 | return "group"
36 |
37 |
38 | class FormHTML(gr.HTML, gr.components.FormComponent):
39 | """Same as gr.HTML but fits inside gradio forms"""
40 |
41 | def get_block_name(self):
42 | return "html"
43 |
44 |
45 | class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
46 | """Same as gr.ColorPicker but fits inside gradio forms"""
47 |
48 | def get_block_name(self):
49 | return "colorpicker"
50 |
51 |
52 | class DropdownMulti(gr.Dropdown):
53 | """Same as gr.Dropdown but always multiselect"""
54 | def __init__(self, **kwargs):
55 | super().__init__(multiselect=True, **kwargs)
56 |
57 | def get_block_name(self):
58 | return "dropdown"
59 |
--------------------------------------------------------------------------------
/modules/sd_vae_approx.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from modules import devices, paths
6 |
7 | sd_vae_approx_model = None
8 |
9 |
10 | class VAEApprox(nn.Module):
11 | def __init__(self):
12 | super(VAEApprox, self).__init__()
13 | self.conv1 = nn.Conv2d(4, 8, (7, 7))
14 | self.conv2 = nn.Conv2d(8, 16, (5, 5))
15 | self.conv3 = nn.Conv2d(16, 32, (3, 3))
16 | self.conv4 = nn.Conv2d(32, 64, (3, 3))
17 | self.conv5 = nn.Conv2d(64, 32, (3, 3))
18 | self.conv6 = nn.Conv2d(32, 16, (3, 3))
19 | self.conv7 = nn.Conv2d(16, 8, (3, 3))
20 | self.conv8 = nn.Conv2d(8, 3, (3, 3))
21 |
22 | def forward(self, x):
23 | extra = 11
24 | x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
25 | x = nn.functional.pad(x, (extra, extra, extra, extra))
26 |
27 | for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
28 | x = layer(x)
29 | x = nn.functional.leaky_relu(x, 0.1)
30 |
31 | return x
32 |
33 |
34 | def model():
35 | global sd_vae_approx_model
36 |
37 | if sd_vae_approx_model is None:
38 | sd_vae_approx_model = VAEApprox()
39 | sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None))
40 | sd_vae_approx_model.eval()
41 | sd_vae_approx_model.to(devices.device, devices.dtype)
42 |
43 | return sd_vae_approx_model
44 |
45 |
46 | def cheap_approximation(sample):
47 | # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
48 |
49 | coefs = torch.tensor([
50 | [0.298, 0.207, 0.208],
51 | [0.187, 0.286, 0.173],
52 | [-0.158, 0.189, 0.264],
53 | [-0.184, -0.271, -0.473],
54 | ]).to(sample.device)
55 |
56 | x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
57 |
58 | return x_sample
59 |
--------------------------------------------------------------------------------
/configs/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
--------------------------------------------------------------------------------
/test/basic_features/extras_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import requests
3 | from gradio.processing_utils import encode_pil_to_base64
4 | from PIL import Image
5 |
6 | class TestExtrasWorking(unittest.TestCase):
7 | def setUp(self):
8 | self.url_extras_single = "http://localhost:7860/sdapi/v1/extra-single-image"
9 | self.extras_single = {
10 | "resize_mode": 0,
11 | "show_extras_results": True,
12 | "gfpgan_visibility": 0,
13 | "codeformer_visibility": 0,
14 | "codeformer_weight": 0,
15 | "upscaling_resize": 2,
16 | "upscaling_resize_w": 128,
17 | "upscaling_resize_h": 128,
18 | "upscaling_crop": True,
19 | "upscaler_1": "None",
20 | "upscaler_2": "None",
21 | "extras_upscaler_2_visibility": 0,
22 | "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))
23 | }
24 |
25 | def test_simple_upscaling_performed(self):
26 | self.extras_single["upscaler_1"] = "Lanczos"
27 | self.assertEqual(requests.post(self.url_extras_single, json=self.extras_single).status_code, 200)
28 |
29 |
30 | class TestPngInfoWorking(unittest.TestCase):
31 | def setUp(self):
32 | self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image"
33 | self.png_info = {
34 | "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))
35 | }
36 |
37 | def test_png_info_performed(self):
38 | self.assertEqual(requests.post(self.url_png_info, json=self.png_info).status_code, 200)
39 |
40 |
41 | class TestInterrogateWorking(unittest.TestCase):
42 | def setUp(self):
43 | self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image"
44 | self.interrogate = {
45 | "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")),
46 | "model": "clip"
47 | }
48 |
49 | def test_interrogate_performed(self):
50 | self.assertEqual(requests.post(self.url_interrogate, json=self.interrogate).status_code, 200)
51 |
52 |
53 | if __name__ == "__main__":
54 | unittest.main()
55 |
--------------------------------------------------------------------------------
/configs/alt-diffusion-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: modules.xlmr.BertSeriesModelWithTransformation
71 | params:
72 | name: "XLMR-Large"
--------------------------------------------------------------------------------
/configs/v1-inpainting-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 7.5e-05
3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: hybrid # important
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | finetune_keys: null
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
--------------------------------------------------------------------------------
/modules/hashes.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import json
3 | import os.path
4 |
5 | import filelock
6 |
7 | from modules import shared
8 | from modules.paths import data_path
9 |
10 |
11 | cache_filename = os.path.join(data_path, "cache.json")
12 | cache_data = None
13 |
14 |
15 | def dump_cache():
16 | with filelock.FileLock(cache_filename+".lock"):
17 | with open(cache_filename, "w", encoding="utf8") as file:
18 | json.dump(cache_data, file, indent=4)
19 |
20 |
21 | def cache(subsection):
22 | global cache_data
23 |
24 | if cache_data is None:
25 | with filelock.FileLock(cache_filename+".lock"):
26 | if not os.path.isfile(cache_filename):
27 | cache_data = {}
28 | else:
29 | with open(cache_filename, "r", encoding="utf8") as file:
30 | cache_data = json.load(file)
31 |
32 | s = cache_data.get(subsection, {})
33 | cache_data[subsection] = s
34 |
35 | return s
36 |
37 |
38 | def calculate_sha256(filename):
39 | hash_sha256 = hashlib.sha256()
40 | blksize = 1024 * 1024
41 |
42 | with open(filename, "rb") as f:
43 | for chunk in iter(lambda: f.read(blksize), b""):
44 | hash_sha256.update(chunk)
45 |
46 | return hash_sha256.hexdigest()
47 |
48 |
49 | def sha256_from_cache(filename, title):
50 | hashes = cache("hashes")
51 | ondisk_mtime = os.path.getmtime(filename)
52 |
53 | if title not in hashes:
54 | return None
55 |
56 | cached_sha256 = hashes[title].get("sha256", None)
57 | cached_mtime = hashes[title].get("mtime", 0)
58 |
59 | if ondisk_mtime > cached_mtime or cached_sha256 is None:
60 | return None
61 |
62 | return cached_sha256
63 |
64 |
65 | def sha256(filename, title):
66 | hashes = cache("hashes")
67 |
68 | sha256_value = sha256_from_cache(filename, title)
69 | if sha256_value is not None:
70 | return sha256_value
71 |
72 | if shared.cmd_opts.no_hashing:
73 | return None
74 |
75 | print(f"Calculating sha256 for {filename}: ", end='')
76 | sha256_value = calculate_sha256(filename)
77 | print(f"{sha256_value}")
78 |
79 | hashes[title] = {
80 | "mtime": os.path.getmtime(filename),
81 | "sha256": sha256_value,
82 | }
83 |
84 | dump_cache()
85 |
86 | return sha256_value
87 |
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/webui.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | if not defined PYTHON (set PYTHON=python)
4 | if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
5 |
6 |
7 | set ERROR_REPORTING=FALSE
8 |
9 | mkdir tmp 2>NUL
10 |
11 | %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
12 | if %ERRORLEVEL% == 0 goto :check_pip
13 | echo Couldn't launch python
14 | goto :show_stdout_stderr
15 |
16 | :check_pip
17 | %PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
18 | if %ERRORLEVEL% == 0 goto :start_venv
19 | if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
20 | %PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
21 | if %ERRORLEVEL% == 0 goto :start_venv
22 | echo Couldn't install pip
23 | goto :show_stdout_stderr
24 |
25 | :start_venv
26 | if ["%VENV_DIR%"] == ["-"] goto :skip_venv
27 | if ["%SKIP_VENV%"] == ["1"] goto :skip_venv
28 |
29 | dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
30 | if %ERRORLEVEL% == 0 goto :activate_venv
31 |
32 | for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
33 | echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
34 | %PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
35 | if %ERRORLEVEL% == 0 goto :activate_venv
36 | echo Unable to create venv in directory "%VENV_DIR%"
37 | goto :show_stdout_stderr
38 |
39 | :activate_venv
40 | set PYTHON="%VENV_DIR%\Scripts\Python.exe"
41 | echo venv %PYTHON%
42 |
43 | :skip_venv
44 | if [%ACCELERATE%] == ["True"] goto :accelerate
45 | goto :launch
46 |
47 | :accelerate
48 | echo Checking for accelerate
49 | set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe"
50 | if EXIST %ACCELERATE% goto :accelerate_launch
51 |
52 | :launch
53 | %PYTHON% launch.py %*
54 | pause
55 | exit /b
56 |
57 | :accelerate_launch
58 | echo Accelerating
59 | %ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py
60 | pause
61 | exit /b
62 |
63 | :show_stdout_stderr
64 |
65 | echo.
66 | echo exit code: %errorlevel%
67 |
68 | for /f %%i in ("tmp\stdout.txt") do set size=%%~zi
69 | if %size% equ 0 goto :show_stderr
70 | echo.
71 | echo stdout:
72 | type tmp\stdout.txt
73 |
74 | :show_stderr
75 | for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
76 | if %size% equ 0 goto :show_stderr
77 | echo.
78 | echo stderr:
79 | type tmp\stderr.txt
80 |
81 | :endofscript
82 |
83 | echo.
84 | echo Launch unsuccessful. Exiting.
85 | pause
86 |
--------------------------------------------------------------------------------
/modules/sd_samplers_common.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 | from modules import devices, processing, images, sd_vae_approx
6 |
7 | from modules.shared import opts, state
8 | import modules.shared as shared
9 |
10 | SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
11 |
12 |
13 | def setup_img2img_steps(p, steps=None):
14 | if opts.img2img_fix_steps or steps is not None:
15 | requested_steps = (steps or p.steps)
16 | steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
17 | t_enc = requested_steps - 1
18 | else:
19 | steps = p.steps
20 | t_enc = int(min(p.denoising_strength, 0.999) * steps)
21 |
22 | return steps, t_enc
23 |
24 |
25 | approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
26 |
27 |
28 | def single_sample_to_image(sample, approximation=None):
29 | if approximation is None:
30 | approximation = approximation_indexes.get(opts.show_progress_type, 0)
31 |
32 | if approximation == 2:
33 | x_sample = sd_vae_approx.cheap_approximation(sample)
34 | elif approximation == 1:
35 | x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
36 | else:
37 | x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
38 |
39 | x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
40 | x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
41 | x_sample = x_sample.astype(np.uint8)
42 | return Image.fromarray(x_sample)
43 |
44 |
45 | def sample_to_image(samples, index=0, approximation=None):
46 | return single_sample_to_image(samples[index], approximation)
47 |
48 |
49 | def samples_to_image_grid(samples, approximation=None):
50 | return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
51 |
52 |
53 | def store_latent(decoded):
54 | state.current_latent = decoded
55 |
56 | if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
57 | if not shared.parallel_processing_allowed:
58 | shared.state.assign_current_image(sample_to_image(decoded))
59 |
60 |
61 | class InterruptedException(BaseException):
62 | pass
63 |
--------------------------------------------------------------------------------
/modules/paths.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import modules.safe
5 |
6 | script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
7 |
8 | # Parse the --data-dir flag first so we can use it as a base for our other argument default values
9 | parser = argparse.ArgumentParser(add_help=False)
10 | parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
11 | cmd_opts_pre = parser.parse_known_args()[0]
12 | data_path = cmd_opts_pre.data_dir
13 | models_path = os.path.join(data_path, "models")
14 |
15 | # data_path = cmd_opts_pre.data
16 | sys.path.insert(0, script_path)
17 |
18 | # search for directory of stable diffusion in following places
19 | sd_path = None
20 | possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
21 | for possible_sd_path in possible_sd_paths:
22 | if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
23 | sd_path = os.path.abspath(possible_sd_path)
24 | break
25 |
26 | assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
27 |
28 | path_dirs = [
29 | (sd_path, 'ldm', 'Stable Diffusion', []),
30 | (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
31 | (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
32 | (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
33 | (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
34 | ]
35 |
36 | paths = {}
37 |
38 | for d, must_exist, what, options in path_dirs:
39 | must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
40 | if not os.path.exists(must_exist_path):
41 | print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
42 | else:
43 | d = os.path.abspath(d)
44 | if "atstart" in options:
45 | sys.path.insert(0, d)
46 | else:
47 | sys.path.append(d)
48 | paths[what] = d
49 |
50 |
51 | class Prioritize:
52 | def __init__(self, name):
53 | self.name = name
54 | self.path = None
55 |
56 | def __enter__(self):
57 | self.path = sys.path.copy()
58 | sys.path = [paths[self.name]] + sys.path
59 |
60 | def __exit__(self, exc_type, exc_val, exc_tb):
61 | sys.path = self.path
62 | self.path = None
63 |
--------------------------------------------------------------------------------
/test/basic_features/utils_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import requests
3 |
4 | class UtilsTests(unittest.TestCase):
5 | def setUp(self):
6 | self.url_options = "http://localhost:7860/sdapi/v1/options"
7 | self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags"
8 | self.url_samplers = "http://localhost:7860/sdapi/v1/samplers"
9 | self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers"
10 | self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models"
11 | self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks"
12 | self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
13 | self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
14 | self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
15 | self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings"
16 |
17 | def test_options_get(self):
18 | self.assertEqual(requests.get(self.url_options).status_code, 200)
19 |
20 | def test_options_write(self):
21 | response = requests.get(self.url_options)
22 | self.assertEqual(response.status_code, 200)
23 |
24 | pre_value = response.json()["send_seed"]
25 |
26 | self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200)
27 |
28 | response = requests.get(self.url_options)
29 | self.assertEqual(response.status_code, 200)
30 | self.assertEqual(response.json()["send_seed"], not pre_value)
31 |
32 | requests.post(self.url_options, json={"send_seed": pre_value})
33 |
34 | def test_cmd_flags(self):
35 | self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
36 |
37 | def test_samplers(self):
38 | self.assertEqual(requests.get(self.url_samplers).status_code, 200)
39 |
40 | def test_upscalers(self):
41 | self.assertEqual(requests.get(self.url_upscalers).status_code, 200)
42 |
43 | def test_sd_models(self):
44 | self.assertEqual(requests.get(self.url_sd_models).status_code, 200)
45 |
46 | def test_hypernetworks(self):
47 | self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200)
48 |
49 | def test_face_restorers(self):
50 | self.assertEqual(requests.get(self.url_face_restorers).status_code, 200)
51 |
52 | def test_realesrgan_models(self):
53 | self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200)
54 |
55 | def test_prompt_styles(self):
56 | self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
57 |
58 | def test_embeddings(self):
59 | self.assertEqual(requests.get(self.url_embeddings).status_code, 200)
60 |
61 | if __name__ == "__main__":
62 | unittest.main()
63 |
--------------------------------------------------------------------------------
/test/basic_features/img2img_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import requests
3 | from gradio.processing_utils import encode_pil_to_base64
4 | from PIL import Image
5 |
6 |
7 | class TestImg2ImgWorking(unittest.TestCase):
8 | def setUp(self):
9 | self.url_img2img = "http://localhost:7860/sdapi/v1/img2img"
10 | self.simple_img2img = {
11 | "init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))],
12 | "resize_mode": 0,
13 | "denoising_strength": 0.75,
14 | "mask": None,
15 | "mask_blur": 4,
16 | "inpainting_fill": 0,
17 | "inpaint_full_res": False,
18 | "inpaint_full_res_padding": 0,
19 | "inpainting_mask_invert": False,
20 | "prompt": "example prompt",
21 | "styles": [],
22 | "seed": -1,
23 | "subseed": -1,
24 | "subseed_strength": 0,
25 | "seed_resize_from_h": -1,
26 | "seed_resize_from_w": -1,
27 | "batch_size": 1,
28 | "n_iter": 1,
29 | "steps": 3,
30 | "cfg_scale": 7,
31 | "width": 64,
32 | "height": 64,
33 | "restore_faces": False,
34 | "tiling": False,
35 | "negative_prompt": "",
36 | "eta": 0,
37 | "s_churn": 0,
38 | "s_tmax": 0,
39 | "s_tmin": 0,
40 | "s_noise": 1,
41 | "override_settings": {},
42 | "sampler_index": "Euler a",
43 | "include_init_images": False
44 | }
45 |
46 | def test_img2img_simple_performed(self):
47 | self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
48 |
49 | def test_inpainting_masked_performed(self):
50 | self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
51 | self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
52 |
53 | def test_inpainting_with_inverted_masked_performed(self):
54 | self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
55 | self.simple_img2img["inpainting_mask_invert"] = True
56 | self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
57 |
58 | def test_img2img_sd_upscale_performed(self):
59 | self.simple_img2img["script_name"] = "sd upscale"
60 | self.simple_img2img["script_args"] = ["", 8, "Lanczos", 2.0]
61 |
62 | self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
63 |
64 |
65 | if __name__ == "__main__":
66 | unittest.main()
67 |
--------------------------------------------------------------------------------
/modules/ui_tempdir.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | from collections import namedtuple
4 | from pathlib import Path
5 |
6 | import gradio as gr
7 |
8 | from PIL import PngImagePlugin
9 |
10 | from modules import shared
11 |
12 |
13 | Savedfile = namedtuple("Savedfile", ["name"])
14 |
15 |
16 | def register_tmp_file(gradio, filename):
17 | if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
18 | gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
19 |
20 | if hasattr(gradio, 'temp_dirs'): # gradio 3.9
21 | gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
22 |
23 |
24 | def check_tmp_file(gradio, filename):
25 | if hasattr(gradio, 'temp_file_sets'):
26 | return any([filename in fileset for fileset in gradio.temp_file_sets])
27 |
28 | if hasattr(gradio, 'temp_dirs'):
29 | return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
30 |
31 | return False
32 |
33 |
34 | def save_pil_to_file(pil_image, dir=None):
35 | already_saved_as = getattr(pil_image, 'already_saved_as', None)
36 | if already_saved_as and os.path.isfile(already_saved_as):
37 | register_tmp_file(shared.demo, already_saved_as)
38 |
39 | file_obj = Savedfile(already_saved_as)
40 | return file_obj
41 |
42 | if shared.opts.temp_dir != "":
43 | dir = shared.opts.temp_dir
44 |
45 | use_metadata = False
46 | metadata = PngImagePlugin.PngInfo()
47 | for key, value in pil_image.info.items():
48 | if isinstance(key, str) and isinstance(value, str):
49 | metadata.add_text(key, value)
50 | use_metadata = True
51 |
52 | file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
53 | pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
54 | return file_obj
55 |
56 |
57 | # override save to file function so that it also writes PNG info
58 | gr.processing_utils.save_pil_to_file = save_pil_to_file
59 |
60 |
61 | def on_tmpdir_changed():
62 | if shared.opts.temp_dir == "" or shared.demo is None:
63 | return
64 |
65 | os.makedirs(shared.opts.temp_dir, exist_ok=True)
66 |
67 | register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
68 |
69 |
70 | def cleanup_tmpdr():
71 | temp_dir = shared.opts.temp_dir
72 | if temp_dir == "" or not os.path.isdir(temp_dir):
73 | return
74 |
75 | for root, dirs, files in os.walk(temp_dir, topdown=False):
76 | for name in files:
77 | _, extension = os.path.splitext(name)
78 | if extension != ".png":
79 | continue
80 |
81 | filename = os.path.join(root, name)
82 | os.remove(filename)
83 |
--------------------------------------------------------------------------------
/modules/memmon.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import time
3 | from collections import defaultdict
4 |
5 | import torch
6 |
7 |
8 | class MemUsageMonitor(threading.Thread):
9 | run_flag = None
10 | device = None
11 | disabled = False
12 | opts = None
13 | data = None
14 |
15 | def __init__(self, name, device, opts):
16 | threading.Thread.__init__(self)
17 | self.name = name
18 | self.device = device
19 | self.opts = opts
20 |
21 | self.daemon = True
22 | self.run_flag = threading.Event()
23 | self.data = defaultdict(int)
24 |
25 | try:
26 | torch.cuda.mem_get_info()
27 | torch.cuda.memory_stats(self.device)
28 | except Exception as e: # AMD or whatever
29 | print(f"Warning: caught exception '{e}', memory monitor disabled")
30 | self.disabled = True
31 |
32 | def run(self):
33 | if self.disabled:
34 | return
35 |
36 | while True:
37 | self.run_flag.wait()
38 |
39 | torch.cuda.reset_peak_memory_stats()
40 | self.data.clear()
41 |
42 | if self.opts.memmon_poll_rate <= 0:
43 | self.run_flag.clear()
44 | continue
45 |
46 | self.data["min_free"] = torch.cuda.mem_get_info()[0]
47 |
48 | while self.run_flag.is_set():
49 | free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
50 | self.data["min_free"] = min(self.data["min_free"], free)
51 |
52 | time.sleep(1 / self.opts.memmon_poll_rate)
53 |
54 | def dump_debug(self):
55 | print(self, 'recorded data:')
56 | for k, v in self.read().items():
57 | print(k, -(v // -(1024 ** 2)))
58 |
59 | print(self, 'raw torch memory stats:')
60 | tm = torch.cuda.memory_stats(self.device)
61 | for k, v in tm.items():
62 | if 'bytes' not in k:
63 | continue
64 | print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
65 |
66 | print(torch.cuda.memory_summary())
67 |
68 | def monitor(self):
69 | self.run_flag.set()
70 |
71 | def read(self):
72 | if not self.disabled:
73 | free, total = torch.cuda.mem_get_info()
74 | self.data["free"] = free
75 | self.data["total"] = total
76 |
77 | torch_stats = torch.cuda.memory_stats(self.device)
78 | self.data["active"] = torch_stats["active.all.current"]
79 | self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
80 | self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
81 | self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
82 | self.data["system_peak"] = total - self.data["min_free"]
83 |
84 | return self.data
85 |
86 | def stop(self):
87 | self.run_flag.clear()
88 | return self.read()
89 |
--------------------------------------------------------------------------------
/modules/ui_postprocessing.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
3 | import modules.generation_parameters_copypaste as parameters_copypaste
4 |
5 |
6 | def create_ui():
7 | tab_index = gr.State(value=0)
8 |
9 | with gr.Row().style(equal_height=False, variant='compact'):
10 | with gr.Column(variant='compact'):
11 | with gr.Tabs(elem_id="mode_extras"):
12 | with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single:
13 | extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
14 |
15 | with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
16 | image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
17 |
18 | with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
19 | extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
20 | extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
21 | show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
22 |
23 | submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
24 |
25 | script_inputs = scripts.scripts_postproc.setup_ui()
26 |
27 | with gr.Column():
28 | result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
29 |
30 | tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
31 | tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
32 | tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
33 |
34 | submit.click(
35 | fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
36 | inputs=[
37 | tab_index,
38 | extras_image,
39 | image_batch,
40 | extras_batch_input_dir,
41 | extras_batch_output_dir,
42 | show_extras_results,
43 | *script_inputs
44 | ],
45 | outputs=[
46 | result_images,
47 | html_info_x,
48 | html_info,
49 | ]
50 | )
51 |
52 | parameters_copypaste.add_paste_fields("extras", extras_image, None)
53 |
54 | extras_image.change(
55 | fn=scripts.scripts_postproc.image_changed,
56 | inputs=[], outputs=[]
57 | )
58 |
--------------------------------------------------------------------------------
/configs/instruct-pix2pix.yaml:
--------------------------------------------------------------------------------
1 | # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2 | # See more details in LICENSE.
3 |
4 | model:
5 | base_learning_rate: 1.0e-04
6 | target: modules.models.diffusion.ddpm_edit.LatentDiffusion
7 | params:
8 | linear_start: 0.00085
9 | linear_end: 0.0120
10 | num_timesteps_cond: 1
11 | log_every_t: 200
12 | timesteps: 1000
13 | first_stage_key: edited
14 | cond_stage_key: edit
15 | # image_size: 64
16 | # image_size: 32
17 | image_size: 16
18 | channels: 4
19 | cond_stage_trainable: false # Note: different from the one we trained before
20 | conditioning_key: hybrid
21 | monitor: val/loss_simple_ema
22 | scale_factor: 0.18215
23 | use_ema: false
24 |
25 | scheduler_config: # 10000 warmup steps
26 | target: ldm.lr_scheduler.LambdaLinearScheduler
27 | params:
28 | warm_up_steps: [ 0 ]
29 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
30 | f_start: [ 1.e-6 ]
31 | f_max: [ 1. ]
32 | f_min: [ 1. ]
33 |
34 | unet_config:
35 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
36 | params:
37 | image_size: 32 # unused
38 | in_channels: 8
39 | out_channels: 4
40 | model_channels: 320
41 | attention_resolutions: [ 4, 2, 1 ]
42 | num_res_blocks: 2
43 | channel_mult: [ 1, 2, 4, 4 ]
44 | num_heads: 8
45 | use_spatial_transformer: True
46 | transformer_depth: 1
47 | context_dim: 768
48 | use_checkpoint: True
49 | legacy: False
50 |
51 | first_stage_config:
52 | target: ldm.models.autoencoder.AutoencoderKL
53 | params:
54 | embed_dim: 4
55 | monitor: val/rec_loss
56 | ddconfig:
57 | double_z: true
58 | z_channels: 4
59 | resolution: 256
60 | in_channels: 3
61 | out_ch: 3
62 | ch: 128
63 | ch_mult:
64 | - 1
65 | - 2
66 | - 4
67 | - 4
68 | num_res_blocks: 2
69 | attn_resolutions: []
70 | dropout: 0.0
71 | lossconfig:
72 | target: torch.nn.Identity
73 |
74 | cond_stage_config:
75 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
76 |
77 | data:
78 | target: main.DataModuleFromConfig
79 | params:
80 | batch_size: 128
81 | num_workers: 1
82 | wrap: false
83 | validation:
84 | target: edit_dataset.EditDataset
85 | params:
86 | path: data/clip-filtered-dataset
87 | cache_dir: data/
88 | cache_name: data_10k
89 | split: val
90 | min_text_sim: 0.2
91 | min_image_sim: 0.75
92 | min_direction_sim: 0.2
93 | max_samples_per_prompt: 1
94 | min_resize_res: 512
95 | max_resize_res: 512
96 | crop_res: 512
97 | output_as_edit: False
98 | real_input: True
99 |
--------------------------------------------------------------------------------
/extensions-builtin/LDSR/scripts/ldsr_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import traceback
4 |
5 | from basicsr.utils.download_util import load_file_from_url
6 |
7 | from modules.upscaler import Upscaler, UpscalerData
8 | from ldsr_model_arch import LDSR
9 | from modules import shared, script_callbacks
10 | import sd_hijack_autoencoder, sd_hijack_ddpm_v1
11 |
12 |
13 | class UpscalerLDSR(Upscaler):
14 | def __init__(self, user_path):
15 | self.name = "LDSR"
16 | self.user_path = user_path
17 | self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
18 | self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
19 | super().__init__()
20 | scaler_data = UpscalerData("LDSR", None, self)
21 | self.scalers = [scaler_data]
22 |
23 | def load_model(self, path: str):
24 | # Remove incorrect project.yaml file if too big
25 | yaml_path = os.path.join(self.model_path, "project.yaml")
26 | old_model_path = os.path.join(self.model_path, "model.pth")
27 | new_model_path = os.path.join(self.model_path, "model.ckpt")
28 | safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
29 | if os.path.exists(yaml_path):
30 | statinfo = os.stat(yaml_path)
31 | if statinfo.st_size >= 10485760:
32 | print("Removing invalid LDSR YAML file.")
33 | os.remove(yaml_path)
34 | if os.path.exists(old_model_path):
35 | print("Renaming model from model.pth to model.ckpt")
36 | os.rename(old_model_path, new_model_path)
37 | if os.path.exists(safetensors_model_path):
38 | model = safetensors_model_path
39 | else:
40 | model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
41 | file_name="model.ckpt", progress=True)
42 | yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
43 | file_name="project.yaml", progress=True)
44 |
45 | try:
46 | return LDSR(model, yaml)
47 |
48 | except Exception:
49 | print("Error importing LDSR:", file=sys.stderr)
50 | print(traceback.format_exc(), file=sys.stderr)
51 | return None
52 |
53 | def do_upscale(self, img, path):
54 | ldsr = self.load_model(path)
55 | if ldsr is None:
56 | print("NO LDSR!")
57 | return img
58 | ddim_steps = shared.opts.ldsr_steps
59 | return ldsr.super_resolution(img, ddim_steps, self.scale)
60 |
61 |
62 | def on_ui_settings():
63 | import gradio as gr
64 |
65 | shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
66 | shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
67 |
68 |
69 | script_callbacks.on_ui_settings(on_ui_settings)
70 |
--------------------------------------------------------------------------------
/modules/textual_inversion/learn_schedule.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 |
3 |
4 | class LearnScheduleIterator:
5 | def __init__(self, learn_rate, max_steps, cur_step=0):
6 | """
7 | specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
8 | """
9 |
10 | pairs = learn_rate.split(',')
11 | self.rates = []
12 | self.it = 0
13 | self.maxit = 0
14 | try:
15 | for i, pair in enumerate(pairs):
16 | if not pair.strip():
17 | continue
18 | tmp = pair.split(':')
19 | if len(tmp) == 2:
20 | step = int(tmp[1])
21 | if step > cur_step:
22 | self.rates.append((float(tmp[0]), min(step, max_steps)))
23 | self.maxit += 1
24 | if step > max_steps:
25 | return
26 | elif step == -1:
27 | self.rates.append((float(tmp[0]), max_steps))
28 | self.maxit += 1
29 | return
30 | else:
31 | self.rates.append((float(tmp[0]), max_steps))
32 | self.maxit += 1
33 | return
34 | assert self.rates
35 | except (ValueError, AssertionError):
36 | raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
37 |
38 |
39 | def __iter__(self):
40 | return self
41 |
42 | def __next__(self):
43 | if self.it < self.maxit:
44 | self.it += 1
45 | return self.rates[self.it - 1]
46 | else:
47 | raise StopIteration
48 |
49 |
50 | class LearnRateScheduler:
51 | def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
52 | self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
53 | (self.learn_rate, self.end_step) = next(self.schedules)
54 | self.verbose = verbose
55 |
56 | if self.verbose:
57 | print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
58 |
59 | self.finished = False
60 |
61 | def step(self, step_number):
62 | if step_number < self.end_step:
63 | return False
64 |
65 | try:
66 | (self.learn_rate, self.end_step) = next(self.schedules)
67 | except StopIteration:
68 | self.finished = True
69 | return False
70 | return True
71 |
72 | def apply(self, optimizer, step_number):
73 | if not self.step(step_number):
74 | return
75 |
76 | if self.verbose:
77 | tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
78 |
79 | for pg in optimizer.param_groups:
80 | pg['lr'] = self.learn_rate
81 |
82 |
--------------------------------------------------------------------------------
/modules/txt2img.py:
--------------------------------------------------------------------------------
1 | import modules.scripts
2 | from modules import sd_samplers
3 | from modules.generation_parameters_copypaste import create_override_settings_dict
4 | from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
5 | StableDiffusionProcessingImg2Img, process_images
6 | from modules.shared import opts, cmd_opts
7 | import modules.shared as shared
8 | import modules.processing as processing
9 | from modules.ui import plaintext_to_html
10 |
11 |
12 | def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
13 | override_settings = create_override_settings_dict(override_settings_texts)
14 |
15 | p = StableDiffusionProcessingTxt2Img(
16 | sd_model=shared.sd_model,
17 | outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
18 | outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
19 | prompt=prompt,
20 | styles=prompt_styles,
21 | negative_prompt=negative_prompt,
22 | seed=seed,
23 | subseed=subseed,
24 | subseed_strength=subseed_strength,
25 | seed_resize_from_h=seed_resize_from_h,
26 | seed_resize_from_w=seed_resize_from_w,
27 | seed_enable_extras=seed_enable_extras,
28 | sampler_name=sd_samplers.samplers[sampler_index].name,
29 | batch_size=batch_size,
30 | n_iter=n_iter,
31 | steps=steps,
32 | cfg_scale=cfg_scale,
33 | width=width,
34 | height=height,
35 | restore_faces=restore_faces,
36 | tiling=tiling,
37 | enable_hr=enable_hr,
38 | denoising_strength=denoising_strength if enable_hr else None,
39 | hr_scale=hr_scale,
40 | hr_upscaler=hr_upscaler,
41 | hr_second_pass_steps=hr_second_pass_steps,
42 | hr_resize_x=hr_resize_x,
43 | hr_resize_y=hr_resize_y,
44 | override_settings=override_settings,
45 | )
46 |
47 | p.scripts = modules.scripts.scripts_txt2img
48 | p.script_args = args
49 |
50 | if cmd_opts.enable_console_prompts:
51 | print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
52 |
53 | processed = modules.scripts.scripts_txt2img.run(p, *args)
54 |
55 | if processed is None:
56 | processed = process_images(p)
57 |
58 | p.close()
59 |
60 | shared.total_tqdm.clear()
61 |
62 | generation_info_js = processed.js()
63 | if opts.samples_log_stdout:
64 | print(generation_info_js)
65 |
66 | if opts.do_not_show_images:
67 | processed.images = []
68 |
69 | return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
70 |
--------------------------------------------------------------------------------
/script.js:
--------------------------------------------------------------------------------
1 | function gradioApp() {
2 | const elems = document.getElementsByTagName('gradio-app')
3 | const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
4 | return !!gradioShadowRoot ? gradioShadowRoot : document;
5 | }
6 |
7 | function get_uiCurrentTab() {
8 | return gradioApp().querySelector('#tabs button:not(.border-transparent)')
9 | }
10 |
11 | function get_uiCurrentTabContent() {
12 | return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])')
13 | }
14 |
15 | uiUpdateCallbacks = []
16 | uiLoadedCallbacks = []
17 | uiTabChangeCallbacks = []
18 | optionsChangedCallbacks = []
19 | let uiCurrentTab = null
20 |
21 | function onUiUpdate(callback){
22 | uiUpdateCallbacks.push(callback)
23 | }
24 | function onUiLoaded(callback){
25 | uiLoadedCallbacks.push(callback)
26 | }
27 | function onUiTabChange(callback){
28 | uiTabChangeCallbacks.push(callback)
29 | }
30 | function onOptionsChanged(callback){
31 | optionsChangedCallbacks.push(callback)
32 | }
33 |
34 | function runCallback(x, m){
35 | try {
36 | x(m)
37 | } catch (e) {
38 | (console.error || console.log).call(console, e.message, e);
39 | }
40 | }
41 | function executeCallbacks(queue, m) {
42 | queue.forEach(function(x){runCallback(x, m)})
43 | }
44 |
45 | var executedOnLoaded = false;
46 |
47 | document.addEventListener("DOMContentLoaded", function() {
48 | var mutationObserver = new MutationObserver(function(m){
49 | if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){
50 | executedOnLoaded = true;
51 | executeCallbacks(uiLoadedCallbacks);
52 | }
53 |
54 | executeCallbacks(uiUpdateCallbacks, m);
55 | const newTab = get_uiCurrentTab();
56 | if ( newTab && ( newTab !== uiCurrentTab ) ) {
57 | uiCurrentTab = newTab;
58 | executeCallbacks(uiTabChangeCallbacks);
59 | }
60 | });
61 | mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
62 | });
63 |
64 | /**
65 | * Add a ctrl+enter as a shortcut to start a generation
66 | */
67 | document.addEventListener('keydown', function(e) {
68 | var handled = false;
69 | if (e.key !== undefined) {
70 | if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
71 | } else if (e.keyCode !== undefined) {
72 | if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
73 | }
74 | if (handled) {
75 | button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
76 | if (button) {
77 | button.click();
78 | }
79 | e.preventDefault();
80 | }
81 | })
82 |
83 | /**
84 | * checks that a UI element is not in another hidden element or tab content
85 | */
86 | function uiElementIsVisible(el) {
87 | let isVisible = !el.closest('.\\!hidden');
88 | if ( ! isVisible ) {
89 | return false;
90 | }
91 |
92 | while( isVisible = el.closest('.tabitem')?.style.display !== 'none' ) {
93 | if ( ! isVisible ) {
94 | return false;
95 | } else if ( el.parentElement ) {
96 | el = el.parentElement
97 | } else {
98 | break;
99 | }
100 | }
101 | return isVisible;
102 | }
103 |
--------------------------------------------------------------------------------
/modules/deepbooru.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 |
4 | import torch
5 | from PIL import Image
6 | import numpy as np
7 |
8 | from modules import modelloader, paths, deepbooru_model, devices, images, shared
9 |
10 | re_special = re.compile(r'([\\()])')
11 |
12 |
13 | class DeepDanbooru:
14 | def __init__(self):
15 | self.model = None
16 |
17 | def load(self):
18 | if self.model is not None:
19 | return
20 |
21 | files = modelloader.load_models(
22 | model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
23 | model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
24 | ext_filter=[".pt"],
25 | download_name='model-resnet_custom_v3.pt',
26 | )
27 |
28 | self.model = deepbooru_model.DeepDanbooruModel()
29 | self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
30 |
31 | self.model.eval()
32 | self.model.to(devices.cpu, devices.dtype)
33 |
34 | def start(self):
35 | self.load()
36 | self.model.to(devices.device)
37 |
38 | def stop(self):
39 | if not shared.opts.interrogate_keep_models_in_memory:
40 | self.model.to(devices.cpu)
41 | devices.torch_gc()
42 |
43 | def tag(self, pil_image):
44 | self.start()
45 | res = self.tag_multi(pil_image)
46 | self.stop()
47 |
48 | return res
49 |
50 | def tag_multi(self, pil_image, force_disable_ranks=False):
51 | threshold = shared.opts.interrogate_deepbooru_score_threshold
52 | use_spaces = shared.opts.deepbooru_use_spaces
53 | use_escape = shared.opts.deepbooru_escape
54 | alpha_sort = shared.opts.deepbooru_sort_alpha
55 | include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
56 |
57 | pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
58 | a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
59 |
60 | with torch.no_grad(), devices.autocast():
61 | x = torch.from_numpy(a).to(devices.device)
62 | y = self.model(x)[0].detach().cpu().numpy()
63 |
64 | probability_dict = {}
65 |
66 | for tag, probability in zip(self.model.tags, y):
67 | if probability < threshold:
68 | continue
69 |
70 | if tag.startswith("rating:"):
71 | continue
72 |
73 | probability_dict[tag] = probability
74 |
75 | if alpha_sort:
76 | tags = sorted(probability_dict)
77 | else:
78 | tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
79 |
80 | res = []
81 |
82 | filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
83 |
84 | for tag in [x for x in tags if x not in filtertags]:
85 | probability = probability_dict[tag]
86 | tag_outformat = tag
87 | if use_spaces:
88 | tag_outformat = tag_outformat.replace('_', ' ')
89 | if use_escape:
90 | tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
91 | if include_ranks:
92 | tag_outformat = f"({tag_outformat}:{probability:.3f})"
93 |
94 | res.append(tag_outformat)
95 |
96 | return ", ".join(res)
97 |
98 |
99 | model = DeepDanbooru()
100 |
--------------------------------------------------------------------------------
/modules/styles.py:
--------------------------------------------------------------------------------
1 | # We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
2 | from __future__ import annotations
3 |
4 | import csv
5 | import os
6 | import os.path
7 | import typing
8 | import collections.abc as abc
9 | import tempfile
10 | import shutil
11 |
12 | if typing.TYPE_CHECKING:
13 | # Only import this when code is being type-checked, it doesn't have any effect at runtime
14 | from .processing import StableDiffusionProcessing
15 |
16 |
17 | class PromptStyle(typing.NamedTuple):
18 | name: str
19 | prompt: str
20 | negative_prompt: str
21 |
22 |
23 | def merge_prompts(style_prompt: str, prompt: str) -> str:
24 | if "{prompt}" in style_prompt:
25 | res = style_prompt.replace("{prompt}", prompt)
26 | else:
27 | parts = filter(None, (prompt.strip(), style_prompt.strip()))
28 | res = ", ".join(parts)
29 |
30 | return res
31 |
32 |
33 | def apply_styles_to_prompt(prompt, styles):
34 | for style in styles:
35 | prompt = merge_prompts(style, prompt)
36 |
37 | return prompt
38 |
39 |
40 | class StyleDatabase:
41 | def __init__(self, path: str):
42 | self.no_style = PromptStyle("None", "", "")
43 | self.styles = {}
44 | self.path = path
45 |
46 | self.reload()
47 |
48 | def reload(self):
49 | self.styles.clear()
50 |
51 | if not os.path.exists(self.path):
52 | return
53 |
54 | with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
55 | reader = csv.DictReader(file)
56 | for row in reader:
57 | # Support loading old CSV format with "name, text"-columns
58 | prompt = row["prompt"] if "prompt" in row else row["text"]
59 | negative_prompt = row.get("negative_prompt", "")
60 | self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
61 |
62 | def get_style_prompts(self, styles):
63 | return [self.styles.get(x, self.no_style).prompt for x in styles]
64 |
65 | def get_negative_style_prompts(self, styles):
66 | return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
67 |
68 | def apply_styles_to_prompt(self, prompt, styles):
69 | return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
70 |
71 | def apply_negative_styles_to_prompt(self, prompt, styles):
72 | return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
73 |
74 | def save_styles(self, path: str) -> None:
75 | # Write to temporary file first, so we don't nuke the file if something goes wrong
76 | fd, temp_path = tempfile.mkstemp(".csv")
77 | with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
78 | # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
79 | # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
80 | writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
81 | writer.writeheader()
82 | writer.writerows(style._asdict() for k, style in self.styles.items())
83 |
84 | # Always keep a backup file around
85 | if os.path.exists(path):
86 | shutil.move(path, path + ".bak")
87 | shutil.move(temp_path, path)
88 |
--------------------------------------------------------------------------------
/javascript/dragdrop.js:
--------------------------------------------------------------------------------
1 | // allows drag-dropping files into gradio image elements, and also pasting images from clipboard
2 |
3 | function isValidImageList( files ) {
4 | return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
5 | }
6 |
7 | function dropReplaceImage( imgWrap, files ) {
8 | if ( ! isValidImageList( files ) ) {
9 | return;
10 | }
11 |
12 | const tmpFile = files[0];
13 |
14 | imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
15 | const callback = () => {
16 | const fileInput = imgWrap.querySelector('input[type="file"]');
17 | if ( fileInput ) {
18 | if ( files.length === 0 ) {
19 | files = new DataTransfer();
20 | files.items.add(tmpFile);
21 | fileInput.files = files.files;
22 | } else {
23 | fileInput.files = files;
24 | }
25 | fileInput.dispatchEvent(new Event('change'));
26 | }
27 | };
28 |
29 | if ( imgWrap.closest('#pnginfo_image') ) {
30 | // special treatment for PNG Info tab, wait for fetch request to finish
31 | const oldFetch = window.fetch;
32 | window.fetch = async (input, options) => {
33 | const response = await oldFetch(input, options);
34 | if ( 'api/predict/' === input ) {
35 | const content = await response.text();
36 | window.fetch = oldFetch;
37 | window.requestAnimationFrame( () => callback() );
38 | return new Response(content, {
39 | status: response.status,
40 | statusText: response.statusText,
41 | headers: response.headers
42 | })
43 | }
44 | return response;
45 | };
46 | } else {
47 | window.requestAnimationFrame( () => callback() );
48 | }
49 | }
50 |
51 | window.document.addEventListener('dragover', e => {
52 | const target = e.composedPath()[0];
53 | const imgWrap = target.closest('[data-testid="image"]');
54 | if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
55 | return;
56 | }
57 | e.stopPropagation();
58 | e.preventDefault();
59 | e.dataTransfer.dropEffect = 'copy';
60 | });
61 |
62 | window.document.addEventListener('drop', e => {
63 | const target = e.composedPath()[0];
64 | if (target.placeholder.indexOf("Prompt") == -1) {
65 | return;
66 | }
67 | const imgWrap = target.closest('[data-testid="image"]');
68 | if ( !imgWrap ) {
69 | return;
70 | }
71 | e.stopPropagation();
72 | e.preventDefault();
73 | const files = e.dataTransfer.files;
74 | dropReplaceImage( imgWrap, files );
75 | });
76 |
77 | window.addEventListener('paste', e => {
78 | const files = e.clipboardData.files;
79 | if ( ! isValidImageList( files ) ) {
80 | return;
81 | }
82 |
83 | const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
84 | .filter(el => uiElementIsVisible(el));
85 | if ( ! visibleImageFields.length ) {
86 | return;
87 | }
88 |
89 | const firstFreeImageField = visibleImageFields
90 | .filter(el => el.querySelector('input[type=file]'))?.[0];
91 |
92 | dropReplaceImage(
93 | firstFreeImageField ?
94 | firstFreeImageField :
95 | visibleImageFields[visibleImageFields.length - 1]
96 | , files );
97 | });
98 |
--------------------------------------------------------------------------------
/extensions-builtin/ScuNET/scripts/scunet_model.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import sys
3 | import traceback
4 |
5 | import PIL.Image
6 | import numpy as np
7 | import torch
8 | from basicsr.utils.download_util import load_file_from_url
9 |
10 | import modules.upscaler
11 | from modules import devices, modelloader
12 | from scunet_model_arch import SCUNet as net
13 |
14 |
15 | class UpscalerScuNET(modules.upscaler.Upscaler):
16 | def __init__(self, dirname):
17 | self.name = "ScuNET"
18 | self.model_name = "ScuNET GAN"
19 | self.model_name2 = "ScuNET PSNR"
20 | self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
21 | self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
22 | self.user_path = dirname
23 | super().__init__()
24 | model_paths = self.find_models(ext_filter=[".pth"])
25 | scalers = []
26 | add_model2 = True
27 | for file in model_paths:
28 | if "http" in file:
29 | name = self.model_name
30 | else:
31 | name = modelloader.friendly_name(file)
32 | if name == self.model_name2 or file == self.model_url2:
33 | add_model2 = False
34 | try:
35 | scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
36 | scalers.append(scaler_data)
37 | except Exception:
38 | print(f"Error loading ScuNET model: {file}", file=sys.stderr)
39 | print(traceback.format_exc(), file=sys.stderr)
40 | if add_model2:
41 | scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
42 | scalers.append(scaler_data2)
43 | self.scalers = scalers
44 |
45 | def do_upscale(self, img: PIL.Image, selected_file):
46 | torch.cuda.empty_cache()
47 |
48 | model = self.load_model(selected_file)
49 | if model is None:
50 | return img
51 |
52 | device = devices.get_device_for('scunet')
53 | img = np.array(img)
54 | img = img[:, :, ::-1]
55 | img = np.moveaxis(img, 2, 0) / 255
56 | img = torch.from_numpy(img).float()
57 | img = img.unsqueeze(0).to(device)
58 |
59 | with torch.no_grad():
60 | output = model(img)
61 | output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
62 | output = 255. * np.moveaxis(output, 0, 2)
63 | output = output.astype(np.uint8)
64 | output = output[:, :, ::-1]
65 | torch.cuda.empty_cache()
66 | return PIL.Image.fromarray(output, 'RGB')
67 |
68 | def load_model(self, path: str):
69 | device = devices.get_device_for('scunet')
70 | if "http" in path:
71 | filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
72 | progress=True)
73 | else:
74 | filename = path
75 | if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
76 | print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
77 | return None
78 |
79 | model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
80 | model.load_state_dict(torch.load(filename), strict=True)
81 | model.eval()
82 | for k, v in model.named_parameters():
83 | v.requires_grad = False
84 | model = model.to(device)
85 |
86 | return model
87 |
88 |
--------------------------------------------------------------------------------
/test/basic_features/txt2img_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import requests
3 |
4 |
5 | class TestTxt2ImgWorking(unittest.TestCase):
6 | def setUp(self):
7 | self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img"
8 | self.simple_txt2img = {
9 | "enable_hr": False,
10 | "denoising_strength": 0,
11 | "firstphase_width": 0,
12 | "firstphase_height": 0,
13 | "prompt": "example prompt",
14 | "styles": [],
15 | "seed": -1,
16 | "subseed": -1,
17 | "subseed_strength": 0,
18 | "seed_resize_from_h": -1,
19 | "seed_resize_from_w": -1,
20 | "batch_size": 1,
21 | "n_iter": 1,
22 | "steps": 3,
23 | "cfg_scale": 7,
24 | "width": 64,
25 | "height": 64,
26 | "restore_faces": False,
27 | "tiling": False,
28 | "negative_prompt": "",
29 | "eta": 0,
30 | "s_churn": 0,
31 | "s_tmax": 0,
32 | "s_tmin": 0,
33 | "s_noise": 1,
34 | "sampler_index": "Euler a"
35 | }
36 |
37 | def test_txt2img_simple_performed(self):
38 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
39 |
40 | def test_txt2img_with_negative_prompt_performed(self):
41 | self.simple_txt2img["negative_prompt"] = "example negative prompt"
42 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
43 |
44 | def test_txt2img_with_complex_prompt_performed(self):
45 | self.simple_txt2img["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]"
46 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
47 |
48 | def test_txt2img_not_square_image_performed(self):
49 | self.simple_txt2img["height"] = 128
50 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
51 |
52 | def test_txt2img_with_hrfix_performed(self):
53 | self.simple_txt2img["enable_hr"] = True
54 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
55 |
56 | def test_txt2img_with_tiling_performed(self):
57 | self.simple_txt2img["tiling"] = True
58 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
59 |
60 | def test_txt2img_with_restore_faces_performed(self):
61 | self.simple_txt2img["restore_faces"] = True
62 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
63 |
64 | def test_txt2img_with_vanilla_sampler_performed(self):
65 | self.simple_txt2img["sampler_index"] = "PLMS"
66 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
67 | self.simple_txt2img["sampler_index"] = "DDIM"
68 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
69 |
70 | def test_txt2img_multiple_batches_performed(self):
71 | self.simple_txt2img["n_iter"] = 2
72 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
73 |
74 | def test_txt2img_batch_performed(self):
75 | self.simple_txt2img["batch_size"] = 2
76 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
77 |
78 |
79 | if __name__ == "__main__":
80 | unittest.main()
81 |
--------------------------------------------------------------------------------
/modules/mac_specific.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from modules import paths
3 | from modules.sd_hijack_utils import CondFunc
4 | from packaging import version
5 |
6 |
7 | # has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
8 | # check `getattr` and try it for compatibility
9 | def check_for_mps() -> bool:
10 | if not getattr(torch, 'has_mps', False):
11 | return False
12 | try:
13 | torch.zeros(1).to(torch.device("mps"))
14 | return True
15 | except Exception:
16 | return False
17 | has_mps = check_for_mps()
18 |
19 |
20 | # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
21 | def cumsum_fix(input, cumsum_func, *args, **kwargs):
22 | if input.device.type == 'mps':
23 | output_dtype = kwargs.get('dtype', input.dtype)
24 | if output_dtype == torch.int64:
25 | return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
26 | elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
27 | return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
28 | return cumsum_func(input, *args, **kwargs)
29 |
30 |
31 | if has_mps:
32 | # MPS fix for randn in torchsde
33 | CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
34 |
35 | if version.parse(torch.__version__) < version.parse("1.13"):
36 | # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
37 |
38 | # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
39 | CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
40 | lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
41 | # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
42 | CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
43 | lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
44 | # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
45 | CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
46 | elif version.parse(torch.__version__) > version.parse("1.13.1"):
47 | cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
48 | cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
49 | cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
50 | CondFunc('torch.cumsum', cumsum_fix_func, None)
51 | CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
52 | CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
53 |
54 |
--------------------------------------------------------------------------------
/modules/masking.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageFilter, ImageOps
2 |
3 |
4 | def get_crop_region(mask, pad=0):
5 | """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
6 | For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
7 |
8 | h, w = mask.shape
9 |
10 | crop_left = 0
11 | for i in range(w):
12 | if not (mask[:, i] == 0).all():
13 | break
14 | crop_left += 1
15 |
16 | crop_right = 0
17 | for i in reversed(range(w)):
18 | if not (mask[:, i] == 0).all():
19 | break
20 | crop_right += 1
21 |
22 | crop_top = 0
23 | for i in range(h):
24 | if not (mask[i] == 0).all():
25 | break
26 | crop_top += 1
27 |
28 | crop_bottom = 0
29 | for i in reversed(range(h)):
30 | if not (mask[i] == 0).all():
31 | break
32 | crop_bottom += 1
33 |
34 | return (
35 | int(max(crop_left-pad, 0)),
36 | int(max(crop_top-pad, 0)),
37 | int(min(w - crop_right + pad, w)),
38 | int(min(h - crop_bottom + pad, h))
39 | )
40 |
41 |
42 | def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
43 | """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
44 | for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
45 |
46 | x1, y1, x2, y2 = crop_region
47 |
48 | ratio_crop_region = (x2 - x1) / (y2 - y1)
49 | ratio_processing = processing_width / processing_height
50 |
51 | if ratio_crop_region > ratio_processing:
52 | desired_height = (x2 - x1) / ratio_processing
53 | desired_height_diff = int(desired_height - (y2-y1))
54 | y1 -= desired_height_diff//2
55 | y2 += desired_height_diff - desired_height_diff//2
56 | if y2 >= image_height:
57 | diff = y2 - image_height
58 | y2 -= diff
59 | y1 -= diff
60 | if y1 < 0:
61 | y2 -= y1
62 | y1 -= y1
63 | if y2 >= image_height:
64 | y2 = image_height
65 | else:
66 | desired_width = (y2 - y1) * ratio_processing
67 | desired_width_diff = int(desired_width - (x2-x1))
68 | x1 -= desired_width_diff//2
69 | x2 += desired_width_diff - desired_width_diff//2
70 | if x2 >= image_width:
71 | diff = x2 - image_width
72 | x2 -= diff
73 | x1 -= diff
74 | if x1 < 0:
75 | x2 -= x1
76 | x1 -= x1
77 | if x2 >= image_width:
78 | x2 = image_width
79 |
80 | return x1, y1, x2, y2
81 |
82 |
83 | def fill(image, mask):
84 | """fills masked regions with colors from image using blur. Not extremely effective."""
85 |
86 | image_mod = Image.new('RGBA', (image.width, image.height))
87 |
88 | image_masked = Image.new('RGBa', (image.width, image.height))
89 | image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
90 |
91 | image_masked = image_masked.convert('RGBa')
92 |
93 | for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
94 | blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
95 | for _ in range(repeats):
96 | image_mod.alpha_composite(blurred)
97 |
98 | return image_mod.convert("RGB")
99 |
100 |
--------------------------------------------------------------------------------
/modules/sd_hijack_clip_old.py:
--------------------------------------------------------------------------------
1 | from modules import sd_hijack_clip
2 | from modules import shared
3 |
4 |
5 | def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
6 | id_start = self.id_start
7 | id_end = self.id_end
8 | maxlen = self.wrapped.max_length # you get to stay at 77
9 | used_custom_terms = []
10 | remade_batch_tokens = []
11 | hijack_comments = []
12 | hijack_fixes = []
13 | token_count = 0
14 |
15 | cache = {}
16 | batch_tokens = self.tokenize(texts)
17 | batch_multipliers = []
18 | for tokens in batch_tokens:
19 | tuple_tokens = tuple(tokens)
20 |
21 | if tuple_tokens in cache:
22 | remade_tokens, fixes, multipliers = cache[tuple_tokens]
23 | else:
24 | fixes = []
25 | remade_tokens = []
26 | multipliers = []
27 | mult = 1.0
28 |
29 | i = 0
30 | while i < len(tokens):
31 | token = tokens[i]
32 |
33 | embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
34 |
35 | mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
36 | if mult_change is not None:
37 | mult *= mult_change
38 | i += 1
39 | elif embedding is None:
40 | remade_tokens.append(token)
41 | multipliers.append(mult)
42 | i += 1
43 | else:
44 | emb_len = int(embedding.vec.shape[0])
45 | fixes.append((len(remade_tokens), embedding))
46 | remade_tokens += [0] * emb_len
47 | multipliers += [mult] * emb_len
48 | used_custom_terms.append((embedding.name, embedding.checksum()))
49 | i += embedding_length_in_tokens
50 |
51 | if len(remade_tokens) > maxlen - 2:
52 | vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
53 | ovf = remade_tokens[maxlen - 2:]
54 | overflowing_words = [vocab.get(int(x), "") for x in ovf]
55 | overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
56 | hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
57 |
58 | token_count = len(remade_tokens)
59 | remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
60 | remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
61 | cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
62 |
63 | multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
64 | multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
65 |
66 | remade_batch_tokens.append(remade_tokens)
67 | hijack_fixes.append(fixes)
68 | batch_multipliers.append(multipliers)
69 | return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
70 |
71 |
72 | def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
73 | batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
74 |
75 | self.hijack.comments += hijack_comments
76 |
77 | if len(used_custom_terms) > 0:
78 | self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
79 |
80 | self.hijack.fixes = hijack_fixes
81 | return self.process_tokens(remade_batch_tokens, batch_multipliers)
82 |
--------------------------------------------------------------------------------
/javascript/edit-attention.js:
--------------------------------------------------------------------------------
1 | function keyupEditAttention(event){
2 | let target = event.originalTarget || event.composedPath()[0];
3 | if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return;
4 | if (! (event.metaKey || event.ctrlKey)) return;
5 |
6 | let isPlus = event.key == "ArrowUp"
7 | let isMinus = event.key == "ArrowDown"
8 | if (!isPlus && !isMinus) return;
9 |
10 | let selectionStart = target.selectionStart;
11 | let selectionEnd = target.selectionEnd;
12 | let text = target.value;
13 |
14 | function selectCurrentParenthesisBlock(OPEN, CLOSE){
15 | if (selectionStart !== selectionEnd) return false;
16 |
17 | // Find opening parenthesis around current cursor
18 | const before = text.substring(0, selectionStart);
19 | let beforeParen = before.lastIndexOf(OPEN);
20 | if (beforeParen == -1) return false;
21 | let beforeParenClose = before.lastIndexOf(CLOSE);
22 | while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
23 | beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
24 | beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
25 | }
26 |
27 | // Find closing parenthesis around current cursor
28 | const after = text.substring(selectionStart);
29 | let afterParen = after.indexOf(CLOSE);
30 | if (afterParen == -1) return false;
31 | let afterParenOpen = after.indexOf(OPEN);
32 | while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
33 | afterParen = after.indexOf(CLOSE, afterParen + 1);
34 | afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
35 | }
36 | if (beforeParen === -1 || afterParen === -1) return false;
37 |
38 | // Set the selection to the text between the parenthesis
39 | const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
40 | const lastColon = parenContent.lastIndexOf(":");
41 | selectionStart = beforeParen + 1;
42 | selectionEnd = selectionStart + lastColon;
43 | target.setSelectionRange(selectionStart, selectionEnd);
44 | return true;
45 | }
46 |
47 | // If the user hasn't selected anything, let's select their current parenthesis block
48 | if(! selectCurrentParenthesisBlock('<', '>')){
49 | selectCurrentParenthesisBlock('(', ')')
50 | }
51 |
52 | event.preventDefault();
53 |
54 | closeCharacter = ')'
55 | delta = opts.keyedit_precision_attention
56 |
57 | if (selectionStart > 0 && text[selectionStart - 1] == '<'){
58 | closeCharacter = '>'
59 | delta = opts.keyedit_precision_extra
60 | } else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
61 |
62 | // do not include spaces at the end
63 | while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
64 | selectionEnd -= 1;
65 | }
66 | if(selectionStart == selectionEnd){
67 | return
68 | }
69 |
70 | text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
71 |
72 | selectionStart += 1;
73 | selectionEnd += 1;
74 | }
75 |
76 | end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
77 | weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
78 | if (isNaN(weight)) return;
79 |
80 | weight += isPlus ? delta : -delta;
81 | weight = parseFloat(weight.toPrecision(12));
82 | if(String(weight).length == 1) weight += ".0"
83 |
84 | text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
85 |
86 | target.focus();
87 | target.value = text;
88 | target.selectionStart = selectionStart;
89 | target.selectionEnd = selectionEnd;
90 |
91 | updateInput(target)
92 | }
93 |
94 | addEventListener('keydown', (event) => {
95 | keyupEditAttention(event);
96 | });
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: Bug Report
2 | description: You think somethings is broken in the UI
3 | title: "[Bug]: "
4 | labels: ["bug-report"]
5 |
6 | body:
7 | - type: checkboxes
8 | attributes:
9 | label: Is there an existing issue for this?
10 | description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
11 | options:
12 | - label: I have searched the existing issues and checked the recent builds/commits
13 | required: true
14 | - type: markdown
15 | attributes:
16 | value: |
17 | *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
18 | - type: textarea
19 | id: what-did
20 | attributes:
21 | label: What happened?
22 | description: Tell us what happened in a very clear and simple way
23 | validations:
24 | required: true
25 | - type: textarea
26 | id: steps
27 | attributes:
28 | label: Steps to reproduce the problem
29 | description: Please provide us with precise step by step information on how to reproduce the bug
30 | value: |
31 | 1. Go to ....
32 | 2. Press ....
33 | 3. ...
34 | validations:
35 | required: true
36 | - type: textarea
37 | id: what-should
38 | attributes:
39 | label: What should have happened?
40 | description: Tell what you think the normal behavior should be
41 | validations:
42 | required: true
43 | - type: input
44 | id: commit
45 | attributes:
46 | label: Commit where the problem happens
47 | description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
48 | validations:
49 | required: true
50 | - type: dropdown
51 | id: platforms
52 | attributes:
53 | label: What platforms do you use to access the UI ?
54 | multiple: true
55 | options:
56 | - Windows
57 | - Linux
58 | - MacOS
59 | - iOS
60 | - Android
61 | - Other/Cloud
62 | - type: dropdown
63 | id: browsers
64 | attributes:
65 | label: What browsers do you use to access the UI ?
66 | multiple: true
67 | options:
68 | - Mozilla Firefox
69 | - Google Chrome
70 | - Brave
71 | - Apple Safari
72 | - Microsoft Edge
73 | - type: textarea
74 | id: cmdargs
75 | attributes:
76 | label: Command Line Arguments
77 | description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
78 | render: Shell
79 | validations:
80 | required: true
81 | - type: textarea
82 | id: extensions
83 | attributes:
84 | label: List of extensions
85 | description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise.
86 | validations:
87 | required: true
88 | - type: textarea
89 | id: logs
90 | attributes:
91 | label: Console logs
92 | description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
93 | render: Shell
94 | validations:
95 | required: true
96 | - type: textarea
97 | id: misc
98 | attributes:
99 | label: Additional information
100 | description: Please provide us with any relevant additional info or context.
101 |
--------------------------------------------------------------------------------
/modules/extensions.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import traceback
4 |
5 | import time
6 | import git
7 |
8 | from modules import paths, shared
9 |
10 | extensions = []
11 | extensions_dir = os.path.join(paths.data_path, "extensions")
12 | extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
13 |
14 | if not os.path.exists(extensions_dir):
15 | os.makedirs(extensions_dir)
16 |
17 | def active():
18 | return [x for x in extensions if x.enabled]
19 |
20 |
21 | class Extension:
22 | def __init__(self, name, path, enabled=True, is_builtin=False):
23 | self.name = name
24 | self.path = path
25 | self.enabled = enabled
26 | self.status = ''
27 | self.can_update = False
28 | self.is_builtin = is_builtin
29 | self.version = ''
30 |
31 | repo = None
32 | try:
33 | if os.path.exists(os.path.join(path, ".git")):
34 | repo = git.Repo(path)
35 | except Exception:
36 | print(f"Error reading github repository info from {path}:", file=sys.stderr)
37 | print(traceback.format_exc(), file=sys.stderr)
38 |
39 | if repo is None or repo.bare:
40 | self.remote = None
41 | else:
42 | try:
43 | self.remote = next(repo.remote().urls, None)
44 | self.status = 'unknown'
45 | head = repo.head.commit
46 | ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
47 | self.version = f'{head.hexsha[:8]} ({ts})'
48 |
49 | except Exception:
50 | self.remote = None
51 |
52 | def list_files(self, subdir, extension):
53 | from modules import scripts
54 |
55 | dirpath = os.path.join(self.path, subdir)
56 | if not os.path.isdir(dirpath):
57 | return []
58 |
59 | res = []
60 | for filename in sorted(os.listdir(dirpath)):
61 | res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
62 |
63 | res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
64 |
65 | return res
66 |
67 | def check_updates(self):
68 | repo = git.Repo(self.path)
69 | for fetch in repo.remote().fetch("--dry-run"):
70 | if fetch.flags != fetch.HEAD_UPTODATE:
71 | self.can_update = True
72 | self.status = "behind"
73 | return
74 |
75 | self.can_update = False
76 | self.status = "latest"
77 |
78 | def fetch_and_reset_hard(self):
79 | repo = git.Repo(self.path)
80 | # Fix: `error: Your local changes to the following files would be overwritten by merge`,
81 | # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
82 | repo.git.fetch('--all')
83 | repo.git.reset('--hard', 'origin')
84 |
85 |
86 | def list_extensions():
87 | extensions.clear()
88 |
89 | if not os.path.isdir(extensions_dir):
90 | return
91 |
92 | paths = []
93 | for dirname in [extensions_dir, extensions_builtin_dir]:
94 | if not os.path.isdir(dirname):
95 | return
96 |
97 | for extension_dirname in sorted(os.listdir(dirname)):
98 | path = os.path.join(dirname, extension_dirname)
99 | if not os.path.isdir(path):
100 | continue
101 |
102 | paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
103 |
104 | for dirname, path, is_builtin in paths:
105 | extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
106 | extensions.append(extension)
107 |
108 |
--------------------------------------------------------------------------------
/javascript/extraNetworks.js:
--------------------------------------------------------------------------------
1 |
2 | function setupExtraNetworksForTab(tabname){
3 | gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
4 |
5 | var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
6 | var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
7 | var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
8 | var close = gradioApp().getElementById(tabname+'_extra_close')
9 |
10 | search.classList.add('search')
11 | tabs.appendChild(search)
12 | tabs.appendChild(refresh)
13 | tabs.appendChild(close)
14 |
15 | search.addEventListener("input", function(evt){
16 | searchTerm = search.value.toLowerCase()
17 |
18 | gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
19 | text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
20 | elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
21 | })
22 | });
23 | }
24 |
25 | var activePromptTextarea = {};
26 |
27 | function setupExtraNetworks(){
28 | setupExtraNetworksForTab('txt2img')
29 | setupExtraNetworksForTab('img2img')
30 |
31 | function registerPrompt(tabname, id){
32 | var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
33 |
34 | if (! activePromptTextarea[tabname]){
35 | activePromptTextarea[tabname] = textarea
36 | }
37 |
38 | textarea.addEventListener("focus", function(){
39 | activePromptTextarea[tabname] = textarea;
40 | });
41 | }
42 |
43 | registerPrompt('txt2img', 'txt2img_prompt')
44 | registerPrompt('txt2img', 'txt2img_neg_prompt')
45 | registerPrompt('img2img', 'img2img_prompt')
46 | registerPrompt('img2img', 'img2img_neg_prompt')
47 | }
48 |
49 | onUiLoaded(setupExtraNetworks)
50 |
51 | var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
52 | var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
53 |
54 | function tryToRemoveExtraNetworkFromPrompt(textarea, text){
55 | var m = text.match(re_extranet)
56 | if(! m) return false
57 |
58 | var partToSearch = m[1]
59 | var replaced = false
60 | var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
61 | m = found.match(re_extranet);
62 | if(m[1] == partToSearch){
63 | replaced = true;
64 | return ""
65 | }
66 | return found;
67 | })
68 |
69 | if(replaced){
70 | textarea.value = newTextareaText
71 | return true;
72 | }
73 |
74 | return false
75 | }
76 |
77 | function cardClicked(tabname, textToAdd, allowNegativePrompt){
78 | var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
79 |
80 | if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
81 | textarea.value = textarea.value + " " + textToAdd
82 | }
83 |
84 | updateInput(textarea)
85 | }
86 |
87 | function saveCardPreview(event, tabname, filename){
88 | var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
89 | var button = gradioApp().getElementById(tabname + '_save_preview')
90 |
91 | textarea.value = filename
92 | updateInput(textarea)
93 |
94 | button.click()
95 |
96 | event.stopPropagation()
97 | event.preventDefault()
98 | }
99 |
100 | function extraNetworksSearchButton(tabs_id, event){
101 | searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
102 | button = event.target
103 | text = button.classList.contains("search-all") ? "" : button.textContent.trim()
104 |
105 | searchTextarea.value = text
106 | updateInput(searchTextarea)
107 | }
--------------------------------------------------------------------------------
/javascript/aspectRatioOverlay.js:
--------------------------------------------------------------------------------
1 |
2 | let currentWidth = null;
3 | let currentHeight = null;
4 | let arFrameTimeout = setTimeout(function(){},0);
5 |
6 | function dimensionChange(e, is_width, is_height){
7 |
8 | if(is_width){
9 | currentWidth = e.target.value*1.0
10 | }
11 | if(is_height){
12 | currentHeight = e.target.value*1.0
13 | }
14 |
15 | var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
16 |
17 | if(!inImg2img){
18 | return;
19 | }
20 |
21 | var targetElement = null;
22 |
23 | var tabIndex = get_tab_index('mode_img2img')
24 | if(tabIndex == 0){ // img2img
25 | targetElement = gradioApp().querySelector('div[data-testid=image] img');
26 | } else if(tabIndex == 1){ //Sketch
27 | targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
28 | } else if(tabIndex == 2){ // Inpaint
29 | targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
30 | } else if(tabIndex == 3){ // Inpaint sketch
31 | targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
32 | }
33 |
34 |
35 | if(targetElement){
36 |
37 | var arPreviewRect = gradioApp().querySelector('#imageARPreview');
38 | if(!arPreviewRect){
39 | arPreviewRect = document.createElement('div')
40 | arPreviewRect.id = "imageARPreview";
41 | gradioApp().getRootNode().appendChild(arPreviewRect)
42 | }
43 |
44 |
45 |
46 | var viewportOffset = targetElement.getBoundingClientRect();
47 |
48 | viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
49 |
50 | scaledx = targetElement.naturalWidth*viewportscale
51 | scaledy = targetElement.naturalHeight*viewportscale
52 |
53 | cleintRectTop = (viewportOffset.top+window.scrollY)
54 | cleintRectLeft = (viewportOffset.left+window.scrollX)
55 | cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
56 | cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
57 |
58 | viewRectTop = cleintRectCentreY-(scaledy/2)
59 | viewRectLeft = cleintRectCentreX-(scaledx/2)
60 | arRectWidth = scaledx
61 | arRectHeight = scaledy
62 |
63 | arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
64 | arscaledx = currentWidth*arscale
65 | arscaledy = currentHeight*arscale
66 |
67 | arRectTop = cleintRectCentreY-(arscaledy/2)
68 | arRectLeft = cleintRectCentreX-(arscaledx/2)
69 | arRectWidth = arscaledx
70 | arRectHeight = arscaledy
71 |
72 | arPreviewRect.style.top = arRectTop+'px';
73 | arPreviewRect.style.left = arRectLeft+'px';
74 | arPreviewRect.style.width = arRectWidth+'px';
75 | arPreviewRect.style.height = arRectHeight+'px';
76 |
77 | clearTimeout(arFrameTimeout);
78 | arFrameTimeout = setTimeout(function(){
79 | arPreviewRect.style.display = 'none';
80 | },2000);
81 |
82 | arPreviewRect.style.display = 'block';
83 |
84 | }
85 |
86 | }
87 |
88 |
89 | onUiUpdate(function(){
90 | var arPreviewRect = gradioApp().querySelector('#imageARPreview');
91 | if(arPreviewRect){
92 | arPreviewRect.style.display = 'none';
93 | }
94 | var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
95 | if(inImg2img){
96 | let inputs = gradioApp().querySelectorAll('input');
97 | inputs.forEach(function(e){
98 | var is_width = e.parentElement.id == "img2img_width"
99 | var is_height = e.parentElement.id == "img2img_height"
100 |
101 | if((is_width || is_height) && !e.classList.contains('scrollwatch')){
102 | e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
103 | e.classList.add('scrollwatch')
104 | }
105 | if(is_width){
106 | currentWidth = e.value*1.0
107 | }
108 | if(is_height){
109 | currentHeight = e.value*1.0
110 | }
111 | })
112 | }
113 | });
114 |
--------------------------------------------------------------------------------
/modules/call_queue.py:
--------------------------------------------------------------------------------
1 | import html
2 | import sys
3 | import threading
4 | import traceback
5 | import time
6 |
7 | from modules import shared, progress
8 |
9 | queue_lock = threading.Lock()
10 |
11 |
12 | def wrap_queued_call(func):
13 | def f(*args, **kwargs):
14 | with queue_lock:
15 | res = func(*args, **kwargs)
16 |
17 | return res
18 |
19 | return f
20 |
21 |
22 | def wrap_gradio_gpu_call(func, extra_outputs=None):
23 | def f(*args, **kwargs):
24 |
25 | # if the first argument is a string that says "task(...)", it is treated as a job id
26 | if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
27 | id_task = args[0]
28 | progress.add_task_to_queue(id_task)
29 | else:
30 | id_task = None
31 |
32 | with queue_lock:
33 | shared.state.begin()
34 | progress.start_task(id_task)
35 |
36 | try:
37 | res = func(*args, **kwargs)
38 | finally:
39 | progress.finish_task(id_task)
40 |
41 | shared.state.end()
42 |
43 | return res
44 |
45 | return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
46 |
47 |
48 | def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
49 | def f(*args, extra_outputs_array=extra_outputs, **kwargs):
50 | run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
51 | if run_memmon:
52 | shared.mem_mon.monitor()
53 | t = time.perf_counter()
54 |
55 | try:
56 | res = list(func(*args, **kwargs))
57 | except Exception as e:
58 | # When printing out our debug argument list, do not print out more than a MB of text
59 | max_debug_str_len = 131072 # (1024*1024)/8
60 |
61 | print("Error completing request", file=sys.stderr)
62 | argStr = f"Arguments: {str(args)} {str(kwargs)}"
63 | print(argStr[:max_debug_str_len], file=sys.stderr)
64 | if len(argStr) > max_debug_str_len:
65 | print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
66 |
67 | print(traceback.format_exc(), file=sys.stderr)
68 |
69 | shared.state.job = ""
70 | shared.state.job_count = 0
71 |
72 | if extra_outputs_array is None:
73 | extra_outputs_array = [None, '']
74 |
75 | res = extra_outputs_array + [f"{html.escape(type(e).__name__+': '+str(e))}
"]
76 |
77 | shared.state.skipped = False
78 | shared.state.interrupted = False
79 | shared.state.job_count = 0
80 |
81 | if not add_stats:
82 | return tuple(res)
83 |
84 | elapsed = time.perf_counter() - t
85 | elapsed_m = int(elapsed // 60)
86 | elapsed_s = elapsed % 60
87 | elapsed_text = f"{elapsed_s:.2f}s"
88 | if elapsed_m > 0:
89 | elapsed_text = f"{elapsed_m}m "+elapsed_text
90 |
91 | if run_memmon:
92 | mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
93 | active_peak = mem_stats['active_peak']
94 | reserved_peak = mem_stats['reserved_peak']
95 | sys_peak = mem_stats['system_peak']
96 | sys_total = mem_stats['total']
97 | sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
98 |
99 | vram_html = f"Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)
"
100 | else:
101 | vram_html = ''
102 |
103 | # last item is always HTML
104 | res[-1] += f""
105 |
106 | return tuple(res)
107 |
108 | return f
109 |
110 |
--------------------------------------------------------------------------------
/scripts/loopback.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import trange
3 |
4 | import modules.scripts as scripts
5 | import gradio as gr
6 |
7 | from modules import processing, shared, sd_samplers, images
8 | from modules.processing import Processed
9 | from modules.sd_samplers import samplers
10 | from modules.shared import opts, cmd_opts, state
11 | from modules import deepbooru
12 |
13 |
14 | class Script(scripts.Script):
15 | def title(self):
16 | return "Loopback"
17 |
18 | def show(self, is_img2img):
19 | return is_img2img
20 |
21 | def ui(self, is_img2img):
22 | loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops"))
23 | denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=self.elem_id("denoising_strength_change_factor"))
24 | append_interrogation = gr.Dropdown(label="Append interrogated prompt at each iteration", choices=["None", "CLIP", "DeepBooru"], value="None")
25 |
26 | return [loops, denoising_strength_change_factor, append_interrogation]
27 |
28 | def run(self, p, loops, denoising_strength_change_factor, append_interrogation):
29 | processing.fix_seed(p)
30 | batch_count = p.n_iter
31 | p.extra_generation_params = {
32 | "Denoising strength change factor": denoising_strength_change_factor,
33 | }
34 |
35 | p.batch_size = 1
36 | p.n_iter = 1
37 |
38 | output_images, info = None, None
39 | initial_seed = None
40 | initial_info = None
41 |
42 | grids = []
43 | all_images = []
44 | original_init_image = p.init_images
45 | original_prompt = p.prompt
46 | state.job_count = loops * batch_count
47 |
48 | initial_color_corrections = [processing.setup_color_correction(p.init_images[0])]
49 |
50 | for n in range(batch_count):
51 | history = []
52 |
53 | # Reset to original init image at the start of each batch
54 | p.init_images = original_init_image
55 |
56 | for i in range(loops):
57 | p.n_iter = 1
58 | p.batch_size = 1
59 | p.do_not_save_grid = True
60 |
61 | if opts.img2img_color_correction:
62 | p.color_corrections = initial_color_corrections
63 |
64 | if append_interrogation != "None":
65 | p.prompt = original_prompt + ", " if original_prompt != "" else ""
66 | if append_interrogation == "CLIP":
67 | p.prompt += shared.interrogator.interrogate(p.init_images[0])
68 | elif append_interrogation == "DeepBooru":
69 | p.prompt += deepbooru.model.tag(p.init_images[0])
70 |
71 | state.job = f"Iteration {i + 1}/{loops}, batch {n + 1}/{batch_count}"
72 |
73 | processed = processing.process_images(p)
74 |
75 | if initial_seed is None:
76 | initial_seed = processed.seed
77 | initial_info = processed.info
78 |
79 | init_img = processed.images[0]
80 |
81 | p.init_images = [init_img]
82 | p.seed = processed.seed + 1
83 | p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1)
84 | history.append(processed.images[0])
85 |
86 | grid = images.image_grid(history, rows=1)
87 | if opts.grid_save:
88 | images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
89 |
90 | grids.append(grid)
91 | all_images += history
92 |
93 | if opts.return_grid:
94 | all_images = grids + all_images
95 |
96 | processed = Processed(p, all_images, initial_seed, initial_info)
97 |
98 | return processed
99 |
--------------------------------------------------------------------------------
/modules/progress.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import io
3 | import time
4 |
5 | import gradio as gr
6 | from pydantic import BaseModel, Field
7 |
8 | from modules.shared import opts
9 |
10 | import modules.shared as shared
11 |
12 |
13 | current_task = None
14 | pending_tasks = {}
15 | finished_tasks = []
16 |
17 |
18 | def start_task(id_task):
19 | global current_task
20 |
21 | current_task = id_task
22 | pending_tasks.pop(id_task, None)
23 |
24 |
25 | def finish_task(id_task):
26 | global current_task
27 |
28 | if current_task == id_task:
29 | current_task = None
30 |
31 | finished_tasks.append(id_task)
32 | if len(finished_tasks) > 16:
33 | finished_tasks.pop(0)
34 |
35 |
36 | def add_task_to_queue(id_job):
37 | pending_tasks[id_job] = time.time()
38 |
39 |
40 | class ProgressRequest(BaseModel):
41 | id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
42 | id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
43 |
44 |
45 | class ProgressResponse(BaseModel):
46 | active: bool = Field(title="Whether the task is being worked on right now")
47 | queued: bool = Field(title="Whether the task is in queue")
48 | completed: bool = Field(title="Whether the task has already finished")
49 | progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
50 | eta: float = Field(default=None, title="ETA in secs")
51 | live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
52 | id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
53 | textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
54 |
55 |
56 | def setup_progress_api(app):
57 | return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
58 |
59 |
60 | def progressapi(req: ProgressRequest):
61 | active = req.id_task == current_task
62 | queued = req.id_task in pending_tasks
63 | completed = req.id_task in finished_tasks
64 |
65 | if not active:
66 | return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
67 |
68 | progress = 0
69 |
70 | job_count, job_no = shared.state.job_count, shared.state.job_no
71 | sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step
72 |
73 | if job_count > 0:
74 | progress += job_no / job_count
75 | if sampling_steps > 0 and job_count > 0:
76 | progress += 1 / job_count * sampling_step / sampling_steps
77 |
78 | progress = min(progress, 1)
79 |
80 | elapsed_since_start = time.time() - shared.state.time_start
81 | predicted_duration = elapsed_since_start / progress if progress > 0 else None
82 | eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
83 |
84 | id_live_preview = req.id_live_preview
85 | shared.state.set_current_image()
86 | if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
87 | image = shared.state.current_image
88 | if image is not None:
89 | buffered = io.BytesIO()
90 | image.save(buffered, format="png")
91 | live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
92 | id_live_preview = shared.state.id_live_preview
93 | else:
94 | live_preview = None
95 | else:
96 | live_preview = None
97 |
98 | return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
99 |
100 |
--------------------------------------------------------------------------------
/modules/sd_hijack_unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from packaging import version
3 |
4 | from modules import devices
5 | from modules.sd_hijack_utils import CondFunc
6 |
7 |
8 | class TorchHijackForUnet:
9 | """
10 | This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
11 | this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
12 | """
13 |
14 | def __getattr__(self, item):
15 | if item == 'cat':
16 | return self.cat
17 |
18 | if hasattr(torch, item):
19 | return getattr(torch, item)
20 |
21 | raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
22 |
23 | def cat(self, tensors, *args, **kwargs):
24 | if len(tensors) == 2:
25 | a, b = tensors
26 | if a.shape[-2:] != b.shape[-2:]:
27 | a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
28 |
29 | tensors = (a, b)
30 |
31 | return torch.cat(tensors, *args, **kwargs)
32 |
33 |
34 | th = TorchHijackForUnet()
35 |
36 |
37 | # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
38 | def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
39 |
40 | if isinstance(cond, dict):
41 | for y in cond.keys():
42 | cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
43 |
44 | with devices.autocast():
45 | return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
46 |
47 |
48 | class GELUHijack(torch.nn.GELU, torch.nn.Module):
49 | def __init__(self, *args, **kwargs):
50 | torch.nn.GELU.__init__(self, *args, **kwargs)
51 | def forward(self, x):
52 | if devices.unet_needs_upcast:
53 | return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
54 | else:
55 | return torch.nn.GELU.forward(self, x)
56 |
57 |
58 | ddpm_edit_hijack = None
59 | def hijack_ddpm_edit():
60 | global ddpm_edit_hijack
61 | if not ddpm_edit_hijack:
62 | CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
63 | CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
64 | ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
65 |
66 |
67 | unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
68 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
69 | CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
70 | if version.parse(torch.__version__) <= version.parse("1.13.1"):
71 | CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
72 | CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
73 | CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
74 |
75 | first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
76 | first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
77 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
78 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
79 | CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
80 |
--------------------------------------------------------------------------------
/modules/gfpgan_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import traceback
4 |
5 | import facexlib
6 | import gfpgan
7 |
8 | import modules.face_restoration
9 | from modules import paths, shared, devices, modelloader
10 |
11 | model_dir = "GFPGAN"
12 | user_path = None
13 | model_path = os.path.join(paths.models_path, model_dir)
14 | model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
15 | have_gfpgan = False
16 | loaded_gfpgan_model = None
17 |
18 |
19 | def gfpgann():
20 | global loaded_gfpgan_model
21 | global model_path
22 | if loaded_gfpgan_model is not None:
23 | loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
24 | return loaded_gfpgan_model
25 |
26 | if gfpgan_constructor is None:
27 | return None
28 |
29 | models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
30 | if len(models) == 1 and "http" in models[0]:
31 | model_file = models[0]
32 | elif len(models) != 0:
33 | latest_file = max(models, key=os.path.getctime)
34 | model_file = latest_file
35 | else:
36 | print("Unable to load gfpgan model!")
37 | return None
38 | if hasattr(facexlib.detection.retinaface, 'device'):
39 | facexlib.detection.retinaface.device = devices.device_gfpgan
40 | model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
41 | loaded_gfpgan_model = model
42 |
43 | return model
44 |
45 |
46 | def send_model_to(model, device):
47 | model.gfpgan.to(device)
48 | model.face_helper.face_det.to(device)
49 | model.face_helper.face_parse.to(device)
50 |
51 |
52 | def gfpgan_fix_faces(np_image):
53 | model = gfpgann()
54 | if model is None:
55 | return np_image
56 |
57 | send_model_to(model, devices.device_gfpgan)
58 |
59 | np_image_bgr = np_image[:, :, ::-1]
60 | cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
61 | np_image = gfpgan_output_bgr[:, :, ::-1]
62 |
63 | model.face_helper.clean_all()
64 |
65 | if shared.opts.face_restoration_unload:
66 | send_model_to(model, devices.cpu)
67 |
68 | return np_image
69 |
70 |
71 | gfpgan_constructor = None
72 |
73 |
74 | def setup_model(dirname):
75 | global model_path
76 | if not os.path.exists(model_path):
77 | os.makedirs(model_path)
78 |
79 | try:
80 | from gfpgan import GFPGANer
81 | from facexlib import detection, parsing
82 | global user_path
83 | global have_gfpgan
84 | global gfpgan_constructor
85 |
86 | load_file_from_url_orig = gfpgan.utils.load_file_from_url
87 | facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
88 | facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
89 |
90 | def my_load_file_from_url(**kwargs):
91 | return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
92 |
93 | def facex_load_file_from_url(**kwargs):
94 | return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
95 |
96 | def facex_load_file_from_url2(**kwargs):
97 | return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
98 |
99 | gfpgan.utils.load_file_from_url = my_load_file_from_url
100 | facexlib.detection.load_file_from_url = facex_load_file_from_url
101 | facexlib.parsing.load_file_from_url = facex_load_file_from_url2
102 | user_path = dirname
103 | have_gfpgan = True
104 | gfpgan_constructor = GFPGANer
105 |
106 | class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
107 | def name(self):
108 | return "GFPGAN"
109 |
110 | def restore(self, np_image):
111 | return gfpgan_fix_faces(np_image)
112 |
113 | shared.face_restorers.append(FaceRestorerGFPGAN())
114 | except Exception:
115 | print("Error setting up GFPGAN:", file=sys.stderr)
116 | print(traceback.format_exc(), file=sys.stderr)
117 |
--------------------------------------------------------------------------------
/scripts/sd_upscale.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import modules.scripts as scripts
4 | import gradio as gr
5 | from PIL import Image
6 |
7 | from modules import processing, shared, sd_samplers, images, devices
8 | from modules.processing import Processed
9 | from modules.shared import opts, cmd_opts, state
10 |
11 |
12 | class Script(scripts.Script):
13 | def title(self):
14 | return "SD upscale"
15 |
16 | def show(self, is_img2img):
17 | return is_img2img
18 |
19 | def ui(self, is_img2img):
20 | info = gr.HTML("Will upscale the image by the selected scale factor; use width and height sliders to set tile size
")
21 | overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap"))
22 | scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor"))
23 | upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=self.elem_id("upscaler_index"))
24 |
25 | return [info, overlap, upscaler_index, scale_factor]
26 |
27 | def run(self, p, _, overlap, upscaler_index, scale_factor):
28 | if isinstance(upscaler_index, str):
29 | upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower())
30 | processing.fix_seed(p)
31 | upscaler = shared.sd_upscalers[upscaler_index]
32 |
33 | p.extra_generation_params["SD upscale overlap"] = overlap
34 | p.extra_generation_params["SD upscale upscaler"] = upscaler.name
35 |
36 | initial_info = None
37 | seed = p.seed
38 |
39 | init_img = p.init_images[0]
40 | init_img = images.flatten(init_img, opts.img2img_background_color)
41 |
42 | if upscaler.name != "None":
43 | img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path)
44 | else:
45 | img = init_img
46 |
47 | devices.torch_gc()
48 |
49 | grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap)
50 |
51 | batch_size = p.batch_size
52 | upscale_count = p.n_iter
53 | p.n_iter = 1
54 | p.do_not_save_grid = True
55 | p.do_not_save_samples = True
56 |
57 | work = []
58 |
59 | for y, h, row in grid.tiles:
60 | for tiledata in row:
61 | work.append(tiledata[2])
62 |
63 | batch_count = math.ceil(len(work) / batch_size)
64 | state.job_count = batch_count * upscale_count
65 |
66 | print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.")
67 |
68 | result_images = []
69 | for n in range(upscale_count):
70 | start_seed = seed + n
71 | p.seed = start_seed
72 |
73 | work_results = []
74 | for i in range(batch_count):
75 | p.batch_size = batch_size
76 | p.init_images = work[i * batch_size:(i + 1) * batch_size]
77 |
78 | state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}"
79 | processed = processing.process_images(p)
80 |
81 | if initial_info is None:
82 | initial_info = processed.info
83 |
84 | p.seed = processed.seed + 1
85 | work_results += processed.images
86 |
87 | image_index = 0
88 | for y, h, row in grid.tiles:
89 | for tiledata in row:
90 | tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
91 | image_index += 1
92 |
93 | combined_image = images.combine_grid(grid)
94 | result_images.append(combined_image)
95 |
96 | if opts.samples_save:
97 | images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p)
98 |
99 | processed = Processed(p, result_images, seed, initial_info)
100 |
101 | return processed
102 |
--------------------------------------------------------------------------------
/modules/postprocessing.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from PIL import Image
4 |
5 | from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
6 | from modules.shared import opts
7 |
8 |
9 | def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
10 | devices.torch_gc()
11 |
12 | shared.state.begin()
13 | shared.state.job = 'extras'
14 |
15 | image_data = []
16 | image_names = []
17 | outputs = []
18 |
19 | if extras_mode == 1:
20 | for img in image_folder:
21 | image = Image.open(img)
22 | image_data.append(image)
23 | image_names.append(os.path.splitext(img.orig_name)[0])
24 | elif extras_mode == 2:
25 | assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
26 | assert input_dir, 'input directory not selected'
27 |
28 | image_list = shared.listfiles(input_dir)
29 | for filename in image_list:
30 | try:
31 | image = Image.open(filename)
32 | except Exception:
33 | continue
34 | image_data.append(image)
35 | image_names.append(filename)
36 | else:
37 | assert image, 'image not selected'
38 |
39 | image_data.append(image)
40 | image_names.append(None)
41 |
42 | if extras_mode == 2 and output_dir != '':
43 | outpath = output_dir
44 | else:
45 | outpath = opts.outdir_samples or opts.outdir_extras_samples
46 |
47 | infotext = ''
48 |
49 | for image, name in zip(image_data, image_names):
50 | shared.state.textinfo = name
51 |
52 | existing_pnginfo = image.info or {}
53 |
54 | pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
55 |
56 | scripts.scripts_postproc.run(pp, args)
57 |
58 | if opts.use_original_name_batch and name is not None:
59 | basename = os.path.splitext(os.path.basename(name))[0]
60 | else:
61 | basename = ''
62 |
63 | infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
64 |
65 | if opts.enable_pnginfo:
66 | pp.image.info = existing_pnginfo
67 | pp.image.info["postprocessing"] = infotext
68 |
69 | if save_output:
70 | images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
71 |
72 | if extras_mode != 2 or show_extras_results:
73 | outputs.append(pp.image)
74 |
75 | devices.torch_gc()
76 |
77 | return outputs, ui_common.plaintext_to_html(infotext), ''
78 |
79 |
80 | def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
81 | """old handler for API"""
82 |
83 | args = scripts.scripts_postproc.create_args_for_run({
84 | "Upscale": {
85 | "upscale_mode": resize_mode,
86 | "upscale_by": upscaling_resize,
87 | "upscale_to_width": upscaling_resize_w,
88 | "upscale_to_height": upscaling_resize_h,
89 | "upscale_crop": upscaling_crop,
90 | "upscaler_1_name": extras_upscaler_1,
91 | "upscaler_2_name": extras_upscaler_2,
92 | "upscaler_2_visibility": extras_upscaler_2_visibility,
93 | },
94 | "GFPGAN": {
95 | "gfpgan_visibility": gfpgan_visibility,
96 | },
97 | "CodeFormer": {
98 | "codeformer_visibility": codeformer_visibility,
99 | "codeformer_weight": codeformer_weight,
100 | },
101 | })
102 |
103 | return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
104 |
--------------------------------------------------------------------------------
/modules/sd_models_config.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 |
4 | import torch
5 |
6 | from modules import shared, paths, sd_disable_initialization
7 |
8 | sd_configs_path = shared.sd_configs_path
9 | sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
10 |
11 |
12 | config_default = shared.sd_default_config
13 | config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
14 | config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
15 | config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
16 | config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
17 | config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
18 | config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
19 | config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
20 |
21 |
22 | def is_using_v_parameterization_for_sd2(state_dict):
23 | """
24 | Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
25 | """
26 |
27 | import ldm.modules.diffusionmodules.openaimodel
28 | from modules import devices
29 |
30 | device = devices.cpu
31 |
32 | with sd_disable_initialization.DisableInitialization():
33 | unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
34 | use_checkpoint=True,
35 | use_fp16=False,
36 | image_size=32,
37 | in_channels=4,
38 | out_channels=4,
39 | model_channels=320,
40 | attention_resolutions=[4, 2, 1],
41 | num_res_blocks=2,
42 | channel_mult=[1, 2, 4, 4],
43 | num_head_channels=64,
44 | use_spatial_transformer=True,
45 | use_linear_in_transformer=True,
46 | transformer_depth=1,
47 | context_dim=1024,
48 | legacy=False
49 | )
50 | unet.eval()
51 |
52 | with torch.no_grad():
53 | unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
54 | unet.load_state_dict(unet_sd, strict=True)
55 | unet.to(device=device, dtype=torch.float)
56 |
57 | test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
58 | x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
59 |
60 | out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
61 |
62 | return out < -1
63 |
64 |
65 | def guess_model_config_from_state_dict(sd, filename):
66 | sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
67 | diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
68 |
69 | if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
70 | return config_depth_model
71 |
72 | if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
73 | if diffusion_model_input.shape[1] == 9:
74 | return config_sd2_inpainting
75 | elif is_using_v_parameterization_for_sd2(sd):
76 | return config_sd2v
77 | else:
78 | return config_sd2
79 |
80 | if diffusion_model_input is not None:
81 | if diffusion_model_input.shape[1] == 9:
82 | return config_inpainting
83 | if diffusion_model_input.shape[1] == 8:
84 | return config_instruct_pix2pix
85 |
86 | if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
87 | return config_alt_diffusion
88 |
89 | return config_default
90 |
91 |
92 | def find_checkpoint_config(state_dict, info):
93 | if info is None:
94 | return guess_model_config_from_state_dict(state_dict, "")
95 |
96 | config = find_checkpoint_config_near_filename(info)
97 | if config is not None:
98 | return config
99 |
100 | return guess_model_config_from_state_dict(state_dict, info.filename)
101 |
102 |
103 | def find_checkpoint_config_near_filename(info):
104 | if info is None:
105 | return None
106 |
107 | config = os.path.splitext(info.filename)[0] + ".yaml"
108 | if os.path.exists(config):
109 | return config
110 |
111 | return None
112 |
113 |
--------------------------------------------------------------------------------
/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js:
--------------------------------------------------------------------------------
1 | // Stable Diffusion WebUI - Bracket checker
2 | // Version 1.0
3 | // By Hingashi no Florin/Bwin4L
4 | // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
5 | // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
6 |
7 | function checkBrackets(evt, textArea, counterElt) {
8 | errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n';
9 | errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n';
10 | errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n';
11 |
12 | openBracketRegExp = /\(/g;
13 | closeBracketRegExp = /\)/g;
14 |
15 | openSquareBracketRegExp = /\[/g;
16 | closeSquareBracketRegExp = /\]/g;
17 |
18 | openCurlyBracketRegExp = /\{/g;
19 | closeCurlyBracketRegExp = /\}/g;
20 |
21 | totalOpenBracketMatches = 0;
22 | totalCloseBracketMatches = 0;
23 | totalOpenSquareBracketMatches = 0;
24 | totalCloseSquareBracketMatches = 0;
25 | totalOpenCurlyBracketMatches = 0;
26 | totalCloseCurlyBracketMatches = 0;
27 |
28 | openBracketMatches = textArea.value.match(openBracketRegExp);
29 | if(openBracketMatches) {
30 | totalOpenBracketMatches = openBracketMatches.length;
31 | }
32 |
33 | closeBracketMatches = textArea.value.match(closeBracketRegExp);
34 | if(closeBracketMatches) {
35 | totalCloseBracketMatches = closeBracketMatches.length;
36 | }
37 |
38 | openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
39 | if(openSquareBracketMatches) {
40 | totalOpenSquareBracketMatches = openSquareBracketMatches.length;
41 | }
42 |
43 | closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
44 | if(closeSquareBracketMatches) {
45 | totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
46 | }
47 |
48 | openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
49 | if(openCurlyBracketMatches) {
50 | totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
51 | }
52 |
53 | closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
54 | if(closeCurlyBracketMatches) {
55 | totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
56 | }
57 |
58 | if(totalOpenBracketMatches != totalCloseBracketMatches) {
59 | if(!counterElt.title.includes(errorStringParen)) {
60 | counterElt.title += errorStringParen;
61 | }
62 | } else {
63 | counterElt.title = counterElt.title.replace(errorStringParen, '');
64 | }
65 |
66 | if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
67 | if(!counterElt.title.includes(errorStringSquare)) {
68 | counterElt.title += errorStringSquare;
69 | }
70 | } else {
71 | counterElt.title = counterElt.title.replace(errorStringSquare, '');
72 | }
73 |
74 | if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
75 | if(!counterElt.title.includes(errorStringCurly)) {
76 | counterElt.title += errorStringCurly;
77 | }
78 | } else {
79 | counterElt.title = counterElt.title.replace(errorStringCurly, '');
80 | }
81 |
82 | if(counterElt.title != '') {
83 | counterElt.classList.add('error');
84 | } else {
85 | counterElt.classList.remove('error');
86 | }
87 | }
88 |
89 | function setupBracketChecking(id_prompt, id_counter){
90 | var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
91 | var counter = gradioApp().getElementById(id_counter)
92 | textarea.addEventListener("input", function(evt){
93 | checkBrackets(evt, textarea, counter)
94 | });
95 | }
96 |
97 | var shadowRootLoaded = setInterval(function() {
98 | var shadowRoot = document.querySelector('gradio-app').shadowRoot;
99 | if(! shadowRoot) return false;
100 |
101 | var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
102 | if(shadowTextArea.length < 1) return false;
103 |
104 | clearInterval(shadowRootLoaded);
105 |
106 | setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
107 | setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
108 | setupBracketChecking('img2img_prompt', 'imgimg_token_counter')
109 | setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
110 | }, 1000);
111 |
--------------------------------------------------------------------------------
/modules/lowvram.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from modules import devices
3 |
4 | module_in_gpu = None
5 | cpu = torch.device("cpu")
6 |
7 |
8 | def send_everything_to_cpu():
9 | global module_in_gpu
10 |
11 | if module_in_gpu is not None:
12 | module_in_gpu.to(cpu)
13 |
14 | module_in_gpu = None
15 |
16 |
17 | def setup_for_low_vram(sd_model, use_medvram):
18 | parents = {}
19 |
20 | def send_me_to_gpu(module, _):
21 | """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
22 | we add this as forward_pre_hook to a lot of modules and this way all but one of them will
23 | be in CPU
24 | """
25 | global module_in_gpu
26 |
27 | module = parents.get(module, module)
28 |
29 | if module_in_gpu == module:
30 | return
31 |
32 | if module_in_gpu is not None:
33 | module_in_gpu.to(cpu)
34 |
35 | module.to(devices.device)
36 | module_in_gpu = module
37 |
38 | # see below for register_forward_pre_hook;
39 | # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
40 | # useless here, and we just replace those methods
41 |
42 | first_stage_model = sd_model.first_stage_model
43 | first_stage_model_encode = sd_model.first_stage_model.encode
44 | first_stage_model_decode = sd_model.first_stage_model.decode
45 |
46 | def first_stage_model_encode_wrap(x):
47 | send_me_to_gpu(first_stage_model, None)
48 | return first_stage_model_encode(x)
49 |
50 | def first_stage_model_decode_wrap(z):
51 | send_me_to_gpu(first_stage_model, None)
52 | return first_stage_model_decode(z)
53 |
54 | # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
55 | if hasattr(sd_model.cond_stage_model, 'model'):
56 | sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
57 |
58 | # remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
59 | # send the model to GPU. Then put modules back. the modules will be in CPU.
60 | stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
61 | sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
62 | sd_model.to(devices.device)
63 | sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
64 |
65 | # register hooks for those the first three models
66 | sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
67 | sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
68 | sd_model.first_stage_model.encode = first_stage_model_encode_wrap
69 | sd_model.first_stage_model.decode = first_stage_model_decode_wrap
70 | if sd_model.depth_model:
71 | sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
72 | parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
73 |
74 | if hasattr(sd_model.cond_stage_model, 'model'):
75 | sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
76 | del sd_model.cond_stage_model.transformer
77 |
78 | if use_medvram:
79 | sd_model.model.register_forward_pre_hook(send_me_to_gpu)
80 | else:
81 | diff_model = sd_model.model.diffusion_model
82 |
83 | # the third remaining model is still too big for 4 GB, so we also do the same for its submodules
84 | # so that only one of them is in GPU at a time
85 | stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
86 | diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
87 | sd_model.model.to(devices.device)
88 | diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
89 |
90 | # install hooks for bits of third model
91 | diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
92 | for block in diff_model.input_blocks:
93 | block.register_forward_pre_hook(send_me_to_gpu)
94 | diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
95 | for block in diff_model.output_blocks:
96 | block.register_forward_pre_hook(send_me_to_gpu)
97 |
--------------------------------------------------------------------------------
/modules/upscaler.py:
--------------------------------------------------------------------------------
1 | import os
2 | from abc import abstractmethod
3 |
4 | import PIL
5 | import numpy as np
6 | import torch
7 | from PIL import Image
8 |
9 | import modules.shared
10 | from modules import modelloader, shared
11 |
12 | LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
13 | NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
14 |
15 |
16 | class Upscaler:
17 | name = None
18 | model_path = None
19 | model_name = None
20 | model_url = None
21 | enable = True
22 | filter = None
23 | model = None
24 | user_path = None
25 | scalers: []
26 | tile = True
27 |
28 | def __init__(self, create_dirs=False):
29 | self.mod_pad_h = None
30 | self.tile_size = modules.shared.opts.ESRGAN_tile
31 | self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
32 | self.device = modules.shared.device
33 | self.img = None
34 | self.output = None
35 | self.scale = 1
36 | self.half = not modules.shared.cmd_opts.no_half
37 | self.pre_pad = 0
38 | self.mod_scale = None
39 |
40 | if self.model_path is None and self.name:
41 | self.model_path = os.path.join(shared.models_path, self.name)
42 | if self.model_path and create_dirs:
43 | os.makedirs(self.model_path, exist_ok=True)
44 |
45 | try:
46 | import cv2
47 | self.can_tile = True
48 | except:
49 | pass
50 |
51 | @abstractmethod
52 | def do_upscale(self, img: PIL.Image, selected_model: str):
53 | return img
54 |
55 | def upscale(self, img: PIL.Image, scale, selected_model: str = None):
56 | self.scale = scale
57 | dest_w = int(img.width * scale)
58 | dest_h = int(img.height * scale)
59 |
60 | for i in range(3):
61 | shape = (img.width, img.height)
62 |
63 | img = self.do_upscale(img, selected_model)
64 |
65 | if shape == (img.width, img.height):
66 | break
67 |
68 | if img.width >= dest_w and img.height >= dest_h:
69 | break
70 |
71 | if img.width != dest_w or img.height != dest_h:
72 | img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
73 |
74 | return img
75 |
76 | @abstractmethod
77 | def load_model(self, path: str):
78 | pass
79 |
80 | def find_models(self, ext_filter=None) -> list:
81 | return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
82 |
83 | def update_status(self, prompt):
84 | print(f"\nextras: {prompt}", file=shared.progress_print_out)
85 |
86 |
87 | class UpscalerData:
88 | name = None
89 | data_path = None
90 | scale: int = 4
91 | scaler: Upscaler = None
92 | model: None
93 |
94 | def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
95 | self.name = name
96 | self.data_path = path
97 | self.local_data_path = path
98 | self.scaler = upscaler
99 | self.scale = scale
100 | self.model = model
101 |
102 |
103 | class UpscalerNone(Upscaler):
104 | name = "None"
105 | scalers = []
106 |
107 | def load_model(self, path):
108 | pass
109 |
110 | def do_upscale(self, img, selected_model=None):
111 | return img
112 |
113 | def __init__(self, dirname=None):
114 | super().__init__(False)
115 | self.scalers = [UpscalerData("None", None, self)]
116 |
117 |
118 | class UpscalerLanczos(Upscaler):
119 | scalers = []
120 |
121 | def do_upscale(self, img, selected_model=None):
122 | return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
123 |
124 | def load_model(self, _):
125 | pass
126 |
127 | def __init__(self, dirname=None):
128 | super().__init__(False)
129 | self.name = "Lanczos"
130 | self.scalers = [UpscalerData("Lanczos", None, self)]
131 |
132 |
133 | class UpscalerNearest(Upscaler):
134 | scalers = []
135 |
136 | def do_upscale(self, img, selected_model=None):
137 | return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
138 |
139 | def load_model(self, _):
140 | pass
141 |
142 | def __init__(self, dirname=None):
143 | super().__init__(False)
144 | self.name = "Nearest"
145 | self.scalers = [UpscalerData("Nearest", None, self)]
146 |
--------------------------------------------------------------------------------