├── __init__.py
├── .env
├── ext_modules
├── __init__.py
├── text_analyzer.py
├── vram_manager.py
└── image_generator.py
├── .gitignore
├── .flake8
├── .gitattributes
├── assets
├── demo1.png
├── demo2.png
├── example_face.jpg
└── example_face_readme.txt
├── requirements.txt
├── .vscode
├── settings.json
├── extensions.json
└── launch.json
├── context.py
├── pyproject.toml
├── style.css
├── LICENSE
├── transformers_logits.py
├── README.md
├── script.py
├── sd_client.py
├── params.py
├── settings.debug.yaml
└── ui.py
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.env:
--------------------------------------------------------------------------------
1 | PYTHONPATH=../..
--------------------------------------------------------------------------------
/ext_modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.py[cod]
2 | __pycache__/
3 | .mypy_cache/
4 | outputs/
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 88
3 | extend-ignore = E203, E704, W503
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/assets/demo1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trojaner/text-generation-webui-stable_diffusion/HEAD/assets/demo1.png
--------------------------------------------------------------------------------
/assets/demo2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trojaner/text-generation-webui-stable_diffusion/HEAD/assets/demo2.png
--------------------------------------------------------------------------------
/assets/example_face.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Trojaner/text-generation-webui-stable_diffusion/HEAD/assets/example_face.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | outlines
2 | debugpy
3 | webuiapi
4 | pyright
5 | pylint
6 | Pillow
7 | types-Pillow
8 | types-requests
9 | stringcase
10 | partial-json-parser
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "[python]": {
3 | "editor.defaultFormatter": "ms-python.black-formatter",
4 | "editor.formatOnSave": false,
5 | "editor.codeActionsOnSave": {
6 | "source.organizeImports": "explicit"
7 | }
8 | }
9 | }
--------------------------------------------------------------------------------
/.vscode/extensions.json:
--------------------------------------------------------------------------------
1 | {
2 | "recommendations": [
3 | "ms-python.black-formatter",
4 | "ms-python.flake8",
5 | "ms-python.isort",
6 | "magicstack.magicpython",
7 | "ms-python.vscode-pylance",
8 | "kevinrose.vsc-python-indent"
9 | ]
10 | }
--------------------------------------------------------------------------------
/assets/example_face_readme.txt:
--------------------------------------------------------------------------------
1 | The person depicted in the example_face.jpg file does not exist and has been computer-generated by https://thispersondoesnotexist.com/
2 | for demonstration purposes only. As computer-generated content, the image is not subject to any copyright and can be used freely for any purpose.
3 | Any similarity to actual persons, living or dead, is purely coincidental and unintentional.
4 |
5 | If the person depicted still looks like a real person, please report this by creating an issue in the GitHub repository here:
6 | https://github.com/Trojaner/text-generation-webui-stable_diffusion/issues/new
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "0.2.0",
3 | "configurations": [
4 | {
5 | "name": "Text Generation WebUI",
6 | "type": "debugpy",
7 | "request": "launch",
8 | "program": "server.py",
9 | "cwd": "${workspaceFolder}/../..",
10 | "args": [
11 | "--listen",
12 | "--listen-port=7862",
13 | "--verbose",
14 | "--settings=extensions/stable_diffusion/settings.debug.yaml",
15 | "--model",
16 | "mistral-7b-instruct-v0.1.Q5_K_M.gguf"
17 | ],
18 | "console": "integratedTerminal",
19 | "justMyCode": false
20 | }
21 | ]
22 | }
--------------------------------------------------------------------------------
/context.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from .params import StableDiffusionWebUiExtensionParams
3 | from .sd_client import SdWebUIApi
4 |
5 |
6 | @dataclass
7 | class GenerationContext(object):
8 | params: StableDiffusionWebUiExtensionParams
9 | sd_client: SdWebUIApi
10 | input_text: str | None = None
11 | output_text: str | None = None
12 | is_completed: bool = False
13 | state: dict | None = None
14 |
15 |
16 | _current_context: GenerationContext | None = None
17 |
18 |
19 | def get_current_context() -> GenerationContext | None:
20 | """
21 | Gets the current generation context. Must be called inside a generation request.
22 | """
23 |
24 | return _current_context
25 |
26 |
27 | def set_current_context(context: GenerationContext | None) -> None:
28 | """
29 | Sets the current generation context. Must be called inside a generation request.
30 | """
31 |
32 | global _current_context
33 | _current_context = context
34 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "stable_diffusion"
3 | version = "1.6.1"
4 | authors = [{ name = "Enes Sadık Özbek", email = "es.ozbek@outlook.com" }]
5 | description = "Stable Diffusion integration for text-generation-webui"
6 | readme = { file = "README.md", content-type = "text/markdown" }
7 | requires-python = ">=3.10"
8 | license = { file = "LICENSE" }
9 |
10 | [project.urls]
11 | Repository = "https://github.com/Trojaner/text-generation-webui-stable_diffusion"
12 |
13 | [tool.pyright]
14 | typeCheckingMode = "standard"
15 | useLibraryCodeForTypes = true
16 | reportMissingImports = true
17 | reportMissingTypeStubs = false
18 | reportPrivateImportUsage = false
19 | reportOptionalMemberAccess = false
20 | reportFunctionMemberAccess = true
21 | reportPossiblyUnboundVariable = true
22 | reportMissingModuleSource = false
23 |
24 | pythonVersion = "3.10"
25 |
26 | [tool.isort]
27 | profile = "black"
28 | py_version = 310
29 | known_first_party = ["modules"]
30 | no_lines_before = [
31 | "FUTURE",
32 | "STANDARD_LIBRARY",
33 | "THIRDPARTY",
34 | "FIRSTPARTY",
35 | "LOCALFOLDER",
36 | ]
37 |
--------------------------------------------------------------------------------
/ext_modules/text_analyzer.py:
--------------------------------------------------------------------------------
1 | import html
2 | import re
3 | from ..params import StableDiffusionWebUiExtensionParams
4 |
5 |
6 | def try_get_description_prompt(
7 | message: str, params: StableDiffusionWebUiExtensionParams
8 | ) -> bool | str:
9 | """
10 | Checks if the given message contains any triggers and returns the prompt if it does.
11 | """
12 |
13 | trigger_regex = params.interactive_mode_input_trigger_regex
14 | subject_regex = params.interactive_mode_subject_regex
15 | default_subject = params.interactive_mode_default_subject
16 | default_description_prompt = params.interactive_mode_description_prompt
17 | normalized_message = html.unescape(message).strip()
18 |
19 | if not trigger_regex or not re.match(
20 | trigger_regex, normalized_message, re.IGNORECASE
21 | ):
22 | return False
23 |
24 | subject = default_subject
25 |
26 | if subject_regex:
27 | match = re.match(subject_regex, normalized_message, re.IGNORECASE)
28 | if match:
29 | subject = match.group(0) or default_subject
30 |
31 | return default_description_prompt.replace("[subject]", subject)
32 |
--------------------------------------------------------------------------------
/style.css:
--------------------------------------------------------------------------------
1 | .SDAP #sampler_box {
2 | padding-top: var(--spacing-sm);
3 | padding-bottom: var(--spacing-sm);
4 | border: 0;
5 | }
6 |
7 | .SDAP #steps_box {
8 | border-radius: 0 0 var(--block-radius) var(--block-radius);
9 | }
10 |
11 | .SDAP #sampler_row {
12 | border-bottom: 0;
13 | box-shadow: var(--block-shadow);
14 | border-width: var(--block-border-width);
15 | border-color: var(--block-border-color);
16 | border-radius: var(--block-radius) var(--block-radius) 0 0;
17 | background: var(--block-background-fill);
18 | gap: 0;
19 | }
20 |
21 | .SDAP .refresh-button {
22 | margin-bottom: var(--spacing-sm);
23 | margin-right: var(--spacing-lg);
24 | }
25 |
26 | .SDAP #clip_skip_box,
27 | .SDAP #seed_box,
28 | .SDAP #cfg_box {
29 | padding-top: var(--spacing-md);
30 | }
31 |
32 | .SDAP #sampler_box span,
33 | .SDAP #seed_box span,
34 | .SDAP #cfg_box span,
35 | .SDAP #steps_box span {
36 | margin-bottom: var(--spacing-sm);
37 | }
38 |
39 | .SDAP svg.dropdown-arrow {
40 | flex-shrink: 0 !important;
41 | margin: 0px !important;
42 | }
43 |
44 | .SDAP .hires_opts input[type="number"] {
45 | width: 6em !important;
46 | }
47 |
48 | .SDAP #status-text {
49 | font-weight: bold;
50 | }
51 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) Enes Sadik Özbek
2 |
3 | All code that is not explicitly licensed under a difference license is licensed
4 | under the MIT license.
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining
7 | a copy of this software and associated documentation files (the
8 | "Software"), to deal in the Software without restriction, including
9 | without limitation the rights to use, copy, modify, merge, publish,
10 | distribute, sublicense, and/or sell copies of the Software, and to
11 | permit persons to whom the Software is furnished to do so, subject to
12 | the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be
15 | included in all copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/ext_modules/vram_manager.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from modules.logging_colors import logger
3 | from modules.models import load_model, unload_model, reload_model
4 | from ..context import GenerationContext
5 | import modules.shared as shared
6 |
7 | loaded_model = "None"
8 |
9 | class VramReallocationTarget(Enum):
10 | """
11 | Defines the target for VRAM reallocation.
12 | """
13 |
14 | STABLE_DIFFUSION = 1
15 | LLM = 2
16 |
17 |
18 | def attempt_vram_reallocation(
19 | target: VramReallocationTarget, context: GenerationContext
20 | ) -> None:
21 | """
22 | Reallocates VRAM for the given target if dynamic VRAM reallocations are enabled.
23 | """
24 |
25 | if not context.params.dynamic_vram_reallocation_enabled:
26 | return
27 |
28 | _reallocate_vram_for_target(target, context)
29 |
30 |
31 | def _reallocate_vram_for_target(
32 | target: VramReallocationTarget, context: GenerationContext
33 | ) -> None:
34 | match target:
35 | case VramReallocationTarget.STABLE_DIFFUSION:
36 | _allocate_vram_for_stable_diffusion(context)
37 | case VramReallocationTarget.LLM:
38 | _allocate_vram_for_llm(context)
39 | case _:
40 | raise ValueError(f"Invalid VRAM reallocation target: {target}")
41 |
42 |
43 | def _allocate_vram_for_stable_diffusion(context: GenerationContext) -> None:
44 | global loaded_model
45 | logger.info("SD Extension: unloading the LLM model for SD")
46 | loaded_model = shared.model_name
47 | unload_model()
48 | context.sd_client.reload_checkpoint()
49 |
50 |
51 | def _allocate_vram_for_llm(context: GenerationContext) -> None:
52 | logger.info("SD Extension: unloading the SD model for LLM")
53 | context.sd_client.unload_checkpoint()
54 | shared.model, shared.tokenizer = load_model(loaded_model)
55 |
--------------------------------------------------------------------------------
/transformers_logits.py:
--------------------------------------------------------------------------------
1 | # Implementation taken from "outlines":
2 | # https://github.com/outlines-dev/outlines
3 | #
4 | # License: Apache License 2.0:
5 | # https://github.com/outlines-dev/outlines/blob/68b71ae810e0d6815a83df525da6d707cd4e971a/LICENSE
6 |
7 | from typing import Optional, Type, Union
8 | import torch
9 | from outlines.fsm.guide import Guide, RegexGuide
10 | from outlines.fsm.json_schema import build_regex_from_schema
11 | from outlines.integrations.utils import adapt_tokenizer, convert_json_schema_to_str
12 | from pydantic import BaseModel
13 | from transformers import LogitsProcessor, PreTrainedTokenizerBase
14 | from typing_extensions import override
15 |
16 |
17 | class FsmLogitsProcessor(LogitsProcessor):
18 | def __init__(self, tokenizer: PreTrainedTokenizerBase, fsm: Guide):
19 | self.fsm = fsm
20 | self._tokenizer = tokenizer
21 | self._fsm_state = 0
22 | self._is_first_token = True
23 |
24 | @override
25 | def __call__(
26 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor
27 | ) -> torch.FloatTensor:
28 | is_first_token = self._is_first_token
29 | if self._is_first_token:
30 | self._is_first_token = False
31 |
32 | mask = torch.full_like(scores, -float("inf"))
33 |
34 | for i in range(len(input_ids)):
35 | if not is_first_token:
36 | last_token = int(input_ids[i][-1].item())
37 | self._fsm_state = self.fsm.get_next_state(self._fsm_state, last_token)
38 |
39 | allowed_tokens = self.fsm.get_next_instruction(self._fsm_state).tokens
40 | mask[i][allowed_tokens] = 0
41 |
42 | biased_scores = scores + mask
43 | return biased_scores # type: ignore
44 |
45 | def copy(self) -> "FsmLogitsProcessor":
46 | return FsmLogitsProcessor(tokenizer=self._tokenizer, fsm=self.fsm.copy())
47 |
48 |
49 | class RegexLogitsProcessor(FsmLogitsProcessor):
50 | def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
51 | assert isinstance(tokenizer, PreTrainedTokenizerBase)
52 |
53 | fsm = RegexGuide(regex_string, tokenizer)
54 | super().__init__(tokenizer=tokenizer, fsm=fsm)
55 |
56 |
57 | class JSONLogitsProcessor(RegexLogitsProcessor):
58 | def __init__(
59 | self,
60 | schema: Union[dict, Type[BaseModel], str],
61 | tokenizer: PreTrainedTokenizerBase,
62 | whitespace_pattern: Optional[str] = None,
63 | ):
64 | schema_str = convert_json_schema_to_str(json_schema=schema)
65 | regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
66 | tokenizer = adapt_tokenizer(tokenizer=tokenizer)
67 | super().__init__(regex_string=regex_string, tokenizer=tokenizer)
68 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Stable Diffusion Extension for text-generation-webui
2 | Integrates image generation capabilities to [text-generation-webui](https://github.com/oobabooga/text-generation-webui) using Stable Diffusion.
3 | Requires stable-diffusion-webui with enabled API.
4 |
5 |
6 | > [!WARNING]
7 | > DO NOT DOWNLOAD OR CLONE THIS REPOSITORY AS-IS. PLEASE FOLLOW THE [THE INSTALLATION INSTRUCTIONS](#installation) INSTEAD. YOU WILL GET ERRORS OTHERWISE.
8 |
9 | **Demo:**
10 |
11 |
12 |
13 |
14 |
15 | > [!WARNING]
16 | > This extension has been mainly developed with [SD.Next](https://github.com/vladmandic/automatic) instead of [AUTOMATIC1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui). Please report any potential bugs you come across when using AUTOMATIC1111 instead.
17 | > You **MUST** use the `--api` CLI flag on AUTOMATIC1111 or this extension will not work.
18 |
19 | > [!NOTE]
20 | > Some features, such as IP Adapter, FaceID and clip skip, are only available when using SD.Next.
21 |
22 | ## Features
23 | - Dynamically generate images in text-generation-webui chat by utlizing the SD.Next or AUTOMATIC1111 API.
24 | - Well documented [settings](./settings.debug.yaml) file for quick and easy configuration.
25 | - Configure image generation parameters such as width, height, sampler, sampling steps, cfg scale, clip skip, seed, etc.
26 | - Post process generated images including upscaling, face restoration and HiRes.fix.
27 | - Run stable-diffusion-webui and text-generation-webui on the same GPU even on low VRAM GPUs using the dynamic VRAM allocation feature.
28 | - Use various image generation modes such as continous (generate image on each message) and interactive (generate image if asked so in chat).
29 | - Define generation rules for when and how to generate images. Can be used for character specific prompts and parameters, dynamically adding LoRAs on triggerwords, etc.
30 | - Face swap for generating consistent character images using FaceLabSwap, ReActor or FaceID (see [Ethical Guidelines](#ethical-guidelines)).
31 | - Generate images based roughly on a base reference image using IP Adapter.
32 |
33 | ## Supported Stable Diffusion WebUI Extensions and Features
34 | - [FaceSwapLab](https://github.com/glucauze/sd-webui-faceswaplab)
35 | - [ReActor](https://github.com/Gourieff/sd-webui-reactor)
36 | - [FaceID](https://github.com/vladmandic/automatic)
37 | - Can be used to force a specific face to be used while generating images.
38 | - Unlike FaceSwapLab and ReActor, FaceID supports various art styles such as cartoon, anime, etc.
39 | - Requires SD.Next.
40 | - The insightface, ip_adapter and onnxruntime-gpu PIP packages must be installed for SD.Next first.
41 | - [IP Adapter](https://github.com/vladmandic/automatic)
42 | - Can be used to generate images based roughly on a base reference image.
43 | - Requires SD.Next.
44 | - The ip_adapter and onnxruntime-gpu PIP packages must be installed for SD.Next first.
45 |
46 | ## Installation
47 | - Open a shell with cmd_linux.sh/cmd_macos.sh/cmd_windows.bat inside your text-generation-webui folder.
48 | - Run `git clone https://github.com/Trojaner/text-generation-webui-stable_diffusion extensions/stable_diffusion`.
49 | - Run `pip install -r extensions/stable_diffusion/requirements.txt` to install the required dependencies.
50 | - Open the [settings.debug.yaml](./settings.debug.yaml) file, copy the extension related settings to your own settings.json in the text-generation-webui directory.
51 | - Add `stable_diffusion` to the enabled extensions in settings.json.
52 |
53 | ## Development Environment Setup
54 |
55 | **Pre-requisites**
56 | text-generation-webui, Visual Studio Code and Python 3.10 are required for development.
57 |
58 | **Setting up Visual Studio Code for development**
59 | - [Install the extension first](#installation) if you haven't already.
60 | - Start Visual Studio Code and open the stable_diffusion directory, then trust the repository if it asks you for it.
61 | - Install the [recommended extensions](./.vscode/extensions.json) as they are required for code completion, linting and auto formatting.
62 | - Adjust `.vscode/launch.json` to use your preferred model for debugging or install the default model [mistral-7b-instruct-v0.1.Q5_K_M.gguf](https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/blob/main/mistral-7b-v0.1.Q5_K_M.gguf) instead.
63 | - Adjust `settings.debug.yaml` if needed. These settings will be used during debugging.
64 | - Once you want to test your changes, hit F5 (*Start Debugging*) to debug text-generation-webui with this extension pre-installed and with the `settings.debug.yaml` file as the settings file. You can also use Ctrl + Shift + F5 (*Restart Debugging*) to apply any changes you made to the code by restarting the server from scratch. Checkout [Key Bindings for Visual Studio Code](https://code.visualstudio.com/docs/getstarted/keybindings) for more shortcuts.
65 | - Be sure to check out the [Contribution Guidelines](#contribution-guidelines) below before submitting a pull request.
66 |
67 | ## Contribution Guidelines
68 | - This project relies heavily on type hints, please make sure to add them to your code as well or your pull request will likely get rejected.
69 | - Always reformat your code using [Black](https://github.com/psf/black) and [isort](https://github.com/PyCQA/isort) before committing (it should already do so when saving files if you have installed the recommended extensions).
70 | - Make sure that both, [pylance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance) and [flake8](https://github.com/PyCQA/flake8), are not outputting any linting errors.
71 | - Prefix local functions and variables with an underscore (`_`) to indicate that they are not meant to be used outside of the current file
72 | - Use snake case when naming functions and variables, pascal case when naming classes and uppercase for constants.
73 | - Do not use abbreviations for variable names (such as `ctx` instead of `context`) unless they are simple and common like `i` for index or `n` for number.
74 | - Always document and include new parameters in the `settings.debug.yaml` file.
75 | - Last but not least, ensure that you do not accidentally commit changes you might have made to the `settings.debug.yaml` or `launch.json` files unless intentional.
76 |
77 | ## Ethical Guidelines
78 | This extension integrates with various face swap extensions for stable-diffusion-webui and hence allows to swap faces in the generated images. This extension is not intended for the creation of non-consensual deepfake content. Please use this extension responsibly and do not use it to create such content. The main purpose of the face swapping functionality is to allow the creation of consistent images for text-generation-webui characters. If you are unsure whether your use case is ethical, please refrain from using this extension.
79 |
80 | The maintainers and contributors of this extension cannot be held liable for any misuse of this extension but will try to prevent such misuse by all means.
81 |
82 | ## Todo
83 | - ~~Some basic Gradio UI for fine-tuning the extension parameters at runtime~~
84 | - ~~Support [ReActor](https://github.com/Gourieff/sd-webui-reactor) as alternative faceswap integration [[api implementation](https://github.com/Gourieff/sd-webui-reactor/blob/main/scripts/reactor_api.py)]~~
85 | - ~~Character specific parameters~~
86 | - Integrate with aDetailer extension for stable-diffusion-webui.
87 | - Integrate with other SD extensions / scripts?
88 | - Add tag system for grouping rules and making them mutually exclusive or dependent.
89 | - Add a sentiment analysis for defining image generation rules based on the sentiment of the generated text.
90 | - Add a custom LogitsProcessor or grammar implementation for generating proper and weighted SD image generation prompts.
91 |
92 |
93 | ## See also
94 | - [sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures) - the original stable-diffusion-webui extension which inspired this one
95 | - [sd_api_pictures_tag_injection](https://github.com/GuizzyQC/sd_api_pictures_tag_injection) - a fork of sd_api_pictures with tag injection support
96 |
97 | ## License
98 | [MIT](./LICENSE)
99 |
--------------------------------------------------------------------------------
/script.py:
--------------------------------------------------------------------------------
1 | import html
2 | import re
3 | from dataclasses import asdict
4 | from os import path
5 | from typing import Any, List
6 | from transformers import LogitsProcessor, PreTrainedTokenizerBase
7 | from modules import chat, shared
8 | from modules.logging_colors import logger
9 | from .context import GenerationContext, get_current_context, set_current_context
10 | from .ext_modules.image_generator import generate_html_images_for_context
11 | from .ext_modules.text_analyzer import try_get_description_prompt
12 | from .params import (
13 | InteractiveModePromptGenerationMode,
14 | StableDiffusionWebUiExtensionParams,
15 | TriggerMode,
16 | )
17 | from .sd_client import SdWebUIApi
18 | from .transformers_logits import JSONLogitsProcessor
19 | from .ui import render_ui
20 |
21 | ui_params: Any = StableDiffusionWebUiExtensionParams()
22 | params = asdict(ui_params)
23 |
24 | context: GenerationContext | None = None
25 |
26 | picture_processing_message = "*Is sending a picture...*"
27 | default_processing_message = shared.processing_message
28 | cached_schema: str | None = None
29 | cached_schema_logits: LogitsProcessor | None = None
30 |
31 | EXTENSION_DIRECTORY_NAME = path.basename(path.dirname(path.realpath(__file__)))
32 |
33 |
34 | def get_or_create_context(state: dict | None = None) -> GenerationContext:
35 | global context, params, ui_params
36 |
37 | for key in ui_params.__dict__:
38 | params[key] = ui_params.__dict__[key]
39 |
40 | sd_client = SdWebUIApi(
41 | baseurl=params["api_endpoint"],
42 | username=params["api_username"],
43 | password=params["api_password"],
44 | )
45 |
46 | if context is not None and not context.is_completed:
47 | context.state = (context.state or {}) | (state or {})
48 | context.sd_client = sd_client
49 | return context
50 |
51 | ext_params = StableDiffusionWebUiExtensionParams(**params)
52 | ext_params.normalize()
53 |
54 | context = (
55 | GenerationContext(
56 | params=ext_params,
57 | sd_client=sd_client,
58 | input_text=None,
59 | state=state or {},
60 | )
61 | if context is None or context.is_completed
62 | else context
63 | )
64 |
65 | set_current_context(context)
66 | return context
67 |
68 |
69 | def custom_generate_chat_prompt(text: str, state: dict, **kwargs: dict) -> str:
70 | """
71 | Modifies the user input string in chat mode (visible_text).
72 | You can also modify the internal representation of the user
73 | input (text) to change how it will appear in the prompt.
74 | """
75 |
76 | # bug: this does not trigger on regeneration and hence
77 | # no context is created in that case
78 |
79 | prompt: str = chat.generate_chat_prompt(text, state, **kwargs) # type: ignore
80 | input_text = text
81 |
82 | context = get_or_create_context(state)
83 | context.input_text = input_text
84 | context.state = state
85 |
86 | if (
87 | context is not None and not context.is_completed
88 | ) or context.params.trigger_mode == TriggerMode.MANUAL:
89 | # A manual trigger was used
90 | return prompt
91 |
92 | if context.params.trigger_mode == TriggerMode.INTERACTIVE:
93 | description_prompt = try_get_description_prompt(text, context.params)
94 |
95 | if description_prompt is False:
96 | # did not match image trigger
97 | return prompt
98 |
99 | assert isinstance(description_prompt, str)
100 |
101 | prompt = (
102 | description_prompt
103 | if context.params.interactive_mode_prompt_generation_mode
104 | == InteractiveModePromptGenerationMode.DYNAMIC
105 | else text
106 | )
107 |
108 | return prompt
109 |
110 |
111 | def state_modifier(state: dict) -> dict:
112 | """
113 | Modifies the state variable, which is a dictionary containing the input
114 | values in the UI like sliders and checkboxes.
115 | """
116 |
117 | context = get_or_create_context(state)
118 |
119 | if context is None or context.is_completed:
120 | return state
121 |
122 | if (
123 | context.params.trigger_mode == TriggerMode.TOOL
124 | or context.params.dont_stream_when_generating_images
125 | ):
126 | state["stream"] = False
127 |
128 | shared.processing_message = (
129 | picture_processing_message
130 | if context.params.dont_stream_when_generating_images
131 | else default_processing_message
132 | )
133 |
134 | return state
135 |
136 |
137 | def history_modifier(history: List[str]) -> List[str]:
138 | """
139 | Modifies the chat history.
140 | Only used in chat mode.
141 | """
142 |
143 | context = get_current_context()
144 |
145 | if context is None or context.is_completed:
146 | return history
147 |
148 | # todo: strip
tags from history
149 | return history
150 |
151 |
152 | def cleanup_context() -> None:
153 | context = get_current_context()
154 |
155 | if context is not None:
156 | context.is_completed = True
157 |
158 | set_current_context(None)
159 | shared.processing_message = default_processing_message
160 | pass
161 |
162 |
163 | def output_modifier(string: str, state: dict, is_chat: bool = False) -> str:
164 | """
165 | Modifies the LLM output before it gets presented.
166 |
167 | In chat mode, the modified version goes into history['visible'],
168 | and the original version goes into history['internal'].
169 | """
170 |
171 | global params
172 |
173 | if not is_chat:
174 | cleanup_context()
175 | return string
176 |
177 | context = get_current_context()
178 |
179 | if context is None or context.is_completed:
180 | ext_params = StableDiffusionWebUiExtensionParams(**params)
181 | ext_params.normalize()
182 |
183 | if ext_params.trigger_mode == TriggerMode.INTERACTIVE:
184 | output_regex = ext_params.interactive_mode_output_trigger_regex
185 |
186 | normalized_message = html.unescape(string).strip()
187 |
188 | if output_regex and re.match(
189 | output_regex, normalized_message, re.IGNORECASE
190 | ):
191 | sd_client = SdWebUIApi(
192 | baseurl=ext_params.api_endpoint,
193 | username=ext_params.api_username,
194 | password=ext_params.api_password,
195 | )
196 |
197 | context = GenerationContext(
198 | params=ext_params,
199 | sd_client=sd_client,
200 | input_text=state.get("input", ""),
201 | state=state,
202 | )
203 |
204 | set_current_context(context)
205 |
206 | if context is None or context.is_completed:
207 | cleanup_context()
208 | return string
209 |
210 | context.state = state
211 | context.output_text = string
212 |
213 | if "
None:
284 | """
285 | Gets executed when the UI is drawn. Custom gradio elements and
286 | their corresponding event handlers should be defined here.
287 |
288 | To learn about gradio components, check out the docs:
289 | https://gradio.app/docs/
290 | """
291 |
292 | global ui_params
293 |
294 | ui_params = StableDiffusionWebUiExtensionParams(**params)
295 | render_ui(ui_params)
296 |
--------------------------------------------------------------------------------
/sd_client.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | from asyncio import Task
4 | from dataclasses import dataclass
5 | from io import BytesIO
6 | from typing import Any, List
7 | from PIL import Image
8 | from webuiapi import HiResUpscaler, WebUIApi, WebUIApiResult
9 | from .params import FaceSwapLabParams, ReactorParams
10 |
11 |
12 | @dataclass
13 | class FaceSwapLabFaceSwapResponse:
14 | images: List[Image.Image]
15 | infos: List[str]
16 |
17 | @property
18 | def image(self) -> Image.Image:
19 | return self.images[0]
20 |
21 |
22 | @dataclass
23 | class ReactorFaceSwapResponse:
24 | image: Image.Image
25 |
26 |
27 | class SdWebUIApi(WebUIApi):
28 | """
29 | This class extends the WebUIApi with some additional api endpoints.
30 | """
31 |
32 | def __init__(self, *args: Any, **kwargs: Any) -> None:
33 | super().__init__(*args, **kwargs)
34 |
35 | def unload_checkpoint(self, use_async: bool = False) -> Task[None] | None:
36 | """
37 | Unload the current checkpoint from VRAM.
38 | """
39 | return self.post_and_get_api_result( # type: ignore
40 | f"{self.baseurl}/unload-checkpoint", "", use_async
41 | )
42 |
43 | def reload_checkpoint(self, use_async: bool = False) -> Task[None] | None:
44 | """
45 | Reload the current checkpoint into VRAM.
46 | """
47 |
48 | return self.post_and_get_api_result( # type: ignore
49 | f"{self.baseurl}/reload-checkpoint", "", use_async
50 | )
51 |
52 | def txt2img( # type: ignore
53 | self,
54 | enable_hr: bool = False,
55 | denoising_strength: float = 0.7,
56 | firstphase_width: int = 0,
57 | firstphase_height: int = 0,
58 | hr_scale: float = 2,
59 | hr_upscaler: str = HiResUpscaler.Latent,
60 | hr_second_pass_steps: int = 0,
61 | hr_resize_x: float = 0,
62 | hr_resize_y: float = 0,
63 | hr_sampler: str = "UniPC",
64 | hr_force: bool = False,
65 | prompt: str = "",
66 | styles: List[str] = [],
67 | seed: int = -1,
68 | subseed: int = -1,
69 | subseed_strength: float = 0.0,
70 | seed_resize_from_h: float = 0,
71 | seed_resize_from_w: float = 0,
72 | sampler_name: str | None = None,
73 | batch_size: int = 1,
74 | n_iter: int = 1,
75 | steps: int | None = None,
76 | cfg_scale: float = 6.0,
77 | width: int = 512,
78 | height: int = 512,
79 | restore_faces: bool = False,
80 | tiling: bool = False,
81 | do_not_save_samples: bool = False,
82 | do_not_save_grid: bool = False,
83 | negative_prompt: str = "",
84 | eta: float = 1.0,
85 | s_churn: int = 0,
86 | s_tmax: int = 0,
87 | s_tmin: int = 0,
88 | s_noise: int = 1,
89 | script_args: dict | None = None,
90 | script_name: str | None = None,
91 | send_images: bool = True,
92 | save_images: bool = False,
93 | full_quality: bool = True,
94 | faceid_enabled: bool = False,
95 | faceid_mode: list[str] = ["FaceID", "FaceSwap"],
96 | faceid_model: str = "FaceID Plus v2",
97 | faceid_image: str | None = None,
98 | faceid_scale: float = 1,
99 | faceid_structure: float = 1,
100 | faceid_rank: int = 128,
101 | faceid_tokens: int = 4,
102 | faceid_override_sampler: bool = True,
103 | faceid_cache_model: bool = False,
104 | ipadapter_enabled: bool = False,
105 | ipadapter_adapter: str = "Base",
106 | ipadapter_scale: float = 0.7,
107 | ipadapter_image: str | None = None,
108 | alwayson_scripts: dict = {},
109 | use_async: bool = False,
110 | ) -> Task[WebUIApiResult] | WebUIApiResult:
111 | if sampler_name is None:
112 | sampler_name = self.default_sampler
113 | if steps is None:
114 | steps = self.default_steps
115 | if script_args is None:
116 | script_args = {}
117 | payload = {
118 | "enable_hr": enable_hr or hr_force,
119 | "hr_scale": hr_scale,
120 | "hr_upscaler": hr_upscaler,
121 | "hr_second_pass_steps": hr_second_pass_steps,
122 | "hr_resize_x": hr_resize_x,
123 | "hr_resize_y": hr_resize_y,
124 | "hr_force": hr_force, # SD.Next (no equivalent in AUTOMATIC1111)
125 | "hr_sampler_name": hr_sampler, # AUTOMATIC1111
126 | "latent_sampler": hr_sampler, # SD.Next
127 | "denoising_strength": denoising_strength,
128 | "firstphase_width": firstphase_width,
129 | "firstphase_height": firstphase_height,
130 | "prompt": prompt,
131 | "styles": styles,
132 | "seed": seed,
133 | "full_quality": full_quality, # SD.Next
134 | "subseed": subseed,
135 | "subseed_strength": subseed_strength,
136 | "seed_resize_from_h": seed_resize_from_h,
137 | "seed_resize_from_w": seed_resize_from_w,
138 | "batch_size": batch_size,
139 | "n_iter": n_iter,
140 | "steps": steps,
141 | "cfg_scale": cfg_scale,
142 | "width": width,
143 | "height": height,
144 | "restore_faces": restore_faces,
145 | "tiling": tiling,
146 | "do_not_save_samples": do_not_save_samples,
147 | "do_not_save_grid": do_not_save_grid,
148 | "negative_prompt": negative_prompt,
149 | "eta": eta,
150 | "s_churn": s_churn,
151 | "s_tmax": s_tmax,
152 | "s_tmin": s_tmin,
153 | "s_noise": s_noise,
154 | "sampler_name": sampler_name,
155 | "send_images": send_images,
156 | "save_images": save_images,
157 | }
158 |
159 | if faceid_enabled:
160 | payload["face_id"] = {
161 | "mode": faceid_mode,
162 | "model": faceid_model,
163 | "image": faceid_image,
164 | "scale": faceid_scale,
165 | "structure": faceid_structure,
166 | "rank": faceid_rank,
167 | "override_sampler": faceid_override_sampler,
168 | "tokens": faceid_tokens,
169 | "cache_model": faceid_cache_model,
170 | }
171 | print(json.dumps(payload["face_id"], indent=2))
172 |
173 | if alwayson_scripts:
174 | payload["alwayson_scripts"] = alwayson_scripts
175 |
176 | if script_name:
177 | payload["script_name"] = script_name
178 | payload["script_args"] = script_args
179 |
180 | if ipadapter_enabled:
181 | payload["ip_adapter"] = {
182 | "adapter": ipadapter_adapter,
183 | "scale": ipadapter_scale,
184 | "image": ipadapter_image,
185 | }
186 |
187 | return self.post_and_get_api_result(
188 | f"{self.baseurl}/txt2img", payload, use_async
189 | )
190 |
191 | def reactor_swap_face(
192 | self,
193 | target_image: Image.Image,
194 | params: ReactorParams,
195 | use_async: bool = False,
196 | ) -> Task[ReactorFaceSwapResponse] | ReactorFaceSwapResponse:
197 | """
198 | Swaps a face in an image using the ReActor extension.
199 | """
200 | buffer = BytesIO()
201 | target_image.save(buffer, format="PNG")
202 | target_image_base64 = base64.b64encode(buffer.getvalue()).decode()
203 |
204 | source_image_base64 = None
205 | source_model = None
206 |
207 | reference_face_image_path = params.reactor_source_face
208 | reference_face_source = 0
209 |
210 | if reference_face_image_path.startswith("checkpoint://"):
211 | source_model = reference_face_image_path.replace("checkpoint://", "")
212 | reference_face_source = 1
213 | elif reference_face_image_path.startswith("data:image"):
214 | source_image_base64 = reference_face_image_path.split(",")[1]
215 | elif reference_face_image_path.startswith("file:///"):
216 | # todo: ensure path is inside text-generation-webui folder
217 | path = reference_face_image_path.replace("file:///", "")
218 |
219 | with open(path, "rb") as image_file:
220 | source_image_base64 = base64.b64encode(image_file.read()).decode()
221 | else:
222 | raise Exception(f"Failed to parse source face: {reference_face_image_path}")
223 |
224 | payload = {
225 | "source_image": source_image_base64 if reference_face_source == 0 else "",
226 | "target_image": target_image_base64,
227 | "source_faces_index": [params.reactor_source_face_index],
228 | "face_index": [params.reactor_target_face_index],
229 | "upscaler": params.reactor_upscaling_upscaler
230 | if params.reactor_upscaling_enabled
231 | else "None",
232 | "scale": params.reactor_upscaling_scale,
233 | "upscale_visibility": params.reactor_upscaling_visibility,
234 | "face_restorer": params.reactor_restore_face_model
235 | if params.reactor_restore_face_enabled
236 | else "None",
237 | "restorer_visibility": params.reactor_restore_face_visibility,
238 | "codeformer_weight": params.reactor_restore_face_codeformer_weight,
239 | "restore_first": 0 if params.reactor_restore_face_upscale_first else 1,
240 | "model": params.reactor_model,
241 | "gender_source": params.reactor_source_gender,
242 | "gender_target": params.reactor_target_gender,
243 | "save_to_file": 0,
244 | "result_file_path": "",
245 | "device": params.reactor_device,
246 | "mask_face": 1 if params.reactor_mask_face else 0,
247 | "select_source": reference_face_source,
248 | "face_model": source_model if reference_face_source == 1 else "None",
249 | "source_folder": "",
250 | }
251 |
252 | return self.post_and_get_api_result( # type: ignore
253 | f"{self.baseurl.replace('/sdapi/v1', '/reactor')}/image",
254 | payload,
255 | use_async,
256 | )
257 |
258 | def faceswaplab_swap_face(
259 | self,
260 | target_image: Image.Image,
261 | params: FaceSwapLabParams,
262 | use_async: bool = False,
263 | ) -> Task[FaceSwapLabFaceSwapResponse] | FaceSwapLabFaceSwapResponse:
264 | """
265 | Swaps a face in an image using the FaceSwapLab extension.
266 | """
267 |
268 | buffer = BytesIO()
269 | target_image.save(buffer, format="PNG")
270 | target_image_base64 = base64.b64encode(buffer.getvalue()).decode()
271 |
272 | source_image_base64 = None
273 | source_face_checkpoint = None
274 |
275 | reference_face_image_path = params.faceswaplab_source_face
276 |
277 | if reference_face_image_path.startswith("checkpoint://"):
278 | source_face_checkpoint = reference_face_image_path.replace(
279 | "checkpoint://", ""
280 | )
281 | elif reference_face_image_path.startswith("data:image"):
282 | source_image_base64 = reference_face_image_path.split(",")[1]
283 | elif reference_face_image_path.startswith("file:///"):
284 | # todo: ensure path is inside text-generation-webui folder
285 | path = reference_face_image_path.replace("file:///", "")
286 |
287 | with open(path, "rb") as image_file:
288 | source_image_base64 = base64.b64encode(image_file.read()).decode()
289 | else:
290 | raise Exception(f"Failed to parse source face: {reference_face_image_path}")
291 |
292 | payload = {
293 | "image": target_image_base64,
294 | "units": [
295 | {
296 | "source_img": source_image_base64,
297 | "source_face": source_face_checkpoint,
298 | "blend_faces": params.faceswaplab_blend_faces,
299 | "same_gender": params.faceswaplab_same_gender_only,
300 | "sort_by_size": params.faceswaplab_sort_by_size,
301 | "check_similarity": False,
302 | "compute_similarity": False,
303 | "min_sim": 0,
304 | "min_ref_sim": 0,
305 | "faces_index": [params.faceswaplab_target_face_index],
306 | "reference_face_index": params.faceswaplab_source_face_index,
307 | "pre_inpainting": {
308 | "inpainting_denoising_strengh": 0,
309 | "inpainting_prompt": "Portrait of a [gender]",
310 | "inpainting_negative_prompt": "blurry",
311 | "inpainting_steps": 20,
312 | "inpainting_sampler": "Default",
313 | "inpainting_model": "Current",
314 | "inpainting_seed": 0,
315 | },
316 | "swapping_options": {
317 | "face_restorer_name": params.faceswaplab_restore_face_model
318 | if params.faceswaplab_restore_face_enabled
319 | else "None",
320 | "restorer_visibility": params.faceswaplab_restore_face_visibility, # noqa: E501
321 | "codeformer_weight": params.faceswaplab_restore_face_codeformer_weight, # noqa: E501
322 | "upscaler_name": params.faceswaplab_upscaling_upscaler
323 | if params.faceswaplab_upscaling_enabled
324 | else "None",
325 | "improved_mask": params.faceswaplab_mask_improved_mask_enabled,
326 | "erosion_factor": params.faceswaplab_mask_erosion_factor,
327 | "color_corrections": params.faceswaplab_color_corrections_enabled, # noqa: E501
328 | "sharpen": params.faceswaplab_sharpen_face,
329 | },
330 | "post_inpainting": {
331 | "inpainting_denoising_strengh": 0,
332 | "inpainting_prompt": "Portrait of a [gender]",
333 | "inpainting_negative_prompt": "blurry",
334 | "inpainting_steps": 20,
335 | "inpainting_sampler": "Default",
336 | "inpainting_model": "Current",
337 | "inpainting_seed": 0,
338 | },
339 | }
340 | ],
341 | "postprocessing": {
342 | "face_restorer_name": params.faceswaplab_postprocessing_restore_face_model # noqa: E501
343 | if params.faceswaplab_postprocessing_restore_face_enabled
344 | else "None",
345 | "restorer_visibility": params.faceswaplab_postprocessing_restore_face_visibility, # noqa: E501
346 | "codeformer_weight": params.faceswaplab_postprocessing_restore_face_codeformer_weight, # noqa: E501
347 | "upscaler_name": params.faceswaplab_postprocessing_upscaling_upscaler
348 | if params.faceswaplab_postprocessing_upscaling_enabled
349 | else "None",
350 | "scale": params.faceswaplab_postprocessing_upscaling_scale,
351 | "upscaler_visibility": params.faceswaplab_postprocessing_upscaling_visibility, # noqa: E501
352 | "inpainting_when": "After Upscaling/Before Restore Face",
353 | "inpainting_options": {
354 | "inpainting_denoising_strengh": 0,
355 | "inpainting_prompt": "Portrait of a [gender]",
356 | "inpainting_negative_prompt": "blurry",
357 | "inpainting_steps": 20,
358 | "inpainting_sampler": "Default",
359 | "inpainting_model": "Current",
360 | "inpainting_seed": 0,
361 | },
362 | },
363 | }
364 |
365 | return self.post_and_get_api_result( # type: ignore
366 | f"{self.baseurl.replace('/sdapi/v1', '/faceswaplab')}/swap_face",
367 | payload,
368 | use_async,
369 | )
370 |
371 | def refresh_vae(self) -> Any:
372 | response = self.session.post(url=f"{self.baseurl}/refresh-vae")
373 | return response.json()
374 |
--------------------------------------------------------------------------------
/params.py:
--------------------------------------------------------------------------------
1 | import base64
2 | from dataclasses import dataclass, field, fields
3 | from enum import Enum
4 | import requests
5 | from typing_extensions import Self
6 | from modules.logging_colors import logger
7 |
8 | default_description_prompt = """
9 | You are now a text generator for the Stable Diffusion AI image generator. You will generate a text prompt for it.
10 |
11 | Describe [subject] using comma-separated tags only. Do not use sentences.
12 | Include many tags such as tags for the environment, gender, clothes, age, location, light, daytime, angle, pose, etc.
13 |
14 | Do not write anything else. Do not ask any questions. Do not talk.
15 | """ # noqa E501
16 |
17 |
18 | class TriggerMode(str, Enum):
19 | TOOL = "tool"
20 | CONTINUOUS = "continuous"
21 | INTERACTIVE = "interactive"
22 | MANUAL = "manual"
23 |
24 | @classmethod
25 | def index_of(cls, mode: Self) -> int:
26 | return list(TriggerMode).index(mode)
27 |
28 | @classmethod
29 | def from_index(cls, index: int) -> Self:
30 | return list(TriggerMode)[index] # type: ignore
31 |
32 | def __str__(self) -> str:
33 | return self
34 |
35 |
36 | class IPAdapterAdapter(str, Enum):
37 | BASE = "Base"
38 | LIGHT = "Light"
39 | PLUS = "Plus"
40 | PLUS_FACE = "Plus Face"
41 | FULL_FACE = "Full face"
42 | BASE_SDXL = "Base SDXL"
43 |
44 | @classmethod
45 | def index_of(cls, mode: Self) -> int:
46 | return list(IPAdapterAdapter).index(mode)
47 |
48 | @classmethod
49 | def from_index(cls, index: int) -> Self:
50 | return list(IPAdapterAdapter)[index] # type: ignore
51 |
52 | def __str__(self) -> str:
53 | return self
54 |
55 |
56 | class ContinuousModePromptGenerationMode(str, Enum):
57 | STATIC = "static"
58 | GENERATED_TEXT = "generated_text"
59 |
60 | @classmethod
61 | def index_of(cls, mode: Self) -> int:
62 | return list(ContinuousModePromptGenerationMode).index(mode)
63 |
64 | @classmethod
65 | def from_index(cls, index: int) -> Self:
66 | return list(ContinuousModePromptGenerationMode)[index] # type: ignore
67 |
68 | def __str__(self) -> str:
69 | return self
70 |
71 |
72 | class InteractiveModePromptGenerationMode(str, Enum):
73 | STATIC = "static"
74 | GENERATED_TEXT = "generated_text"
75 | DYNAMIC = "dynamic"
76 |
77 | @classmethod
78 | def index_of(cls, mode: Self) -> int:
79 | return list(InteractiveModePromptGenerationMode).index(mode)
80 |
81 | @classmethod
82 | def from_index(cls, index: int) -> Self:
83 | return list(InteractiveModePromptGenerationMode)[index] # type: ignore
84 |
85 | def __str__(self) -> str:
86 | return self
87 |
88 |
89 | class ReactorFace(int, Enum):
90 | NONE = 0
91 | FEMALE = 1
92 | MALE = 2
93 |
94 | @classmethod
95 | def index_of(cls, mode: Self) -> int:
96 | return list(ReactorFace).index(mode)
97 |
98 | @classmethod
99 | def from_index(cls, index: int) -> Self:
100 | return list(ReactorFace)[index] # type: ignore
101 |
102 |
103 | @dataclass
104 | class StableDiffusionClientParams:
105 | api_endpoint: str = field(default="http://127.0.0.1:7860/sdapi/v1")
106 | api_username: str | None = field(default=None)
107 | api_password: str | None = field(default=None)
108 |
109 |
110 | @dataclass
111 | class StableDiffusionGenerationParams:
112 | base_prompt: str = field(
113 | default=(
114 | "RAW photo, subject, 8k uhd, dslr, soft lighting, high quality, "
115 | "film grain, Fujifilm XT3"
116 | )
117 | )
118 | base_negative_prompt: str = field(
119 | default=(
120 | "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, "
121 | "sketch, cartoon, drawing, anime), text, cropped, out of frame, "
122 | "worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, "
123 | "mutilated, extra fingers, mutated hands, poorly drawn hands, "
124 | "poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, "
125 | "bad proportions, extra limbs, cloned face, disfigured, gross proportions, "
126 | "malformed limbs, missing arms, missing legs, extra arms, extra legs, "
127 | "fused fingers, too many fingers, long neck"
128 | )
129 | )
130 | sampler_name: str = field(default="DPM SDE")
131 | sampling_steps: int = field(default=25)
132 | width: int = field(default=512)
133 | height: int = field(default=512)
134 | cfg_scale: float = field(default=6)
135 | clip_skip: int = field(default=1)
136 | seed: int = field(default=-1)
137 |
138 |
139 | @dataclass
140 | class StableDiffusionPostProcessingParams:
141 | upscaling_enabled: bool = field(default=False)
142 | upscaling_upscaler: str = field(default="RealESRGAN 4x+")
143 | upscaling_scale: float = field(default=2)
144 | hires_fix_enabled: bool = field(default=False)
145 | hires_fix_denoising_strength: float = field(default=0.2)
146 | hires_fix_sampler: str = field(default="UniPC")
147 | hires_fix_sampling_steps: int = field(default=10)
148 | restore_faces_enabled: bool = field(default=False)
149 |
150 |
151 | @dataclass
152 | class RegexGenerationRuleMatch(str, Enum):
153 | INPUT: str = "input"
154 | INPUT_SENTENCE: str = "input_sentence"
155 | OUTPUT: str = "output"
156 | OUTPUT_SENTENCE: str = "output_sentence"
157 | CHARACTER_NAME: str = "character_name"
158 |
159 | def __str__(self) -> str:
160 | return self
161 |
162 |
163 | @dataclass
164 | class RegexGenerationAction:
165 | name: str
166 | args: str | None
167 |
168 |
169 | @dataclass
170 | class RegexGenerationRule:
171 | regex: str | None
172 | negative_regex: str | None
173 | match: list[RegexGenerationRuleMatch] | None
174 | actions: list[RegexGenerationAction]
175 |
176 |
177 | @dataclass
178 | class UserPreferencesParams:
179 | save_images: bool = field(default=True)
180 | trigger_mode: TriggerMode = field(default=TriggerMode.TOOL)
181 | tool_mode_force_json_output_enabled: bool = field(default=True)
182 | tool_mode_force_json_output_schema: str = field(default="")
183 | interactive_mode_input_trigger_regex: str = field(
184 | default=".*(send|upload|add|show|attach|generate)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)(s?)" # noqa E501
185 | )
186 | interactive_mode_output_trigger_regex: str = field(
187 | default=".*[*([](sends|uploads|adds|shows|attaches|generates|here (is|are))\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)(s?)" # noqa E501
188 | )
189 | interactive_mode_prompt_generation_mode: InteractiveModePromptGenerationMode = (
190 | field(default=InteractiveModePromptGenerationMode.DYNAMIC)
191 | )
192 | interactive_mode_subject_regex: str = field(default=".*\\b(of)\\b(.+?)(?:[.!?]|$)")
193 | interactive_mode_description_prompt: str = field(default=default_description_prompt)
194 | interactive_mode_default_subject: str = field(
195 | default="your appearance, your surroundings and what you are doing right now"
196 | )
197 | continuous_mode_prompt_generation_mode: ContinuousModePromptGenerationMode = field(
198 | default=ContinuousModePromptGenerationMode.GENERATED_TEXT
199 | )
200 | dynamic_vram_reallocation_enabled: bool = field(default=False)
201 | dont_stream_when_generating_images: bool = field(default=True)
202 | generation_rules: dict | None = field(
203 | default=None
204 | ) # list[RegexGenerationRule] | None = field(default=None)
205 |
206 |
207 | @dataclass
208 | class FaceIDParams:
209 | faceid_enabled: bool = field(default=False)
210 | faceid_source_face: str = field(
211 | default=("file:///extensions/stable_diffusion/assets/example_face.jpg")
212 | )
213 | faceid_mode: list[str] = field(default_factory=lambda: ["FaceID", "FaceSwap"])
214 | faceid_model: str = field(default="FaceID Plus v2")
215 | faceid_override_sampler: bool = field(default=True)
216 | faceid_strength: float = field(default=1.0)
217 | faceid_structure: float = field(default=1.0)
218 | faceid_rank: int = field(default=128)
219 | faceid_tokens: int = field(default=4)
220 | faceid_cache_model: bool = field(default=False)
221 |
222 |
223 | @dataclass
224 | class IPAdapterParams:
225 | ipadapter_enabled: bool = field(default=False)
226 | ipadapter_adapter: IPAdapterAdapter = field(default=IPAdapterAdapter.BASE)
227 | ipadapter_reference_image: str = field(
228 | default=("file:///extensions/stable_diffusion/assets/example_face.jpg")
229 | )
230 | ipadapter_scale: float = field(default=0.5)
231 |
232 |
233 | @dataclass
234 | class FaceSwapLabParams:
235 | faceswaplab_enabled: bool = field(default=False)
236 | faceswaplab_source_face: str = field(
237 | default=("file:///extensions/stable_diffusion/assets/example_face.jpg")
238 | )
239 | faceswaplab_upscaling_enabled: bool = field(default=False)
240 | faceswaplab_upscaling_upscaler: str = field(default="RealESRGAN 4x+")
241 | faceswaplab_upscaling_scale: float = field(default=2)
242 | faceswaplab_upscaling_visibility: float = field(default=1)
243 | faceswaplab_postprocessing_upscaling_enabled: bool = field(default=False)
244 | faceswaplab_postprocessing_upscaling_upscaler: str = field(default="RealESRGAN 4x+")
245 | faceswaplab_postprocessing_upscaling_scale: float = field(default=2)
246 | faceswaplab_postprocessing_upscaling_visibility: float = field(default=1)
247 | faceswaplab_same_gender_only: bool = field(default=True)
248 | faceswaplab_sort_by_size: bool = field(default=True)
249 | faceswaplab_source_face_index: int = field(default=0)
250 | faceswaplab_target_face_index: int = field(default=0)
251 | faceswaplab_restore_face_enabled: bool = field(default=False)
252 | faceswaplab_restore_face_model: str = field(default="CodeFormer")
253 | faceswaplab_restore_face_visibility: float = field(default=1)
254 | faceswaplab_restore_face_codeformer_weight: float = field(default=1)
255 | faceswaplab_postprocessing_restore_face_enabled: bool = field(default=False)
256 | faceswaplab_postprocessing_restore_face_model: str = field(default="CodeFormer")
257 | faceswaplab_postprocessing_restore_face_visibility: float = field(default=1)
258 | faceswaplab_postprocessing_restore_face_codeformer_weight: float = field(default=1)
259 | faceswaplab_color_corrections_enabled: bool = field(default=False)
260 | faceswaplab_mask_erosion_factor: float = field(default=1)
261 | faceswaplab_mask_improved_mask_enabled: bool = field(default=False)
262 | faceswaplab_sharpen_face: bool = field(default=False)
263 | faceswaplab_blend_faces: bool = field(default=True)
264 |
265 |
266 | @dataclass
267 | class ReactorParams:
268 | reactor_enabled: bool = field(default=False)
269 | reactor_source_face: str = field(
270 | default=("file:///extensions/stable_diffusion/assets/example_face.jpg")
271 | )
272 | reactor_source_gender: ReactorFace = field(default=ReactorFace.NONE)
273 | reactor_target_gender: ReactorFace = field(default=ReactorFace.NONE)
274 | reactor_source_face_index: int = field(default=0)
275 | reactor_target_face_index: int = field(default=0)
276 | reactor_restore_face_enabled: bool = field(default=False)
277 | reactor_restore_face_model: str = field(default="CodeFormer")
278 | reactor_restore_face_visibility: float = field(default=1)
279 | reactor_restore_face_codeformer_weight: float = field(default=1)
280 | reactor_restore_face_upscale_first: bool = field(default=False)
281 | reactor_upscaling_enabled: bool = field(default=False)
282 | reactor_upscaling_upscaler: str = field(default="RealESRGAN 4x+")
283 | reactor_upscaling_scale: float = field(default=2)
284 | reactor_upscaling_visibility: float = field(default=1)
285 | reactor_mask_face: bool = field(default=False)
286 | reactor_model: str = field(default="inswapper_128.onnx")
287 | reactor_device: str = field(default="CPU")
288 |
289 |
290 | @dataclass(kw_only=True)
291 | class StableDiffusionWebUiExtensionParams(
292 | StableDiffusionClientParams,
293 | StableDiffusionGenerationParams,
294 | StableDiffusionPostProcessingParams,
295 | UserPreferencesParams,
296 | FaceSwapLabParams,
297 | ReactorParams,
298 | FaceIDParams,
299 | IPAdapterParams,
300 | ):
301 | display_name: str = field(default="Stable Diffusion")
302 | is_tab: bool = field(default=True)
303 | debug_mode_enabled: bool = field(default=False)
304 |
305 | def update(self, params: dict) -> None:
306 | """
307 | Updates the parameters.
308 | """
309 |
310 | for f in params.keys():
311 | assert f in [x.name for x in fields(self)], f"Invalid field for params: {f}"
312 |
313 | val = params[f]
314 | setattr(self, f, val)
315 |
316 | def normalize(self) -> None:
317 | """
318 | Normalizes the parameters. This should be called after changing any parameters.
319 | """
320 |
321 | if self.api_username is not None and self.api_username.strip() == "":
322 | self.api_username = None
323 |
324 | if self.api_password is not None and self.api_password.strip() == "":
325 | self.api_password = None
326 |
327 | if isinstance(self.reactor_source_gender, str):
328 | self.reactor_source_gender = (
329 | ReactorFace[self.reactor_source_gender.upper()] or ReactorFace.NONE
330 | )
331 |
332 | if isinstance(self.reactor_target_gender, str):
333 | self.reactor_target_gender = (
334 | ReactorFace[self.reactor_target_gender.upper()] or ReactorFace.NONE
335 | )
336 |
337 | # Todo: images are redownloaded and files are reread every time a text is generated. # noqa E501
338 | # This happens because normalize() is called on every generation and the downloaded values are not cached. # noqa E501
339 |
340 | if self.faceswaplab_enabled and (
341 | self.faceswaplab_source_face.startswith("http://")
342 | or self.faceswaplab_source_face.startswith("https://")
343 | ):
344 | try:
345 | self.faceswaplab_source_face = base64.b64encode(
346 | requests.get(self.faceswaplab_source_face).content
347 | ).decode()
348 | except Exception as e:
349 | logger.exception(
350 | "Failed to load FaceSwapLab source face image: %s", e, exc_info=True
351 | )
352 | self.faceswaplab_enabled = False
353 |
354 | if self.reactor_enabled and (
355 | self.reactor_source_face.startswith("http://")
356 | or self.reactor_source_face.startswith("https://")
357 | ):
358 | try:
359 | self.reactor_source_face = base64.b64encode(
360 | requests.get(self.reactor_source_face).content
361 | ).decode()
362 | except Exception as e:
363 | logger.exception(
364 | "Failed to load ReActor source face image: %s", e, exc_info=True
365 | )
366 | self.reactor_enabled = False
367 |
368 | if self.faceid_enabled:
369 | try:
370 | if self.faceid_source_face.startswith(
371 | "http://"
372 | ) or self.faceid_source_face.startswith("https://"):
373 | self.faceid_source_face = base64.b64encode(
374 | requests.get(self.faceid_source_face).content
375 | ).decode()
376 |
377 | if self.faceid_source_face.startswith("file:///"):
378 | with open(
379 | self.faceid_source_face.replace("file:///", ""), "rb"
380 | ) as f:
381 | self.faceid_source_face = base64.b64encode(f.read()).decode()
382 | except Exception as e:
383 | logger.exception(
384 | "Failed to load FaceID source face image: %s", e, exc_info=True
385 | )
386 | self.faceid_enabled = False
387 |
388 | if self.ipadapter_enabled:
389 | try:
390 | if self.ipadapter_reference_image.startswith(
391 | "http://"
392 | ) or self.ipadapter_reference_image.startswith("https://"):
393 | self.ipadapter_reference_image = base64.b64encode(
394 | requests.get(self.ipadapter_reference_image).content
395 | ).decode()
396 |
397 | if self.ipadapter_reference_image.startswith("file:///"):
398 | with open(
399 | self.ipadapter_reference_image.replace("file:///", ""), "rb"
400 | ) as f:
401 | self.ipadapter_reference_image = base64.b64encode(
402 | f.read()
403 | ).decode()
404 | except Exception as e:
405 | logger.exception(
406 | "Failed to load IP Adapter reference image: %s", e, exc_info=True
407 | )
408 | self.ipadapter_enabled = False
409 |
--------------------------------------------------------------------------------
/ext_modules/image_generator.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import dataclasses
3 | import html
4 | import io
5 | import re
6 | import time
7 | from datetime import date
8 | from pathlib import Path
9 | from typing import Any, cast
10 | from partial_json_parser import loads
11 | from PIL import Image
12 | from webuiapi import WebUIApiResult
13 | from modules.logging_colors import logger
14 | from ..context import GenerationContext
15 | from ..params import (
16 | ContinuousModePromptGenerationMode,
17 | InteractiveModePromptGenerationMode,
18 | RegexGenerationRuleMatch,
19 | TriggerMode,
20 | )
21 | from .vram_manager import VramReallocationTarget, attempt_vram_reallocation
22 |
23 |
24 | def normalize_regex(regex: str) -> str:
25 | if not regex.startswith("^") and not regex.startswith(".*"):
26 | regex = f".*{regex}"
27 |
28 | if not regex.endswith("$") and not regex.endswith(".*"):
29 | regex = f"{regex}.*"
30 |
31 | return regex
32 |
33 |
34 | def normalize_prompt(prompt: str) -> str:
35 | if prompt is None:
36 | return ""
37 |
38 | result = (
39 | prompt.replace("*", "")
40 | .replace('"', "")
41 | .replace("!", ",")
42 | .replace("?", ",")
43 | .replace("&", "")
44 | .replace("\r", "")
45 | .replace("\n", ", ")
46 | .replace("*", "")
47 | .replace("#", "")
48 | .replace(".,", ",")
49 | .replace(",,", ",")
50 | .replace(", ,", ",")
51 | .replace(";", ",")
52 | .strip()
53 | .strip(",")
54 | .strip()
55 | )
56 |
57 | # deduplicate tags
58 | tags = set([x.strip() for x in result.split(",")])
59 | return ", ".join(tags)
60 |
61 |
62 | def generate_html_images_for_context(
63 | context: GenerationContext,
64 | ) -> tuple[str, str | None, str | None, str | None, str | None, str | None]:
65 | """
66 | Generates images for the given context using Stable Diffusion
67 | and returns the result as HTML output
68 | """
69 |
70 | attempt_vram_reallocation(VramReallocationTarget.STABLE_DIFFUSION, context)
71 |
72 | sd_client = context.sd_client
73 |
74 | output_text = context.output_text or ""
75 |
76 | rules_prompt = ""
77 | rules_negative_prompt = ""
78 |
79 | faceswaplab_force_enabled: bool | None = None
80 | faceswaplab_overwrite_source_face: str | None = None
81 |
82 | reactor_force_enabled: bool | None = None
83 | reactor_overwrite_source_face: str | None = None
84 |
85 | if context.params.generation_rules:
86 | for rule in context.params.generation_rules:
87 | try:
88 | match_against = []
89 |
90 | delimiters = ".", ",", "!", "?", "\n", "*", '"'
91 | delimiters_regex_pattern = "|".join(map(re.escape, delimiters))
92 |
93 | if "match" in rule:
94 | if (
95 | context.input_text
96 | and context.input_text != ""
97 | and RegexGenerationRuleMatch.INPUT.value in rule["match"]
98 | ):
99 | match_against.append(context.input_text.strip())
100 |
101 | if (
102 | context.input_text
103 | and context.input_text != ""
104 | and RegexGenerationRuleMatch.INPUT_SENTENCE.value
105 | in rule["match"]
106 | ):
107 | match_against += [
108 | x.strip()
109 | for x in re.split(
110 | delimiters_regex_pattern, context.input_text
111 | )
112 | if x.strip() != ""
113 | ]
114 |
115 | if (
116 | output_text
117 | and output_text != ""
118 | and RegexGenerationRuleMatch.OUTPUT.value in rule["match"]
119 | ):
120 | match_against.append(html.unescape(output_text).strip())
121 |
122 | if (
123 | output_text
124 | and output_text != ""
125 | and RegexGenerationRuleMatch.OUTPUT_SENTENCE.value
126 | in rule["match"]
127 | ):
128 | match_against += [
129 | x.strip()
130 | for x in re.split(delimiters_regex_pattern, output_text)
131 | if x.strip() != ""
132 | ]
133 |
134 | if (
135 | context.state
136 | and "character_menu" in context.state
137 | and context.state["character_menu"]
138 | and context.state["character_menu"] != ""
139 | and RegexGenerationRuleMatch.CHARACTER_NAME.value
140 | in rule["match"]
141 | ):
142 | match_against.append(context.state["character_menu"])
143 |
144 | if "negative_regex" in rule and any(
145 | re.match(
146 | normalize_regex(rule["negative_regex"]), x, re.IGNORECASE
147 | )
148 | for x in match_against
149 | ):
150 | continue
151 |
152 | if "regex" in rule and not any(
153 | re.match(normalize_regex(rule["regex"]), x, re.IGNORECASE)
154 | for x in match_against
155 | ):
156 | continue
157 |
158 | if "actions" not in rule:
159 | continue
160 |
161 | for action in rule["actions"]:
162 | if action["name"] == "skip_generation":
163 | return (
164 | output_text,
165 | None,
166 | "",
167 | "",
168 | context.params.base_prompt,
169 | context.params.base_negative_prompt,
170 | )
171 |
172 | if action["name"] == "prompt_append" and "args" in action:
173 | rules_prompt = _combine_prompts(rules_prompt, action["args"])
174 |
175 | if action["name"] == "negative_prompt_append" "args" in action:
176 | rules_negative_prompt += _combine_prompts(
177 | rules_negative_prompt, action["args"]
178 | )
179 |
180 | if action["name"] == "faceswaplab_enable":
181 | faceswaplab_force_enabled = True
182 |
183 | if action["name"] == "faceswaplab_disable":
184 | faceswaplab_force_enabled = False
185 |
186 | if (
187 | action["name"] == "faceswaplab_set_source_face"
188 | and "args" in action
189 | ):
190 | faceswaplab_overwrite_source_face = action["args"]
191 |
192 | if action["name"] == "reactor_enable":
193 | reactor_force_enabled = True
194 |
195 | if action["name"] == "reactor_disable":
196 | reactor_force_enabled = False
197 |
198 | if action["name"] == "reactor_set_source_face" and "args" in action:
199 | reactor_overwrite_source_face = action["args"]
200 |
201 | except Exception as e:
202 | logger.error(
203 | f"[SD WebUI Integration] Failed to apply rule: {rule['regex']}: %s",
204 | e,
205 | exc_info=True,
206 | )
207 |
208 | context_prompt = None
209 |
210 | if context.params.trigger_mode == TriggerMode.INTERACTIVE and (
211 | context.params.interactive_mode_prompt_generation_mode
212 | == InteractiveModePromptGenerationMode.GENERATED_TEXT
213 | or InteractiveModePromptGenerationMode.DYNAMIC
214 | ):
215 | context_prompt = html.unescape(output_text or "")
216 |
217 | if context.params.trigger_mode == TriggerMode.CONTINUOUS and (
218 | context.params.continuous_mode_prompt_generation_mode
219 | == ContinuousModePromptGenerationMode.GENERATED_TEXT
220 | ):
221 | context_prompt = html.unescape(output_text or "")
222 |
223 | if context.params.trigger_mode == TriggerMode.TOOL:
224 | output_text = html.unescape(output_text or "").strip()
225 |
226 | json_search = re.search(
227 | r"(\b)?([{\[].*[\]}])(\b)?", output_text, flags=re.I | re.M | re.S | re.U
228 | )
229 |
230 | if not json_search:
231 | logger.warning(
232 | "No JSON output found in the output text: %s.\nTry enabling JSON grammar rules to avoid such errors.",
233 | output_text,
234 | )
235 |
236 | json_text_original = json_search.group(0) if json_search else "{}"
237 |
238 | try:
239 | json_text = (
240 | json_text_original.strip()
241 | .replace("\r\n", "\n")
242 | .replace("'", "")
243 | .replace("“", '"') # yes, this actually happened.
244 | .replace("”", '"') # llms are really creative and crazy...
245 | .replace(
246 | "{{", "{ {"
247 | ) # for some reason the json parser doesnt like this
248 | .replace("}}", "} }")
249 | )
250 | except Exception as e:
251 | logger.warning(
252 | "JSON extraction from text failed: %s\n%s.\n\nTry enabling JSON grammar rules to avoid such errors.",
253 | repr(e),
254 | output_text,
255 | )
256 |
257 | json_text = "{}"
258 |
259 | output_text = (
260 | output_text.replace(json_text_original + "\n", "")
261 | .replace("\n" + json_text_original, "")
262 | .replace(json_text_original, "")
263 | .replace("Action: ```json\n", "")
264 | .replace("Action: ```json", "")
265 | .replace("Action:\n", "")
266 | .replace("Action:", "")
267 | .replace("\n```json", "")
268 | .replace("```json", "")
269 | .replace("```json\n", "")
270 | .replace("\n```", "")
271 | .replace("```", "")
272 | .strip("\r\n")
273 | .strip("\n")
274 | .strip()
275 | )
276 |
277 | json = None
278 |
279 | if json_search and json_text and json_text not in ["[]", "{}", "()"]:
280 | try:
281 | json = loads(json_text)
282 | except Exception as e:
283 | logger.warning(
284 | "Failed to parse JSON from output text: %s\n%s\n\nTry enabling JSON grammar rules to avoid such errors.",
285 | repr(e),
286 | json_text,
287 | exc_info=True,
288 | )
289 |
290 | if json is not None:
291 | tools: list[Any] = json if isinstance(json, list) else [json]
292 |
293 | for tool in tools:
294 | tool_name: str = (
295 | tool.get("tool", None)
296 | or tool.get("tool name", None)
297 | or tool.get("tool_name", None)
298 | or tool.get("tool call", None)
299 | or tool.get("tool_call", None)
300 | or tool.get("name", None)
301 | or tool.get("function", None)
302 | or tool.get("function_name", None)
303 | or tool.get("function name", None)
304 | or tool.get("function_call", None)
305 | or tool.get("function call", None)
306 | )
307 |
308 | tool_params: dict = (
309 | tool.get("tool_parameters", None)
310 | or tool.get("tool parameters", None)
311 | or tool.get("parameters", None)
312 | or tool.get("tool_params", None)
313 | or tool.get("tool params", None)
314 | or tool.get("params", None)
315 | or tool.get("tool_arguments", None)
316 | or tool.get("tool arguments", None)
317 | or tool.get("arguments", None)
318 | or tool.get("tool_args", None)
319 | or tool.get("tool args", None)
320 | or tool.get("args", None)
321 | )
322 |
323 | if not tool_name or not tool_params:
324 | continue
325 |
326 | if tool_name.lower() in [
327 | "generate_image",
328 | "generate image",
329 | "generateimage",
330 | ]:
331 | context_prompt = (
332 | tool_params.get("text", None)
333 | or tool_params.get("prompt", None)
334 | or tool_params.get("query", None)
335 | or ""
336 | )
337 |
338 | if tool_name.lower() in ["add_text", "add text", "addtext"]:
339 | tool_text = (
340 | tool_params.get("text", None)
341 | or tool_params.get("prompt", None)
342 | or tool_params.get("query", None)
343 | or ""
344 | )
345 | output_text = tool_text + (
346 | "\n" + output_text if output_text else ""
347 | )
348 |
349 | if context_prompt is None:
350 | return (
351 | output_text,
352 | None,
353 | None,
354 | None,
355 | None,
356 | None,
357 | )
358 |
359 | if ":" in context_prompt:
360 | context_prompt = (
361 | ", ".join(context_prompt.split(":")[1:])
362 | .replace(".", ",")
363 | .replace(":", ",")
364 | .strip()
365 | .strip("\n")
366 | .strip()
367 | .split("\n")[0]
368 | .strip()
369 | .lower()
370 | )
371 |
372 | generated_prompt = _combine_prompts(rules_prompt, normalize_prompt(context_prompt))
373 | generated_negative_prompt = rules_negative_prompt
374 |
375 | full_prompt = _combine_prompts(generated_prompt, context.params.base_prompt)
376 |
377 | full_negative_prompt = _combine_prompts(
378 | generated_negative_prompt, context.params.base_negative_prompt
379 | )
380 |
381 | debug_info = (
382 | (
383 | f"\n"
384 | f" Prompt: {full_prompt}\n"
385 | f" Negative Prompt: {full_negative_prompt}"
386 | )
387 | if context.params.debug_mode_enabled
388 | else ""
389 | )
390 |
391 | logger.info(
392 | "[SD WebUI Integration] Using stable-diffusion-webui to generate images. %s",
393 | debug_info,
394 | )
395 |
396 | try:
397 | response = sd_client.txt2img(
398 | prompt=full_prompt,
399 | negative_prompt=full_negative_prompt,
400 | seed=context.params.seed,
401 | sampler_name=context.params.sampler_name,
402 | full_quality=True,
403 | enable_hr=context.params.upscaling_enabled
404 | or context.params.hires_fix_enabled,
405 | hr_scale=context.params.upscaling_scale,
406 | hr_upscaler=context.params.upscaling_upscaler,
407 | denoising_strength=context.params.hires_fix_denoising_strength,
408 | hr_sampler=context.params.hires_fix_sampler,
409 | hr_force=context.params.hires_fix_enabled,
410 | hr_second_pass_steps=(
411 | context.params.hires_fix_sampling_steps
412 | if context.params.hires_fix_enabled
413 | else 0
414 | ),
415 | steps=context.params.sampling_steps,
416 | cfg_scale=context.params.cfg_scale,
417 | width=context.params.width,
418 | height=context.params.height,
419 | restore_faces=context.params.restore_faces_enabled,
420 | faceid_enabled=context.params.faceid_enabled,
421 | faceid_mode=context.params.faceid_mode,
422 | faceid_model=context.params.faceid_model,
423 | faceid_image=context.params.faceid_source_face,
424 | faceid_scale=context.params.faceid_strength,
425 | faceid_structure=context.params.faceid_structure,
426 | faceid_rank=context.params.faceid_rank,
427 | faceid_override_sampler=context.params.faceid_override_sampler,
428 | faceid_tokens=context.params.faceid_tokens,
429 | faceid_cache_model=context.params.faceid_cache_model,
430 | ipadapter_enabled=context.params.ipadapter_enabled,
431 | ipadapter_adapter=context.params.ipadapter_adapter,
432 | ipadapter_scale=context.params.ipadapter_scale,
433 | ipadapter_image=context.params.ipadapter_reference_image,
434 | use_async=False,
435 | )
436 |
437 | response = cast(WebUIApiResult, response)
438 |
439 | if len(response.images) == 0:
440 | logger.error("[SD WebUI Integration] Failed to generate any images.")
441 | return (
442 | output_text,
443 | None,
444 | generated_prompt,
445 | generated_negative_prompt,
446 | full_prompt,
447 | full_negative_prompt,
448 | )
449 |
450 | formatted_result = ""
451 | style = 'style="width: 100%; max-height: 100vh;"'
452 |
453 | from ..script import EXTENSION_DIRECTORY_NAME
454 |
455 | image: Image.Image
456 | for image in response.images:
457 | if faceswaplab_force_enabled or (
458 | faceswaplab_force_enabled is None and context.params.faceswaplab_enabled
459 | ):
460 | if context.params.debug_mode_enabled:
461 | logger.info(
462 | "[SD WebUI Integration] Using FaceSwapLab to swap faces."
463 | )
464 |
465 | try:
466 | response = sd_client.faceswaplab_swap_face(
467 | image,
468 | params=dataclasses.replace(
469 | context.params,
470 | faceswaplab_source_face=(
471 | faceswaplab_overwrite_source_face
472 | if faceswaplab_overwrite_source_face is not None
473 | else context.params.faceswaplab_source_face
474 | ).replace(
475 | "{STABLE_DIFFUSION_EXTENSION_DIRECTORY}",
476 | f"./extensions/{EXTENSION_DIRECTORY_NAME}",
477 | ),
478 | ),
479 | use_async=False,
480 | )
481 | image = response.image # type: ignore
482 | except Exception as e:
483 | logger.error(
484 | "[SD WebUI Integration] FaceSwapLab failed to swap faces: %s",
485 | e,
486 | exc_info=True,
487 | )
488 |
489 | if reactor_force_enabled or (
490 | reactor_force_enabled is None and context.params.reactor_enabled
491 | ):
492 | if context.params.debug_mode_enabled:
493 | logger.info("[SD WebUI Integration] Using ReActor to swap faces.")
494 |
495 | try:
496 | response = sd_client.reactor_swap_face(
497 | image,
498 | params=dataclasses.replace(
499 | context.params,
500 | reactor_source_face=(
501 | reactor_overwrite_source_face
502 | if reactor_overwrite_source_face is not None
503 | else context.params.reactor_source_face
504 | ).replace(
505 | "{STABLE_DIFFUSION_EXTENSION_DIRECTORY}",
506 | f"./extensions/{EXTENSION_DIRECTORY_NAME}",
507 | ),
508 | ),
509 | use_async=False,
510 | )
511 | image = response.image # type: ignore
512 | except Exception as e:
513 | logger.error(
514 | "[SD WebUI Integration] ReActor failed to swap faces: %s",
515 | e,
516 | exc_info=True,
517 | )
518 |
519 | if context.params.save_images:
520 | character = (
521 | context.state.get("character_menu", "Default")
522 | if context.state
523 | else "Default"
524 | )
525 |
526 | file = f'{date.today().strftime("%Y_%m_%d")}/{character}_{int(time.time())}' # noqa: E501
527 |
528 | # todo: do not hardcode extension path
529 | output_file = Path(
530 | f"extensions/{EXTENSION_DIRECTORY_NAME}/outputs/{file}.png"
531 | )
532 | output_file.parent.mkdir(parents=True, exist_ok=True)
533 |
534 | image.save(output_file)
535 | image_source = f"/file/{output_file}"
536 | else:
537 | # resize image to avoid huge logs
538 | image.thumbnail((512, int(512 * image.height / image.width)))
539 |
540 | buffered = io.BytesIO()
541 | image.save(buffered, format="JPEG")
542 | buffered.seek(0)
543 | image_bytes = buffered.getvalue()
544 | image_base64 = (
545 | "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
546 | )
547 | image_source = image_base64
548 |
549 | formatted_result += f'
\n'
550 |
551 | finally:
552 | attempt_vram_reallocation(VramReallocationTarget.LLM, context)
553 |
554 | return (
555 | output_text,
556 | formatted_result.rstrip("\n"),
557 | generated_prompt,
558 | generated_negative_prompt,
559 | full_prompt,
560 | full_negative_prompt,
561 | )
562 |
563 |
564 | def _combine_prompts(prompt1: str, prompt2: str) -> str:
565 | if prompt1 is None and prompt2 is None:
566 | return ""
567 |
568 | if prompt1 is None or prompt1 == "":
569 | return prompt2.strip(",").strip()
570 |
571 | if prompt2 is None or prompt2 == "":
572 | return prompt1.strip(",").strip()
573 |
574 | return prompt1.strip(",").strip() + ", " + prompt2.strip(",").strip()
575 |
--------------------------------------------------------------------------------
/settings.debug.yaml:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------------------------------------------#
2 | # STABLE DIFFUSION EXTENSION #
3 | # This file contains the documentation and debug settings for the Stable Diffusion extension. #
4 | # --------------------------------------------------------------------------------------------#
5 |
6 | dark_theme: true
7 | show_controls: true
8 | mode: chat
9 | chat_style: cai-chat
10 | character: Assistant
11 | preset: Debug-deterministic
12 | seed: 1337
13 | stream: true
14 |
15 | default_extensions:
16 | - gallery
17 | - stable_diffusion
18 |
19 | #----------------------#
20 | # API ENDPOINT DETAILS #
21 | #----------------------#
22 |
23 | ## Sets the API endpoint to use for generating images.
24 | ## If you are using the default stable-diffusion-webui settings, you do not need to change this.
25 | stable_diffusion-api_endpoint: "http://127.0.0.1:7860/sdapi/v1"
26 |
27 | ## Leave as-is if you did not set up any authentication for the API.
28 | stable_diffusion-api_username: ""
29 | stable_diffusion-api_password: ""
30 |
31 | #-----------------------------#
32 | # IMAGE GENERATION PARAMETERS #
33 | #-----------------------------#
34 |
35 | stable_diffusion-base_prompt: "RAW photo, subject, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
36 | stable_diffusion-base_negative_prompt: "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime), text, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
37 | stable_diffusion-sampler_name: "DPM SDE"
38 | stable_diffusion-sampling_steps: 25
39 | stable_diffusion-width: 512
40 | stable_diffusion-height: 512
41 | stable_diffusion-cfg_scale: 6
42 | stable_diffusion-clip_skip: 1
43 | stable_diffusion-seed: -1
44 |
45 | #------------------#
46 | # USER PREFERENCES #
47 | #------------------#
48 |
49 | ## Sets if debug mode (e.g. for additional logs) should be enabled
50 | stable_diffusion-debug_mode_enabled: true
51 |
52 | ## Sets if generated images should be saved to the "outputs" folder inside the stable_diffusion extension directory.
53 | stable_diffusion-save_images: true
54 |
55 | ## Defines how image generation should be triggered. Possible values:
56 | ## - "tool": Generate images using tool calls (requires special models and prompt modifications).
57 | ##
58 | ## This is overall the best and most accurate option, as the LLM model itself triggers and generates the prompt for image generation,
59 | ## similar to how ChatGPT does it. This is also the only option that supports image generation while remembering the chat history.
60 | # However, it also the most difficult option to set up as it requires some special text generation models and some
61 | ## minor system prompt adjustments.
62 | ##
63 | ## Firstly you will need to use a model that supports function / tool calls like Command-R or Llama-2-chat-7b-hf-function-calling-v2.
64 | ## This option works best with models that are compatible to OpenAI tools spec. Once you have a found a suitable model, define two
65 | ## tools in your system prompt: add_text (with a "text" string parameter) and generate_image (with a "prompt" string parameter). Then
66 | ## instruct your chatbot or character to use these tools whenever needed.
67 | ##
68 | ## It is also highly recommended to either enable tool_mode_force_json_output_enabled or to force start the prompt directly in a JSON format
69 | ## like e.g. this start_with for Command-R:
70 | ##
71 | ## start_with: |-
72 | ## Action: ```json
73 | ## [
74 | ## {
75 | ## "tool_name": "generate_image",
76 | ## "parameters": {
77 | ## "prompt": "
78 | ##
79 | ## Consult the documentation for your model for more information regarding how tool / function calls work, how tools are setup and
80 | ## what the response must start with. For Command-R, see https://docs.cohere.com/docs/prompting-command-r.
81 | ##
82 | ## - "continuous": Generate images for all replies without any specific triggers.
83 | ## This option supports including the output text as image generation prompt but will ignore any past chat history.
84 | ##
85 | ## - "interactive": Generate images only if a message with a triggering text was sent or received.
86 | ## This option supports including the output text as image generation prompt but will also ignore any past chat history.
87 | ##
88 | ## - "manual": Generates images only if the image generation button was pressed.
89 | ## This option is currently not implemented and can be used as an off-switch until then.
90 |
91 | stable_diffusion-trigger_mode: "tool"
92 |
93 | ## Forces model to output correct JSON for tool calls
94 | ## Only works with transformers / HF based loaders, ignored if using a different loader
95 | stable_diffusion-tool_mode_force_json_output_enabled: false
96 | stable_diffusion-tool_mode_force_json_output_schema: |-
97 | {
98 | "type": "array",
99 | "items": {
100 | "type": "object",
101 | "properties": {
102 | "tool_name": {
103 | "type": "string"
104 | },
105 | "parameters": {
106 | "type": "object",
107 | "additionalProperties": {"type": "string"}
108 | }
109 | },
110 | "required": ["tool_name", "parameters"]
111 | }
112 | }
113 |
114 | ## Set's how the prompt for image generation should be generated. Possible values:
115 | ## - "static": Uses the prompt option as-is ignoring any chat context.
116 | ## - "generated_text": Uses the generated output as-is as prompt.
117 | ## - "dynamic": Generates a dynamic prompt using the subject_regex, default_subject and description_prompt options.
118 | ## The result is combined with the base_prompt and base_negative_prompt options.
119 | stable_diffusion-interactive_mode_prompt_generation_mode: "dynamic"
120 |
121 | ## Defines the regex pattern for the input message which triggers image generation in interactive mode.
122 | stable_diffusion-interactive_mode_input_trigger_regex: >-
123 | .*(draw|paint|create|send|upload|add|show|attach|generate)\b.+?\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)(s?)
124 |
125 | ## Defines the regex pattern for the generated output message which triggers image generation in interactive mode.
126 | stable_diffusion-interactive_mode_output_trigger_regex: >-
127 | .*[*([]?(draws|paints|creates|sends|uploads|adds|shows|attaches|generates|here (is|are))\b.+?\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)(s?)
128 |
129 | ## Defines the regex pattern for extracting the subject of the message for dynamic prompt generation in interactive mode.
130 | ## Only used when prompt_generation_mode is set to "dynamic".
131 | stable_diffusion-interactive_mode_subject_regex: >-
132 | .*\b(of)\b(.+?)(?:[.!?]|$)
133 |
134 | ## Sets the default subject to use instead if no subject was found in the input message using the subject_regex option.
135 | ## Only used when prompt_generation_mode is set to "dynamic".
136 | stable_diffusion-interactive_mode_default_subject: "your appearance, your surroundings and what you are doing right now"
137 |
138 | ## The text to use for generating a stable diffusion compatible description of the given subject. Replaces the original input message.
139 | ## Only used when prompt_generation_mode is set to "dynamic".
140 | stable_diffusion-interactive_mode_description_prompt: >
141 | You are now a text generator for the Stable Diffusion AI image generator. You will generate a text prompt for it.
142 |
143 | Describe [subject] using comma-separated tags only. Do not use sentences.
144 | Include many tags such as tags for the environment, gender, clothes, age, location, light, daytime, angle, pose, etc.
145 |
146 | Very important: only write the comma-separated tags. Do not write anything else. Do not ask any questions. Do not talk.
147 |
148 | ## Set's how the base prompt for image generation should be generated. Possible values:
149 | ## - "static": Uses the prompt option as-is ignoring any chat context.
150 | ## - "generated_text": Uses the generated output as-is as prompt.
151 | ## The result is combined with the base_prompt and base_negative_prompt options.
152 | stable_diffusion-continuous_mode_prompt_generation_mode: "generated_text"
153 |
154 | ## If enabled, will automatically unload the LLM model from VRAM and then load the SD model instead when generating images.
155 | ## After the image is generated, it will unload the SD model again and then reload the LLM model.
156 | ## Saves VRAM but will slow down generation speed. Only recommended if you have a low amount of VRAM or use very large models.
157 | stable_diffusion-dynamic_vram_reallocation_enabled: false
158 |
159 | ## Do not stream messages if generating images at the same time. Improves generation speed.
160 | stable_diffusion-dont_stream_when_generating_images: true
161 |
162 | ## Defines regex based rules that triggers the given actions.
163 | ## regex: The regex pattern that triggers the action (optional)
164 | ## negative_regex: Do not trigger the action if the text matches this regex (optional)
165 | ## match: A list of where to match the regex. Available options:
166 | ## - "input": match on input text.
167 | ## - "input_sentence": match on any sentence in input text.
168 | ## - "output": match on generated output text.
169 | ## - "output_sentence": match on any sentence in generated output text.
170 | ## - "character_name": match on current character name (only if using gallery extension).
171 | ## actions: A list of actions to perform if the regex matches.
172 | ## - name: The name of the action to perform.
173 | ## Available options:
174 | ## - "prompt_append": appends the given text in "args" to the image generation prompt if an image is to be generated.
175 | ## - "negative_prompt_append": appends the given text in "args" to the image generation negative prompt if an image is to be generated.
176 | ## - "skip_generation": force skips any image generation.
177 | ## - "faceswaplab_enable": force enables face swap with FaceSwapLab.
178 | ## - "faceswaplab_disable": force disables face swap with FaceSwapLab.
179 | ## - "faceswaplab_set_source_face": sets source face for FaceSwapLab (requires "args" to be set to a valid source face).
180 | ## - "reactor_enable": force enables face swap with ReActor.
181 | ## - "reactor_disable": force disables face swap with ReActor.
182 | ## - "reactor_set_source_face": sets source face for ReActor (requires "args" to be set to a valid source face).
183 | ## - args: The arguments to pass to the action (optional).
184 | stable_diffusion-generation_rules:
185 | # Add details to the prompt if the input text or output text contains the word "detailed".
186 | - regex: .*\b(detailed)\b
187 | match: ["input", "output"]
188 | actions:
189 | - name: "prompt_append"
190 | args: "(high resolution, detailed, realistic, vivid: 1.2), hdr, 8k, "
191 |
192 | # Append a prompt and negative prompt describing the characters look if the character's name equals "Assistant".
193 | - regex: ^Assistant$
194 | match: ["character_name"]
195 | actions:
196 | - name: "prompt_append"
197 | args: "small cute robot, monochrome, droid, 3d render, white reflective plastic body, simple, 3DMM, "
198 | - name: "negative_prompt_append"
199 | args: "humanoid, human, person, animal, anthropomorphic"
200 |
201 | # Enable face swap via FaceSwapLab (see below for documentation) if the character's name equals "Example".
202 | - regex: ^Example$
203 | match: ["character_name"]
204 | actions:
205 | - name: "faceswaplab_enable"
206 | - name: "faceswaplab_set_source_face"
207 | args: "file:///{STABLE_DIFFUSION_EXTENSION_DIRECTORY}/assets/example_face.jpg"
208 |
209 | #-----------------#
210 | # POST PROCESSING #
211 | #-----------------#
212 |
213 | ## Sets if generated images should be upscaled.
214 | stable_diffusion-upscaling_enabled: false
215 |
216 | ## Sets the upscaler to use for upscaling generated images.
217 | ## Some examples are: Latent, LDSR, RealESRGAN 4x+, Lanczos, Nearest, etc.
218 | stable_diffusion-upscaling_upscaler: "RealESRGAN 4x+"
219 |
220 | ## Amount to upscale by (1 = 100%, 2 = 200%, etc.).
221 | stable_diffusion-upscaling_scale: 2
222 |
223 | ## Sets if HiRes.fix should be enabled.
224 | stable_diffusion-hires_fix_enabled: false
225 |
226 | ## Sets the sampler to use for HiRes.fix.
227 | stable_diffusion-hires_fix_sampler: "UniPC"
228 |
229 | ## Sets the amount of steps for HiRes.fix.
230 | stable_diffusion-hires_sampling_steps: 10
231 |
232 | ## Sets the denoising strength for HiRes.fix.
233 | stable_diffusion-hires_fix_denoising_strength: 0.2
234 |
235 | ## Sets if faces should be enhanced (or "restored") in generated images.
236 | stable_diffusion-restore_faces_enabled: false
237 |
238 | #-------------#
239 | # FACESWAPLAB #
240 | #-------------#
241 |
242 | ## Apply face swapping using FaceSwapLab.
243 | ## Requires the sd-webui-faceswaplab extension to be installed.
244 | ## Repository: https://github.com/glucauze/sd-webui-faceswaplab
245 |
246 | ## Sets if faces should be swapped in generated images.
247 | stable_diffusion-faceswaplab_enabled: false
248 |
249 | ## Sets the source image with the face to use for face swapping.
250 | ## It's possible to set it in 3 ways:
251 | ## 1. Local file: "file:///./example.jpg"
252 | ## 2. URL: "https://some-site.com/example.png"
253 | ## 3. (recommended) FaceSwapLab face checkpoint: "checkpoint://example"
254 | ## You can see the list of available checkpoints in the "models/faceswaplab/faces" directory inside your Stable Diffusion WebUI directory.
255 | ## See https://github.com/glucauze/sd-webui-faceswaplab#build-and-use-checkpoints- for more on how to create face checkpoints.
256 | stable_diffusion-faceswaplab_source_face: "file:///{STABLE_DIFFUSION_EXTENSION_DIRECTORY}/assets/example_face.jpg"
257 |
258 | ## Only swap faces if same gender.
259 | stable_diffusion-faceswaplab_same_gender_only: true
260 |
261 | ## If enabled, order source faces by size.
262 | ## Otherwise, order source faces from left to right.
263 | stable_diffusion-faceswaplab_sort_by_size: true
264 |
265 | ## Use the nth face in the source face image as reference face
266 | ## Note: the first face is 0, the second face is 1, etc.
267 | ##
268 | ## Example:
269 | ## If you have 3 faces in the source image and set this to 1, it will use the second face from left to right if sort_by_size is set to false.
270 | ## If sort_by_size is true, it will use the second largest face instead.
271 | stable_diffusion-faceswaplab_source_face_index: 0
272 |
273 | ## Use the nth face in the generated image as the face to swap out
274 | ## See source_face_index for more info
275 | stable_diffusion-faceswaplab_target_face_index: 0
276 |
277 | ## Sets if the face should be upscaled
278 | stable_diffusion-faceswaplab_upscaling_enabled: false
279 |
280 | ## Sets the upscaler to use for upscaling faces
281 | ## Some examples are: Latent, LDSR, RealESRGAN 4x+, Lanczos, Nearest, etc.
282 | stable_diffusion-faceswaplab_upscaling_upscaler: "RealESRGAN 4x+"
283 |
284 | ## Amount to upscale the face by (1 = 100%, 2 = 200%, etc.)
285 | stable_diffusion-faceswaplab_upscaling_scale: 2
286 |
287 | ## Visibility of the upscaled face (0.0 - 1.0)
288 | stable_diffusion-faceswaplab_upscaling_visibility: 1
289 |
290 | ## Sets if the final result should be upscaled
291 | stable_diffusion-faceswaplab_postprocessing_upscaling_enabled: false
292 |
293 | ## Sets the upscaler to use for upscaling final result image after swapping
294 | ## Some examples are: Latent, LDSR, RealESRGAN 4x+, Lanczos, Nearest, etc.
295 | stable_diffusion-faceswaplab_postprocessing_upscaling_upscaler: "RealESRGAN 4x+"
296 |
297 | ## Amount to upscale the final result by (1 = 100%, 2 = 200%, etc.)
298 | stable_diffusion-faceswaplab_postprocessing_upscaling_scale: 2
299 |
300 | ## Visibility of the final result upscale (0.0 - 1.0)
301 | stable_diffusion-faceswaplab_postprocessing_upscaling_visibility: 1
302 |
303 | ## Sets if the face should be enhanced (or "restored") during swapping
304 | stable_diffusion-faceswaplab_restore_face_enabled: false
305 |
306 | ## Model to use for enhancing the face (CodeFormer, GFPGAN)
307 | stable_diffusion-faceswaplab_restore_face_model: "CodeFormer"
308 |
309 | ## Visibility of the restored face (0.0 - 1.0)
310 | stable_diffusion-faceswaplab_restore_face_visibility: 1
311 |
312 | ## Weight of the CodeFormer model (0.0 - 1.0)
313 | stable_diffusion-faceswaplab_restore_face_codeformer_weight: 1
314 |
315 | ## Sets if the faces should be enhanced (or "restored") in the final result image after swapping
316 | stable_diffusion-faceswaplab_postprocessing_restore_face_enabled: false
317 |
318 | ## Model to use for restoring the faces (CodeFormer, GFPGAN)
319 | stable_diffusion-faceswaplab_postprocessing_restore_face_model: "CodeFormer"
320 |
321 | ## Visibility of the restored faces (0.0 - 1.0)
322 | stable_diffusion-faceswaplab_postprocessing_restore_face_visibility: 1
323 |
324 | ## Weight of the CodeFormer model (0.0 - 1.0)
325 | stable_diffusion-faceswaplab_postprocessing_restore_face_codeformer_weight: 1
326 |
327 | ## Sets if color corrections should be applied
328 | stable_diffusion-faceswaplab_color_corrections_enabled: false
329 |
330 | ## Sets the erosion factor for the mask
331 | stable_diffusion-faceswaplab_mask_erosion_factor: 1
332 |
333 | ## Use improved segmented mask (use pastenetto mask only the face )
334 | ## Note: you should enable upscaling if you enable this option
335 | stable_diffusion-faceswaplab_mask_improved_mask_enabled: false
336 |
337 | ## Sharpen the face
338 | stable_diffusion-faceswaplab_sharpen_face: false
339 |
340 | ## Sets if faces should be blended in generated images
341 | stable_diffusion-faceswaplab_blend_faces: true
342 |
343 | #---------#
344 | # ReActor #
345 | #---------#
346 |
347 | ## Apply face swapping using ReActor.
348 | ## Requires the sd-webui-reactor extension to be installed.
349 | ## Repository: https://github.com/Gourieff/sd-webui-reactor
350 |
351 | ## Sets if faces should be swapped in generated images.
352 | stable_diffusion-reactor_enabled: true
353 |
354 | ## Sets the source image with the face to use for face swapping.
355 | ## It's possible to set it in 3 ways:
356 | ## 1. Local file: "file:///./example.jpg"
357 | ## 2. URL: "https://some-site.com/example.png"
358 | ## 3. (recommended) ReActor face model: "checkpoint://example"
359 | ## You can see the list of available checkpoints in the "models/reactor/faces" directory inside your Stable Diffusion WebUI directory.
360 | stable_diffusion-reactor_source_face: "file:///{STABLE_DIFFUSION_EXTENSION_DIRECTORY}/assets/example_face.jpg"
361 |
362 | ## Sets the gender for the face in the source image (supported values: none, male, female)
363 | ## In other words, will only use the face with this gender as source face
364 | stable_diffusion-reactor_source_gender: "none"
365 |
366 | ## Sets which gender to target in the generated image for swapping (supported values: none, male, female)
367 | ## In other words, will only swap a face of this gender
368 | stable_diffusion-reactor_target_gender: "none"
369 |
370 | ## Use the nth face in the source face image as reference face
371 | ## Note: the first face is 0, the second face is 1, etc.
372 | ##
373 | ## Example:
374 | ## If you have 3 faces in the source image and set this to 1, it will use the second face from left to right and top to bottom.
375 | stable_diffusion-reactor_source_face_index: 0
376 |
377 | ## Use the nth face in the generated image as the face to swap out
378 | ## See source_face_index for more info
379 | stable_diffusion-reactor_target_face_index: 0
380 |
381 | ## Sets if the face should be enhanced (or "restored") after swapping
382 | stable_diffusion-reactor_restore_face_enabled: false
383 |
384 | ## Model to use for restoring the face (CodeFormer, GFPGAN)
385 | stable_diffusion-reactor_restore_face_model: "CodeFormer"
386 |
387 | ## Visibility of the restored face (0.0 - 1.0)
388 | stable_diffusion-reactor_restore_face_visibility: 1
389 |
390 | ## Weight of the CodeFormer model (0.0 - 1.0)
391 | stable_diffusion-reactor_restore_face_codeformer_weight: 1
392 |
393 | ## Upscale face first before enhancing it; otherwise restores face first then upscales it instead
394 | stable_diffusion-reactor_restore_face_upscale_first: false
395 |
396 | ## Sets if the face should be upscaled.
397 | stable_diffusion-reactor_upscaling_enabled: false
398 |
399 | ## Sets the upscaler to use for upscaling faces.
400 | ## Some examples are: Latent, LDSR, RealESRGAN 4x+, Lanczos, Nearest, etc.
401 | stable_diffusion-reactor_upscaling_upscaler: "RealESRGAN 4x+"
402 |
403 | ## Amount to upscale the face by (1 = 100%, 2 = 200%, etc.).
404 | stable_diffusion-reactor_upscaling_scale: 2
405 |
406 | ## Visibility of the upscaled face (0.0 - 1.0)
407 | stable_diffusion-reactor_upscaling_visibility: 1
408 |
409 | ## Sets if face mask correction should be enabled to fix pixelation around face contours
410 | stable_diffusion-reactor_mask_face: false
411 |
412 | ## Model to use for swapping faces
413 | stable_diffusion-reactor_model: "inswapper_128.onnx"
414 |
415 | ## Device to use for swapping faces (CPU, CUDA).
416 | ## CUDA recommended for faster inference if you have an NVIDIA GPU.
417 | ## Note: CUDA requires installation of the onnxruntime-gpu package instead of onnxruntime in stable-diffusion-webui
418 | stable_diffusion-reactor_device: "CPU"
419 |
420 | #---------#
421 | # FaceID #
422 | #---------#
423 |
424 | ## Apply face swapping using FaceID feature of SD.Next (a fork of AUTOMATIC1111).
425 | ## See: https://github.com/vladmandic/automatic for SD.Next repository.
426 | ##
427 | ## Works much better than ReActor or FaceSwapLab as
428 | ## the face not actually swapped but instead directly
429 | ## generated like the source face while the image is
430 | ## still being generated.
431 | ##
432 | ## Works with stylized images too, e.g. 3D renders, drawings, cartoon, paintings etc.
433 | ##
434 | ## WARNING: DOES NOT WORK WITH VANILLA AUTOMATIC1111. YOU _MUST_ USE SD.NEXT INSTEAD.
435 | ##
436 | ## Requires "insightface", "ip_adapter" and "onnxruntime-gpu" PIP packages to be installed in SD.Next.
437 |
438 | ## Sets if faces should be swapped in generated images.
439 | stable_diffusion-faceid_enabled: false
440 |
441 | ## Sets the source image with the face to use for face swapping.
442 | ## It's possible to set it in 2 difference ways:
443 | ## 1. Local file: "file:///./example.jpg"
444 | ## 2. URL: "https://some-site.com/example.png"
445 | stable_diffusion-faceid_source_face: "file:///{STABLE_DIFFUSION_EXTENSION_DIRECTORY}/assets/example_face.jpg"
446 |
447 | ## FaceID mode
448 | stable_diffusion-faceid_mode: ["FaceID", "FaceSwap"]
449 |
450 | ## Model to use for FaceID
451 | ## Available options:
452 | ## - FaceID Base
453 | ## - FaceID Plus
454 | ## - FaceID Plus v2
455 | ## - FaceID XL
456 | stable_diffusion-faceid_model: "FaceID Plus v2"
457 |
458 | ## Use recommended sampler for FaceID
459 | stable_diffusion-faceid_override_sampler: true
460 |
461 | ## Cache FaceID model for faster generation
462 | stable_diffusion-faceid_cache_model: false
463 |
464 | ## FaceID strength (0.0 - 2.0)
465 | stable_diffusion-faceid_strength: 1.0
466 |
467 | ## FaceID structure (0.0 - 1.0)
468 | stable_diffusion-faceid_structure: 1.0
469 |
470 | ## FaceID rank (4 - 256)
471 | stable_diffusion-faceid_rank: 128
472 |
473 | ## FaceID token count (1 - 16)
474 | stable_diffusion-faceid_tokens: 4
475 |
476 | #-------------#
477 | # IP ADAPTER #
478 | #-------------#
479 |
480 | ## Adjust the IP Adapter integration feature of SD.Next (a fork of AUTOMATIC1111).
481 | ## See: https://github.com/vladmandic/automatic for SD.Next repository.
482 | ## See: https://ip-adapter.github.io/ for IP Adapter paper.
483 | ##
484 | ## Can be used for face swapping as well similar to the FaceID feature
485 | ## (by using the "Plus Face" or the "Full face" adapter).
486 | ##
487 | ## WARNING: DOES NOT WORK WITH VANILLA AUTOMATIC1111. YOU _MUST_ USE SD.NEXT INSTEAD.
488 | ## Requires "ip_adapter" and "onnxruntime-gpu" PIP packages to be installed in SD.Next.
489 |
490 | ## Sets if IP adapter should be enabled.
491 | stable_diffusion-ipadapter_enabled: false
492 |
493 | ## Sets the source image to use for face swapping.
494 | ## It's possible to set it in 2 difference ways:
495 | ## 1. Local file: "file:///./example.jpg"
496 | ## 2. URL: "https://some-site.com/example.png"
497 | stable_diffusion-ipadapter_reference_image: "file:///{STABLE_DIFFUSION_EXTENSION_DIRECTORY}/assets/example_face.jpg"
498 |
499 | ## The adapter to use.
500 | ## Possible values:
501 | ## - "Base"
502 | ## - "Light"
503 | ## - "Plus"
504 | ## - "Plus Face"
505 | ## - "Full face"
506 | ## - "Base SDXL"
507 | stable_diffusion-ipadapter_adapter: "Base"
508 |
509 | ## Scale for the source face during image generation (0.0 - 1.0)
510 | stable_diffusion-ipadapter_scale: 0.5
511 |
--------------------------------------------------------------------------------
/ui.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List
2 | import gradio as gr
3 | from stringcase import sentencecase
4 | from modules.logging_colors import logger
5 | from modules.ui import refresh_symbol
6 | from .context import GenerationContext
7 | from .ext_modules.vram_manager import VramReallocationTarget, attempt_vram_reallocation
8 | from .params import (
9 | ContinuousModePromptGenerationMode,
10 | InteractiveModePromptGenerationMode,
11 | IPAdapterAdapter,
12 | )
13 | from .params import StableDiffusionWebUiExtensionParams as Params
14 | from .params import TriggerMode
15 | from .sd_client import SdWebUIApi
16 |
17 | STATUS_SUCCESS = "#00FF00"
18 | STATUS_PROGRESS = "#FFFF00"
19 | STATUS_FAILURE = "#FF0000"
20 |
21 | refresh_listeners: List[Any] = []
22 | connect_listeners: List[Any] = []
23 |
24 | status: gr.Label | None = None
25 | status_text: str = ""
26 |
27 | refresh_button: gr.Button | None = None
28 |
29 | sd_client: SdWebUIApi | None = None
30 | sd_samplers: List[str] = []
31 | sd_upscalers: List[str] = []
32 | sd_checkpoints: List[str] = []
33 | sd_current_checkpoint: str = ""
34 | sd_vaes: List[str] = []
35 | sd_current_vae: str = ""
36 |
37 | sd_connected: bool = True
38 | sd_options: Any = None
39 |
40 |
41 | def render_ui(params: Params) -> None:
42 | _render_status()
43 | _refresh_sd_data(params)
44 |
45 | _render_connection_details(params)
46 | _render_prompts(params)
47 | _render_models(params)
48 | _render_generation_parameters(params)
49 |
50 | with gr.Row():
51 | _render_chat_config(params)
52 |
53 | with gr.Row():
54 | _render_faceswaplab_config(params)
55 | _render_reactor_config(params)
56 |
57 | with gr.Row():
58 | _render_faceid_config(params)
59 | _render_ipadapter_config(params)
60 |
61 |
62 | def _render_connection_details(params: Params) -> None:
63 | global refresh_button
64 |
65 | with gr.Accordion("Connection details", open=True):
66 | with gr.Row():
67 | with gr.Column():
68 | api_username = gr.Textbox(
69 | label="Username",
70 | placeholder="Leave empty if no authentication is required",
71 | value=lambda: params.api_username or "",
72 | )
73 | api_username.change(
74 | lambda new_username: params.update({"api_username": new_username}),
75 | api_username,
76 | None,
77 | )
78 |
79 | api_password = gr.Textbox(
80 | label="Password",
81 | placeholder="Leave empty if no authentication is required",
82 | value=lambda: params.api_password or "",
83 | type="password",
84 | )
85 | api_password.change(
86 | lambda new_api_password: params.update(
87 | {"api_password": new_api_password}
88 | ),
89 | api_password,
90 | None,
91 | )
92 |
93 | with gr.Column():
94 | api_endpoint = gr.Textbox(
95 | label="API Endpoint",
96 | placeholder=params.api_endpoint,
97 | value=lambda: params.api_endpoint,
98 | )
99 | api_endpoint.change(
100 | lambda new_api_endpoint: params.update(
101 | {"api_endpoint": new_api_endpoint}
102 | ),
103 | api_endpoint,
104 | None,
105 | )
106 |
107 | refresh_button = gr.Button(
108 | refresh_symbol + " Connect / refresh data",
109 | interactive=True,
110 | )
111 | refresh_button.click(
112 | lambda: _refresh_sd_data(params, force_refetch=True),
113 | inputs=[],
114 | outputs=refresh_listeners,
115 | )
116 |
117 |
118 | def _render_prompts(params: Params) -> None:
119 | with gr.Accordion("Prompt Settings", open=True, visible=sd_connected) as prompts:
120 | connect_listeners.append(prompts)
121 |
122 | with gr.Row():
123 | prompt = gr.Textbox(
124 | label="Base prompt used for image generation",
125 | placeholder=params.base_prompt,
126 | value=lambda: params.base_prompt,
127 | )
128 | prompt.change(
129 | lambda new_prompt: params.update({"base_prompt": new_prompt}),
130 | prompt,
131 | None,
132 | )
133 |
134 | negative_prompt = gr.Textbox(
135 | label="Base negative prompt used for image generation",
136 | placeholder=params.base_negative_prompt,
137 | value=lambda: params.base_negative_prompt,
138 | )
139 | negative_prompt.change(
140 | lambda new_prompt: params.update({"base_negative_prompt": new_prompt}),
141 | negative_prompt,
142 | None,
143 | )
144 |
145 |
146 | def _render_models(params: Params) -> None:
147 | with gr.Accordion("Models", open=True, visible=sd_connected) as models:
148 | connect_listeners.append(models)
149 |
150 | with gr.Row():
151 | global sd_current_checkpoint, sd_current_vae
152 |
153 | checkpoint = gr.Dropdown(
154 | label="Checkpoint",
155 | choices=sd_checkpoints, # type: ignore
156 | value=lambda: sd_current_checkpoint, # checkpoint is not defined in params # noqa: E501
157 | )
158 | checkpoint.change(
159 | lambda new_checkpoint: _load_checkpoint(new_checkpoint, params),
160 | checkpoint,
161 | None,
162 | )
163 | refresh_listeners.append(checkpoint)
164 |
165 | vae = gr.Dropdown(
166 | label="VAE",
167 | choices=sd_vaes + ['None'], # type: ignore
168 | value=lambda: sd_current_vae, # vae is not defined in params
169 | )
170 | vae.change(
171 | lambda new_vae: _load_vae(new_vae, params),
172 | vae,
173 | None,
174 | )
175 | refresh_listeners.append(vae)
176 |
177 |
178 | def _render_generation_parameters(params: Params) -> None:
179 | with gr.Accordion(
180 | "Generation Parameters", open=True, visible=sd_connected
181 | ) as generation_params:
182 | connect_listeners.append(generation_params)
183 |
184 | with gr.Row():
185 | with gr.Row():
186 | width = gr.Number(
187 | label="Width",
188 | maximum=2048,
189 | value=lambda: params.width,
190 | )
191 | width.change(
192 | lambda new_width: params.update({"width": new_width}),
193 | width,
194 | None,
195 | )
196 |
197 | height = gr.Number(
198 | label="Height",
199 | maximum=2048,
200 | value=lambda: params.height,
201 | )
202 | height.change(
203 | lambda new_height: params.update({"height": new_height}),
204 | height,
205 | None,
206 | )
207 |
208 | with gr.Column():
209 | with gr.Row(elem_id="sampler_row"):
210 | sampler_name = gr.Dropdown(
211 | label="Sampling method",
212 | choices=sd_samplers, # type: ignore
213 | value=lambda: params.sampler_name,
214 | elem_id="sampler_box",
215 | )
216 | sampler_name.change(
217 | lambda new_sampler_name: params.update(
218 | {"sampler_name": new_sampler_name}
219 | ),
220 | sampler_name,
221 | None,
222 | )
223 | refresh_listeners.append(sampler_name)
224 |
225 | steps = gr.Slider(
226 | label="Sampling steps",
227 | minimum=1,
228 | maximum=150,
229 | value=lambda: params.sampling_steps,
230 | step=1,
231 | elem_id="steps_box",
232 | )
233 | steps.change(
234 | lambda new_steps: params.update({"sampling_steps": new_steps}),
235 | steps,
236 | None,
237 | )
238 |
239 | clip_skip = gr.Slider(
240 | label="CLIP skip",
241 | minimum=1,
242 | maximum=4,
243 | value=lambda: params.clip_skip,
244 | step=1,
245 | elem_id="clip_skip_box",
246 | )
247 | clip_skip.change(
248 | lambda new_clip_skip: params.update(
249 | {"clip_skip": new_clip_skip}
250 | ),
251 | clip_skip,
252 | None,
253 | )
254 |
255 | with gr.Row():
256 | seed = gr.Number(
257 | label="Seed (use -1 for random)",
258 | value=lambda: params.seed,
259 | elem_id="seed_box",
260 | )
261 | seed.change(lambda new_seed: params.update({"seed": new_seed}), seed, None)
262 |
263 | cfg_scale = gr.Slider(
264 | label="CFG Scale",
265 | value=lambda: params.cfg_scale,
266 | minimum=1,
267 | maximum=30,
268 | elem_id="cfg_box",
269 | step=0.5,
270 | )
271 | cfg_scale.change(
272 | lambda new_cfg_scale: params.update({"cfg_scale": new_cfg_scale}),
273 | cfg_scale,
274 | None,
275 | )
276 |
277 | with gr.Column() as hr_options:
278 | restore_faces = gr.Checkbox(
279 | label="Restore faces", value=lambda: params.restore_faces_enabled
280 | )
281 | restore_faces.change(
282 | lambda new_value: params.update(
283 | {"restore_faces_enabled": new_value}
284 | ),
285 | restore_faces,
286 | None,
287 | )
288 |
289 | enable_hr = gr.Checkbox(
290 | label="Upscale image", value=lambda: params.upscaling_enabled
291 | )
292 | enable_hr.change(
293 | lambda new_value: params.update({"upscaling_enabled": new_value}),
294 | enable_hr,
295 | None,
296 | )
297 |
298 | with gr.Row(
299 | visible=params.upscaling_enabled, elem_classes="hires_opts"
300 | ) as hr_options:
301 | connect_listeners.append(hr_options)
302 |
303 | enable_hr.change(
304 | lambda enabled: hr_options.update(visible=enabled),
305 | enable_hr,
306 | hr_options,
307 | )
308 |
309 | hr_upscaler = gr.Dropdown(
310 | label="Upscaler",
311 | choices=sd_upscalers, # type: ignore
312 | value=lambda: params.upscaling_upscaler,
313 | allow_custom_value=True,
314 | )
315 | hr_upscaler.change(
316 | lambda new_upscaler: params.update(
317 | {"upscaling_upscaler": new_upscaler}
318 | ),
319 | hr_upscaler,
320 | None,
321 | )
322 | refresh_listeners.append(hr_upscaler)
323 |
324 | hr_scale = gr.Slider(
325 | label="Upscale amount",
326 | minimum=1,
327 | maximum=4,
328 | value=lambda: params.upscaling_scale,
329 | step=0.01,
330 | )
331 | hr_scale.change(
332 | lambda new_value: params.update({"upscaling_scale": new_value}),
333 | hr_scale,
334 | None,
335 | )
336 |
337 | hires_fix_denoising_strength = gr.Slider(
338 | label="Denoising strength",
339 | minimum=0,
340 | maximum=1,
341 | value=lambda: params.hires_fix_denoising_strength,
342 | step=0.01,
343 | )
344 | hires_fix_denoising_strength.change(
345 | lambda new_value: params.update(
346 | {"hires_fix_denoising_strength": new_value}
347 | ),
348 | hires_fix_denoising_strength,
349 | None,
350 | )
351 |
352 |
353 | def _render_faceswaplab_config(params: Params) -> None:
354 | with gr.Accordion(
355 | "FaceSwapLab", open=True, visible=sd_connected
356 | ) as faceswap_config:
357 | connect_listeners.append(faceswap_config)
358 |
359 | with gr.Column():
360 | faceswap_enabled = gr.Checkbox(
361 | label="Enabled", value=lambda: params.faceswaplab_enabled
362 | )
363 |
364 | faceswap_enabled.change(
365 | lambda new_enabled: params.update({"faceswaplab_enabled": new_enabled}),
366 | faceswap_enabled,
367 | None,
368 | )
369 |
370 | faceswap_source_face = gr.Text(
371 | label="Source face",
372 | placeholder="See documentation for details...",
373 | value=lambda: params.faceswaplab_source_face,
374 | )
375 |
376 | faceswap_source_face.change(
377 | lambda new_source_face: params.update(
378 | {"faceswaplab_source_face": new_source_face}
379 | ),
380 | faceswap_source_face,
381 | None,
382 | )
383 |
384 |
385 | def _render_reactor_config(params: Params) -> None:
386 | with gr.Accordion("ReActor", open=True, visible=sd_connected) as reactor_config:
387 | connect_listeners.append(reactor_config)
388 |
389 | with gr.Column():
390 | reactor_enabled = gr.Checkbox(
391 | label="Enabled", value=lambda: params.reactor_enabled
392 | )
393 |
394 | reactor_enabled.change(
395 | lambda new_enabled: params.update({"reactor_enabled": new_enabled}),
396 | reactor_enabled,
397 | None,
398 | )
399 |
400 | reactor_source_face = gr.Text(
401 | label="Source face",
402 | placeholder="See documentation for details...",
403 | value=lambda: params.reactor_source_face,
404 | )
405 |
406 | reactor_source_face.change(
407 | lambda new_source_face: params.update(
408 | {"reactor_source_face": new_source_face}
409 | ),
410 | reactor_source_face,
411 | None,
412 | )
413 |
414 |
415 | def _render_faceid_config(params: Params) -> None:
416 | with gr.Accordion("FaceID (SD.Next only)", open=True, visible=sd_connected) as faceid_config: # noqa: E501
417 | connect_listeners.append(faceid_config)
418 |
419 | with gr.Column():
420 | faceid_enabled = gr.Checkbox(
421 | label="Enabled", value=lambda: params.faceid_enabled
422 | )
423 |
424 | faceid_enabled.change(
425 | lambda new_enabled: params.update({"faceid_enabled": new_enabled}),
426 | faceid_enabled,
427 | None,
428 | )
429 |
430 | faceid_source_face = gr.Text(
431 | label="Source face",
432 | placeholder="See documentation for details...",
433 | value=lambda: params.faceid_source_face,
434 | )
435 |
436 | faceid_source_face.change(
437 | lambda new_source_face: params.update(
438 | {"faceid_source_face": new_source_face}
439 | ),
440 | faceid_source_face,
441 | None,
442 | )
443 |
444 | faceid_mode = gr.Dropdown(
445 | label="Mode",
446 | choices=["FaceID", "FaceSwap"],
447 | value=lambda: params.faceid_mode,
448 | )
449 |
450 | faceid_mode.change(
451 | lambda new_mode: params.update({"faceid_mode": new_mode}),
452 | faceid_mode,
453 | None,
454 | )
455 |
456 | faceid_model = gr.Dropdown(
457 | label="Model",
458 | choices=["FaceID Base", "FaceID Plus", "FaceID Plus v2", "FaceID XL"],
459 | value=lambda: params.faceid_model,
460 | )
461 |
462 | faceid_model.change(
463 | lambda new_model: params.update({"faceid_model": new_model}),
464 | faceid_model,
465 | None,
466 | )
467 |
468 | faceid_strength = gr.Slider(
469 | label="Strength",
470 | value=lambda: params.faceid_strength,
471 | minimum=0,
472 | maximum=2,
473 | step=0.01,
474 | )
475 |
476 | faceid_strength.change(
477 | lambda new_strength: params.update({"faceid_strength": new_strength}),
478 | faceid_strength,
479 | None,
480 | )
481 |
482 | faceid_structure = gr.Slider(
483 | label="Structure",
484 | value=lambda: params.faceid_structure,
485 | minimum=0,
486 | maximum=1,
487 | step=0.01,
488 | )
489 |
490 | faceid_structure.change(
491 | lambda new_structure: params.update(
492 | {"faceid_structure": new_structure}
493 | ),
494 | faceid_structure,
495 | None,
496 | )
497 |
498 |
499 | def _render_ipadapter_config(params: Params) -> None:
500 | with gr.Accordion(
501 | "IP Adapter (SD.Next only)", open=True, visible=sd_connected
502 | ) as ipadapter_config:
503 | connect_listeners.append(ipadapter_config)
504 |
505 | with gr.Column():
506 | ipadapter_enabled = gr.Checkbox(
507 | label="Enabled", value=lambda: params.ipadapter_enabled
508 | )
509 |
510 | ipadapter_enabled.change(
511 | lambda new_enabled: params.update({"ipadapter_enabled": new_enabled}),
512 | ipadapter_enabled,
513 | None,
514 | )
515 |
516 | ipadapter_adapter = gr.Dropdown(
517 | label="Adapter",
518 | choices=[adapter for adapter in IPAdapterAdapter],
519 | value=lambda: params.ipadapter_adapter,
520 | type="index",
521 | )
522 |
523 | ipadapter_adapter.change(
524 | lambda index: params.update(
525 | {"ipadapter_adapter": IPAdapterAdapter.from_index(index)}
526 | ),
527 | ipadapter_adapter,
528 | None,
529 | )
530 |
531 | ipadapter_reference_image = gr.Text(
532 | label="Reference image",
533 | placeholder="See documentation for details...",
534 | value=lambda: params.ipadapter_reference_image,
535 | )
536 |
537 | ipadapter_reference_image.change(
538 | lambda new_reference_image: params.update(
539 | {"ipadapter_reference_image": new_reference_image}
540 | ),
541 | ipadapter_reference_image,
542 | None,
543 | )
544 |
545 | ipadapter_scale = gr.Slider(
546 | label="Scale",
547 | minimum=0,
548 | maximum=1,
549 | value=lambda: params.ipadapter_scale,
550 | step=0.1,
551 | )
552 |
553 | ipadapter_scale.change(
554 | lambda new_scale: params.update({"ipadapter_scale": new_scale}),
555 | ipadapter_scale,
556 | None,
557 | )
558 |
559 |
560 | def _render_chat_config(params: Params) -> None:
561 | with gr.Accordion("Chat Settings", open=True, visible=True) as chat_config:
562 | connect_listeners.append(chat_config)
563 |
564 | with gr.Column():
565 | trigger_mode = gr.Dropdown(
566 | label="Image generation trigger mode",
567 | choices=[sentencecase(mode) for mode in TriggerMode],
568 | value=lambda: sentencecase(params.trigger_mode),
569 | type="index",
570 | )
571 |
572 | trigger_mode.change(
573 | lambda index: params.update(
574 | {"trigger_mode": TriggerMode.from_index(index)}
575 | ),
576 | trigger_mode,
577 | None,
578 | )
579 |
580 | interactive_prompt_generation_mode = gr.Dropdown(
581 | label="Interactive mode prompt generation mode",
582 | choices=[
583 | sentencecase(mode) for mode in InteractiveModePromptGenerationMode
584 | ],
585 | value=lambda: sentencecase(
586 | params.interactive_mode_prompt_generation_mode
587 | ),
588 | type="index",
589 | )
590 |
591 | interactive_prompt_generation_mode.change(
592 | lambda index: params.update(
593 | {
594 | "interactive_mode_prompt_generation_mode": InteractiveModePromptGenerationMode.from_index( # noqa: E501
595 | index
596 | )
597 | }
598 | ),
599 | interactive_prompt_generation_mode,
600 | None,
601 | )
602 |
603 | continuous_prompt_generation_mode = gr.Dropdown(
604 | label="Continous mode prompt generation mode",
605 | choices=[
606 | sentencecase(mode) for mode in ContinuousModePromptGenerationMode
607 | ],
608 | value=lambda: sentencecase(
609 | params.continuous_mode_prompt_generation_mode
610 | ),
611 | type="index",
612 | )
613 |
614 | continuous_prompt_generation_mode.change(
615 | lambda index: params.update(
616 | {
617 | "continuous_mode_prompt_generation_mode": ContinuousModePromptGenerationMode.from_index( # noqa: E501
618 | index
619 | )
620 | }
621 | ),
622 | continuous_prompt_generation_mode,
623 | None,
624 | )
625 |
626 |
627 | def _render_status() -> None:
628 | global status
629 | status = gr.Label(lambda: status_text, label="Status", show_label=True)
630 | _set_status("Ready.", STATUS_SUCCESS)
631 |
632 |
633 | def _refresh_sd_data(params: Params, force_refetch: bool = False) -> None:
634 | global sd_client, sd_connected, refresh_button
635 |
636 | sd_client = SdWebUIApi(
637 | baseurl=params.api_endpoint,
638 | username=params.api_username,
639 | password=params.api_password,
640 | )
641 |
642 | sd_connected = True
643 | _set_status("Connecting to Stable Diffusion WebUI...", STATUS_PROGRESS)
644 |
645 | if sd_connected and (force_refetch or sd_options is None):
646 | _fetch_sd_options(sd_client)
647 |
648 | if sd_connected and (force_refetch or len(sd_samplers) == 0):
649 | _fetch_samplers(sd_client)
650 |
651 | if sd_connected and (force_refetch or len(sd_upscalers) == 0):
652 | _fetch_upscalers(sd_client)
653 |
654 | if sd_connected and (force_refetch or len(sd_checkpoints) == 0):
655 | _fetch_checkpoints(sd_client)
656 |
657 | if sd_connected and (force_refetch or len(sd_vaes) == 0):
658 | _fetch_vaes(sd_client)
659 |
660 | for listener in connect_listeners:
661 | listener.set_visibility(sd_connected)
662 |
663 | if not sd_connected:
664 | _set_status("Stable Diffusion WebUI connection failed", STATUS_FAILURE)
665 | return
666 |
667 | _set_status("✓ Connected to Stable Diffusion WebUI", STATUS_SUCCESS)
668 |
669 |
670 | def _fetch_sd_options(sd_client: SdWebUIApi) -> None:
671 | _set_status("Fetching Stable Diffusion WebUI options...", STATUS_PROGRESS)
672 |
673 | global sd_options, sd_connected
674 |
675 | try:
676 | sd_options = sd_client.get_options()
677 | except BaseException as error:
678 | logger.error(error, exc_info=True)
679 | sd_connected = False
680 |
681 |
682 | def _fetch_samplers(sd_client: SdWebUIApi) -> None:
683 | _set_status("Fetching Stable Diffusion samplers...", STATUS_PROGRESS)
684 |
685 | global sd_samplers, sd_connected
686 |
687 | try:
688 | sd_samplers = [
689 | sampler if isinstance(sampler, str) else sampler["name"]
690 | for sampler in sd_client.get_samplers()
691 | ]
692 | except BaseException as error:
693 | logger.error(error, exc_info=True)
694 | sd_connected = False
695 |
696 |
697 | def _fetch_upscalers(sd_client: SdWebUIApi) -> None:
698 | _set_status("Fetching Stable Diffusion upscalers...", STATUS_PROGRESS)
699 |
700 | global sd_upscalers, sd_connected
701 |
702 | try:
703 | sd_upscalers = [
704 | upscaler if isinstance(upscaler, str) else upscaler["name"]
705 | for upscaler in sd_client.get_upscalers()
706 | ]
707 | except BaseException as error:
708 | logger.error(error, exc_info=True)
709 | sd_connected = False
710 |
711 |
712 | def _fetch_checkpoints(sd_client: SdWebUIApi) -> None:
713 | _set_status("Fetching Stable Diffusion checkpoints...", STATUS_PROGRESS)
714 |
715 | global sd_checkpoints, sd_current_checkpoint, sd_connected
716 |
717 | try:
718 | sd_client.refresh_checkpoints()
719 |
720 | sd_current_checkpoint = sd_options["sd_model_checkpoint"]
721 | sd_checkpoints = [
722 | checkpoint["title"] for checkpoint in sd_client.get_sd_models()
723 | ]
724 | except BaseException as error:
725 | logger.error(error, exc_info=True)
726 | sd_connected = False
727 |
728 |
729 | def _fetch_vaes(sd_client: SdWebUIApi) -> None:
730 | _set_status("Fetching Stable Diffusion VAEs...", STATUS_PROGRESS)
731 |
732 | global sd_vaes, sd_current_vae, sd_connected
733 |
734 | try:
735 | sd_client.refresh_vae()
736 | sd_current_vae = sd_options["sd_vae"]
737 | sd_vaes = [checkpoint["model_name"] for checkpoint in sd_client.get_sd_vae()]
738 | except BaseException as error:
739 | logger.error(error, exc_info=True)
740 | sd_connected = False
741 |
742 |
743 | def _load_checkpoint(checkpoint: str, params: Params) -> None:
744 | global sd_client, sd_current_checkpoint
745 | sd_current_checkpoint = checkpoint
746 |
747 | assert sd_client is not None
748 | sd_client.set_options({"sd_model_checkpoint": checkpoint})
749 |
750 | # apply changes if dynamic VRAM allocation is not enabled
751 | # todo: check if model is loaded in VRAM via SD API instead of relying on vram reallocation check # noqa: E501
752 | if not params.dynamic_vram_reallocation_enabled:
753 | _set_status(
754 | f"Loading Stable Diffusion checkpoint: {checkpoint}...", STATUS_PROGRESS
755 | )
756 | sd_client.reload_checkpoint()
757 |
758 | _set_status("Reloading LLM model:...", STATUS_PROGRESS)
759 |
760 | attempt_vram_reallocation(
761 | VramReallocationTarget.LLM,
762 | GenerationContext(params=params, sd_client=sd_client),
763 | )
764 |
765 | _set_status(f"Stable Diffusion checkpoint ready: {checkpoint}.", STATUS_SUCCESS)
766 |
767 |
768 | def _load_vae(vae: str, params: Params) -> None:
769 | global sd_client, sd_current_vae
770 | sd_current_vae = vae
771 |
772 | assert sd_client is not None
773 | sd_client.set_options({"sd_vae": vae})
774 |
775 | # apply changes if dynamic VRAM allocation is not enabled
776 | # todo: check if model is loaded in VRAM via SD API instead of relying on vram reallocation check # noqa: E501
777 | if not params.dynamic_vram_reallocation_enabled:
778 | _set_status(f"Loading Stable Diffusion VAE: {vae}...", STATUS_PROGRESS)
779 | sd_client.reload_checkpoint()
780 |
781 | attempt_vram_reallocation(
782 | VramReallocationTarget.LLM,
783 | GenerationContext(params=params, sd_client=sd_client),
784 | )
785 |
786 | _set_status(f"Stable Diffusion VAE ready: {vae}.", STATUS_SUCCESS)
787 |
788 |
789 | def _set_status(text: str, status_color: str) -> None:
790 | global status, status_text
791 | assert status is not None
792 |
793 | status_text = text
794 | logger.info("[SD WebUI Integration] " + status_text)
795 |
--------------------------------------------------------------------------------