├── .gitignore
├── README.md
├── built-in-presets.json
├── install.py
├── kohya-sd-scripts-webui-colab.ipynb
├── launch.py
├── main.py
├── screenshots
├── installation-extension.png
└── webui-01.png
├── script.js
├── scripts
├── main.py
├── ngrok.py
├── presets.py
├── runner.py
├── shared.py
├── tabs
│ ├── networks
│ │ ├── check_lora_weights.py
│ │ ├── extract_lora_from_models.py
│ │ ├── lora_interrogator.py
│ │ ├── merge_lora.py
│ │ ├── resize_lora.py
│ │ └── svd_merge_lora.py
│ ├── preparation
│ │ ├── clean_captions_and_tags.py
│ │ ├── make_captions.py
│ │ ├── make_captions_by_git.py
│ │ ├── merge_captions.py
│ │ ├── merge_tags.py
│ │ ├── prepare_latents.py
│ │ └── tag_images_by_wd14tagger.py
│ ├── tools
│ │ ├── convert_diffusers.py
│ │ ├── detect_face_rotate.py
│ │ └── resize_images_to_resolution.py
│ └── training
│ │ ├── fine_tune.py
│ │ ├── train_db.py
│ │ ├── train_network.py
│ │ └── train_textual_inversion.py
├── ui.py
├── ui_overrides.py
└── utilities.py
├── style.css
├── sub.py
├── update.bat
├── update.sh
├── webui.bat
└── webui.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | venv
3 | tmp
4 |
5 | kohya_ss
6 | wd14_tagger_model
7 | presets.json
8 | meta.json
9 | presets
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # kohya sd-scripts webui
2 |
3 | [](https://colab.research.google.com/github/ddPn08/kohya-sd-scripts-webui/blob/main/kohya-sd-scripts-webui-colab.ipynb)
4 |
5 | Gradio wrapper for [sd-scripts](https://github.com/kohya-ss/sd-scripts) by kohya
6 |
7 | It can be used as an extension of [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) or can be launched standalone.
8 |
9 | 
10 |
11 | # Usage
12 | ## As an extension of stable-diffusion-webui
13 |
14 | Go to `Extensions` > `Install from URL`, enter the following URL and press the install button.
15 |
16 | https://github.com/ddpn08/kohya-sd-scripts-webui.git
17 |
18 | 
19 |
20 | ## Start standalone
21 |
22 | Run `webui.bat` for Windows, `webui.sh` for Linux, MacOS
23 |
--------------------------------------------------------------------------------
/built-in-presets.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_network": {
3 | "lora-x512": {
4 | "v2": null,
5 | "v_parameterization": null,
6 | "pretrained_model_name_or_path": null,
7 | "train_data_dir": null,
8 | "shuffle_caption": true,
9 | "caption_extension": ".caption",
10 | "caption_extention": null,
11 | "keep_tokens": null,
12 | "color_aug": null,
13 | "flip_aug": true,
14 | "face_crop_aug_range": null,
15 | "random_crop": null,
16 | "debug_dataset": null,
17 | "resolution": "512",
18 | "cache_latents": null,
19 | "enable_bucket": true,
20 | "min_bucket_reso": 256,
21 | "max_bucket_reso": 1024,
22 | "reg_data_dir": null,
23 | "in_json": null,
24 | "dataset_repeats": 1,
25 | "output_dir": null,
26 | "output_name": null,
27 | "save_precision": null,
28 | "save_every_n_epochs": 5,
29 | "save_n_epoch_ratio": null,
30 | "save_last_n_epochs": null,
31 | "save_last_n_epochs_state": null,
32 | "save_state": null,
33 | "resume": null,
34 | "train_batch_size": 1,
35 | "max_token_length": null,
36 | "use_8bit_adam": true,
37 | "mem_eff_attn": null,
38 | "xformers": true,
39 | "vae": null,
40 | "learning_rate": 0.0001,
41 | "max_train_steps": 1600,
42 | "max_train_epochs": null,
43 | "max_data_loader_n_workers": 8,
44 | "seed": null,
45 | "gradient_checkpointing": true,
46 | "gradient_accumulation_steps": 1,
47 | "mixed_precision": "no",
48 | "full_fp16": null,
49 | "clip_skip": 2,
50 | "logging_dir": null,
51 | "log_prefix": null,
52 | "lr_scheduler": "constant",
53 | "lr_warmup_steps": 0,
54 | "prior_loss_weight": 1.0,
55 | "no_metadata": null,
56 | "save_model_as": "safetensors",
57 | "unet_lr": null,
58 | "text_encoder_lr": null,
59 | "network_weights": null,
60 | "network_module": "networks.lora",
61 | "network_dim": 16,
62 | "network_alpha": 1.0,
63 | "network_args": null,
64 | "network_train_unet_only": null,
65 | "network_train_text_encoder_only": null,
66 | "training_comment": null
67 | }
68 | },
69 | "train_db": {
70 | "db-x512": {
71 | "v2": null,
72 | "v_parameterization": null,
73 | "pretrained_model_name_or_path": null,
74 | "train_data_dir": null,
75 | "shuffle_caption": true,
76 | "caption_extension": ".caption",
77 | "caption_extention": null,
78 | "keep_tokens": null,
79 | "color_aug": null,
80 | "flip_aug": true,
81 | "face_crop_aug_range": null,
82 | "random_crop": null,
83 | "debug_dataset": null,
84 | "resolution": null,
85 | "cache_latents": null,
86 | "enable_bucket": true,
87 | "min_bucket_reso": 256,
88 | "max_bucket_reso": 1024,
89 | "reg_data_dir": null,
90 | "output_dir": null,
91 | "output_name": null,
92 | "save_precision": null,
93 | "save_every_n_epochs": 5,
94 | "save_n_epoch_ratio": null,
95 | "save_last_n_epochs": null,
96 | "save_last_n_epochs_state": null,
97 | "save_state": null,
98 | "resume": null,
99 | "train_batch_size": 1,
100 | "max_token_length": null,
101 | "use_8bit_adam": true,
102 | "mem_eff_attn": null,
103 | "xformers": true,
104 | "vae": null,
105 | "learning_rate": 1e-06,
106 | "max_train_steps": 1600,
107 | "max_train_epochs": null,
108 | "max_data_loader_n_workers": 8,
109 | "seed": null,
110 | "gradient_checkpointing": null,
111 | "gradient_accumulation_steps": 1,
112 | "mixed_precision": "no",
113 | "full_fp16": null,
114 | "clip_skip": 2,
115 | "logging_dir": null,
116 | "log_prefix": null,
117 | "lr_scheduler": "constant",
118 | "lr_warmup_steps": 0,
119 | "prior_loss_weight": 1.0,
120 | "save_model_as": "safetensors",
121 | "use_safetensors": null,
122 | "no_token_padding": null,
123 | "stop_text_encoder_training": null
124 | }
125 | }
126 | }
127 |
--------------------------------------------------------------------------------
/install.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import launch
3 | import platform
4 | import os
5 | import shutil
6 | import site
7 | import glob
8 | import re
9 |
10 | dirname = os.path.dirname(__file__)
11 | repo_dir = os.path.join(dirname, "kohya_ss")
12 |
13 |
14 | def prepare_environment():
15 | torch_command = os.environ.get(
16 | "TORCH_COMMAND",
17 | "pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118",
18 | )
19 | sd_scripts_repo = os.environ.get("SD_SCRIPTS_REPO", "https://github.com/kohya-ss/sd-scripts.git")
20 | sd_scripts_branch = os.environ.get("SD_SCRIPTS_BRANCH", "main")
21 | requirements_file = os.environ.get("REQS_FILE", "requirements.txt")
22 |
23 | sys.argv, skip_install = launch.extract_arg(sys.argv, "--skip-install")
24 | sys.argv, disable_strict_version = launch.extract_arg(
25 | sys.argv, "--disable-strict-version"
26 | )
27 | sys.argv, skip_torch_cuda_test = launch.extract_arg(
28 | sys.argv, "--skip-torch-cuda-test"
29 | )
30 | sys.argv, skip_checkout_repo = launch.extract_arg(sys.argv, "--skip-checkout-repo")
31 | sys.argv, update = launch.extract_arg(sys.argv, "--update")
32 | sys.argv, reinstall_xformers = launch.extract_arg(sys.argv, "--reinstall-xformers")
33 | sys.argv, reinstall_torch = launch.extract_arg(sys.argv, "--reinstall-torch")
34 | xformers = "--xformers" in sys.argv
35 | ngrok = "--ngrok" in sys.argv
36 |
37 | if skip_install:
38 | return
39 |
40 |
41 | if (
42 | reinstall_torch
43 | or not launch.is_installed("torch")
44 | or not launch.is_installed("torchvision")
45 | ) and not disable_strict_version:
46 | launch.run(
47 | f'"{launch.python}" -m {torch_command}',
48 | "Installing torch and torchvision",
49 | "Couldn't install torch",
50 | )
51 |
52 | if not skip_torch_cuda_test:
53 | launch.run_python(
54 | "import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'"
55 | )
56 |
57 | if (not launch.is_installed("xformers") or reinstall_xformers) and xformers:
58 | launch.run_pip("install xformers --pre", "xformers")
59 |
60 | if update and os.path.exists(repo_dir):
61 | launch.run(f'cd "{repo_dir}" && {launch.git} fetch --prune')
62 | launch.run(f'cd "{repo_dir}" && {launch.git} reset --hard origin/main')
63 | elif not os.path.exists(repo_dir):
64 | launch.run(
65 | f'{launch.git} clone {sd_scripts_repo} "{repo_dir}"'
66 | )
67 |
68 | if not skip_checkout_repo:
69 | launch.run(f'cd "{repo_dir}" && {launch.git} checkout {sd_scripts_branch}')
70 |
71 | if not launch.is_installed("gradio"):
72 | launch.run_pip("install gradio==3.16.2", "gradio")
73 |
74 | if not launch.is_installed("pyngrok") and ngrok:
75 | launch.run_pip("install pyngrok", "ngrok")
76 |
77 | if platform.system() == "Linux":
78 | if not launch.is_installed("triton"):
79 | launch.run_pip("install triton", "triton")
80 |
81 | if disable_strict_version:
82 | with open(os.path.join(repo_dir, requirements_file), "r") as f:
83 | txt = f.read()
84 | requirements = [
85 | re.split("==|<|>", a)[0]
86 | for a in txt.split("\n")
87 | if (not a.startswith("#") and a != ".")
88 | ]
89 | requirements = " ".join(requirements)
90 | launch.run_pip(
91 | f'install "{requirements}" "{repo_dir}"',
92 | "requirements for kohya sd-scripts",
93 | )
94 | else:
95 | launch.run(
96 | f'cd "{repo_dir}" && "{launch.python}" -m pip install -r requirements.txt',
97 | desc=f"Installing requirements for kohya sd-scripts",
98 | errdesc=f"Couldn't install requirements for kohya sd-scripts",
99 | )
100 |
101 | if platform.system() == "Windows":
102 | for file in glob.glob(os.path.join(repo_dir, "bitsandbytes_windows", "*")):
103 | filename = os.path.basename(file)
104 | for dir in site.getsitepackages():
105 | outfile = (
106 | os.path.join(dir, "bitsandbytes", "cuda_setup", filename)
107 | if filename == "main.py"
108 | else os.path.join(dir, "bitsandbytes", filename)
109 | )
110 | if not os.path.exists(os.path.dirname(outfile)):
111 | continue
112 | shutil.copy(file, outfile)
113 |
114 |
115 | if __name__ == "__main__":
116 | prepare_environment()
117 |
--------------------------------------------------------------------------------
/kohya-sd-scripts-webui-colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "view-in-github"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {
16 | "id": "zSM6HuYmkYCt"
17 | },
18 | "source": [
19 | "# [kohya sd-scripts webui](https://github.com/ddPn08/kohya-sd-scripts-webui)\n",
20 | "\n",
21 | "This notebook is for running [sd-scripts](https://github.com/kohya-ss/sd-scripts) by [Kohya](https://github.com/kohya-ss).\n",
22 | "\n",
23 | "このノートブックは[Kohya](https://github.com/kohya-ss)さんによる[sd-scripts](https://github.com/kohya-ss/sd-scripts)を実行するためのものです。\n",
24 | "\n",
25 | "# Repository\n",
26 | "[kohya_ss/sd-scripts](https://github.com/kohya-ss/sd-scripts)"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": null,
32 | "metadata": {
33 | "id": "zXcznGdeyb2I"
34 | },
35 | "outputs": [],
36 | "source": [
37 | "! nvidia-smi\n",
38 | "! nvcc -V\n",
39 | "! free -h"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {
46 | "cellView": "form",
47 | "id": "tj65Tb_oyxtP"
48 | },
49 | "outputs": [],
50 | "source": [
51 | "# @markdown # Mount Google Drive\n",
52 | "mount_gdrive = True # @param {type:\"boolean\"}\n",
53 | "gdrive_preset_path = \"/content/drive/MyDrive/AI/kohya-sd-scripts-webui/presets\" # @param {type:\"string\"}\n",
54 | "\n",
55 | "if mount_gdrive:\n",
56 | " from google.colab import drive\n",
57 | " drive.mount('/content/drive', force_remount=False)"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {
64 | "cellView": "form",
65 | "id": "FN7UJvSdzBFF"
66 | },
67 | "outputs": [],
68 | "source": [
69 | "# @markdown # Initialize environment\n",
70 | "\n",
71 | "! git clone https://github.com/ddPn08/kohya-sd-scripts-webui.git\n",
72 | "\n",
73 | "import os\n",
74 | "\n",
75 | "if not os.path.exists(gdrive_preset_path):\n",
76 | " os.makedirs(gdrive_preset_path, exist_ok=True)\n",
77 | "\n",
78 | "! rm -f ./kohya-sd-scripts-webui/presets.json\n",
79 | "! ln -s {gdrive_preset_path} ./kohya-sd-scripts-webui/presets\n",
80 | "\n",
81 | "conda_dir = \"/opt/conda\" # @param{type:\"string\"}\n",
82 | "conda_bin = os.path.join(conda_dir, \"bin\", \"conda\")\n",
83 | "\n",
84 | "if not os.path.exists(conda_bin):\n",
85 | " ! curl -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n",
86 | " ! chmod +x Miniconda3-latest-Linux-x86_64.sh\n",
87 | " ! bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p {conda_dir}\n",
88 | " ! rm Miniconda3-latest-Linux-x86_64.sh\n",
89 | "\n",
90 | "def run_script(s):\n",
91 | " ! {s}\n",
92 | "\n",
93 | "def make_args(d):\n",
94 | " arguments = \"\"\n",
95 | " for k, v in d.items():\n",
96 | " if type(v) == bool:\n",
97 | " arguments += f\"--{k} \" if v else \"\"\n",
98 | " elif type(v) == str and v:\n",
99 | " arguments += f\"--{k} \\\"{v}\\\" \"\n",
100 | " elif v:\n",
101 | " arguments += f\"--{k}={v} \"\n",
102 | " return arguments"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "metadata": {
109 | "cellView": "form",
110 | "id": "uetu1lShs6aJ"
111 | },
112 | "outputs": [],
113 | "source": [
114 | "# @markdown # Run\n",
115 | "\n",
116 | "# @markdown
\n",
117 | "\n",
118 | "# @markdown ## Optional | Ngrok Tunnel\n",
119 | "# @markdown Get token from [here](https://dashboard.ngrok.com/get-started/your-authtoken)\n",
120 | "\n",
121 | "ngrok_token = \"\" # @param {type:\"string\"}\n",
122 | "ngrok_region = \"us\" # @param [\"us\", \"eu\", \"au\", \"ap\", \"sa\", \"jp\", \"in\"]\n",
123 | "\n",
124 | "arguments = {\n",
125 | " \"ngrok\": ngrok_token,\n",
126 | " \"ngrok-region\": ngrok_region,\n",
127 | " \"share\": ngrok_token is None,\n",
128 | " \"xformers\": True,\n",
129 | " \"enable-console-log\": True\n",
130 | "}\n",
131 | "\n",
132 | "run_script(f\"\"\"\n",
133 | "eval \"$({conda_bin} shell.bash hook)\"\n",
134 | "cd kohya-sd-scripts-webui\n",
135 | "python launch.py {make_args(arguments)}\n",
136 | "\"\"\")"
137 | ]
138 | }
139 | ],
140 | "metadata": {
141 | "accelerator": "GPU",
142 | "colab": {
143 | "include_colab_link": true,
144 | "provenance": []
145 | },
146 | "gpuClass": "standard",
147 | "kernelspec": {
148 | "display_name": "Python 3",
149 | "name": "python3"
150 | },
151 | "language_info": {
152 | "name": "python"
153 | }
154 | },
155 | "nbformat": 4,
156 | "nbformat_minor": 0
157 | }
158 |
--------------------------------------------------------------------------------
/launch.py:
--------------------------------------------------------------------------------
1 | import install
2 | import subprocess
3 | import os
4 | import sys
5 | import importlib.util
6 |
7 | python = sys.executable
8 | git = os.environ.get("GIT", "git")
9 | index_url = os.environ.get("INDEX_URL", "")
10 | skip_install = False
11 |
12 |
13 | def run(command, desc=None, errdesc=None, custom_env=None):
14 | if desc is not None:
15 | print(desc)
16 |
17 | result = subprocess.run(
18 | command,
19 | stdout=subprocess.PIPE,
20 | stderr=subprocess.PIPE,
21 | shell=True,
22 | env=os.environ if custom_env is None else custom_env,
23 | )
24 |
25 | if result.returncode != 0:
26 |
27 | message = f"""{errdesc or 'Error running command'}.
28 | Command: {command}
29 | Error code: {result.returncode}
30 | stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else ''}
31 | stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else ''}
32 | """
33 | raise RuntimeError(message)
34 |
35 | return result.stdout.decode(encoding="utf8", errors="ignore")
36 |
37 |
38 | def check_run(command):
39 | result = subprocess.run(
40 | command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True
41 | )
42 | return result.returncode == 0
43 |
44 |
45 | def is_installed(package):
46 | try:
47 | spec = importlib.util.find_spec(package)
48 | except ModuleNotFoundError:
49 | return False
50 |
51 | return spec is not None
52 |
53 |
54 | def run_pip(args, desc=None):
55 | if skip_install:
56 | return
57 |
58 | index_url_line = f" --index-url {index_url}" if index_url != "" else ""
59 | return run(
60 | f'"{python}" -m pip {args} --prefer-binary{index_url_line}',
61 | desc=f"Installing {desc}",
62 | errdesc=f"Couldn't install {desc}",
63 | )
64 |
65 |
66 | def run_python(code, desc=None, errdesc=None):
67 | return run(f'"{python}" -c "{code}"', desc, errdesc)
68 |
69 |
70 | def extract_arg(args, name):
71 | return [x for x in args if x != name], name in args
72 |
73 |
74 | if __name__ == "__main__":
75 | install.prepare_environment()
76 |
77 | from scripts import main
78 |
79 | main.launch()
80 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import io
2 | import sys
3 | import subprocess
4 |
5 | ps = subprocess.Popen(
6 | [sys.executable, "-u", "./sub.py"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT
7 | )
8 |
9 | reader = io.TextIOWrapper(ps.stdout, encoding='utf8')
10 | while ps.poll() is None:
11 | char = reader.read(1)
12 | if char == '\n':
13 | print('break')
14 | sys.stdout.write(char)
15 |
--------------------------------------------------------------------------------
/screenshots/installation-extension.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ddPn08/kohya-sd-scripts-webui/a36f7a5d65576ddc7f37925076d6a681b6e5db4d/screenshots/installation-extension.png
--------------------------------------------------------------------------------
/screenshots/webui-01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ddPn08/kohya-sd-scripts-webui/a36f7a5d65576ddc7f37925076d6a681b6e5db4d/screenshots/webui-01.png
--------------------------------------------------------------------------------
/script.js:
--------------------------------------------------------------------------------
1 | function gradioApp() {
2 | const elems = document.getElementsByTagName('gradio-app')
3 | const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
4 | return !!gradioShadowRoot ? gradioShadowRoot : document;
5 | }
6 |
7 | let executed = false
8 |
9 | /** @type {(() => void)[]} */
10 |
11 | /**
12 | * @param {string} tab
13 | * @param {boolean} show
14 | */
15 | function kohya_sd_webui__toggle_runner_button(tab, show) {
16 | gradioApp().getElementById(`kohya_sd_webui__${tab}_run_button`).style.display = show ? 'block' : 'none'
17 | gradioApp().getElementById(`kohya_sd_webui__${tab}_stop_button`).style.display = show ? 'none' : 'block'
18 | }
19 |
20 | window.addEventListener('DOMContentLoaded', () => {
21 | const observer = new MutationObserver((m) => {
22 | if (!executed && gradioApp().querySelector('#kohya_sd_webui__root')) {
23 | executed = true;
24 |
25 | /** @type {Record} */
26 | const helps = kohya_sd_webui__help_map
27 | /** @type {string[]} */
28 | const all_tabs = kohya_sd_webui__all_tabs
29 |
30 | const initializeTerminalObserver = () => {
31 | const container = gradioApp().querySelector("#kohya_sd_webui__terminal_outputs")
32 | const parentContainer = container.parentElement
33 | const clearBtn = document.createElement('button')
34 | clearBtn.innerText = 'Clear The Terminal'
35 | clearBtn.style.color = 'yellow';
36 | parentContainer.insertBefore(clearBtn, container)
37 | let clearTerminal = false;
38 | clearBtn.addEventListener('click', () => {
39 | container.innerHTML = ''
40 | clearTerminal = true
41 | })
42 | setInterval(async () => {
43 | const res = await fetch('./internal/extensions/kohya-sd-scripts-webui/terminal/outputs', {
44 | method: "POST",
45 | headers: { 'Content-Type': 'application/json' },
46 | body: JSON.stringify({
47 | output_index: container.children.length,
48 | clear_terminal: clearTerminal,
49 | }),
50 | })
51 | clearTerminal = false
52 | const obj = await res.json()
53 | const isBottom = container.scrollHeight - container.scrollTop === container.clientHeight
54 | for(const line of obj.outputs){
55 | const el = document.createElement('div')
56 | el.innerText = line
57 | container.appendChild(el)
58 | }
59 | if(isBottom) container.scrollTop = container.scrollHeight
60 | }, 1000)
61 | }
62 |
63 | const checkProcessIsAlive = () => {
64 | setInterval(async () => {
65 | const res = await fetch('./internal/extensions/kohya-sd-scripts-webui/process/alive')
66 | const obj = await res.json()
67 | for (const tab of all_tabs)
68 | kohya_sd_webui__toggle_runner_button(tab, !obj.alive)
69 |
70 | }, 1000)
71 | }
72 |
73 | initializeTerminalObserver()
74 | checkProcessIsAlive()
75 |
76 | for (const tab of all_tabs)
77 | gradioApp().querySelector(`#kohya_sd_webui__${tab}_run_button`).addEventListener('click', () => kohya_sd_webui__toggle_runner_button(tab, false))
78 |
79 | for (const [k, v] of Object.entries(helps)) {
80 | el = gradioApp().getElementById(k)
81 | if (!el) continue
82 | el.title = v
83 | }
84 | }
85 | })
86 | observer.observe(gradioApp(), { childList: true, subtree: true })
87 | })
--------------------------------------------------------------------------------
/scripts/main.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 |
5 | import gradio.routes
6 |
7 | import scripts.runner as runner
8 | import scripts.shared as shared
9 | from scripts.shared import ROOT_DIR, is_webui_extension
10 | from scripts.ui import create_ui
11 |
12 |
13 | def create_js():
14 | jsfile = os.path.join(ROOT_DIR, "script.js")
15 | with open(jsfile, mode="r") as f:
16 | js = f.read()
17 |
18 | js = js.replace("kohya_sd_webui__help_map", json.dumps(shared.help_title_map))
19 | js = js.replace(
20 | "kohya_sd_webui__all_tabs",
21 | json.dumps(shared.loaded_tabs),
22 | )
23 | return js
24 |
25 |
26 | def create_head():
27 | head = f''
28 |
29 | def template_response_for_webui(*args, **kwargs):
30 | res = shared.gradio_template_response_original(*args, **kwargs)
31 | res.body = res.body.replace(b"", f"{head}".encode("utf8"))
32 | return res
33 |
34 | def template_response(*args, **kwargs):
35 | res = template_response_for_webui(*args, **kwargs)
36 | res.init_headers()
37 | return res
38 |
39 | if is_webui_extension():
40 | import modules.shared
41 |
42 | modules.shared.GradioTemplateResponseOriginal = template_response_for_webui
43 | else:
44 | gradio.routes.templates.TemplateResponse = template_response
45 |
46 |
47 | def wait_on_server():
48 | while 1:
49 | time.sleep(0.5)
50 |
51 |
52 | def on_ui_tabs():
53 | cssfile = os.path.join(ROOT_DIR, "style.css")
54 | with open(cssfile, mode="r") as f:
55 | css = f.read()
56 | sd_scripts = create_ui(css)
57 | create_head()
58 | return [(sd_scripts, "Kohya sd-scripts", "kohya_sd_scripts")]
59 |
60 |
61 | def launch():
62 | block, _, _ = on_ui_tabs()[0]
63 | if shared.cmd_opts.ngrok is not None:
64 | import scripts.ngrok as ngrok
65 |
66 | address = ngrok.connect(
67 | shared.cmd_opts.ngrok,
68 | shared.cmd_opts.port if shared.cmd_opts.port is not None else 7860,
69 | shared.cmd_opts.ngrok_region,
70 | )
71 | print("Running on ngrok URL: " + address)
72 |
73 | app, local_url, share_url = block.launch(
74 | share=shared.cmd_opts.share,
75 | server_port=shared.cmd_opts.port,
76 | server_name=shared.cmd_opts.host,
77 | prevent_thread_lock=True,
78 | )
79 |
80 | runner.initialize_api(app)
81 |
82 | wait_on_server()
83 |
84 |
85 | if not hasattr(shared, "gradio_template_response_original"):
86 | shared.gradio_template_response_original = gradio.routes.templates.TemplateResponse
87 |
88 | if is_webui_extension():
89 | from modules import script_callbacks
90 |
91 | def initialize_api(_, app):
92 | runner.initialize_api(app)
93 |
94 | script_callbacks.on_ui_tabs(on_ui_tabs)
95 | script_callbacks.on_app_started(initialize_api)
96 |
97 | if __name__ == "__main__":
98 | launch()
99 |
--------------------------------------------------------------------------------
/scripts/ngrok.py:
--------------------------------------------------------------------------------
1 | def connect(token, port, region):
2 | from pyngrok import conf, exception, ngrok
3 |
4 | account = None
5 | if token is None:
6 | token = "None"
7 | else:
8 | if ":" in token:
9 | account = token.split(":")[1] + ":" + token.split(":")[-1]
10 | token = token.split(":")[0]
11 |
12 | config = conf.PyngrokConfig(auth_token=token, region=region)
13 | try:
14 | if account is None:
15 | public_url = ngrok.connect(
16 | port, pyngrok_config=config, bind_tls=True
17 | ).public_url
18 | else:
19 | public_url = ngrok.connect(
20 | port, pyngrok_config=config, bind_tls=True, auth=account
21 | ).public_url
22 | except exception.PyngrokNgrokError:
23 | print(
24 | f"Invalid ngrok authtoken, ngrok connection aborted.\n"
25 | f"Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken"
26 | )
27 | else:
28 | return public_url
29 |
--------------------------------------------------------------------------------
/scripts/presets.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import inspect
3 | import os
4 | from pathlib import Path
5 | import toml
6 | from kohya_ss.library import train_util, config_util
7 |
8 | import gradio as gr
9 |
10 | from scripts.shared import ROOT_DIR
11 | from scripts.utilities import gradio_to_args
12 |
13 | PRESET_DIR = os.path.join(ROOT_DIR, "presets")
14 | PRESET_PATH = os.path.join(ROOT_DIR, "presets.json")
15 |
16 |
17 | def get_arg_templates(fn):
18 | parser = argparse.ArgumentParser()
19 | args = [parser]
20 | sig = inspect.signature(fn)
21 | args.extend([True] * (len(sig.parameters) - 1))
22 | fn(*args)
23 | keys = [
24 | x.replace("--", "") for x in parser.__dict__["_option_string_actions"].keys()
25 | ]
26 | keys = [x for x in keys if x not in ["help", "-h"]]
27 | return keys, fn.__name__.replace("add_", "")
28 |
29 |
30 | arguments_functions = [
31 | train_util.add_dataset_arguments,
32 | train_util.add_optimizer_arguments,
33 | train_util.add_sd_models_arguments,
34 | train_util.add_sd_saving_arguments,
35 | train_util.add_training_arguments,
36 | config_util.add_config_arguments,
37 | ]
38 |
39 | arg_templates = [get_arg_templates(x) for x in arguments_functions]
40 |
41 |
42 | def load_presets():
43 | obj = {}
44 | os.makedirs(PRESET_DIR, exist_ok=True)
45 | preset_names = os.listdir(PRESET_DIR)
46 | for preset_name in preset_names:
47 | preset_path = os.path.join(PRESET_DIR, preset_name)
48 | obj[preset_name] = {}
49 | for key in os.listdir(preset_path):
50 | key = key.replace(".toml", "")
51 | obj[preset_name][key] = load_preset(preset_name, key)
52 | return obj
53 |
54 |
55 | def load_preset(key, name):
56 | filepath = os.path.join(PRESET_DIR, key, name + ".toml")
57 | if not os.path.exists(filepath):
58 | return {}
59 | with open(filepath, mode="r") as f:
60 | obj = toml.load(f)
61 |
62 | flatten = {}
63 | for k, v in obj.items():
64 | if not isinstance(v, dict):
65 | flatten[k] = v
66 | else:
67 | for k2, v2 in v.items():
68 | flatten[k2] = v2
69 | return flatten
70 |
71 |
72 | def save_preset(key, name, value):
73 | obj = {}
74 | for k, v in value.items():
75 | if isinstance(v, Path):
76 | v = str(v)
77 | for (template, category) in arg_templates:
78 | if k in template:
79 | if category not in obj:
80 | obj[category] = {}
81 | obj[category][k] = v
82 | break
83 | else:
84 | obj[k] = v
85 |
86 | filepath = os.path.join(PRESET_DIR, key, name + ".toml")
87 | os.makedirs(os.path.dirname(filepath), exist_ok=True)
88 | with open(filepath, mode="w") as f:
89 | toml.dump(obj, f)
90 |
91 |
92 | def delete_preset(key, name):
93 | filepath = os.path.join(PRESET_DIR, key, name + ".toml")
94 | if os.path.exists(filepath):
95 | os.remove(filepath)
96 |
97 |
98 | def create_ui(key, tmpls, opts):
99 | get_templates = lambda: tmpls() if callable(tmpls) else tmpls
100 | get_options = lambda: opts() if callable(opts) else opts
101 |
102 | presets = load_presets()
103 |
104 | if key not in presets:
105 | presets[key] = {}
106 |
107 | with gr.Box():
108 | with gr.Row():
109 | with gr.Column() as c:
110 | load_preset_button = gr.Button("Load preset", variant="primary")
111 | delete_preset_button = gr.Button("Delete preset")
112 | with gr.Column() as c:
113 | load_preset_name = gr.Dropdown(
114 | list(presets[key].keys()), show_label=False
115 | ).style(container=False)
116 | reload_presets_button = gr.Button("🔄️")
117 | with gr.Column() as c:
118 | c.scale = 0.5
119 | save_preset_name = gr.Textbox(
120 | "", placeholder="Preset name", lines=1, show_label=False
121 | ).style(container=False)
122 | save_preset_button = gr.Button("Save preset", variant="primary")
123 |
124 | def update_dropdown():
125 | presets = load_presets()
126 | if key not in presets:
127 | presets[key] = {}
128 | return gr.Dropdown.update(choices=list(presets[key].keys()))
129 |
130 | def _save_preset(args):
131 | name = args[save_preset_name]
132 | if not name:
133 | return update_dropdown()
134 | args = gradio_to_args(get_templates(), get_options(), args)
135 | save_preset(key, name, args)
136 | return update_dropdown()
137 |
138 | def _load_preset(args):
139 | name = args[load_preset_name]
140 | if not name:
141 | return update_dropdown()
142 | args = gradio_to_args(get_templates(), get_options(), args)
143 | preset = load_preset(key, name)
144 | result = []
145 | for k, _ in args.items():
146 | if k == load_preset_name:
147 | continue
148 | if k not in preset:
149 | result.append(None)
150 | continue
151 | v = preset[k]
152 | if type(v) == list:
153 | v = " ".join(v)
154 | result.append(v)
155 | return result[0] if len(result) == 1 else result
156 |
157 | def _delete_preset(name):
158 | if not name:
159 | return update_dropdown()
160 | delete_preset(key, name)
161 | return update_dropdown()
162 |
163 | def init():
164 | save_preset_button.click(
165 | _save_preset,
166 | set([save_preset_name, *get_options().values()]),
167 | [load_preset_name],
168 | )
169 | load_preset_button.click(
170 | _load_preset,
171 | set([load_preset_name, *get_options().values()]),
172 | [*get_options().values()],
173 | )
174 | delete_preset_button.click(_delete_preset, load_preset_name, [load_preset_name])
175 | reload_presets_button.click(
176 | update_dropdown, inputs=[], outputs=[load_preset_name]
177 | )
178 |
179 | return init
180 |
--------------------------------------------------------------------------------
/scripts/runner.py:
--------------------------------------------------------------------------------
1 | import io
2 | import sys
3 |
4 | import fastapi
5 | import gradio as gr
6 | from pydantic import BaseModel, Field
7 |
8 | import scripts.shared as shared
9 | from scripts.utilities import run_python
10 |
11 | proc = None
12 | outputs = []
13 |
14 |
15 | def alive():
16 | return proc is not None
17 |
18 |
19 | def initialize_runner(script_file, tmpls, opts):
20 | run_button = gr.Button(
21 | "Run",
22 | variant="primary",
23 | elem_id=f"kohya_sd_webui__{shared.current_tab}_run_button",
24 | )
25 | stop_button = gr.Button(
26 | "Stop",
27 | variant="secondary",
28 | elem_id=f"kohya_sd_webui__{shared.current_tab}_stop_button",
29 | )
30 | get_templates = lambda: tmpls() if callable(tmpls) else tmpls
31 | get_options = lambda: opts() if callable(opts) else opts
32 |
33 | def run(args):
34 | global proc
35 | global outputs
36 | if alive():
37 | return
38 | proc = run_python(script_file, get_templates(), get_options(), args)
39 | reader = io.TextIOWrapper(proc.stdout, encoding="utf-8-sig")
40 | line = ""
41 | while proc is not None and proc.poll() is None:
42 | try:
43 | char = reader.read(1)
44 | if shared.cmd_opts.enable_console_log:
45 | sys.stdout.write(char)
46 | if char == "\n":
47 | outputs.append(line)
48 | line = ""
49 | continue
50 | line += char
51 | except:
52 | ()
53 | proc = None
54 |
55 | def stop():
56 | global proc
57 | print("killed the running process")
58 | proc.kill()
59 | proc = None
60 |
61 | def init():
62 | run_button.click(
63 | run,
64 | set(get_options().values()),
65 | )
66 | stop_button.click(stop)
67 |
68 | return init
69 |
70 |
71 | class GetOutputRequest(BaseModel):
72 | output_index: int = Field(
73 | default=0, title="Index of the beginning of the log to retrieve"
74 | )
75 | clear_terminal: bool = Field(
76 | default=False, title="Whether to clear the terminal"
77 | )
78 |
79 |
80 | class GetOutputResponse(BaseModel):
81 | outputs: list = Field(title="List of terminal output")
82 |
83 |
84 | class ProcessAliveResponse(BaseModel):
85 | alive: bool = Field(title="Whether the process is running.")
86 |
87 |
88 | def api_get_outputs(req: GetOutputRequest):
89 | i = req.output_index
90 | if req.clear_terminal:
91 | global outputs
92 | outputs = []
93 | out = outputs[i:] if len(outputs) > i else []
94 | return GetOutputResponse(outputs=out)
95 |
96 |
97 | def api_get_isalive(req: fastapi.Request):
98 | return ProcessAliveResponse(alive=alive())
99 |
100 |
101 | def initialize_api(app: fastapi.FastAPI):
102 | app.add_api_route(
103 | "/internal/extensions/kohya-sd-scripts-webui/terminal/outputs",
104 | api_get_outputs,
105 | methods=["POST"],
106 | response_model=GetOutputResponse,
107 | )
108 | app.add_api_route(
109 | "/internal/extensions/kohya-sd-scripts-webui/process/alive",
110 | api_get_isalive,
111 | methods=["GET"],
112 | response_model=ProcessAliveResponse,
113 | )
114 |
--------------------------------------------------------------------------------
/scripts/shared.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import importlib
3 | import os
4 | import sys
5 |
6 |
7 | def is_webui_extension():
8 | try:
9 | importlib.import_module("webui")
10 | return True
11 | except:
12 | return False
13 |
14 |
15 | ROOT_DIR = (
16 | importlib.import_module("modules.scripts").basedir()
17 | if is_webui_extension()
18 | else os.path.dirname(os.path.dirname(__file__))
19 | )
20 |
21 | current_tab = None
22 | loaded_tabs = []
23 | help_title_map = {}
24 |
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--share", action="store_true")
27 | parser.add_argument("--port", type=int, default=None)
28 | parser.add_argument("--host", type=str, default=None)
29 | parser.add_argument("--ngrok", type=str, default=None)
30 | parser.add_argument("--ngrok-region", type=str, default="us")
31 | parser.add_argument("--enable-console-log", action="store_true")
32 | cmd_opts, _ = parser.parse_known_args(sys.argv)
33 |
--------------------------------------------------------------------------------
/scripts/tabs/networks/check_lora_weights.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Check lora wights"
10 |
11 |
12 | def create_ui():
13 | options = {}
14 | templates, script_file = load_args_template("networks", "check_lora_weights.py")
15 |
16 | with gr.Column():
17 | init = initialize_runner(script_file, templates, options)
18 | with gr.Box():
19 | ui.title("Options")
20 | with gr.Column():
21 | options_to_gradio(templates, options)
22 |
23 | init()
24 |
--------------------------------------------------------------------------------
/scripts/tabs/networks/extract_lora_from_models.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Extract lora from models"
10 |
11 |
12 | def create_ui():
13 | options = {}
14 | templates, script_file = load_args_template(
15 | "networks", "extract_lora_from_models.py"
16 | )
17 |
18 | with gr.Column():
19 | init = initialize_runner(script_file, templates, options)
20 | with gr.Box():
21 | ui.title("Options")
22 | with gr.Column():
23 | options_to_gradio(templates, options)
24 |
25 | init()
26 |
--------------------------------------------------------------------------------
/scripts/tabs/networks/lora_interrogator.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Lora interrogator"
10 |
11 |
12 | def create_ui():
13 | options = {}
14 | templates, script_file = load_args_template("networks", "lora_interrogator.py")
15 |
16 | with gr.Column():
17 | init = initialize_runner(script_file, templates, options)
18 | with gr.Box():
19 | ui.title("Options")
20 | with gr.Column():
21 | options_to_gradio(templates, options)
22 |
23 | init()
24 |
--------------------------------------------------------------------------------
/scripts/tabs/networks/merge_lora.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Merge lora"
10 |
11 |
12 | def create_ui():
13 | options = {}
14 | templates, script_file = load_args_template("networks", "merge_lora.py")
15 |
16 | with gr.Column():
17 | init = initialize_runner(script_file, templates, options)
18 | with gr.Box():
19 | ui.title("Options")
20 | with gr.Column():
21 | options_to_gradio(templates, options)
22 |
23 | init()
24 |
--------------------------------------------------------------------------------
/scripts/tabs/networks/resize_lora.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Resize lora"
10 |
11 |
12 | def create_ui():
13 | options = {}
14 | templates, script_file = load_args_template("networks", "resize_lora.py")
15 |
16 | with gr.Column():
17 | init = initialize_runner(script_file, templates, options)
18 | with gr.Box():
19 | ui.title("Options")
20 | with gr.Column():
21 | options_to_gradio(templates, options)
22 |
23 | init()
24 |
--------------------------------------------------------------------------------
/scripts/tabs/networks/svd_merge_lora.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Svd merge lora"
10 |
11 |
12 | def create_ui():
13 | options = {}
14 | templates, script_file = load_args_template("networks", "svd_merge_lora.py")
15 |
16 | with gr.Column():
17 | init = initialize_runner(script_file, templates, options)
18 | with gr.Box():
19 | ui.title("Options")
20 | with gr.Column():
21 | options_to_gradio(templates, options)
22 |
23 | init()
24 |
--------------------------------------------------------------------------------
/scripts/tabs/preparation/clean_captions_and_tags.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import presets, ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Clean captions and tags"
10 |
11 |
12 | def create_ui():
13 | import traceback
14 |
15 | try:
16 | options = {}
17 | templates, script_file = load_args_template(
18 | "finetune", "clean_captions_and_tags.py"
19 | )
20 |
21 | with gr.Column():
22 | init_runner = initialize_runner(script_file, templates, options)
23 | with gr.Box():
24 | with gr.Row():
25 | init_ui = presets.create_ui(
26 | "finetune.clean_captions_and_tags", templates, options
27 | )
28 | with gr.Box():
29 | ui.title("Options")
30 | with gr.Column():
31 | options_to_gradio(templates, options)
32 |
33 | init_runner()
34 | init_ui()
35 |
36 | except:
37 | traceback.print_exc()
38 |
--------------------------------------------------------------------------------
/scripts/tabs/preparation/make_captions.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import presets, ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Make captions"
10 |
11 |
12 | def create_ui():
13 | options = {}
14 | templates, script_file = load_args_template("finetune", "make_captions.py")
15 |
16 | with gr.Column():
17 | init_runner = initialize_runner(script_file, templates, options)
18 | with gr.Box():
19 | with gr.Row():
20 | init_ui = presets.create_ui(
21 | "finetune.make_captions", templates, options
22 | )
23 | with gr.Box():
24 | ui.title("Options")
25 | with gr.Column():
26 | options_to_gradio(templates, options)
27 |
28 | init_runner()
29 | init_ui()
30 |
--------------------------------------------------------------------------------
/scripts/tabs/preparation/make_captions_by_git.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import presets, ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Make captions by GIT"
10 |
11 |
12 | def create_ui():
13 | options = {}
14 | templates, script_file = load_args_template("finetune", "make_captions_by_git.py")
15 |
16 | with gr.Column():
17 | init_runner = initialize_runner(script_file, templates, options)
18 | with gr.Box():
19 | with gr.Row():
20 | init_ui = presets.create_ui(
21 | "finetune.make_captions_by_git", templates, options
22 | )
23 | with gr.Box():
24 | ui.title("Options")
25 | with gr.Column():
26 | options_to_gradio(templates, options)
27 |
28 | init_runner()
29 | init_ui()
30 |
--------------------------------------------------------------------------------
/scripts/tabs/preparation/merge_captions.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import presets, ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Merge captions"
10 |
11 |
12 | def create_ui():
13 | options = {}
14 | templates, script_file = load_args_template(
15 | "finetune", "merge_captions_to_metadata.py"
16 | )
17 |
18 | with gr.Column():
19 | inti_runner = initialize_runner(script_file, templates, options)
20 | with gr.Box():
21 | with gr.Row():
22 | init_ui = presets.create_ui(
23 | "finetune.merge_captions_to_metadata", templates, options
24 | )
25 | with gr.Box():
26 | ui.title("Options")
27 | with gr.Column():
28 | options_to_gradio(templates, options)
29 |
30 | inti_runner()
31 | init_ui()
32 |
--------------------------------------------------------------------------------
/scripts/tabs/preparation/merge_tags.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import presets, ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Merge tags"
10 |
11 |
12 | def create_ui():
13 | options = {}
14 | templates, script_file = load_args_template(
15 | "finetune", "merge_dd_tags_to_metadata.py"
16 | )
17 |
18 | with gr.Column():
19 | init_runner = initialize_runner(script_file, templates, options)
20 | with gr.Box():
21 | with gr.Row():
22 | init_id = presets.create_ui(
23 | "finetune.merge_dd_tags_to_metadata", templates, options
24 | )
25 | with gr.Box():
26 | ui.title("Options")
27 | with gr.Column():
28 | options_to_gradio(templates, options)
29 |
30 | init_runner()
31 | init_id()
32 |
--------------------------------------------------------------------------------
/scripts/tabs/preparation/prepare_latents.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import presets, ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Prepare latents"
10 |
11 |
12 | def create_ui():
13 | options = {}
14 | templates, script_file = load_args_template(
15 | "finetune", "prepare_buckets_latents.py"
16 | )
17 |
18 | with gr.Column():
19 | init_runner = initialize_runner(script_file, templates, options)
20 | with gr.Box():
21 | with gr.Row():
22 | init_ui = presets.create_ui(
23 | "finetune.prepare_buckets_latents", templates, options
24 | )
25 | with gr.Box():
26 | ui.title("Options")
27 | with gr.Column():
28 | options_to_gradio(templates, options)
29 |
30 | init_runner()
31 | init_ui()
32 |
--------------------------------------------------------------------------------
/scripts/tabs/preparation/tag_images_by_wd14tagger.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import presets, ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Tag images by wd1.4tagger"
10 |
11 |
12 | def create_ui():
13 | options = {}
14 | templates, script_file = load_args_template(
15 | "finetune", "tag_images_by_wd14_tagger.py"
16 | )
17 |
18 | with gr.Column():
19 | init_runner = initialize_runner(script_file, templates, options)
20 | with gr.Box():
21 | with gr.Row():
22 | init_id = presets.create_ui(
23 | "finetune.tag_images_by_wd14_tagger", templates, options
24 | )
25 | with gr.Box():
26 | ui.title("Options")
27 | with gr.Column():
28 | options_to_gradio(templates, options)
29 |
30 | init_runner()
31 | init_id()
32 |
--------------------------------------------------------------------------------
/scripts/tabs/tools/convert_diffusers.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import presets, ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Convert Diffusers"
10 |
11 |
12 | def create_ui():
13 | options = {}
14 | templates, script_file = load_args_template(
15 | "tools", "convert_diffusers20_original_sd.py"
16 | )
17 |
18 | with gr.Column():
19 | init_runner = initialize_runner(script_file, templates, options)
20 | with gr.Box():
21 | with gr.Row():
22 | init_ui = presets.create_ui(
23 | "convert_diffusers20_original_sd", templates, options
24 | )
25 | with gr.Box():
26 | ui.title("Options")
27 | with gr.Column():
28 | options_to_gradio(templates, options)
29 |
30 | init_runner()
31 | init_ui()
32 |
--------------------------------------------------------------------------------
/scripts/tabs/tools/detect_face_rotate.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import presets, ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Detect face rotate"
10 |
11 |
12 | def create_ui():
13 | options = {}
14 |
15 | templates, script_file = load_args_template("tools", "detect_face_rotate.py")
16 |
17 | with gr.Column():
18 | init_runner = initialize_runner(script_file, templates, options)
19 | with gr.Box():
20 | with gr.Row():
21 | init_ui = presets.create_ui(
22 | "tools.detect_face_rotate", templates, options
23 | )
24 | with gr.Box():
25 | ui.title("Options")
26 | with gr.Column():
27 | options_to_gradio(templates, options)
28 |
29 | init_runner()
30 | init_ui()
31 |
--------------------------------------------------------------------------------
/scripts/tabs/tools/resize_images_to_resolution.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from scripts import presets, ui
4 | from scripts.runner import initialize_runner
5 | from scripts.utilities import load_args_template, options_to_gradio
6 |
7 |
8 | def title():
9 | return "Resize images to resolution"
10 |
11 |
12 | def create_ui():
13 | options = {}
14 |
15 | templates, script_file = load_args_template(
16 | "tools", "resize_images_to_resolution.py"
17 | )
18 |
19 | with gr.Column():
20 | init_runner = initialize_runner(script_file, templates, options)
21 | with gr.Box():
22 | with gr.Row():
23 | init_ui = presets.create_ui(
24 | "tools.resize_images_to_resolution", templates, options
25 | )
26 | with gr.Box():
27 | ui.title("Options")
28 | with gr.Column():
29 | options_to_gradio(templates, options)
30 |
31 | init_runner()
32 | init_ui()
33 |
--------------------------------------------------------------------------------
/scripts/tabs/training/fine_tune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gradio as gr
4 |
5 | from kohya_ss.library import train_util, config_util
6 | from scripts import presets, ui, ui_overrides
7 | from scripts.runner import initialize_runner
8 | from scripts.utilities import args_to_gradio, load_args_template, options_to_gradio
9 |
10 |
11 | def title():
12 | return "Fine tune"
13 |
14 |
15 | def create_ui():
16 | sd_models_arguments = argparse.ArgumentParser()
17 | dataset_arguments = argparse.ArgumentParser()
18 | training_arguments = argparse.ArgumentParser()
19 | sd_saving_arguments = argparse.ArgumentParser()
20 | optimizer_arguments = argparse.ArgumentParser()
21 | config_arguments = argparse.ArgumentParser()
22 | train_util.add_sd_models_arguments(sd_models_arguments)
23 | train_util.add_dataset_arguments(dataset_arguments, False, True, True)
24 | train_util.add_training_arguments(training_arguments, False)
25 | train_util.add_sd_saving_arguments(sd_saving_arguments)
26 | train_util.add_optimizer_arguments(optimizer_arguments)
27 | config_util.add_config_arguments(config_arguments)
28 | sd_models_options = {}
29 | dataset_options = {}
30 | training_options = {}
31 | sd_saving_options = {}
32 | optimizer_options = {}
33 | config_options = {}
34 | finetune_options = {}
35 |
36 | templates, script_file = load_args_template("fine_tune.py")
37 |
38 | get_options = lambda: {
39 | **sd_models_options,
40 | **dataset_options,
41 | **training_options,
42 | **sd_saving_options,
43 | **optimizer_options,
44 | **finetune_options,
45 | **config_options,
46 | }
47 |
48 | get_templates = lambda: {
49 | **sd_models_arguments.__dict__["_option_string_actions"],
50 | **dataset_arguments.__dict__["_option_string_actions"],
51 | **training_arguments.__dict__["_option_string_actions"],
52 | **sd_saving_arguments.__dict__["_option_string_actions"],
53 | **optimizer_arguments.__dict__["_option_string_actions"],
54 | **config_arguments.__dict__["_option_string_actions"],
55 | **templates,
56 | }
57 |
58 | with gr.Column():
59 | init_runner = initialize_runner(script_file, get_templates, get_options)
60 | with gr.Box():
61 | with gr.Row():
62 | init_ui = presets.create_ui("fine_tune", get_templates, get_options)
63 | with gr.Row():
64 | with gr.Group():
65 | with gr.Box():
66 | ui.title("Fine tune options")
67 | options_to_gradio(templates, finetune_options)
68 | with gr.Box():
69 | ui.title("Model options")
70 | args_to_gradio(sd_models_arguments, sd_models_options)
71 | with gr.Box():
72 | ui.title("Dataset options")
73 | args_to_gradio(dataset_arguments, dataset_options)
74 | with gr.Box():
75 | ui.title("Dataset Config options")
76 | args_to_gradio(config_arguments, config_options)
77 | with gr.Box():
78 | ui.title("Training options")
79 | args_to_gradio(training_arguments, training_options)
80 | with gr.Group():
81 | with gr.Box():
82 | ui.title("Save options")
83 | args_to_gradio(sd_saving_arguments, sd_saving_options)
84 | with gr.Box():
85 | ui.title("Optimizer options")
86 | args_to_gradio(
87 | optimizer_arguments,
88 | optimizer_options,
89 | ui_overrides.OPTIMIZER_OPTIONS,
90 | )
91 |
92 | init_runner()
93 | init_ui()
94 |
--------------------------------------------------------------------------------
/scripts/tabs/training/train_db.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gradio as gr
4 |
5 | from kohya_ss.library import train_util, config_util
6 | from scripts import presets, ui, ui_overrides
7 | from scripts.runner import initialize_runner
8 | from scripts.utilities import args_to_gradio, load_args_template, options_to_gradio
9 |
10 |
11 | def title():
12 | return "Train dreambooth"
13 |
14 |
15 | def create_ui():
16 | sd_models_arguments = argparse.ArgumentParser()
17 | dataset_arguments = argparse.ArgumentParser()
18 | training_arguments = argparse.ArgumentParser()
19 | sd_saving_arguments = argparse.ArgumentParser()
20 | optimizer_arguments = argparse.ArgumentParser()
21 | config_arguments = argparse.ArgumentParser()
22 | train_util.add_sd_models_arguments(sd_models_arguments)
23 | train_util.add_dataset_arguments(dataset_arguments, True, False, True)
24 | train_util.add_training_arguments(training_arguments, True)
25 | train_util.add_sd_saving_arguments(sd_saving_arguments)
26 | train_util.add_optimizer_arguments(optimizer_arguments)
27 | config_util.add_config_arguments(config_arguments)
28 | sd_models_options = {}
29 | dataset_options = {}
30 | training_options = {}
31 | sd_saving_options = {}
32 | optimizer_options = {}
33 | config_options = {}
34 | dreambooth_options = {}
35 |
36 | templates, script_file = load_args_template("train_db.py")
37 |
38 | get_options = lambda: {
39 | **sd_models_options,
40 | **dataset_options,
41 | **training_options,
42 | **sd_saving_options,
43 | **optimizer_options,
44 | **config_options,
45 | **dreambooth_options,
46 | }
47 |
48 | get_templates = lambda: {
49 | **sd_models_arguments.__dict__["_option_string_actions"],
50 | **dataset_arguments.__dict__["_option_string_actions"],
51 | **training_arguments.__dict__["_option_string_actions"],
52 | **sd_saving_arguments.__dict__["_option_string_actions"],
53 | **optimizer_arguments.__dict__["_option_string_actions"],
54 | **config_arguments.__dict__["_option_string_actions"],
55 | **templates,
56 | }
57 |
58 | with gr.Column():
59 | init_runner = initialize_runner(script_file, get_templates, get_options)
60 | with gr.Box():
61 | with gr.Row():
62 | init_ui = presets.create_ui("train_db", get_templates, get_options)
63 | with gr.Row():
64 | with gr.Group():
65 | with gr.Box():
66 | ui.title("Dreambooth options")
67 | options_to_gradio(templates, dreambooth_options)
68 | with gr.Box():
69 | ui.title("Model options")
70 | args_to_gradio(sd_models_arguments, sd_models_options)
71 | with gr.Box():
72 | ui.title("Dataset options")
73 | args_to_gradio(dataset_arguments, dataset_options)
74 | with gr.Box():
75 | ui.title("Dataset Config options")
76 | args_to_gradio(config_arguments, config_options)
77 | with gr.Box():
78 | ui.title("Training options")
79 | args_to_gradio(training_arguments, training_options)
80 | with gr.Group():
81 | with gr.Box():
82 | ui.title("Save options")
83 | args_to_gradio(sd_saving_arguments, sd_saving_options)
84 | with gr.Box():
85 | ui.title("Optimizer options")
86 | args_to_gradio(
87 | optimizer_arguments,
88 | optimizer_options,
89 | ui_overrides.OPTIMIZER_OPTIONS,
90 | )
91 |
92 | init_runner()
93 | init_ui()
94 |
--------------------------------------------------------------------------------
/scripts/tabs/training/train_network.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gradio as gr
4 |
5 | from kohya_ss.library import train_util, config_util
6 | from scripts import presets, ui, ui_overrides
7 | from scripts.runner import initialize_runner
8 | from scripts.utilities import args_to_gradio, load_args_template, options_to_gradio
9 |
10 |
11 | def title():
12 | return "Train network"
13 |
14 |
15 | def create_ui():
16 | sd_models_arguments = argparse.ArgumentParser()
17 | dataset_arguments = argparse.ArgumentParser()
18 | training_arguments = argparse.ArgumentParser()
19 | optimizer_arguments = argparse.ArgumentParser()
20 | config_arguments = argparse.ArgumentParser()
21 | train_util.add_sd_models_arguments(sd_models_arguments)
22 | train_util.add_dataset_arguments(dataset_arguments, True, True, True)
23 | train_util.add_training_arguments(training_arguments, True)
24 | train_util.add_optimizer_arguments(optimizer_arguments)
25 | config_util.add_config_arguments(config_arguments)
26 | sd_models_options = {}
27 | dataset_options = {}
28 | training_options = {}
29 | optimizer_options = {}
30 | config_options = {}
31 | network_options = {}
32 |
33 | templates, script_file = load_args_template("train_network.py")
34 |
35 | get_options = lambda: {
36 | **sd_models_options,
37 | **dataset_options,
38 | **training_options,
39 | **optimizer_options,
40 | **config_options,
41 | **network_options,
42 | }
43 |
44 | get_templates = lambda: {
45 | **sd_models_arguments.__dict__["_option_string_actions"],
46 | **dataset_arguments.__dict__["_option_string_actions"],
47 | **training_arguments.__dict__["_option_string_actions"],
48 | **optimizer_arguments.__dict__["_option_string_actions"],
49 | **config_arguments.__dict__["_option_string_actions"],
50 | **templates,
51 | }
52 |
53 | with gr.Column():
54 | init_runner = initialize_runner(script_file, get_templates, get_options)
55 | with gr.Box():
56 | with gr.Row():
57 | init_id = presets.create_ui("train_network", get_templates, get_options)
58 | with gr.Row():
59 | with gr.Group():
60 | with gr.Box():
61 | ui.title("Network options")
62 | options_to_gradio(templates, network_options)
63 | with gr.Box():
64 | ui.title("Model options")
65 | args_to_gradio(sd_models_arguments, sd_models_options)
66 | with gr.Box():
67 | ui.title("Dataset Config options")
68 | args_to_gradio(config_arguments, config_options)
69 | with gr.Box():
70 | ui.title("Dataset options")
71 | args_to_gradio(dataset_arguments, dataset_options)
72 | with gr.Box():
73 | ui.title("Training options")
74 | args_to_gradio(training_arguments, training_options)
75 | with gr.Box():
76 | ui.title("Optimizer options")
77 | args_to_gradio(
78 | optimizer_arguments,
79 | optimizer_options,
80 | ui_overrides.OPTIMIZER_OPTIONS,
81 | )
82 |
83 | init_runner()
84 | init_id()
85 |
--------------------------------------------------------------------------------
/scripts/tabs/training/train_textual_inversion.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gradio as gr
4 |
5 | from kohya_ss.library import train_util, config_util
6 | from scripts import presets, ui, ui_overrides
7 | from scripts.runner import initialize_runner
8 | from scripts.utilities import args_to_gradio, load_args_template, options_to_gradio
9 |
10 |
11 | def title():
12 | return "Train textual inversion"
13 |
14 |
15 | def create_ui():
16 | sd_models_arguments = argparse.ArgumentParser()
17 | dataset_arguments = argparse.ArgumentParser()
18 | training_arguments = argparse.ArgumentParser()
19 | optimizer_arguments = argparse.ArgumentParser()
20 | config_arguments = argparse.ArgumentParser()
21 | train_util.add_sd_models_arguments(sd_models_arguments)
22 | train_util.add_dataset_arguments(dataset_arguments, True, True, False)
23 | train_util.add_training_arguments(training_arguments, True)
24 | train_util.add_optimizer_arguments(optimizer_arguments)
25 | config_util.add_config_arguments(config_arguments)
26 | sd_models_options = {}
27 | dataset_options = {}
28 | training_options = {}
29 | optimizer_options = {}
30 | config_options = {}
31 | ti_options = {}
32 |
33 | templates, script_file = load_args_template("train_textual_inversion.py")
34 |
35 | get_options = lambda: {
36 | **sd_models_options,
37 | **dataset_options,
38 | **training_options,
39 | **optimizer_options,
40 | **config_options,
41 | **ti_options,
42 | }
43 |
44 | get_templates = lambda: {
45 | **sd_models_arguments.__dict__["_option_string_actions"],
46 | **dataset_arguments.__dict__["_option_string_actions"],
47 | **training_arguments.__dict__["_option_string_actions"],
48 | **optimizer_arguments.__dict__["_option_string_actions"],
49 | **config_arguments.__dict__["_option_string_actions"],
50 | **templates,
51 | }
52 |
53 | with gr.Column():
54 | init_runner = initialize_runner(script_file, get_templates, get_options)
55 | with gr.Box():
56 | with gr.Row():
57 | init_ui = presets.create_ui(
58 | "train_textual_inversion", get_templates, get_options
59 | )
60 | with gr.Row():
61 | with gr.Group():
62 | with gr.Box():
63 | ui.title("Textual inversion options")
64 | options_to_gradio(templates, ti_options)
65 | with gr.Box():
66 | ui.title("Model options")
67 | args_to_gradio(sd_models_arguments, sd_models_options)
68 | with gr.Box():
69 | ui.title("Dataset Config options")
70 | args_to_gradio(config_arguments, config_options)
71 | with gr.Box():
72 | ui.title("Dataset options")
73 | args_to_gradio(dataset_arguments, dataset_options)
74 | with gr.Box():
75 | ui.title("Training options")
76 | args_to_gradio(training_arguments, training_options)
77 | with gr.Box():
78 | ui.title("Optimizer options")
79 | args_to_gradio(
80 | optimizer_arguments,
81 | optimizer_options,
82 | ui_overrides.OPTIMIZER_OPTIONS,
83 | )
84 |
85 | init_runner()
86 | init_ui()
87 |
--------------------------------------------------------------------------------
/scripts/ui.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import importlib
3 | import os
4 | import sys
5 |
6 | import gradio as gr
7 |
8 | import scripts.shared as shared
9 | from scripts.shared import ROOT_DIR
10 | from scripts.utilities import path_to_module
11 |
12 |
13 | def title(txt):
14 | gr.HTML(
15 | f'{txt}
',
16 | )
17 |
18 |
19 | def create_ui(css):
20 | PATHS = [
21 | os.path.join(ROOT_DIR, "kohya_ss", "library"),
22 | ROOT_DIR,
23 | ]
24 | sys.path.extend(PATHS)
25 | with gr.Blocks(css=css, analytics_enabled=False) as ui:
26 | with gr.Tabs(elem_id="kohya_sd_webui__root"):
27 | tabs_dir = os.path.join(ROOT_DIR, "scripts", "tabs")
28 | for category in os.listdir(tabs_dir):
29 | dir = os.path.join(tabs_dir, category)
30 | tabs = glob.glob(os.path.join(dir, "*.py"))
31 | sys.path.append(dir)
32 | if len(tabs) < 1:
33 | continue
34 | with gr.TabItem(category):
35 | for lib in tabs:
36 | try:
37 | module_path = path_to_module(lib)
38 | module_name = module_path.replace(".", "_")
39 |
40 | module = importlib.import_module(module_path)
41 | shared.current_tab = module_name
42 | shared.loaded_tabs.append(module_name)
43 |
44 | with gr.TabItem(module.title()):
45 | module.create_ui()
46 | except Exception as e:
47 | print(f"Failed to load {module_path}")
48 | print(e)
49 | sys.path.remove(dir)
50 | with gr.TabItem("terminal"):
51 | gr.HTML('')
52 | sys.path = [x for x in sys.path if x not in PATHS]
53 | return ui
54 |
--------------------------------------------------------------------------------
/scripts/ui_overrides.py:
--------------------------------------------------------------------------------
1 | OPTIMIZER_OPTIONS = {
2 | "optimizer_type": {
3 | "type": list,
4 | "choices": [
5 | "AdamW",
6 | "AdamW8bit",
7 | "Lion",
8 | "SGDNesterov",
9 | "SGDNesterov8bit",
10 | "DAdaptation",
11 | "AdaFactor",
12 | ],
13 | },
14 | "lr_scheduler": {
15 | "type": list,
16 | "choices": [
17 | "linear",
18 | "cosine",
19 | "cosine_with_restarts",
20 | "polynomial",
21 | "constant",
22 | "constant_with_warmup",
23 | "adafactor",
24 | ],
25 | },
26 | }
27 |
--------------------------------------------------------------------------------
/scripts/utilities.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import importlib
3 | import os
4 | import subprocess
5 | import sys
6 |
7 | import gradio as gr
8 |
9 | import scripts.shared as shared
10 | from scripts.shared import ROOT_DIR
11 |
12 | python = sys.executable
13 |
14 |
15 | def path_to_module(filepath):
16 | return (
17 | os.path.relpath(filepath, ROOT_DIR).replace(os.path.sep, ".").replace(".py", "")
18 | )
19 |
20 |
21 | def which(program):
22 | def is_exe(fpath):
23 | return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
24 |
25 | fpath, _ = os.path.split(program)
26 | if fpath:
27 | if is_exe(program):
28 | return program
29 | else:
30 | for path in os.environ["PATH"].split(os.pathsep):
31 | path = path.strip('"')
32 | exe_file = os.path.join(path, program)
33 | if is_exe(exe_file):
34 | return exe_file
35 |
36 | return None
37 |
38 |
39 | def literal_eval(v, module=None):
40 | if v == "str":
41 | return str
42 | elif v == "int":
43 | return int
44 | elif v == "float":
45 | return float
46 | elif v == list:
47 | return list
48 | else:
49 | if module:
50 | try:
51 | m = importlib.import_module(module)
52 | if hasattr(m, v):
53 | return getattr(m, v)
54 | except:
55 | ()
56 |
57 | return ast.literal_eval(v)
58 |
59 |
60 | def compile_arg_parser(txt, module_path=None):
61 | in_parser = False
62 | parsers = {}
63 | args = []
64 | arg = ""
65 | in_list = False
66 | in_str = None
67 |
68 | def compile(arg):
69 | arg = arg.strip()
70 | matches = arg.split("=")
71 |
72 | if len(matches) > 1:
73 | k = "".join(matches[:1])
74 | v = literal_eval("".join(matches[1:]), module_path)
75 | return (k, v)
76 | else:
77 | return literal_eval(arg, module_path)
78 |
79 | for line in txt.split("\n"):
80 | line = line.split("#")[0]
81 |
82 | if "parser.add_argument(" in line:
83 | in_parser = True
84 | line = line.replace("parser.add_argument(", "")
85 |
86 | if not in_parser:
87 | continue
88 |
89 | for c in line:
90 |
91 | if in_str is None and c == ")":
92 | if arg.strip():
93 | args.append(compile(arg))
94 | in_parser = False
95 | [dest, *others] = args
96 | parsers[dest] = {"dest": dest.replace("--", ""), **dict(others)}
97 | arg = ""
98 | args = []
99 | break
100 |
101 | if c == "[":
102 | in_list = True
103 | elif c == "]":
104 | in_list = False
105 | if c == '"' or c == "'":
106 | if in_str is not None and in_str == c:
107 | in_str = None
108 | elif in_str is None:
109 | in_str = c
110 |
111 | if c == "," and not in_list and in_str is None:
112 | args.append(compile(arg))
113 | arg = ""
114 | continue
115 |
116 | arg += c
117 |
118 | if arg.strip():
119 | args.append(compile(arg))
120 | return parsers
121 |
122 |
123 | def load_args_template(*filename):
124 | repo_dir = os.path.join(ROOT_DIR, "kohya_ss")
125 | filepath = os.path.join(repo_dir, *filename)
126 | with open(filepath, mode="r", encoding="utf-8_sig") as f:
127 | lines = f.readlines()
128 | add = False
129 | txt = ""
130 | for line in lines:
131 | if add == True:
132 | txt += line
133 | if "def setup_parser()" in line:
134 | add = True
135 | continue
136 | return compile_arg_parser(txt, path_to_module(filepath)), filepath
137 |
138 |
139 | def check_key(d, k):
140 | return k in d and d[k] is not None
141 |
142 |
143 | def get_arg_type(d):
144 | if check_key(d, "choices"):
145 | return list
146 | if check_key(d, "type"):
147 | return d["type"]
148 | if check_key(d, "action") and (
149 | d["action"] == "store_true" or d["action"] == "store_false"
150 | ):
151 | return bool
152 | if check_key(d, "const") and type(d["const"]) == bool:
153 | return bool
154 | return str
155 |
156 |
157 | def options_to_gradio(options, out, overrides={}):
158 | for _, item in options.items():
159 | item = item.__dict__ if hasattr(item, "__dict__") else item
160 | key = item["dest"]
161 | if key == "help":
162 | continue
163 | override = overrides[key] if key in overrides else {}
164 | component = None
165 |
166 | help = item["help"] if "help" in item else ""
167 | id = f"kohya_sd_webui__{shared.current_tab.replace('.', '_')}_{key}"
168 | type = override["type"] if "type" in override else get_arg_type(item)
169 | if type == list:
170 | choices = [
171 | c if c is not None else "None"
172 | for c in (
173 | override["choices"] if "choices" in override else item["choices"]
174 | )
175 | ]
176 | component = gr.Radio(
177 | choices=choices,
178 | value=item["default"] if check_key(item, "default") else choices[0],
179 | label=key,
180 | elem_id=id,
181 | interactive=True,
182 | )
183 | elif type == bool:
184 | component = gr.Checkbox(
185 | value=item["default"] if check_key(item, "default") else False,
186 | label=key,
187 | elem_id=id,
188 | interactive=True,
189 | )
190 | else:
191 | component = gr.Textbox(
192 | value=item["default"] if check_key(item, "default") else "",
193 | label=key,
194 | elem_id=id,
195 | interactive=True,
196 | ).style()
197 |
198 | shared.help_title_map[id] = help
199 | out[key] = component
200 |
201 |
202 | def args_to_gradio(args, out, overrides={}):
203 | options_to_gradio(args.__dict__["_option_string_actions"], out, overrides)
204 |
205 |
206 | def gradio_to_args(arguments, options, args, strarg=False):
207 | def find_arg(key):
208 | for k, arg in arguments.items():
209 | arg = arg.__dict__ if hasattr(arg, "__dict__") else arg
210 | if arg["dest"] == key:
211 | return k, arg
212 | return None, None
213 |
214 | def get_value(key):
215 | item = args[options[key]]
216 | raw_key, arg = find_arg(key)
217 | arg_type = get_arg_type(arg)
218 | multiple = "nargs" in arg and arg["nargs"] == "*"
219 |
220 | def set_type(x):
221 | if x is None or x == "None":
222 | return None
223 | elif arg_type is None:
224 | return x
225 | elif arg_type == list:
226 | return x
227 | return arg_type(x)
228 |
229 | if multiple and item is None or item == "":
230 | return raw_key, None
231 |
232 | return raw_key, (
233 | [set_type(x) for x in item.split(" ")] if multiple else set_type(item)
234 | )
235 |
236 | if strarg:
237 | main = []
238 | optional = {}
239 |
240 | for k in options:
241 | key, v = get_value(k)
242 | if key.startswith("--"):
243 | key = k.replace("--", "")
244 | optional[key] = v
245 | else:
246 | main.append(v)
247 |
248 | main = [x for x in main if x is not None]
249 |
250 | return main, optional
251 | else:
252 | result = {}
253 | for k in options:
254 | _, v = get_value(k)
255 | result[k] = v
256 | return result
257 |
258 |
259 | def make_args(d):
260 | arguments = []
261 | for k, v in d.items():
262 | if type(v) == bool:
263 | arguments.append(f"--{k}" if v else "")
264 | elif type(v) == list and len(v) > 0:
265 | arguments.extend([f"--{k}", *v])
266 | elif type(v) == str and v:
267 | arguments.extend([f"--{k}", f"{v}"])
268 | elif v:
269 | arguments.extend([f"--{k}", f"{v}"])
270 | return arguments
271 |
272 |
273 | def run_python(script, templates, options, args):
274 | main, optional = gradio_to_args(templates, options, args, strarg=True)
275 | args = [x for x in [*main, *make_args(optional)] if x]
276 | proc_args = [python, "-u", script, *args]
277 | print("Start process: ", " ".join(proc_args))
278 |
279 | ps = subprocess.Popen(
280 | proc_args,
281 | stdout=subprocess.PIPE,
282 | stderr=subprocess.STDOUT,
283 | cwd=os.path.join(ROOT_DIR, "kohya_ss"),
284 | )
285 | return ps
286 |
--------------------------------------------------------------------------------
/style.css:
--------------------------------------------------------------------------------
1 | #kohya_sd_webui__terminal_outputs {
2 | height: 80vh;
3 | overflow-y: auto;
4 | }
5 |
6 | button[id^='kohya_sd_webui__'][id$='_stop_button'] {
7 | display: none;
8 | }
--------------------------------------------------------------------------------
/sub.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from time import sleep
3 |
4 | progress_bar = tqdm(range(10), smoothing=0, desc="steps")
5 | for i in range(10):
6 | sleep(5)
7 | progress_bar.update(1)
8 | progress_bar.set_postfix({"log": f"sleeping 5 sec {i}"})
9 |
--------------------------------------------------------------------------------
/update.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | git fetch --prune
3 | git reset --hard origin/main
4 | pause
--------------------------------------------------------------------------------
/update.sh:
--------------------------------------------------------------------------------
1 | git fetch --prune
2 | git reset --hard origin/main
--------------------------------------------------------------------------------
/webui.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | if not defined PYTHON (set PYTHON=python)
4 | if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
5 |
6 | set ERROR_REPORTING=FALSE
7 |
8 | mkdir tmp 2>NUL
9 |
10 | %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
11 | if %ERRORLEVEL% == 0 goto :start_venv
12 | echo Couldn't launch python
13 | goto :show_stdout_stderr
14 |
15 | :start_venv
16 | if ["%VENV_DIR%"] == ["-"] goto :skip_venv
17 |
18 | dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
19 | if %ERRORLEVEL% == 0 goto :activate_venv
20 |
21 | for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
22 | echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
23 | %PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
24 | if %ERRORLEVEL% == 0 goto :activate_venv
25 | echo Unable to create venv in directory "%VENV_DIR%"
26 | goto :show_stdout_stderr
27 |
28 | :activate_venv
29 | set PYTHON="%VENV_DIR%\Scripts\Python.exe"
30 | echo venv %PYTHON%
31 | if [%ACCELERATE%] == ["True"] goto :accelerate
32 | goto :launch
33 |
34 | :skip_venv
35 |
36 | :accelerate
37 | echo "Checking for accelerate"
38 | set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe"
39 | if EXIST %ACCELERATE% goto :accelerate_launch
40 |
41 | :launch
42 | %PYTHON% launch.py %*
43 | pause
44 | exit /b
45 |
46 | :accelerate_launch
47 | echo "Accelerating"
48 | %ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py
49 | pause
50 | exit /b
51 |
52 | :show_stdout_stderr
53 |
54 | echo.
55 | echo exit code: %errorlevel%
56 |
57 | for /f %%i in ("tmp\stdout.txt") do set size=%%~zi
58 | if %size% equ 0 goto :show_stderr
59 | echo.
60 | echo stdout:
61 | type tmp\stdout.txt
62 |
63 | :show_stderr
64 | for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
65 | if %size% equ 0 goto :show_stderr
66 | echo.
67 | echo stderr:
68 | type tmp\stderr.txt
69 |
70 | :endofscript
71 |
72 | echo.
73 | echo Launch unsuccessful. Exiting.
74 | pause
75 |
--------------------------------------------------------------------------------
/webui.sh:
--------------------------------------------------------------------------------
1 | # python3 executable
2 | if [[ -z "${python_cmd}" ]]
3 | then
4 | python_cmd="python3"
5 | fi
6 |
7 | # git executable
8 | if [[ -z "${GIT}" ]]
9 | then
10 | export GIT="git"
11 | fi
12 |
13 | # python3 venv without trailing slash
14 | if [[ -z "${venv_dir}" ]]
15 | then
16 | venv_dir="venv"
17 | fi
18 |
19 | if [[ -z "${LAUNCH_SCRIPT}" ]]
20 | then
21 | LAUNCH_SCRIPT="launch.py"
22 | fi
23 |
24 | # this script cannot be run as root by default
25 | can_run_as_root=0
26 | delimiter="################################################################"
27 |
28 | printf "\n%s\n" "${delimiter}"
29 | printf "Create and activate python venv"
30 | printf "\n%s\n" "${delimiter}"
31 | if [[ ! -d "${venv_dir}" ]]
32 | then
33 | "${python_cmd}" -m venv "${venv_dir}"
34 | first_launch=1
35 | fi
36 | # shellcheck source=/dev/null
37 | if [[ -f "${venv_dir}"/bin/activate ]]
38 | then
39 | source "${venv_dir}"/bin/activate
40 | else
41 | printf "\n%s\n" "${delimiter}"
42 | printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m"
43 | printf "\n%s\n" "${delimiter}"
44 | exit 1
45 | fi
46 |
47 | if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]
48 | then
49 | printf "\n%s\n" "${delimiter}"
50 | printf "Accelerating launch.py..."
51 | printf "\n%s\n" "${delimiter}"
52 | exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@"
53 | else
54 | printf "\n%s\n" "${delimiter}"
55 | printf "Launching launch.py..."
56 | printf "\n%s\n" "${delimiter}"
57 | exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
58 | fi
59 |
--------------------------------------------------------------------------------