├── test
├── __init__.py
├── test_files
│ ├── mask_basic.png
│ └── img2img_basic.png
├── server_poll.py
├── extras_test.py
├── img2img_test.py
├── utils_test.py
└── txt2img_test.py
├── models
├── VAE
│ └── Put VAE here.txt
├── Stable-diffusion
│ └── Put Stable Diffusion checkpoints here.txt
└── deepbooru
│ └── Put your deepbooru release project folder here.txt
├── extensions
└── put extensions here.txt
├── localizations
└── Put localization files here.txt
├── textual_inversion_templates
├── none.txt
├── style.txt
├── subject.txt
├── hypernetwork.txt
├── style_filewords.txt
└── subject_filewords.txt
├── embeddings
└── Place Textual Inversion embeddings here.txt
├── screenshot.png
├── txt2img_Screenshot.png
├── webui-user.bat
├── .pylintrc
├── modules
├── textual_inversion
│ ├── test_embedding.png
│ ├── ui.py
│ ├── learn_schedule.py
│ └── dataset.py
├── errors.py
├── face_restoration.py
├── ngrok.py
├── artists.py
├── localization.py
├── safety.py
├── paths.py
├── ldsr_model.py
├── hypernetworks
│ └── ui.py
├── txt2img.py
├── extensions.py
├── memmon.py
├── devices.py
├── scunet_model.py
├── masking.py
├── lowvram.py
├── styles.py
├── gfpgan_model.py
├── upscaler.py
├── realesrgan_model.py
├── safe.py
├── img2img.py
├── swinir_model.py
├── codeformer_model.py
├── modelloader.py
└── deepbooru.py
├── javascript
├── textualInversion.js
├── imageParams.js
├── extensions.js
├── imageMaskFix.js
├── notification.js
├── dragdrop.js
├── edit-attention.js
├── aspectRatioOverlay.js
├── localization.js
├── contextMenus.js
├── progressbar.js
└── ui.js
├── environment-wsl2.yaml
├── .github
├── ISSUE_TEMPLATE
│ ├── config.yml
│ ├── feature_request.yml
│ └── bug_report.yml
├── workflows
│ └── on_pull_request.yaml
└── PULL_REQUEST_TEMPLATE
│ └── pull_request_template.md
├── requirements.txt
├── .gitignore
├── requirements_versions.txt
├── CODEOWNERS
├── scripts
├── custom_code.py
├── loopback.py
├── prompt_matrix.py
├── sd_upscale.py
├── prompts_from_file.py
└── poor_mans_outpainting.py
├── webui-user.sh
├── webui.bat
├── README.md
├── repositories
└── stable-diffusion-taiyi
│ └── configs
│ └── stable-diffusion
│ ├── v1-inference.yaml
│ └── v1-inference-en.yaml
├── script.js
├── webui.sh
└── webui.py
/test/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/VAE/Put VAE here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/extensions/put extensions here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/localizations/Put localization files here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/textual_inversion_templates/none.txt:
--------------------------------------------------------------------------------
1 | picture
2 |
--------------------------------------------------------------------------------
/embeddings/Place Textual Inversion embeddings here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/Stable-diffusion/Put Stable Diffusion checkpoints here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/deepbooru/Put your deepbooru release project folder here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-CCNL/stable-diffusion-webui/HEAD/screenshot.png
--------------------------------------------------------------------------------
/txt2img_Screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-CCNL/stable-diffusion-webui/HEAD/txt2img_Screenshot.png
--------------------------------------------------------------------------------
/test/test_files/mask_basic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-CCNL/stable-diffusion-webui/HEAD/test/test_files/mask_basic.png
--------------------------------------------------------------------------------
/webui-user.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | set PYTHON=
4 | set GIT=
5 | set VENV_DIR=
6 | set COMMANDLINE_ARGS=
7 |
8 | call webui.bat
9 |
--------------------------------------------------------------------------------
/test/test_files/img2img_basic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-CCNL/stable-diffusion-webui/HEAD/test/test_files/img2img_basic.png
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | # See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html
2 | [MESSAGES CONTROL]
3 | disable=C,R,W,E,I
4 |
--------------------------------------------------------------------------------
/modules/textual_inversion/test_embedding.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-CCNL/stable-diffusion-webui/HEAD/modules/textual_inversion/test_embedding.png
--------------------------------------------------------------------------------
/javascript/textualInversion.js:
--------------------------------------------------------------------------------
1 |
2 |
3 | function start_training_textual_inversion(){
4 | requestProgress('ti')
5 | gradioApp().querySelector('#ti_error').innerHTML=''
6 |
7 | return args_to_array(arguments)
8 | }
9 |
--------------------------------------------------------------------------------
/environment-wsl2.yaml:
--------------------------------------------------------------------------------
1 | name: automatic
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.10
7 | - pip=22.2.2
8 | - cudatoolkit=11.3
9 | - pytorch=1.12.1
10 | - torchvision=0.13.1
11 | - numpy=1.23.1
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: WebUI Community Support
4 | url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
5 | about: Please ask and answer questions here.
6 |
--------------------------------------------------------------------------------
/modules/errors.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import traceback
3 |
4 |
5 | def run(code, task):
6 | try:
7 | code()
8 | except Exception as e:
9 | print(f"{task}: {type(e).__name__}", file=sys.stderr)
10 | print(traceback.format_exc(), file=sys.stderr)
11 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | basicsr
2 | diffusers
3 | fairscale==0.4.4
4 | fonts
5 | font-roboto
6 | gfpgan
7 | gradio==3.8
8 | invisible-watermark
9 | numpy
10 | omegaconf
11 | opencv-python
12 | requests
13 | piexif
14 | Pillow
15 | pytorch_lightning==1.7.7
16 | realesrgan
17 | scikit-image>=0.19
18 | timm==0.4.12
19 | transformers==4.19.2
20 | torch
21 | einops
22 | jsonmerge
23 | clean-fid
24 | resize-right
25 | torchdiffeq
26 | kornia
27 | lark
28 | inflection
29 | GitPython
30 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.ckpt
3 | *.pth
4 | /ESRGAN/*
5 | /SwinIR/*
6 | /repositories
7 | /venv
8 | /tmp
9 | /model.ckpt
10 | /models/**/*
11 | /GFPGANv1.3.pth
12 | /gfpgan/weights/*.pth
13 | /ui-config.json
14 | /outputs
15 | /config.json
16 | /log
17 | /webui.settings.bat
18 | /embeddings
19 | /styles.csv
20 | /params.txt
21 | /styles.csv.bak
22 | /webui-user.bat
23 | /webui-user.sh
24 | /interrogate
25 | /user.css
26 | /.idea
27 | notification.mp3
28 | /SwinIR
29 | /textual_inversion
30 | .vscode
31 | /extensions
32 | /test/stdout.txt
33 | /test/stderr.txt
34 |
--------------------------------------------------------------------------------
/requirements_versions.txt:
--------------------------------------------------------------------------------
1 | transformers==4.19.2
2 | diffusers==0.3.0
3 | basicsr==1.4.2
4 | gfpgan==1.3.8
5 | gradio==3.8
6 | numpy==1.23.3
7 | Pillow==9.2.0
8 | realesrgan==0.3.0
9 | torch
10 | omegaconf==2.2.3
11 | pytorch_lightning==1.7.6
12 | scikit-image==0.19.2
13 | fonts
14 | font-roboto
15 | timm==0.6.7
16 | fairscale==0.4.9
17 | piexif==1.1.3
18 | einops==0.4.1
19 | jsonmerge==1.8.0
20 | clean-fid==0.1.29
21 | resize-right==0.0.2
22 | torchdiffeq==0.2.3
23 | kornia==0.6.7
24 | lark==1.1.2
25 | inflection==0.5.1
26 | GitPython==3.1.27
27 |
--------------------------------------------------------------------------------
/modules/face_restoration.py:
--------------------------------------------------------------------------------
1 | from modules import shared
2 |
3 |
4 | class FaceRestoration:
5 | def name(self):
6 | return "None"
7 |
8 | def restore(self, np_image):
9 | return np_image
10 |
11 |
12 | def restore_faces(np_image):
13 | face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
14 | if len(face_restorers) == 0:
15 | return np_image
16 |
17 | face_restorer = face_restorers[0]
18 |
19 | return face_restorer.restore(np_image)
20 |
--------------------------------------------------------------------------------
/test/server_poll.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import requests
3 | import time
4 |
5 |
6 | def run_tests():
7 | timeout_threshold = 240
8 | start_time = time.time()
9 | while time.time()-start_time < timeout_threshold:
10 | try:
11 | requests.head("http://localhost:7860/")
12 | break
13 | except requests.exceptions.ConnectionError:
14 | pass
15 | if time.time()-start_time < timeout_threshold:
16 | suite = unittest.TestLoader().discover('', pattern='*_test.py')
17 | result = unittest.TextTestRunner(verbosity=2).run(suite)
18 | else:
19 | print("Launch unsuccessful")
20 |
--------------------------------------------------------------------------------
/textual_inversion_templates/style.txt:
--------------------------------------------------------------------------------
1 | a painting, art by [name]
2 | a rendering, art by [name]
3 | a cropped painting, art by [name]
4 | the painting, art by [name]
5 | a clean painting, art by [name]
6 | a dirty painting, art by [name]
7 | a dark painting, art by [name]
8 | a picture, art by [name]
9 | a cool painting, art by [name]
10 | a close-up painting, art by [name]
11 | a bright painting, art by [name]
12 | a cropped painting, art by [name]
13 | a good painting, art by [name]
14 | a close-up painting, art by [name]
15 | a rendition, art by [name]
16 | a nice painting, art by [name]
17 | a small painting, art by [name]
18 | a weird painting, art by [name]
19 | a large painting, art by [name]
20 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @AUTOMATIC1111
2 |
3 | # if you were managing a localization and were removed from this file, this is because
4 | # the intended way to do localizations now is via extensions. See:
5 | # https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions
6 | # Make a repo with your localization and since you are still listed as a collaborator
7 | # you can add it to the wiki page yourself. This change is because some people complained
8 | # the git commit log is cluttered with things unrelated to almost everyone and
9 | # because I believe this is the best overall for the project to handle localizations almost
10 | # entirely without my oversight.
11 |
12 |
13 |
--------------------------------------------------------------------------------
/modules/ngrok.py:
--------------------------------------------------------------------------------
1 | from pyngrok import ngrok, conf, exception
2 |
3 |
4 | def connect(token, port, region):
5 | if token == None:
6 | token = 'None'
7 | config = conf.PyngrokConfig(
8 | auth_token=token, region=region
9 | )
10 | try:
11 | public_url = ngrok.connect(port, pyngrok_config=config).public_url
12 | except exception.PyngrokNgrokError:
13 | print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
14 | f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
15 | else:
16 | print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
17 | 'You can use this link after the launch is complete.')
18 |
--------------------------------------------------------------------------------
/javascript/imageParams.js:
--------------------------------------------------------------------------------
1 | window.onload = (function(){
2 | window.addEventListener('drop', e => {
3 | const target = e.composedPath()[0];
4 | const idx = selected_gallery_index();
5 | if (target.placeholder.indexOf("Prompt") == -1) return;
6 |
7 | let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
8 |
9 | e.stopPropagation();
10 | e.preventDefault();
11 | const imgParent = gradioApp().getElementById(prompt_target);
12 | const files = e.dataTransfer.files;
13 | const fileInput = imgParent.querySelector('input[type="file"]');
14 | if ( fileInput ) {
15 | fileInput.files = files;
16 | fileInput.dispatchEvent(new Event('change'));
17 | }
18 | });
19 | });
20 |
--------------------------------------------------------------------------------
/modules/artists.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import csv
3 | from collections import namedtuple
4 |
5 | Artist = namedtuple("Artist", ['name', 'weight', 'category'])
6 |
7 |
8 | class ArtistsDatabase:
9 | def __init__(self, filename):
10 | self.cats = set()
11 | self.artists = []
12 |
13 | if not os.path.exists(filename):
14 | return
15 |
16 | with open(filename, "r", newline='', encoding="utf8") as file:
17 | reader = csv.DictReader(file)
18 |
19 | for row in reader:
20 | artist = Artist(row["artist"], float(row["score"]), row["category"])
21 | self.artists.append(artist)
22 | self.cats.add(artist.category)
23 |
24 | def categories(self):
25 | return sorted(self.cats)
26 |
--------------------------------------------------------------------------------
/textual_inversion_templates/subject.txt:
--------------------------------------------------------------------------------
1 | a photo of a [name]
2 | a rendering of a [name]
3 | a cropped photo of the [name]
4 | the photo of a [name]
5 | a photo of a clean [name]
6 | a photo of a dirty [name]
7 | a dark photo of the [name]
8 | a photo of my [name]
9 | a photo of the cool [name]
10 | a close-up photo of a [name]
11 | a bright photo of the [name]
12 | a cropped photo of a [name]
13 | a photo of the [name]
14 | a good photo of the [name]
15 | a photo of one [name]
16 | a close-up photo of the [name]
17 | a rendition of the [name]
18 | a photo of the clean [name]
19 | a rendition of a [name]
20 | a photo of a nice [name]
21 | a good photo of a [name]
22 | a photo of the nice [name]
23 | a photo of the small [name]
24 | a photo of the weird [name]
25 | a photo of the large [name]
26 | a photo of a cool [name]
27 | a photo of a small [name]
28 |
--------------------------------------------------------------------------------
/test/extras_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 |
4 | class TestExtrasWorking(unittest.TestCase):
5 | def setUp(self):
6 | self.url_img2img = "http://localhost:7860/sdapi/v1/extra-single-image"
7 | self.simple_extras = {
8 | "resize_mode": 0,
9 | "show_extras_results": True,
10 | "gfpgan_visibility": 0,
11 | "codeformer_visibility": 0,
12 | "codeformer_weight": 0,
13 | "upscaling_resize": 2,
14 | "upscaling_resize_w": 512,
15 | "upscaling_resize_h": 512,
16 | "upscaling_crop": True,
17 | "upscaler_1": "None",
18 | "upscaler_2": "None",
19 | "extras_upscaler_2_visibility": 0,
20 | "image": ""
21 | }
22 |
23 |
24 | class TestExtrasCorrectness(unittest.TestCase):
25 | pass
26 |
27 |
28 | if __name__ == "__main__":
29 | unittest.main()
30 |
--------------------------------------------------------------------------------
/textual_inversion_templates/hypernetwork.txt:
--------------------------------------------------------------------------------
1 | a photo of a [filewords]
2 | a rendering of a [filewords]
3 | a cropped photo of the [filewords]
4 | the photo of a [filewords]
5 | a photo of a clean [filewords]
6 | a photo of a dirty [filewords]
7 | a dark photo of the [filewords]
8 | a photo of my [filewords]
9 | a photo of the cool [filewords]
10 | a close-up photo of a [filewords]
11 | a bright photo of the [filewords]
12 | a cropped photo of a [filewords]
13 | a photo of the [filewords]
14 | a good photo of the [filewords]
15 | a photo of one [filewords]
16 | a close-up photo of the [filewords]
17 | a rendition of the [filewords]
18 | a photo of the clean [filewords]
19 | a rendition of a [filewords]
20 | a photo of a nice [filewords]
21 | a good photo of a [filewords]
22 | a photo of the nice [filewords]
23 | a photo of the small [filewords]
24 | a photo of the weird [filewords]
25 | a photo of the large [filewords]
26 | a photo of a cool [filewords]
27 | a photo of a small [filewords]
28 |
--------------------------------------------------------------------------------
/textual_inversion_templates/style_filewords.txt:
--------------------------------------------------------------------------------
1 | a painting of [filewords], art by [name]
2 | a rendering of [filewords], art by [name]
3 | a cropped painting of [filewords], art by [name]
4 | the painting of [filewords], art by [name]
5 | a clean painting of [filewords], art by [name]
6 | a dirty painting of [filewords], art by [name]
7 | a dark painting of [filewords], art by [name]
8 | a picture of [filewords], art by [name]
9 | a cool painting of [filewords], art by [name]
10 | a close-up painting of [filewords], art by [name]
11 | a bright painting of [filewords], art by [name]
12 | a cropped painting of [filewords], art by [name]
13 | a good painting of [filewords], art by [name]
14 | a close-up painting of [filewords], art by [name]
15 | a rendition of [filewords], art by [name]
16 | a nice painting of [filewords], art by [name]
17 | a small painting of [filewords], art by [name]
18 | a weird painting of [filewords], art by [name]
19 | a large painting of [filewords], art by [name]
20 |
--------------------------------------------------------------------------------
/javascript/extensions.js:
--------------------------------------------------------------------------------
1 |
2 | function extensions_apply(_, _){
3 | disable = []
4 | update = []
5 | gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
6 | if(x.name.startsWith("enable_") && ! x.checked)
7 | disable.push(x.name.substr(7))
8 |
9 | if(x.name.startsWith("update_") && x.checked)
10 | update.push(x.name.substr(7))
11 | })
12 |
13 | restart_reload()
14 |
15 | return [JSON.stringify(disable), JSON.stringify(update)]
16 | }
17 |
18 | function extensions_check(){
19 | gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
20 | x.innerHTML = "Loading..."
21 | })
22 |
23 | return []
24 | }
25 |
26 | function install_extension_from_index(button, url){
27 | button.disabled = "disabled"
28 | button.value = "Installing..."
29 |
30 | textarea = gradioApp().querySelector('#extension_to_install textarea')
31 | textarea.value = url
32 | textarea.dispatchEvent(new Event("input", { bubbles: true }))
33 |
34 | gradioApp().querySelector('#install_extension_button').click()
35 | }
36 |
--------------------------------------------------------------------------------
/modules/localization.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sys
4 | import traceback
5 |
6 |
7 | localizations = {}
8 |
9 |
10 | def list_localizations(dirname):
11 | localizations.clear()
12 |
13 | for file in os.listdir(dirname):
14 | fn, ext = os.path.splitext(file)
15 | if ext.lower() != ".json":
16 | continue
17 |
18 | localizations[fn] = os.path.join(dirname, file)
19 |
20 | from modules import scripts
21 | for file in scripts.list_scripts("localizations", ".json"):
22 | fn, ext = os.path.splitext(file.filename)
23 | localizations[fn] = file.path
24 |
25 |
26 | def localization_js(current_localization_name):
27 | fn = localizations.get(current_localization_name, None)
28 | data = {}
29 | if fn is not None:
30 | try:
31 | with open(fn, "r", encoding="utf8") as file:
32 | data = json.load(file)
33 | except Exception:
34 | print(f"Error loading localization from {fn}:", file=sys.stderr)
35 | print(traceback.format_exc(), file=sys.stderr)
36 |
37 | return f"var localization = {json.dumps(data)}\n"
38 |
--------------------------------------------------------------------------------
/textual_inversion_templates/subject_filewords.txt:
--------------------------------------------------------------------------------
1 | a photo of a [name], [filewords]
2 | a rendering of a [name], [filewords]
3 | a cropped photo of the [name], [filewords]
4 | the photo of a [name], [filewords]
5 | a photo of a clean [name], [filewords]
6 | a photo of a dirty [name], [filewords]
7 | a dark photo of the [name], [filewords]
8 | a photo of my [name], [filewords]
9 | a photo of the cool [name], [filewords]
10 | a close-up photo of a [name], [filewords]
11 | a bright photo of the [name], [filewords]
12 | a cropped photo of a [name], [filewords]
13 | a photo of the [name], [filewords]
14 | a good photo of the [name], [filewords]
15 | a photo of one [name], [filewords]
16 | a close-up photo of the [name], [filewords]
17 | a rendition of the [name], [filewords]
18 | a photo of the clean [name], [filewords]
19 | a rendition of a [name], [filewords]
20 | a photo of a nice [name], [filewords]
21 | a good photo of a [name], [filewords]
22 | a photo of the nice [name], [filewords]
23 | a photo of the small [name], [filewords]
24 | a photo of the weird [name], [filewords]
25 | a photo of the large [name], [filewords]
26 | a photo of a cool [name], [filewords]
27 | a photo of a small [name], [filewords]
28 |
--------------------------------------------------------------------------------
/scripts/custom_code.py:
--------------------------------------------------------------------------------
1 | import modules.scripts as scripts
2 | import gradio as gr
3 |
4 | from modules.processing import Processed
5 | from modules.shared import opts, cmd_opts, state
6 |
7 | class Script(scripts.Script):
8 |
9 | def title(self):
10 | return "Custom code"
11 |
12 |
13 | def show(self, is_img2img):
14 | return cmd_opts.allow_code
15 |
16 | def ui(self, is_img2img):
17 | code = gr.Textbox(label="Python code", lines=1)
18 |
19 | return [code]
20 |
21 |
22 | def run(self, p, code):
23 | assert cmd_opts.allow_code, '--allow-code option must be enabled'
24 |
25 | display_result_data = [[], -1, ""]
26 |
27 | def display(imgs, s=display_result_data[1], i=display_result_data[2]):
28 | display_result_data[0] = imgs
29 | display_result_data[1] = s
30 | display_result_data[2] = i
31 |
32 | from types import ModuleType
33 | compiled = compile(code, '', 'exec')
34 | module = ModuleType("testmodule")
35 | module.__dict__.update(globals())
36 | module.p = p
37 | module.display = display
38 | exec(compiled, module.__dict__)
39 |
40 | return Processed(p, *display_result_data)
41 |
42 |
--------------------------------------------------------------------------------
/webui-user.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #########################################################
3 | # Uncomment and change the variables below to your need:#
4 | #########################################################
5 |
6 | # Install directory without trailing slash
7 | #install_dir="/home/$(whoami)"
8 |
9 | # Name of the subdirectory
10 | #clone_dir="stable-diffusion-webui"
11 |
12 | # Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention"
13 | export COMMANDLINE_ARGS=""
14 |
15 | # python3 executable
16 | #python_cmd="python3"
17 |
18 | # git executable
19 | #export GIT="git"
20 |
21 | # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
22 | #venv_dir="venv"
23 |
24 | # script to launch to start the app
25 | #export LAUNCH_SCRIPT="launch.py"
26 |
27 | # install command for torch
28 | #export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113"
29 |
30 | # Requirements file to use for stable-diffusion-webui
31 | #export REQS_FILE="requirements_versions.txt"
32 |
33 | # Fixed git repos
34 | #export K_DIFFUSION_PACKAGE=""
35 | #export GFPGAN_PACKAGE=""
36 |
37 | # Fixed git commits
38 | #export STABLE_DIFFUSION_COMMIT_HASH=""
39 | #export TAMING_TRANSFORMERS_COMMIT_HASH=""
40 | #export CODEFORMER_COMMIT_HASH=""
41 | #export BLIP_COMMIT_HASH=""
42 |
43 | ###########################################
44 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: Feature request
2 | description: Suggest an idea for this project
3 | title: "[Feature Request]: "
4 | labels: ["suggestion"]
5 |
6 | body:
7 | - type: checkboxes
8 | attributes:
9 | label: Is there an existing issue for this?
10 | description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit.
11 | options:
12 | - label: I have searched the existing issues and checked the recent builds/commits
13 | required: true
14 | - type: markdown
15 | attributes:
16 | value: |
17 | *Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible*
18 | - type: textarea
19 | id: feature
20 | attributes:
21 | label: What would your feature do ?
22 | description: Tell us about your feature in a very clear and simple way, and what problem it would solve
23 | validations:
24 | required: true
25 | - type: textarea
26 | id: workflow
27 | attributes:
28 | label: Proposed workflow
29 | description: Please provide us with step by step information on how you'd like the feature to be accessed and used
30 | value: |
31 | 1. Go to ....
32 | 2. Press ....
33 | 3. ...
34 | validations:
35 | required: true
36 | - type: textarea
37 | id: misc
38 | attributes:
39 | label: Additional information
40 | description: Add any other context or screenshots about the feature request here.
41 |
--------------------------------------------------------------------------------
/javascript/imageMaskFix.js:
--------------------------------------------------------------------------------
1 | /**
2 | * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
3 | * @see https://github.com/gradio-app/gradio/issues/1721
4 | */
5 | window.addEventListener( 'resize', () => imageMaskResize());
6 | function imageMaskResize() {
7 | const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
8 | if ( ! canvases.length ) {
9 | canvases_fixed = false;
10 | window.removeEventListener( 'resize', imageMaskResize );
11 | return;
12 | }
13 |
14 | const wrapper = canvases[0].closest('.touch-none');
15 | const previewImage = wrapper.previousElementSibling;
16 |
17 | if ( ! previewImage.complete ) {
18 | previewImage.addEventListener( 'load', () => imageMaskResize());
19 | return;
20 | }
21 |
22 | const w = previewImage.width;
23 | const h = previewImage.height;
24 | const nw = previewImage.naturalWidth;
25 | const nh = previewImage.naturalHeight;
26 | const portrait = nh > nw;
27 | const factor = portrait;
28 |
29 | const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
30 | const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
31 |
32 | wrapper.style.width = `${wW}px`;
33 | wrapper.style.height = `${wH}px`;
34 | wrapper.style.left = `0px`;
35 | wrapper.style.top = `0px`;
36 |
37 | canvases.forEach( c => {
38 | c.style.width = c.style.height = '';
39 | c.style.maxWidth = '100%';
40 | c.style.maxHeight = '100%';
41 | c.style.objectFit = 'contain';
42 | });
43 | }
44 |
45 | onUiUpdate(() => imageMaskResize());
46 |
--------------------------------------------------------------------------------
/.github/workflows/on_pull_request.yaml:
--------------------------------------------------------------------------------
1 | # See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
2 | name: Run Linting/Formatting on Pull Requests
3 |
4 | on:
5 | - push
6 | - pull_request
7 | # See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
8 | # if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
9 | # pull_request:
10 | # branches:
11 | # - master
12 | # branches-ignore:
13 | # - development
14 |
15 | jobs:
16 | lint:
17 | runs-on: ubuntu-latest
18 | steps:
19 | - name: Checkout Code
20 | uses: actions/checkout@v3
21 | - name: Set up Python 3.10
22 | uses: actions/setup-python@v3
23 | with:
24 | python-version: 3.10.6
25 | - uses: actions/cache@v2
26 | with:
27 | path: ~/.cache/pip
28 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
29 | restore-keys: |
30 | ${{ runner.os }}-pip-
31 | - name: Install PyLint
32 | run: |
33 | python -m pip install --upgrade pip
34 | pip install pylint
35 | # This lets PyLint check to see if it can resolve imports
36 | - name: Install dependencies
37 | run : |
38 | export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
39 | python launch.py
40 | - name: Analysing the code with pylint
41 | run: |
42 | pylint $(git ls-files '*.py')
43 |
--------------------------------------------------------------------------------
/modules/textual_inversion/ui.py:
--------------------------------------------------------------------------------
1 | import html
2 |
3 | import gradio as gr
4 |
5 | import modules.textual_inversion.textual_inversion
6 | import modules.textual_inversion.preprocess
7 | from modules import sd_hijack, shared
8 |
9 |
10 | def create_embedding(name, initialization_text, nvpt, overwrite_old):
11 | filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
12 |
13 | sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
14 |
15 | return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
16 |
17 |
18 | def preprocess(*args):
19 | modules.textual_inversion.preprocess.preprocess(*args)
20 |
21 | return "Preprocessing finished.", ""
22 |
23 |
24 | def train_embedding(*args):
25 |
26 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
27 |
28 | apply_optimizations = shared.opts.training_xattention_optimizations
29 | try:
30 | if not apply_optimizations:
31 | sd_hijack.undo_optimizations()
32 |
33 | embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
34 |
35 | res = f"""
36 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
37 | Embedding saved to {html.escape(filename)}
38 | """
39 | return res, ""
40 | except Exception:
41 | raise
42 | finally:
43 | if not apply_optimizations:
44 | sd_hijack.apply_optimizations()
45 |
46 |
--------------------------------------------------------------------------------
/modules/safety.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
3 | from transformers import AutoFeatureExtractor
4 | from PIL import Image
5 |
6 | import modules.shared as shared
7 |
8 | safety_model_id = "CompVis/stable-diffusion-safety-checker"
9 | safety_feature_extractor = None
10 | safety_checker = None
11 |
12 | def numpy_to_pil(images):
13 | """
14 | Convert a numpy image or a batch of images to a PIL image.
15 | """
16 | if images.ndim == 3:
17 | images = images[None, ...]
18 | images = (images * 255).round().astype("uint8")
19 | pil_images = [Image.fromarray(image) for image in images]
20 |
21 | return pil_images
22 |
23 | # check and replace nsfw content
24 | def check_safety(x_image):
25 | global safety_feature_extractor, safety_checker
26 |
27 | if safety_feature_extractor is None:
28 | safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
29 | safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
30 |
31 | safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
32 | x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
33 |
34 | return x_checked_image, has_nsfw_concept
35 |
36 |
37 | def censor_batch(x):
38 | x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
39 | x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
40 | x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
41 |
42 | return x
43 |
--------------------------------------------------------------------------------
/webui.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | if not defined PYTHON (set PYTHON=python)
4 | if not defined VENV_DIR (set VENV_DIR=venv)
5 |
6 | set ERROR_REPORTING=FALSE
7 |
8 | mkdir tmp 2>NUL
9 |
10 | %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
11 | if %ERRORLEVEL% == 0 goto :start_venv
12 | echo Couldn't launch python
13 | goto :show_stdout_stderr
14 |
15 | :start_venv
16 | if [%VENV_DIR%] == [-] goto :skip_venv
17 |
18 | dir %VENV_DIR%\Scripts\Python.exe >tmp/stdout.txt 2>tmp/stderr.txt
19 | if %ERRORLEVEL% == 0 goto :activate_venv
20 |
21 | for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
22 | echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
23 | %PYTHON_FULLNAME% -m venv %VENV_DIR% >tmp/stdout.txt 2>tmp/stderr.txt
24 | if %ERRORLEVEL% == 0 goto :activate_venv
25 | echo Unable to create venv in directory %VENV_DIR%
26 | goto :show_stdout_stderr
27 |
28 | :activate_venv
29 | set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe"
30 | echo venv %PYTHON%
31 | goto :launch
32 |
33 | :skip_venv
34 |
35 | :launch
36 | %PYTHON% launch.py %*
37 | pause
38 | exit /b
39 |
40 | :show_stdout_stderr
41 |
42 | echo.
43 | echo exit code: %errorlevel%
44 |
45 | for /f %%i in ("tmp\stdout.txt") do set size=%%~zi
46 | if %size% equ 0 goto :show_stderr
47 | echo.
48 | echo stdout:
49 | type tmp\stdout.txt
50 |
51 | :show_stderr
52 | for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
53 | if %size% equ 0 goto :show_stderr
54 | echo.
55 | echo stderr:
56 | type tmp\stderr.txt
57 |
58 | :endofscript
59 |
60 | echo.
61 | echo Launch unsuccessful. Exiting.
62 | pause
63 |
--------------------------------------------------------------------------------
/javascript/notification.js:
--------------------------------------------------------------------------------
1 | // Monitors the gallery and sends a browser notification when the leading image is new.
2 |
3 | let lastHeadImg = null;
4 |
5 | notificationButton = null
6 |
7 | onUiUpdate(function(){
8 | if(notificationButton == null){
9 | notificationButton = gradioApp().getElementById('request_notifications')
10 |
11 | if(notificationButton != null){
12 | notificationButton.addEventListener('click', function (evt) {
13 | Notification.requestPermission();
14 | },true);
15 | }
16 | }
17 |
18 | const galleryPreviews = gradioApp().querySelectorAll('img.h-full.w-full.overflow-hidden');
19 |
20 | if (galleryPreviews == null) return;
21 |
22 | const headImg = galleryPreviews[0]?.src;
23 |
24 | if (headImg == null || headImg == lastHeadImg) return;
25 |
26 | lastHeadImg = headImg;
27 |
28 | // play notification sound if available
29 | gradioApp().querySelector('#audio_notification audio')?.play();
30 |
31 | if (document.hasFocus()) return;
32 |
33 | // Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated.
34 | const imgs = new Set(Array.from(galleryPreviews).map(img => img.src));
35 |
36 | const notification = new Notification(
37 | 'Stable Diffusion',
38 | {
39 | body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`,
40 | icon: headImg,
41 | image: headImg,
42 | }
43 | );
44 |
45 | notification.onclick = function(_){
46 | parent.focus();
47 | this.close();
48 | };
49 | });
50 |
--------------------------------------------------------------------------------
/modules/paths.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import modules.safe
5 |
6 | script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
7 | models_path = os.path.join(script_path, "models")
8 | sys.path.insert(0, script_path)
9 |
10 | # search for directory of stable diffusion in following places
11 | sd_path = None
12 | possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
13 | for possible_sd_path in possible_sd_paths:
14 | if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
15 | sd_path = os.path.abspath(possible_sd_path)
16 | break
17 |
18 | assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
19 |
20 | path_dirs = [
21 | (sd_path, 'ldm', 'Stable Diffusion', []),
22 | (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
23 | (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
24 | (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
25 | (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
26 | ]
27 |
28 | paths = {}
29 |
30 | for d, must_exist, what, options in path_dirs:
31 | must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
32 | if not os.path.exists(must_exist_path):
33 | print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
34 | else:
35 | d = os.path.abspath(d)
36 | if "atstart" in options:
37 | sys.path.insert(0, d)
38 | else:
39 | sys.path.append(d)
40 | paths[what] = d
41 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md:
--------------------------------------------------------------------------------
1 | # Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request!
2 |
3 | If you have a large change, pay special attention to this paragraph:
4 |
5 | > Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature.
6 |
7 | Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on.
8 |
9 | **Describe what this pull request is trying to achieve.**
10 |
11 | A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code.
12 |
13 | **Additional notes and description of your changes**
14 |
15 | More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of.
16 |
17 | **Environment this was tested in**
18 |
19 | List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
20 | - OS: [e.g. Windows, Linux]
21 | - Browser [e.g. chrome, safari]
22 | - Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
23 |
24 | **Screenshots or videos of your changes**
25 |
26 | If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made.
27 |
28 | This is **required** for anything that touches the user interface.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Taiyi stable-diffusion-webui
2 | Stable Diffusion web UI for Taiyi
3 |
4 | Make sure the requirement at least, very helpful.
5 |
6 | - transformers>=4.24.0
7 | - diffusers>=0.7.2
8 |
9 | ## step 1
10 |
11 | Since Taiyi's text_encoder has been modified (BertModel vs CLIPTextModel), and webui currently only supports stable diffusion in English, it is necessary to use the webui project modified by Fengshenbang's own fork.
12 |
13 | ```
14 | git clone https://github.com/IDEA-CCNL/stable-diffusion-webui.git
15 | cd stable-diffusion-webui
16 | ```
17 |
18 | ## step 2
19 |
20 | Run webui's own commands to check and install the environment, webui will pull down the required repositories in the stable-diffusion-webui/repositories directory, this process will take some time.
21 |
22 | ```
23 | bash webui.sh
24 | ```
25 |
26 | This script will then automatically download the required files, back up the original v1_inference.yaml file and replace it with the version needed to start our [taiyi model](https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1 ).
27 |
28 | **Notice that**, if you choose to redownload the Taiyi model, the total size of all the files needed is over 10G, the step:
29 | "Cloning taiyi_model into repositories/Taiyi-Stable-Diffusion-1B-Chinese-v0.1..." will take lots of time, please be patient.
30 |
31 | If you have already downloaded our whole Taiyi model in [https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1]() once, follow the path checker and choose "(2)move your downloaded Taiyi model path?", and move your downloaded model folder to ./repositories/Taiyi-Stable-Diffusion-1B-Chinese-v0.1.
32 |
33 | After all the progress is done, the web-ui service will be started on port 12345.
34 |
35 |
36 |
37 |
38 | You can run the following command to start the web-ui service.
39 |
40 | ```
41 | bash webui.sh --listen --port 12345
42 | bash webui.sh --ckpt repositories/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/Taiyi-Stable-Diffusion-1B-Chinese-v0.1.ckpt --listen --port 12345
43 | ```
44 |
--------------------------------------------------------------------------------
/test/img2img_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import requests
3 | from gradio.processing_utils import encode_pil_to_base64
4 | from PIL import Image
5 |
6 |
7 | class TestImg2ImgWorking(unittest.TestCase):
8 | def setUp(self):
9 | self.url_img2img = "http://localhost:7860/sdapi/v1/img2img"
10 | self.simple_img2img = {
11 | "init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))],
12 | "resize_mode": 0,
13 | "denoising_strength": 0.75,
14 | "mask": None,
15 | "mask_blur": 4,
16 | "inpainting_fill": 0,
17 | "inpaint_full_res": False,
18 | "inpaint_full_res_padding": 0,
19 | "inpainting_mask_invert": 0,
20 | "prompt": "example prompt",
21 | "styles": [],
22 | "seed": -1,
23 | "subseed": -1,
24 | "subseed_strength": 0,
25 | "seed_resize_from_h": -1,
26 | "seed_resize_from_w": -1,
27 | "batch_size": 1,
28 | "n_iter": 1,
29 | "steps": 3,
30 | "cfg_scale": 7,
31 | "width": 64,
32 | "height": 64,
33 | "restore_faces": False,
34 | "tiling": False,
35 | "negative_prompt": "",
36 | "eta": 0,
37 | "s_churn": 0,
38 | "s_tmax": 0,
39 | "s_tmin": 0,
40 | "s_noise": 1,
41 | "override_settings": {},
42 | "sampler_index": "Euler a",
43 | "include_init_images": False
44 | }
45 |
46 | def test_img2img_simple_performed(self):
47 | self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
48 |
49 | def test_inpainting_masked_performed(self):
50 | self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
51 | self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
52 |
53 |
54 | class TestImg2ImgCorrectness(unittest.TestCase):
55 | pass
56 |
57 |
58 | if __name__ == "__main__":
59 | unittest.main()
60 |
--------------------------------------------------------------------------------
/modules/ldsr_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import traceback
4 |
5 | from basicsr.utils.download_util import load_file_from_url
6 |
7 | from modules.upscaler import Upscaler, UpscalerData
8 | from modules.ldsr_model_arch import LDSR
9 | from modules import shared
10 |
11 |
12 | class UpscalerLDSR(Upscaler):
13 | def __init__(self, user_path):
14 | self.name = "LDSR"
15 | self.user_path = user_path
16 | self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
17 | self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
18 | super().__init__()
19 | scaler_data = UpscalerData("LDSR", None, self)
20 | self.scalers = [scaler_data]
21 |
22 | def load_model(self, path: str):
23 | # Remove incorrect project.yaml file if too big
24 | yaml_path = os.path.join(self.model_path, "project.yaml")
25 | old_model_path = os.path.join(self.model_path, "model.pth")
26 | new_model_path = os.path.join(self.model_path, "model.ckpt")
27 | if os.path.exists(yaml_path):
28 | statinfo = os.stat(yaml_path)
29 | if statinfo.st_size >= 10485760:
30 | print("Removing invalid LDSR YAML file.")
31 | os.remove(yaml_path)
32 | if os.path.exists(old_model_path):
33 | print("Renaming model from model.pth to model.ckpt")
34 | os.rename(old_model_path, new_model_path)
35 | model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
36 | file_name="model.ckpt", progress=True)
37 | yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
38 | file_name="project.yaml", progress=True)
39 |
40 | try:
41 | return LDSR(model, yaml)
42 |
43 | except Exception:
44 | print("Error importing LDSR:", file=sys.stderr)
45 | print(traceback.format_exc(), file=sys.stderr)
46 | return None
47 |
48 | def do_upscale(self, img, path):
49 | ldsr = self.load_model(path)
50 | if ldsr is None:
51 | print("NO LDSR!")
52 | return img
53 | ddim_steps = shared.opts.ldsr_steps
54 | return ldsr.super_resolution(img, ddim_steps, self.scale)
55 |
--------------------------------------------------------------------------------
/repositories/stable-diffusion-taiyi/configs/stable-diffusion/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | # cond_stage_config:
70 | # target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 | cond_stage_config:
72 | target: ldm.modules.encoders.modules.TaiyiCLIPEmbedder
73 | params:
74 | # you can git clone the model and change the version to your local model path
75 | version: your_path/Taiyi-Stable-Diffusion-1B-Chinese-v0.1
76 | max_length: 512
77 |
78 |
--------------------------------------------------------------------------------
/repositories/stable-diffusion-taiyi/configs/stable-diffusion/v1-inference-en.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | # cond_stage_config:
70 | # target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 | cond_stage_config:
72 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
73 | params:
74 | # you can git clone the model and change the version to your local model path
75 | version: your_path/Taiyi-Stable-Diffusion-1B-Chinese-EN-v0.1
76 | max_length: 77
77 |
78 |
--------------------------------------------------------------------------------
/modules/hypernetworks/ui.py:
--------------------------------------------------------------------------------
1 | import html
2 | import os
3 | import re
4 |
5 | import gradio as gr
6 | import modules.textual_inversion.preprocess
7 | import modules.textual_inversion.textual_inversion
8 | from modules import devices, sd_hijack, shared
9 | from modules.hypernetworks import hypernetwork
10 |
11 | not_available = ["hardswish", "multiheadattention"]
12 | keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
13 |
14 | def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
15 | # Remove illegal characters from name.
16 | name = "".join( x for x in name if (x.isalnum() or x in "._- "))
17 |
18 | fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
19 | if not overwrite_old:
20 | assert not os.path.exists(fn), f"file {fn} already exists"
21 |
22 | if type(layer_structure) == str:
23 | layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
24 |
25 | hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
26 | name=name,
27 | enable_sizes=[int(x) for x in enable_sizes],
28 | layer_structure=layer_structure,
29 | activation_func=activation_func,
30 | weight_init=weight_init,
31 | add_layer_norm=add_layer_norm,
32 | use_dropout=use_dropout,
33 | )
34 | hypernet.save(fn)
35 |
36 | shared.reload_hypernetworks()
37 |
38 | return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
39 |
40 |
41 | def train_hypernetwork(*args):
42 |
43 | initial_hypernetwork = shared.loaded_hypernetwork
44 |
45 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
46 |
47 | try:
48 | sd_hijack.undo_optimizations()
49 |
50 | hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
51 |
52 | res = f"""
53 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
54 | Hypernetwork saved to {html.escape(filename)}
55 | """
56 | return res, ""
57 | except Exception:
58 | raise
59 | finally:
60 | shared.loaded_hypernetwork = initial_hypernetwork
61 | shared.sd_model.cond_stage_model.to(devices.device)
62 | shared.sd_model.first_stage_model.to(devices.device)
63 | sd_hijack.apply_optimizations()
64 |
65 |
--------------------------------------------------------------------------------
/script.js:
--------------------------------------------------------------------------------
1 | function gradioApp(){
2 | return document.getElementsByTagName('gradio-app')[0].shadowRoot;
3 | }
4 |
5 | function get_uiCurrentTab() {
6 | return gradioApp().querySelector('.tabs button:not(.border-transparent)')
7 | }
8 |
9 | function get_uiCurrentTabContent() {
10 | return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])')
11 | }
12 |
13 | uiUpdateCallbacks = []
14 | uiTabChangeCallbacks = []
15 | let uiCurrentTab = null
16 |
17 | function onUiUpdate(callback){
18 | uiUpdateCallbacks.push(callback)
19 | }
20 | function onUiTabChange(callback){
21 | uiTabChangeCallbacks.push(callback)
22 | }
23 |
24 | function runCallback(x, m){
25 | try {
26 | x(m)
27 | } catch (e) {
28 | (console.error || console.log).call(console, e.message, e);
29 | }
30 | }
31 | function executeCallbacks(queue, m) {
32 | queue.forEach(function(x){runCallback(x, m)})
33 | }
34 |
35 | document.addEventListener("DOMContentLoaded", function() {
36 | var mutationObserver = new MutationObserver(function(m){
37 | executeCallbacks(uiUpdateCallbacks, m);
38 | const newTab = get_uiCurrentTab();
39 | if ( newTab && ( newTab !== uiCurrentTab ) ) {
40 | uiCurrentTab = newTab;
41 | executeCallbacks(uiTabChangeCallbacks);
42 | }
43 | });
44 | mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
45 | });
46 |
47 | /**
48 | * Add a ctrl+enter as a shortcut to start a generation
49 | */
50 | document.addEventListener('keydown', function(e) {
51 | var handled = false;
52 | if (e.key !== undefined) {
53 | if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
54 | } else if (e.keyCode !== undefined) {
55 | if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
56 | }
57 | if (handled) {
58 | button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
59 | if (button) {
60 | button.click();
61 | }
62 | e.preventDefault();
63 | }
64 | })
65 |
66 | /**
67 | * checks that a UI element is not in another hidden element or tab content
68 | */
69 | function uiElementIsVisible(el) {
70 | let isVisible = !el.closest('.\\!hidden');
71 | if ( ! isVisible ) {
72 | return false;
73 | }
74 |
75 | while( isVisible = el.closest('.tabitem')?.style.display !== 'none' ) {
76 | if ( ! isVisible ) {
77 | return false;
78 | } else if ( el.parentElement ) {
79 | el = el.parentElement
80 | } else {
81 | break;
82 | }
83 | }
84 | return isVisible;
85 | }
--------------------------------------------------------------------------------
/modules/txt2img.py:
--------------------------------------------------------------------------------
1 | import modules.scripts
2 | from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
3 | StableDiffusionProcessingImg2Img, process_images
4 | from modules.shared import opts, cmd_opts
5 | import modules.shared as shared
6 | import modules.processing as processing
7 | from modules.ui import plaintext_to_html
8 |
9 |
10 | def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
11 | p = StableDiffusionProcessingTxt2Img(
12 | sd_model=shared.sd_model,
13 | outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
14 | outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
15 | prompt=prompt,
16 | styles=[prompt_style, prompt_style2],
17 | negative_prompt=negative_prompt,
18 | seed=seed,
19 | subseed=subseed,
20 | subseed_strength=subseed_strength,
21 | seed_resize_from_h=seed_resize_from_h,
22 | seed_resize_from_w=seed_resize_from_w,
23 | seed_enable_extras=seed_enable_extras,
24 | sampler_index=sampler_index,
25 | batch_size=batch_size,
26 | n_iter=n_iter,
27 | steps=steps,
28 | cfg_scale=cfg_scale,
29 | width=width,
30 | height=height,
31 | restore_faces=restore_faces,
32 | tiling=tiling,
33 | enable_hr=enable_hr,
34 | denoising_strength=denoising_strength if enable_hr else None,
35 | firstphase_width=firstphase_width if enable_hr else None,
36 | firstphase_height=firstphase_height if enable_hr else None,
37 | )
38 |
39 | p.scripts = modules.scripts.scripts_txt2img
40 | p.script_args = args
41 |
42 | if cmd_opts.enable_console_prompts:
43 | print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
44 |
45 | processed = modules.scripts.scripts_txt2img.run(p, *args)
46 |
47 | if processed is None:
48 | processed = process_images(p)
49 |
50 | p.close()
51 |
52 | shared.total_tqdm.clear()
53 |
54 | generation_info_js = processed.js()
55 | if opts.samples_log_stdout:
56 | print(generation_info_js)
57 |
58 | if opts.do_not_show_images:
59 | processed.images = []
60 |
61 | return processed.images, generation_info_js, plaintext_to_html(processed.info)
62 |
--------------------------------------------------------------------------------
/modules/extensions.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import traceback
4 |
5 | import git
6 |
7 | from modules import paths, shared
8 |
9 |
10 | extensions = []
11 | extensions_dir = os.path.join(paths.script_path, "extensions")
12 |
13 |
14 | def active():
15 | return [x for x in extensions if x.enabled]
16 |
17 |
18 | class Extension:
19 | def __init__(self, name, path, enabled=True):
20 | self.name = name
21 | self.path = path
22 | self.enabled = enabled
23 | self.status = ''
24 | self.can_update = False
25 |
26 | repo = None
27 | try:
28 | if os.path.exists(os.path.join(path, ".git")):
29 | repo = git.Repo(path)
30 | except Exception:
31 | print(f"Error reading github repository info from {path}:", file=sys.stderr)
32 | print(traceback.format_exc(), file=sys.stderr)
33 |
34 | if repo is None or repo.bare:
35 | self.remote = None
36 | else:
37 | try:
38 | self.remote = next(repo.remote().urls, None)
39 | self.status = 'unknown'
40 | except Exception:
41 | self.remote = None
42 |
43 | def list_files(self, subdir, extension):
44 | from modules import scripts
45 |
46 | dirpath = os.path.join(self.path, subdir)
47 | if not os.path.isdir(dirpath):
48 | return []
49 |
50 | res = []
51 | for filename in sorted(os.listdir(dirpath)):
52 | res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
53 |
54 | res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
55 |
56 | return res
57 |
58 | def check_updates(self):
59 | repo = git.Repo(self.path)
60 | for fetch in repo.remote().fetch("--dry-run"):
61 | if fetch.flags != fetch.HEAD_UPTODATE:
62 | self.can_update = True
63 | self.status = "behind"
64 | return
65 |
66 | self.can_update = False
67 | self.status = "latest"
68 |
69 | def pull(self):
70 | repo = git.Repo(self.path)
71 | repo.remotes.origin.pull()
72 |
73 |
74 | def list_extensions():
75 | extensions.clear()
76 |
77 | if not os.path.isdir(extensions_dir):
78 | return
79 |
80 | for dirname in sorted(os.listdir(extensions_dir)):
81 | path = os.path.join(extensions_dir, dirname)
82 | if not os.path.isdir(path):
83 | continue
84 |
85 | extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
86 | extensions.append(extension)
87 |
--------------------------------------------------------------------------------
/modules/memmon.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import time
3 | from collections import defaultdict
4 |
5 | import torch
6 |
7 |
8 | class MemUsageMonitor(threading.Thread):
9 | run_flag = None
10 | device = None
11 | disabled = False
12 | opts = None
13 | data = None
14 |
15 | def __init__(self, name, device, opts):
16 | threading.Thread.__init__(self)
17 | self.name = name
18 | self.device = device
19 | self.opts = opts
20 |
21 | self.daemon = True
22 | self.run_flag = threading.Event()
23 | self.data = defaultdict(int)
24 |
25 | try:
26 | torch.cuda.mem_get_info()
27 | torch.cuda.memory_stats(self.device)
28 | except Exception as e: # AMD or whatever
29 | print(f"Warning: caught exception '{e}', memory monitor disabled")
30 | self.disabled = True
31 |
32 | def run(self):
33 | if self.disabled:
34 | return
35 |
36 | while True:
37 | self.run_flag.wait()
38 |
39 | torch.cuda.reset_peak_memory_stats()
40 | self.data.clear()
41 |
42 | if self.opts.memmon_poll_rate <= 0:
43 | self.run_flag.clear()
44 | continue
45 |
46 | self.data["min_free"] = torch.cuda.mem_get_info()[0]
47 |
48 | while self.run_flag.is_set():
49 | free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
50 | self.data["min_free"] = min(self.data["min_free"], free)
51 |
52 | time.sleep(1 / self.opts.memmon_poll_rate)
53 |
54 | def dump_debug(self):
55 | print(self, 'recorded data:')
56 | for k, v in self.read().items():
57 | print(k, -(v // -(1024 ** 2)))
58 |
59 | print(self, 'raw torch memory stats:')
60 | tm = torch.cuda.memory_stats(self.device)
61 | for k, v in tm.items():
62 | if 'bytes' not in k:
63 | continue
64 | print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
65 |
66 | print(torch.cuda.memory_summary())
67 |
68 | def monitor(self):
69 | self.run_flag.set()
70 |
71 | def read(self):
72 | if not self.disabled:
73 | free, total = torch.cuda.mem_get_info()
74 | self.data["total"] = total
75 |
76 | torch_stats = torch.cuda.memory_stats(self.device)
77 | self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
78 | self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
79 | self.data["system_peak"] = total - self.data["min_free"]
80 |
81 | return self.data
82 |
83 | def stop(self):
84 | self.run_flag.clear()
85 | return self.read()
86 |
--------------------------------------------------------------------------------
/test/utils_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import requests
3 |
4 | class UtilsTests(unittest.TestCase):
5 | def setUp(self):
6 | self.url_options = "http://localhost:7860/sdapi/v1/options"
7 | self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags"
8 | self.url_samplers = "http://localhost:7860/sdapi/v1/samplers"
9 | self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers"
10 | self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models"
11 | self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks"
12 | self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
13 | self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
14 | self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
15 | self.url_artist_categories = "http://localhost:7860/sdapi/v1/artist-categories"
16 | self.url_artists = "http://localhost:7860/sdapi/v1/artists"
17 |
18 | def test_options_get(self):
19 | self.assertEqual(requests.get(self.url_options).status_code, 200)
20 |
21 | def test_options_write(self):
22 | response = requests.get(self.url_options)
23 | self.assertEqual(response.status_code, 200)
24 |
25 | pre_value = response.json()["send_seed"]
26 |
27 | self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200)
28 |
29 | response = requests.get(self.url_options)
30 | self.assertEqual(response.status_code, 200)
31 | self.assertEqual(response.json()["send_seed"], not pre_value)
32 |
33 | requests.post(self.url_options, json={"send_seed": pre_value})
34 |
35 | def test_cmd_flags(self):
36 | self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
37 |
38 | def test_samplers(self):
39 | self.assertEqual(requests.get(self.url_samplers).status_code, 200)
40 |
41 | def test_upscalers(self):
42 | self.assertEqual(requests.get(self.url_upscalers).status_code, 200)
43 |
44 | def test_sd_models(self):
45 | self.assertEqual(requests.get(self.url_sd_models).status_code, 200)
46 |
47 | def test_hypernetworks(self):
48 | self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200)
49 |
50 | def test_face_restorers(self):
51 | self.assertEqual(requests.get(self.url_face_restorers).status_code, 200)
52 |
53 | def test_realesrgan_models(self):
54 | self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200)
55 |
56 | def test_prompt_styles(self):
57 | self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
58 |
59 | def test_artist_categories(self):
60 | self.assertEqual(requests.get(self.url_artist_categories).status_code, 200)
61 |
62 | def test_artists(self):
63 | self.assertEqual(requests.get(self.url_artists).status_code, 200)
--------------------------------------------------------------------------------
/modules/devices.py:
--------------------------------------------------------------------------------
1 | import sys, os, shlex
2 | import contextlib
3 | import torch
4 | from modules import errors
5 |
6 | # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
7 | has_mps = getattr(torch, 'has_mps', False)
8 |
9 | cpu = torch.device("cpu")
10 |
11 | def extract_device_id(args, name):
12 | for x in range(len(args)):
13 | if name in args[x]: return args[x+1]
14 | return None
15 |
16 | def get_optimal_device():
17 | if torch.cuda.is_available():
18 | from modules import shared
19 |
20 | device_id = shared.cmd_opts.device_id
21 |
22 | if device_id is not None:
23 | cuda_device = f"cuda:{device_id}"
24 | return torch.device(cuda_device)
25 | else:
26 | return torch.device("cuda")
27 |
28 | if has_mps:
29 | return torch.device("mps")
30 |
31 | return cpu
32 |
33 |
34 | def torch_gc():
35 | if torch.cuda.is_available():
36 | torch.cuda.empty_cache()
37 | torch.cuda.ipc_collect()
38 |
39 |
40 | def enable_tf32():
41 | if torch.cuda.is_available():
42 | torch.backends.cuda.matmul.allow_tf32 = True
43 | torch.backends.cudnn.allow_tf32 = True
44 |
45 |
46 | errors.run(enable_tf32, "Enabling TF32")
47 |
48 | device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
49 | dtype = torch.float16
50 | dtype_vae = torch.float16
51 |
52 | def randn(seed, shape):
53 | # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
54 | if device.type == 'mps':
55 | generator = torch.Generator(device=cpu)
56 | generator.manual_seed(seed)
57 | noise = torch.randn(shape, generator=generator, device=cpu).to(device)
58 | return noise
59 |
60 | torch.manual_seed(seed)
61 | return torch.randn(shape, device=device)
62 |
63 |
64 | def randn_without_seed(shape):
65 | # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
66 | if device.type == 'mps':
67 | generator = torch.Generator(device=cpu)
68 | noise = torch.randn(shape, generator=generator, device=cpu).to(device)
69 | return noise
70 |
71 | return torch.randn(shape, device=device)
72 |
73 |
74 | def autocast(disable=False):
75 | from modules import shared
76 |
77 | if disable:
78 | return contextlib.nullcontext()
79 |
80 | if dtype == torch.float32 or shared.cmd_opts.precision == "full":
81 | return contextlib.nullcontext()
82 |
83 | return torch.autocast("cuda")
84 |
85 | # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
86 | def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
87 | def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)
88 |
--------------------------------------------------------------------------------
/modules/textual_inversion/learn_schedule.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 |
3 |
4 | class LearnScheduleIterator:
5 | def __init__(self, learn_rate, max_steps, cur_step=0):
6 | """
7 | specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
8 | """
9 |
10 | pairs = learn_rate.split(',')
11 | self.rates = []
12 | self.it = 0
13 | self.maxit = 0
14 | try:
15 | for i, pair in enumerate(pairs):
16 | if not pair.strip():
17 | continue
18 | tmp = pair.split(':')
19 | if len(tmp) == 2:
20 | step = int(tmp[1])
21 | if step > cur_step:
22 | self.rates.append((float(tmp[0]), min(step, max_steps)))
23 | self.maxit += 1
24 | if step > max_steps:
25 | return
26 | elif step == -1:
27 | self.rates.append((float(tmp[0]), max_steps))
28 | self.maxit += 1
29 | return
30 | else:
31 | self.rates.append((float(tmp[0]), max_steps))
32 | self.maxit += 1
33 | return
34 | assert self.rates
35 | except (ValueError, AssertionError):
36 | raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
37 |
38 |
39 | def __iter__(self):
40 | return self
41 |
42 | def __next__(self):
43 | if self.it < self.maxit:
44 | self.it += 1
45 | return self.rates[self.it - 1]
46 | else:
47 | raise StopIteration
48 |
49 |
50 | class LearnRateScheduler:
51 | def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
52 | self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
53 | (self.learn_rate, self.end_step) = next(self.schedules)
54 | self.verbose = verbose
55 |
56 | if self.verbose:
57 | print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
58 |
59 | self.finished = False
60 |
61 | def apply(self, optimizer, step_number):
62 | if step_number < self.end_step:
63 | return
64 |
65 | try:
66 | (self.learn_rate, self.end_step) = next(self.schedules)
67 | except Exception:
68 | self.finished = True
69 | return
70 |
71 | if self.verbose:
72 | tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
73 |
74 | for pg in optimizer.param_groups:
75 | pg['lr'] = self.learn_rate
76 |
77 |
--------------------------------------------------------------------------------
/test/txt2img_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import requests
3 |
4 |
5 | class TestTxt2ImgWorking(unittest.TestCase):
6 | def setUp(self):
7 | self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img"
8 | self.simple_txt2img = {
9 | "enable_hr": False,
10 | "denoising_strength": 0,
11 | "firstphase_width": 0,
12 | "firstphase_height": 0,
13 | "prompt": "example prompt",
14 | "styles": [],
15 | "seed": -1,
16 | "subseed": -1,
17 | "subseed_strength": 0,
18 | "seed_resize_from_h": -1,
19 | "seed_resize_from_w": -1,
20 | "batch_size": 1,
21 | "n_iter": 1,
22 | "steps": 3,
23 | "cfg_scale": 7,
24 | "width": 64,
25 | "height": 64,
26 | "restore_faces": False,
27 | "tiling": False,
28 | "negative_prompt": "",
29 | "eta": 0,
30 | "s_churn": 0,
31 | "s_tmax": 0,
32 | "s_tmin": 0,
33 | "s_noise": 1,
34 | "sampler_index": "Euler a"
35 | }
36 |
37 | def test_txt2img_simple_performed(self):
38 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
39 |
40 | def test_txt2img_with_negative_prompt_performed(self):
41 | self.simple_txt2img["negative_prompt"] = "example negative prompt"
42 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
43 |
44 | def test_txt2img_not_square_image_performed(self):
45 | self.simple_txt2img["height"] = 128
46 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
47 |
48 | def test_txt2img_with_hrfix_performed(self):
49 | self.simple_txt2img["enable_hr"] = True
50 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
51 |
52 | def test_txt2img_with_restore_faces_performed(self):
53 | self.simple_txt2img["restore_faces"] = True
54 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
55 |
56 | def test_txt2img_with_tiling_faces_performed(self):
57 | self.simple_txt2img["tiling"] = True
58 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
59 |
60 | def test_txt2img_with_vanilla_sampler_performed(self):
61 | self.simple_txt2img["sampler_index"] = "PLMS"
62 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
63 |
64 | def test_txt2img_multiple_batches_performed(self):
65 | self.simple_txt2img["n_iter"] = 2
66 | self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
67 |
68 |
69 | class TestTxt2ImgCorrectness(unittest.TestCase):
70 | pass
71 |
72 |
73 | if __name__ == "__main__":
74 | unittest.main()
75 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: Bug Report
2 | description: You think somethings is broken in the UI
3 | title: "[Bug]: "
4 | labels: ["bug-report"]
5 |
6 | body:
7 | - type: checkboxes
8 | attributes:
9 | label: Is there an existing issue for this?
10 | description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
11 | options:
12 | - label: I have searched the existing issues and checked the recent builds/commits
13 | required: true
14 | - type: markdown
15 | attributes:
16 | value: |
17 | *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
18 | - type: textarea
19 | id: what-did
20 | attributes:
21 | label: What happened?
22 | description: Tell us what happened in a very clear and simple way
23 | validations:
24 | required: true
25 | - type: textarea
26 | id: steps
27 | attributes:
28 | label: Steps to reproduce the problem
29 | description: Please provide us with precise step by step information on how to reproduce the bug
30 | value: |
31 | 1. Go to ....
32 | 2. Press ....
33 | 3. ...
34 | validations:
35 | required: true
36 | - type: textarea
37 | id: what-should
38 | attributes:
39 | label: What should have happened?
40 | description: tell what you think the normal behavior should be
41 | validations:
42 | required: true
43 | - type: input
44 | id: commit
45 | attributes:
46 | label: Commit where the problem happens
47 | description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
48 | validations:
49 | required: true
50 | - type: dropdown
51 | id: platforms
52 | attributes:
53 | label: What platforms do you use to access UI ?
54 | multiple: true
55 | options:
56 | - Windows
57 | - Linux
58 | - MacOS
59 | - iOS
60 | - Android
61 | - Other/Cloud
62 | - type: dropdown
63 | id: browsers
64 | attributes:
65 | label: What browsers do you use to access the UI ?
66 | multiple: true
67 | options:
68 | - Mozilla Firefox
69 | - Google Chrome
70 | - Brave
71 | - Apple Safari
72 | - Microsoft Edge
73 | - type: textarea
74 | id: cmdargs
75 | attributes:
76 | label: Command Line Arguments
77 | description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
78 | render: Shell
79 | - type: textarea
80 | id: misc
81 | attributes:
82 | label: Additional information, context and logs
83 | description: Please provide us with any relevant additional info, context or log output.
84 |
--------------------------------------------------------------------------------
/javascript/dragdrop.js:
--------------------------------------------------------------------------------
1 | // allows drag-dropping files into gradio image elements, and also pasting images from clipboard
2 |
3 | function isValidImageList( files ) {
4 | return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
5 | }
6 |
7 | function dropReplaceImage( imgWrap, files ) {
8 | if ( ! isValidImageList( files ) ) {
9 | return;
10 | }
11 |
12 | imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
13 | const callback = () => {
14 | const fileInput = imgWrap.querySelector('input[type="file"]');
15 | if ( fileInput ) {
16 | fileInput.files = files;
17 | fileInput.dispatchEvent(new Event('change'));
18 | }
19 | };
20 |
21 | if ( imgWrap.closest('#pnginfo_image') ) {
22 | // special treatment for PNG Info tab, wait for fetch request to finish
23 | const oldFetch = window.fetch;
24 | window.fetch = async (input, options) => {
25 | const response = await oldFetch(input, options);
26 | if ( 'api/predict/' === input ) {
27 | const content = await response.text();
28 | window.fetch = oldFetch;
29 | window.requestAnimationFrame( () => callback() );
30 | return new Response(content, {
31 | status: response.status,
32 | statusText: response.statusText,
33 | headers: response.headers
34 | })
35 | }
36 | return response;
37 | };
38 | } else {
39 | window.requestAnimationFrame( () => callback() );
40 | }
41 | }
42 |
43 | window.document.addEventListener('dragover', e => {
44 | const target = e.composedPath()[0];
45 | const imgWrap = target.closest('[data-testid="image"]');
46 | if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
47 | return;
48 | }
49 | e.stopPropagation();
50 | e.preventDefault();
51 | e.dataTransfer.dropEffect = 'copy';
52 | });
53 |
54 | window.document.addEventListener('drop', e => {
55 | const target = e.composedPath()[0];
56 | if (target.placeholder.indexOf("Prompt") == -1) {
57 | return;
58 | }
59 | const imgWrap = target.closest('[data-testid="image"]');
60 | if ( !imgWrap ) {
61 | return;
62 | }
63 | e.stopPropagation();
64 | e.preventDefault();
65 | const files = e.dataTransfer.files;
66 | dropReplaceImage( imgWrap, files );
67 | });
68 |
69 | window.addEventListener('paste', e => {
70 | const files = e.clipboardData.files;
71 | if ( ! isValidImageList( files ) ) {
72 | return;
73 | }
74 |
75 | const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
76 | .filter(el => uiElementIsVisible(el));
77 | if ( ! visibleImageFields.length ) {
78 | return;
79 | }
80 |
81 | const firstFreeImageField = visibleImageFields
82 | .filter(el => el.querySelector('input[type=file]'))?.[0];
83 |
84 | dropReplaceImage(
85 | firstFreeImageField ?
86 | firstFreeImageField :
87 | visibleImageFields[visibleImageFields.length - 1]
88 | , files );
89 | });
90 |
--------------------------------------------------------------------------------
/javascript/edit-attention.js:
--------------------------------------------------------------------------------
1 | addEventListener('keydown', (event) => {
2 | let target = event.originalTarget || event.composedPath()[0];
3 | if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return;
4 | if (! (event.metaKey || event.ctrlKey)) return;
5 |
6 |
7 | let plus = "ArrowUp"
8 | let minus = "ArrowDown"
9 | if (event.key != plus && event.key != minus) return;
10 |
11 | let selectionStart = target.selectionStart;
12 | let selectionEnd = target.selectionEnd;
13 | // If the user hasn't selected anything, let's select their current parenthesis block
14 | if (selectionStart === selectionEnd) {
15 | // Find opening parenthesis around current cursor
16 | const before = target.value.substring(0, selectionStart);
17 | let beforeParen = before.lastIndexOf("(");
18 | if (beforeParen == -1) return;
19 | let beforeParenClose = before.lastIndexOf(")");
20 | while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
21 | beforeParen = before.lastIndexOf("(", beforeParen - 1);
22 | beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1);
23 | }
24 |
25 | // Find closing parenthesis around current cursor
26 | const after = target.value.substring(selectionStart);
27 | let afterParen = after.indexOf(")");
28 | if (afterParen == -1) return;
29 | let afterParenOpen = after.indexOf("(");
30 | while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
31 | afterParen = after.indexOf(")", afterParen + 1);
32 | afterParenOpen = after.indexOf("(", afterParenOpen + 1);
33 | }
34 | if (beforeParen === -1 || afterParen === -1) return;
35 |
36 | // Set the selection to the text between the parenthesis
37 | const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen);
38 | const lastColon = parenContent.lastIndexOf(":");
39 | selectionStart = beforeParen + 1;
40 | selectionEnd = selectionStart + lastColon;
41 | target.setSelectionRange(selectionStart, selectionEnd);
42 | }
43 |
44 | event.preventDefault();
45 |
46 | if (selectionStart == 0 || target.value[selectionStart - 1] != "(") {
47 | target.value = target.value.slice(0, selectionStart) +
48 | "(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" +
49 | target.value.slice(selectionEnd);
50 |
51 | target.focus();
52 | target.selectionStart = selectionStart + 1;
53 | target.selectionEnd = selectionEnd + 1;
54 |
55 | } else {
56 | end = target.value.slice(selectionEnd + 1).indexOf(")") + 1;
57 | weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end));
58 | if (isNaN(weight)) return;
59 | if (event.key == minus) weight -= 0.1;
60 | if (event.key == plus) weight += 0.1;
61 |
62 | weight = parseFloat(weight.toPrecision(12));
63 |
64 | target.value = target.value.slice(0, selectionEnd + 1) +
65 | weight +
66 | target.value.slice(selectionEnd + 1 + end - 1);
67 |
68 | target.focus();
69 | target.selectionStart = selectionStart;
70 | target.selectionEnd = selectionEnd;
71 | }
72 | // Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its
73 | // internal Svelte data binding remains in sync.
74 | target.dispatchEvent(new Event("input", { bubbles: true }));
75 | });
76 |
--------------------------------------------------------------------------------
/scripts/loopback.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import trange
3 |
4 | import modules.scripts as scripts
5 | import gradio as gr
6 |
7 | from modules import processing, shared, sd_samplers, images
8 | from modules.processing import Processed
9 | from modules.sd_samplers import samplers
10 | from modules.shared import opts, cmd_opts, state
11 |
12 | class Script(scripts.Script):
13 | def title(self):
14 | return "Loopback"
15 |
16 | def show(self, is_img2img):
17 | return is_img2img
18 |
19 | def ui(self, is_img2img):
20 | loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4)
21 | denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1)
22 |
23 | return [loops, denoising_strength_change_factor]
24 |
25 | def run(self, p, loops, denoising_strength_change_factor):
26 | processing.fix_seed(p)
27 | batch_count = p.n_iter
28 | p.extra_generation_params = {
29 | "Denoising strength change factor": denoising_strength_change_factor,
30 | }
31 |
32 | p.batch_size = 1
33 | p.n_iter = 1
34 |
35 | output_images, info = None, None
36 | initial_seed = None
37 | initial_info = None
38 |
39 | grids = []
40 | all_images = []
41 | original_init_image = p.init_images
42 | state.job_count = loops * batch_count
43 |
44 | initial_color_corrections = [processing.setup_color_correction(p.init_images[0])]
45 |
46 | for n in range(batch_count):
47 | history = []
48 |
49 | # Reset to original init image at the start of each batch
50 | p.init_images = original_init_image
51 |
52 | for i in range(loops):
53 | p.n_iter = 1
54 | p.batch_size = 1
55 | p.do_not_save_grid = True
56 |
57 | if opts.img2img_color_correction:
58 | p.color_corrections = initial_color_corrections
59 |
60 | state.job = f"Iteration {i + 1}/{loops}, batch {n + 1}/{batch_count}"
61 |
62 | processed = processing.process_images(p)
63 |
64 | if initial_seed is None:
65 | initial_seed = processed.seed
66 | initial_info = processed.info
67 |
68 | init_img = processed.images[0]
69 |
70 | p.init_images = [init_img]
71 | p.seed = processed.seed + 1
72 | p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1)
73 | history.append(processed.images[0])
74 |
75 | grid = images.image_grid(history, rows=1)
76 | if opts.grid_save:
77 | images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
78 |
79 | grids.append(grid)
80 | all_images += history
81 |
82 | if opts.return_grid:
83 | all_images = grids + all_images
84 |
85 | processed = Processed(p, all_images, initial_seed, initial_info)
86 |
87 | return processed
88 |
--------------------------------------------------------------------------------
/scripts/prompt_matrix.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import namedtuple
3 | from copy import copy
4 | import random
5 |
6 | import modules.scripts as scripts
7 | import gradio as gr
8 |
9 | from modules import images
10 | from modules.processing import process_images, Processed
11 | from modules.shared import opts, cmd_opts, state
12 | import modules.sd_samplers
13 |
14 |
15 | def draw_xy_grid(xs, ys, x_label, y_label, cell):
16 | res = []
17 |
18 | ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
19 | hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
20 |
21 | first_pocessed = None
22 |
23 | state.job_count = len(xs) * len(ys)
24 |
25 | for iy, y in enumerate(ys):
26 | for ix, x in enumerate(xs):
27 | state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
28 |
29 | processed = cell(x, y)
30 | if first_pocessed is None:
31 | first_pocessed = processed
32 |
33 | res.append(processed.images[0])
34 |
35 | grid = images.image_grid(res, rows=len(ys))
36 | grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
37 |
38 | first_pocessed.images = [grid]
39 |
40 | return first_pocessed
41 |
42 |
43 | class Script(scripts.Script):
44 | def title(self):
45 | return "Prompt matrix"
46 |
47 | def ui(self, is_img2img):
48 | put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False)
49 |
50 | return [put_at_start]
51 |
52 | def run(self, p, put_at_start):
53 | modules.processing.fix_seed(p)
54 |
55 | original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
56 |
57 | all_prompts = []
58 | prompt_matrix_parts = original_prompt.split("|")
59 | combination_count = 2 ** (len(prompt_matrix_parts) - 1)
60 | for combination_num in range(combination_count):
61 | selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
62 |
63 | if put_at_start:
64 | selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
65 | else:
66 | selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
67 |
68 | all_prompts.append(", ".join(selected_prompts))
69 |
70 | p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
71 | p.do_not_save_grid = True
72 |
73 | print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
74 |
75 | p.prompt = all_prompts
76 | p.seed = [p.seed for _ in all_prompts]
77 | p.prompt_for_display = original_prompt
78 | processed = process_images(p)
79 |
80 | grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
81 | grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
82 | processed.images.insert(0, grid)
83 |
84 | if opts.grid_save:
85 | images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", prompt=original_prompt, seed=processed.seed, grid=True, p=p)
86 |
87 | return processed
88 |
--------------------------------------------------------------------------------
/modules/scunet_model.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import sys
3 | import traceback
4 |
5 | import PIL.Image
6 | import numpy as np
7 | import torch
8 | from basicsr.utils.download_util import load_file_from_url
9 |
10 | import modules.upscaler
11 | from modules import devices, modelloader
12 | from modules.scunet_model_arch import SCUNet as net
13 |
14 |
15 | class UpscalerScuNET(modules.upscaler.Upscaler):
16 | def __init__(self, dirname):
17 | self.name = "ScuNET"
18 | self.model_name = "ScuNET GAN"
19 | self.model_name2 = "ScuNET PSNR"
20 | self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
21 | self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
22 | self.user_path = dirname
23 | super().__init__()
24 | model_paths = self.find_models(ext_filter=[".pth"])
25 | scalers = []
26 | add_model2 = True
27 | for file in model_paths:
28 | if "http" in file:
29 | name = self.model_name
30 | else:
31 | name = modelloader.friendly_name(file)
32 | if name == self.model_name2 or file == self.model_url2:
33 | add_model2 = False
34 | try:
35 | scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
36 | scalers.append(scaler_data)
37 | except Exception:
38 | print(f"Error loading ScuNET model: {file}", file=sys.stderr)
39 | print(traceback.format_exc(), file=sys.stderr)
40 | if add_model2:
41 | scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
42 | scalers.append(scaler_data2)
43 | self.scalers = scalers
44 |
45 | def do_upscale(self, img: PIL.Image, selected_file):
46 | torch.cuda.empty_cache()
47 |
48 | model = self.load_model(selected_file)
49 | if model is None:
50 | return img
51 |
52 | device = devices.device_scunet
53 | img = np.array(img)
54 | img = img[:, :, ::-1]
55 | img = np.moveaxis(img, 2, 0) / 255
56 | img = torch.from_numpy(img).float()
57 | img = devices.mps_contiguous_to(img.unsqueeze(0), device)
58 |
59 | with torch.no_grad():
60 | output = model(img)
61 | output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
62 | output = 255. * np.moveaxis(output, 0, 2)
63 | output = output.astype(np.uint8)
64 | output = output[:, :, ::-1]
65 | torch.cuda.empty_cache()
66 | return PIL.Image.fromarray(output, 'RGB')
67 |
68 | def load_model(self, path: str):
69 | device = devices.device_scunet
70 | if "http" in path:
71 | filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
72 | progress=True)
73 | else:
74 | filename = path
75 | if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
76 | print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
77 | return None
78 |
79 | model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
80 | model.load_state_dict(torch.load(filename), strict=True)
81 | model.eval()
82 | for k, v in model.named_parameters():
83 | v.requires_grad = False
84 | model = model.to(device)
85 |
86 | return model
87 |
88 |
--------------------------------------------------------------------------------
/modules/masking.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageFilter, ImageOps
2 |
3 |
4 | def get_crop_region(mask, pad=0):
5 | """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
6 | For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
7 |
8 | h, w = mask.shape
9 |
10 | crop_left = 0
11 | for i in range(w):
12 | if not (mask[:, i] == 0).all():
13 | break
14 | crop_left += 1
15 |
16 | crop_right = 0
17 | for i in reversed(range(w)):
18 | if not (mask[:, i] == 0).all():
19 | break
20 | crop_right += 1
21 |
22 | crop_top = 0
23 | for i in range(h):
24 | if not (mask[i] == 0).all():
25 | break
26 | crop_top += 1
27 |
28 | crop_bottom = 0
29 | for i in reversed(range(h)):
30 | if not (mask[i] == 0).all():
31 | break
32 | crop_bottom += 1
33 |
34 | return (
35 | int(max(crop_left-pad, 0)),
36 | int(max(crop_top-pad, 0)),
37 | int(min(w - crop_right + pad, w)),
38 | int(min(h - crop_bottom + pad, h))
39 | )
40 |
41 |
42 | def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
43 | """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
44 | for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
45 |
46 | x1, y1, x2, y2 = crop_region
47 |
48 | ratio_crop_region = (x2 - x1) / (y2 - y1)
49 | ratio_processing = processing_width / processing_height
50 |
51 | if ratio_crop_region > ratio_processing:
52 | desired_height = (x2 - x1) / ratio_processing
53 | desired_height_diff = int(desired_height - (y2-y1))
54 | y1 -= desired_height_diff//2
55 | y2 += desired_height_diff - desired_height_diff//2
56 | if y2 >= image_height:
57 | diff = y2 - image_height
58 | y2 -= diff
59 | y1 -= diff
60 | if y1 < 0:
61 | y2 -= y1
62 | y1 -= y1
63 | if y2 >= image_height:
64 | y2 = image_height
65 | else:
66 | desired_width = (y2 - y1) * ratio_processing
67 | desired_width_diff = int(desired_width - (x2-x1))
68 | x1 -= desired_width_diff//2
69 | x2 += desired_width_diff - desired_width_diff//2
70 | if x2 >= image_width:
71 | diff = x2 - image_width
72 | x2 -= diff
73 | x1 -= diff
74 | if x1 < 0:
75 | x2 -= x1
76 | x1 -= x1
77 | if x2 >= image_width:
78 | x2 = image_width
79 |
80 | return x1, y1, x2, y2
81 |
82 |
83 | def fill(image, mask):
84 | """fills masked regions with colors from image using blur. Not extremely effective."""
85 |
86 | image_mod = Image.new('RGBA', (image.width, image.height))
87 |
88 | image_masked = Image.new('RGBa', (image.width, image.height))
89 | image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
90 |
91 | image_masked = image_masked.convert('RGBa')
92 |
93 | for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
94 | blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
95 | for _ in range(repeats):
96 | image_mod.alpha_composite(blurred)
97 |
98 | return image_mod.convert("RGB")
99 |
100 |
--------------------------------------------------------------------------------
/javascript/aspectRatioOverlay.js:
--------------------------------------------------------------------------------
1 |
2 | let currentWidth = null;
3 | let currentHeight = null;
4 | let arFrameTimeout = setTimeout(function(){},0);
5 |
6 | function dimensionChange(e, is_width, is_height){
7 |
8 | if(is_width){
9 | currentWidth = e.target.value*1.0
10 | }
11 | if(is_height){
12 | currentHeight = e.target.value*1.0
13 | }
14 |
15 | var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
16 |
17 | if(!inImg2img){
18 | return;
19 | }
20 |
21 | var targetElement = null;
22 |
23 | var tabIndex = get_tab_index('mode_img2img')
24 | if(tabIndex == 0){
25 | targetElement = gradioApp().querySelector('div[data-testid=image] img');
26 | } else if(tabIndex == 1){
27 | targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
28 | }
29 |
30 | if(targetElement){
31 |
32 | var arPreviewRect = gradioApp().querySelector('#imageARPreview');
33 | if(!arPreviewRect){
34 | arPreviewRect = document.createElement('div')
35 | arPreviewRect.id = "imageARPreview";
36 | gradioApp().getRootNode().appendChild(arPreviewRect)
37 | }
38 |
39 |
40 |
41 | var viewportOffset = targetElement.getBoundingClientRect();
42 |
43 | viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
44 |
45 | scaledx = targetElement.naturalWidth*viewportscale
46 | scaledy = targetElement.naturalHeight*viewportscale
47 |
48 | cleintRectTop = (viewportOffset.top+window.scrollY)
49 | cleintRectLeft = (viewportOffset.left+window.scrollX)
50 | cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
51 | cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
52 |
53 | viewRectTop = cleintRectCentreY-(scaledy/2)
54 | viewRectLeft = cleintRectCentreX-(scaledx/2)
55 | arRectWidth = scaledx
56 | arRectHeight = scaledy
57 |
58 | arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
59 | arscaledx = currentWidth*arscale
60 | arscaledy = currentHeight*arscale
61 |
62 | arRectTop = cleintRectCentreY-(arscaledy/2)
63 | arRectLeft = cleintRectCentreX-(arscaledx/2)
64 | arRectWidth = arscaledx
65 | arRectHeight = arscaledy
66 |
67 | arPreviewRect.style.top = arRectTop+'px';
68 | arPreviewRect.style.left = arRectLeft+'px';
69 | arPreviewRect.style.width = arRectWidth+'px';
70 | arPreviewRect.style.height = arRectHeight+'px';
71 |
72 | clearTimeout(arFrameTimeout);
73 | arFrameTimeout = setTimeout(function(){
74 | arPreviewRect.style.display = 'none';
75 | },2000);
76 |
77 | arPreviewRect.style.display = 'block';
78 |
79 | }
80 |
81 | }
82 |
83 |
84 | onUiUpdate(function(){
85 | var arPreviewRect = gradioApp().querySelector('#imageARPreview');
86 | if(arPreviewRect){
87 | arPreviewRect.style.display = 'none';
88 | }
89 | var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
90 | if(inImg2img){
91 | let inputs = gradioApp().querySelectorAll('input');
92 | inputs.forEach(function(e){
93 | var is_width = e.parentElement.id == "img2img_width"
94 | var is_height = e.parentElement.id == "img2img_height"
95 |
96 | if((is_width || is_height) && !e.classList.contains('scrollwatch')){
97 | e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
98 | e.classList.add('scrollwatch')
99 | }
100 | if(is_width){
101 | currentWidth = e.value*1.0
102 | }
103 | if(is_height){
104 | currentHeight = e.value*1.0
105 | }
106 | })
107 | }
108 | });
109 |
--------------------------------------------------------------------------------
/scripts/sd_upscale.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import modules.scripts as scripts
4 | import gradio as gr
5 | from PIL import Image
6 |
7 | from modules import processing, shared, sd_samplers, images, devices
8 | from modules.processing import Processed
9 | from modules.shared import opts, cmd_opts, state
10 |
11 |
12 | class Script(scripts.Script):
13 | def title(self):
14 | return "SD upscale"
15 |
16 | def show(self, is_img2img):
17 | return is_img2img
18 |
19 | def ui(self, is_img2img):
20 | info = gr.HTML("
Will upscale the image to twice the dimensions; use width and height sliders to set tile size
") 21 | overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64) 22 | upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") 23 | 24 | return [info, overlap, upscaler_index] 25 | 26 | def run(self, p, _, overlap, upscaler_index): 27 | processing.fix_seed(p) 28 | upscaler = shared.sd_upscalers[upscaler_index] 29 | 30 | p.extra_generation_params["SD upscale overlap"] = overlap 31 | p.extra_generation_params["SD upscale upscaler"] = upscaler.name 32 | 33 | initial_info = None 34 | seed = p.seed 35 | 36 | init_img = p.init_images[0] 37 | 38 | if(upscaler.name != "None"): 39 | img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path) 40 | else: 41 | img = init_img 42 | 43 | devices.torch_gc() 44 | 45 | grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap) 46 | 47 | batch_size = p.batch_size 48 | upscale_count = p.n_iter 49 | p.n_iter = 1 50 | p.do_not_save_grid = True 51 | p.do_not_save_samples = True 52 | 53 | work = [] 54 | 55 | for y, h, row in grid.tiles: 56 | for tiledata in row: 57 | work.append(tiledata[2]) 58 | 59 | batch_count = math.ceil(len(work) / batch_size) 60 | state.job_count = batch_count * upscale_count 61 | 62 | print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.") 63 | 64 | result_images = [] 65 | for n in range(upscale_count): 66 | start_seed = seed + n 67 | p.seed = start_seed 68 | 69 | work_results = [] 70 | for i in range(batch_count): 71 | p.batch_size = batch_size 72 | p.init_images = work[i*batch_size:(i+1)*batch_size] 73 | 74 | state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}" 75 | processed = processing.process_images(p) 76 | 77 | if initial_info is None: 78 | initial_info = processed.info 79 | 80 | p.seed = processed.seed + 1 81 | work_results += processed.images 82 | 83 | image_index = 0 84 | for y, h, row in grid.tiles: 85 | for tiledata in row: 86 | tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) 87 | image_index += 1 88 | 89 | combined_image = images.combine_grid(grid) 90 | result_images.append(combined_image) 91 | 92 | if opts.samples_save: 93 | images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p) 94 | 95 | processed = Processed(p, result_images, seed, initial_info) 96 | 97 | return processed 98 | -------------------------------------------------------------------------------- /modules/lowvram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules import devices 3 | 4 | module_in_gpu = None 5 | cpu = torch.device("cpu") 6 | 7 | 8 | def send_everything_to_cpu(): 9 | global module_in_gpu 10 | 11 | if module_in_gpu is not None: 12 | module_in_gpu.to(cpu) 13 | 14 | module_in_gpu = None 15 | 16 | 17 | def setup_for_low_vram(sd_model, use_medvram): 18 | parents = {} 19 | 20 | def send_me_to_gpu(module, _): 21 | """send this module to GPU; send whatever tracked module was previous in GPU to CPU; 22 | we add this as forward_pre_hook to a lot of modules and this way all but one of them will 23 | be in CPU 24 | """ 25 | global module_in_gpu 26 | 27 | module = parents.get(module, module) 28 | 29 | if module_in_gpu == module: 30 | return 31 | 32 | if module_in_gpu is not None: 33 | module_in_gpu.to(cpu) 34 | 35 | module.to(devices.device) 36 | module_in_gpu = module 37 | 38 | # see below for register_forward_pre_hook; 39 | # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is 40 | # useless here, and we just replace those methods 41 | 42 | first_stage_model = sd_model.first_stage_model 43 | first_stage_model_encode = sd_model.first_stage_model.encode 44 | first_stage_model_decode = sd_model.first_stage_model.decode 45 | 46 | def first_stage_model_encode_wrap(x): 47 | send_me_to_gpu(first_stage_model, None) 48 | return first_stage_model_encode(x) 49 | 50 | def first_stage_model_decode_wrap(z): 51 | send_me_to_gpu(first_stage_model, None) 52 | return first_stage_model_decode(z) 53 | 54 | # remove three big modules, cond, first_stage, and unet from the model and then 55 | # send the model to GPU. Then put modules back. the modules will be in CPU. 56 | stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model 57 | sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None 58 | sd_model.to(devices.device) 59 | sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored 60 | 61 | # register hooks for those the first two models 62 | sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) 63 | sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) 64 | sd_model.first_stage_model.encode = first_stage_model_encode_wrap 65 | sd_model.first_stage_model.decode = first_stage_model_decode_wrap 66 | parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model 67 | 68 | if use_medvram: 69 | sd_model.model.register_forward_pre_hook(send_me_to_gpu) 70 | else: 71 | diff_model = sd_model.model.diffusion_model 72 | 73 | # the third remaining model is still too big for 4 GB, so we also do the same for its submodules 74 | # so that only one of them is in GPU at a time 75 | stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed 76 | diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None 77 | sd_model.model.to(devices.device) 78 | diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored 79 | 80 | # install hooks for bits of third model 81 | diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu) 82 | for block in diff_model.input_blocks: 83 | block.register_forward_pre_hook(send_me_to_gpu) 84 | diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu) 85 | for block in diff_model.output_blocks: 86 | block.register_forward_pre_hook(send_me_to_gpu) 87 | -------------------------------------------------------------------------------- /modules/styles.py: -------------------------------------------------------------------------------- 1 | # We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime 2 | from __future__ import annotations 3 | 4 | import csv 5 | import os 6 | import os.path 7 | import typing 8 | import collections.abc as abc 9 | import tempfile 10 | import shutil 11 | 12 | if typing.TYPE_CHECKING: 13 | # Only import this when code is being type-checked, it doesn't have any effect at runtime 14 | from .processing import StableDiffusionProcessing 15 | 16 | 17 | class PromptStyle(typing.NamedTuple): 18 | name: str 19 | prompt: str 20 | negative_prompt: str 21 | 22 | 23 | def merge_prompts(style_prompt: str, prompt: str) -> str: 24 | if "{prompt}" in style_prompt: 25 | res = style_prompt.replace("{prompt}", prompt) 26 | else: 27 | parts = filter(None, (prompt.strip(), style_prompt.strip())) 28 | res = ", ".join(parts) 29 | 30 | return res 31 | 32 | 33 | def apply_styles_to_prompt(prompt, styles): 34 | for style in styles: 35 | prompt = merge_prompts(style, prompt) 36 | 37 | return prompt 38 | 39 | 40 | class StyleDatabase: 41 | def __init__(self, path: str): 42 | self.no_style = PromptStyle("None", "", "") 43 | self.styles = {"None": self.no_style} 44 | 45 | if not os.path.exists(path): 46 | return 47 | 48 | with open(path, "r", encoding="utf-8-sig", newline='') as file: 49 | reader = csv.DictReader(file) 50 | for row in reader: 51 | # Support loading old CSV format with "name, text"-columns 52 | prompt = row["prompt"] if "prompt" in row else row["text"] 53 | negative_prompt = row.get("negative_prompt", "") 54 | self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt) 55 | 56 | def get_style_prompts(self, styles): 57 | return [self.styles.get(x, self.no_style).prompt for x in styles] 58 | 59 | def get_negative_style_prompts(self, styles): 60 | return [self.styles.get(x, self.no_style).negative_prompt for x in styles] 61 | 62 | def apply_styles_to_prompt(self, prompt, styles): 63 | return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles]) 64 | 65 | def apply_negative_styles_to_prompt(self, prompt, styles): 66 | return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]) 67 | 68 | def apply_styles(self, p: StableDiffusionProcessing) -> None: 69 | if isinstance(p.prompt, list): 70 | p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt] 71 | else: 72 | p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles) 73 | 74 | if isinstance(p.negative_prompt, list): 75 | p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt] 76 | else: 77 | p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles) 78 | 79 | def save_styles(self, path: str) -> None: 80 | # Write to temporary file first, so we don't nuke the file if something goes wrong 81 | fd, temp_path = tempfile.mkstemp(".csv") 82 | with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: 83 | # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, 84 | # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() 85 | writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) 86 | writer.writeheader() 87 | writer.writerows(style._asdict() for k, style in self.styles.items()) 88 | 89 | # Always keep a backup file around 90 | if os.path.exists(path): 91 | shutil.move(path, path + ".bak") 92 | shutil.move(temp_path, path) 93 | -------------------------------------------------------------------------------- /modules/gfpgan_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | import facexlib 6 | import gfpgan 7 | 8 | import modules.face_restoration 9 | from modules import shared, devices, modelloader 10 | from modules.paths import models_path 11 | 12 | model_dir = "GFPGAN" 13 | user_path = None 14 | model_path = os.path.join(models_path, model_dir) 15 | model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" 16 | have_gfpgan = False 17 | loaded_gfpgan_model = None 18 | 19 | 20 | def gfpgann(): 21 | global loaded_gfpgan_model 22 | global model_path 23 | if loaded_gfpgan_model is not None: 24 | loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan) 25 | return loaded_gfpgan_model 26 | 27 | if gfpgan_constructor is None: 28 | return None 29 | 30 | models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") 31 | if len(models) == 1 and "http" in models[0]: 32 | model_file = models[0] 33 | elif len(models) != 0: 34 | latest_file = max(models, key=os.path.getctime) 35 | model_file = latest_file 36 | else: 37 | print("Unable to load gfpgan model!") 38 | return None 39 | model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) 40 | loaded_gfpgan_model = model 41 | 42 | return model 43 | 44 | 45 | def send_model_to(model, device): 46 | model.gfpgan.to(device) 47 | model.face_helper.face_det.to(device) 48 | model.face_helper.face_parse.to(device) 49 | 50 | 51 | def gfpgan_fix_faces(np_image): 52 | model = gfpgann() 53 | if model is None: 54 | return np_image 55 | 56 | send_model_to(model, devices.device_gfpgan) 57 | 58 | np_image_bgr = np_image[:, :, ::-1] 59 | cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) 60 | np_image = gfpgan_output_bgr[:, :, ::-1] 61 | 62 | model.face_helper.clean_all() 63 | 64 | if shared.opts.face_restoration_unload: 65 | send_model_to(model, devices.cpu) 66 | 67 | return np_image 68 | 69 | 70 | gfpgan_constructor = None 71 | 72 | 73 | def setup_model(dirname): 74 | global model_path 75 | if not os.path.exists(model_path): 76 | os.makedirs(model_path) 77 | 78 | try: 79 | from gfpgan import GFPGANer 80 | from facexlib import detection, parsing 81 | global user_path 82 | global have_gfpgan 83 | global gfpgan_constructor 84 | 85 | load_file_from_url_orig = gfpgan.utils.load_file_from_url 86 | facex_load_file_from_url_orig = facexlib.detection.load_file_from_url 87 | facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url 88 | 89 | def my_load_file_from_url(**kwargs): 90 | return load_file_from_url_orig(**dict(kwargs, model_dir=model_path)) 91 | 92 | def facex_load_file_from_url(**kwargs): 93 | return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None)) 94 | 95 | def facex_load_file_from_url2(**kwargs): 96 | return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None)) 97 | 98 | gfpgan.utils.load_file_from_url = my_load_file_from_url 99 | facexlib.detection.load_file_from_url = facex_load_file_from_url 100 | facexlib.parsing.load_file_from_url = facex_load_file_from_url2 101 | user_path = dirname 102 | have_gfpgan = True 103 | gfpgan_constructor = GFPGANer 104 | 105 | class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): 106 | def name(self): 107 | return "GFPGAN" 108 | 109 | def restore(self, np_image): 110 | return gfpgan_fix_faces(np_image) 111 | 112 | shared.face_restorers.append(FaceRestorerGFPGAN()) 113 | except Exception: 114 | print("Error setting up GFPGAN:", file=sys.stderr) 115 | print(traceback.format_exc(), file=sys.stderr) 116 | -------------------------------------------------------------------------------- /webui.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ################################################# 3 | # Please do not make any changes to this file, # 4 | # change the variables in webui-user.sh instead # 5 | ################################################# 6 | # Read variables from webui-user.sh 7 | # shellcheck source=/dev/null 8 | if [[ -f webui-user.sh ]] 9 | then 10 | source ./webui-user.sh 11 | fi 12 | 13 | # Set defaults 14 | # Install directory without trailing slash 15 | if [[ -z "${install_dir}" ]] 16 | then 17 | install_dir="/home/$(whoami)" 18 | fi 19 | 20 | # Name of the subdirectory (defaults to stable-diffusion-webui) 21 | if [[ -z "${clone_dir}" ]] 22 | then 23 | clone_dir="stable-diffusion-webui" 24 | fi 25 | 26 | # python3 executable 27 | if [[ -z "${python_cmd}" ]] 28 | then 29 | python_cmd="python3" 30 | fi 31 | 32 | # git executable 33 | if [[ -z "${GIT}" ]] 34 | then 35 | export GIT="git" 36 | fi 37 | 38 | # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) 39 | if [[ -z "${venv_dir}" ]] 40 | then 41 | venv_dir="venv" 42 | fi 43 | 44 | if [[ -z "${LAUNCH_SCRIPT}" ]] 45 | then 46 | LAUNCH_SCRIPT="launch.py" 47 | fi 48 | 49 | # Disable sentry logging 50 | export ERROR_REPORTING=FALSE 51 | 52 | # Do not reinstall existing pip packages on Debian/Ubuntu 53 | export PIP_IGNORE_INSTALLED=0 54 | 55 | # Pretty print 56 | delimiter="################################################################" 57 | 58 | printf "\n%s\n" "${delimiter}" 59 | printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n" 60 | printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m" 61 | printf "\n%s\n" "${delimiter}" 62 | 63 | # Do not run as root 64 | if [[ $(id -u) -eq 0 ]] 65 | then 66 | printf "\n%s\n" "${delimiter}" 67 | printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m" 68 | printf "\n%s\n" "${delimiter}" 69 | exit 1 70 | else 71 | printf "\n%s\n" "${delimiter}" 72 | printf "Running on \e[1m\e[32m%s\e[0m user" "$(whoami)" 73 | printf "\n%s\n" "${delimiter}" 74 | fi 75 | 76 | if [[ -d .git ]] 77 | then 78 | printf "\n%s\n" "${delimiter}" 79 | printf "Repo already cloned, using it as install directory" 80 | printf "\n%s\n" "${delimiter}" 81 | install_dir="${PWD}/../" 82 | clone_dir="${PWD##*/}" 83 | fi 84 | 85 | # Check prerequisites 86 | for preq in "${GIT}" "${python_cmd}" 87 | do 88 | if ! hash "${preq}" &>/dev/null 89 | then 90 | printf "\n%s\n" "${delimiter}" 91 | printf "\e[1m\e[31mERROR: %s is not installed, aborting...\e[0m" "${preq}" 92 | printf "\n%s\n" "${delimiter}" 93 | exit 1 94 | fi 95 | done 96 | 97 | if ! "${python_cmd}" -c "import venv" &>/dev/null 98 | then 99 | printf "\n%s\n" "${delimiter}" 100 | printf "\e[1m\e[31mERROR: python3-venv is not installed, aborting...\e[0m" 101 | printf "\n%s\n" "${delimiter}" 102 | exit 1 103 | fi 104 | 105 | cd "${install_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/, aborting...\e[0m" "${install_dir}"; exit 1; } 106 | if [[ -d "${clone_dir}" ]] 107 | then 108 | cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } 109 | else 110 | printf "\n%s\n" "${delimiter}" 111 | printf "Clone stable-diffusion-webui" 112 | printf "\n%s\n" "${delimiter}" 113 | "${GIT}" clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git "${clone_dir}" 114 | cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } 115 | fi 116 | 117 | printf "\n%s\n" "${delimiter}" 118 | printf "Create and activate python venv" 119 | printf "\n%s\n" "${delimiter}" 120 | cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } 121 | if [[ ! -d "${venv_dir}" ]] 122 | then 123 | "${python_cmd}" -m venv "${venv_dir}" 124 | first_launch=1 125 | fi 126 | # shellcheck source=/dev/null 127 | if [[ -f "${venv_dir}"/bin/activate ]] 128 | then 129 | source "${venv_dir}"/bin/activate 130 | else 131 | printf "\n%s\n" "${delimiter}" 132 | printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m" 133 | printf "\n%s\n" "${delimiter}" 134 | exit 1 135 | fi 136 | 137 | printf "\n%s\n" "${delimiter}" 138 | printf "Launching launch.py..." 139 | printf "\n%s\n" "${delimiter}" 140 | "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" 141 | -------------------------------------------------------------------------------- /modules/upscaler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import abstractmethod 3 | 4 | import PIL 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | 9 | import modules.shared 10 | from modules import modelloader, shared 11 | 12 | LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) 13 | NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) 14 | from modules.paths import models_path 15 | 16 | 17 | class Upscaler: 18 | name = None 19 | model_path = None 20 | model_name = None 21 | model_url = None 22 | enable = True 23 | filter = None 24 | model = None 25 | user_path = None 26 | scalers: [] 27 | tile = True 28 | 29 | def __init__(self, create_dirs=False): 30 | self.mod_pad_h = None 31 | self.tile_size = modules.shared.opts.ESRGAN_tile 32 | self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap 33 | self.device = modules.shared.device 34 | self.img = None 35 | self.output = None 36 | self.scale = 1 37 | self.half = not modules.shared.cmd_opts.no_half 38 | self.pre_pad = 0 39 | self.mod_scale = None 40 | 41 | if self.model_path is None and self.name: 42 | self.model_path = os.path.join(models_path, self.name) 43 | if self.model_path and create_dirs: 44 | os.makedirs(self.model_path, exist_ok=True) 45 | 46 | try: 47 | import cv2 48 | self.can_tile = True 49 | except: 50 | pass 51 | 52 | @abstractmethod 53 | def do_upscale(self, img: PIL.Image, selected_model: str): 54 | return img 55 | 56 | def upscale(self, img: PIL.Image, scale: int, selected_model: str = None): 57 | self.scale = scale 58 | dest_w = img.width * scale 59 | dest_h = img.height * scale 60 | 61 | for i in range(3): 62 | shape = (img.width, img.height) 63 | 64 | img = self.do_upscale(img, selected_model) 65 | 66 | if shape == (img.width, img.height): 67 | break 68 | 69 | if img.width >= dest_w and img.height >= dest_h: 70 | break 71 | 72 | if img.width != dest_w or img.height != dest_h: 73 | img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) 74 | 75 | return img 76 | 77 | @abstractmethod 78 | def load_model(self, path: str): 79 | pass 80 | 81 | def find_models(self, ext_filter=None) -> list: 82 | return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path) 83 | 84 | def update_status(self, prompt): 85 | print(f"\nextras: {prompt}", file=shared.progress_print_out) 86 | 87 | 88 | class UpscalerData: 89 | name = None 90 | data_path = None 91 | scale: int = 4 92 | scaler: Upscaler = None 93 | model: None 94 | 95 | def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None): 96 | self.name = name 97 | self.data_path = path 98 | self.scaler = upscaler 99 | self.scale = scale 100 | self.model = model 101 | 102 | 103 | class UpscalerNone(Upscaler): 104 | name = "None" 105 | scalers = [] 106 | 107 | def load_model(self, path): 108 | pass 109 | 110 | def do_upscale(self, img, selected_model=None): 111 | return img 112 | 113 | def __init__(self, dirname=None): 114 | super().__init__(False) 115 | self.scalers = [UpscalerData("None", None, self)] 116 | 117 | 118 | class UpscalerLanczos(Upscaler): 119 | scalers = [] 120 | 121 | def do_upscale(self, img, selected_model=None): 122 | return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS) 123 | 124 | def load_model(self, _): 125 | pass 126 | 127 | def __init__(self, dirname=None): 128 | super().__init__(False) 129 | self.name = "Lanczos" 130 | self.scalers = [UpscalerData("Lanczos", None, self)] 131 | 132 | 133 | class UpscalerNearest(Upscaler): 134 | scalers = [] 135 | 136 | def do_upscale(self, img, selected_model=None): 137 | return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST) 138 | 139 | def load_model(self, _): 140 | pass 141 | 142 | def __init__(self, dirname=None): 143 | super().__init__(False) 144 | self.name = "Nearest" 145 | self.scalers = [UpscalerData("Nearest", None, self)] -------------------------------------------------------------------------------- /modules/textual_inversion/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | import torch 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | 9 | import random 10 | import tqdm 11 | from modules import devices, shared 12 | import re 13 | 14 | re_numbers_at_start = re.compile(r"^[-\d]+\s*") 15 | 16 | 17 | class DatasetEntry: 18 | def __init__(self, filename=None, latent=None, filename_text=None): 19 | self.filename = filename 20 | self.latent = latent 21 | self.filename_text = filename_text 22 | self.cond = None 23 | self.cond_text = None 24 | 25 | 26 | class PersonalizedBase(Dataset): 27 | def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): 28 | re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None 29 | 30 | self.placeholder_token = placeholder_token 31 | 32 | self.batch_size = batch_size 33 | self.width = width 34 | self.height = height 35 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 36 | 37 | self.dataset = [] 38 | 39 | with open(template_file, "r") as file: 40 | lines = [x.strip() for x in file.readlines()] 41 | 42 | self.lines = lines 43 | 44 | assert data_root, 'dataset directory not specified' 45 | assert os.path.isdir(data_root), "Dataset directory doesn't exist" 46 | assert os.listdir(data_root), "Dataset directory is empty" 47 | 48 | cond_model = shared.sd_model.cond_stage_model 49 | 50 | self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] 51 | print("Preparing dataset...") 52 | for path in tqdm.tqdm(self.image_paths): 53 | try: 54 | image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) 55 | except Exception: 56 | continue 57 | 58 | text_filename = os.path.splitext(path)[0] + ".txt" 59 | filename = os.path.basename(path) 60 | 61 | if os.path.exists(text_filename): 62 | with open(text_filename, "r", encoding="utf8") as file: 63 | filename_text = file.read() 64 | else: 65 | filename_text = os.path.splitext(filename)[0] 66 | filename_text = re.sub(re_numbers_at_start, '', filename_text) 67 | if re_word: 68 | tokens = re_word.findall(filename_text) 69 | filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens) 70 | 71 | npimage = np.array(image).astype(np.uint8) 72 | npimage = (npimage / 127.5 - 1.0).astype(np.float32) 73 | 74 | torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) 75 | torchdata = torch.moveaxis(torchdata, 2, 0) 76 | 77 | init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() 78 | init_latent = init_latent.to(devices.cpu) 79 | 80 | entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) 81 | 82 | if include_cond: 83 | entry.cond_text = self.create_text(filename_text) 84 | entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) 85 | 86 | self.dataset.append(entry) 87 | 88 | assert len(self.dataset) > 0, "No images have been found in the dataset." 89 | self.length = len(self.dataset) * repeats // batch_size 90 | 91 | self.dataset_length = len(self.dataset) 92 | self.indexes = None 93 | self.shuffle() 94 | 95 | def shuffle(self): 96 | self.indexes = np.random.permutation(self.dataset_length) 97 | 98 | def create_text(self, filename_text): 99 | text = random.choice(self.lines) 100 | text = text.replace("[name]", self.placeholder_token) 101 | text = text.replace("[filewords]", filename_text) 102 | return text 103 | 104 | def __len__(self): 105 | return self.length 106 | 107 | def __getitem__(self, i): 108 | res = [] 109 | 110 | for j in range(self.batch_size): 111 | position = i * self.batch_size + j 112 | if position % len(self.indexes) == 0: 113 | self.shuffle() 114 | 115 | index = self.indexes[position % len(self.indexes)] 116 | entry = self.dataset[index] 117 | 118 | if entry.cond is None: 119 | entry.cond_text = self.create_text(entry.filename_text) 120 | 121 | res.append(entry) 122 | 123 | return res 124 | -------------------------------------------------------------------------------- /javascript/localization.js: -------------------------------------------------------------------------------- 1 | 2 | // localization = {} -- the dict with translations is created by the backend 3 | 4 | ignore_ids_for_localization={ 5 | setting_sd_hypernetwork: 'OPTION', 6 | setting_sd_model_checkpoint: 'OPTION', 7 | setting_realesrgan_enabled_models: 'OPTION', 8 | modelmerger_primary_model_name: 'OPTION', 9 | modelmerger_secondary_model_name: 'OPTION', 10 | modelmerger_tertiary_model_name: 'OPTION', 11 | train_embedding: 'OPTION', 12 | train_hypernetwork: 'OPTION', 13 | txt2img_style_index: 'OPTION', 14 | txt2img_style2_index: 'OPTION', 15 | img2img_style_index: 'OPTION', 16 | img2img_style2_index: 'OPTION', 17 | setting_random_artist_categories: 'SPAN', 18 | setting_face_restoration_model: 'SPAN', 19 | setting_realesrgan_enabled_models: 'SPAN', 20 | extras_upscaler_1: 'SPAN', 21 | extras_upscaler_2: 'SPAN', 22 | } 23 | 24 | re_num = /^[\.\d]+$/ 25 | re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u 26 | 27 | original_lines = {} 28 | translated_lines = {} 29 | 30 | function textNodesUnder(el){ 31 | var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false); 32 | while(n=walk.nextNode()) a.push(n); 33 | return a; 34 | } 35 | 36 | function canBeTranslated(node, text){ 37 | if(! text) return false; 38 | if(! node.parentElement) return false; 39 | 40 | parentType = node.parentElement.nodeName 41 | if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false; 42 | 43 | if (parentType=='OPTION' || parentType=='SPAN'){ 44 | pnode = node 45 | for(var level=0; level<4; level++){ 46 | pnode = pnode.parentElement 47 | if(! pnode) break; 48 | 49 | if(ignore_ids_for_localization[pnode.id] == parentType) return false; 50 | } 51 | } 52 | 53 | if(re_num.test(text)) return false; 54 | if(re_emoji.test(text)) return false; 55 | return true 56 | } 57 | 58 | function getTranslation(text){ 59 | if(! text) return undefined 60 | 61 | if(translated_lines[text] === undefined){ 62 | original_lines[text] = 1 63 | } 64 | 65 | tl = localization[text] 66 | if(tl !== undefined){ 67 | translated_lines[tl] = 1 68 | } 69 | 70 | return tl 71 | } 72 | 73 | function processTextNode(node){ 74 | text = node.textContent.trim() 75 | 76 | if(! canBeTranslated(node, text)) return 77 | 78 | tl = getTranslation(text) 79 | if(tl !== undefined){ 80 | node.textContent = tl 81 | } 82 | } 83 | 84 | function processNode(node){ 85 | if(node.nodeType == 3){ 86 | processTextNode(node) 87 | return 88 | } 89 | 90 | if(node.title){ 91 | tl = getTranslation(node.title) 92 | if(tl !== undefined){ 93 | node.title = tl 94 | } 95 | } 96 | 97 | if(node.placeholder){ 98 | tl = getTranslation(node.placeholder) 99 | if(tl !== undefined){ 100 | node.placeholder = tl 101 | } 102 | } 103 | 104 | textNodesUnder(node).forEach(function(node){ 105 | processTextNode(node) 106 | }) 107 | } 108 | 109 | function dumpTranslations(){ 110 | dumped = {} 111 | if (localization.rtl) { 112 | dumped.rtl = true 113 | } 114 | 115 | Object.keys(original_lines).forEach(function(text){ 116 | if(dumped[text] !== undefined) return 117 | 118 | dumped[text] = localization[text] || text 119 | }) 120 | 121 | return dumped 122 | } 123 | 124 | onUiUpdate(function(m){ 125 | m.forEach(function(mutation){ 126 | mutation.addedNodes.forEach(function(node){ 127 | processNode(node) 128 | }) 129 | }); 130 | }) 131 | 132 | 133 | document.addEventListener("DOMContentLoaded", function() { 134 | processNode(gradioApp()) 135 | 136 | if (localization.rtl) { // if the language is from right to left, 137 | (new MutationObserver((mutations, observer) => { // wait for the style to load 138 | mutations.forEach(mutation => { 139 | mutation.addedNodes.forEach(node => { 140 | if (node.tagName === 'STYLE') { 141 | observer.disconnect(); 142 | 143 | for (const x of node.sheet.rules) { // find all rtl media rules 144 | if (Array.from(x.media || []).includes('rtl')) { 145 | x.media.appendMedium('all'); // enable them 146 | } 147 | } 148 | } 149 | }) 150 | }); 151 | })).observe(gradioApp(), { childList: true }); 152 | } 153 | }) 154 | 155 | function download_localization() { 156 | text = JSON.stringify(dumpTranslations(), null, 4) 157 | 158 | var element = document.createElement('a'); 159 | element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text)); 160 | element.setAttribute('download', "localization.json"); 161 | element.style.display = 'none'; 162 | document.body.appendChild(element); 163 | 164 | element.click(); 165 | 166 | document.body.removeChild(element); 167 | } 168 | -------------------------------------------------------------------------------- /scripts/prompts_from_file.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import os 4 | import random 5 | import sys 6 | import traceback 7 | import shlex 8 | 9 | import modules.scripts as scripts 10 | import gradio as gr 11 | 12 | from modules.processing import Processed, process_images 13 | from PIL import Image 14 | from modules.shared import opts, cmd_opts, state 15 | 16 | 17 | def process_string_tag(tag): 18 | return tag 19 | 20 | 21 | def process_int_tag(tag): 22 | return int(tag) 23 | 24 | 25 | def process_float_tag(tag): 26 | return float(tag) 27 | 28 | 29 | def process_boolean_tag(tag): 30 | return True if (tag == "true") else False 31 | 32 | 33 | prompt_tags = { 34 | "sd_model": None, 35 | "outpath_samples": process_string_tag, 36 | "outpath_grids": process_string_tag, 37 | "prompt_for_display": process_string_tag, 38 | "prompt": process_string_tag, 39 | "negative_prompt": process_string_tag, 40 | "styles": process_string_tag, 41 | "seed": process_int_tag, 42 | "subseed_strength": process_float_tag, 43 | "subseed": process_int_tag, 44 | "seed_resize_from_h": process_int_tag, 45 | "seed_resize_from_w": process_int_tag, 46 | "sampler_index": process_int_tag, 47 | "batch_size": process_int_tag, 48 | "n_iter": process_int_tag, 49 | "steps": process_int_tag, 50 | "cfg_scale": process_float_tag, 51 | "width": process_int_tag, 52 | "height": process_int_tag, 53 | "restore_faces": process_boolean_tag, 54 | "tiling": process_boolean_tag, 55 | "do_not_save_samples": process_boolean_tag, 56 | "do_not_save_grid": process_boolean_tag 57 | } 58 | 59 | 60 | def cmdargs(line): 61 | args = shlex.split(line) 62 | pos = 0 63 | res = {} 64 | 65 | while pos < len(args): 66 | arg = args[pos] 67 | 68 | assert arg.startswith("--"), f'must start with "--": {arg}' 69 | tag = arg[2:] 70 | 71 | func = prompt_tags.get(tag, None) 72 | assert func, f'unknown commandline option: {arg}' 73 | 74 | assert pos+1 < len(args), f'missing argument for command line option {arg}' 75 | 76 | val = args[pos+1] 77 | 78 | res[tag] = func(val) 79 | 80 | pos += 2 81 | 82 | return res 83 | 84 | 85 | def load_prompt_file(file): 86 | if file is None: 87 | lines = [] 88 | else: 89 | lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")] 90 | 91 | return None, "\n".join(lines), gr.update(lines=7) 92 | 93 | 94 | class Script(scripts.Script): 95 | def title(self): 96 | return "Prompts from file or textbox" 97 | 98 | def ui(self, is_img2img): 99 | checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False) 100 | checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False) 101 | 102 | prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1) 103 | file = gr.File(label="Upload prompt inputs", type='bytes') 104 | 105 | file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt]) 106 | 107 | # We start at one line. When the text changes, we jump to seven lines, or two lines if no \n. 108 | # We don't shrink back to 1, because that causes the control to ignore [enter], and it may 109 | # be unclear to the user that shift-enter is needed. 110 | prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt]) 111 | return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] 112 | 113 | def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): 114 | lines = [x.strip() for x in prompt_txt.splitlines()] 115 | lines = [x for x in lines if len(x) > 0] 116 | 117 | p.do_not_save_grid = True 118 | 119 | job_count = 0 120 | jobs = [] 121 | 122 | for line in lines: 123 | if "--" in line: 124 | try: 125 | args = cmdargs(line) 126 | except Exception: 127 | print(f"Error parsing line [line] as commandline:", file=sys.stderr) 128 | print(traceback.format_exc(), file=sys.stderr) 129 | args = {"prompt": line} 130 | else: 131 | args = {"prompt": line} 132 | 133 | n_iter = args.get("n_iter", 1) 134 | if n_iter != 1: 135 | job_count += n_iter 136 | else: 137 | job_count += 1 138 | 139 | jobs.append(args) 140 | 141 | print(f"Will process {len(lines)} lines in {job_count} jobs.") 142 | if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1: 143 | p.seed = int(random.randrange(4294967294)) 144 | 145 | state.job_count = job_count 146 | 147 | images = [] 148 | for n, args in enumerate(jobs): 149 | state.job = f"{state.job_no + 1} out of {state.job_count}" 150 | 151 | copy_p = copy.copy(p) 152 | for k, v in args.items(): 153 | setattr(copy_p, k, v) 154 | 155 | proc = process_images(copy_p) 156 | images += proc.images 157 | 158 | if checkbox_iterate: 159 | p.seed = p.seed + (p.batch_size * p.n_iter) 160 | 161 | return Processed(p, images, p.seed, "") 162 | -------------------------------------------------------------------------------- /modules/realesrgan_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | import numpy as np 6 | from PIL import Image 7 | from basicsr.utils.download_util import load_file_from_url 8 | from realesrgan import RealESRGANer 9 | 10 | from modules.upscaler import Upscaler, UpscalerData 11 | from modules.shared import cmd_opts, opts 12 | 13 | 14 | class UpscalerRealESRGAN(Upscaler): 15 | def __init__(self, path): 16 | self.name = "RealESRGAN" 17 | self.user_path = path 18 | super().__init__() 19 | try: 20 | from basicsr.archs.rrdbnet_arch import RRDBNet 21 | from realesrgan import RealESRGANer 22 | from realesrgan.archs.srvgg_arch import SRVGGNetCompact 23 | self.enable = True 24 | self.scalers = [] 25 | scalers = self.load_models(path) 26 | for scaler in scalers: 27 | if scaler.name in opts.realesrgan_enabled_models: 28 | self.scalers.append(scaler) 29 | 30 | except Exception: 31 | print("Error importing Real-ESRGAN:", file=sys.stderr) 32 | print(traceback.format_exc(), file=sys.stderr) 33 | self.enable = False 34 | self.scalers = [] 35 | 36 | def do_upscale(self, img, path): 37 | if not self.enable: 38 | return img 39 | 40 | info = self.load_model(path) 41 | if not os.path.exists(info.data_path): 42 | print("Unable to load RealESRGAN model: %s" % info.name) 43 | return img 44 | 45 | upsampler = RealESRGANer( 46 | scale=info.scale, 47 | model_path=info.data_path, 48 | model=info.model(), 49 | half=not cmd_opts.no_half, 50 | tile=opts.ESRGAN_tile, 51 | tile_pad=opts.ESRGAN_tile_overlap, 52 | ) 53 | 54 | upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0] 55 | 56 | image = Image.fromarray(upsampled) 57 | return image 58 | 59 | def load_model(self, path): 60 | try: 61 | info = None 62 | for scaler in self.scalers: 63 | if scaler.data_path == path: 64 | info = scaler 65 | 66 | if info is None: 67 | print(f"Unable to find model info: {path}") 68 | return None 69 | 70 | model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True) 71 | info.data_path = model_file 72 | return info 73 | except Exception as e: 74 | print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr) 75 | print(traceback.format_exc(), file=sys.stderr) 76 | return None 77 | 78 | def load_models(self, _): 79 | return get_realesrgan_models(self) 80 | 81 | 82 | def get_realesrgan_models(scaler): 83 | try: 84 | from basicsr.archs.rrdbnet_arch import RRDBNet 85 | from realesrgan.archs.srvgg_arch import SRVGGNetCompact 86 | models = [ 87 | UpscalerData( 88 | name="R-ESRGAN General 4xV3", 89 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", 90 | scale=4, 91 | upscaler=scaler, 92 | model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') 93 | ), 94 | UpscalerData( 95 | name="R-ESRGAN General WDN 4xV3", 96 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", 97 | scale=4, 98 | upscaler=scaler, 99 | model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') 100 | ), 101 | UpscalerData( 102 | name="R-ESRGAN AnimeVideo", 103 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", 104 | scale=4, 105 | upscaler=scaler, 106 | model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') 107 | ), 108 | UpscalerData( 109 | name="R-ESRGAN 4x+", 110 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", 111 | scale=4, 112 | upscaler=scaler, 113 | model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) 114 | ), 115 | UpscalerData( 116 | name="R-ESRGAN 4x+ Anime6B", 117 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", 118 | scale=4, 119 | upscaler=scaler, 120 | model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) 121 | ), 122 | UpscalerData( 123 | name="R-ESRGAN 2x+", 124 | path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", 125 | scale=2, 126 | upscaler=scaler, 127 | model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) 128 | ), 129 | ] 130 | return models 131 | except Exception as e: 132 | print("Error making Real-ESRGAN models list:", file=sys.stderr) 133 | print(traceback.format_exc(), file=sys.stderr) 134 | -------------------------------------------------------------------------------- /modules/safe.py: -------------------------------------------------------------------------------- 1 | # this code is adapted from the script contributed by anon from /h/ 2 | 3 | import io 4 | import pickle 5 | import collections 6 | import sys 7 | import traceback 8 | 9 | import torch 10 | import numpy 11 | import _codecs 12 | import zipfile 13 | import re 14 | 15 | 16 | # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage 17 | TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage 18 | 19 | 20 | def encode(*args): 21 | out = _codecs.encode(*args) 22 | return out 23 | 24 | 25 | class RestrictedUnpickler(pickle.Unpickler): 26 | extra_handler = None 27 | 28 | def persistent_load(self, saved_id): 29 | assert saved_id[0] == 'storage' 30 | return TypedStorage() 31 | 32 | def find_class(self, module, name): 33 | if self.extra_handler is not None: 34 | res = self.extra_handler(module, name) 35 | if res is not None: 36 | return res 37 | 38 | if module == 'collections' and name == 'OrderedDict': 39 | return getattr(collections, name) 40 | if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: 41 | return getattr(torch._utils, name) 42 | if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']: 43 | return getattr(torch, name) 44 | if module == 'torch.nn.modules.container' and name in ['ParameterDict']: 45 | return getattr(torch.nn.modules.container, name) 46 | if module == 'numpy.core.multiarray' and name == 'scalar': 47 | return numpy.core.multiarray.scalar 48 | if module == 'numpy' and name == 'dtype': 49 | return numpy.dtype 50 | if module == '_codecs' and name == 'encode': 51 | return encode 52 | if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': 53 | import pytorch_lightning.callbacks 54 | return pytorch_lightning.callbacks.model_checkpoint 55 | if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': 56 | import pytorch_lightning.callbacks.model_checkpoint 57 | return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint 58 | if module == "__builtin__" and name == 'set': 59 | return set 60 | 61 | # Forbid everything else. 62 | raise Exception(f"global '{module}/{name}' is forbidden") 63 | 64 | 65 | allowed_zip_names = ["archive/data.pkl", "archive/version"] 66 | allowed_zip_names_re = re.compile(r"^archive/data/\d+$") 67 | 68 | 69 | def check_zip_filenames(filename, names): 70 | for name in names: 71 | if name in allowed_zip_names: 72 | continue 73 | if allowed_zip_names_re.match(name): 74 | continue 75 | 76 | raise Exception(f"bad file inside {filename}: {name}") 77 | 78 | 79 | def check_pt(filename, extra_handler): 80 | try: 81 | 82 | # new pytorch format is a zip file 83 | with zipfile.ZipFile(filename) as z: 84 | check_zip_filenames(filename, z.namelist()) 85 | 86 | with z.open('archive/data.pkl') as file: 87 | unpickler = RestrictedUnpickler(file) 88 | unpickler.extra_handler = extra_handler 89 | unpickler.load() 90 | 91 | except zipfile.BadZipfile: 92 | 93 | # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle 94 | with open(filename, "rb") as file: 95 | unpickler = RestrictedUnpickler(file) 96 | unpickler.extra_handler = extra_handler 97 | for i in range(5): 98 | unpickler.load() 99 | 100 | 101 | def load(filename, *args, **kwargs): 102 | return load_with_extra(filename, *args, **kwargs) 103 | 104 | 105 | def load_with_extra(filename, extra_handler=None, *args, **kwargs): 106 | """ 107 | this functon is intended to be used by extensions that want to load models with 108 | some extra classes in them that the usual unpickler would find suspicious. 109 | 110 | Use the extra_handler argument to specify a function that takes module and field name as text, 111 | and returns that field's value: 112 | 113 | ```python 114 | def extra(module, name): 115 | if module == 'collections' and name == 'OrderedDict': 116 | return collections.OrderedDict 117 | 118 | return None 119 | 120 | safe.load_with_extra('model.pt', extra_handler=extra) 121 | ``` 122 | 123 | The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is 124 | definitely unsafe. 125 | """ 126 | 127 | from modules import shared 128 | 129 | try: 130 | if not shared.cmd_opts.disable_safe_unpickle: 131 | check_pt(filename, extra_handler) 132 | 133 | except pickle.UnpicklingError: 134 | print(f"Error verifying pickled file from {filename}:", file=sys.stderr) 135 | print(traceback.format_exc(), file=sys.stderr) 136 | print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) 137 | print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) 138 | return None 139 | 140 | except Exception: 141 | print(f"Error verifying pickled file from {filename}:", file=sys.stderr) 142 | print(traceback.format_exc(), file=sys.stderr) 143 | print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) 144 | print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) 145 | return None 146 | 147 | return unsafe_torch_load(filename, *args, **kwargs) 148 | 149 | 150 | unsafe_torch_load = torch.load 151 | torch.load = load 152 | -------------------------------------------------------------------------------- /modules/img2img.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | import traceback 5 | 6 | import numpy as np 7 | from PIL import Image, ImageOps, ImageChops 8 | 9 | from modules import devices 10 | from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images 11 | from modules.shared import opts, state 12 | import modules.shared as shared 13 | import modules.processing as processing 14 | from modules.ui import plaintext_to_html 15 | import modules.images as images 16 | import modules.scripts 17 | 18 | 19 | def process_batch(p, input_dir, output_dir, args): 20 | processing.fix_seed(p) 21 | 22 | images = shared.listfiles(input_dir) 23 | 24 | print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") 25 | 26 | save_normally = output_dir == '' 27 | 28 | p.do_not_save_grid = True 29 | p.do_not_save_samples = not save_normally 30 | 31 | state.job_count = len(images) * p.n_iter 32 | 33 | for i, image in enumerate(images): 34 | state.job = f"{i+1} out of {len(images)}" 35 | if state.skipped: 36 | state.skipped = False 37 | 38 | if state.interrupted: 39 | break 40 | 41 | img = Image.open(image) 42 | # Use the EXIF orientation of photos taken by smartphones. 43 | img = ImageOps.exif_transpose(img) 44 | p.init_images = [img] * p.batch_size 45 | 46 | proc = modules.scripts.scripts_img2img.run(p, *args) 47 | if proc is None: 48 | proc = process_images(p) 49 | 50 | for n, processed_image in enumerate(proc.images): 51 | filename = os.path.basename(image) 52 | 53 | if n > 0: 54 | left, right = os.path.splitext(filename) 55 | filename = f"{left}-{n}{right}" 56 | 57 | if not save_normally: 58 | os.makedirs(output_dir, exist_ok=True) 59 | processed_image.save(os.path.join(output_dir, filename)) 60 | 61 | 62 | def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): 63 | is_inpaint = mode == 1 64 | is_batch = mode == 2 65 | 66 | if is_inpaint: 67 | # Drawn mask 68 | if mask_mode == 0: 69 | image = init_img_with_mask['image'] 70 | mask = init_img_with_mask['mask'] 71 | alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') 72 | mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') 73 | image = image.convert('RGB') 74 | # Uploaded mask 75 | else: 76 | image = init_img_inpaint 77 | mask = init_mask_inpaint 78 | # No mask 79 | else: 80 | image = init_img 81 | mask = None 82 | 83 | # Use the EXIF orientation of photos taken by smartphones. 84 | if image is not None: 85 | image = ImageOps.exif_transpose(image) 86 | 87 | assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' 88 | 89 | p = StableDiffusionProcessingImg2Img( 90 | sd_model=shared.sd_model, 91 | outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples, 92 | outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids, 93 | prompt=prompt, 94 | negative_prompt=negative_prompt, 95 | styles=[prompt_style, prompt_style2], 96 | seed=seed, 97 | subseed=subseed, 98 | subseed_strength=subseed_strength, 99 | seed_resize_from_h=seed_resize_from_h, 100 | seed_resize_from_w=seed_resize_from_w, 101 | seed_enable_extras=seed_enable_extras, 102 | sampler_index=sampler_index, 103 | batch_size=batch_size, 104 | n_iter=n_iter, 105 | steps=steps, 106 | cfg_scale=cfg_scale, 107 | width=width, 108 | height=height, 109 | restore_faces=restore_faces, 110 | tiling=tiling, 111 | init_images=[image], 112 | mask=mask, 113 | mask_blur=mask_blur, 114 | inpainting_fill=inpainting_fill, 115 | resize_mode=resize_mode, 116 | denoising_strength=denoising_strength, 117 | inpaint_full_res=inpaint_full_res, 118 | inpaint_full_res_padding=inpaint_full_res_padding, 119 | inpainting_mask_invert=inpainting_mask_invert, 120 | ) 121 | 122 | p.scripts = modules.scripts.scripts_txt2img 123 | p.script_args = args 124 | 125 | if shared.cmd_opts.enable_console_prompts: 126 | print(f"\nimg2img: {prompt}", file=shared.progress_print_out) 127 | 128 | p.extra_generation_params["Mask blur"] = mask_blur 129 | 130 | if is_batch: 131 | assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" 132 | 133 | process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args) 134 | 135 | processed = Processed(p, [], p.seed, "") 136 | else: 137 | processed = modules.scripts.scripts_img2img.run(p, *args) 138 | if processed is None: 139 | processed = process_images(p) 140 | 141 | p.close() 142 | 143 | shared.total_tqdm.clear() 144 | 145 | generation_info_js = processed.js() 146 | if opts.samples_log_stdout: 147 | print(generation_info_js) 148 | 149 | if opts.do_not_show_images: 150 | processed.images = [] 151 | 152 | return processed.images, generation_info_js, plaintext_to_html(processed.info) 153 | -------------------------------------------------------------------------------- /scripts/poor_mans_outpainting.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import modules.scripts as scripts 4 | import gradio as gr 5 | from PIL import Image, ImageDraw 6 | 7 | from modules import images, processing, devices 8 | from modules.processing import Processed, process_images 9 | from modules.shared import opts, cmd_opts, state 10 | 11 | 12 | 13 | class Script(scripts.Script): 14 | def title(self): 15 | return "Poor man's outpainting" 16 | 17 | def show(self, is_img2img): 18 | return is_img2img 19 | 20 | def ui(self, is_img2img): 21 | if not is_img2img: 22 | return None 23 | 24 | pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128) 25 | mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) 26 | inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index") 27 | direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down']) 28 | 29 | return [pixels, mask_blur, inpainting_fill, direction] 30 | 31 | def run(self, p, pixels, mask_blur, inpainting_fill, direction): 32 | initial_seed = None 33 | initial_info = None 34 | 35 | p.mask_blur = mask_blur * 2 36 | p.inpainting_fill = inpainting_fill 37 | p.inpaint_full_res = False 38 | 39 | left = pixels if "left" in direction else 0 40 | right = pixels if "right" in direction else 0 41 | up = pixels if "up" in direction else 0 42 | down = pixels if "down" in direction else 0 43 | 44 | init_img = p.init_images[0] 45 | target_w = math.ceil((init_img.width + left + right) / 64) * 64 46 | target_h = math.ceil((init_img.height + up + down) / 64) * 64 47 | 48 | if left > 0: 49 | left = left * (target_w - init_img.width) // (left + right) 50 | if right > 0: 51 | right = target_w - init_img.width - left 52 | 53 | if up > 0: 54 | up = up * (target_h - init_img.height) // (up + down) 55 | 56 | if down > 0: 57 | down = target_h - init_img.height - up 58 | 59 | img = Image.new("RGB", (target_w, target_h)) 60 | img.paste(init_img, (left, up)) 61 | 62 | mask = Image.new("L", (img.width, img.height), "white") 63 | draw = ImageDraw.Draw(mask) 64 | draw.rectangle(( 65 | left + (mask_blur * 2 if left > 0 else 0), 66 | up + (mask_blur * 2 if up > 0 else 0), 67 | mask.width - right - (mask_blur * 2 if right > 0 else 0), 68 | mask.height - down - (mask_blur * 2 if down > 0 else 0) 69 | ), fill="black") 70 | 71 | latent_mask = Image.new("L", (img.width, img.height), "white") 72 | latent_draw = ImageDraw.Draw(latent_mask) 73 | latent_draw.rectangle(( 74 | left + (mask_blur//2 if left > 0 else 0), 75 | up + (mask_blur//2 if up > 0 else 0), 76 | mask.width - right - (mask_blur//2 if right > 0 else 0), 77 | mask.height - down - (mask_blur//2 if down > 0 else 0) 78 | ), fill="black") 79 | 80 | devices.torch_gc() 81 | 82 | grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels) 83 | grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels) 84 | grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels) 85 | 86 | p.n_iter = 1 87 | p.batch_size = 1 88 | p.do_not_save_grid = True 89 | p.do_not_save_samples = True 90 | 91 | work = [] 92 | work_mask = [] 93 | work_latent_mask = [] 94 | work_results = [] 95 | 96 | for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles): 97 | for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask): 98 | x, w = tiledata[0:2] 99 | 100 | if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: 101 | continue 102 | 103 | work.append(tiledata[2]) 104 | work_mask.append(tiledata_mask[2]) 105 | work_latent_mask.append(tiledata_latent_mask[2]) 106 | 107 | batch_count = len(work) 108 | print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.") 109 | 110 | state.job_count = batch_count 111 | 112 | for i in range(batch_count): 113 | p.init_images = [work[i]] 114 | p.image_mask = work_mask[i] 115 | p.latent_mask = work_latent_mask[i] 116 | 117 | state.job = f"Batch {i + 1} out of {batch_count}" 118 | processed = process_images(p) 119 | 120 | if initial_seed is None: 121 | initial_seed = processed.seed 122 | initial_info = processed.info 123 | 124 | p.seed = processed.seed + 1 125 | work_results += processed.images 126 | 127 | 128 | image_index = 0 129 | for y, h, row in grid.tiles: 130 | for tiledata in row: 131 | x, w = tiledata[0:2] 132 | 133 | if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: 134 | continue 135 | 136 | tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) 137 | image_index += 1 138 | 139 | combined_image = images.combine_grid(grid) 140 | 141 | if opts.samples_save: 142 | images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.grid_format, info=initial_info, p=p) 143 | 144 | processed = Processed(p, [combined_image], initial_seed, initial_info) 145 | 146 | return processed 147 | 148 | -------------------------------------------------------------------------------- /javascript/contextMenus.js: -------------------------------------------------------------------------------- 1 | 2 | contextMenuInit = function(){ 3 | let eventListenerApplied=false; 4 | let menuSpecs = new Map(); 5 | 6 | const uid = function(){ 7 | return Date.now().toString(36) + Math.random().toString(36).substr(2); 8 | } 9 | 10 | function showContextMenu(event,element,menuEntries){ 11 | let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft; 12 | let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop; 13 | 14 | let oldMenu = gradioApp().querySelector('#context-menu') 15 | if(oldMenu){ 16 | oldMenu.remove() 17 | } 18 | 19 | let tabButton = uiCurrentTab 20 | let baseStyle = window.getComputedStyle(tabButton) 21 | 22 | const contextMenu = document.createElement('nav') 23 | contextMenu.id = "context-menu" 24 | contextMenu.style.background = baseStyle.background 25 | contextMenu.style.color = baseStyle.color 26 | contextMenu.style.fontFamily = baseStyle.fontFamily 27 | contextMenu.style.top = posy+'px' 28 | contextMenu.style.left = posx+'px' 29 | 30 | 31 | 32 | const contextMenuList = document.createElement('ul') 33 | contextMenuList.className = 'context-menu-items'; 34 | contextMenu.append(contextMenuList); 35 | 36 | menuEntries.forEach(function(entry){ 37 | let contextMenuEntry = document.createElement('a') 38 | contextMenuEntry.innerHTML = entry['name'] 39 | contextMenuEntry.addEventListener("click", function(e) { 40 | entry['func'](); 41 | }) 42 | contextMenuList.append(contextMenuEntry); 43 | 44 | }) 45 | 46 | gradioApp().getRootNode().appendChild(contextMenu) 47 | 48 | let menuWidth = contextMenu.offsetWidth + 4; 49 | let menuHeight = contextMenu.offsetHeight + 4; 50 | 51 | let windowWidth = window.innerWidth; 52 | let windowHeight = window.innerHeight; 53 | 54 | if ( (windowWidth - posx) < menuWidth ) { 55 | contextMenu.style.left = windowWidth - menuWidth + "px"; 56 | } 57 | 58 | if ( (windowHeight - posy) < menuHeight ) { 59 | contextMenu.style.top = windowHeight - menuHeight + "px"; 60 | } 61 | 62 | } 63 | 64 | function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){ 65 | 66 | currentItems = menuSpecs.get(targetEmementSelector) 67 | 68 | if(!currentItems){ 69 | currentItems = [] 70 | menuSpecs.set(targetEmementSelector,currentItems); 71 | } 72 | let newItem = {'id':targetEmementSelector+'_'+uid(), 73 | 'name':entryName, 74 | 'func':entryFunction, 75 | 'isNew':true} 76 | 77 | currentItems.push(newItem) 78 | return newItem['id'] 79 | } 80 | 81 | function removeContextMenuOption(uid){ 82 | menuSpecs.forEach(function(v,k) { 83 | let index = -1 84 | v.forEach(function(e,ei){if(e['id']==uid){index=ei}}) 85 | if(index>=0){ 86 | v.splice(index, 1); 87 | } 88 | }) 89 | } 90 | 91 | function addContextMenuEventListener(){ 92 | if(eventListenerApplied){ 93 | return; 94 | } 95 | gradioApp().addEventListener("click", function(e) { 96 | let source = e.composedPath()[0] 97 | if(source.id && source.id.indexOf('check_progress')>-1){ 98 | return 99 | } 100 | 101 | let oldMenu = gradioApp().querySelector('#context-menu') 102 | if(oldMenu){ 103 | oldMenu.remove() 104 | } 105 | }); 106 | gradioApp().addEventListener("contextmenu", function(e) { 107 | let oldMenu = gradioApp().querySelector('#context-menu') 108 | if(oldMenu){ 109 | oldMenu.remove() 110 | } 111 | menuSpecs.forEach(function(v,k) { 112 | if(e.composedPath()[0].matches(k)){ 113 | showContextMenu(e,e.composedPath()[0],v) 114 | e.preventDefault() 115 | return 116 | } 117 | }) 118 | }); 119 | eventListenerApplied=true 120 | 121 | } 122 | 123 | return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener] 124 | } 125 | 126 | initResponse = contextMenuInit(); 127 | appendContextMenuOption = initResponse[0]; 128 | removeContextMenuOption = initResponse[1]; 129 | addContextMenuEventListener = initResponse[2]; 130 | 131 | (function(){ 132 | //Start example Context Menu Items 133 | let generateOnRepeat = function(genbuttonid,interruptbuttonid){ 134 | let genbutton = gradioApp().querySelector(genbuttonid); 135 | let interruptbutton = gradioApp().querySelector(interruptbuttonid); 136 | if(!interruptbutton.offsetParent){ 137 | genbutton.click(); 138 | } 139 | clearInterval(window.generateOnRepeatInterval) 140 | window.generateOnRepeatInterval = setInterval(function(){ 141 | if(!interruptbutton.offsetParent){ 142 | genbutton.click(); 143 | } 144 | }, 145 | 500) 146 | } 147 | 148 | appendContextMenuOption('#txt2img_generate','Generate forever',function(){ 149 | generateOnRepeat('#txt2img_generate','#txt2img_interrupt'); 150 | }) 151 | appendContextMenuOption('#img2img_generate','Generate forever',function(){ 152 | generateOnRepeat('#img2img_generate','#img2img_interrupt'); 153 | }) 154 | 155 | let cancelGenerateForever = function(){ 156 | clearInterval(window.generateOnRepeatInterval) 157 | } 158 | 159 | appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever) 160 | appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever) 161 | appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever) 162 | appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever) 163 | 164 | appendContextMenuOption('#roll','Roll three', 165 | function(){ 166 | let rollbutton = get_uiCurrentTabContent().querySelector('#roll'); 167 | setTimeout(function(){rollbutton.click()},100) 168 | setTimeout(function(){rollbutton.click()},200) 169 | setTimeout(function(){rollbutton.click()},300) 170 | } 171 | ) 172 | })(); 173 | //End example Context Menu Items 174 | 175 | onUiUpdate(function(){ 176 | addContextMenuEventListener() 177 | }); 178 | -------------------------------------------------------------------------------- /modules/swinir_model.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from basicsr.utils.download_util import load_file_from_url 8 | from tqdm import tqdm 9 | 10 | from modules import modelloader, devices 11 | from modules.shared import cmd_opts, opts 12 | from modules.swinir_model_arch import SwinIR as net 13 | from modules.swinir_model_arch_v2 import Swin2SR as net2 14 | from modules.upscaler import Upscaler, UpscalerData 15 | 16 | precision_scope = ( 17 | torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext 18 | ) 19 | 20 | 21 | class UpscalerSwinIR(Upscaler): 22 | def __init__(self, dirname): 23 | self.name = "SwinIR" 24 | self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \ 25 | "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \ 26 | "-L_x4_GAN.pth " 27 | self.model_name = "SwinIR 4x" 28 | self.user_path = dirname 29 | super().__init__() 30 | scalers = [] 31 | model_files = self.find_models(ext_filter=[".pt", ".pth"]) 32 | for model in model_files: 33 | if "http" in model: 34 | name = self.model_name 35 | else: 36 | name = modelloader.friendly_name(model) 37 | model_data = UpscalerData(name, model, self) 38 | scalers.append(model_data) 39 | self.scalers = scalers 40 | 41 | def do_upscale(self, img, model_file): 42 | model = self.load_model(model_file) 43 | if model is None: 44 | return img 45 | model = model.to(devices.device_swinir) 46 | img = upscale(img, model) 47 | try: 48 | torch.cuda.empty_cache() 49 | except: 50 | pass 51 | return img 52 | 53 | def load_model(self, path, scale=4): 54 | if "http" in path: 55 | dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth") 56 | filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True) 57 | else: 58 | filename = path 59 | if filename is None or not os.path.exists(filename): 60 | return None 61 | if filename.endswith(".v2.pth"): 62 | model = net2( 63 | upscale=scale, 64 | in_chans=3, 65 | img_size=64, 66 | window_size=8, 67 | img_range=1.0, 68 | depths=[6, 6, 6, 6, 6, 6], 69 | embed_dim=180, 70 | num_heads=[6, 6, 6, 6, 6, 6], 71 | mlp_ratio=2, 72 | upsampler="nearest+conv", 73 | resi_connection="1conv", 74 | ) 75 | params = None 76 | else: 77 | model = net( 78 | upscale=scale, 79 | in_chans=3, 80 | img_size=64, 81 | window_size=8, 82 | img_range=1.0, 83 | depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], 84 | embed_dim=240, 85 | num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], 86 | mlp_ratio=2, 87 | upsampler="nearest+conv", 88 | resi_connection="3conv", 89 | ) 90 | params = "params_ema" 91 | 92 | pretrained_model = torch.load(filename) 93 | if params is not None: 94 | model.load_state_dict(pretrained_model[params], strict=True) 95 | else: 96 | model.load_state_dict(pretrained_model, strict=True) 97 | if not cmd_opts.no_half: 98 | model = model.half() 99 | return model 100 | 101 | 102 | def upscale( 103 | img, 104 | model, 105 | tile=opts.SWIN_tile, 106 | tile_overlap=opts.SWIN_tile_overlap, 107 | window_size=8, 108 | scale=4, 109 | ): 110 | img = np.array(img) 111 | img = img[:, :, ::-1] 112 | img = np.moveaxis(img, 2, 0) / 255 113 | img = torch.from_numpy(img).float() 114 | img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir) 115 | with torch.no_grad(), precision_scope("cuda"): 116 | _, _, h_old, w_old = img.size() 117 | h_pad = (h_old // window_size + 1) * window_size - h_old 118 | w_pad = (w_old // window_size + 1) * window_size - w_old 119 | img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] 120 | img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] 121 | output = inference(img, model, tile, tile_overlap, window_size, scale) 122 | output = output[..., : h_old * scale, : w_old * scale] 123 | output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() 124 | if output.ndim == 3: 125 | output = np.transpose( 126 | output[[2, 1, 0], :, :], (1, 2, 0) 127 | ) # CHW-RGB to HCW-BGR 128 | output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 129 | return Image.fromarray(output, "RGB") 130 | 131 | 132 | def inference(img, model, tile, tile_overlap, window_size, scale): 133 | # test the image tile by tile 134 | b, c, h, w = img.size() 135 | tile = min(tile, h, w) 136 | assert tile % window_size == 0, "tile size should be a multiple of window_size" 137 | sf = scale 138 | 139 | stride = tile - tile_overlap 140 | h_idx_list = list(range(0, h - tile, stride)) + [h - tile] 141 | w_idx_list = list(range(0, w - tile, stride)) + [w - tile] 142 | E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img) 143 | W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir) 144 | 145 | with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: 146 | for h_idx in h_idx_list: 147 | for w_idx in w_idx_list: 148 | in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] 149 | out_patch = model(in_patch) 150 | out_patch_mask = torch.ones_like(out_patch) 151 | 152 | E[ 153 | ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf 154 | ].add_(out_patch) 155 | W[ 156 | ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf 157 | ].add_(out_patch_mask) 158 | pbar.update(1) 159 | output = E.div_(W) 160 | 161 | return output 162 | -------------------------------------------------------------------------------- /modules/codeformer_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import traceback 4 | 5 | import cv2 6 | import torch 7 | 8 | import modules.face_restoration 9 | import modules.shared 10 | from modules import shared, devices, modelloader 11 | from modules.paths import script_path, models_path 12 | 13 | # codeformer people made a choice to include modified basicsr library to their project which makes 14 | # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. 15 | # I am making a choice to include some files from codeformer to work around this issue. 16 | model_dir = "Codeformer" 17 | model_path = os.path.join(models_path, model_dir) 18 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' 19 | 20 | have_codeformer = False 21 | codeformer = None 22 | 23 | 24 | def setup_model(dirname): 25 | global model_path 26 | if not os.path.exists(model_path): 27 | os.makedirs(model_path) 28 | 29 | path = modules.paths.paths.get("CodeFormer", None) 30 | if path is None: 31 | return 32 | 33 | try: 34 | from torchvision.transforms.functional import normalize 35 | from modules.codeformer.codeformer_arch import CodeFormer 36 | from basicsr.utils.download_util import load_file_from_url 37 | from basicsr.utils import imwrite, img2tensor, tensor2img 38 | from facelib.utils.face_restoration_helper import FaceRestoreHelper 39 | from modules.shared import cmd_opts 40 | 41 | net_class = CodeFormer 42 | 43 | class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): 44 | def name(self): 45 | return "CodeFormer" 46 | 47 | def __init__(self, dirname): 48 | self.net = None 49 | self.face_helper = None 50 | self.cmd_dir = dirname 51 | 52 | def create_models(self): 53 | 54 | if self.net is not None and self.face_helper is not None: 55 | self.net.to(devices.device_codeformer) 56 | return self.net, self.face_helper 57 | model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth') 58 | if len(model_paths) != 0: 59 | ckpt_path = model_paths[0] 60 | else: 61 | print("Unable to load codeformer model.") 62 | return None, None 63 | net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer) 64 | checkpoint = torch.load(ckpt_path)['params_ema'] 65 | net.load_state_dict(checkpoint) 66 | net.eval() 67 | 68 | face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer) 69 | 70 | self.net = net 71 | self.face_helper = face_helper 72 | 73 | return net, face_helper 74 | 75 | def send_model_to(self, device): 76 | self.net.to(device) 77 | self.face_helper.face_det.to(device) 78 | self.face_helper.face_parse.to(device) 79 | 80 | def restore(self, np_image, w=None): 81 | np_image = np_image[:, :, ::-1] 82 | 83 | original_resolution = np_image.shape[0:2] 84 | 85 | self.create_models() 86 | if self.net is None or self.face_helper is None: 87 | return np_image 88 | 89 | self.send_model_to(devices.device_codeformer) 90 | 91 | self.face_helper.clean_all() 92 | self.face_helper.read_image(np_image) 93 | self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) 94 | self.face_helper.align_warp_face() 95 | 96 | for idx, cropped_face in enumerate(self.face_helper.cropped_faces): 97 | cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) 98 | normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 99 | cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) 100 | 101 | try: 102 | with torch.no_grad(): 103 | output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] 104 | restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) 105 | del output 106 | torch.cuda.empty_cache() 107 | except Exception as error: 108 | print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr) 109 | restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) 110 | 111 | restored_face = restored_face.astype('uint8') 112 | self.face_helper.add_restored_face(restored_face) 113 | 114 | self.face_helper.get_inverse_affine(None) 115 | 116 | restored_img = self.face_helper.paste_faces_to_input_image() 117 | restored_img = restored_img[:, :, ::-1] 118 | 119 | if original_resolution != restored_img.shape[0:2]: 120 | restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) 121 | 122 | self.face_helper.clean_all() 123 | 124 | if shared.opts.face_restoration_unload: 125 | self.send_model_to(devices.cpu) 126 | 127 | return restored_img 128 | 129 | global have_codeformer 130 | have_codeformer = True 131 | 132 | global codeformer 133 | codeformer = FaceRestorerCodeFormer(dirname) 134 | shared.face_restorers.append(codeformer) 135 | 136 | except Exception: 137 | print("Error setting up CodeFormer:", file=sys.stderr) 138 | print(traceback.format_exc(), file=sys.stderr) 139 | 140 | # sys.path = stored_sys_path 141 | -------------------------------------------------------------------------------- /modules/modelloader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import shutil 4 | import importlib 5 | from urllib.parse import urlparse 6 | 7 | from basicsr.utils.download_util import load_file_from_url 8 | from modules import shared 9 | from modules.upscaler import Upscaler 10 | from modules.paths import script_path, models_path 11 | 12 | 13 | def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list: 14 | """ 15 | A one-and done loader to try finding the desired models in specified directories. 16 | 17 | @param download_name: Specify to download from model_url immediately. 18 | @param model_url: If no other models are found, this will be downloaded on upscale. 19 | @param model_path: The location to store/find models in. 20 | @param command_path: A command-line argument to search for models in first. 21 | @param ext_filter: An optional list of filename extensions to filter by 22 | @return: A list of paths containing the desired model(s) 23 | """ 24 | output = [] 25 | 26 | if ext_filter is None: 27 | ext_filter = [] 28 | 29 | try: 30 | places = [] 31 | 32 | if command_path is not None and command_path != model_path: 33 | pretrained_path = os.path.join(command_path, 'experiments/pretrained_models') 34 | if os.path.exists(pretrained_path): 35 | print(f"Appending path: {pretrained_path}") 36 | places.append(pretrained_path) 37 | elif os.path.exists(command_path): 38 | places.append(command_path) 39 | 40 | places.append(model_path) 41 | 42 | for place in places: 43 | if os.path.exists(place): 44 | for file in glob.iglob(place + '**/**', recursive=True): 45 | full_path = file 46 | if os.path.isdir(full_path): 47 | continue 48 | if len(ext_filter) != 0: 49 | model_name, extension = os.path.splitext(file) 50 | if extension not in ext_filter: 51 | continue 52 | if file not in output: 53 | output.append(full_path) 54 | 55 | if model_url is not None and len(output) == 0: 56 | if download_name is not None: 57 | dl = load_file_from_url(model_url, model_path, True, download_name) 58 | output.append(dl) 59 | else: 60 | output.append(model_url) 61 | 62 | except Exception: 63 | pass 64 | 65 | return output 66 | 67 | 68 | def friendly_name(file: str): 69 | if "http" in file: 70 | file = urlparse(file).path 71 | 72 | file = os.path.basename(file) 73 | model_name, extension = os.path.splitext(file) 74 | return model_name 75 | 76 | 77 | def cleanup_models(): 78 | # This code could probably be more efficient if we used a tuple list or something to store the src/destinations 79 | # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler 80 | # somehow auto-register and just do these things... 81 | root_path = script_path 82 | src_path = models_path 83 | dest_path = os.path.join(models_path, "Stable-diffusion") 84 | move_files(src_path, dest_path, ".ckpt") 85 | src_path = os.path.join(root_path, "ESRGAN") 86 | dest_path = os.path.join(models_path, "ESRGAN") 87 | move_files(src_path, dest_path) 88 | src_path = os.path.join(models_path, "BSRGAN") 89 | dest_path = os.path.join(models_path, "ESRGAN") 90 | move_files(src_path, dest_path, ".pth") 91 | src_path = os.path.join(root_path, "gfpgan") 92 | dest_path = os.path.join(models_path, "GFPGAN") 93 | move_files(src_path, dest_path) 94 | src_path = os.path.join(root_path, "SwinIR") 95 | dest_path = os.path.join(models_path, "SwinIR") 96 | move_files(src_path, dest_path) 97 | src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/") 98 | dest_path = os.path.join(models_path, "LDSR") 99 | move_files(src_path, dest_path) 100 | 101 | 102 | def move_files(src_path: str, dest_path: str, ext_filter: str = None): 103 | try: 104 | if not os.path.exists(dest_path): 105 | os.makedirs(dest_path) 106 | if os.path.exists(src_path): 107 | for file in os.listdir(src_path): 108 | fullpath = os.path.join(src_path, file) 109 | if os.path.isfile(fullpath): 110 | if ext_filter is not None: 111 | if ext_filter not in file: 112 | continue 113 | print(f"Moving {file} from {src_path} to {dest_path}.") 114 | try: 115 | shutil.move(fullpath, dest_path) 116 | except: 117 | pass 118 | if len(os.listdir(src_path)) == 0: 119 | print(f"Removing empty folder: {src_path}") 120 | shutil.rmtree(src_path, True) 121 | except: 122 | pass 123 | 124 | 125 | def load_upscalers(): 126 | sd = shared.script_path 127 | # We can only do this 'magic' method to dynamically load upscalers if they are referenced, 128 | # so we'll try to import any _model.py files before looking in __subclasses__ 129 | modules_dir = os.path.join(sd, "modules") 130 | for file in os.listdir(modules_dir): 131 | if "_model.py" in file: 132 | model_name = file.replace("_model.py", "") 133 | full_model = f"modules.{model_name}_model" 134 | try: 135 | importlib.import_module(full_model) 136 | except: 137 | pass 138 | datas = [] 139 | c_o = vars(shared.cmd_opts) 140 | for cls in Upscaler.__subclasses__(): 141 | name = cls.__name__ 142 | module_name = cls.__module__ 143 | module = importlib.import_module(module_name) 144 | class_ = getattr(module, name) 145 | cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" 146 | opt_string = None 147 | try: 148 | if cmd_name in c_o: 149 | opt_string = c_o[cmd_name] 150 | except: 151 | pass 152 | scaler = class_(opt_string) 153 | for child in scaler.scalers: 154 | datas.append(child) 155 | 156 | shared.sd_upscalers = datas 157 | -------------------------------------------------------------------------------- /javascript/progressbar.js: -------------------------------------------------------------------------------- 1 | // code related to showing and updating progressbar shown as the image is being made 2 | global_progressbars = {} 3 | galleries = {} 4 | galleryObservers = {} 5 | 6 | // this tracks laumnches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running 7 | timeoutIds = {} 8 | 9 | function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ 10 | // gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id 11 | // every time you use gr.HTML(elem_id='xxx'), so we handle this here 12 | var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar) 13 | var progressbarParent 14 | if(progressbar){ 15 | progressbarParent = gradioApp().querySelector("#"+id_progressbar) 16 | } else{ 17 | progressbar = gradioApp().getElementById(id_progressbar) 18 | progressbarParent = null 19 | } 20 | 21 | var skip = id_skip ? gradioApp().getElementById(id_skip) : null 22 | var interrupt = gradioApp().getElementById(id_interrupt) 23 | 24 | if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ 25 | if(progressbar.innerText){ 26 | let newtitle = 'Stable Diffusion - ' + progressbar.innerText 27 | if(document.title != newtitle){ 28 | document.title = newtitle; 29 | } 30 | }else{ 31 | let newtitle = 'Stable Diffusion' 32 | if(document.title != newtitle){ 33 | document.title = newtitle; 34 | } 35 | } 36 | } 37 | 38 | if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){ 39 | global_progressbars[id_progressbar] = progressbar 40 | 41 | var mutationObserver = new MutationObserver(function(m){ 42 | if(timeoutIds[id_part]) return; 43 | 44 | preview = gradioApp().getElementById(id_preview) 45 | gallery = gradioApp().getElementById(id_gallery) 46 | 47 | if(preview != null && gallery != null){ 48 | preview.style.width = gallery.clientWidth + "px" 49 | preview.style.height = gallery.clientHeight + "px" 50 | if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px" 51 | 52 | //only watch gallery if there is a generation process going on 53 | check_gallery(id_gallery); 54 | 55 | var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; 56 | if(progressDiv){ 57 | timeoutIds[id_part] = window.setTimeout(function() { 58 | timeoutIds[id_part] = null 59 | requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) 60 | }, 500) 61 | } else{ 62 | if (skip) { 63 | skip.style.display = "none" 64 | } 65 | interrupt.style.display = "none" 66 | 67 | //disconnect observer once generation finished, so user can close selected image if they want 68 | if (galleryObservers[id_gallery]) { 69 | galleryObservers[id_gallery].disconnect(); 70 | galleries[id_gallery] = null; 71 | } 72 | } 73 | } 74 | 75 | }); 76 | mutationObserver.observe( progressbar, { childList:true, subtree:true }) 77 | } 78 | } 79 | 80 | function check_gallery(id_gallery){ 81 | let gallery = gradioApp().getElementById(id_gallery) 82 | // if gallery has no change, no need to setting up observer again. 83 | if (gallery && galleries[id_gallery] !== gallery){ 84 | galleries[id_gallery] = gallery; 85 | if(galleryObservers[id_gallery]){ 86 | galleryObservers[id_gallery].disconnect(); 87 | } 88 | let prevSelectedIndex = selected_gallery_index(); 89 | galleryObservers[id_gallery] = new MutationObserver(function (){ 90 | let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') 91 | let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') 92 | if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) { 93 | // automatically re-open previously selected index (if exists) 94 | activeElement = gradioApp().activeElement; 95 | 96 | galleryButtons[prevSelectedIndex].click(); 97 | showGalleryImage(); 98 | 99 | if(activeElement){ 100 | // i fought this for about an hour; i don't know why the focus is lost or why this helps recover it 101 | // if somenoe has a better solution please by all means 102 | setTimeout(function() { activeElement.focus() }, 1); 103 | } 104 | } 105 | }) 106 | galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false }) 107 | } 108 | } 109 | 110 | onUiUpdate(function(){ 111 | check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') 112 | check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') 113 | check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery') 114 | }) 115 | 116 | function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){ 117 | btn = gradioApp().getElementById(id_part+"_check_progress"); 118 | if(btn==null) return; 119 | 120 | btn.click(); 121 | var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; 122 | var skip = id_skip ? gradioApp().getElementById(id_skip) : null 123 | var interrupt = gradioApp().getElementById(id_interrupt) 124 | if(progressDiv && interrupt){ 125 | if (skip) { 126 | skip.style.display = "block" 127 | } 128 | interrupt.style.display = "block" 129 | } 130 | } 131 | 132 | function requestProgress(id_part){ 133 | btn = gradioApp().getElementById(id_part+"_check_progress_initial"); 134 | if(btn==null) return; 135 | 136 | btn.click(); 137 | } 138 | -------------------------------------------------------------------------------- /javascript/ui.js: -------------------------------------------------------------------------------- 1 | // various functions for interation with ui.py not large enough to warrant putting them in separate files 2 | 3 | function set_theme(theme){ 4 | gradioURL = window.location.href 5 | if (!gradioURL.includes('?__theme=')) { 6 | window.location.replace(gradioURL + '?__theme=' + theme); 7 | } 8 | } 9 | 10 | function selected_gallery_index(){ 11 | var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item') 12 | var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2') 13 | 14 | var result = -1 15 | buttons.forEach(function(v, i){ if(v==button) { result = i } }) 16 | 17 | return result 18 | } 19 | 20 | function extract_image_from_gallery(gallery){ 21 | if(gallery.length == 1){ 22 | return gallery[0] 23 | } 24 | 25 | index = selected_gallery_index() 26 | 27 | if (index < 0 || index >= gallery.length){ 28 | return [null] 29 | } 30 | 31 | return gallery[index]; 32 | } 33 | 34 | function args_to_array(args){ 35 | res = [] 36 | for(var i=0;i