├── .gitignore ├── README.md ├── built-in-presets.json ├── install.py ├── kohya-sd-scripts-webui-colab.ipynb ├── launch.py ├── main.py ├── screenshots ├── installation-extension.png └── webui-01.png ├── script.js ├── scripts ├── main.py ├── ngrok.py ├── presets.py ├── runner.py ├── shared.py ├── tabs │ ├── networks │ │ ├── check_lora_weights.py │ │ ├── extract_lora_from_models.py │ │ ├── lora_interrogator.py │ │ ├── merge_lora.py │ │ ├── resize_lora.py │ │ └── svd_merge_lora.py │ ├── preparation │ │ ├── clean_captions_and_tags.py │ │ ├── make_captions.py │ │ ├── make_captions_by_git.py │ │ ├── merge_captions.py │ │ ├── merge_tags.py │ │ ├── prepare_latents.py │ │ └── tag_images_by_wd14tagger.py │ ├── tools │ │ ├── convert_diffusers.py │ │ ├── detect_face_rotate.py │ │ └── resize_images_to_resolution.py │ └── training │ │ ├── fine_tune.py │ │ ├── train_db.py │ │ ├── train_network.py │ │ └── train_textual_inversion.py ├── ui.py ├── ui_overrides.py └── utilities.py ├── style.css ├── sub.py ├── update.bat ├── update.sh ├── webui.bat └── webui.sh /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | venv 3 | tmp 4 | 5 | kohya_ss 6 | wd14_tagger_model 7 | presets.json 8 | meta.json 9 | presets -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # kohya sd-scripts webui 2 | 3 | [![](https://img.shields.io/static/v1?message=Open%20in%20Colab&logo=googlecolab&labelColor=5c5c5c&color=0f80c1&label=%20&style=for-the-badge)](https://colab.research.google.com/github/ddPn08/kohya-sd-scripts-webui/blob/main/kohya-sd-scripts-webui-colab.ipynb) 4 | 5 | Gradio wrapper for [sd-scripts](https://github.com/kohya-ss/sd-scripts) by kohya 6 | 7 | It can be used as an extension of [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) or can be launched standalone. 8 | 9 | ![](/screenshots/webui-01.png) 10 | 11 | # Usage 12 | ## As an extension of stable-diffusion-webui 13 | 14 | Go to `Extensions` > `Install from URL`, enter the following URL and press the install button. 15 | 16 | https://github.com/ddpn08/kohya-sd-scripts-webui.git 17 | 18 | ![](/screenshots/installation-extension.png) 19 | 20 | ## Start standalone 21 | 22 | Run `webui.bat` for Windows, `webui.sh` for Linux, MacOS 23 | -------------------------------------------------------------------------------- /built-in-presets.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_network": { 3 | "lora-x512": { 4 | "v2": null, 5 | "v_parameterization": null, 6 | "pretrained_model_name_or_path": null, 7 | "train_data_dir": null, 8 | "shuffle_caption": true, 9 | "caption_extension": ".caption", 10 | "caption_extention": null, 11 | "keep_tokens": null, 12 | "color_aug": null, 13 | "flip_aug": true, 14 | "face_crop_aug_range": null, 15 | "random_crop": null, 16 | "debug_dataset": null, 17 | "resolution": "512", 18 | "cache_latents": null, 19 | "enable_bucket": true, 20 | "min_bucket_reso": 256, 21 | "max_bucket_reso": 1024, 22 | "reg_data_dir": null, 23 | "in_json": null, 24 | "dataset_repeats": 1, 25 | "output_dir": null, 26 | "output_name": null, 27 | "save_precision": null, 28 | "save_every_n_epochs": 5, 29 | "save_n_epoch_ratio": null, 30 | "save_last_n_epochs": null, 31 | "save_last_n_epochs_state": null, 32 | "save_state": null, 33 | "resume": null, 34 | "train_batch_size": 1, 35 | "max_token_length": null, 36 | "use_8bit_adam": true, 37 | "mem_eff_attn": null, 38 | "xformers": true, 39 | "vae": null, 40 | "learning_rate": 0.0001, 41 | "max_train_steps": 1600, 42 | "max_train_epochs": null, 43 | "max_data_loader_n_workers": 8, 44 | "seed": null, 45 | "gradient_checkpointing": true, 46 | "gradient_accumulation_steps": 1, 47 | "mixed_precision": "no", 48 | "full_fp16": null, 49 | "clip_skip": 2, 50 | "logging_dir": null, 51 | "log_prefix": null, 52 | "lr_scheduler": "constant", 53 | "lr_warmup_steps": 0, 54 | "prior_loss_weight": 1.0, 55 | "no_metadata": null, 56 | "save_model_as": "safetensors", 57 | "unet_lr": null, 58 | "text_encoder_lr": null, 59 | "network_weights": null, 60 | "network_module": "networks.lora", 61 | "network_dim": 16, 62 | "network_alpha": 1.0, 63 | "network_args": null, 64 | "network_train_unet_only": null, 65 | "network_train_text_encoder_only": null, 66 | "training_comment": null 67 | } 68 | }, 69 | "train_db": { 70 | "db-x512": { 71 | "v2": null, 72 | "v_parameterization": null, 73 | "pretrained_model_name_or_path": null, 74 | "train_data_dir": null, 75 | "shuffle_caption": true, 76 | "caption_extension": ".caption", 77 | "caption_extention": null, 78 | "keep_tokens": null, 79 | "color_aug": null, 80 | "flip_aug": true, 81 | "face_crop_aug_range": null, 82 | "random_crop": null, 83 | "debug_dataset": null, 84 | "resolution": null, 85 | "cache_latents": null, 86 | "enable_bucket": true, 87 | "min_bucket_reso": 256, 88 | "max_bucket_reso": 1024, 89 | "reg_data_dir": null, 90 | "output_dir": null, 91 | "output_name": null, 92 | "save_precision": null, 93 | "save_every_n_epochs": 5, 94 | "save_n_epoch_ratio": null, 95 | "save_last_n_epochs": null, 96 | "save_last_n_epochs_state": null, 97 | "save_state": null, 98 | "resume": null, 99 | "train_batch_size": 1, 100 | "max_token_length": null, 101 | "use_8bit_adam": true, 102 | "mem_eff_attn": null, 103 | "xformers": true, 104 | "vae": null, 105 | "learning_rate": 1e-06, 106 | "max_train_steps": 1600, 107 | "max_train_epochs": null, 108 | "max_data_loader_n_workers": 8, 109 | "seed": null, 110 | "gradient_checkpointing": null, 111 | "gradient_accumulation_steps": 1, 112 | "mixed_precision": "no", 113 | "full_fp16": null, 114 | "clip_skip": 2, 115 | "logging_dir": null, 116 | "log_prefix": null, 117 | "lr_scheduler": "constant", 118 | "lr_warmup_steps": 0, 119 | "prior_loss_weight": 1.0, 120 | "save_model_as": "safetensors", 121 | "use_safetensors": null, 122 | "no_token_padding": null, 123 | "stop_text_encoder_training": null 124 | } 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import launch 3 | import platform 4 | import os 5 | import shutil 6 | import site 7 | import glob 8 | import re 9 | 10 | dirname = os.path.dirname(__file__) 11 | repo_dir = os.path.join(dirname, "kohya_ss") 12 | 13 | 14 | def prepare_environment(): 15 | torch_command = os.environ.get( 16 | "TORCH_COMMAND", 17 | "pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118", 18 | ) 19 | sd_scripts_repo = os.environ.get("SD_SCRIPTS_REPO", "https://github.com/kohya-ss/sd-scripts.git") 20 | sd_scripts_branch = os.environ.get("SD_SCRIPTS_BRANCH", "main") 21 | requirements_file = os.environ.get("REQS_FILE", "requirements.txt") 22 | 23 | sys.argv, skip_install = launch.extract_arg(sys.argv, "--skip-install") 24 | sys.argv, disable_strict_version = launch.extract_arg( 25 | sys.argv, "--disable-strict-version" 26 | ) 27 | sys.argv, skip_torch_cuda_test = launch.extract_arg( 28 | sys.argv, "--skip-torch-cuda-test" 29 | ) 30 | sys.argv, skip_checkout_repo = launch.extract_arg(sys.argv, "--skip-checkout-repo") 31 | sys.argv, update = launch.extract_arg(sys.argv, "--update") 32 | sys.argv, reinstall_xformers = launch.extract_arg(sys.argv, "--reinstall-xformers") 33 | sys.argv, reinstall_torch = launch.extract_arg(sys.argv, "--reinstall-torch") 34 | xformers = "--xformers" in sys.argv 35 | ngrok = "--ngrok" in sys.argv 36 | 37 | if skip_install: 38 | return 39 | 40 | 41 | if ( 42 | reinstall_torch 43 | or not launch.is_installed("torch") 44 | or not launch.is_installed("torchvision") 45 | ) and not disable_strict_version: 46 | launch.run( 47 | f'"{launch.python}" -m {torch_command}', 48 | "Installing torch and torchvision", 49 | "Couldn't install torch", 50 | ) 51 | 52 | if not skip_torch_cuda_test: 53 | launch.run_python( 54 | "import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'" 55 | ) 56 | 57 | if (not launch.is_installed("xformers") or reinstall_xformers) and xformers: 58 | launch.run_pip("install xformers --pre", "xformers") 59 | 60 | if update and os.path.exists(repo_dir): 61 | launch.run(f'cd "{repo_dir}" && {launch.git} fetch --prune') 62 | launch.run(f'cd "{repo_dir}" && {launch.git} reset --hard origin/main') 63 | elif not os.path.exists(repo_dir): 64 | launch.run( 65 | f'{launch.git} clone {sd_scripts_repo} "{repo_dir}"' 66 | ) 67 | 68 | if not skip_checkout_repo: 69 | launch.run(f'cd "{repo_dir}" && {launch.git} checkout {sd_scripts_branch}') 70 | 71 | if not launch.is_installed("gradio"): 72 | launch.run_pip("install gradio==3.16.2", "gradio") 73 | 74 | if not launch.is_installed("pyngrok") and ngrok: 75 | launch.run_pip("install pyngrok", "ngrok") 76 | 77 | if platform.system() == "Linux": 78 | if not launch.is_installed("triton"): 79 | launch.run_pip("install triton", "triton") 80 | 81 | if disable_strict_version: 82 | with open(os.path.join(repo_dir, requirements_file), "r") as f: 83 | txt = f.read() 84 | requirements = [ 85 | re.split("==|<|>", a)[0] 86 | for a in txt.split("\n") 87 | if (not a.startswith("#") and a != ".") 88 | ] 89 | requirements = " ".join(requirements) 90 | launch.run_pip( 91 | f'install "{requirements}" "{repo_dir}"', 92 | "requirements for kohya sd-scripts", 93 | ) 94 | else: 95 | launch.run( 96 | f'cd "{repo_dir}" && "{launch.python}" -m pip install -r requirements.txt', 97 | desc=f"Installing requirements for kohya sd-scripts", 98 | errdesc=f"Couldn't install requirements for kohya sd-scripts", 99 | ) 100 | 101 | if platform.system() == "Windows": 102 | for file in glob.glob(os.path.join(repo_dir, "bitsandbytes_windows", "*")): 103 | filename = os.path.basename(file) 104 | for dir in site.getsitepackages(): 105 | outfile = ( 106 | os.path.join(dir, "bitsandbytes", "cuda_setup", filename) 107 | if filename == "main.py" 108 | else os.path.join(dir, "bitsandbytes", filename) 109 | ) 110 | if not os.path.exists(os.path.dirname(outfile)): 111 | continue 112 | shutil.copy(file, outfile) 113 | 114 | 115 | if __name__ == "__main__": 116 | prepare_environment() 117 | -------------------------------------------------------------------------------- /kohya-sd-scripts-webui-colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "id": "zSM6HuYmkYCt" 17 | }, 18 | "source": [ 19 | "# [kohya sd-scripts webui](https://github.com/ddPn08/kohya-sd-scripts-webui)\n", 20 | "\n", 21 | "This notebook is for running [sd-scripts](https://github.com/kohya-ss/sd-scripts) by [Kohya](https://github.com/kohya-ss).\n", 22 | "\n", 23 | "このノートブックは[Kohya](https://github.com/kohya-ss)さんによる[sd-scripts](https://github.com/kohya-ss/sd-scripts)を実行するためのものです。\n", 24 | "\n", 25 | "# Repository\n", 26 | "[kohya_ss/sd-scripts](https://github.com/kohya-ss/sd-scripts)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": { 33 | "id": "zXcznGdeyb2I" 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "! nvidia-smi\n", 38 | "! nvcc -V\n", 39 | "! free -h" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": { 46 | "cellView": "form", 47 | "id": "tj65Tb_oyxtP" 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "# @markdown # Mount Google Drive\n", 52 | "mount_gdrive = True # @param {type:\"boolean\"}\n", 53 | "gdrive_preset_path = \"/content/drive/MyDrive/AI/kohya-sd-scripts-webui/presets\" # @param {type:\"string\"}\n", 54 | "\n", 55 | "if mount_gdrive:\n", 56 | " from google.colab import drive\n", 57 | " drive.mount('/content/drive', force_remount=False)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": { 64 | "cellView": "form", 65 | "id": "FN7UJvSdzBFF" 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "# @markdown # Initialize environment\n", 70 | "\n", 71 | "! git clone https://github.com/ddPn08/kohya-sd-scripts-webui.git\n", 72 | "\n", 73 | "import os\n", 74 | "\n", 75 | "if not os.path.exists(gdrive_preset_path):\n", 76 | " os.makedirs(gdrive_preset_path, exist_ok=True)\n", 77 | "\n", 78 | "! rm -f ./kohya-sd-scripts-webui/presets.json\n", 79 | "! ln -s {gdrive_preset_path} ./kohya-sd-scripts-webui/presets\n", 80 | "\n", 81 | "conda_dir = \"/opt/conda\" # @param{type:\"string\"}\n", 82 | "conda_bin = os.path.join(conda_dir, \"bin\", \"conda\")\n", 83 | "\n", 84 | "if not os.path.exists(conda_bin):\n", 85 | " ! curl -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n", 86 | " ! chmod +x Miniconda3-latest-Linux-x86_64.sh\n", 87 | " ! bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p {conda_dir}\n", 88 | " ! rm Miniconda3-latest-Linux-x86_64.sh\n", 89 | "\n", 90 | "def run_script(s):\n", 91 | " ! {s}\n", 92 | "\n", 93 | "def make_args(d):\n", 94 | " arguments = \"\"\n", 95 | " for k, v in d.items():\n", 96 | " if type(v) == bool:\n", 97 | " arguments += f\"--{k} \" if v else \"\"\n", 98 | " elif type(v) == str and v:\n", 99 | " arguments += f\"--{k} \\\"{v}\\\" \"\n", 100 | " elif v:\n", 101 | " arguments += f\"--{k}={v} \"\n", 102 | " return arguments" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": { 109 | "cellView": "form", 110 | "id": "uetu1lShs6aJ" 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "# @markdown # Run\n", 115 | "\n", 116 | "# @markdown
\n", 117 | "\n", 118 | "# @markdown ## Optional | Ngrok Tunnel\n", 119 | "# @markdown Get token from [here](https://dashboard.ngrok.com/get-started/your-authtoken)\n", 120 | "\n", 121 | "ngrok_token = \"\" # @param {type:\"string\"}\n", 122 | "ngrok_region = \"us\" # @param [\"us\", \"eu\", \"au\", \"ap\", \"sa\", \"jp\", \"in\"]\n", 123 | "\n", 124 | "arguments = {\n", 125 | " \"ngrok\": ngrok_token,\n", 126 | " \"ngrok-region\": ngrok_region,\n", 127 | " \"share\": ngrok_token is None,\n", 128 | " \"xformers\": True,\n", 129 | " \"enable-console-log\": True\n", 130 | "}\n", 131 | "\n", 132 | "run_script(f\"\"\"\n", 133 | "eval \"$({conda_bin} shell.bash hook)\"\n", 134 | "cd kohya-sd-scripts-webui\n", 135 | "python launch.py {make_args(arguments)}\n", 136 | "\"\"\")" 137 | ] 138 | } 139 | ], 140 | "metadata": { 141 | "accelerator": "GPU", 142 | "colab": { 143 | "include_colab_link": true, 144 | "provenance": [] 145 | }, 146 | "gpuClass": "standard", 147 | "kernelspec": { 148 | "display_name": "Python 3", 149 | "name": "python3" 150 | }, 151 | "language_info": { 152 | "name": "python" 153 | } 154 | }, 155 | "nbformat": 4, 156 | "nbformat_minor": 0 157 | } 158 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | import install 2 | import subprocess 3 | import os 4 | import sys 5 | import importlib.util 6 | 7 | python = sys.executable 8 | git = os.environ.get("GIT", "git") 9 | index_url = os.environ.get("INDEX_URL", "") 10 | skip_install = False 11 | 12 | 13 | def run(command, desc=None, errdesc=None, custom_env=None): 14 | if desc is not None: 15 | print(desc) 16 | 17 | result = subprocess.run( 18 | command, 19 | stdout=subprocess.PIPE, 20 | stderr=subprocess.PIPE, 21 | shell=True, 22 | env=os.environ if custom_env is None else custom_env, 23 | ) 24 | 25 | if result.returncode != 0: 26 | 27 | message = f"""{errdesc or 'Error running command'}. 28 | Command: {command} 29 | Error code: {result.returncode} 30 | stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else ''} 31 | stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else ''} 32 | """ 33 | raise RuntimeError(message) 34 | 35 | return result.stdout.decode(encoding="utf8", errors="ignore") 36 | 37 | 38 | def check_run(command): 39 | result = subprocess.run( 40 | command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True 41 | ) 42 | return result.returncode == 0 43 | 44 | 45 | def is_installed(package): 46 | try: 47 | spec = importlib.util.find_spec(package) 48 | except ModuleNotFoundError: 49 | return False 50 | 51 | return spec is not None 52 | 53 | 54 | def run_pip(args, desc=None): 55 | if skip_install: 56 | return 57 | 58 | index_url_line = f" --index-url {index_url}" if index_url != "" else "" 59 | return run( 60 | f'"{python}" -m pip {args} --prefer-binary{index_url_line}', 61 | desc=f"Installing {desc}", 62 | errdesc=f"Couldn't install {desc}", 63 | ) 64 | 65 | 66 | def run_python(code, desc=None, errdesc=None): 67 | return run(f'"{python}" -c "{code}"', desc, errdesc) 68 | 69 | 70 | def extract_arg(args, name): 71 | return [x for x in args if x != name], name in args 72 | 73 | 74 | if __name__ == "__main__": 75 | install.prepare_environment() 76 | 77 | from scripts import main 78 | 79 | main.launch() 80 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import io 2 | import sys 3 | import subprocess 4 | 5 | ps = subprocess.Popen( 6 | [sys.executable, "-u", "./sub.py"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT 7 | ) 8 | 9 | reader = io.TextIOWrapper(ps.stdout, encoding='utf8') 10 | while ps.poll() is None: 11 | char = reader.read(1) 12 | if char == '\n': 13 | print('break') 14 | sys.stdout.write(char) 15 | -------------------------------------------------------------------------------- /screenshots/installation-extension.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddPn08/kohya-sd-scripts-webui/a36f7a5d65576ddc7f37925076d6a681b6e5db4d/screenshots/installation-extension.png -------------------------------------------------------------------------------- /screenshots/webui-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddPn08/kohya-sd-scripts-webui/a36f7a5d65576ddc7f37925076d6a681b6e5db4d/screenshots/webui-01.png -------------------------------------------------------------------------------- /script.js: -------------------------------------------------------------------------------- 1 | function gradioApp() { 2 | const elems = document.getElementsByTagName('gradio-app') 3 | const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot 4 | return !!gradioShadowRoot ? gradioShadowRoot : document; 5 | } 6 | 7 | let executed = false 8 | 9 | /** @type {(() => void)[]} */ 10 | 11 | /** 12 | * @param {string} tab 13 | * @param {boolean} show 14 | */ 15 | function kohya_sd_webui__toggle_runner_button(tab, show) { 16 | gradioApp().getElementById(`kohya_sd_webui__${tab}_run_button`).style.display = show ? 'block' : 'none' 17 | gradioApp().getElementById(`kohya_sd_webui__${tab}_stop_button`).style.display = show ? 'none' : 'block' 18 | } 19 | 20 | window.addEventListener('DOMContentLoaded', () => { 21 | const observer = new MutationObserver((m) => { 22 | if (!executed && gradioApp().querySelector('#kohya_sd_webui__root')) { 23 | executed = true; 24 | 25 | /** @type {Record} */ 26 | const helps = kohya_sd_webui__help_map 27 | /** @type {string[]} */ 28 | const all_tabs = kohya_sd_webui__all_tabs 29 | 30 | const initializeTerminalObserver = () => { 31 | const container = gradioApp().querySelector("#kohya_sd_webui__terminal_outputs") 32 | const parentContainer = container.parentElement 33 | const clearBtn = document.createElement('button') 34 | clearBtn.innerText = 'Clear The Terminal' 35 | clearBtn.style.color = 'yellow'; 36 | parentContainer.insertBefore(clearBtn, container) 37 | let clearTerminal = false; 38 | clearBtn.addEventListener('click', () => { 39 | container.innerHTML = '' 40 | clearTerminal = true 41 | }) 42 | setInterval(async () => { 43 | const res = await fetch('./internal/extensions/kohya-sd-scripts-webui/terminal/outputs', { 44 | method: "POST", 45 | headers: { 'Content-Type': 'application/json' }, 46 | body: JSON.stringify({ 47 | output_index: container.children.length, 48 | clear_terminal: clearTerminal, 49 | }), 50 | }) 51 | clearTerminal = false 52 | const obj = await res.json() 53 | const isBottom = container.scrollHeight - container.scrollTop === container.clientHeight 54 | for(const line of obj.outputs){ 55 | const el = document.createElement('div') 56 | el.innerText = line 57 | container.appendChild(el) 58 | } 59 | if(isBottom) container.scrollTop = container.scrollHeight 60 | }, 1000) 61 | } 62 | 63 | const checkProcessIsAlive = () => { 64 | setInterval(async () => { 65 | const res = await fetch('./internal/extensions/kohya-sd-scripts-webui/process/alive') 66 | const obj = await res.json() 67 | for (const tab of all_tabs) 68 | kohya_sd_webui__toggle_runner_button(tab, !obj.alive) 69 | 70 | }, 1000) 71 | } 72 | 73 | initializeTerminalObserver() 74 | checkProcessIsAlive() 75 | 76 | for (const tab of all_tabs) 77 | gradioApp().querySelector(`#kohya_sd_webui__${tab}_run_button`).addEventListener('click', () => kohya_sd_webui__toggle_runner_button(tab, false)) 78 | 79 | for (const [k, v] of Object.entries(helps)) { 80 | el = gradioApp().getElementById(k) 81 | if (!el) continue 82 | el.title = v 83 | } 84 | } 85 | }) 86 | observer.observe(gradioApp(), { childList: true, subtree: true }) 87 | }) -------------------------------------------------------------------------------- /scripts/main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | 5 | import gradio.routes 6 | 7 | import scripts.runner as runner 8 | import scripts.shared as shared 9 | from scripts.shared import ROOT_DIR, is_webui_extension 10 | from scripts.ui import create_ui 11 | 12 | 13 | def create_js(): 14 | jsfile = os.path.join(ROOT_DIR, "script.js") 15 | with open(jsfile, mode="r") as f: 16 | js = f.read() 17 | 18 | js = js.replace("kohya_sd_webui__help_map", json.dumps(shared.help_title_map)) 19 | js = js.replace( 20 | "kohya_sd_webui__all_tabs", 21 | json.dumps(shared.loaded_tabs), 22 | ) 23 | return js 24 | 25 | 26 | def create_head(): 27 | head = f'' 28 | 29 | def template_response_for_webui(*args, **kwargs): 30 | res = shared.gradio_template_response_original(*args, **kwargs) 31 | res.body = res.body.replace(b"", f"{head}".encode("utf8")) 32 | return res 33 | 34 | def template_response(*args, **kwargs): 35 | res = template_response_for_webui(*args, **kwargs) 36 | res.init_headers() 37 | return res 38 | 39 | if is_webui_extension(): 40 | import modules.shared 41 | 42 | modules.shared.GradioTemplateResponseOriginal = template_response_for_webui 43 | else: 44 | gradio.routes.templates.TemplateResponse = template_response 45 | 46 | 47 | def wait_on_server(): 48 | while 1: 49 | time.sleep(0.5) 50 | 51 | 52 | def on_ui_tabs(): 53 | cssfile = os.path.join(ROOT_DIR, "style.css") 54 | with open(cssfile, mode="r") as f: 55 | css = f.read() 56 | sd_scripts = create_ui(css) 57 | create_head() 58 | return [(sd_scripts, "Kohya sd-scripts", "kohya_sd_scripts")] 59 | 60 | 61 | def launch(): 62 | block, _, _ = on_ui_tabs()[0] 63 | if shared.cmd_opts.ngrok is not None: 64 | import scripts.ngrok as ngrok 65 | 66 | address = ngrok.connect( 67 | shared.cmd_opts.ngrok, 68 | shared.cmd_opts.port if shared.cmd_opts.port is not None else 7860, 69 | shared.cmd_opts.ngrok_region, 70 | ) 71 | print("Running on ngrok URL: " + address) 72 | 73 | app, local_url, share_url = block.launch( 74 | share=shared.cmd_opts.share, 75 | server_port=shared.cmd_opts.port, 76 | server_name=shared.cmd_opts.host, 77 | prevent_thread_lock=True, 78 | ) 79 | 80 | runner.initialize_api(app) 81 | 82 | wait_on_server() 83 | 84 | 85 | if not hasattr(shared, "gradio_template_response_original"): 86 | shared.gradio_template_response_original = gradio.routes.templates.TemplateResponse 87 | 88 | if is_webui_extension(): 89 | from modules import script_callbacks 90 | 91 | def initialize_api(_, app): 92 | runner.initialize_api(app) 93 | 94 | script_callbacks.on_ui_tabs(on_ui_tabs) 95 | script_callbacks.on_app_started(initialize_api) 96 | 97 | if __name__ == "__main__": 98 | launch() 99 | -------------------------------------------------------------------------------- /scripts/ngrok.py: -------------------------------------------------------------------------------- 1 | def connect(token, port, region): 2 | from pyngrok import conf, exception, ngrok 3 | 4 | account = None 5 | if token is None: 6 | token = "None" 7 | else: 8 | if ":" in token: 9 | account = token.split(":")[1] + ":" + token.split(":")[-1] 10 | token = token.split(":")[0] 11 | 12 | config = conf.PyngrokConfig(auth_token=token, region=region) 13 | try: 14 | if account is None: 15 | public_url = ngrok.connect( 16 | port, pyngrok_config=config, bind_tls=True 17 | ).public_url 18 | else: 19 | public_url = ngrok.connect( 20 | port, pyngrok_config=config, bind_tls=True, auth=account 21 | ).public_url 22 | except exception.PyngrokNgrokError: 23 | print( 24 | f"Invalid ngrok authtoken, ngrok connection aborted.\n" 25 | f"Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken" 26 | ) 27 | else: 28 | return public_url 29 | -------------------------------------------------------------------------------- /scripts/presets.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | import os 4 | from pathlib import Path 5 | import toml 6 | from kohya_ss.library import train_util, config_util 7 | 8 | import gradio as gr 9 | 10 | from scripts.shared import ROOT_DIR 11 | from scripts.utilities import gradio_to_args 12 | 13 | PRESET_DIR = os.path.join(ROOT_DIR, "presets") 14 | PRESET_PATH = os.path.join(ROOT_DIR, "presets.json") 15 | 16 | 17 | def get_arg_templates(fn): 18 | parser = argparse.ArgumentParser() 19 | args = [parser] 20 | sig = inspect.signature(fn) 21 | args.extend([True] * (len(sig.parameters) - 1)) 22 | fn(*args) 23 | keys = [ 24 | x.replace("--", "") for x in parser.__dict__["_option_string_actions"].keys() 25 | ] 26 | keys = [x for x in keys if x not in ["help", "-h"]] 27 | return keys, fn.__name__.replace("add_", "") 28 | 29 | 30 | arguments_functions = [ 31 | train_util.add_dataset_arguments, 32 | train_util.add_optimizer_arguments, 33 | train_util.add_sd_models_arguments, 34 | train_util.add_sd_saving_arguments, 35 | train_util.add_training_arguments, 36 | config_util.add_config_arguments, 37 | ] 38 | 39 | arg_templates = [get_arg_templates(x) for x in arguments_functions] 40 | 41 | 42 | def load_presets(): 43 | obj = {} 44 | os.makedirs(PRESET_DIR, exist_ok=True) 45 | preset_names = os.listdir(PRESET_DIR) 46 | for preset_name in preset_names: 47 | preset_path = os.path.join(PRESET_DIR, preset_name) 48 | obj[preset_name] = {} 49 | for key in os.listdir(preset_path): 50 | key = key.replace(".toml", "") 51 | obj[preset_name][key] = load_preset(preset_name, key) 52 | return obj 53 | 54 | 55 | def load_preset(key, name): 56 | filepath = os.path.join(PRESET_DIR, key, name + ".toml") 57 | if not os.path.exists(filepath): 58 | return {} 59 | with open(filepath, mode="r") as f: 60 | obj = toml.load(f) 61 | 62 | flatten = {} 63 | for k, v in obj.items(): 64 | if not isinstance(v, dict): 65 | flatten[k] = v 66 | else: 67 | for k2, v2 in v.items(): 68 | flatten[k2] = v2 69 | return flatten 70 | 71 | 72 | def save_preset(key, name, value): 73 | obj = {} 74 | for k, v in value.items(): 75 | if isinstance(v, Path): 76 | v = str(v) 77 | for (template, category) in arg_templates: 78 | if k in template: 79 | if category not in obj: 80 | obj[category] = {} 81 | obj[category][k] = v 82 | break 83 | else: 84 | obj[k] = v 85 | 86 | filepath = os.path.join(PRESET_DIR, key, name + ".toml") 87 | os.makedirs(os.path.dirname(filepath), exist_ok=True) 88 | with open(filepath, mode="w") as f: 89 | toml.dump(obj, f) 90 | 91 | 92 | def delete_preset(key, name): 93 | filepath = os.path.join(PRESET_DIR, key, name + ".toml") 94 | if os.path.exists(filepath): 95 | os.remove(filepath) 96 | 97 | 98 | def create_ui(key, tmpls, opts): 99 | get_templates = lambda: tmpls() if callable(tmpls) else tmpls 100 | get_options = lambda: opts() if callable(opts) else opts 101 | 102 | presets = load_presets() 103 | 104 | if key not in presets: 105 | presets[key] = {} 106 | 107 | with gr.Box(): 108 | with gr.Row(): 109 | with gr.Column() as c: 110 | load_preset_button = gr.Button("Load preset", variant="primary") 111 | delete_preset_button = gr.Button("Delete preset") 112 | with gr.Column() as c: 113 | load_preset_name = gr.Dropdown( 114 | list(presets[key].keys()), show_label=False 115 | ).style(container=False) 116 | reload_presets_button = gr.Button("🔄️") 117 | with gr.Column() as c: 118 | c.scale = 0.5 119 | save_preset_name = gr.Textbox( 120 | "", placeholder="Preset name", lines=1, show_label=False 121 | ).style(container=False) 122 | save_preset_button = gr.Button("Save preset", variant="primary") 123 | 124 | def update_dropdown(): 125 | presets = load_presets() 126 | if key not in presets: 127 | presets[key] = {} 128 | return gr.Dropdown.update(choices=list(presets[key].keys())) 129 | 130 | def _save_preset(args): 131 | name = args[save_preset_name] 132 | if not name: 133 | return update_dropdown() 134 | args = gradio_to_args(get_templates(), get_options(), args) 135 | save_preset(key, name, args) 136 | return update_dropdown() 137 | 138 | def _load_preset(args): 139 | name = args[load_preset_name] 140 | if not name: 141 | return update_dropdown() 142 | args = gradio_to_args(get_templates(), get_options(), args) 143 | preset = load_preset(key, name) 144 | result = [] 145 | for k, _ in args.items(): 146 | if k == load_preset_name: 147 | continue 148 | if k not in preset: 149 | result.append(None) 150 | continue 151 | v = preset[k] 152 | if type(v) == list: 153 | v = " ".join(v) 154 | result.append(v) 155 | return result[0] if len(result) == 1 else result 156 | 157 | def _delete_preset(name): 158 | if not name: 159 | return update_dropdown() 160 | delete_preset(key, name) 161 | return update_dropdown() 162 | 163 | def init(): 164 | save_preset_button.click( 165 | _save_preset, 166 | set([save_preset_name, *get_options().values()]), 167 | [load_preset_name], 168 | ) 169 | load_preset_button.click( 170 | _load_preset, 171 | set([load_preset_name, *get_options().values()]), 172 | [*get_options().values()], 173 | ) 174 | delete_preset_button.click(_delete_preset, load_preset_name, [load_preset_name]) 175 | reload_presets_button.click( 176 | update_dropdown, inputs=[], outputs=[load_preset_name] 177 | ) 178 | 179 | return init 180 | -------------------------------------------------------------------------------- /scripts/runner.py: -------------------------------------------------------------------------------- 1 | import io 2 | import sys 3 | 4 | import fastapi 5 | import gradio as gr 6 | from pydantic import BaseModel, Field 7 | 8 | import scripts.shared as shared 9 | from scripts.utilities import run_python 10 | 11 | proc = None 12 | outputs = [] 13 | 14 | 15 | def alive(): 16 | return proc is not None 17 | 18 | 19 | def initialize_runner(script_file, tmpls, opts): 20 | run_button = gr.Button( 21 | "Run", 22 | variant="primary", 23 | elem_id=f"kohya_sd_webui__{shared.current_tab}_run_button", 24 | ) 25 | stop_button = gr.Button( 26 | "Stop", 27 | variant="secondary", 28 | elem_id=f"kohya_sd_webui__{shared.current_tab}_stop_button", 29 | ) 30 | get_templates = lambda: tmpls() if callable(tmpls) else tmpls 31 | get_options = lambda: opts() if callable(opts) else opts 32 | 33 | def run(args): 34 | global proc 35 | global outputs 36 | if alive(): 37 | return 38 | proc = run_python(script_file, get_templates(), get_options(), args) 39 | reader = io.TextIOWrapper(proc.stdout, encoding="utf-8-sig") 40 | line = "" 41 | while proc is not None and proc.poll() is None: 42 | try: 43 | char = reader.read(1) 44 | if shared.cmd_opts.enable_console_log: 45 | sys.stdout.write(char) 46 | if char == "\n": 47 | outputs.append(line) 48 | line = "" 49 | continue 50 | line += char 51 | except: 52 | () 53 | proc = None 54 | 55 | def stop(): 56 | global proc 57 | print("killed the running process") 58 | proc.kill() 59 | proc = None 60 | 61 | def init(): 62 | run_button.click( 63 | run, 64 | set(get_options().values()), 65 | ) 66 | stop_button.click(stop) 67 | 68 | return init 69 | 70 | 71 | class GetOutputRequest(BaseModel): 72 | output_index: int = Field( 73 | default=0, title="Index of the beginning of the log to retrieve" 74 | ) 75 | clear_terminal: bool = Field( 76 | default=False, title="Whether to clear the terminal" 77 | ) 78 | 79 | 80 | class GetOutputResponse(BaseModel): 81 | outputs: list = Field(title="List of terminal output") 82 | 83 | 84 | class ProcessAliveResponse(BaseModel): 85 | alive: bool = Field(title="Whether the process is running.") 86 | 87 | 88 | def api_get_outputs(req: GetOutputRequest): 89 | i = req.output_index 90 | if req.clear_terminal: 91 | global outputs 92 | outputs = [] 93 | out = outputs[i:] if len(outputs) > i else [] 94 | return GetOutputResponse(outputs=out) 95 | 96 | 97 | def api_get_isalive(req: fastapi.Request): 98 | return ProcessAliveResponse(alive=alive()) 99 | 100 | 101 | def initialize_api(app: fastapi.FastAPI): 102 | app.add_api_route( 103 | "/internal/extensions/kohya-sd-scripts-webui/terminal/outputs", 104 | api_get_outputs, 105 | methods=["POST"], 106 | response_model=GetOutputResponse, 107 | ) 108 | app.add_api_route( 109 | "/internal/extensions/kohya-sd-scripts-webui/process/alive", 110 | api_get_isalive, 111 | methods=["GET"], 112 | response_model=ProcessAliveResponse, 113 | ) 114 | -------------------------------------------------------------------------------- /scripts/shared.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | import os 4 | import sys 5 | 6 | 7 | def is_webui_extension(): 8 | try: 9 | importlib.import_module("webui") 10 | return True 11 | except: 12 | return False 13 | 14 | 15 | ROOT_DIR = ( 16 | importlib.import_module("modules.scripts").basedir() 17 | if is_webui_extension() 18 | else os.path.dirname(os.path.dirname(__file__)) 19 | ) 20 | 21 | current_tab = None 22 | loaded_tabs = [] 23 | help_title_map = {} 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--share", action="store_true") 27 | parser.add_argument("--port", type=int, default=None) 28 | parser.add_argument("--host", type=str, default=None) 29 | parser.add_argument("--ngrok", type=str, default=None) 30 | parser.add_argument("--ngrok-region", type=str, default="us") 31 | parser.add_argument("--enable-console-log", action="store_true") 32 | cmd_opts, _ = parser.parse_known_args(sys.argv) 33 | -------------------------------------------------------------------------------- /scripts/tabs/networks/check_lora_weights.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Check lora wights" 10 | 11 | 12 | def create_ui(): 13 | options = {} 14 | templates, script_file = load_args_template("networks", "check_lora_weights.py") 15 | 16 | with gr.Column(): 17 | init = initialize_runner(script_file, templates, options) 18 | with gr.Box(): 19 | ui.title("Options") 20 | with gr.Column(): 21 | options_to_gradio(templates, options) 22 | 23 | init() 24 | -------------------------------------------------------------------------------- /scripts/tabs/networks/extract_lora_from_models.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Extract lora from models" 10 | 11 | 12 | def create_ui(): 13 | options = {} 14 | templates, script_file = load_args_template( 15 | "networks", "extract_lora_from_models.py" 16 | ) 17 | 18 | with gr.Column(): 19 | init = initialize_runner(script_file, templates, options) 20 | with gr.Box(): 21 | ui.title("Options") 22 | with gr.Column(): 23 | options_to_gradio(templates, options) 24 | 25 | init() 26 | -------------------------------------------------------------------------------- /scripts/tabs/networks/lora_interrogator.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Lora interrogator" 10 | 11 | 12 | def create_ui(): 13 | options = {} 14 | templates, script_file = load_args_template("networks", "lora_interrogator.py") 15 | 16 | with gr.Column(): 17 | init = initialize_runner(script_file, templates, options) 18 | with gr.Box(): 19 | ui.title("Options") 20 | with gr.Column(): 21 | options_to_gradio(templates, options) 22 | 23 | init() 24 | -------------------------------------------------------------------------------- /scripts/tabs/networks/merge_lora.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Merge lora" 10 | 11 | 12 | def create_ui(): 13 | options = {} 14 | templates, script_file = load_args_template("networks", "merge_lora.py") 15 | 16 | with gr.Column(): 17 | init = initialize_runner(script_file, templates, options) 18 | with gr.Box(): 19 | ui.title("Options") 20 | with gr.Column(): 21 | options_to_gradio(templates, options) 22 | 23 | init() 24 | -------------------------------------------------------------------------------- /scripts/tabs/networks/resize_lora.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Resize lora" 10 | 11 | 12 | def create_ui(): 13 | options = {} 14 | templates, script_file = load_args_template("networks", "resize_lora.py") 15 | 16 | with gr.Column(): 17 | init = initialize_runner(script_file, templates, options) 18 | with gr.Box(): 19 | ui.title("Options") 20 | with gr.Column(): 21 | options_to_gradio(templates, options) 22 | 23 | init() 24 | -------------------------------------------------------------------------------- /scripts/tabs/networks/svd_merge_lora.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Svd merge lora" 10 | 11 | 12 | def create_ui(): 13 | options = {} 14 | templates, script_file = load_args_template("networks", "svd_merge_lora.py") 15 | 16 | with gr.Column(): 17 | init = initialize_runner(script_file, templates, options) 18 | with gr.Box(): 19 | ui.title("Options") 20 | with gr.Column(): 21 | options_to_gradio(templates, options) 22 | 23 | init() 24 | -------------------------------------------------------------------------------- /scripts/tabs/preparation/clean_captions_and_tags.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import presets, ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Clean captions and tags" 10 | 11 | 12 | def create_ui(): 13 | import traceback 14 | 15 | try: 16 | options = {} 17 | templates, script_file = load_args_template( 18 | "finetune", "clean_captions_and_tags.py" 19 | ) 20 | 21 | with gr.Column(): 22 | init_runner = initialize_runner(script_file, templates, options) 23 | with gr.Box(): 24 | with gr.Row(): 25 | init_ui = presets.create_ui( 26 | "finetune.clean_captions_and_tags", templates, options 27 | ) 28 | with gr.Box(): 29 | ui.title("Options") 30 | with gr.Column(): 31 | options_to_gradio(templates, options) 32 | 33 | init_runner() 34 | init_ui() 35 | 36 | except: 37 | traceback.print_exc() 38 | -------------------------------------------------------------------------------- /scripts/tabs/preparation/make_captions.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import presets, ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Make captions" 10 | 11 | 12 | def create_ui(): 13 | options = {} 14 | templates, script_file = load_args_template("finetune", "make_captions.py") 15 | 16 | with gr.Column(): 17 | init_runner = initialize_runner(script_file, templates, options) 18 | with gr.Box(): 19 | with gr.Row(): 20 | init_ui = presets.create_ui( 21 | "finetune.make_captions", templates, options 22 | ) 23 | with gr.Box(): 24 | ui.title("Options") 25 | with gr.Column(): 26 | options_to_gradio(templates, options) 27 | 28 | init_runner() 29 | init_ui() 30 | -------------------------------------------------------------------------------- /scripts/tabs/preparation/make_captions_by_git.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import presets, ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Make captions by GIT" 10 | 11 | 12 | def create_ui(): 13 | options = {} 14 | templates, script_file = load_args_template("finetune", "make_captions_by_git.py") 15 | 16 | with gr.Column(): 17 | init_runner = initialize_runner(script_file, templates, options) 18 | with gr.Box(): 19 | with gr.Row(): 20 | init_ui = presets.create_ui( 21 | "finetune.make_captions_by_git", templates, options 22 | ) 23 | with gr.Box(): 24 | ui.title("Options") 25 | with gr.Column(): 26 | options_to_gradio(templates, options) 27 | 28 | init_runner() 29 | init_ui() 30 | -------------------------------------------------------------------------------- /scripts/tabs/preparation/merge_captions.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import presets, ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Merge captions" 10 | 11 | 12 | def create_ui(): 13 | options = {} 14 | templates, script_file = load_args_template( 15 | "finetune", "merge_captions_to_metadata.py" 16 | ) 17 | 18 | with gr.Column(): 19 | inti_runner = initialize_runner(script_file, templates, options) 20 | with gr.Box(): 21 | with gr.Row(): 22 | init_ui = presets.create_ui( 23 | "finetune.merge_captions_to_metadata", templates, options 24 | ) 25 | with gr.Box(): 26 | ui.title("Options") 27 | with gr.Column(): 28 | options_to_gradio(templates, options) 29 | 30 | inti_runner() 31 | init_ui() 32 | -------------------------------------------------------------------------------- /scripts/tabs/preparation/merge_tags.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import presets, ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Merge tags" 10 | 11 | 12 | def create_ui(): 13 | options = {} 14 | templates, script_file = load_args_template( 15 | "finetune", "merge_dd_tags_to_metadata.py" 16 | ) 17 | 18 | with gr.Column(): 19 | init_runner = initialize_runner(script_file, templates, options) 20 | with gr.Box(): 21 | with gr.Row(): 22 | init_id = presets.create_ui( 23 | "finetune.merge_dd_tags_to_metadata", templates, options 24 | ) 25 | with gr.Box(): 26 | ui.title("Options") 27 | with gr.Column(): 28 | options_to_gradio(templates, options) 29 | 30 | init_runner() 31 | init_id() 32 | -------------------------------------------------------------------------------- /scripts/tabs/preparation/prepare_latents.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import presets, ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Prepare latents" 10 | 11 | 12 | def create_ui(): 13 | options = {} 14 | templates, script_file = load_args_template( 15 | "finetune", "prepare_buckets_latents.py" 16 | ) 17 | 18 | with gr.Column(): 19 | init_runner = initialize_runner(script_file, templates, options) 20 | with gr.Box(): 21 | with gr.Row(): 22 | init_ui = presets.create_ui( 23 | "finetune.prepare_buckets_latents", templates, options 24 | ) 25 | with gr.Box(): 26 | ui.title("Options") 27 | with gr.Column(): 28 | options_to_gradio(templates, options) 29 | 30 | init_runner() 31 | init_ui() 32 | -------------------------------------------------------------------------------- /scripts/tabs/preparation/tag_images_by_wd14tagger.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import presets, ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Tag images by wd1.4tagger" 10 | 11 | 12 | def create_ui(): 13 | options = {} 14 | templates, script_file = load_args_template( 15 | "finetune", "tag_images_by_wd14_tagger.py" 16 | ) 17 | 18 | with gr.Column(): 19 | init_runner = initialize_runner(script_file, templates, options) 20 | with gr.Box(): 21 | with gr.Row(): 22 | init_id = presets.create_ui( 23 | "finetune.tag_images_by_wd14_tagger", templates, options 24 | ) 25 | with gr.Box(): 26 | ui.title("Options") 27 | with gr.Column(): 28 | options_to_gradio(templates, options) 29 | 30 | init_runner() 31 | init_id() 32 | -------------------------------------------------------------------------------- /scripts/tabs/tools/convert_diffusers.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import presets, ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Convert Diffusers" 10 | 11 | 12 | def create_ui(): 13 | options = {} 14 | templates, script_file = load_args_template( 15 | "tools", "convert_diffusers20_original_sd.py" 16 | ) 17 | 18 | with gr.Column(): 19 | init_runner = initialize_runner(script_file, templates, options) 20 | with gr.Box(): 21 | with gr.Row(): 22 | init_ui = presets.create_ui( 23 | "convert_diffusers20_original_sd", templates, options 24 | ) 25 | with gr.Box(): 26 | ui.title("Options") 27 | with gr.Column(): 28 | options_to_gradio(templates, options) 29 | 30 | init_runner() 31 | init_ui() 32 | -------------------------------------------------------------------------------- /scripts/tabs/tools/detect_face_rotate.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import presets, ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Detect face rotate" 10 | 11 | 12 | def create_ui(): 13 | options = {} 14 | 15 | templates, script_file = load_args_template("tools", "detect_face_rotate.py") 16 | 17 | with gr.Column(): 18 | init_runner = initialize_runner(script_file, templates, options) 19 | with gr.Box(): 20 | with gr.Row(): 21 | init_ui = presets.create_ui( 22 | "tools.detect_face_rotate", templates, options 23 | ) 24 | with gr.Box(): 25 | ui.title("Options") 26 | with gr.Column(): 27 | options_to_gradio(templates, options) 28 | 29 | init_runner() 30 | init_ui() 31 | -------------------------------------------------------------------------------- /scripts/tabs/tools/resize_images_to_resolution.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from scripts import presets, ui 4 | from scripts.runner import initialize_runner 5 | from scripts.utilities import load_args_template, options_to_gradio 6 | 7 | 8 | def title(): 9 | return "Resize images to resolution" 10 | 11 | 12 | def create_ui(): 13 | options = {} 14 | 15 | templates, script_file = load_args_template( 16 | "tools", "resize_images_to_resolution.py" 17 | ) 18 | 19 | with gr.Column(): 20 | init_runner = initialize_runner(script_file, templates, options) 21 | with gr.Box(): 22 | with gr.Row(): 23 | init_ui = presets.create_ui( 24 | "tools.resize_images_to_resolution", templates, options 25 | ) 26 | with gr.Box(): 27 | ui.title("Options") 28 | with gr.Column(): 29 | options_to_gradio(templates, options) 30 | 31 | init_runner() 32 | init_ui() 33 | -------------------------------------------------------------------------------- /scripts/tabs/training/fine_tune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gradio as gr 4 | 5 | from kohya_ss.library import train_util, config_util 6 | from scripts import presets, ui, ui_overrides 7 | from scripts.runner import initialize_runner 8 | from scripts.utilities import args_to_gradio, load_args_template, options_to_gradio 9 | 10 | 11 | def title(): 12 | return "Fine tune" 13 | 14 | 15 | def create_ui(): 16 | sd_models_arguments = argparse.ArgumentParser() 17 | dataset_arguments = argparse.ArgumentParser() 18 | training_arguments = argparse.ArgumentParser() 19 | sd_saving_arguments = argparse.ArgumentParser() 20 | optimizer_arguments = argparse.ArgumentParser() 21 | config_arguments = argparse.ArgumentParser() 22 | train_util.add_sd_models_arguments(sd_models_arguments) 23 | train_util.add_dataset_arguments(dataset_arguments, False, True, True) 24 | train_util.add_training_arguments(training_arguments, False) 25 | train_util.add_sd_saving_arguments(sd_saving_arguments) 26 | train_util.add_optimizer_arguments(optimizer_arguments) 27 | config_util.add_config_arguments(config_arguments) 28 | sd_models_options = {} 29 | dataset_options = {} 30 | training_options = {} 31 | sd_saving_options = {} 32 | optimizer_options = {} 33 | config_options = {} 34 | finetune_options = {} 35 | 36 | templates, script_file = load_args_template("fine_tune.py") 37 | 38 | get_options = lambda: { 39 | **sd_models_options, 40 | **dataset_options, 41 | **training_options, 42 | **sd_saving_options, 43 | **optimizer_options, 44 | **finetune_options, 45 | **config_options, 46 | } 47 | 48 | get_templates = lambda: { 49 | **sd_models_arguments.__dict__["_option_string_actions"], 50 | **dataset_arguments.__dict__["_option_string_actions"], 51 | **training_arguments.__dict__["_option_string_actions"], 52 | **sd_saving_arguments.__dict__["_option_string_actions"], 53 | **optimizer_arguments.__dict__["_option_string_actions"], 54 | **config_arguments.__dict__["_option_string_actions"], 55 | **templates, 56 | } 57 | 58 | with gr.Column(): 59 | init_runner = initialize_runner(script_file, get_templates, get_options) 60 | with gr.Box(): 61 | with gr.Row(): 62 | init_ui = presets.create_ui("fine_tune", get_templates, get_options) 63 | with gr.Row(): 64 | with gr.Group(): 65 | with gr.Box(): 66 | ui.title("Fine tune options") 67 | options_to_gradio(templates, finetune_options) 68 | with gr.Box(): 69 | ui.title("Model options") 70 | args_to_gradio(sd_models_arguments, sd_models_options) 71 | with gr.Box(): 72 | ui.title("Dataset options") 73 | args_to_gradio(dataset_arguments, dataset_options) 74 | with gr.Box(): 75 | ui.title("Dataset Config options") 76 | args_to_gradio(config_arguments, config_options) 77 | with gr.Box(): 78 | ui.title("Training options") 79 | args_to_gradio(training_arguments, training_options) 80 | with gr.Group(): 81 | with gr.Box(): 82 | ui.title("Save options") 83 | args_to_gradio(sd_saving_arguments, sd_saving_options) 84 | with gr.Box(): 85 | ui.title("Optimizer options") 86 | args_to_gradio( 87 | optimizer_arguments, 88 | optimizer_options, 89 | ui_overrides.OPTIMIZER_OPTIONS, 90 | ) 91 | 92 | init_runner() 93 | init_ui() 94 | -------------------------------------------------------------------------------- /scripts/tabs/training/train_db.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gradio as gr 4 | 5 | from kohya_ss.library import train_util, config_util 6 | from scripts import presets, ui, ui_overrides 7 | from scripts.runner import initialize_runner 8 | from scripts.utilities import args_to_gradio, load_args_template, options_to_gradio 9 | 10 | 11 | def title(): 12 | return "Train dreambooth" 13 | 14 | 15 | def create_ui(): 16 | sd_models_arguments = argparse.ArgumentParser() 17 | dataset_arguments = argparse.ArgumentParser() 18 | training_arguments = argparse.ArgumentParser() 19 | sd_saving_arguments = argparse.ArgumentParser() 20 | optimizer_arguments = argparse.ArgumentParser() 21 | config_arguments = argparse.ArgumentParser() 22 | train_util.add_sd_models_arguments(sd_models_arguments) 23 | train_util.add_dataset_arguments(dataset_arguments, True, False, True) 24 | train_util.add_training_arguments(training_arguments, True) 25 | train_util.add_sd_saving_arguments(sd_saving_arguments) 26 | train_util.add_optimizer_arguments(optimizer_arguments) 27 | config_util.add_config_arguments(config_arguments) 28 | sd_models_options = {} 29 | dataset_options = {} 30 | training_options = {} 31 | sd_saving_options = {} 32 | optimizer_options = {} 33 | config_options = {} 34 | dreambooth_options = {} 35 | 36 | templates, script_file = load_args_template("train_db.py") 37 | 38 | get_options = lambda: { 39 | **sd_models_options, 40 | **dataset_options, 41 | **training_options, 42 | **sd_saving_options, 43 | **optimizer_options, 44 | **config_options, 45 | **dreambooth_options, 46 | } 47 | 48 | get_templates = lambda: { 49 | **sd_models_arguments.__dict__["_option_string_actions"], 50 | **dataset_arguments.__dict__["_option_string_actions"], 51 | **training_arguments.__dict__["_option_string_actions"], 52 | **sd_saving_arguments.__dict__["_option_string_actions"], 53 | **optimizer_arguments.__dict__["_option_string_actions"], 54 | **config_arguments.__dict__["_option_string_actions"], 55 | **templates, 56 | } 57 | 58 | with gr.Column(): 59 | init_runner = initialize_runner(script_file, get_templates, get_options) 60 | with gr.Box(): 61 | with gr.Row(): 62 | init_ui = presets.create_ui("train_db", get_templates, get_options) 63 | with gr.Row(): 64 | with gr.Group(): 65 | with gr.Box(): 66 | ui.title("Dreambooth options") 67 | options_to_gradio(templates, dreambooth_options) 68 | with gr.Box(): 69 | ui.title("Model options") 70 | args_to_gradio(sd_models_arguments, sd_models_options) 71 | with gr.Box(): 72 | ui.title("Dataset options") 73 | args_to_gradio(dataset_arguments, dataset_options) 74 | with gr.Box(): 75 | ui.title("Dataset Config options") 76 | args_to_gradio(config_arguments, config_options) 77 | with gr.Box(): 78 | ui.title("Training options") 79 | args_to_gradio(training_arguments, training_options) 80 | with gr.Group(): 81 | with gr.Box(): 82 | ui.title("Save options") 83 | args_to_gradio(sd_saving_arguments, sd_saving_options) 84 | with gr.Box(): 85 | ui.title("Optimizer options") 86 | args_to_gradio( 87 | optimizer_arguments, 88 | optimizer_options, 89 | ui_overrides.OPTIMIZER_OPTIONS, 90 | ) 91 | 92 | init_runner() 93 | init_ui() 94 | -------------------------------------------------------------------------------- /scripts/tabs/training/train_network.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gradio as gr 4 | 5 | from kohya_ss.library import train_util, config_util 6 | from scripts import presets, ui, ui_overrides 7 | from scripts.runner import initialize_runner 8 | from scripts.utilities import args_to_gradio, load_args_template, options_to_gradio 9 | 10 | 11 | def title(): 12 | return "Train network" 13 | 14 | 15 | def create_ui(): 16 | sd_models_arguments = argparse.ArgumentParser() 17 | dataset_arguments = argparse.ArgumentParser() 18 | training_arguments = argparse.ArgumentParser() 19 | optimizer_arguments = argparse.ArgumentParser() 20 | config_arguments = argparse.ArgumentParser() 21 | train_util.add_sd_models_arguments(sd_models_arguments) 22 | train_util.add_dataset_arguments(dataset_arguments, True, True, True) 23 | train_util.add_training_arguments(training_arguments, True) 24 | train_util.add_optimizer_arguments(optimizer_arguments) 25 | config_util.add_config_arguments(config_arguments) 26 | sd_models_options = {} 27 | dataset_options = {} 28 | training_options = {} 29 | optimizer_options = {} 30 | config_options = {} 31 | network_options = {} 32 | 33 | templates, script_file = load_args_template("train_network.py") 34 | 35 | get_options = lambda: { 36 | **sd_models_options, 37 | **dataset_options, 38 | **training_options, 39 | **optimizer_options, 40 | **config_options, 41 | **network_options, 42 | } 43 | 44 | get_templates = lambda: { 45 | **sd_models_arguments.__dict__["_option_string_actions"], 46 | **dataset_arguments.__dict__["_option_string_actions"], 47 | **training_arguments.__dict__["_option_string_actions"], 48 | **optimizer_arguments.__dict__["_option_string_actions"], 49 | **config_arguments.__dict__["_option_string_actions"], 50 | **templates, 51 | } 52 | 53 | with gr.Column(): 54 | init_runner = initialize_runner(script_file, get_templates, get_options) 55 | with gr.Box(): 56 | with gr.Row(): 57 | init_id = presets.create_ui("train_network", get_templates, get_options) 58 | with gr.Row(): 59 | with gr.Group(): 60 | with gr.Box(): 61 | ui.title("Network options") 62 | options_to_gradio(templates, network_options) 63 | with gr.Box(): 64 | ui.title("Model options") 65 | args_to_gradio(sd_models_arguments, sd_models_options) 66 | with gr.Box(): 67 | ui.title("Dataset Config options") 68 | args_to_gradio(config_arguments, config_options) 69 | with gr.Box(): 70 | ui.title("Dataset options") 71 | args_to_gradio(dataset_arguments, dataset_options) 72 | with gr.Box(): 73 | ui.title("Training options") 74 | args_to_gradio(training_arguments, training_options) 75 | with gr.Box(): 76 | ui.title("Optimizer options") 77 | args_to_gradio( 78 | optimizer_arguments, 79 | optimizer_options, 80 | ui_overrides.OPTIMIZER_OPTIONS, 81 | ) 82 | 83 | init_runner() 84 | init_id() 85 | -------------------------------------------------------------------------------- /scripts/tabs/training/train_textual_inversion.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gradio as gr 4 | 5 | from kohya_ss.library import train_util, config_util 6 | from scripts import presets, ui, ui_overrides 7 | from scripts.runner import initialize_runner 8 | from scripts.utilities import args_to_gradio, load_args_template, options_to_gradio 9 | 10 | 11 | def title(): 12 | return "Train textual inversion" 13 | 14 | 15 | def create_ui(): 16 | sd_models_arguments = argparse.ArgumentParser() 17 | dataset_arguments = argparse.ArgumentParser() 18 | training_arguments = argparse.ArgumentParser() 19 | optimizer_arguments = argparse.ArgumentParser() 20 | config_arguments = argparse.ArgumentParser() 21 | train_util.add_sd_models_arguments(sd_models_arguments) 22 | train_util.add_dataset_arguments(dataset_arguments, True, True, False) 23 | train_util.add_training_arguments(training_arguments, True) 24 | train_util.add_optimizer_arguments(optimizer_arguments) 25 | config_util.add_config_arguments(config_arguments) 26 | sd_models_options = {} 27 | dataset_options = {} 28 | training_options = {} 29 | optimizer_options = {} 30 | config_options = {} 31 | ti_options = {} 32 | 33 | templates, script_file = load_args_template("train_textual_inversion.py") 34 | 35 | get_options = lambda: { 36 | **sd_models_options, 37 | **dataset_options, 38 | **training_options, 39 | **optimizer_options, 40 | **config_options, 41 | **ti_options, 42 | } 43 | 44 | get_templates = lambda: { 45 | **sd_models_arguments.__dict__["_option_string_actions"], 46 | **dataset_arguments.__dict__["_option_string_actions"], 47 | **training_arguments.__dict__["_option_string_actions"], 48 | **optimizer_arguments.__dict__["_option_string_actions"], 49 | **config_arguments.__dict__["_option_string_actions"], 50 | **templates, 51 | } 52 | 53 | with gr.Column(): 54 | init_runner = initialize_runner(script_file, get_templates, get_options) 55 | with gr.Box(): 56 | with gr.Row(): 57 | init_ui = presets.create_ui( 58 | "train_textual_inversion", get_templates, get_options 59 | ) 60 | with gr.Row(): 61 | with gr.Group(): 62 | with gr.Box(): 63 | ui.title("Textual inversion options") 64 | options_to_gradio(templates, ti_options) 65 | with gr.Box(): 66 | ui.title("Model options") 67 | args_to_gradio(sd_models_arguments, sd_models_options) 68 | with gr.Box(): 69 | ui.title("Dataset Config options") 70 | args_to_gradio(config_arguments, config_options) 71 | with gr.Box(): 72 | ui.title("Dataset options") 73 | args_to_gradio(dataset_arguments, dataset_options) 74 | with gr.Box(): 75 | ui.title("Training options") 76 | args_to_gradio(training_arguments, training_options) 77 | with gr.Box(): 78 | ui.title("Optimizer options") 79 | args_to_gradio( 80 | optimizer_arguments, 81 | optimizer_options, 82 | ui_overrides.OPTIMIZER_OPTIONS, 83 | ) 84 | 85 | init_runner() 86 | init_ui() 87 | -------------------------------------------------------------------------------- /scripts/ui.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import importlib 3 | import os 4 | import sys 5 | 6 | import gradio as gr 7 | 8 | import scripts.shared as shared 9 | from scripts.shared import ROOT_DIR 10 | from scripts.utilities import path_to_module 11 | 12 | 13 | def title(txt): 14 | gr.HTML( 15 | f'

{txt}

', 16 | ) 17 | 18 | 19 | def create_ui(css): 20 | PATHS = [ 21 | os.path.join(ROOT_DIR, "kohya_ss", "library"), 22 | ROOT_DIR, 23 | ] 24 | sys.path.extend(PATHS) 25 | with gr.Blocks(css=css, analytics_enabled=False) as ui: 26 | with gr.Tabs(elem_id="kohya_sd_webui__root"): 27 | tabs_dir = os.path.join(ROOT_DIR, "scripts", "tabs") 28 | for category in os.listdir(tabs_dir): 29 | dir = os.path.join(tabs_dir, category) 30 | tabs = glob.glob(os.path.join(dir, "*.py")) 31 | sys.path.append(dir) 32 | if len(tabs) < 1: 33 | continue 34 | with gr.TabItem(category): 35 | for lib in tabs: 36 | try: 37 | module_path = path_to_module(lib) 38 | module_name = module_path.replace(".", "_") 39 | 40 | module = importlib.import_module(module_path) 41 | shared.current_tab = module_name 42 | shared.loaded_tabs.append(module_name) 43 | 44 | with gr.TabItem(module.title()): 45 | module.create_ui() 46 | except Exception as e: 47 | print(f"Failed to load {module_path}") 48 | print(e) 49 | sys.path.remove(dir) 50 | with gr.TabItem("terminal"): 51 | gr.HTML('
') 52 | sys.path = [x for x in sys.path if x not in PATHS] 53 | return ui 54 | -------------------------------------------------------------------------------- /scripts/ui_overrides.py: -------------------------------------------------------------------------------- 1 | OPTIMIZER_OPTIONS = { 2 | "optimizer_type": { 3 | "type": list, 4 | "choices": [ 5 | "AdamW", 6 | "AdamW8bit", 7 | "Lion", 8 | "SGDNesterov", 9 | "SGDNesterov8bit", 10 | "DAdaptation", 11 | "AdaFactor", 12 | ], 13 | }, 14 | "lr_scheduler": { 15 | "type": list, 16 | "choices": [ 17 | "linear", 18 | "cosine", 19 | "cosine_with_restarts", 20 | "polynomial", 21 | "constant", 22 | "constant_with_warmup", 23 | "adafactor", 24 | ], 25 | }, 26 | } 27 | -------------------------------------------------------------------------------- /scripts/utilities.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import importlib 3 | import os 4 | import subprocess 5 | import sys 6 | 7 | import gradio as gr 8 | 9 | import scripts.shared as shared 10 | from scripts.shared import ROOT_DIR 11 | 12 | python = sys.executable 13 | 14 | 15 | def path_to_module(filepath): 16 | return ( 17 | os.path.relpath(filepath, ROOT_DIR).replace(os.path.sep, ".").replace(".py", "") 18 | ) 19 | 20 | 21 | def which(program): 22 | def is_exe(fpath): 23 | return os.path.isfile(fpath) and os.access(fpath, os.X_OK) 24 | 25 | fpath, _ = os.path.split(program) 26 | if fpath: 27 | if is_exe(program): 28 | return program 29 | else: 30 | for path in os.environ["PATH"].split(os.pathsep): 31 | path = path.strip('"') 32 | exe_file = os.path.join(path, program) 33 | if is_exe(exe_file): 34 | return exe_file 35 | 36 | return None 37 | 38 | 39 | def literal_eval(v, module=None): 40 | if v == "str": 41 | return str 42 | elif v == "int": 43 | return int 44 | elif v == "float": 45 | return float 46 | elif v == list: 47 | return list 48 | else: 49 | if module: 50 | try: 51 | m = importlib.import_module(module) 52 | if hasattr(m, v): 53 | return getattr(m, v) 54 | except: 55 | () 56 | 57 | return ast.literal_eval(v) 58 | 59 | 60 | def compile_arg_parser(txt, module_path=None): 61 | in_parser = False 62 | parsers = {} 63 | args = [] 64 | arg = "" 65 | in_list = False 66 | in_str = None 67 | 68 | def compile(arg): 69 | arg = arg.strip() 70 | matches = arg.split("=") 71 | 72 | if len(matches) > 1: 73 | k = "".join(matches[:1]) 74 | v = literal_eval("".join(matches[1:]), module_path) 75 | return (k, v) 76 | else: 77 | return literal_eval(arg, module_path) 78 | 79 | for line in txt.split("\n"): 80 | line = line.split("#")[0] 81 | 82 | if "parser.add_argument(" in line: 83 | in_parser = True 84 | line = line.replace("parser.add_argument(", "") 85 | 86 | if not in_parser: 87 | continue 88 | 89 | for c in line: 90 | 91 | if in_str is None and c == ")": 92 | if arg.strip(): 93 | args.append(compile(arg)) 94 | in_parser = False 95 | [dest, *others] = args 96 | parsers[dest] = {"dest": dest.replace("--", ""), **dict(others)} 97 | arg = "" 98 | args = [] 99 | break 100 | 101 | if c == "[": 102 | in_list = True 103 | elif c == "]": 104 | in_list = False 105 | if c == '"' or c == "'": 106 | if in_str is not None and in_str == c: 107 | in_str = None 108 | elif in_str is None: 109 | in_str = c 110 | 111 | if c == "," and not in_list and in_str is None: 112 | args.append(compile(arg)) 113 | arg = "" 114 | continue 115 | 116 | arg += c 117 | 118 | if arg.strip(): 119 | args.append(compile(arg)) 120 | return parsers 121 | 122 | 123 | def load_args_template(*filename): 124 | repo_dir = os.path.join(ROOT_DIR, "kohya_ss") 125 | filepath = os.path.join(repo_dir, *filename) 126 | with open(filepath, mode="r", encoding="utf-8_sig") as f: 127 | lines = f.readlines() 128 | add = False 129 | txt = "" 130 | for line in lines: 131 | if add == True: 132 | txt += line 133 | if "def setup_parser()" in line: 134 | add = True 135 | continue 136 | return compile_arg_parser(txt, path_to_module(filepath)), filepath 137 | 138 | 139 | def check_key(d, k): 140 | return k in d and d[k] is not None 141 | 142 | 143 | def get_arg_type(d): 144 | if check_key(d, "choices"): 145 | return list 146 | if check_key(d, "type"): 147 | return d["type"] 148 | if check_key(d, "action") and ( 149 | d["action"] == "store_true" or d["action"] == "store_false" 150 | ): 151 | return bool 152 | if check_key(d, "const") and type(d["const"]) == bool: 153 | return bool 154 | return str 155 | 156 | 157 | def options_to_gradio(options, out, overrides={}): 158 | for _, item in options.items(): 159 | item = item.__dict__ if hasattr(item, "__dict__") else item 160 | key = item["dest"] 161 | if key == "help": 162 | continue 163 | override = overrides[key] if key in overrides else {} 164 | component = None 165 | 166 | help = item["help"] if "help" in item else "" 167 | id = f"kohya_sd_webui__{shared.current_tab.replace('.', '_')}_{key}" 168 | type = override["type"] if "type" in override else get_arg_type(item) 169 | if type == list: 170 | choices = [ 171 | c if c is not None else "None" 172 | for c in ( 173 | override["choices"] if "choices" in override else item["choices"] 174 | ) 175 | ] 176 | component = gr.Radio( 177 | choices=choices, 178 | value=item["default"] if check_key(item, "default") else choices[0], 179 | label=key, 180 | elem_id=id, 181 | interactive=True, 182 | ) 183 | elif type == bool: 184 | component = gr.Checkbox( 185 | value=item["default"] if check_key(item, "default") else False, 186 | label=key, 187 | elem_id=id, 188 | interactive=True, 189 | ) 190 | else: 191 | component = gr.Textbox( 192 | value=item["default"] if check_key(item, "default") else "", 193 | label=key, 194 | elem_id=id, 195 | interactive=True, 196 | ).style() 197 | 198 | shared.help_title_map[id] = help 199 | out[key] = component 200 | 201 | 202 | def args_to_gradio(args, out, overrides={}): 203 | options_to_gradio(args.__dict__["_option_string_actions"], out, overrides) 204 | 205 | 206 | def gradio_to_args(arguments, options, args, strarg=False): 207 | def find_arg(key): 208 | for k, arg in arguments.items(): 209 | arg = arg.__dict__ if hasattr(arg, "__dict__") else arg 210 | if arg["dest"] == key: 211 | return k, arg 212 | return None, None 213 | 214 | def get_value(key): 215 | item = args[options[key]] 216 | raw_key, arg = find_arg(key) 217 | arg_type = get_arg_type(arg) 218 | multiple = "nargs" in arg and arg["nargs"] == "*" 219 | 220 | def set_type(x): 221 | if x is None or x == "None": 222 | return None 223 | elif arg_type is None: 224 | return x 225 | elif arg_type == list: 226 | return x 227 | return arg_type(x) 228 | 229 | if multiple and item is None or item == "": 230 | return raw_key, None 231 | 232 | return raw_key, ( 233 | [set_type(x) for x in item.split(" ")] if multiple else set_type(item) 234 | ) 235 | 236 | if strarg: 237 | main = [] 238 | optional = {} 239 | 240 | for k in options: 241 | key, v = get_value(k) 242 | if key.startswith("--"): 243 | key = k.replace("--", "") 244 | optional[key] = v 245 | else: 246 | main.append(v) 247 | 248 | main = [x for x in main if x is not None] 249 | 250 | return main, optional 251 | else: 252 | result = {} 253 | for k in options: 254 | _, v = get_value(k) 255 | result[k] = v 256 | return result 257 | 258 | 259 | def make_args(d): 260 | arguments = [] 261 | for k, v in d.items(): 262 | if type(v) == bool: 263 | arguments.append(f"--{k}" if v else "") 264 | elif type(v) == list and len(v) > 0: 265 | arguments.extend([f"--{k}", *v]) 266 | elif type(v) == str and v: 267 | arguments.extend([f"--{k}", f"{v}"]) 268 | elif v: 269 | arguments.extend([f"--{k}", f"{v}"]) 270 | return arguments 271 | 272 | 273 | def run_python(script, templates, options, args): 274 | main, optional = gradio_to_args(templates, options, args, strarg=True) 275 | args = [x for x in [*main, *make_args(optional)] if x] 276 | proc_args = [python, "-u", script, *args] 277 | print("Start process: ", " ".join(proc_args)) 278 | 279 | ps = subprocess.Popen( 280 | proc_args, 281 | stdout=subprocess.PIPE, 282 | stderr=subprocess.STDOUT, 283 | cwd=os.path.join(ROOT_DIR, "kohya_ss"), 284 | ) 285 | return ps 286 | -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | #kohya_sd_webui__terminal_outputs { 2 | height: 80vh; 3 | overflow-y: auto; 4 | } 5 | 6 | button[id^='kohya_sd_webui__'][id$='_stop_button'] { 7 | display: none; 8 | } -------------------------------------------------------------------------------- /sub.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from time import sleep 3 | 4 | progress_bar = tqdm(range(10), smoothing=0, desc="steps") 5 | for i in range(10): 6 | sleep(5) 7 | progress_bar.update(1) 8 | progress_bar.set_postfix({"log": f"sleeping 5 sec {i}"}) 9 | -------------------------------------------------------------------------------- /update.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | git fetch --prune 3 | git reset --hard origin/main 4 | pause -------------------------------------------------------------------------------- /update.sh: -------------------------------------------------------------------------------- 1 | git fetch --prune 2 | git reset --hard origin/main -------------------------------------------------------------------------------- /webui.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | if not defined PYTHON (set PYTHON=python) 4 | if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv") 5 | 6 | set ERROR_REPORTING=FALSE 7 | 8 | mkdir tmp 2>NUL 9 | 10 | %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt 11 | if %ERRORLEVEL% == 0 goto :start_venv 12 | echo Couldn't launch python 13 | goto :show_stdout_stderr 14 | 15 | :start_venv 16 | if ["%VENV_DIR%"] == ["-"] goto :skip_venv 17 | 18 | dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt 19 | if %ERRORLEVEL% == 0 goto :activate_venv 20 | 21 | for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i" 22 | echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME% 23 | %PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt 24 | if %ERRORLEVEL% == 0 goto :activate_venv 25 | echo Unable to create venv in directory "%VENV_DIR%" 26 | goto :show_stdout_stderr 27 | 28 | :activate_venv 29 | set PYTHON="%VENV_DIR%\Scripts\Python.exe" 30 | echo venv %PYTHON% 31 | if [%ACCELERATE%] == ["True"] goto :accelerate 32 | goto :launch 33 | 34 | :skip_venv 35 | 36 | :accelerate 37 | echo "Checking for accelerate" 38 | set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe" 39 | if EXIST %ACCELERATE% goto :accelerate_launch 40 | 41 | :launch 42 | %PYTHON% launch.py %* 43 | pause 44 | exit /b 45 | 46 | :accelerate_launch 47 | echo "Accelerating" 48 | %ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py 49 | pause 50 | exit /b 51 | 52 | :show_stdout_stderr 53 | 54 | echo. 55 | echo exit code: %errorlevel% 56 | 57 | for /f %%i in ("tmp\stdout.txt") do set size=%%~zi 58 | if %size% equ 0 goto :show_stderr 59 | echo. 60 | echo stdout: 61 | type tmp\stdout.txt 62 | 63 | :show_stderr 64 | for /f %%i in ("tmp\stderr.txt") do set size=%%~zi 65 | if %size% equ 0 goto :show_stderr 66 | echo. 67 | echo stderr: 68 | type tmp\stderr.txt 69 | 70 | :endofscript 71 | 72 | echo. 73 | echo Launch unsuccessful. Exiting. 74 | pause 75 | -------------------------------------------------------------------------------- /webui.sh: -------------------------------------------------------------------------------- 1 | # python3 executable 2 | if [[ -z "${python_cmd}" ]] 3 | then 4 | python_cmd="python3" 5 | fi 6 | 7 | # git executable 8 | if [[ -z "${GIT}" ]] 9 | then 10 | export GIT="git" 11 | fi 12 | 13 | # python3 venv without trailing slash 14 | if [[ -z "${venv_dir}" ]] 15 | then 16 | venv_dir="venv" 17 | fi 18 | 19 | if [[ -z "${LAUNCH_SCRIPT}" ]] 20 | then 21 | LAUNCH_SCRIPT="launch.py" 22 | fi 23 | 24 | # this script cannot be run as root by default 25 | can_run_as_root=0 26 | delimiter="################################################################" 27 | 28 | printf "\n%s\n" "${delimiter}" 29 | printf "Create and activate python venv" 30 | printf "\n%s\n" "${delimiter}" 31 | if [[ ! -d "${venv_dir}" ]] 32 | then 33 | "${python_cmd}" -m venv "${venv_dir}" 34 | first_launch=1 35 | fi 36 | # shellcheck source=/dev/null 37 | if [[ -f "${venv_dir}"/bin/activate ]] 38 | then 39 | source "${venv_dir}"/bin/activate 40 | else 41 | printf "\n%s\n" "${delimiter}" 42 | printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m" 43 | printf "\n%s\n" "${delimiter}" 44 | exit 1 45 | fi 46 | 47 | if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ] 48 | then 49 | printf "\n%s\n" "${delimiter}" 50 | printf "Accelerating launch.py..." 51 | printf "\n%s\n" "${delimiter}" 52 | exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@" 53 | else 54 | printf "\n%s\n" "${delimiter}" 55 | printf "Launching launch.py..." 56 | printf "\n%s\n" "${delimiter}" 57 | exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" 58 | fi 59 | --------------------------------------------------------------------------------