├── logs └── .keep ├── output └── .keep ├── huggingface ├── hub │ └── version.txt └── accelerate │ └── default_config.yaml ├── .gitattributes ├── sd-models └── put stable diffusion model here.txt ├── mikazuki ├── app │ ├── __init__.py │ ├── models.py │ ├── application.py │ ├── proxy.py │ └── api.py ├── tsconfig.json ├── scripts │ ├── fix_scripts_python_executable_path.py │ └── torch_check.py ├── utils │ ├── tk_window.py │ ├── devices.py │ └── train_utils.py ├── log.py ├── tagger │ ├── format.py │ ├── dbimutils.py │ └── interrogator.py ├── process.py ├── tasks.py ├── schema │ ├── lora-basic.ts │ ├── dreambooth-master.ts │ └── lora-master.ts ├── global.d.ts └── launch_utils.py ├── install-cn.ps1 ├── train_by_toml.ps1 ├── tensorboard.ps1 ├── assets ├── tensorboard-example.png └── gitconfig-cn ├── run_gui.sh ├── run_gui.ps1 ├── config ├── sample_prompts.txt ├── default.toml └── lora.toml ├── .gitmodules ├── .gitignore ├── Dockerfile ├── install.ps1 ├── requirements.txt ├── train_by_toml.sh ├── interrogate.ps1 ├── resize.ps1 ├── svd_merge.ps1 ├── tagger.sh ├── tagger.ps1 ├── run.ipynb ├── install.bash ├── gui.py ├── train.ipynb ├── README-zh.md ├── README.md ├── train.sh ├── train.ps1 └── LICENSE /logs/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /huggingface/hub/version.txt: -------------------------------------------------------------------------------- 1 | 1 -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ps1 text eol=crlf -------------------------------------------------------------------------------- /sd-models/put stable diffusion model here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mikazuki/app/__init__.py: -------------------------------------------------------------------------------- 1 | from . import application 2 | 3 | app = application.app -------------------------------------------------------------------------------- /mikazuki/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "include": [ 3 | "**/*.ts" 4 | ], 5 | } -------------------------------------------------------------------------------- /install-cn.ps1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrazyBoyM/lora-scripts/main/install-cn.ps1 -------------------------------------------------------------------------------- /train_by_toml.ps1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrazyBoyM/lora-scripts/main/train_by_toml.ps1 -------------------------------------------------------------------------------- /tensorboard.ps1: -------------------------------------------------------------------------------- 1 | $Env:TF_CPP_MIN_LOG_LEVEL = "3" 2 | 3 | .\venv\Scripts\activate 4 | tensorboard --logdir=logs -------------------------------------------------------------------------------- /assets/tensorboard-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrazyBoyM/lora-scripts/main/assets/tensorboard-example.png -------------------------------------------------------------------------------- /run_gui.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export HF_HOME=huggingface 4 | export PYTHONUTF8=1 5 | 6 | python gui.py "$@" 7 | 8 | -------------------------------------------------------------------------------- /run_gui.ps1: -------------------------------------------------------------------------------- 1 | .\venv\Scripts\activate 2 | 3 | $Env:HF_HOME = "huggingface" 4 | $Env:PYTHONUTF8 = "1" 5 | 6 | python gui.py -------------------------------------------------------------------------------- /config/sample_prompts.txt: -------------------------------------------------------------------------------- 1 | (masterpiece, best quality:1.2), 1girl, solo, --n lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts,signature, watermark, username, blurry, --w 512 --h 768 --l 7 --s 24 --d 1337 -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "sd-scripts"] 2 | path = sd-scripts 3 | url = https://github.com/kohya-ss/sd-scripts.git 4 | [submodule "frontend"] 5 | path = frontend 6 | url = https://github.com/hanamizuki-ai/lora-gui-dist 7 | [submodule "mikazuki/dataset-tag-editor"] 8 | path = mikazuki/dataset-tag-editor 9 | url = https://github.com/Akegarasu/dataset-tag-editor 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .idea 3 | 4 | venv 5 | __pycache__ 6 | 7 | output/* 8 | !output/.keep 9 | 10 | py310 11 | python 12 | git 13 | wd14_tagger_model 14 | 15 | train/* 16 | logs/* 17 | sd-models/* 18 | toml/autosave/* 19 | config/autosave/* 20 | !sd-models/put stable diffusion model here.txt 21 | !logs/.keep 22 | 23 | tests/ 24 | 25 | huggingface/hub/models* 26 | huggingface/hub/version_diffusers_cache.txt -------------------------------------------------------------------------------- /assets/gitconfig-cn: -------------------------------------------------------------------------------- 1 | [url "https://jihulab.com/Akegarasu/lora-scripts"] 2 | insteadOf = https://github.com/Akegarasu/lora-scripts 3 | 4 | [url "https://jihulab.com/affair3547/sd-scripts"] 5 | insteadOf = https://github.com/kohya-ss/sd-scripts.git 6 | 7 | [url "https://jihulab.com/affair3547/lora-gui-dist"] 8 | insteadOf = https://github.com/hanamizuki-ai/lora-gui-dist 9 | 10 | [url "https://jihulab.com/Akegarasu/dataset-tag-editor"] 11 | insteadOf = https://github.com/Akegarasu/dataset-tag-editor 12 | -------------------------------------------------------------------------------- /huggingface/accelerate/default_config.yaml: -------------------------------------------------------------------------------- 1 | command_file: null 2 | commands: null 3 | compute_environment: LOCAL_MACHINE 4 | deepspeed_config: {} 5 | distributed_type: 'NO' 6 | downcast_bf16: 'no' 7 | dynamo_backend: 'NO' 8 | fsdp_config: {} 9 | gpu_ids: all 10 | machine_rank: 0 11 | main_process_ip: null 12 | main_process_port: null 13 | main_training_function: main 14 | megatron_lm_config: {} 15 | mixed_precision: fp16 16 | num_machines: 1 17 | num_processes: 1 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_name: null 21 | tpu_zone: null 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.07-py3 2 | 3 | EXPOSE 28000 4 | 5 | ENV TZ=Asia/Shanghai 6 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && apt update && apt install python3-tk -y 7 | 8 | RUN mkdir /app 9 | 10 | WORKDIR /app 11 | RUN git clone --recurse-submodules https://github.com/Akegarasu/lora-scripts 12 | 13 | WORKDIR /app/lora-scripts 14 | RUN pip install xformers==0.0.21 --no-deps && pip install -r requirements.txt 15 | 16 | WORKDIR /app/lora-scripts/sd-scripts 17 | RUN pip install -r requirements.txt 18 | 19 | WORKDIR /app/lora-scripts 20 | 21 | CMD ["python", "gui.py", "--listen"] -------------------------------------------------------------------------------- /install.ps1: -------------------------------------------------------------------------------- 1 | $Env:HF_HOME = "huggingface" 2 | 3 | if (!(Test-Path -Path "venv")) { 4 | Write-Output "Creating venv for python..." 5 | python -m venv venv 6 | } 7 | .\venv\Scripts\activate 8 | 9 | Write-Output "Installing deps..." 10 | Set-Location .\sd-scripts 11 | pip install torch==2.2.1+cu118 torchvision==0.17.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 12 | pip install -U -I --no-deps xformers==0.0.25+cu118 13 | pip install --upgrade -r requirements.txt 14 | 15 | Set-Location .. 16 | pip install --upgrade -r requirements.txt 17 | 18 | Write-Output "Install completed" 19 | Read-Host | Out-Null ; -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.25.0 2 | transformers==4.36.2 3 | diffusers[torch]==0.25.0 4 | ftfy==6.1.1 5 | # albumentations==1.3.0 6 | opencv-python==4.7.0.68 7 | einops==0.7.0 8 | pytorch-lightning==1.9.0 9 | bitsandbytes==0.43.0 10 | prodigyopt==1.0 11 | lion-pytorch==0.1.2 12 | tensorboard==2.10.1 13 | safetensors==0.4.2 14 | # gradio==3.16.2 15 | altair==4.2.2 16 | easygui==0.98.3 17 | toml==0.10.2 18 | voluptuous==0.13.1 19 | huggingface-hub==0.20.1 20 | # for Image utils 21 | imagesize==1.4.1 22 | # for ui 23 | rich 24 | pandas 25 | scipy 26 | requests 27 | pillow 28 | numpy 29 | gradio==3.44.2 30 | fastapi==0.95.1 31 | uvicorn==0.22.0 32 | wandb==0.16.2 33 | httpx==0.24.1 34 | # extra 35 | open-clip-torch==2.20.0 36 | lycoris-lora==2.1.0.post3 37 | dadaptation==3.1 -------------------------------------------------------------------------------- /mikazuki/scripts/fix_scripts_python_executable_path.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | from pathlib import Path 5 | 6 | py_path = sys.executable 7 | scripts_path = Path(sys.executable).parent 8 | 9 | if scripts_path.name != "Scripts": 10 | print("Seems your env not venv, do you want to continue? [y/n]") 11 | sure = input() 12 | if sure != "y": 13 | sys.exit(1) 14 | 15 | scripts_list = os.listdir(scripts_path) 16 | 17 | for script in scripts_list: 18 | if not script.endswith(".exe") or script in ["python.exe", "pythonw.exe"]: 19 | continue 20 | 21 | with open(os.path.join(scripts_path, script), "rb+") as f: 22 | s = f.read() 23 | spl = re.split(b'(#!.*python\.exe)', s) 24 | if len(spl) == 3: 25 | spl[1] = bytes(b"#!"+sys.executable.encode()) 26 | f.seek(0) 27 | f.write(b''.join(spl)) 28 | print(f"fixed {script}") -------------------------------------------------------------------------------- /train_by_toml.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # LoRA train script by @Akegarasu 3 | 4 | config_file="./config/default.toml" # config file | 使用 toml 文件指定训练参数 5 | sample_prompts="./config/sample_prompts.txt" # prompt file for sample | 采样 prompts 文件, 留空则不启用采样功能 6 | 7 | sdxl=0 # train sdxl LoRA | 训练 SDXL LoRA 8 | multi_gpu=0 # multi gpu | 多显卡训练 该参数仅限在显卡数 >= 2 使用 9 | 10 | # ============= DO NOT MODIFY CONTENTS BELOW | 请勿修改下方内容 ===================== 11 | 12 | export HF_HOME="huggingface" 13 | export TF_CPP_MIN_LOG_LEVEL=3 14 | export PYTHONUTF8=1 15 | 16 | extArgs=() 17 | launchArgs=() 18 | 19 | if [[ $multi_gpu == 1 ]]; then 20 | launchArgs+=("--multi_gpu") 21 | launchArgs+=("--num_processes=2") 22 | fi 23 | 24 | # run train 25 | if [[ $sdxl == 1 ]]; then 26 | script_name="sdxl_train_network.py" 27 | else 28 | script_name="train_network.py" 29 | fi 30 | 31 | python -m accelerate.commands.launch "${launchArgs[@]}" --num_cpu_threads_per_process=8 "./sd-scripts/$script_name" \ 32 | --config_file="$config_file" \ 33 | --sample_prompts="$sample_prompts" \ 34 | "${extArgs[@]}" 35 | -------------------------------------------------------------------------------- /mikazuki/utils/tk_window.py: -------------------------------------------------------------------------------- 1 | from mikazuki.log import log 2 | try: 3 | import tkinter 4 | from tkinter.filedialog import askdirectory, askopenfilename 5 | except ImportError: 6 | tkinter = None 7 | askdirectory = None 8 | askopenfilename = None 9 | log.warning("tkinter not found, file selector will not work.") 10 | 11 | 12 | def tk_window(): 13 | window = tkinter.Tk() 14 | window.wm_attributes('-topmost', 1) 15 | window.withdraw() 16 | 17 | 18 | def open_file_selector( 19 | initialdir, 20 | title, 21 | filetypes) -> str: 22 | try: 23 | tk_window() 24 | filename = askopenfilename( 25 | initialdir=initialdir, title=title, 26 | filetypes=filetypes 27 | ) 28 | return filename 29 | except: 30 | return "" 31 | 32 | 33 | def open_directory_selector(initialdir) -> str: 34 | try: 35 | tk_window() 36 | directory = askdirectory( 37 | initialdir=initialdir 38 | ) 39 | return directory 40 | except: 41 | return "" 42 | -------------------------------------------------------------------------------- /mikazuki/scripts/torch_check.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | def check_torch_gpu(): 4 | try: 5 | import torch 6 | print(f'Torch {torch.__version__}') 7 | if torch.cuda.is_available(): 8 | if torch.version.cuda: 9 | print( 10 | f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}') 11 | for device in [torch.cuda.device(i) for i in range(torch.cuda.device_count())]: 12 | print(f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}') 13 | else: 14 | print("Torch is not able to use GPU, please check your torch installation.\n Use --skip-prepare-environment to disable this check") 15 | except Exception as e: 16 | print(f'Could not load torch: {e}') 17 | sys.exit(1) 18 | 19 | check_torch_gpu() -------------------------------------------------------------------------------- /mikazuki/app/models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import List, Optional, Union, Dict, Any 3 | 4 | 5 | class TaggerInterrogateRequest(BaseModel): 6 | path: str 7 | interrogator_model: str = Field( 8 | default="wd14-convnextv2-v2" 9 | ) 10 | threshold: float = Field( 11 | default=0.35, 12 | ge=0, 13 | le=1 14 | ) 15 | additional_tags: str = "" 16 | exclude_tags: str = "" 17 | escape_tag: bool = True 18 | batch_input_recursive: bool = False 19 | batch_output_action_on_conflict: str = "ignore" 20 | replace_underscore: bool = True 21 | replace_underscore_excludes: str = Field( 22 | default="0_0, (o)_(o), +_+, +_-, ._., _, <|>_<|>, =_=, >_<, 3_3, 6_9, >_o, @_@, ^_^, o_o, u_u, x_x, |_|, ||_||" 23 | ) 24 | 25 | 26 | class APIResponse(BaseModel): 27 | status: str 28 | message: Optional[str] 29 | data: Optional[Dict] 30 | 31 | 32 | class APIResponseSuccess(APIResponse): 33 | status: str = "success" 34 | 35 | 36 | class APIResponseFail(APIResponse): 37 | status: str = "fail" 38 | -------------------------------------------------------------------------------- /interrogate.ps1: -------------------------------------------------------------------------------- 1 | # LoRA interrogate script by @bdsqlsz 2 | 3 | $v2 = 0 # load Stable Diffusion v2.x model / Stable Diffusion 2.x模型读取 4 | $sd_model = "./sd-models/sd_model.safetensors" # Stable Diffusion model to load: ckpt or safetensors file | 读取的基础SD模型, 保存格式 cpkt 或 safetensors 5 | $model = "./output/LoRA.safetensors" # LoRA model to interrogate: ckpt or safetensors file | 需要调查关键字的LORA模型, 保存格式 cpkt 或 safetensors 6 | $batch_size = 64 # batch size for processing with Text Encoder | 使用 Text Encoder 处理时的批量大小,默认16,推荐64/128 7 | $clip_skip = 1 # use output of nth layer from back of text encoder (n>=1) | 使用文本编码器倒数第 n 层的输出,n 可以是大于等于 1 的整数 8 | 9 | 10 | # Activate python venv 11 | .\venv\Scripts\activate 12 | 13 | $Env:HF_HOME = "huggingface" 14 | $ext_args = [System.Collections.ArrayList]::new() 15 | 16 | if ($v2) { 17 | [void]$ext_args.Add("--v2") 18 | } 19 | 20 | # run interrogate 21 | accelerate launch --num_cpu_threads_per_process=8 "./sd-scripts/networks/lora_interrogator.py" ` 22 | --sd_model=$sd_model ` 23 | --model=$model ` 24 | --batch_size=$batch_size ` 25 | --clip_skip=$clip_skip ` 26 | $ext_args 27 | 28 | Write-Output "Interrogate finished" 29 | Read-Host | Out-Null ; 30 | -------------------------------------------------------------------------------- /mikazuki/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | log = logging.getLogger('sd-trainer') 5 | log.setLevel(logging.DEBUG) 6 | 7 | try: 8 | from rich.console import Console 9 | from rich.logging import RichHandler 10 | from rich.pretty import install as pretty_install 11 | from rich.theme import Theme 12 | 13 | console = Console( 14 | log_time=True, 15 | log_time_format='%H:%M:%S-%f', 16 | theme=Theme( 17 | { 18 | 'traceback.border': 'black', 19 | 'traceback.border.syntax_error': 'black', 20 | 'inspect.value.border': 'black', 21 | } 22 | ), 23 | ) 24 | pretty_install(console=console) 25 | rh = RichHandler( 26 | show_time=True, 27 | omit_repeated_times=False, 28 | show_level=True, 29 | show_path=False, 30 | markup=False, 31 | rich_tracebacks=True, 32 | log_time_format='%H:%M:%S-%f', 33 | level=logging.INFO, 34 | console=console, 35 | ) 36 | rh.set_name(logging.INFO) 37 | while log.hasHandlers() and len(log.handlers) > 0: 38 | log.removeHandler(log.handlers[0]) 39 | log.addHandler(rh) 40 | 41 | except ModuleNotFoundError: 42 | pass 43 | 44 | -------------------------------------------------------------------------------- /mikazuki/tagger/format.py: -------------------------------------------------------------------------------- 1 | import re 2 | import hashlib 3 | 4 | from typing import Dict, Callable, NamedTuple 5 | from pathlib import Path 6 | 7 | 8 | class Info(NamedTuple): 9 | path: Path 10 | output_ext: str 11 | 12 | 13 | def hash(i: Info, algo='sha1') -> str: 14 | try: 15 | hash = hashlib.new(algo) 16 | except ImportError: 17 | raise ValueError(f"'{algo}' is invalid hash algorithm") 18 | 19 | # TODO: is okay to hash large image? 20 | with open(i.path, 'rb') as file: 21 | hash.update(file.read()) 22 | 23 | return hash.hexdigest() 24 | 25 | 26 | pattern = re.compile(r'\[([\w:]+)\]') 27 | 28 | # all function must returns string or raise TypeError or ValueError 29 | # other errors will cause the extension error 30 | available_formats: Dict[str, Callable] = { 31 | 'name': lambda i: i.path.stem, 32 | 'extension': lambda i: i.path.suffix[1:], 33 | 'hash': hash, 34 | 35 | 'output_extension': lambda i: i.output_ext 36 | } 37 | 38 | 39 | def format(match: re.Match, info: Info) -> str: 40 | matches = match[1].split(':') 41 | name, args = matches[0], matches[1:] 42 | 43 | if name not in available_formats: 44 | return match[0] 45 | 46 | return available_formats[name](info, *args) 47 | -------------------------------------------------------------------------------- /config/default.toml: -------------------------------------------------------------------------------- 1 | [model] 2 | v2 = false 3 | v_parameterization = false 4 | pretrained_model_name_or_path = "./sd-models/model.ckpt" 5 | 6 | [dataset] 7 | train_data_dir = "./train/input" 8 | reg_data_dir = "" 9 | prior_loss_weight = 1 10 | cache_latents = true 11 | shuffle_caption = true 12 | enable_bucket = true 13 | 14 | [additional_network] 15 | network_dim = 32 16 | network_alpha = 16 17 | network_train_unet_only = false 18 | network_train_text_encoder_only = false 19 | network_module = "networks.lora" 20 | network_args = [] 21 | 22 | [optimizer] 23 | unet_lr = 1e-4 24 | text_encoder_lr = 1e-5 25 | optimizer_type = "AdamW8bit" 26 | lr_scheduler = "cosine_with_restarts" 27 | lr_warmup_steps = 0 28 | lr_restart_cycles = 1 29 | 30 | [training] 31 | resolution = "512,512" 32 | train_batch_size = 1 33 | max_train_epochs = 10 34 | noise_offset = 0.0 35 | keep_tokens = 0 36 | xformers = true 37 | lowram = false 38 | clip_skip = 2 39 | mixed_precision = "fp16" 40 | save_precision = "fp16" 41 | 42 | [sample_prompt] 43 | sample_sampler = "euler_a" 44 | sample_every_n_epochs = 1 45 | 46 | [saving] 47 | output_name = "output_name" 48 | save_every_n_epochs = 1 49 | save_n_epoch_ratio = 0 50 | save_last_n_epochs = 499 51 | save_state = false 52 | save_model_as = "safetensors" 53 | output_dir = "./output" 54 | logging_dir = "./logs" 55 | log_prefix = "output_name" 56 | 57 | [others] 58 | min_bucket_reso = 256 59 | max_bucket_reso = 1024 60 | caption_extension = ".txt" 61 | max_token_length = 225 62 | seed = 1337 63 | -------------------------------------------------------------------------------- /config/lora.toml: -------------------------------------------------------------------------------- 1 | [model_arguments] 2 | v2 = false 3 | v_parameterization = false 4 | pretrained_model_name_or_path = "./sd-models/model.ckpt" 5 | 6 | [dataset_arguments] 7 | train_data_dir = "./train/aki" 8 | reg_data_dir = "" 9 | resolution = "512,512" 10 | prior_loss_weight = 1 11 | 12 | [additional_network_arguments] 13 | network_dim = 32 14 | network_alpha = 16 15 | network_train_unet_only = false 16 | network_train_text_encoder_only = false 17 | network_module = "networks.lora" 18 | network_args = [] 19 | 20 | [optimizer_arguments] 21 | unet_lr = 1e-4 22 | text_encoder_lr = 1e-5 23 | 24 | optimizer_type = "AdamW8bit" 25 | lr_scheduler = "cosine_with_restarts" 26 | lr_warmup_steps = 0 27 | lr_restart_cycles = 1 28 | 29 | [training_arguments] 30 | train_batch_size = 1 31 | noise_offset = 0.0 32 | keep_tokens = 0 33 | min_bucket_reso = 256 34 | max_bucket_reso = 1024 35 | caption_extension = ".txt" 36 | max_token_length = 225 37 | seed = 1337 38 | xformers = true 39 | lowram = false 40 | max_train_epochs = 10 41 | resolution = "512,512" 42 | clip_skip = 2 43 | mixed_precision = "fp16" 44 | 45 | [sample_prompt_arguments] 46 | sample_sampler = "euler_a" 47 | sample_every_n_epochs = 5 48 | 49 | [saving_arguments] 50 | output_name = "output_name" 51 | save_every_n_epochs = 1 52 | save_state = false 53 | save_model_as = "safetensors" 54 | output_dir = "./output" 55 | logging_dir = "./logs" 56 | log_prefix = "" 57 | save_precision = "fp16" 58 | 59 | [others] 60 | cache_latents = true 61 | shuffle_caption = true 62 | enable_bucket = true -------------------------------------------------------------------------------- /mikazuki/utils/devices.py: -------------------------------------------------------------------------------- 1 | from mikazuki.log import log 2 | 3 | available_devices = [] 4 | printable_devices = [] 5 | 6 | 7 | def check_torch_gpu(): 8 | try: 9 | import torch 10 | log.info(f'Torch {torch.__version__}') 11 | if not torch.cuda.is_available(): 12 | log.error("Torch is not able to use GPU, please check your torch installation.\n Use --skip-prepare-environment to disable this check") 13 | log.error("!!!Torch 无法使用 GPU,您无法正常开始训练!!!\n您的显卡可能并不支持,或是 torch 安装有误。请检查您的 torch 安装。") 14 | return 15 | 16 | if torch.version.cuda: 17 | log.info( 18 | f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}') 19 | elif torch.version.hip: 20 | log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}') 21 | 22 | devices = [torch.cuda.device(i) for i in range(torch.cuda.device_count())] 23 | 24 | for pos, device in enumerate(devices): 25 | name = torch.cuda.get_device_name(device) 26 | memory = torch.cuda.get_device_properties(device).total_memory 27 | available_devices.append(device) 28 | printable_devices.append(f"GPU {pos}: {name} ({round(memory / (1024**3))} GB)") 29 | log.info( 30 | f'Torch detected GPU: {name} VRAM {round(memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}') 31 | except Exception as e: 32 | log.error(f'Could not load torch: {e}') 33 | -------------------------------------------------------------------------------- /resize.ps1: -------------------------------------------------------------------------------- 1 | # LoRA resize script by @bdsqlsz 2 | 3 | $save_precision = "fp16" # precision in saving, default float | 保存精度, 可选 float、fp16、bf16, 默认 float 4 | $new_rank = 4 # dim rank of output LoRA | dim rank等级, 默认 4 5 | $model = "./output/lora_name.safetensors" # original LoRA model path need to resize, save as cpkt or safetensors | 需要调整大小的模型路径, 保存格式 cpkt 或 safetensors 6 | $save_to = "./output/lora_name_new.safetensors" # output LoRA model path, save as ckpt or safetensors | 输出路径, 保存格式 cpkt 或 safetensors 7 | $device = "cuda" # device to use, cuda for GPU | 使用 GPU跑, 默认 CPU 8 | $verbose = 1 # display verbose resizing information | rank变更时, 显示详细信息 9 | $dynamic_method = "" # Specify dynamic resizing method, --new_rank is used as a hard limit for max rank | 动态调节大小,可选"sv_ratio", "sv_fro", "sv_cumulative",默认无 10 | $dynamic_param = "" # Specify target for dynamic reduction | 动态参数,sv_ratio模式推荐1~2, sv_cumulative模式0~1, sv_fro模式0~1, 比sv_cumulative要高 11 | 12 | 13 | # Activate python venv 14 | .\venv\Scripts\activate 15 | 16 | $Env:HF_HOME = "huggingface" 17 | $ext_args = [System.Collections.ArrayList]::new() 18 | 19 | if ($verbose) { 20 | [void]$ext_args.Add("--verbose") 21 | } 22 | 23 | if ($dynamic_method) { 24 | [void]$ext_args.Add("--dynamic_method=" + $dynamic_method) 25 | } 26 | 27 | if ($dynamic_param) { 28 | [void]$ext_args.Add("--dynamic_param=" + $dynamic_param) 29 | } 30 | 31 | # run resize 32 | accelerate launch --num_cpu_threads_per_process=8 "./sd-scripts/networks/resize_lora.py" ` 33 | --save_precision=$save_precision ` 34 | --new_rank=$new_rank ` 35 | --model=$model ` 36 | --save_to=$save_to ` 37 | --device=$device ` 38 | $ext_args 39 | 40 | Write-Output "Resize finished" 41 | Read-Host | Out-Null ; 42 | -------------------------------------------------------------------------------- /mikazuki/tagger/dbimutils.py: -------------------------------------------------------------------------------- 1 | # DanBooru IMage Utility functions 2 | 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | def smart_imread(img, flag=cv2.IMREAD_UNCHANGED): 9 | if img.endswith(".gif"): 10 | img = Image.open(img) 11 | img = img.convert("RGB") 12 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 13 | else: 14 | img = cv2.imread(img, flag) 15 | return img 16 | 17 | 18 | def smart_24bit(img): 19 | if img.dtype is np.dtype(np.uint16): 20 | img = (img / 257).astype(np.uint8) 21 | 22 | if len(img.shape) == 2: 23 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 24 | elif img.shape[2] == 4: 25 | trans_mask = img[:, :, 3] == 0 26 | img[trans_mask] = [255, 255, 255, 255] 27 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) 28 | return img 29 | 30 | 31 | def make_square(img, target_size): 32 | old_size = img.shape[:2] 33 | desired_size = max(old_size) 34 | desired_size = max(desired_size, target_size) 35 | 36 | delta_w = desired_size - old_size[1] 37 | delta_h = desired_size - old_size[0] 38 | top, bottom = delta_h // 2, delta_h - (delta_h // 2) 39 | left, right = delta_w // 2, delta_w - (delta_w // 2) 40 | 41 | color = [255, 255, 255] 42 | new_im = cv2.copyMakeBorder( 43 | img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color 44 | ) 45 | return new_im 46 | 47 | 48 | def smart_resize(img, size): 49 | # Assumes the image has already gone through make_square 50 | if img.shape[0] > size: 51 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) 52 | elif img.shape[0] < size: 53 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) 54 | return img 55 | -------------------------------------------------------------------------------- /svd_merge.ps1: -------------------------------------------------------------------------------- 1 | # LoRA svd_merge script by @bdsqlsz 2 | 3 | $save_precision = "fp16" # precision in saving, default float | 保存精度, 可选 float、fp16、bf16, 默认 和源文件相同 4 | $precision = "float" # precision in merging (float is recommended) | 合并时计算精度, 可选 float、fp16、bf16, 推荐float 5 | $new_rank = 4 # dim rank of output LoRA | dim rank等级, 默认 4 6 | $models = "./output/modelA.safetensors ./output/modelB.safetensors" # original LoRA model path need to resize, save as cpkt or safetensors | 需要合并的模型路径, 保存格式 cpkt 或 safetensors,多个用空格隔开 7 | $ratios = "1.0 -1.0" # ratios for each model / LoRA模型合并比例,数量等于模型数量,多个用空格隔开 8 | $save_to = "./output/lora_name_new.safetensors" # output LoRA model path, save as ckpt or safetensors | 输出路径, 保存格式 cpkt 或 safetensors 9 | $device = "cuda" # device to use, cuda for GPU | 使用 GPU跑, 默认 CPU 10 | $new_conv_rank = 0 # Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank | Conv2d 3x3输出,没有默认同new_rank 11 | 12 | # Activate python venv 13 | .\venv\Scripts\activate 14 | 15 | $Env:HF_HOME = "huggingface" 16 | $Env:XFORMERS_FORCE_DISABLE_TRITON = "1" 17 | $ext_args = [System.Collections.ArrayList]::new() 18 | 19 | [void]$ext_args.Add("--models") 20 | foreach ($model in $models.Split(" ")) { 21 | [void]$ext_args.Add($model) 22 | } 23 | 24 | [void]$ext_args.Add("--ratios") 25 | foreach ($ratio in $ratios.Split(" ")) { 26 | [void]$ext_args.Add([float]$ratio) 27 | } 28 | 29 | if ($new_conv_rank) { 30 | [void]$ext_args.Add("--new_conv_rank=" + $new_conv_rank) 31 | } 32 | 33 | # run svd_merge 34 | accelerate launch --num_cpu_threads_per_process=8 "./sd-scripts/networks/svd_merge_lora.py" ` 35 | --save_precision=$save_precision ` 36 | --precision=$precision ` 37 | --new_rank=$new_rank ` 38 | --save_to=$save_to ` 39 | --device=$device ` 40 | $ext_args 41 | 42 | Write-Output "SVD Merge finished" 43 | Read-Host | Out-Null ; 44 | -------------------------------------------------------------------------------- /mikazuki/process.py: -------------------------------------------------------------------------------- 1 | 2 | import asyncio 3 | import os 4 | import sys 5 | from typing import Optional 6 | 7 | from mikazuki.app.models import APIResponse 8 | from mikazuki.log import log 9 | from mikazuki.tasks import tm 10 | 11 | 12 | def run_train(toml_path: str, 13 | trainer_file: str = "./sd-scripts/train_network.py", 14 | gpu_ids: Optional[list] = None, 15 | cpu_threads: Optional[int] = 2): 16 | log.info(f"Training started with config file / 训练开始,使用配置文件: {toml_path}") 17 | args = [ 18 | sys.executable, "-m", "accelerate.commands.launch", # use -m to avoid python script executable error 19 | "--num_cpu_threads_per_process", str(cpu_threads), # cpu threads 20 | "--quiet", # silence accelerate error message 21 | trainer_file, 22 | "--config_file", toml_path, 23 | ] 24 | 25 | customize_env = os.environ.copy() 26 | customize_env["ACCELERATE_DISABLE_RICH"] = "1" 27 | customize_env["PYTHONUNBUFFERED"] = "1" 28 | 29 | if gpu_ids: 30 | customize_env["CUDA_VISIBLE_DEVICES"] = ",".join(gpu_ids) 31 | log.info(f"Using GPU(s) / 使用 GPU: {gpu_ids}") 32 | 33 | if len(gpu_ids) > 1: 34 | args[3:3] = ["--multi_gpu", "--num_processes", str(len(gpu_ids))] 35 | 36 | if not (task := tm.create_task(args, customize_env)): 37 | return APIResponse(status="error", message="Failed to create task / 无法创建训练任务") 38 | 39 | def _run(): 40 | try: 41 | task.execute() 42 | result = task.communicate() 43 | if result.returncode != 0: 44 | log.error(f"Training failed / 训练失败") 45 | else: 46 | log.info(f"Training finished / 训练完成") 47 | except Exception as e: 48 | log.error(f"An error occurred when training / 训练出现致命错误: {e}") 49 | 50 | coro = asyncio.to_thread(_run) 51 | asyncio.create_task(coro) 52 | 53 | return APIResponse(status="success", message=f"Training started / 训练开始 ID: {task.task_id}") 54 | -------------------------------------------------------------------------------- /tagger.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # tagger script by @bdsqlsz 3 | # Train data path 4 | train_data_dir="./input" # input images path | 图片输入路径 5 | repo_id="SmilingWolf/wd-v1-4-swinv2-tagger-v2" # model repo id from huggingface |huggingface模型repoID 6 | model_dir="" # model dir path | 本地模型文件夹路径 7 | batch_size=12 # batch size in inference 批处理大小,越大越快 8 | max_data_loader_n_workers=0 # enable image reading by DataLoader with this number of workers (faster) | 0最快 9 | thresh=0.35 # concept thresh | 最小识别阈值 10 | general_threshold=0.35 # general threshold | 总体识别阈值 11 | character_threshold=0.1 # character threshold | 人物姓名识别阈值 12 | remove_underscore=0 # remove_underscore | 下划线转空格,1为开,0为关 13 | undesired_tags="" # no need tags | 排除标签 14 | recursive=0 # search for images in subfolders recursively | 递归搜索下层文件夹,1为开,0为关 15 | frequency_tags=0 # order by frequency tags | 从大到小按识别率排序标签,1为开,0为关 16 | 17 | 18 | # ============= DO NOT MODIFY CONTENTS BELOW | 请勿修改下方内容 ===================== 19 | 20 | export HF_HOME="huggingface" 21 | export TF_CPP_MIN_LOG_LEVEL=3 22 | extArgs=() 23 | 24 | if [ -n "$repo_id" ]; then 25 | extArgs+=( "--repo_id=$repo_id" ) 26 | fi 27 | 28 | if [ -n "$model_dir" ]; then 29 | extArgs+=( "--model_dir=$model_dir" ) 30 | fi 31 | 32 | if [[ $batch_size -ne 0 ]]; then 33 | extArgs+=( "--batch_size=$batch_size" ) 34 | fi 35 | 36 | if [ -n "$max_data_loader_n_workers" ]; then 37 | extArgs+=( "--max_data_loader_n_workers=$max_data_loader_n_workers" ) 38 | fi 39 | 40 | if [ -n "$general_threshold" ]; then 41 | extArgs+=( "--general_threshold=$general_threshold" ) 42 | fi 43 | 44 | if [ -n "$character_threshold" ]; then 45 | extArgs+=( "--character_threshold=$character_threshold" ) 46 | fi 47 | 48 | if [ "$remove_underscore" -eq 1 ]; then 49 | extArgs+=( "--remove_underscore" ) 50 | fi 51 | 52 | if [ -n "$undesired_tags" ]; then 53 | extArgs+=( "--undesired_tags=$undesired_tags" ) 54 | fi 55 | 56 | if [ "$recursive" -eq 1 ]; then 57 | extArgs+=( "--recursive" ) 58 | fi 59 | 60 | if [ "$frequency_tags" -eq 1 ]; then 61 | extArgs+=( "--frequency_tags" ) 62 | fi 63 | 64 | 65 | # run tagger 66 | accelerate launch --num_cpu_threads_per_process=8 "./sd-scripts/finetune/tag_images_by_wd14_tagger.py" \ 67 | $train_data_dir \ 68 | --thresh=$thresh \ 69 | --caption_extension .txt \ 70 | ${extArgs[@]} 71 | -------------------------------------------------------------------------------- /tagger.ps1: -------------------------------------------------------------------------------- 1 | # tagger script by @bdsqlsz 2 | 3 | # Train data path 4 | $train_data_dir = "./input" # input images path | 图片输入路径 5 | $repo_id = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" # model repo id from huggingface |huggingface模型repoID 6 | $model_dir = "" # model dir path | 本地模型文件夹路径 7 | $batch_size = 4 # batch size in inference 批处理大小,越大越快 8 | $max_data_loader_n_workers = 0 # enable image reading by DataLoader with this number of workers (faster) | 0最快 9 | $thresh = 0.35 # concept thresh | 最小识别阈值 10 | $general_threshold = 0.35 # general threshold | 总体识别阈值 11 | $character_threshold = 0.1 # character threshold | 人物姓名识别阈值 12 | $remove_underscore = 0 # remove_underscore | 下划线转空格,1为开,0为关 13 | $undesired_tags = "" # no need tags | 排除标签 14 | $recursive = 0 # search for images in subfolders recursively | 递归搜索下层文件夹,1为开,0为关 15 | $frequency_tags = 0 # order by frequency tags | 从大到小按识别率排序标签,1为开,0为关 16 | 17 | 18 | # Activate python venv 19 | .\venv\Scripts\activate 20 | 21 | $Env:HF_HOME = "huggingface" 22 | $Env:XFORMERS_FORCE_DISABLE_TRITON = "1" 23 | $ext_args = [System.Collections.ArrayList]::new() 24 | 25 | if ($repo_id) { 26 | [void]$ext_args.Add("--repo_id=" + $repo_id) 27 | } 28 | 29 | if ($model_dir) { 30 | [void]$ext_args.Add("--model_dir=" + $model_dir) 31 | } 32 | 33 | if ($batch_size) { 34 | [void]$ext_args.Add("--batch_size=" + $batch_size) 35 | } 36 | 37 | if ($max_data_loader_n_workers) { 38 | [void]$ext_args.Add("--max_data_loader_n_workers=" + $max_data_loader_n_workers) 39 | } 40 | 41 | if ($general_threshold) { 42 | [void]$ext_args.Add("--general_threshold=" + $general_threshold) 43 | } 44 | 45 | if ($character_threshold) { 46 | [void]$ext_args.Add("--character_threshold=" + $character_threshold) 47 | } 48 | 49 | if ($remove_underscore) { 50 | [void]$ext_args.Add("--remove_underscore") 51 | } 52 | 53 | if ($undesired_tags) { 54 | [void]$ext_args.Add("--undesired_tags=" + $undesired_tags) 55 | } 56 | 57 | if ($recursive) { 58 | [void]$ext_args.Add("--recursive") 59 | } 60 | 61 | if ($frequency_tags) { 62 | [void]$ext_args.Add("--frequency_tags") 63 | } 64 | 65 | # run tagger 66 | accelerate launch --num_cpu_threads_per_process=8 "./sd-scripts/finetune/tag_images_by_wd14_tagger.py" ` 67 | $train_data_dir ` 68 | --thresh=$thresh ` 69 | --caption_extension .txt ` 70 | $ext_args 71 | 72 | Write-Output "Tagger finished" 73 | Read-Host | Out-Null ; 74 | -------------------------------------------------------------------------------- /run.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5e35269a-ec20-41a3-93a6-da798c3a8401", 6 | "metadata": {}, 7 | "source": [ 8 | "# LoRA Train UI: SD-Trainer\n", 9 | "\n", 10 | "LoRA Training UI By [Akegarasu](https://github.com/Akegarasu)\n", 11 | "User Guide:https://github.com/Akegarasu/lora-scripts/blob/main/README.md\n", 12 | "\n", 13 | "LoRA 训练 By [秋葉aaaki@bilibili](https://space.bilibili.com/12566101)\n", 14 | "使用方法:https://www.bilibili.com/read/cv24050162/" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "id": "12c2a3d0-9aec-4680-9b8a-cb02cac48de6", 20 | "metadata": {}, 21 | "source": [ 22 | "### Run | 运行" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "7ae0678f-69df-4a12-a0bc-1325e52e9122", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "import sys\n", 33 | "!export HF_HOME=huggingface && $sys.executable gui.py --host 0.0.0.0" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "99edaa2b-9ba2-4fde-9b2e-af5dc8bf7062", 39 | "metadata": {}, 40 | "source": [ 41 | "## Update | 更新" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "### Github" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "!git pull && git submodule init && git submodule update" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "### 国内镜像加速" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "!export GIT_CONFIG_GLOBAL=./assets/gitconfig-cn && export GIT_TERMINAL_PROMPT=false && git pull && git submodule init && git submodule update" 74 | ] 75 | } 76 | ], 77 | "metadata": { 78 | "kernelspec": { 79 | "display_name": "Python 3 (ipykernel)", 80 | "language": "python", 81 | "name": "python3" 82 | }, 83 | "language_info": { 84 | "codemirror_mode": { 85 | "name": "ipython", 86 | "version": 3 87 | }, 88 | "file_extension": ".py", 89 | "mimetype": "text/x-python", 90 | "name": "python", 91 | "nbconvert_exporter": "python", 92 | "pygments_lexer": "ipython3", 93 | "version": "3.10.8" 94 | } 95 | }, 96 | "nbformat": 4, 97 | "nbformat_minor": 5 98 | } 99 | -------------------------------------------------------------------------------- /mikazuki/app/application.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import mimetypes 3 | import os 4 | import sys 5 | import webbrowser 6 | from contextlib import asynccontextmanager 7 | 8 | from fastapi import FastAPI 9 | from fastapi.middleware.cors import CORSMiddleware 10 | from fastapi.responses import FileResponse 11 | from fastapi.staticfiles import StaticFiles 12 | from starlette.exceptions import HTTPException 13 | 14 | from mikazuki.app.api import load_schemas 15 | from mikazuki.app.api import router as api_router 16 | # from mikazuki.app.ipc import router as ipc_router 17 | from mikazuki.app.proxy import router as proxy_router 18 | from mikazuki.utils.devices import check_torch_gpu 19 | 20 | mimetypes.add_type("application/javascript", ".js") 21 | mimetypes.add_type("text/css", ".css") 22 | 23 | 24 | class SPAStaticFiles(StaticFiles): 25 | async def get_response(self, path: str, scope): 26 | try: 27 | return await super().get_response(path, scope) 28 | except HTTPException as ex: 29 | if ex.status_code == 404: 30 | return await super().get_response("index.html", scope) 31 | else: 32 | raise ex 33 | 34 | 35 | async def app_startup(): 36 | await asyncio.to_thread(check_torch_gpu) 37 | await load_schemas() 38 | if sys.platform == "win32" and os.environ.get("MIKAZUKI_DEV", "0") != "1": 39 | webbrowser.open(f'http://{os.environ["MIKAZUKI_HOST"]}:{os.environ["MIKAZUKI_PORT"]}') 40 | 41 | 42 | @asynccontextmanager 43 | async def lifespan(app: FastAPI): 44 | await app_startup() 45 | yield 46 | 47 | 48 | app = FastAPI(lifespan=lifespan) 49 | app.include_router(proxy_router) 50 | 51 | 52 | cors_config = os.environ.get("MIKAZUKI_APP_CORS", "") 53 | if cors_config != "": 54 | if cors_config == "1": 55 | cors_config = ["http://localhost:8004", "*"] 56 | else: 57 | cors_config = cors_config.split(";") 58 | app.add_middleware( 59 | CORSMiddleware, 60 | allow_origins=cors_config, 61 | allow_credentials=True, 62 | allow_methods=["*"], 63 | allow_headers=["*"], 64 | ) 65 | 66 | 67 | @app.middleware("http") 68 | async def add_cache_control_header(request, call_next): 69 | response = await call_next(request) 70 | response.headers["Cache-Control"] = "max-age=0" 71 | return response 72 | 73 | app.include_router(api_router, prefix="/api") 74 | # app.include_router(ipc_router, prefix="/ipc") 75 | 76 | 77 | @app.get("/") 78 | async def index(): 79 | return FileResponse("./frontend/dist/index.html") 80 | 81 | 82 | app.mount("/", SPAStaticFiles(directory="frontend/dist", html=True), name="static") 83 | -------------------------------------------------------------------------------- /install.bash: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 4 | create_venv=true 5 | 6 | while [ -n "$1" ]; do 7 | case "$1" in 8 | --disable-venv) 9 | create_venv=false 10 | shift 11 | ;; 12 | *) 13 | shift 14 | ;; 15 | esac 16 | done 17 | 18 | if $create_venv; then 19 | echo "Creating python venv..." 20 | python3 -m venv venv 21 | source "$script_dir/venv/bin/activate" 22 | echo "active venv" 23 | fi 24 | 25 | echo "Installing torch & xformers..." 26 | 27 | cuda_version=$(nvidia-smi | grep -oiP 'CUDA Version: \K[\d\.]+') 28 | 29 | if [ -z "$cuda_version" ]; then 30 | cuda_version=$(nvcc --version | grep -oiP 'release \K[\d\.]+') 31 | fi 32 | cuda_major_version=$(echo "$cuda_version" | awk -F'.' '{print $1}') 33 | cuda_minor_version=$(echo "$cuda_version" | awk -F'.' '{print $2}') 34 | 35 | echo "CUDA Version: $cuda_version" 36 | 37 | 38 | if (( cuda_major_version >= 12 )); then 39 | echo "install torch 2.2.1+cu121" 40 | pip install torch==2.2.1+cu121 torchvision==0.17.1+cu121 --extra-index-url https://download.pytorch.org/whl/cu121 41 | pip install --no-deps xformers==0.0.25 --extra-index-url https://download.pytorch.org/whl/cu118 42 | elif (( cuda_major_version == 11 && cuda_minor_version >= 8 )); then 43 | echo "install torch 2.2.1+cu118" 44 | pip install torch==2.2.1+cu118 torchvision==0.17.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 45 | pip install --no-deps xformers==0.0.25+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 46 | elif (( cuda_major_version == 11 && cuda_minor_version >= 6 )); then 47 | echo "install torch 1.12.1+cu116" 48 | pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 49 | # for RTX3090+cu113/cu116 xformers, we need to install this version from source. You can also try xformers==0.0.18 50 | pip install --upgrade git+https://github.com/facebookresearch/xformers.git@0bad001ddd56c080524d37c84ff58d9cd030ebfd 51 | pip install triton==2.0.0.dev20221202 52 | elif (( cuda_major_version == 11 && cuda_minor_version >= 2 )); then 53 | echo "install torch 1.12.1+cu113" 54 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu116 55 | pip install --upgrade git+https://github.com/facebookresearch/xformers.git@0bad001ddd56c080524d37c84ff58d9cd030ebfd 56 | pip install triton==2.0.0.dev20221202 57 | else 58 | echo "Unsupported cuda version:$cuda_version" 59 | exit 1 60 | fi 61 | 62 | echo "Installing deps..." 63 | cd "$script_dir/sd-scripts" || exit 64 | 65 | pip install --upgrade -r requirements.txt 66 | 67 | cd "$script_dir" || exit 68 | 69 | pip install --upgrade -r requirements.txt 70 | 71 | echo "Install completed" 72 | -------------------------------------------------------------------------------- /gui.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import locale 3 | import os 4 | import platform 5 | import subprocess 6 | import sys 7 | 8 | from mikazuki.launch_utils import (base_dir_path, catch_exception, 9 | prepare_environment) 10 | from mikazuki.log import log 11 | 12 | parser = argparse.ArgumentParser(description="GUI for stable diffusion training") 13 | parser.add_argument("--host", type=str, default="127.0.0.1") 14 | parser.add_argument("--port", type=int, default=28000, help="Port to run the server on") 15 | parser.add_argument("--listen", action="store_true") 16 | parser.add_argument("--skip-prepare-environment", action="store_true") 17 | parser.add_argument("--disable-tensorboard", action="store_true") 18 | parser.add_argument("--disable-tageditor", action="store_true") 19 | parser.add_argument("--tensorboard-host", type=str, default="127.0.0.1", help="Port to run the tensorboard") 20 | parser.add_argument("--tensorboard-port", type=int, default=6006, help="Port to run the tensorboard") 21 | parser.add_argument("--localization", type=str) 22 | parser.add_argument("--dev", action="store_true") 23 | 24 | 25 | @catch_exception 26 | def run_tensorboard(): 27 | log.info("Starting tensorboard...") 28 | subprocess.Popen([sys.executable, "-m", "tensorboard.main", "--logdir", "logs", 29 | "--host", args.tensorboard_host, "--port", str(args.tensorboard_port)]) 30 | 31 | 32 | @catch_exception 33 | def run_tag_editor(): 34 | log.info("Starting tageditor...") 35 | cmd = [ 36 | sys.executable, 37 | base_dir_path() / "mikazuki/dataset-tag-editor/scripts/launch.py", 38 | "--port", "28001", 39 | "--shadow-gradio-output", 40 | "--root-path", "/proxy/tageditor" 41 | ] 42 | if args.localization: 43 | cmd.extend(["--localization", args.localization]) 44 | else: 45 | l = locale.getdefaultlocale()[0] 46 | if l and l.startswith("zh"): 47 | cmd.extend(["--localization", "zh-Hans"]) 48 | subprocess.Popen(cmd) 49 | 50 | 51 | def launch(): 52 | log.info("Starting SD-Trainer Mikazuki GUI...") 53 | log.info(f"Base directory: {base_dir_path()}, Working directory: {os.getcwd()}") 54 | log.info(f'{platform.system()} Python {platform.python_version()} {sys.executable}') 55 | 56 | if not args.skip_prepare_environment: 57 | prepare_environment() 58 | 59 | os.environ["MIKAZUKI_HOST"] = args.host 60 | os.environ["MIKAZUKI_PORT"] = str(args.port) 61 | os.environ["MIKAZUKI_TENSORBOARD_HOST"] = args.tensorboard_host 62 | os.environ["MIKAZUKI_TENSORBOARD_PORT"] = str(args.tensorboard_port) 63 | os.environ["MIKAZUKI_DEV"] = "1" if args.dev else "0" 64 | 65 | if args.listen: 66 | args.host = "0.0.0.0" 67 | args.tensorboard_host = "0.0.0.0" 68 | 69 | if not args.disable_tageditor: 70 | run_tag_editor() 71 | 72 | if not args.disable_tensorboard: 73 | run_tensorboard() 74 | 75 | import uvicorn 76 | log.info(f"Server started at http://{args.host}:{args.port}") 77 | uvicorn.run("mikazuki.app:app", host=args.host, port=args.port, log_level="error") 78 | 79 | 80 | if __name__ == "__main__": 81 | args, _ = parser.parse_known_args() 82 | launch() 83 | -------------------------------------------------------------------------------- /train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "pycharm": { 8 | "name": "#%%\n" 9 | } 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "# Train data path | 设置训练用模型、图片\n", 14 | "pretrained_model = \"./sd-models/model.ckpt\" # base model path | 底模路径\n", 15 | "train_data_dir = \"./train/aki\" # train dataset path | 训练数据集路径\n", 16 | "\n", 17 | "# Train related params | 训练相关参数\n", 18 | "resolution = \"512,512\" # image resolution w,h. 图片分辨率,宽,高。支持非正方形,但必须是 64 倍数。\n", 19 | "batch_size = 1 # batch size\n", 20 | "max_train_epoches = 10 # max train epoches | 最大训练 epoch\n", 21 | "save_every_n_epochs = 2 # save every n epochs | 每 N 个 epoch 保存一次\n", 22 | "network_dim = 32 # network dim | 常用 4~128,不是越大越好\n", 23 | "network_alpha= 32 # network alpha | 常用与 network_dim 相同的值或者采用较小的值,如 network_dim的一半 防止下溢。默认值为 1,使用较小的 alpha 需要提升学习率。\n", 24 | "clip_skip = 2 # clip skip | 玄学 一般用 2\n", 25 | "train_unet_only = 0 # train U-Net only | 仅训练 U-Net,开启这个会牺牲效果大幅减少显存使用。6G显存可以开启\n", 26 | "train_text_encoder_only = 0 # train Text Encoder only | 仅训练 文本编码器\n", 27 | "\n", 28 | "# Learning rate | 学习率\n", 29 | "lr = \"1e-4\"\n", 30 | "unet_lr = \"1e-4\"\n", 31 | "text_encoder_lr = \"1e-5\"\n", 32 | "lr_scheduler = \"cosine_with_restarts\" # \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"\n", 33 | "\n", 34 | "# Output settings | 输出设置\n", 35 | "output_name = \"aki\" # output model name | 模型保存名称\n", 36 | "save_model_as = \"safetensors\" # model save ext | 模型保存格式 ckpt, pt, safetensors" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": { 43 | "pycharm": { 44 | "name": "#%%\n" 45 | } 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "!accelerate launch --num_cpu_threads_per_process=8 \"./sd-scripts/train_network.py\" \\\n", 50 | " --enable_bucket \\\n", 51 | " --pretrained_model_name_or_path=$pretrained_model \\\n", 52 | " --train_data_dir=$train_data_dir \\\n", 53 | " --output_dir=\"./output\" \\\n", 54 | " --logging_dir=\"./logs\" \\\n", 55 | " --resolution=$resolution \\\n", 56 | " --network_module=networks.lora \\\n", 57 | " --max_train_epochs=$max_train_epoches \\\n", 58 | " --learning_rate=$lr \\\n", 59 | " --unet_lr=$unet_lr \\\n", 60 | " --text_encoder_lr=$text_encoder_lr \\\n", 61 | " --network_dim=$network_dim \\\n", 62 | " --network_alpha=$network_alpha \\\n", 63 | " --output_name=$output_name \\\n", 64 | " --lr_scheduler=$lr_scheduler \\\n", 65 | " --train_batch_size=$batch_size \\\n", 66 | " --save_every_n_epochs=$save_every_n_epochs \\\n", 67 | " --mixed_precision=\"fp16\" \\\n", 68 | " --save_precision=\"fp16\" \\\n", 69 | " --seed=\"1337\" \\\n", 70 | " --cache_latents \\\n", 71 | " --clip_skip=$clip_skip \\\n", 72 | " --prior_loss_weight=1 \\\n", 73 | " --max_token_length=225 \\\n", 74 | " --caption_extension=\".txt\" \\\n", 75 | " --save_model_as=$save_model_as \\\n", 76 | " --xformers --shuffle_caption --use_8bit_adam" 77 | ] 78 | } 79 | ], 80 | "metadata": { 81 | "kernelspec": { 82 | "display_name": "Python 3", 83 | "language": "python", 84 | "name": "python3" 85 | }, 86 | "language_info": { 87 | "name": "python", 88 | "version": "3.10.7 (tags/v3.10.7:6cc6b13, Sep 5 2022, 14:08:36) [MSC v.1933 64 bit (AMD64)]" 89 | }, 90 | "orig_nbformat": 4, 91 | "vscode": { 92 | "interpreter": { 93 | "hash": "675b13e958f0d0236d13cdfe08a1df3882cae564fa23a2e7e5eb1f2c6c632b02" 94 | } 95 | } 96 | }, 97 | "nbformat": 4, 98 | "nbformat_minor": 2 99 | } -------------------------------------------------------------------------------- /mikazuki/app/proxy.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | import httpx 5 | import starlette 6 | import websockets 7 | from fastapi import APIRouter, Request, WebSocket 8 | from httpx import ConnectError 9 | from starlette.background import BackgroundTask 10 | from starlette.requests import Request 11 | from starlette.responses import PlainTextResponse, StreamingResponse 12 | 13 | from mikazuki.log import log 14 | 15 | router = APIRouter() 16 | 17 | 18 | def reverse_proxy_maker(url_type: str, full_path: bool = False): 19 | if url_type == "tensorboard": 20 | host = os.environ.get("MIKAZUKI_TENSORBOARD_HOST", "127.0.0.1") 21 | port = os.environ.get("MIKAZUKI_TENSORBOARD_PORT", "6006") 22 | elif url_type == "tageditor": 23 | host = os.environ.get("MIKAZUKI_TAGEDITOR_HOST", "127.0.0.1") 24 | port = os.environ.get("MIKAZUKI_TAGEDITOR_PORT", "28001") 25 | 26 | client = httpx.AsyncClient(base_url=f"http://{host}:{port}/", proxies={}, trust_env=False, timeout=360) 27 | 28 | async def _reverse_proxy(request: Request): 29 | if full_path: 30 | url = httpx.URL(path=request.url.path, query=request.url.query.encode("utf-8")) 31 | else: 32 | url = httpx.URL( 33 | path=request.path_params.get("path", ""), 34 | query=request.url.query.encode("utf-8") 35 | ) 36 | rp_req = client.build_request( 37 | request.method, url, 38 | headers=request.headers.raw, 39 | content=request.stream() if request.method != "GET" else None 40 | ) 41 | try: 42 | rp_resp = await client.send(rp_req, stream=True) 43 | except ConnectError: 44 | return PlainTextResponse( 45 | content="The requested service not started yet or service started fail. This may cost a while when you first time startup\n请求的服务尚未启动或启动失败。若是第一次启动,可能需要等待一段时间后再刷新网页。", 46 | status_code=502 47 | ) 48 | return StreamingResponse( 49 | rp_resp.aiter_raw(), 50 | status_code=rp_resp.status_code, 51 | headers=rp_resp.headers, 52 | background=BackgroundTask(rp_resp.aclose), 53 | ) 54 | 55 | return _reverse_proxy 56 | 57 | 58 | async def proxy_ws_forward(ws_a: WebSocket, ws_b: websockets.WebSocketClientProtocol): 59 | while True: 60 | try: 61 | data = await ws_a.receive_text() 62 | await ws_b.send(data) 63 | except starlette.websockets.WebSocketDisconnect as e: 64 | break 65 | except Exception as e: 66 | log.error(f"Error when proxy data client -> backend: {e}") 67 | break 68 | 69 | 70 | async def proxy_ws_reverse(ws_a: WebSocket, ws_b: websockets.WebSocketClientProtocol): 71 | while True: 72 | try: 73 | data = await ws_b.recv() 74 | await ws_a.send_text(data) 75 | except websockets.exceptions.ConnectionClosedOK as e: 76 | break 77 | except Exception as e: 78 | log.error(f"Error when proxy data backend -> client: {e}") 79 | break 80 | 81 | 82 | @router.websocket("/proxy/tageditor/queue/join") 83 | async def websocket_a(ws_a: WebSocket): 84 | # for temp use 85 | ws_b_uri = "ws://127.0.0.1:28001/queue/join" 86 | await ws_a.accept() 87 | async with websockets.connect(ws_b_uri, timeout=360) as ws_b_client: 88 | fwd_task = asyncio.create_task(proxy_ws_forward(ws_a, ws_b_client)) 89 | rev_task = asyncio.create_task(proxy_ws_reverse(ws_a, ws_b_client)) 90 | await asyncio.gather(fwd_task, rev_task) 91 | 92 | router.add_route("/proxy/tensorboard/{path:path}", reverse_proxy_maker("tensorboard"), ["GET", "POST"]) 93 | router.add_route("/font-roboto/{path:path}", reverse_proxy_maker("tensorboard", full_path=True), ["GET", "POST"]) 94 | router.add_route("/proxy/tageditor/{path:path}", reverse_proxy_maker("tageditor"), ["GET", "POST"]) 95 | -------------------------------------------------------------------------------- /mikazuki/tasks.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | import os 4 | import threading 5 | import uuid 6 | from enum import Enum 7 | from typing import Dict, List 8 | from subprocess import Popen, PIPE, TimeoutExpired, CalledProcessError, CompletedProcess 9 | import psutil 10 | 11 | from mikazuki.log import log 12 | 13 | try: 14 | import msvcrt 15 | import _winapi 16 | _mswindows = True 17 | except ModuleNotFoundError: 18 | _mswindows = False 19 | 20 | 21 | def kill_proc_tree(pid, including_parent=True): 22 | parent = psutil.Process(pid) 23 | children = parent.children(recursive=True) 24 | for child in children: 25 | child.kill() 26 | gone, still_alive = psutil.wait_procs(children, timeout=5) 27 | if including_parent: 28 | parent.kill() 29 | parent.wait(5) 30 | 31 | 32 | class TaskStatus(Enum): 33 | CREATED = 0 34 | RUNNING = 1 35 | FINISHED = 2 36 | TERMINATED = 3 37 | 38 | 39 | class Task: 40 | def __init__(self, task_id, command, environ=None): 41 | self.task_id = task_id 42 | self.lock = threading.Lock() 43 | self.command = command 44 | self.status = TaskStatus.CREATED 45 | self.environ = environ or os.environ 46 | 47 | def communicate(self, input=None, timeout=None): 48 | try: 49 | stdout, stderr = self.process.communicate(input, timeout=timeout) 50 | except TimeoutExpired as exc: 51 | self.process.kill() 52 | if _mswindows: 53 | exc.stdout, exc.stderr = self.process.communicate() 54 | else: 55 | self.process.wait() 56 | raise 57 | except: 58 | self.process.kill() 59 | raise 60 | retcode = self.process.poll() 61 | self.status = TaskStatus.FINISHED 62 | return CompletedProcess(self.process.args, retcode, stdout, stderr) 63 | 64 | def wait(self): 65 | self.process.wait() 66 | self.status = TaskStatus.FINISHED 67 | 68 | def execute(self): 69 | self.status = TaskStatus.RUNNING 70 | self.process = subprocess.Popen(self.command, env=self.environ) 71 | 72 | def terminate(self): 73 | try: 74 | kill_proc_tree(self.process.pid, False) 75 | except Exception as e: 76 | log.error(f"Error when killing process: {e}") 77 | return 78 | finally: 79 | self.status = TaskStatus.TERMINATED 80 | 81 | 82 | class TaskManager: 83 | def __init__(self, max_concurrent=1) -> None: 84 | self.max_concurrent = max_concurrent 85 | self.tasks: Dict[Task] = {} 86 | 87 | def create_task(self, command: List[str], environ): 88 | running_tasks = [t for _, t in self.tasks.items() if t.status == TaskStatus.RUNNING] 89 | if len(running_tasks) >= self.max_concurrent: 90 | log.error( 91 | f"Unable to create a task because there are already {len(running_tasks)} tasks running, reaching the maximum concurrent limit. / 无法创建任务,因为已经有 {len(running_tasks)} 个任务正在运行,已达到最大并发限制。") 92 | return None 93 | task_id = str(uuid.uuid4()) 94 | task = Task(task_id=task_id, command=command, environ=environ) 95 | self.tasks[task_id] = task 96 | # task.execute() # breaking change 97 | log.info(f"Task {task_id} created") 98 | return task 99 | 100 | def add_task(self, task_id: str, task: Task): 101 | self.tasks[task_id] = task 102 | 103 | def terminate_task(self, task_id: str): 104 | if task_id in self.tasks: 105 | task = self.tasks[task_id] 106 | task.terminate() 107 | 108 | def wait_for_process(self, task_id: str): 109 | if task_id in self.tasks: 110 | task: Task = self.tasks[task_id] 111 | task.wait() 112 | 113 | def dump(self) -> List[Dict]: 114 | return [ 115 | { 116 | "id": task.task_id, 117 | "status": task.status.name, 118 | } 119 | for task in self.tasks.values() 120 | ] 121 | 122 | 123 | tm = TaskManager() 124 | -------------------------------------------------------------------------------- /mikazuki/schema/lora-basic.ts: -------------------------------------------------------------------------------- 1 | Schema.intersect([ 2 | Schema.object({ 3 | pretrained_model_name_or_path: Schema.string().role('filepicker').default("./sd-models/model.safetensors").description("底模文件路径"), 4 | }).description("训练用模型"), 5 | 6 | Schema.object({ 7 | train_data_dir: Schema.string().role('filepicker', { type: "folder" }).default("./train/aki").description("训练数据集路径"), 8 | reg_data_dir: Schema.string().role('filepicker', { type: "folder" }).description("正则化数据集路径。默认留空,不使用正则化图像"), 9 | resolution: Schema.string().default("512,512").description("训练图片分辨率,宽x高。支持非正方形,但必须是 64 倍数。"), 10 | }).description("数据集设置"), 11 | 12 | Schema.object({ 13 | output_name: Schema.string().default("aki").description("模型保存名称"), 14 | output_dir: Schema.string().default("./output").role('filepicker', { type: "folder" }).description("模型保存文件夹"), 15 | save_every_n_epochs: Schema.number().default(2).description("每 N epoch(轮)自动保存一次模型"), 16 | }).description("保存设置"), 17 | 18 | Schema.object({ 19 | max_train_epochs: Schema.number().min(1).default(10).description("最大训练 epoch(轮数)"), 20 | train_batch_size: Schema.number().min(1).default(1).description("批量大小"), 21 | }).description("训练相关参数"), 22 | 23 | Schema.intersect([ 24 | Schema.object({ 25 | unet_lr: Schema.string().default("1e-4").description("U-Net 学习率"), 26 | text_encoder_lr: Schema.string().default("1e-5").description("文本编码器学习率"), 27 | lr_scheduler: Schema.union([ 28 | "cosine", 29 | "cosine_with_restarts", 30 | "constant", 31 | "constant_with_warmup", 32 | ]).default("cosine_with_restarts").description("学习率调度器设置"), 33 | lr_warmup_steps: Schema.number().default(0).description('学习率预热步数'), 34 | }).description("学习率与优化器设置"), 35 | Schema.union([ 36 | Schema.object({ 37 | lr_scheduler: Schema.const('cosine_with_restarts'), 38 | lr_scheduler_num_cycles: Schema.number().default(1).description('重启次数'), 39 | }), 40 | Schema.object({}), 41 | ]), 42 | Schema.object({ 43 | optimizer_type: Schema.union([ 44 | "AdamW8bit", 45 | "Lion", 46 | ]).default("AdamW8bit").description("优化器设置"), 47 | }) 48 | ]), 49 | 50 | Schema.intersect([ 51 | Schema.object({ 52 | enable_preview: Schema.boolean().default(false).description('启用训练预览图'), 53 | }).description('训练预览图设置'), 54 | 55 | Schema.union([ 56 | Schema.object({ 57 | enable_preview: Schema.const(true).required(), 58 | sample_prompts: Schema.string().role('textarea').default(window.__MIKAZUKI__.SAMPLE_PROMPTS_DEFAULT).description(window.__MIKAZUKI__.SAMPLE_PROMPTS_DESCRIPTION), 59 | sample_sampler: Schema.union(["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver", "dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]).default("euler_a").description("生成预览图所用采样器"), 60 | sample_every_n_epochs: Schema.number().default(2).description("每 N 个 epoch 生成一次预览图"), 61 | }), 62 | Schema.object({}), 63 | ]), 64 | ]), 65 | 66 | Schema.intersect([ 67 | Schema.object({ 68 | network_weights: Schema.string().role('filepicker').description("从已有的 LoRA 模型上继续训练,填写路径"), 69 | network_dim: Schema.number().min(8).max(256).step(8).default(32).description("网络维度,常用 4~128,不是越大越好, 低dim可以降低显存占用"), 70 | network_alpha: Schema.number().min(1).default(32).description( 71 | "常用值:等于 network_dim 或 network_dim*1/2 或 1。使用较小的 alpha 需要提升学习率。" 72 | ), 73 | }).description("网络设置"), 74 | ]), 75 | 76 | Schema.object({ 77 | shuffle_caption: Schema.boolean().default(true).description("训练时随机打乱 tokens"), 78 | keep_tokens: Schema.number().min(0).max(255).step(1).default(0).description("在随机打乱 tokens 时,保留前 N 个不变"), 79 | }).description("caption 选项"), 80 | 81 | Schema.object({ 82 | mixed_precision: Schema.union(["no", "fp16", "bf16"]).default("fp16").description("混合精度, RTX30系列以后也可以指定`bf16`"), 83 | no_half_vae: Schema.boolean().description("不使用半精度 VAE,当出现 NaN detected in latents 报错时使用"), 84 | xformers: Schema.boolean().default(true).description("启用 xformers"), 85 | cache_latents: Schema.boolean().default(true).description("缓存图像 latent, 缓存 VAE 输出以减少 VRAM 使用") 86 | }).description("速度优化选项"), 87 | ]); 88 | -------------------------------------------------------------------------------- /README-zh.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | SD-Trainer 4 | 5 | # SD-Trainer 6 | 7 | _✨ 享受 Stable Diffusion 训练! ✨_ 8 | 9 |
10 | 11 |

12 | 13 | GitHub 仓库星标 14 | 15 | 16 | GitHub 仓库分支 17 | 18 | 19 | 许可证 20 | 21 | 22 | 发布版本 23 | 24 |

25 | 26 |

27 | 下载 28 | · 29 | 文档 30 | · 31 | 中文README 32 |

33 | 34 | LoRA-scripts(又名 SD-Trainer) 35 | 36 | LoRA & Dreambooth 训练图形界面 & 脚本预设 & 一键训练环境,用于 [kohya-ss/sd-scripts](https://github.com/kohya-ss/sd-scripts.git) 37 | 38 | ## ✨新特性: 训练 WebUI 39 | 40 | Stable Diffusion 训练工作台。一切集成于一个 WebUI 中。 41 | 42 | 按照下面的安装指南安装 GUI,然后运行 `run_gui.ps1`(Windows) 或 `run_gui.sh`(Linux) 来启动 GUI。 43 | 44 | ![image](https://github.com/Akegarasu/lora-scripts/assets/36563862/d3fcf5ad-fb8f-4e1d-81f9-c903376c19c6) 45 | 46 | | Tensorboard | WD 1.4 标签器 | 标签编辑器 | 47 | | ------------ | ------------ | ------------ | 48 | | ![image](https://github.com/Akegarasu/lora-scripts/assets/36563862/b2ac5c36-3edf-43a6-9719-cb00b757fc76) | ![image](https://github.com/Akegarasu/lora-scripts/assets/36563862/9504fad1-7d77-46a7-a68f-91fbbdbc7407) | ![image](https://github.com/Akegarasu/lora-scripts/assets/36563862/4597917b-caa8-4e90-b950-8b01738996f2) | 49 | 50 | 51 | # 使用方法 52 | 53 | ### 必要依赖 54 | 55 | Python 3.10 和 Git 56 | 57 | ### 克隆带子模块的仓库 58 | 59 | ```sh 60 | git clone --recurse-submodules https://github.com/Akegarasu/lora-scripts 61 | ``` 62 | 63 | ## ✨ SD-Trainer GUI 64 | 65 | ### Windows 66 | 67 | #### 安装 68 | 69 | 运行 `install-cn.ps1` 将自动为您创建虚拟环境并安装必要的依赖。 70 | 71 | #### 训练 72 | 73 | 运行 `run_gui.ps1`,程序将自动打开 [http://127.0.0.1:28000](http://127.0.0.1:28000) 74 | 75 | ### Linux 76 | 77 | #### 安装 78 | 79 | 运行 `install.bash` 将创建虚拟环境并安装必要的依赖。 80 | 81 | #### 训练 82 | 83 | 运行 `bash run_gui.bash`,程序将自动打开 [http://127.0.0.1:28000](http://127.0.0.1:28000) 84 | 85 | ## 通过手动运行脚本的传统训练方式 86 | 87 | ### Windows 88 | 89 | #### 安装 90 | 91 | 运行 `install.ps1` 将自动为您创建虚拟环境并安装必要的依赖。 92 | 93 | #### 训练 94 | 95 | 编辑 `train.ps1`,然后运行它。 96 | 97 | ### Linux 98 | 99 | #### 安装 100 | 101 | 运行 `install.bash` 将创建虚拟环境并安装必要的依赖。 102 | 103 | #### 训练 104 | 105 | 训练 106 | 107 | 脚本 `train.sh` **不会** 为您激活虚拟环境。您应该先激活虚拟环境。 108 | 109 | ```sh 110 | source venv/bin/activate 111 | ``` 112 | 113 | 编辑 `train.sh`,然后运行它。 114 | 115 | #### TensorBoard 116 | 117 | 运行 `tensorboard.ps1` 将在 http://localhost:6006/ 启动 TensorBoard 118 | 119 | ## 程序参数 120 | 121 | | 参数名称 | 类型 | 默认值 | 描述 | 122 | |------------------------------|-------|--------------|-------------------------------------------------| 123 | | `--host` | str | "127.0.0.1" | 服务器的主机名 | 124 | | `--port` | int | 28000 | 运行服务器的端口 | 125 | | `--listen` | bool | false | 启用服务器的监听模式 | 126 | | `--skip-prepare-environment` | bool | false | 跳过环境准备步骤 | 127 | | `--disable-tensorboard` | bool | false | 禁用 TensorBoard | 128 | | `--disable-tageditor` | bool | false | 禁用标签编辑器 | 129 | | `--tensorboard-host` | str | "127.0.0.1" | 运行 TensorBoard 的主机 | 130 | | `--tensorboard-port` | int | 6006 | 运行 TensorBoard 的端口 | 131 | | `--localization` | str | | 界面的本地化设置 | 132 | | `--dev` | bool | false | 开发者模式,用于禁用某些检查 | 133 | -------------------------------------------------------------------------------- /mikazuki/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import glob 3 | import os 4 | import re 5 | import shutil 6 | import sys 7 | 8 | from mikazuki.log import log 9 | 10 | python_bin = sys.executable 11 | 12 | 13 | class ModelType(Enum): 14 | UNKNOWN = -1 15 | SD15 = 1 16 | SD2 = 2 17 | SDXL = 3 18 | SD3 = 4 19 | LoRA = 5 20 | 21 | 22 | def is_promopt_like(s): 23 | for p in ["--n", "--s", "--l", "--d"]: 24 | if p in s: 25 | return True 26 | return False 27 | 28 | 29 | def validate_model(model_name: str, training_type: str = "sd-lora"): 30 | if os.path.exists(model_name): 31 | try: 32 | with open(model_name, "rb") as f: 33 | content = f.read(1024 * 200) 34 | 35 | model_type = match_model_type(content) 36 | 37 | if model_type == ModelType.UNKNOWN: 38 | log.error(f"Can't match model type from {model_name}") 39 | 40 | if model_type not in [ModelType.SD15, ModelType.SD2, ModelType.SDXL]: 41 | return False, "Pretrained model is not a Stable Diffusion checkpoint / 校验失败:底模不是 Stable Diffusion 模型" 42 | elif model_type == ModelType.SD3: 43 | return False, "Pretrained model not supported yet / 校验失败:SD3 模型暂不支持" 44 | elif model_type == ModelType.SDXL and training_type == "sd-lora": 45 | return False, "Pretrained model is SDXL, but you are training with LoRA / 校验失败:你选择的是 LoRA 训练,但预训练模型是 SDXL。请前往专家模式选择正确的模型种类。" 46 | 47 | except Exception as e: 48 | log.warn(f"model file {model_name} can't open: {e}") 49 | return True, "" 50 | 51 | return True, "ok" 52 | 53 | # huggerface model repo 54 | if model_name.count("/") <= 1: 55 | return True, "ok" 56 | 57 | return False, "model not found" 58 | 59 | 60 | def match_model_type(sig_content: bytes): 61 | if b"model.diffusion_model.x_embedder.proj.weight" in sig_content: 62 | return ModelType.SD3 63 | 64 | if b"conditioner.embedders.1.model.transformer.resblocks" in sig_content: 65 | return ModelType.SDXL 66 | 67 | if b"model.diffusion_model" in sig_content or b"cond_stage_model.transformer.text_model" in sig_content: 68 | return ModelType.SD15 69 | 70 | if b"lora_unet" in sig_content or b"lora_te" in sig_content: 71 | return ModelType.LoRA 72 | 73 | return ModelType.UNKNOWN 74 | 75 | 76 | def validate_data_dir(path): 77 | if not os.path.exists(path): 78 | log.error(f"Data dir {path} not exists, check your params") 79 | return False 80 | 81 | dir_content = os.listdir(path) 82 | 83 | if len(dir_content) == 0: 84 | log.error(f"Data dir {path} is empty, check your params") 85 | 86 | subdirs = [f for f in dir_content if os.path.isdir(os.path.join(path, f))] 87 | 88 | if len(subdirs) == 0: 89 | log.warn(f"No subdir found in data dir") 90 | 91 | ok_dir = [d for d in subdirs if re.findall(r"^\d+_.+", d)] 92 | 93 | if len(ok_dir) == 0: 94 | log.warning(f"No leagal dataset found. Try find avaliable images") 95 | imgs = get_total_images(path, False) 96 | captions = glob.glob(path + '/*.txt') 97 | log.info(f"{len(imgs)} images found, {len(captions)} captions found") 98 | if len(imgs) > 0: 99 | num_repeat = suggest_num_repeat(len(imgs)) 100 | dataset_path = os.path.join(path, f"{num_repeat}_zkz") 101 | os.makedirs(dataset_path) 102 | for i in imgs: 103 | shutil.move(i, dataset_path) 104 | if len(captions) > 0: 105 | for c in captions: 106 | shutil.move(c, dataset_path) 107 | log.info(f"Auto dataset created {dataset_path}") 108 | else: 109 | log.error("No image found in data dir") 110 | return False 111 | 112 | return True 113 | 114 | 115 | def suggest_num_repeat(img_count): 116 | if img_count <= 10: 117 | return 7 118 | elif 10 < img_count <= 50: 119 | return 5 120 | elif 50 < img_count <= 100: 121 | return 3 122 | 123 | return 1 124 | 125 | 126 | def check_training_params(data): 127 | potential_path = [ 128 | "train_data_dir", "reg_data_dir", "output_dir" 129 | ] 130 | file_paths = [ 131 | "sample_prompts" 132 | ] 133 | for p in potential_path: 134 | if p in data and not os.path.exists(data[p]): 135 | return False 136 | 137 | for f in file_paths: 138 | if f in data and not os.path.exists(data[f]): 139 | return False 140 | return True 141 | 142 | 143 | def get_total_images(path, recursive=True): 144 | if recursive: 145 | image_files = glob.glob(path + '/**/*.jpg', recursive=True) 146 | image_files += glob.glob(path + '/**/*.jpeg', recursive=True) 147 | image_files += glob.glob(path + '/**/*.png', recursive=True) 148 | else: 149 | image_files = glob.glob(path + '/*.jpg') 150 | image_files += glob.glob(path + '/*.jpeg') 151 | image_files += glob.glob(path + '/*.png') 152 | return image_files 153 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | SD-Trainer 4 | 5 | # SD-Trainer 6 | 7 | _✨ Enjoy Stable Diffusion Train! ✨_ 8 | 9 |
10 | 11 |

12 | 13 | GitHub Repo stars 14 | 15 | 16 | GitHub forks 17 | 18 | 19 | license 20 | 21 | 22 | release 23 | 24 |

25 | 26 |

27 | Download 28 | · 29 | Documents 30 | · 31 | 中文README 32 |

33 | 34 | LoRA-scripts (a.k.a SD-Trainer) 35 | 36 | LoRA & Dreambooth training GUI & scripts preset & one key training environment for [kohya-ss/sd-scripts](https://github.com/kohya-ss/sd-scripts.git) 37 | 38 | ## ✨NEW: Train WebUI 39 | 40 | The **REAL** Stable Diffusion Training Studio. Everything in one WebUI. 41 | 42 | Follow the installation guide below to install the GUI, then run `run_gui.ps1`(windows) or `run_gui.sh`(linux) to start the GUI. 43 | 44 | ![image](https://github.com/Akegarasu/lora-scripts/assets/36563862/d3fcf5ad-fb8f-4e1d-81f9-c903376c19c6) 45 | 46 | | Tensorboard | WD 1.4 Tagger | Tag Editor | 47 | | ------------ | ------------ | ------------ | 48 | | ![image](https://github.com/Akegarasu/lora-scripts/assets/36563862/b2ac5c36-3edf-43a6-9719-cb00b757fc76) | ![image](https://github.com/Akegarasu/lora-scripts/assets/36563862/9504fad1-7d77-46a7-a68f-91fbbdbc7407) | ![image](https://github.com/Akegarasu/lora-scripts/assets/36563862/4597917b-caa8-4e90-b950-8b01738996f2) | 49 | 50 | 51 | # Usage 52 | 53 | ### Required Dependencies 54 | 55 | Python 3.10 and Git 56 | 57 | ### Clone repo with submodules 58 | 59 | ```sh 60 | git clone --recurse-submodules https://github.com/Akegarasu/lora-scripts 61 | ``` 62 | 63 | ## ✨ SD-Trainer GUI 64 | 65 | ### Windows 66 | 67 | #### Installation 68 | 69 | Run `install.ps1` will automaticilly create a venv for you and install necessary deps. 70 | If you are in China mainland, please use `install-cn.ps1` 71 | 72 | #### Train 73 | 74 | run `run_gui.ps1`, then program will open [http://127.0.0.1:28000](http://127.0.0.1:28000) automanticlly 75 | 76 | ### Linux 77 | 78 | #### Installation 79 | 80 | Run `install.bash` will create a venv and install necessary deps. 81 | 82 | #### Train 83 | 84 | run `bash run_gui.bash`, then program will open [http://127.0.0.1:28000](http://127.0.0.1:28000) automanticlly 85 | 86 | ## Legacy training through run script manually 87 | 88 | ### Windows 89 | 90 | #### Installation 91 | 92 | Run `install.ps1` will automaticilly create a venv for you and install necessary deps. 93 | 94 | #### Train 95 | 96 | Edit `train.ps1`, and run it. 97 | 98 | ### Linux 99 | 100 | #### Installation 101 | 102 | Run `install.bash` will create a venv and install necessary deps. 103 | 104 | #### Train 105 | 106 | Training script `train.sh` **will not** activate venv for you. You should activate venv first. 107 | 108 | ```sh 109 | source venv/bin/activate 110 | ``` 111 | 112 | Edit `train.sh`, and run it. 113 | 114 | #### TensorBoard 115 | 116 | Run `tensorboard.ps1` will start TensorBoard at http://localhost:6006/ 117 | 118 | ## Program arguments 119 | 120 | | Parameter Name | Type | Default Value | Description | 121 | |-------------------------------|-------|---------------|--------------------------------------------------| 122 | | `--host` | str | "127.0.0.1" | Hostname for the server | 123 | | `--port` | int | 28000 | Port to run the server | 124 | | `--listen` | bool | false | Enable listening mode for the server | 125 | | `--skip-prepare-environment` | bool | false | Skip the environment preparation step | 126 | | `--disable-tensorboard` | bool | false | Disable TensorBoard | 127 | | `--disable-tageditor` | bool | false | Disable tag editor | 128 | | `--tensorboard-host` | str | "127.0.0.1" | Host to run TensorBoard | 129 | | `--tensorboard-port` | int | 6006 | Port to run TensorBoard | 130 | | `--localization` | str | | Localization settings for the interface | 131 | | `--dev` | bool | false | Developer mode to disale some checks | 132 | -------------------------------------------------------------------------------- /mikazuki/global.d.ts: -------------------------------------------------------------------------------- 1 | interface Window { 2 | __MIKAZUKI__: any; 3 | } 4 | 5 | type Dict = { 6 | [key in K]: T; 7 | }; 8 | 9 | declare const kSchema: unique symbol; 10 | 11 | declare namespace Schemastery { 12 | type From = X extends string | number | boolean ? SchemaI : X extends SchemaI ? X : X extends typeof String ? SchemaI : X extends typeof Number ? SchemaI : X extends typeof Boolean ? SchemaI : X extends typeof Function ? SchemaI any> : X extends Constructor ? SchemaI : never; 13 | type TypeS1 = X extends SchemaI ? S : never; 14 | type Inverse = X extends SchemaI ? (arg: Y) => void : never; 15 | type TypeS = TypeS1>; 16 | type TypeT = ReturnType>; 17 | type Resolve = (data: any, schema: SchemaI, options?: Options, strict?: boolean) => [any, any?]; 18 | type IntersectS = From extends SchemaI ? S : never; 19 | type IntersectT = Inverse> extends ((arg: infer T) => void) ? T : never; 20 | type TupleS = X extends readonly [infer L, ...infer R] ? [TypeS?, ...TupleS] : any[]; 21 | type TupleT = X extends readonly [infer L, ...infer R] ? [TypeT?, ...TupleT] : any[]; 22 | type ObjectS = { 23 | [K in keyof X]?: TypeS | null; 24 | } & Dict; 25 | type ObjectT = { 26 | [K in keyof X]: TypeT; 27 | } & Dict; 28 | type Constructor = new (...args: any[]) => T; 29 | interface Static { 30 | (options: Partial>): SchemaI; 31 | new (options: Partial>): SchemaI; 32 | prototype: SchemaI; 33 | resolve: Resolve; 34 | from(source?: X): From; 35 | extend(type: string, resolve: Resolve): void; 36 | any(): SchemaI; 37 | never(): SchemaI; 38 | const(value: T): SchemaI; 39 | string(): SchemaI; 40 | number(): SchemaI; 41 | natural(): SchemaI; 42 | percent(): SchemaI; 43 | boolean(): SchemaI; 44 | date(): SchemaI; 45 | bitset(bits: Partial>): SchemaI; 46 | function(): SchemaI any>; 47 | is(constructor: Constructor): SchemaI; 48 | array(inner: X): SchemaI[], TypeT[]>; 49 | dict = SchemaI>(inner: X, sKey?: Y): SchemaI, TypeS>, Dict, TypeT>>; 50 | tuple(list: X): SchemaI, TupleT>; 51 | object(dict: X): SchemaI, ObjectT>; 52 | union(list: readonly X[]): SchemaI, TypeT>; 53 | intersect(list: readonly X[]): SchemaI, IntersectT>; 54 | transform(inner: X, callback: (value: TypeS) => T, preserve?: boolean): SchemaI, T>; 55 | } 56 | interface Options { 57 | autofix?: boolean; 58 | } 59 | interface Meta { 60 | default?: T extends {} ? Partial : T; 61 | required?: boolean; 62 | disabled?: boolean; 63 | collapse?: boolean; 64 | badges?: { 65 | text: string; 66 | type: string; 67 | }[]; 68 | hidden?: boolean; 69 | loose?: boolean; 70 | role?: string; 71 | extra?: any; 72 | link?: string; 73 | description?: string | Dict; 74 | comment?: string; 75 | pattern?: { 76 | source: string; 77 | flags?: string; 78 | }; 79 | max?: number; 80 | min?: number; 81 | step?: number; 82 | } 83 | 84 | interface Schemastery { 85 | (data?: S | null, options?: Schemastery.Options): T; 86 | new(data?: S | null, options?: Schemastery.Options): T; 87 | [kSchema]: true; 88 | uid: number; 89 | meta: Schemastery.Meta; 90 | type: string; 91 | sKey?: SchemaI; 92 | inner?: SchemaI; 93 | list?: SchemaI[]; 94 | dict?: Dict; 95 | bits?: Dict; 96 | callback?: Function; 97 | value?: T; 98 | refs?: Dict; 99 | preserve?: boolean; 100 | toString(inline?: boolean): string; 101 | toJSON(): SchemaI; 102 | required(value?: boolean): SchemaI; 103 | hidden(value?: boolean): SchemaI; 104 | loose(value?: boolean): SchemaI; 105 | role(text: string, extra?: any): SchemaI; 106 | link(link: string): SchemaI; 107 | default(value: T): SchemaI; 108 | comment(text: string): SchemaI; 109 | description(text: string): SchemaI; 110 | disabled(value?: boolean): SchemaI; 111 | collapse(value?: boolean): SchemaI; 112 | deprecated(): SchemaI; 113 | experimental(): SchemaI; 114 | pattern(regexp: RegExp): SchemaI; 115 | max(value: number): SchemaI; 116 | min(value: number): SchemaI; 117 | step(value: number): SchemaI; 118 | set(key: string, value: SchemaI): SchemaI; 119 | push(value: SchemaI): SchemaI; 120 | simplify(value?: any): any; 121 | i18n(messages: Dict): SchemaI; 122 | extra(key: K, value: Schemastery.Meta[K]): SchemaI; 123 | } 124 | 125 | } 126 | 127 | type SchemaI = Schemastery.Schemastery; 128 | 129 | declare const Schema: Schemastery.Static -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # LoRA train script by @Akegarasu 3 | 4 | # Train data path | 设置训练用模型、图片 5 | pretrained_model="./sd-models/model.ckpt" # base model path | 底模路径 6 | model_type="sd1.5" # option: sd1.5 sd2.0 sdxl | 可选 sd1.5 sd2.0 sdxl。SD2.0模型 2.0模型下 clip_skip 默认无效 7 | parameterization=0 # parameterization | 参数化 本参数需要在 model_type 为 sd2.0 时才可启用 8 | 9 | train_data_dir="./train/aki" # train dataset path | 训练数据集路径 10 | reg_data_dir="" # directory for regularization images | 正则化数据集路径,默认不使用正则化图像。 11 | 12 | # Network settings | 网络设置 13 | network_module="networks.lora" # 在这里将会设置训练的网络种类,默认为 networks.lora 也就是 LoRA 训练。如果你想训练 LyCORIS(LoCon、LoHa) 等,则修改这个值为 lycoris.kohya 14 | network_weights="" # pretrained weights for LoRA network | 若需要从已有的 LoRA 模型上继续训练,请填写 LoRA 模型路径。 15 | network_dim=32 # network dim | 常用 4~128,不是越大越好 16 | network_alpha=32 # network alpha | 常用与 network_dim 相同的值或者采用较小的值,如 network_dim的一半 防止下溢。默认值为 1,使用较小的 alpha 需要提升学习率。 17 | 18 | # Train related params | 训练相关参数 19 | resolution="512,512" # image resolution w,h. 图片分辨率,宽,高。支持非正方形,但必须是 64 倍数。 20 | batch_size=1 # batch size 21 | max_train_epoches=10 # max train epoches | 最大训练 epoch 22 | save_every_n_epochs=2 # save every n epochs | 每 N 个 epoch 保存一次 23 | 24 | train_unet_only=0 # train U-Net only | 仅训练 U-Net,开启这个会牺牲效果大幅减少显存使用。6G显存可以开启 25 | train_text_encoder_only=0 # train Text Encoder only | 仅训练 文本编码器 26 | stop_text_encoder_training=0 # stop text encoder training | 在第N步时停止训练文本编码器 27 | 28 | noise_offset="0" # noise offset | 在训练中添加噪声偏移来改良生成非常暗或者非常亮的图像,如果启用,推荐参数为0.1 29 | keep_tokens=0 # keep heading N tokens when shuffling caption tokens | 在随机打乱 tokens 时,保留前 N 个不变。 30 | min_snr_gamma=0 # minimum signal-to-noise ratio (SNR) value for gamma-ray | 伽马射线事件的最小信噪比(SNR)值 默认为 0 31 | 32 | # Learning rate | 学习率 33 | lr="1e-4" # learning rate | 学习率,在分别设置下方 U-Net 和 文本编码器 的学习率时,该参数失效 34 | unet_lr="1e-4" # U-Net learning rate | U-Net 学习率 35 | text_encoder_lr="1e-5" # Text Encoder learning rate | 文本编码器 学习率 36 | lr_scheduler="cosine_with_restarts" # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "adafactor" 37 | lr_warmup_steps=0 # warmup steps | 学习率预热步数,lr_scheduler 为 constant 或 adafactor 时该值需要设为0。 38 | lr_restart_cycles=1 # cosine_with_restarts restart cycles | 余弦退火重启次数,仅在 lr_scheduler 为 cosine_with_restarts 时起效。 39 | 40 | # Optimizer settings | 优化器设置 41 | optimizer_type="AdamW8bit" # Optimizer type | 优化器类型 默认为 AdamW8bit,可选:AdamW AdamW8bit Lion Lion8bit SGDNesterov SGDNesterov8bit DAdaptation AdaFactor prodigy 42 | 43 | # Output settings | 输出设置 44 | output_name="aki" # output model name | 模型保存名称 45 | save_model_as="safetensors" # model save ext | 模型保存格式 ckpt, pt, safetensors 46 | 47 | # Resume training state | 恢复训练设置 48 | save_state=0 # save state | 保存训练状态 名称类似于 -??????-state ?????? 表示 epoch 数 49 | resume="" # resume from state | 从某个状态文件夹中恢复训练 需配合上方参数同时使用 由于规范文件限制 epoch 数和全局步数不会保存 即使恢复时它们也从 1 开始 与 network_weights 的具体实现操作并不一致 50 | 51 | # 其他设置 52 | min_bucket_reso=256 # arb min resolution | arb 最小分辨率 53 | max_bucket_reso=1024 # arb max resolution | arb 最大分辨率 54 | persistent_data_loader_workers=1 # persistent dataloader workers | 保留加载训练集的worker,减少每个 epoch 之间的停顿 55 | clip_skip=2 # clip skip | 玄学 一般用 2 56 | multi_gpu=0 # multi gpu | 多显卡训练 该参数仅限在显卡数 >= 2 使用 57 | lowram=0 # lowram mode | 低内存模式 该模式下会将 U-net 文本编码器 VAE 转移到 GPU 显存中 启用该模式可能会对显存有一定影响 58 | 59 | # LyCORIS 训练设置 60 | algo="lora" # LyCORIS network algo | LyCORIS 网络算法 可选 lora、loha、lokr、ia3、dylora。lora即为locon 61 | conv_dim=4 # conv dim | 类似于 network_dim,推荐为 4 62 | conv_alpha=4 # conv alpha | 类似于 network_alpha,可以采用与 conv_dim 一致或者更小的值 63 | dropout="0" # dropout | dropout 概率, 0 为不使用 dropout, 越大则 dropout 越多,推荐 0~0.5, LoHa/LoKr/(IA)^3暂时不支持 64 | 65 | # Remote logging | 远程记录设置 66 | use_wandb=0 # use_wandb | 启用wandb远程记录功能 67 | wandb_api_key="" # wandb_api_key | API,通过 https://wandb.ai/authorize 获取 68 | log_tracker_name="" # log_tracker_name | wandb项目名称,留空则为"network_train" 69 | 70 | # ============= DO NOT MODIFY CONTENTS BELOW | 请勿修改下方内容 ===================== 71 | export HF_HOME="huggingface" 72 | export TF_CPP_MIN_LOG_LEVEL=3 73 | 74 | extArgs=() 75 | launchArgs=() 76 | 77 | trainer_file="./sd-scripts/train_network.py" 78 | 79 | if [ $model_type == "sd1.5" ]; then 80 | ext_args+=("--clip_skip=$clip_skip") 81 | elif [ $model_type == "sd2.0" ]; then 82 | ext_args+=("--v2") 83 | elif [ $model_type == "sdxl" ]; then 84 | trainer_file="./sd-scripts/sdxl_train_network.py" 85 | fi 86 | 87 | if [[ $multi_gpu == 1 ]]; then 88 | launchArgs+=("--multi_gpu") 89 | launchArgs+=("--num_processes=2") 90 | fi 91 | 92 | if [[ $lowram ]]; then extArgs+=("--lowram"); fi 93 | 94 | if [[ $parameterization == 1 ]]; then extArgs+=("--v_parameterization"); fi 95 | 96 | if [[ $train_unet_only == 1 ]]; then extArgs+=("--network_train_unet_only"); fi 97 | 98 | if [[ $train_text_encoder_only == 1 ]]; then extArgs+=("--network_train_text_encoder_only"); fi 99 | 100 | if [[ $network_weights ]]; then extArgs+=("--network_weights $network_weights"); fi 101 | 102 | if [[ $reg_data_dir ]]; then extArgs+=("--reg_data_dir $reg_data_dir"); fi 103 | 104 | if [[ $optimizer_type ]]; then extArgs+=("--optimizer_type $optimizer_type"); fi 105 | 106 | if [[ $optimizer_type == "DAdaptation" ]]; then extArgs+=("--optimizer_args decouple=True"); fi 107 | 108 | if [[ $save_state == 1 ]]; then extArgs+=("--save_state"); fi 109 | 110 | if [[ $resume ]]; then extArgs+=("--resume $resume"); fi 111 | 112 | if [[ $persistent_data_loader_workers == 1 ]]; then extArgs+=("--persistent_data_loader_workers"); fi 113 | 114 | if [[ $network_module == "lycoris.kohya" ]]; then 115 | extArgs+=("--network_args conv_dim=$conv_dim conv_alpha=$conv_alpha algo=$algo dropout=$dropout") 116 | fi 117 | 118 | if [[ $stop_text_encoder_training -ne 0 ]]; then extArgs+=("--stop_text_encoder_training $stop_text_encoder_training"); fi 119 | 120 | if [[ $noise_offset != "0" ]]; then extArgs+=("--noise_offset $noise_offset"); fi 121 | 122 | if [[ $min_snr_gamma -ne 0 ]]; then extArgs+=("--min_snr_gamma $min_snr_gamma"); fi 123 | 124 | if [[ $use_wandb == 1 ]]; then 125 | extArgs+=("--log_with=all") 126 | if [[ $wandb_api_key ]]; then extArgs+=("--wandb_api_key $wandb_api_key"); fi 127 | if [[ $log_tracker_name ]]; then extArgs+=("--log_tracker_name $log_tracker_name"); fi 128 | else 129 | extArgs+=("--log_with=tensorboard") 130 | fi 131 | 132 | python -m accelerate.commands.launch ${launchArgs[@]} --num_cpu_threads_per_process=4 $trainer_file \ 133 | --enable_bucket \ 134 | --pretrained_model_name_or_path=$pretrained_model \ 135 | --train_data_dir=$train_data_dir \ 136 | --output_dir="./output" \ 137 | --logging_dir="./logs" \ 138 | --log_prefix=$output_name \ 139 | --resolution=$resolution \ 140 | --network_module=$network_module \ 141 | --max_train_epochs=$max_train_epoches \ 142 | --learning_rate=$lr \ 143 | --unet_lr=$unet_lr \ 144 | --text_encoder_lr=$text_encoder_lr \ 145 | --lr_scheduler=$lr_scheduler \ 146 | --lr_warmup_steps=$lr_warmup_steps \ 147 | --lr_scheduler_num_cycles=$lr_restart_cycles \ 148 | --network_dim=$network_dim \ 149 | --network_alpha=$network_alpha \ 150 | --output_name=$output_name \ 151 | --train_batch_size=$batch_size \ 152 | --save_every_n_epochs=$save_every_n_epochs \ 153 | --mixed_precision="fp16" \ 154 | --save_precision="fp16" \ 155 | --seed="1337" \ 156 | --cache_latents \ 157 | --prior_loss_weight=1 \ 158 | --max_token_length=225 \ 159 | --caption_extension=".txt" \ 160 | --save_model_as=$save_model_as \ 161 | --min_bucket_reso=$min_bucket_reso \ 162 | --max_bucket_reso=$max_bucket_reso \ 163 | --keep_tokens=$keep_tokens \ 164 | --xformers --shuffle_caption ${extArgs[@]} 165 | -------------------------------------------------------------------------------- /mikazuki/app/api.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import os 4 | from datetime import datetime 5 | from pathlib import Path 6 | 7 | import toml 8 | import hashlib 9 | from fastapi import APIRouter, BackgroundTasks, Request 10 | from starlette.requests import Request 11 | 12 | import mikazuki.process as process 13 | from mikazuki import launch_utils 14 | from mikazuki.app.models import (APIResponse, APIResponseFail, 15 | APIResponseSuccess, TaggerInterrogateRequest) 16 | from mikazuki.log import log 17 | from mikazuki.tagger.interrogator import (available_interrogators, 18 | on_interrogate) 19 | from mikazuki.tasks import tm 20 | from mikazuki.utils import train_utils 21 | from mikazuki.utils.devices import printable_devices 22 | from mikazuki.utils.tk_window import (open_directory_selector, 23 | open_file_selector) 24 | 25 | router = APIRouter() 26 | 27 | avaliable_scripts = [ 28 | "networks/extract_lora_from_models.py", 29 | "networks/extract_lora_from_dylora.py", 30 | "networks/merge_lora.py", 31 | "tools/merge_models.py", 32 | ] 33 | 34 | avaliable_schemas = [] 35 | 36 | trainer_mapping = { 37 | "sd-lora": "./sd-scripts/train_network.py", 38 | "sdxl-lora": "./sd-scripts/sdxl_train_network.py", 39 | "sd-dreambooth": "./sd-scripts/train_db.py", 40 | "sdxl-finetune": "./sd-scripts/sdxl_train.py", 41 | } 42 | 43 | 44 | async def load_schemas(): 45 | avaliable_schemas.clear() 46 | 47 | schema_dir = os.path.join(os.getcwd(), "mikazuki", "schema") 48 | schemas = os.listdir(schema_dir) 49 | 50 | def lambda_hash(x): 51 | return hashlib.md5(x.encode()).hexdigest() 52 | 53 | for schema_name in schemas: 54 | with open(os.path.join(schema_dir, schema_name), encoding="utf-8") as f: 55 | content = f.read() 56 | avaliable_schemas.append({ 57 | "name": schema_name.strip(".ts"), 58 | "schema": content, 59 | "hash": lambda_hash(content) 60 | }) 61 | 62 | 63 | @router.post("/run") 64 | async def create_toml_file(request: Request): 65 | timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") 66 | toml_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}.toml") 67 | json_data = await request.body() 68 | config: dict = json.loads(json_data.decode("utf-8")) 69 | 70 | gpu_ids = config.pop("gpu_ids", None) 71 | 72 | suggest_cpu_threads = 8 if len(train_utils.get_total_images(config["train_data_dir"])) > 200 else 2 73 | model_train_type = config.pop("model_train_type", "sd-lora") 74 | trainer_file = trainer_mapping[model_train_type] 75 | 76 | if model_train_type != "sdxl-finetune": 77 | if not train_utils.validate_data_dir(config["train_data_dir"]): 78 | return APIResponseFail(message="训练数据集路径不存在或没有图片,请检查目录。") 79 | 80 | validated, message = train_utils.validate_model(config["pretrained_model_name_or_path"], model_train_type) 81 | if not validated: 82 | return APIResponseFail(message=message) 83 | 84 | sample_prompts = config.get("sample_prompts", None) 85 | if sample_prompts is not None and not os.path.exists(sample_prompts) and train_utils.is_promopt_like(sample_prompts): 86 | sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt") 87 | with open(sample_prompts_file, "w", encoding="utf-8") as f: 88 | f.write(sample_prompts) 89 | config["sample_prompts"] = sample_prompts_file 90 | log.info(f"Wrote promopts to file {sample_prompts_file}") 91 | 92 | with open(toml_file, "w", encoding="utf-8") as f: 93 | f.write(toml.dumps(config)) 94 | 95 | result = process.run_train(toml_file, trainer_file, gpu_ids, suggest_cpu_threads) 96 | 97 | return result 98 | 99 | 100 | @router.post("/run_script") 101 | async def run_script(request: Request, background_tasks: BackgroundTasks): 102 | paras = await request.body() 103 | j = json.loads(paras.decode("utf-8")) 104 | script_name = j["script_name"] 105 | if script_name not in avaliable_scripts: 106 | return APIResponseFail(message="Script not found") 107 | del j["script_name"] 108 | result = [] 109 | for k, v in j.items(): 110 | result.append(f"--{k}") 111 | if not isinstance(v, bool): 112 | value = str(v) 113 | if " " in value: 114 | value = f'"{v}"' 115 | result.append(value) 116 | script_args = " ".join(result) 117 | script_path = Path(os.getcwd()) / "sd-scripts" / script_name 118 | cmd = f"{launch_utils.python_bin} {script_path} {script_args}" 119 | background_tasks.add_task(launch_utils.run, cmd) 120 | return APIResponseSuccess() 121 | 122 | 123 | @router.post("/interrogate") 124 | async def run_interrogate(req: TaggerInterrogateRequest, background_tasks: BackgroundTasks): 125 | interrogator = available_interrogators.get(req.interrogator_model, available_interrogators["wd14-convnextv2-v2"]) 126 | background_tasks.add_task( 127 | on_interrogate, 128 | image=None, 129 | batch_input_glob=req.path, 130 | batch_input_recursive=req.batch_input_recursive, 131 | batch_output_dir="", 132 | batch_output_filename_format="[name].[output_extension]", 133 | batch_output_action_on_conflict=req.batch_output_action_on_conflict, 134 | batch_remove_duplicated_tag=True, 135 | batch_output_save_json=False, 136 | interrogator=interrogator, 137 | threshold=req.threshold, 138 | additional_tags=req.additional_tags, 139 | exclude_tags=req.exclude_tags, 140 | sort_by_alphabetical_order=False, 141 | add_confident_as_weight=False, 142 | replace_underscore=req.replace_underscore, 143 | replace_underscore_excludes=req.replace_underscore_excludes, 144 | escape_tag=req.escape_tag, 145 | unload_model_after_running=True 146 | ) 147 | return APIResponseSuccess() 148 | 149 | 150 | @router.get("/pick_file") 151 | async def pick_file(picker_type: str): 152 | if picker_type == "folder": 153 | coro = asyncio.to_thread(open_directory_selector, os.getcwd()) 154 | elif picker_type == "modelfile": 155 | file_types = [("checkpoints", "*.safetensors;*.ckpt;*.pt"), ("all files", "*.*")] 156 | coro = asyncio.to_thread(open_file_selector, os.getcwd(), "Select file", file_types) 157 | 158 | result = await coro 159 | if result == "": 160 | return APIResponseFail(message="用户取消选择") 161 | 162 | return APIResponseSuccess(data={ 163 | "path": result 164 | }) 165 | 166 | 167 | @router.get("/tasks", response_model_exclude_none=True) 168 | async def get_tasks() -> APIResponse: 169 | return APIResponseSuccess(data={ 170 | "tasks": tm.dump() 171 | }) 172 | 173 | 174 | @router.get("/tasks/terminate/{task_id}", response_model_exclude_none=True) 175 | async def terminate_task(task_id: str): 176 | tm.terminate_task(task_id) 177 | return APIResponseSuccess() 178 | 179 | 180 | @router.get("/graphic_cards") 181 | async def list_avaliable_cards() -> APIResponse: 182 | if not printable_devices: 183 | return APIResponse(status="pending") 184 | 185 | return APIResponseSuccess(data={ 186 | "cards": printable_devices 187 | }) 188 | 189 | 190 | @router.get("/schemas/hashes") 191 | async def list_schema_hashes() -> APIResponse: 192 | if os.environ.get("MIKAZUKI_SCHEMA_HOT_RELOAD", "0") == "1": 193 | log.info("Hot reloading schemas") 194 | await load_schemas() 195 | 196 | return APIResponseSuccess(data={ 197 | "schemas": [ 198 | { 199 | "name": schema["name"], 200 | "hash": schema["hash"] 201 | } 202 | for schema in avaliable_schemas 203 | ] 204 | }) 205 | 206 | 207 | @router.get("/schemas/all") 208 | async def get_all_schemas() -> APIResponse: 209 | return APIResponseSuccess(data={ 210 | "schemas": avaliable_schemas 211 | }) 212 | -------------------------------------------------------------------------------- /train.ps1: -------------------------------------------------------------------------------- 1 | # LoRA train script by @Akegarasu 2 | 3 | # Train data path | 设置训练用模型、图片 4 | $pretrained_model = "./sd-models/model.ckpt" # base model path | 底模路径 5 | $model_type = "sd1.5" # sd1.5 sd2.0 sdxl model | 可选 sd1.5 sd2.0 sdxl。SD2.0模型 2.0模型下 clip_skip 默认无效 6 | $parameterization = 0 # parameterization | 参数化 本参数需要在 model_type 为 sd2.0 时才可启用 7 | 8 | $train_data_dir = "./train/aki" # train dataset path | 训练数据集路径 9 | $reg_data_dir = "" # directory for regularization images | 正则化数据集路径,默认不使用正则化图像。 10 | 11 | # Network settings | 网络设置 12 | $network_module = "networks.lora" # 在这里将会设置训练的网络种类,默认为 networks.lora 也就是 LoRA 训练。如果你想训练 LyCORIS(LoCon、LoHa) 等,则修改这个值为 lycoris.kohya 13 | $network_weights = "" # pretrained weights for LoRA network | 若需要从已有的 LoRA 模型上继续训练,请填写 LoRA 模型路径。 14 | $network_dim = 32 # network dim | 常用 4~128,不是越大越好 15 | $network_alpha = 32 # network alpha | 常用与 network_dim 相同的值或者采用较小的值,如 network_dim的一半 防止下溢。默认值为 1,使用较小的 alpha 需要提升学习率。 16 | 17 | # Train related params | 训练相关参数 18 | $resolution = "512,512" # image resolution w,h. 图片分辨率,宽,高。支持非正方形,但必须是 64 倍数。 19 | $batch_size = 1 # batch size | batch 大小 20 | $max_train_epoches = 10 # max train epoches | 最大训练 epoch 21 | $save_every_n_epochs = 2 # save every n epochs | 每 N 个 epoch 保存一次 22 | 23 | $train_unet_only = 0 # train U-Net only | 仅训练 U-Net,开启这个会牺牲效果大幅减少显存使用。6G显存可以开启 24 | $train_text_encoder_only = 0 # train Text Encoder only | 仅训练 文本编码器 25 | $stop_text_encoder_training = 0 # stop text encoder training | 在第 N 步时停止训练文本编码器 26 | 27 | $noise_offset = 0 # noise offset | 在训练中添加噪声偏移来改良生成非常暗或者非常亮的图像,如果启用,推荐参数为 0.1 28 | $keep_tokens = 0 # keep heading N tokens when shuffling caption tokens | 在随机打乱 tokens 时,保留前 N 个不变。 29 | $min_snr_gamma = 0 # minimum signal-to-noise ratio (SNR) value for gamma-ray | 伽马射线事件的最小信噪比(SNR)值 默认为 0 30 | 31 | # Learning rate | 学习率 32 | $lr = "1e-4" # learning rate | 学习率,在分别设置下方 U-Net 和 文本编码器 的学习率时,该参数失效 33 | $unet_lr = "1e-4" # U-Net learning rate | U-Net 学习率 34 | $text_encoder_lr = "1e-5" # Text Encoder learning rate | 文本编码器 学习率 35 | $lr_scheduler = "cosine_with_restarts" # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" 36 | $lr_warmup_steps = 0 # warmup steps | 学习率预热步数,lr_scheduler 为 constant 或 adafactor 时该值需要设为0。 37 | $lr_restart_cycles = 1 # cosine_with_restarts restart cycles | 余弦退火重启次数,仅在 lr_scheduler 为 cosine_with_restarts 时起效。 38 | 39 | # Optimizer settings | 优化器设置 40 | $optimizer_type = "AdamW8bit" # Optimizer type | 优化器类型 默认为 AdamW8bit,可选:AdamW AdamW8bit Lion Lion8bit SGDNesterov SGDNesterov8bit DAdaptation AdaFactor prodigy 41 | 42 | # Output settings | 输出设置 43 | $output_name = "aki" # output model name | 模型保存名称 44 | $save_model_as = "safetensors" # model save ext | 模型保存格式 ckpt, pt, safetensors 45 | 46 | # Resume training state | 恢复训练设置 47 | $save_state = 0 # save training state | 保存训练状态 名称类似于 -??????-state ?????? 表示 epoch 数 48 | $resume = "" # resume from state | 从某个状态文件夹中恢复训练 需配合上方参数同时使用 由于规范文件限制 epoch 数和全局步数不会保存 即使恢复时它们也从 1 开始 与 network_weights 的具体实现操作并不一致 49 | 50 | # 其他设置 51 | $min_bucket_reso = 256 # arb min resolution | arb 最小分辨率 52 | $max_bucket_reso = 1024 # arb max resolution | arb 最大分辨率 53 | $persistent_data_loader_workers = 1 # persistent dataloader workers | 保留加载训练集的worker,减少每个 epoch 之间的停顿 54 | $clip_skip = 2 # clip skip | 玄学 一般用 2 55 | $multi_gpu = 0 # multi gpu | 多显卡训练 该参数仅限在显卡数 >= 2 使用 56 | $lowram = 0 # lowram mode | 低内存模式 该模式下会将 U-net 文本编码器 VAE 转移到 GPU 显存中 启用该模式可能会对显存有一定影响 57 | 58 | # LyCORIS 训练设置 59 | $algo = "lora" # LyCORIS network algo | LyCORIS 网络算法 可选 lora、loha、lokr、ia3、dylora。lora即为locon 60 | $conv_dim = 4 # conv dim | 类似于 network_dim,推荐为 4 61 | $conv_alpha = 4 # conv alpha | 类似于 network_alpha,可以采用与 conv_dim 一致或者更小的值 62 | $dropout = "0" # dropout | dropout 概率, 0 为不使用 dropout, 越大则 dropout 越多,推荐 0~0.5, LoHa/LoKr/(IA)^3 暂时不支持 63 | 64 | # 远程记录设置 65 | $use_wandb = 0 # enable wandb logging | 启用wandb远程记录功能 66 | $wandb_api_key = "" # wandb api key | API,通过 https://wandb.ai/authorize 获取 67 | $log_tracker_name = "" # wandb log tracker name | wandb项目名称,留空则为"network_train" 68 | 69 | # ============= DO NOT MODIFY CONTENTS BELOW | 请勿修改下方内容 ===================== 70 | # Activate python venv 71 | .\venv\Scripts\activate 72 | 73 | $Env:HF_HOME = "huggingface" 74 | $Env:XFORMERS_FORCE_DISABLE_TRITON = "1" 75 | $ext_args = [System.Collections.ArrayList]::new() 76 | $launch_args = [System.Collections.ArrayList]::new() 77 | 78 | $trainer_file = "./sd-scripts/train_network.py" 79 | 80 | if ($model_type -eq "sd1.5") { 81 | [void]$ext_args.Add("--clip_skip=$clip_skip") 82 | } elseif ($model_type -eq "sd2.0") { 83 | [void]$ext_args.Add("--v2") 84 | } elseif ($model_type -eq "sdxl") { 85 | $trainer_file = "./sd-scripts/sdxl_train_network.py" 86 | } 87 | 88 | if ($multi_gpu) { 89 | [void]$launch_args.Add("--multi_gpu") 90 | [void]$launch_args.Add("--num_processes=2") 91 | } 92 | 93 | if ($lowram) { 94 | [void]$ext_args.Add("--lowram") 95 | } 96 | 97 | if ($parameterization) { 98 | [void]$ext_args.Add("--v_parameterization") 99 | } 100 | 101 | if ($train_unet_only) { 102 | [void]$ext_args.Add("--network_train_unet_only") 103 | } 104 | 105 | if ($train_text_encoder_only) { 106 | [void]$ext_args.Add("--network_train_text_encoder_only") 107 | } 108 | 109 | if ($network_weights) { 110 | [void]$ext_args.Add("--network_weights=" + $network_weights) 111 | } 112 | 113 | if ($reg_data_dir) { 114 | [void]$ext_args.Add("--reg_data_dir=" + $reg_data_dir) 115 | } 116 | 117 | if ($optimizer_type) { 118 | [void]$ext_args.Add("--optimizer_type=" + $optimizer_type) 119 | } 120 | 121 | if ($optimizer_type -eq "DAdaptation") { 122 | [void]$ext_args.Add("--optimizer_args") 123 | [void]$ext_args.Add("decouple=True") 124 | } 125 | 126 | if ($network_module -eq "lycoris.kohya") { 127 | [void]$ext_args.Add("--network_args") 128 | [void]$ext_args.Add("conv_dim=$conv_dim") 129 | [void]$ext_args.Add("conv_alpha=$conv_alpha") 130 | [void]$ext_args.Add("algo=$algo") 131 | [void]$ext_args.Add("dropout=$dropout") 132 | } 133 | 134 | if ($noise_offset -ne 0) { 135 | [void]$ext_args.Add("--noise_offset=$noise_offset") 136 | } 137 | 138 | if ($stop_text_encoder_training -ne 0) { 139 | [void]$ext_args.Add("--stop_text_encoder_training=$stop_text_encoder_training") 140 | } 141 | 142 | if ($save_state -eq 1) { 143 | [void]$ext_args.Add("--save_state") 144 | } 145 | 146 | if ($resume) { 147 | [void]$ext_args.Add("--resume=" + $resume) 148 | } 149 | 150 | if ($min_snr_gamma -ne 0) { 151 | [void]$ext_args.Add("--min_snr_gamma=$min_snr_gamma") 152 | } 153 | 154 | if ($persistent_data_loader_workers) { 155 | [void]$ext_args.Add("--persistent_data_loader_workers") 156 | } 157 | 158 | if ($use_wandb -eq 1) { 159 | [void]$ext_args.Add("--log_with=all") 160 | if ($wandb_api_key) { 161 | [void]$ext_args.Add("--wandb_api_key=" + $wandb_api_key) 162 | } 163 | 164 | if ($log_tracker_name) { 165 | [void]$ext_args.Add("--log_tracker_name=" + $log_tracker_name) 166 | } 167 | } 168 | else { 169 | [void]$ext_args.Add("--log_with=tensorboard") 170 | } 171 | 172 | # run train 173 | python -m accelerate.commands.launch $launch_args --num_cpu_threads_per_process=4 $trainer_file ` 174 | --enable_bucket ` 175 | --pretrained_model_name_or_path=$pretrained_model ` 176 | --train_data_dir=$train_data_dir ` 177 | --output_dir="./output" ` 178 | --logging_dir="./logs" ` 179 | --log_prefix=$output_name ` 180 | --resolution=$resolution ` 181 | --network_module=$network_module ` 182 | --max_train_epochs=$max_train_epoches ` 183 | --learning_rate=$lr ` 184 | --unet_lr=$unet_lr ` 185 | --text_encoder_lr=$text_encoder_lr ` 186 | --lr_scheduler=$lr_scheduler ` 187 | --lr_warmup_steps=$lr_warmup_steps ` 188 | --lr_scheduler_num_cycles=$lr_restart_cycles ` 189 | --network_dim=$network_dim ` 190 | --network_alpha=$network_alpha ` 191 | --output_name=$output_name ` 192 | --train_batch_size=$batch_size ` 193 | --save_every_n_epochs=$save_every_n_epochs ` 194 | --mixed_precision="fp16" ` 195 | --save_precision="fp16" ` 196 | --seed="1337" ` 197 | --cache_latents ` 198 | --prior_loss_weight=1 ` 199 | --max_token_length=225 ` 200 | --caption_extension=".txt" ` 201 | --save_model_as=$save_model_as ` 202 | --min_bucket_reso=$min_bucket_reso ` 203 | --max_bucket_reso=$max_bucket_reso ` 204 | --keep_tokens=$keep_tokens ` 205 | --xformers --shuffle_caption $ext_args 206 | Write-Output "Train finished" 207 | Read-Host | Out-Null ; 208 | -------------------------------------------------------------------------------- /mikazuki/launch_utils.py: -------------------------------------------------------------------------------- 1 | import locale 2 | import os 3 | import platform 4 | import re 5 | import shutil 6 | import subprocess 7 | import sys 8 | import sysconfig 9 | from typing import List 10 | from pathlib import Path 11 | from typing import Optional 12 | 13 | import pkg_resources 14 | 15 | from mikazuki.log import log 16 | 17 | python_bin = sys.executable 18 | 19 | 20 | def base_dir_path(): 21 | return Path(__file__).parents[1].absolute() 22 | 23 | 24 | def find_windows_git(): 25 | possible_paths = ["git\\bin\\git.exe", "git\\cmd\\git.exe", "Git\\mingw64\\libexec\\git-core\\git.exe"] 26 | for path in possible_paths: 27 | if os.path.exists(path): 28 | return path 29 | 30 | 31 | def prepare_submodules(): 32 | frontend_path = base_dir_path() / "frontend" / "dist" 33 | tag_editor_path = base_dir_path() / "mikazuki" / "dataset-tag-editor" / "scripts" 34 | 35 | if not os.path.exists(frontend_path) or not os.path.exists(tag_editor_path): 36 | log.info("submodule not found, try clone...") 37 | log.info("checking git installation...") 38 | if not shutil.which("git"): 39 | if sys.platform == "win32": 40 | git_path = find_windows_git() 41 | 42 | if git_path is not None: 43 | log.info(f"Git not found, but found git in {git_path}, add it to PATH") 44 | os.environ["PATH"] += os.pathsep + os.path.dirname(git_path) 45 | return 46 | else: 47 | log.error("git not found, please install git first") 48 | sys.exit(1) 49 | subprocess.run(["git", "submodule", "init"]) 50 | subprocess.run(["git", "submodule", "update"]) 51 | 52 | 53 | def check_dirs(dirs: List): 54 | for d in dirs: 55 | if not os.path.exists(d): 56 | os.makedirs(d) 57 | 58 | 59 | def run(command, 60 | desc: Optional[str] = None, 61 | errdesc: Optional[str] = None, 62 | custom_env: Optional[list] = None, 63 | live: Optional[bool] = True, 64 | shell: Optional[bool] = None): 65 | 66 | if shell is None: 67 | shell = False if sys.platform == "win32" else True 68 | 69 | if desc is not None: 70 | print(desc) 71 | 72 | if live: 73 | result = subprocess.run(command, shell=shell, env=os.environ if custom_env is None else custom_env) 74 | if result.returncode != 0: 75 | raise RuntimeError(f"""{errdesc or 'Error running command'}. 76 | Command: {command} 77 | Error code: {result.returncode}""") 78 | 79 | return "" 80 | 81 | result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, 82 | shell=shell, env=os.environ if custom_env is None else custom_env) 83 | 84 | if result.returncode != 0: 85 | message = f"""{errdesc or 'Error running command'}. 86 | Command: {command} 87 | Error code: {result.returncode} 88 | stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout) > 0 else ''} 89 | stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr) > 0 else ''} 90 | """ 91 | raise RuntimeError(message) 92 | 93 | return result.stdout.decode(encoding="utf8", errors="ignore") 94 | 95 | 96 | def is_installed(package, friendly: str = None): 97 | # 98 | # This function was adapted from code written by vladimandic: https://github.com/vladmandic/automatic/commits/master 99 | # 100 | 101 | # Remove brackets and their contents from the line using regular expressions 102 | # e.g., diffusers[torch]==0.10.2 becomes diffusers==0.10.2 103 | package = re.sub(r'\[.*?\]', '', package) 104 | 105 | try: 106 | if friendly: 107 | pkgs = friendly.split() 108 | else: 109 | pkgs = [ 110 | p 111 | for p in package.split() 112 | if not p.startswith('-') and not p.startswith('=') 113 | ] 114 | pkgs = [ 115 | p.split('/')[-1] for p in pkgs 116 | ] # get only package name if installing from URL 117 | 118 | for pkg in pkgs: 119 | if '>=' in pkg: 120 | pkg_name, pkg_version = [x.strip() for x in pkg.split('>=')] 121 | elif '==' in pkg: 122 | pkg_name, pkg_version = [x.strip() for x in pkg.split('==')] 123 | else: 124 | pkg_name, pkg_version = pkg.strip(), None 125 | 126 | spec = pkg_resources.working_set.by_key.get(pkg_name, None) 127 | if spec is None: 128 | spec = pkg_resources.working_set.by_key.get(pkg_name.lower(), None) 129 | if spec is None: 130 | spec = pkg_resources.working_set.by_key.get(pkg_name.replace('_', '-'), None) 131 | 132 | if spec is not None: 133 | version = pkg_resources.get_distribution(pkg_name).version 134 | # log.debug(f'Package version found: {pkg_name} {version}') 135 | 136 | if pkg_version is not None: 137 | if '>=' in pkg: 138 | ok = version >= pkg_version 139 | else: 140 | ok = version == pkg_version 141 | 142 | if not ok: 143 | log.info(f'Package wrong version: {pkg_name} {version} required {pkg_version}') 144 | return False 145 | else: 146 | log.warning(f'Package version not found: {pkg_name}') 147 | return False 148 | 149 | return True 150 | except ModuleNotFoundError: 151 | log.warning(f'Package not installed: {pkgs}') 152 | return False 153 | 154 | 155 | def validate_requirements(requirements_file: str): 156 | with open(requirements_file, 'r', encoding='utf8') as f: 157 | lines = [ 158 | line.strip() 159 | for line in f.readlines() 160 | if line.strip() != '' 161 | and not line.startswith("#") 162 | and not (line.startswith("-") and not line.startswith("--index-url ")) 163 | and line is not None 164 | and "# skip_verify" not in line 165 | ] 166 | 167 | index_url = "" 168 | for line in lines: 169 | if line.startswith("--index-url "): 170 | index_url = line.replace("--index-url ", "") 171 | continue 172 | 173 | if not is_installed(line): 174 | if index_url != "": 175 | run_pip(f"install {line} --index-url {index_url}", line, live=True) 176 | else: 177 | run_pip(f"install {line}", line, live=True) 178 | 179 | 180 | def setup_windows_bitsandbytes(): 181 | if sys.platform != "win32": 182 | return 183 | 184 | # bnb_windows_index = os.environ.get("BNB_WINDOWS_INDEX", "https://jihulab.com/api/v4/projects/140618/packages/pypi/simple") 185 | bnb_package = "bitsandbytes==0.43.0" 186 | bnb_path = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes") 187 | 188 | installed_bnb = is_installed(bnb_package) 189 | bnb_cuda_setup = len([f for f in os.listdir(bnb_path) if re.findall(r"libbitsandbytes_cuda.+?\.dll", f)]) != 0 190 | 191 | if not installed_bnb or not bnb_cuda_setup: 192 | log.error("detected wrong install of bitsandbytes, reinstall it") 193 | run_pip(f"uninstall bitsandbytes -y", "bitsandbytes", live=True) 194 | run_pip(f"install {bnb_package}", bnb_package, live=True) 195 | 196 | 197 | def setup_onnxruntime(): 198 | onnx_version = "1.17.1" 199 | 200 | if sys.platform == "linux": 201 | libc_ver = platform.libc_ver() 202 | if libc_ver[0] == "glibc" and libc_ver[1] <= "2.27": 203 | onnx_version = "1.16.3" 204 | 205 | if not is_installed(f"onnxruntime-gpu=={onnx_version}"): 206 | log.info("uninstalling wrong onnxruntime version") 207 | # run_pip(f"install onnxruntime=={onnx_version}", f"onnxruntime=={onnx_version}", live=True) 208 | run_pip(f"uninstall onnxruntime -y", "onnxruntime", live=True) 209 | run_pip(f"uninstall onnxruntime-gpu -y", "onnxruntime", live=True) 210 | 211 | log.info(f"installing onnxruntime") 212 | run_pip(f"install onnxruntime=={onnx_version}", f"onnxruntime", live=True) 213 | run_pip(f"install onnxruntime-gpu=={onnx_version}", f"onnxruntime-gpu", live=True) 214 | 215 | 216 | def run_pip(command, desc=None, live=False): 217 | return run(f'"{python_bin}" -m pip {command}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live) 218 | 219 | 220 | def check_run(file: str) -> bool: 221 | result = subprocess.run([python_bin, file], capture_output=True, shell=False) 222 | log.info(result.stdout.decode("utf-8").strip()) 223 | return result.returncode == 0 224 | 225 | 226 | def prepare_environment(): 227 | if sys.platform == "win32": 228 | # disable triton on windows 229 | os.environ["XFORMERS_FORCE_DISABLE_TRITON"] = "1" 230 | 231 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 232 | os.environ["BITSANDBYTES_NOWELCOME"] = "1" 233 | os.environ["PYTHONWARNINGS"] = "ignore::UserWarning" 234 | os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1" 235 | 236 | if locale.getdefaultlocale()[0] == "zh_CN": 237 | log.info("detected locale zh_CN, use pip mirrors") 238 | os.environ.setdefault("PIP_FIND_LINKS", "https://mirror.sjtu.edu.cn/pytorch-wheels/torch_stable.html") 239 | os.environ.setdefault("PIP_INDEX_URL", "https://pypi.tuna.tsinghua.edu.cn/simple") 240 | 241 | if not os.environ.get("PATH"): 242 | os.environ["PATH"] = os.path.dirname(sys.executable) 243 | 244 | prepare_submodules() 245 | 246 | check_dirs(["config/autosave", "logs"]) 247 | 248 | # if not check_run("mikazuki/scripts/torch_check.py"): 249 | # sys.exit(1) 250 | 251 | validate_requirements("requirements.txt") 252 | setup_windows_bitsandbytes() 253 | setup_onnxruntime() 254 | 255 | 256 | def catch_exception(f): 257 | def wrapper(*args, **kwargs): 258 | try: 259 | return f(*args, **kwargs) 260 | except Exception as e: 261 | log.error(f"An error occurred: {e}") 262 | return wrapper 263 | -------------------------------------------------------------------------------- /mikazuki/schema/dreambooth-master.ts: -------------------------------------------------------------------------------- 1 | Schema.intersect([ 2 | Schema.intersect([ 3 | Schema.object({ 4 | model_train_type: Schema.union(["sd-dreambooth", "sdxl-finetune"]).default("sd-dreambooth").description("训练种类"), 5 | pretrained_model_name_or_path: Schema.string().role("filepicker").default("./sd-models/model.safetensors").description("底模文件路径"), 6 | resume: Schema.string().role("filepicker").description("从某个 `save_state` 保存的中断状态继续训练,填写文件路径"), 7 | vae: Schema.string().role("filepicker").description("(可选) VAE 模型文件路径,使用外置 VAE 文件覆盖模型内本身的"), 8 | }).description("训练用模型"), 9 | 10 | Schema.union([ 11 | Schema.object({ 12 | model_train_type: Schema.const("sd-dreambooth"), 13 | v2: Schema.boolean().default(false).description("底模为 sd2.0 以后的版本需要启用"), 14 | }), 15 | Schema.object({}), 16 | ]), 17 | 18 | Schema.union([ 19 | Schema.object({ 20 | model_train_type: Schema.const("sd-dreambooth"), 21 | v2: Schema.const(true).required(), 22 | v_parameterization: Schema.boolean().default(false).description("v-parameterization 学习"), 23 | scale_v_pred_loss_like_noise_pred: Schema.boolean().default(false).description("缩放 v-prediction 损失(与v-parameterization配合使用)"), 24 | }), 25 | Schema.object({}), 26 | ]), 27 | ]), 28 | 29 | Schema.object({ 30 | train_data_dir: Schema.string().role("filepicker", { type: "folder" }).default("./train/aki").description("训练数据集路径"), 31 | reg_data_dir: Schema.string().role("filepicker", { type: "folder" }).description("正则化数据集路径。默认留空,不使用正则化图像"), 32 | prior_loss_weight: Schema.number().step(0.1).description("正则化 - 先验损失权重"), 33 | resolution: Schema.string().default("512,512").description("训练图片分辨率,宽x高。支持非正方形,但必须是 64 倍数。"), 34 | enable_bucket: Schema.boolean().default(true).description("启用 arb 桶以允许非固定宽高比的图片"), 35 | min_bucket_reso: Schema.number().default(256).description("arb 桶最小分辨率"), 36 | max_bucket_reso: Schema.number().default(1024).description("arb 桶最大分辨率"), 37 | bucket_reso_steps: Schema.number().default(64).description("arb 桶分辨率划分单位,SDXL 可以使用 32"), 38 | }).description("数据集设置"), 39 | 40 | Schema.object({ 41 | output_name: Schema.string().default("aki").description("模型保存名称"), 42 | output_dir: Schema.string().role("filepicker", { type: "folder" }).default("./output").description("模型保存文件夹"), 43 | save_model_as: Schema.union(["safetensors", "pt", "ckpt"]).default("safetensors").description("模型保存格式"), 44 | save_precision: Schema.union(["fp16", "float", "bf16"]).default("fp16").description("模型保存精度"), 45 | save_every_n_epochs: Schema.number().default(2).description("每 N epoch(轮)自动保存一次模型"), 46 | save_state: Schema.boolean().description("保存训练状态 配合 `resume` 参数可以继续从某个状态训练"), 47 | }).description("保存设置"), 48 | 49 | Schema.object({ 50 | max_train_epochs: Schema.number().min(1).default(10).description("最大训练 epoch(轮数)"), 51 | train_batch_size: Schema.number().min(1).default(1).description("批量大小"), 52 | stop_text_encoder_training: Schema.number().min(-1).description("仅 sd-dreambooth 可用。在第 N 步时,停止训练文本编码器。设置为 -1 不训练文本编码器"), 53 | gradient_checkpointing: Schema.boolean().default(false).description("梯度检查点"), 54 | gradient_accumulation_steps: Schema.number().min(1).description("梯度累加步数"), 55 | }).description("训练相关参数"), 56 | 57 | 58 | Schema.intersect([ 59 | Schema.object({ 60 | learning_rate: Schema.string().default("1e-6").description("学习率"), 61 | }).description("学习率与优化器设置"), 62 | 63 | Schema.union([ 64 | Schema.object({ 65 | model_train_type: Schema.const("sd-dreambooth"), 66 | learning_rate_te: Schema.string().default("5e-7").description("文本编码器学习率"), 67 | }), 68 | Schema.object({}), 69 | ]), 70 | 71 | Schema.union([ 72 | Schema.object({ 73 | model_train_type: Schema.const("sdxl-finetune").required(), 74 | learning_rate_te1: Schema.string().default("5e-7").description("SDXL 文本编码器 1 (ViT-L) 学习率"), 75 | learning_rate_te2: Schema.string().default("5e-7").description("SDXL 文本编码器 2 (BiG-G) 学习率"), 76 | }), 77 | Schema.object({}), 78 | ]), 79 | 80 | Schema.object({ 81 | lr_scheduler: Schema.union([ 82 | "linear", 83 | "cosine", 84 | "cosine_with_restarts", 85 | "polynomial", 86 | "constant", 87 | "constant_with_warmup", 88 | ]).default("cosine_with_restarts").description("学习率调度器设置"), 89 | lr_warmup_steps: Schema.number().default(0).description("学习率预热步数"), 90 | }), 91 | 92 | Schema.union([ 93 | Schema.object({ 94 | lr_scheduler: Schema.const("cosine_with_restarts"), 95 | lr_scheduler_num_cycles: Schema.number().default(1).description("重启次数"), 96 | }), 97 | Schema.object({}), 98 | ]), 99 | 100 | Schema.object({ 101 | optimizer_type: Schema.union([ 102 | "AdamW", 103 | "AdamW8bit", 104 | "PagedAdamW8bit", 105 | "Lion", 106 | "Lion8bit", 107 | "PagedLion8bit", 108 | "SGDNesterov", 109 | "SGDNesterov8bit", 110 | "DAdaptation", 111 | "DAdaptAdam", 112 | "DAdaptAdaGrad", 113 | "DAdaptAdanIP", 114 | "DAdaptLion", 115 | "DAdaptSGD", 116 | "AdaFactor", 117 | "Prodigy" 118 | ]).default("AdamW8bit").description("优化器设置"), 119 | min_snr_gamma: Schema.number().step(0.1).description("最小信噪比伽马值,如果启用推荐为 5"), 120 | }), 121 | 122 | Schema.union([ 123 | Schema.object({ 124 | optimizer_type: Schema.const("Prodigy").required(), 125 | prodigy_d0: Schema.string(), 126 | prodigy_d_coef: Schema.string().default("2.0"), 127 | }), 128 | Schema.object({}), 129 | ]), 130 | 131 | Schema.object({ 132 | optimizer_args_custom: Schema.array(String).role("table").description("自定义 optimizer_args,一行一个"), 133 | }) 134 | ]), 135 | 136 | Schema.intersect([ 137 | Schema.object({ 138 | enable_preview: Schema.boolean().default(false).description("启用训练预览图"), 139 | }).description("训练预览图设置"), 140 | 141 | Schema.union([ 142 | Schema.object({ 143 | enable_preview: Schema.const(true).required(), 144 | sample_prompts: Schema.string().role("textarea").default("(masterpiece, best quality:1.2), 1girl, solo, --n lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts,signature, watermark, username, blurry, --w 512 --h 768 --l 7 --s 24 --d 1337").description("预览图生成参数。`--n` 后方为反向提示词,
`--w`宽,`--h`高
`--l`: CFG Scale
`--s`: 迭代步数
`--d`: 种子"), 145 | sample_sampler: Schema.union(["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver", "dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]).default("euler_a").description("生成预览图所用采样器"), 146 | sample_every_n_epochs: Schema.number().default(2).description("每 N 个 epoch 生成一次预览图"), 147 | }), 148 | Schema.object({}), 149 | ]), 150 | ]), 151 | 152 | Schema.intersect([ 153 | Schema.object({ 154 | log_with: Schema.union(["tensorboard", "wandb"]).default("tensorboard").description("日志模块"), 155 | log_prefix: Schema.string().description("日志前缀"), 156 | log_tracker_name: Schema.string().description("日志追踪器名称"), 157 | logging_dir: Schema.string().default("./logs").description("日志保存文件夹"), 158 | }).description("日志设置"), 159 | 160 | Schema.union([ 161 | Schema.object({ 162 | log_with: Schema.const("wandb").required(), 163 | wandb_api_key: Schema.string().required().description("wandb 的 api 密钥"), 164 | }), 165 | Schema.object({}), 166 | ]), 167 | ]), 168 | 169 | Schema.object({ 170 | caption_extension: Schema.string().default(".txt").description("Tag 文件扩展名"), 171 | shuffle_caption: Schema.boolean().default(true).description("训练时随机打乱 tokens"), 172 | weighted_captions: Schema.boolean().default(false).description("使用带权重的 token,不推荐与 shuffle_caption 一同开启"), 173 | keep_tokens: Schema.number().min(0).max(255).step(1).default(0).description("在随机打乱 tokens 时,保留前 N 个不变"), 174 | keep_tokens_separator: Schema.string().description("保留 tokens 时使用的分隔符"), 175 | max_token_length: Schema.number().default(255).description("最大 token 长度"), 176 | caption_dropout_rate: Schema.number().min(0).max(1).step(0.1).description("丢弃全部标签的概率,对一个图片概率不使用 caption 或 class token"), 177 | caption_dropout_every_n_epochs: Schema.number().min(0).max(100).step(1).description("每 N 个 epoch 丢弃全部标签"), 178 | caption_tag_dropout_rate: Schema.number().min(0).max(1).step(0.1).description("按逗号分隔的标签来随机丢弃 tag 的概率"), 179 | }).description("caption(Tag)选项"), 180 | 181 | Schema.object({ 182 | noise_offset: Schema.number().step(0.0001).description("在训练中添加噪声偏移来改良生成非常暗或者非常亮的图像,如果启用推荐为 0.1"), 183 | multires_noise_iterations: Schema.number().step(1).description("多分辨率(金字塔)噪声迭代次数 推荐 6-10。无法与 noise_offset 一同启用"), 184 | multires_noise_discount: Schema.number().step(0.1).description("多分辨率(金字塔)衰减率 推荐 0.3-0.8,须同时与上方参数 multires_noise_iterations 一同启用"), 185 | }).description("噪声设置"), 186 | 187 | Schema.object({ 188 | seed: Schema.number().default(1337).description("随机种子"), 189 | clip_skip: Schema.number().role("slider").min(0).max(12).step(1).default(2).description("CLIP 跳过层数 *玄学*"), 190 | no_token_padding: Schema.boolean().default(false).description("禁用 token 填充(与 Diffusers 的旧 Dreambooth 脚本一致)"), 191 | }).description("高级设置"), 192 | 193 | Schema.object({ 194 | mixed_precision: Schema.union(["no", "fp16", "bf16"]).default("fp16").description("训练混合精度"), 195 | full_fp16: Schema.boolean().description("完全使用 FP16 精度"), 196 | full_bf16: Schema.boolean().description("完全使用 BF16 精度 仅支持 SDXL"), 197 | xformers: Schema.boolean().default(true).description("启用 xformers"), 198 | lowram: Schema.boolean().default(false).description("低内存模式 该模式下会将 U-net、文本编码器、VAE 直接加载到显存中"), 199 | cache_latents: Schema.boolean().default(true).description("缓存图像 latent"), 200 | cache_latents_to_disk: Schema.boolean().default(true).description("缓存图像 latent 到磁盘"), 201 | persistent_data_loader_workers: Schema.boolean().default(true).description("保留加载训练集的worker,减少每个 epoch 之间的停顿。"), 202 | }).description("速度优化选项"), 203 | 204 | Schema.object({ 205 | ddp_timeout: Schema.number().min(0).description("分布式训练超时时间"), 206 | ddp_gradient_as_bucket_view: Schema.boolean(), 207 | }).description("分布式训练"), 208 | ]); 209 | -------------------------------------------------------------------------------- /mikazuki/tagger/interrogator.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/toriato/stable-diffusion-webui-wd14-tagger 2 | import json 3 | import os 4 | import re 5 | from collections import OrderedDict 6 | from glob import glob 7 | from pathlib import Path 8 | from typing import Dict, List, Tuple 9 | 10 | import numpy as np 11 | import pandas as pd 12 | from PIL import Image 13 | from PIL import UnidentifiedImageError 14 | from huggingface_hub import hf_hub_download 15 | 16 | from mikazuki.tagger import dbimutils, format 17 | 18 | tag_escape_pattern = re.compile(r'([\\()])') 19 | 20 | 21 | class Interrogator: 22 | @staticmethod 23 | def postprocess_tags( 24 | tags: Dict[str, float], 25 | 26 | threshold=0.35, 27 | additional_tags: List[str] = [], 28 | exclude_tags: List[str] = [], 29 | sort_by_alphabetical_order=False, 30 | add_confident_as_weight=False, 31 | replace_underscore=False, 32 | replace_underscore_excludes: List[str] = [], 33 | escape_tag=False 34 | ) -> Dict[str, float]: 35 | for t in additional_tags: 36 | tags[t] = 1.0 37 | 38 | # those lines are totally not "pythonic" but looks better to me 39 | tags = { 40 | t: c 41 | 42 | # sort by tag name or confident 43 | for t, c in sorted( 44 | tags.items(), 45 | key=lambda i: i[0 if sort_by_alphabetical_order else 1], 46 | reverse=not sort_by_alphabetical_order 47 | ) 48 | 49 | # filter tags 50 | if ( 51 | c >= threshold 52 | and t not in exclude_tags 53 | ) 54 | } 55 | 56 | new_tags = [] 57 | for tag in list(tags): 58 | new_tag = tag 59 | 60 | if replace_underscore and tag not in replace_underscore_excludes: 61 | new_tag = new_tag.replace('_', ' ') 62 | 63 | if escape_tag: 64 | new_tag = tag_escape_pattern.sub(r'\\\1', new_tag) 65 | 66 | if add_confident_as_weight: 67 | new_tag = f'({new_tag}:{tags[tag]})' 68 | 69 | new_tags.append((new_tag, tags[tag])) 70 | tags = dict(new_tags) 71 | 72 | return tags 73 | 74 | def __init__(self, name: str) -> None: 75 | self.name = name 76 | 77 | def load(self): 78 | raise NotImplementedError() 79 | 80 | def unload(self) -> bool: 81 | unloaded = False 82 | 83 | if hasattr(self, 'model') and self.model is not None: 84 | del self.model 85 | unloaded = True 86 | print(f'Unloaded {self.name}') 87 | 88 | if hasattr(self, 'tags'): 89 | del self.tags 90 | 91 | return unloaded 92 | 93 | def interrogate( 94 | self, 95 | image: Image 96 | ) -> Tuple[ 97 | Dict[str, float], # rating confidents 98 | Dict[str, float] # tag confidents 99 | ]: 100 | raise NotImplementedError() 101 | 102 | 103 | class WaifuDiffusionInterrogator(Interrogator): 104 | def __init__( 105 | self, 106 | name: str, 107 | model_path='model.onnx', 108 | tags_path='selected_tags.csv', 109 | **kwargs 110 | ) -> None: 111 | super().__init__(name) 112 | self.model_path = model_path 113 | self.tags_path = tags_path 114 | self.kwargs = kwargs 115 | 116 | def download(self) -> Tuple[os.PathLike, os.PathLike]: 117 | print(f"Loading {self.name} model file from {self.kwargs['repo_id']}") 118 | 119 | model_path = Path(hf_hub_download( 120 | **self.kwargs, filename=self.model_path)) 121 | tags_path = Path(hf_hub_download( 122 | **self.kwargs, filename=self.tags_path)) 123 | return model_path, tags_path 124 | 125 | def load(self) -> None: 126 | model_path, tags_path = self.download() 127 | 128 | # only one of these packages should be installed at a time in any one environment 129 | # https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime 130 | # TODO: remove old package when the environment changes? 131 | from mikazuki.launch_utils import is_installed, run_pip 132 | if not is_installed('onnxruntime'): 133 | package = os.environ.get( 134 | 'ONNXRUNTIME_PACKAGE', 135 | 'onnxruntime-gpu' 136 | ) 137 | 138 | run_pip(f'install {package}', 'onnxruntime') 139 | 140 | # Load torch to load cuda libs built in torch for onnxruntime, do not delete this. 141 | import torch 142 | from onnxruntime import InferenceSession 143 | 144 | # https://onnxruntime.ai/docs/execution-providers/ 145 | # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958 146 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] 147 | 148 | self.model = InferenceSession(str(model_path), providers=providers) 149 | 150 | print(f'Loaded {self.name} model from {model_path}') 151 | 152 | self.tags = pd.read_csv(tags_path) 153 | 154 | def interrogate( 155 | self, 156 | image: Image 157 | ) -> Tuple[ 158 | Dict[str, float], # rating confidents 159 | Dict[str, float] # tag confidents 160 | ]: 161 | # init model 162 | if not hasattr(self, 'model') or self.model is None: 163 | self.load() 164 | 165 | # code for converting the image and running the model is taken from the link below 166 | # thanks, SmilingWolf! 167 | # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py 168 | 169 | # convert an image to fit the model 170 | _, height, _, _ = self.model.get_inputs()[0].shape 171 | 172 | # alpha to white 173 | image = image.convert('RGBA') 174 | new_image = Image.new('RGBA', image.size, 'WHITE') 175 | new_image.paste(image, mask=image) 176 | image = new_image.convert('RGB') 177 | image = np.asarray(image) 178 | 179 | # PIL RGB to OpenCV BGR 180 | image = image[:, :, ::-1] 181 | 182 | image = dbimutils.make_square(image, height) 183 | image = dbimutils.smart_resize(image, height) 184 | image = image.astype(np.float32) 185 | image = np.expand_dims(image, 0) 186 | 187 | # evaluate model 188 | input_name = self.model.get_inputs()[0].name 189 | label_name = self.model.get_outputs()[0].name 190 | confidents = self.model.run([label_name], {input_name: image})[0] 191 | 192 | tags = self.tags[:][['name']] 193 | tags['confidents'] = confidents[0] 194 | 195 | # first 4 items are for rating (general, sensitive, questionable, explicit) 196 | ratings = dict(tags[:4].values) 197 | 198 | # rest are regular tags 199 | tags = dict(tags[4:].values) 200 | 201 | return ratings, tags 202 | 203 | 204 | available_interrogators = { 205 | 'wd-convnext-v3': WaifuDiffusionInterrogator( 206 | 'wd-convnext-v3', 207 | repo_id='SmilingWolf/wd-convnext-tagger-v3', 208 | ), 209 | 'wd-swinv2-v3': WaifuDiffusionInterrogator( 210 | 'wd-swinv2-v3', 211 | repo_id='SmilingWolf/wd-swinv2-tagger-v3', 212 | ), 213 | 'wd-vit-v3': WaifuDiffusionInterrogator( 214 | 'wd14-vit-v3', 215 | repo_id='SmilingWolf/wd-vit-tagger-v3', 216 | ), 217 | 'wd14-convnextv2-v2': WaifuDiffusionInterrogator( 218 | 'wd14-convnextv2-v2', repo_id='SmilingWolf/wd-v1-4-convnextv2-tagger-v2', 219 | revision='v2.0' 220 | ), 221 | 'wd14-swinv2-v2': WaifuDiffusionInterrogator( 222 | 'wd14-swinv2-v2', repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2', 223 | revision='v2.0' 224 | ), 225 | 'wd14-vit-v2': WaifuDiffusionInterrogator( 226 | 'wd14-vit-v2', repo_id='SmilingWolf/wd-v1-4-vit-tagger-v2', 227 | revision='v2.0' 228 | ), 229 | 'wd14-moat-v2': WaifuDiffusionInterrogator( 230 | 'wd-v1-4-moat-tagger-v2', 231 | repo_id='SmilingWolf/wd-v1-4-moat-tagger-v2', 232 | revision='v2.0' 233 | ), 234 | } 235 | 236 | 237 | def split_str(s: str, separator=',') -> List[str]: 238 | return [x.strip() for x in s.split(separator) if x] 239 | 240 | 241 | def on_interrogate( 242 | image: Image, 243 | batch_input_glob: str, 244 | batch_input_recursive: bool, 245 | batch_output_dir: str, 246 | batch_output_filename_format: str, 247 | batch_output_action_on_conflict: str, 248 | batch_remove_duplicated_tag: bool, 249 | batch_output_save_json: bool, 250 | 251 | interrogator: Interrogator, 252 | threshold: float, 253 | additional_tags: str, 254 | exclude_tags: str, 255 | sort_by_alphabetical_order: bool, 256 | add_confident_as_weight: bool, 257 | replace_underscore: bool, 258 | replace_underscore_excludes: str, 259 | escape_tag: bool, 260 | 261 | unload_model_after_running: bool 262 | ): 263 | postprocess_opts = ( 264 | threshold, 265 | split_str(additional_tags), 266 | split_str(exclude_tags), 267 | sort_by_alphabetical_order, 268 | add_confident_as_weight, 269 | replace_underscore, 270 | split_str(replace_underscore_excludes), 271 | escape_tag 272 | ) 273 | 274 | # batch process 275 | batch_input_glob = batch_input_glob.strip() 276 | batch_output_dir = batch_output_dir.strip() 277 | batch_output_filename_format = batch_output_filename_format.strip() 278 | 279 | if batch_input_glob != '': 280 | # if there is no glob pattern, insert it automatically 281 | if not batch_input_glob.endswith('*'): 282 | if not batch_input_glob.endswith(os.sep): 283 | batch_input_glob += os.sep 284 | batch_input_glob += '*' 285 | 286 | if batch_input_recursive: 287 | batch_input_glob += '*' 288 | 289 | # get root directory of input glob pattern 290 | base_dir = batch_input_glob.replace('?', '*') 291 | base_dir = base_dir.split(os.sep + '*').pop(0) 292 | 293 | # check the input directory path 294 | if not os.path.isdir(base_dir): 295 | print('input path is not a directory / 输入的路径不是文件夹,终止识别') 296 | return 'input path is not a directory' 297 | 298 | # this line is moved here because some reason 299 | # PIL.Image.registered_extensions() returns only PNG if you call too early 300 | supported_extensions = [ 301 | e 302 | for e, f in Image.registered_extensions().items() 303 | if f in Image.OPEN 304 | ] 305 | 306 | paths = [ 307 | Path(p) 308 | for p in glob(batch_input_glob, recursive=batch_input_recursive) 309 | if '.' + p.split('.').pop().lower() in supported_extensions 310 | ] 311 | 312 | print(f'found {len(paths)} image(s)') 313 | 314 | for path in paths: 315 | try: 316 | image = Image.open(path) 317 | except UnidentifiedImageError: 318 | # just in case, user has mysterious file... 319 | print(f'${path} is not supported image type') 320 | continue 321 | 322 | # guess the output path 323 | base_dir_last = Path(base_dir).parts[-1] 324 | base_dir_last_idx = path.parts.index(base_dir_last) 325 | output_dir = Path( 326 | batch_output_dir) if batch_output_dir else Path(base_dir) 327 | output_dir = output_dir.joinpath( 328 | *path.parts[base_dir_last_idx + 1:]).parent 329 | 330 | output_dir.mkdir(0o777, True, True) 331 | 332 | # format output filename 333 | format_info = format.Info(path, 'txt') 334 | 335 | try: 336 | formatted_output_filename = format.pattern.sub( 337 | lambda m: format.format(m, format_info), 338 | batch_output_filename_format 339 | ) 340 | except (TypeError, ValueError) as error: 341 | return str(error) 342 | 343 | output_path = output_dir.joinpath( 344 | formatted_output_filename 345 | ) 346 | 347 | output = [] 348 | 349 | if output_path.is_file(): 350 | output.append(output_path.read_text(errors='ignore').strip()) 351 | 352 | if batch_output_action_on_conflict == 'ignore': 353 | print(f'skipping {path}') 354 | continue 355 | 356 | ratings, tags = interrogator.interrogate(image) 357 | processed_tags = Interrogator.postprocess_tags( 358 | tags, 359 | *postprocess_opts 360 | ) 361 | 362 | # TODO: switch for less print 363 | print( 364 | f'found {len(processed_tags)} tags out of {len(tags)} from {path}' 365 | ) 366 | 367 | plain_tags = ', '.join(processed_tags) 368 | 369 | if batch_output_action_on_conflict == 'copy': 370 | output = [plain_tags] 371 | elif batch_output_action_on_conflict == 'prepend': 372 | output.insert(0, plain_tags) 373 | else: 374 | output.append(plain_tags) 375 | 376 | if batch_remove_duplicated_tag: 377 | output_path.write_text( 378 | ', '.join( 379 | OrderedDict.fromkeys( 380 | map(str.strip, ','.join(output).split(',')) 381 | ) 382 | ), 383 | encoding='utf-8' 384 | ) 385 | else: 386 | output_path.write_text( 387 | ', '.join(output), 388 | encoding='utf-8' 389 | ) 390 | 391 | if batch_output_save_json: 392 | output_path.with_suffix('.json').write_text( 393 | json.dumps([ratings, tags]) 394 | ) 395 | 396 | print('all done / 识别完成') 397 | 398 | if unload_model_after_running: 399 | interrogator.unload() 400 | 401 | return 'Succeed' 402 | -------------------------------------------------------------------------------- /mikazuki/schema/lora-master.ts: -------------------------------------------------------------------------------- 1 | Schema.intersect([ 2 | Schema.intersect([ 3 | Schema.object({ 4 | model_train_type: Schema.union(["sd-lora", "sdxl-lora"]).default("sd-lora").description("训练种类"), 5 | pretrained_model_name_or_path: Schema.string().role('filepicker').default("./sd-models/model.safetensors").description("底模文件路径"), 6 | resume: Schema.string().role('filepicker').description("从某个 `save_state` 保存的中断状态继续训练,填写文件路径"), 7 | vae: Schema.string().role('filepicker').description("(可选) VAE 模型文件路径,使用外置 VAE 文件覆盖模型内本身的"), 8 | }).description("训练用模型"), 9 | 10 | Schema.union([ 11 | Schema.object({ 12 | model_train_type: Schema.const("sd-lora"), 13 | v2: Schema.boolean().default(false).description("底模为 sd2.0 以后的版本需要启用"), 14 | }), 15 | Schema.object({}), 16 | ]), 17 | 18 | Schema.union([ 19 | Schema.object({ 20 | model_train_type: Schema.const("sd-lora"), 21 | v2: Schema.const(true).required(), 22 | v_parameterization: Schema.boolean().default(false).description("v-parameterization 学习"), 23 | scale_v_pred_loss_like_noise_pred: Schema.boolean().default(false).description("缩放 v-prediction 损失(与v-parameterization配合使用)"), 24 | }), 25 | Schema.object({}), 26 | ]), 27 | ]), 28 | 29 | Schema.object({ 30 | train_data_dir: Schema.string().role('filepicker', { type: "folder" }).default("./train/aki").description("训练数据集路径"), 31 | reg_data_dir: Schema.string().role('filepicker', { type: "folder" }).description("正则化数据集路径。默认留空,不使用正则化图像"), 32 | prior_loss_weight: Schema.number().step(0.1).default(1.0).description("正则化 - 先验损失权重"), 33 | resolution: Schema.string().default("512,512").description("训练图片分辨率,宽x高。支持非正方形,但必须是 64 倍数。"), 34 | enable_bucket: Schema.boolean().default(true).description("启用 arb 桶以允许非固定宽高比的图片"), 35 | min_bucket_reso: Schema.number().default(256).description("arb 桶最小分辨率"), 36 | max_bucket_reso: Schema.number().default(1024).description("arb 桶最大分辨率"), 37 | bucket_reso_steps: Schema.number().default(64).description("arb 桶分辨率划分单位,SDXL 可以使用 32 (SDXL低于32时失效)"), 38 | }).description("数据集设置"), 39 | 40 | Schema.object({ 41 | output_name: Schema.string().default("aki").description("模型保存名称"), 42 | output_dir: Schema.string().role('filepicker', { type: "folder" }).default("./output").description("模型保存文件夹"), 43 | save_model_as: Schema.union(["safetensors", "pt", "ckpt"]).default("safetensors").description("模型保存格式"), 44 | save_precision: Schema.union(["fp16", "float", "bf16"]).default("fp16").description("模型保存精度"), 45 | save_every_n_epochs: Schema.number().default(2).description("每 N epoch(轮)自动保存一次模型"), 46 | save_state: Schema.boolean().description("保存训练状态 配合 `resume` 参数可以继续从某个状态训练"), 47 | }).description("保存设置"), 48 | 49 | Schema.object({ 50 | max_train_epochs: Schema.number().min(1).default(10).description("最大训练 epoch(轮数)"), 51 | train_batch_size: Schema.number().min(1).default(1).description("批量大小, 越高显存占用越高"), 52 | gradient_checkpointing: Schema.boolean().default(false).description("梯度检查点"), 53 | gradient_accumulation_steps: Schema.number().min(1).description("梯度累加步数"), 54 | network_train_unet_only: Schema.boolean().default(false).description("仅训练 U-Net 训练SDXL Lora时推荐开启"), 55 | network_train_text_encoder_only: Schema.boolean().default(false).description("仅训练文本编码器"), 56 | }).description("训练相关参数"), 57 | 58 | Schema.intersect([ 59 | Schema.object({ 60 | learning_rate: Schema.string().default("1e-4").description("总学习率, 在分开设置 U-Net 与文本编码器学习率后这个值失效。"), 61 | unet_lr: Schema.string().default("1e-4").description("U-Net 学习率"), 62 | text_encoder_lr: Schema.string().default("1e-5").description("文本编码器学习率"), 63 | lr_scheduler: Schema.union([ 64 | "linear", 65 | "cosine", 66 | "cosine_with_restarts", 67 | "polynomial", 68 | "constant", 69 | "constant_with_warmup", 70 | ]).default("cosine_with_restarts").description("学习率调度器设置"), 71 | lr_warmup_steps: Schema.number().default(0).description('学习率预热步数'), 72 | }).description("学习率与优化器设置"), 73 | 74 | Schema.union([ 75 | Schema.object({ 76 | lr_scheduler: Schema.const('cosine_with_restarts'), 77 | lr_scheduler_num_cycles: Schema.number().default(1).description('重启次数'), 78 | }), 79 | Schema.object({}), 80 | ]), 81 | 82 | Schema.object({ 83 | optimizer_type: Schema.union([ 84 | "AdamW", 85 | "AdamW8bit", 86 | "PagedAdamW8bit", 87 | "Lion", 88 | "Lion8bit", 89 | "PagedLion8bit", 90 | "SGDNesterov", 91 | "SGDNesterov8bit", 92 | "DAdaptation", 93 | "DAdaptAdam", 94 | "DAdaptAdaGrad", 95 | "DAdaptAdanIP", 96 | "DAdaptLion", 97 | "DAdaptSGD", 98 | "AdaFactor", 99 | "Prodigy" 100 | ]).default("AdamW8bit").description("优化器设置"), 101 | min_snr_gamma: Schema.number().step(0.1).description("最小信噪比伽马值, 如果启用推荐为 5"), 102 | }), 103 | 104 | Schema.union([ 105 | Schema.object({ 106 | optimizer_type: Schema.const('Prodigy').required(), 107 | prodigy_d0: Schema.string(), 108 | prodigy_d_coef: Schema.string().default("2.0"), 109 | }), 110 | Schema.object({}), 111 | ]), 112 | 113 | Schema.object({ 114 | optimizer_args_custom: Schema.array(String).role('table').description('自定义 optimizer_args,一行一个'), 115 | }) 116 | ]), 117 | 118 | Schema.intersect([ 119 | Schema.object({ 120 | network_module: Schema.union(["networks.lora", "networks.dylora", "networks.oft", "lycoris.kohya"]).default("networks.lora").description("训练网络模块"), 121 | network_weights: Schema.string().role('filepicker').description("从已有的 LoRA 模型上继续训练,填写路径"), 122 | network_dim: Schema.number().min(1).default(32).description("网络维度,常用 4~128,不是越大越好, 低dim可以降低显存占用"), 123 | network_alpha: Schema.number().min(1).default(32).description("常用值:等于 network_dim 或 network_dim*1/2 或 1。使用较小的 alpha 需要提升学习率"), 124 | network_dropout: Schema.number().step(0.01).default(0).description('dropout 概率 (与 lycoris 不兼容,需要用 lycoris 自带的)'), 125 | scale_weight_norms: Schema.number().step(0.01).min(0).description("最大范数正则化。如果使用,推荐为 1"), 126 | network_args_custom: Schema.array(String).role('table').description('自定义 network_args,一行一个'), 127 | enable_block_weights: Schema.boolean().default(false).description('启用分层学习率训练(只支持网络模块 networks.lora)'), 128 | enable_base_weight: Schema.boolean().default(false).description('启用基础权重(差异炼丹)'), 129 | }).description("网络设置"), 130 | 131 | Schema.union([ 132 | Schema.object({ 133 | network_module: Schema.const('lycoris.kohya').required(), 134 | lycoris_algo: Schema.union(["locon", "loha", "lokr", "ia3", "dylora", "glora", "diag-oft", "boft"]).default("locon").description('LyCORIS 网络算法'), 135 | conv_dim: Schema.number().default(4), 136 | conv_alpha: Schema.number().default(1), 137 | dropout: Schema.number().step(0.01).default(0).description('dropout 概率。推荐 0~0.5,LoHa/LoKr/(IA)^3暂不支持'), 138 | train_norm: Schema.boolean().default(false).description('训练 Norm 层,不支持 (IA)^3'), 139 | }), 140 | Schema.object({ 141 | network_module: Schema.const('networks.dylora').required(), 142 | dylora_unit: Schema.number().min(1).default(4).description(' dylora 分割块数单位,最小 1 也最慢。一般 4、8、12、16 这几个选'), 143 | }), 144 | Schema.object({}), 145 | ]), 146 | 147 | Schema.union([ 148 | Schema.object({ 149 | lycoris_algo: Schema.const('lokr').required(), 150 | lokr_factor: Schema.number().min(-1).default(-1).description('常用 `4~无穷`(填写 -1 为无穷)'), 151 | }), 152 | Schema.object({}), 153 | ]), 154 | 155 | Schema.union([ 156 | Schema.object({ 157 | enable_block_weights: Schema.const(true).required(), 158 | down_lr_weight: Schema.string().role('folder').default("1,1,1,1,1,1,1,1,1,1,1,1").description("U-Net 的 Encoder 层分层学习率权重,共 12 层"), 159 | mid_lr_weight: Schema.string().role('folder').default("1").description("U-Net 的 Mid 层分层学习率权重,共 1 层"), 160 | up_lr_weight: Schema.string().role('folder').default("1,1,1,1,1,1,1,1,1,1,1,1").description("U-Net 的 Decoder 层分层学习率权重,共 12 层"), 161 | block_lr_zero_threshold: Schema.number().step(0.01).default(0).description("分层学习率置 0 阈值"), 162 | }), 163 | Schema.object({}), 164 | ]), 165 | 166 | Schema.union([ 167 | Schema.object({ 168 | enable_base_weight: Schema.const(true).required(), 169 | base_weights: Schema.string().role('textarea').description("合并入底模的 LoRA 路径,一行一个路径"), 170 | base_weights_multiplier: Schema.string().role('textarea').description("合并入底模的 LoRA 权重,一行一个数字"), 171 | }), 172 | Schema.object({}), 173 | ]), 174 | ]), 175 | 176 | Schema.intersect([ 177 | Schema.object({ 178 | enable_preview: Schema.boolean().default(false).description('启用训练预览图'), 179 | }).description('训练预览图设置'), 180 | 181 | Schema.union([ 182 | Schema.object({ 183 | enable_preview: Schema.const(true).required(), 184 | sample_prompts: Schema.string().role('textarea').default(window.__MIKAZUKI__.SAMPLE_PROMPTS_DEFAULT).description(window.__MIKAZUKI__.SAMPLE_PROMPTS_DESCRIPTION), 185 | sample_sampler: Schema.union(["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver", "dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]).default("euler_a").description("生成预览图所用采样器"), 186 | sample_every_n_epochs: Schema.number().default(2).description("每 N 个 epoch 生成一次预览图"), 187 | }), 188 | Schema.object({}), 189 | ]), 190 | ]), 191 | 192 | Schema.intersect([ 193 | Schema.object({ 194 | log_with: Schema.union(["tensorboard", "wandb"]).default("tensorboard").description("日志模块"), 195 | log_prefix: Schema.string().description("日志前缀"), 196 | log_tracker_name: Schema.string().description("日志追踪器名称"), 197 | logging_dir: Schema.string().default("./logs").description("日志保存文件夹"), 198 | }).description('日志设置'), 199 | 200 | Schema.union([ 201 | Schema.object({ 202 | log_with: Schema.const("wandb").required(), 203 | wandb_api_key: Schema.string().required().description("wandb 的 api 密钥"), 204 | }), 205 | Schema.object({}), 206 | ]), 207 | ]), 208 | 209 | Schema.object({ 210 | caption_extension: Schema.string().default(".txt").description("Tag 文件扩展名"), 211 | shuffle_caption: Schema.boolean().default(true).description("训练时随机打乱 tokens"), 212 | weighted_captions: Schema.boolean().description("使用带权重的 token,不推荐与 shuffle_caption 一同开启"), 213 | keep_tokens: Schema.number().min(0).max(255).step(1).default(0).description("在随机打乱 tokens 时,保留前 N 个不变"), 214 | keep_tokens_separator: Schema.string().description("保留 tokens 时使用的分隔符"), 215 | max_token_length: Schema.number().default(255).description("最大 token 长度"), 216 | caption_dropout_rate: Schema.number().min(0).step(0.01).description("丢弃全部标签的概率,对一个图片概率不使用 caption 或 class token"), 217 | caption_dropout_every_n_epochs: Schema.number().min(0).max(100).step(1).description("每 N 个 epoch 丢弃全部标签"), 218 | caption_tag_dropout_rate: Schema.number().min(0).step(0.01).description("按逗号分隔的标签来随机丢弃 tag 的概率"), 219 | }).description("caption(Tag)选项"), 220 | 221 | Schema.object({ 222 | noise_offset: Schema.number().step(0.0001).description("在训练中添加噪声偏移来改良生成非常暗或者非常亮的图像,如果启用推荐为 0.1"), 223 | multires_noise_iterations: Schema.number().step(1).description("多分辨率(金字塔)噪声迭代次数 推荐 6-10。无法与 noise_offset 一同启用"), 224 | multires_noise_discount: Schema.number().step(0.01).description("多分辨率(金字塔)衰减率 推荐 0.3-0.8,须同时与上方参数 multires_noise_iterations 一同启用"), 225 | }).description("噪声设置"), 226 | 227 | Schema.object({ 228 | color_aug: Schema.boolean().description("颜色改变"), 229 | flip_aug: Schema.boolean().description("图像翻转"), 230 | random_crop: Schema.boolean().description("随机剪裁"), 231 | }).description("数据增强"), 232 | 233 | Schema.object({ 234 | seed: Schema.number().default(1337).description("随机种子"), 235 | clip_skip: Schema.number().role("slider").min(0).max(12).step(1).default(2).description("CLIP 跳过层数 *玄学*"), 236 | no_metadata: Schema.boolean().description("不保存模型元数据"), 237 | ui_custom_params: Schema.string().role('textarea').description("**危险** 自定义参数,请输入 TOML 格式,将会直接覆盖当前界面内任何参数。实时更新,推荐写完后再粘贴过来"), 238 | }).description("高级设置"), 239 | 240 | Schema.object({ 241 | mixed_precision: Schema.union(["no", "fp16", "bf16"]).default("fp16").description("训练混合精度, RTX30系列以后也可以指定`bf16`"), 242 | full_fp16: Schema.boolean().description("完全使用 FP16 精度"), 243 | full_bf16: Schema.boolean().description("完全使用 BF16 精度"), 244 | fp8_base: Schema.boolean().description("对基础模型使用 FP8 精度"), 245 | no_half_vae: Schema.boolean().description("不使用半精度 VAE"), 246 | xformers: Schema.boolean().default(true).description("启用 xformers"), 247 | lowram: Schema.boolean().default(false).description("低内存模式 该模式下会将 U-net、文本编码器、VAE 直接加载到显存中"), 248 | cache_latents: Schema.boolean().default(true).description("缓存图像 latent, 缓存 VAE 输出以减少 VRAM 使用"), 249 | cache_latents_to_disk: Schema.boolean().default(true).description("缓存图像 latent 到磁盘"), 250 | cache_text_encoder_outputs: Schema.boolean().description("缓存文本编码器的输出,减少显存使用。使用时需要关闭 shuffle_caption"), 251 | cache_text_encoder_outputs_to_disk: Schema.boolean().description("缓存文本编码器的输出到磁盘"), 252 | persistent_data_loader_workers: Schema.boolean().default(true).description("保留加载训练集的worker,减少每个 epoch 之间的停顿。"), 253 | }).description("速度优化选项"), 254 | 255 | Schema.object({ 256 | ddp_timeout: Schema.number().min(0).description("分布式训练超时时间"), 257 | ddp_gradient_as_bucket_view: Schema.boolean(), 258 | }).description("分布式训练"), 259 | ]); 260 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU AFFERO GENERAL PUBLIC LICENSE 2 | Version 3, 19 November 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU Affero General Public License is a free, copyleft license for 11 | software and other kinds of works, specifically designed to ensure 12 | cooperation with the community in the case of network server software. 13 | 14 | The licenses for most software and other practical works are designed 15 | to take away your freedom to share and change the works. By contrast, 16 | our General Public Licenses are intended to guarantee your freedom to 17 | share and change all versions of a program--to make sure it remains free 18 | software for all its users. 19 | 20 | When we speak of free software, we are referring to freedom, not 21 | price. Our General Public Licenses are designed to make sure that you 22 | have the freedom to distribute copies of free software (and charge for 23 | them if you wish), that you receive source code or can get it if you 24 | want it, that you can change the software or use pieces of it in new 25 | free programs, and that you know you can do these things. 26 | 27 | Developers that use our General Public Licenses protect your rights 28 | with two steps: (1) assert copyright on the software, and (2) offer 29 | you this License which gives you legal permission to copy, distribute 30 | and/or modify the software. 31 | 32 | A secondary benefit of defending all users' freedom is that 33 | improvements made in alternate versions of the program, if they 34 | receive widespread use, become available for other developers to 35 | incorporate. Many developers of free software are heartened and 36 | encouraged by the resulting cooperation. However, in the case of 37 | software used on network servers, this result may fail to come about. 38 | The GNU General Public License permits making a modified version and 39 | letting the public access it on a server without ever releasing its 40 | source code to the public. 41 | 42 | The GNU Affero General Public License is designed specifically to 43 | ensure that, in such cases, the modified source code becomes available 44 | to the community. It requires the operator of a network server to 45 | provide the source code of the modified version running there to the 46 | users of that server. Therefore, public use of a modified version, on 47 | a publicly accessible server, gives the public access to the source 48 | code of the modified version. 49 | 50 | An older license, called the Affero General Public License and 51 | published by Affero, was designed to accomplish similar goals. This is 52 | a different license, not a version of the Affero GPL, but Affero has 53 | released a new version of the Affero GPL which permits relicensing under 54 | this license. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | TERMS AND CONDITIONS 60 | 61 | 0. Definitions. 62 | 63 | "This License" refers to version 3 of the GNU Affero General Public License. 64 | 65 | "Copyright" also means copyright-like laws that apply to other kinds of 66 | works, such as semiconductor masks. 67 | 68 | "The Program" refers to any copyrightable work licensed under this 69 | License. Each licensee is addressed as "you". "Licensees" and 70 | "recipients" may be individuals or organizations. 71 | 72 | To "modify" a work means to copy from or adapt all or part of the work 73 | in a fashion requiring copyright permission, other than the making of an 74 | exact copy. The resulting work is called a "modified version" of the 75 | earlier work or a work "based on" the earlier work. 76 | 77 | A "covered work" means either the unmodified Program or a work based 78 | on the Program. 79 | 80 | To "propagate" a work means to do anything with it that, without 81 | permission, would make you directly or secondarily liable for 82 | infringement under applicable copyright law, except executing it on a 83 | computer or modifying a private copy. Propagation includes copying, 84 | distribution (with or without modification), making available to the 85 | public, and in some countries other activities as well. 86 | 87 | To "convey" a work means any kind of propagation that enables other 88 | parties to make or receive copies. Mere interaction with a user through 89 | a computer network, with no transfer of a copy, is not conveying. 90 | 91 | An interactive user interface displays "Appropriate Legal Notices" 92 | to the extent that it includes a convenient and prominently visible 93 | feature that (1) displays an appropriate copyright notice, and (2) 94 | tells the user that there is no warranty for the work (except to the 95 | extent that warranties are provided), that licensees may convey the 96 | work under this License, and how to view a copy of this License. If 97 | the interface presents a list of user commands or options, such as a 98 | menu, a prominent item in the list meets this criterion. 99 | 100 | 1. Source Code. 101 | 102 | The "source code" for a work means the preferred form of the work 103 | for making modifications to it. "Object code" means any non-source 104 | form of a work. 105 | 106 | A "Standard Interface" means an interface that either is an official 107 | standard defined by a recognized standards body, or, in the case of 108 | interfaces specified for a particular programming language, one that 109 | is widely used among developers working in that language. 110 | 111 | The "System Libraries" of an executable work include anything, other 112 | than the work as a whole, that (a) is included in the normal form of 113 | packaging a Major Component, but which is not part of that Major 114 | Component, and (b) serves only to enable use of the work with that 115 | Major Component, or to implement a Standard Interface for which an 116 | implementation is available to the public in source code form. A 117 | "Major Component", in this context, means a major essential component 118 | (kernel, window system, and so on) of the specific operating system 119 | (if any) on which the executable work runs, or a compiler used to 120 | produce the work, or an object code interpreter used to run it. 121 | 122 | The "Corresponding Source" for a work in object code form means all 123 | the source code needed to generate, install, and (for an executable 124 | work) run the object code and to modify the work, including scripts to 125 | control those activities. However, it does not include the work's 126 | System Libraries, or general-purpose tools or generally available free 127 | programs which are used unmodified in performing those activities but 128 | which are not part of the work. For example, Corresponding Source 129 | includes interface definition files associated with source files for 130 | the work, and the source code for shared libraries and dynamically 131 | linked subprograms that the work is specifically designed to require, 132 | such as by intimate data communication or control flow between those 133 | subprograms and other parts of the work. 134 | 135 | The Corresponding Source need not include anything that users 136 | can regenerate automatically from other parts of the Corresponding 137 | Source. 138 | 139 | The Corresponding Source for a work in source code form is that 140 | same work. 141 | 142 | 2. Basic Permissions. 143 | 144 | All rights granted under this License are granted for the term of 145 | copyright on the Program, and are irrevocable provided the stated 146 | conditions are met. This License explicitly affirms your unlimited 147 | permission to run the unmodified Program. The output from running a 148 | covered work is covered by this License only if the output, given its 149 | content, constitutes a covered work. This License acknowledges your 150 | rights of fair use or other equivalent, as provided by copyright law. 151 | 152 | You may make, run and propagate covered works that you do not 153 | convey, without conditions so long as your license otherwise remains 154 | in force. You may convey covered works to others for the sole purpose 155 | of having them make modifications exclusively for you, or provide you 156 | with facilities for running those works, provided that you comply with 157 | the terms of this License in conveying all material for which you do 158 | not control copyright. Those thus making or running the covered works 159 | for you must do so exclusively on your behalf, under your direction 160 | and control, on terms that prohibit them from making any copies of 161 | your copyrighted material outside their relationship with you. 162 | 163 | Conveying under any other circumstances is permitted solely under 164 | the conditions stated below. Sublicensing is not allowed; section 10 165 | makes it unnecessary. 166 | 167 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 168 | 169 | No covered work shall be deemed part of an effective technological 170 | measure under any applicable law fulfilling obligations under article 171 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 172 | similar laws prohibiting or restricting circumvention of such 173 | measures. 174 | 175 | When you convey a covered work, you waive any legal power to forbid 176 | circumvention of technological measures to the extent such circumvention 177 | is effected by exercising rights under this License with respect to 178 | the covered work, and you disclaim any intention to limit operation or 179 | modification of the work as a means of enforcing, against the work's 180 | users, your or third parties' legal rights to forbid circumvention of 181 | technological measures. 182 | 183 | 4. Conveying Verbatim Copies. 184 | 185 | You may convey verbatim copies of the Program's source code as you 186 | receive it, in any medium, provided that you conspicuously and 187 | appropriately publish on each copy an appropriate copyright notice; 188 | keep intact all notices stating that this License and any 189 | non-permissive terms added in accord with section 7 apply to the code; 190 | keep intact all notices of the absence of any warranty; and give all 191 | recipients a copy of this License along with the Program. 192 | 193 | You may charge any price or no price for each copy that you convey, 194 | and you may offer support or warranty protection for a fee. 195 | 196 | 5. Conveying Modified Source Versions. 197 | 198 | You may convey a work based on the Program, or the modifications to 199 | produce it from the Program, in the form of source code under the 200 | terms of section 4, provided that you also meet all of these conditions: 201 | 202 | a) The work must carry prominent notices stating that you modified 203 | it, and giving a relevant date. 204 | 205 | b) The work must carry prominent notices stating that it is 206 | released under this License and any conditions added under section 207 | 7. This requirement modifies the requirement in section 4 to 208 | "keep intact all notices". 209 | 210 | c) You must license the entire work, as a whole, under this 211 | License to anyone who comes into possession of a copy. This 212 | License will therefore apply, along with any applicable section 7 213 | additional terms, to the whole of the work, and all its parts, 214 | regardless of how they are packaged. This License gives no 215 | permission to license the work in any other way, but it does not 216 | invalidate such permission if you have separately received it. 217 | 218 | d) If the work has interactive user interfaces, each must display 219 | Appropriate Legal Notices; however, if the Program has interactive 220 | interfaces that do not display Appropriate Legal Notices, your 221 | work need not make them do so. 222 | 223 | A compilation of a covered work with other separate and independent 224 | works, which are not by their nature extensions of the covered work, 225 | and which are not combined with it such as to form a larger program, 226 | in or on a volume of a storage or distribution medium, is called an 227 | "aggregate" if the compilation and its resulting copyright are not 228 | used to limit the access or legal rights of the compilation's users 229 | beyond what the individual works permit. Inclusion of a covered work 230 | in an aggregate does not cause this License to apply to the other 231 | parts of the aggregate. 232 | 233 | 6. Conveying Non-Source Forms. 234 | 235 | You may convey a covered work in object code form under the terms 236 | of sections 4 and 5, provided that you also convey the 237 | machine-readable Corresponding Source under the terms of this License, 238 | in one of these ways: 239 | 240 | a) Convey the object code in, or embodied in, a physical product 241 | (including a physical distribution medium), accompanied by the 242 | Corresponding Source fixed on a durable physical medium 243 | customarily used for software interchange. 244 | 245 | b) Convey the object code in, or embodied in, a physical product 246 | (including a physical distribution medium), accompanied by a 247 | written offer, valid for at least three years and valid for as 248 | long as you offer spare parts or customer support for that product 249 | model, to give anyone who possesses the object code either (1) a 250 | copy of the Corresponding Source for all the software in the 251 | product that is covered by this License, on a durable physical 252 | medium customarily used for software interchange, for a price no 253 | more than your reasonable cost of physically performing this 254 | conveying of source, or (2) access to copy the 255 | Corresponding Source from a network server at no charge. 256 | 257 | c) Convey individual copies of the object code with a copy of the 258 | written offer to provide the Corresponding Source. This 259 | alternative is allowed only occasionally and noncommercially, and 260 | only if you received the object code with such an offer, in accord 261 | with subsection 6b. 262 | 263 | d) Convey the object code by offering access from a designated 264 | place (gratis or for a charge), and offer equivalent access to the 265 | Corresponding Source in the same way through the same place at no 266 | further charge. You need not require recipients to copy the 267 | Corresponding Source along with the object code. If the place to 268 | copy the object code is a network server, the Corresponding Source 269 | may be on a different server (operated by you or a third party) 270 | that supports equivalent copying facilities, provided you maintain 271 | clear directions next to the object code saying where to find the 272 | Corresponding Source. Regardless of what server hosts the 273 | Corresponding Source, you remain obligated to ensure that it is 274 | available for as long as needed to satisfy these requirements. 275 | 276 | e) Convey the object code using peer-to-peer transmission, provided 277 | you inform other peers where the object code and Corresponding 278 | Source of the work are being offered to the general public at no 279 | charge under subsection 6d. 280 | 281 | A separable portion of the object code, whose source code is excluded 282 | from the Corresponding Source as a System Library, need not be 283 | included in conveying the object code work. 284 | 285 | A "User Product" is either (1) a "consumer product", which means any 286 | tangible personal property which is normally used for personal, family, 287 | or household purposes, or (2) anything designed or sold for incorporation 288 | into a dwelling. In determining whether a product is a consumer product, 289 | doubtful cases shall be resolved in favor of coverage. For a particular 290 | product received by a particular user, "normally used" refers to a 291 | typical or common use of that class of product, regardless of the status 292 | of the particular user or of the way in which the particular user 293 | actually uses, or expects or is expected to use, the product. A product 294 | is a consumer product regardless of whether the product has substantial 295 | commercial, industrial or non-consumer uses, unless such uses represent 296 | the only significant mode of use of the product. 297 | 298 | "Installation Information" for a User Product means any methods, 299 | procedures, authorization keys, or other information required to install 300 | and execute modified versions of a covered work in that User Product from 301 | a modified version of its Corresponding Source. The information must 302 | suffice to ensure that the continued functioning of the modified object 303 | code is in no case prevented or interfered with solely because 304 | modification has been made. 305 | 306 | If you convey an object code work under this section in, or with, or 307 | specifically for use in, a User Product, and the conveying occurs as 308 | part of a transaction in which the right of possession and use of the 309 | User Product is transferred to the recipient in perpetuity or for a 310 | fixed term (regardless of how the transaction is characterized), the 311 | Corresponding Source conveyed under this section must be accompanied 312 | by the Installation Information. But this requirement does not apply 313 | if neither you nor any third party retains the ability to install 314 | modified object code on the User Product (for example, the work has 315 | been installed in ROM). 316 | 317 | The requirement to provide Installation Information does not include a 318 | requirement to continue to provide support service, warranty, or updates 319 | for a work that has been modified or installed by the recipient, or for 320 | the User Product in which it has been modified or installed. Access to a 321 | network may be denied when the modification itself materially and 322 | adversely affects the operation of the network or violates the rules and 323 | protocols for communication across the network. 324 | 325 | Corresponding Source conveyed, and Installation Information provided, 326 | in accord with this section must be in a format that is publicly 327 | documented (and with an implementation available to the public in 328 | source code form), and must require no special password or key for 329 | unpacking, reading or copying. 330 | 331 | 7. Additional Terms. 332 | 333 | "Additional permissions" are terms that supplement the terms of this 334 | License by making exceptions from one or more of its conditions. 335 | Additional permissions that are applicable to the entire Program shall 336 | be treated as though they were included in this License, to the extent 337 | that they are valid under applicable law. If additional permissions 338 | apply only to part of the Program, that part may be used separately 339 | under those permissions, but the entire Program remains governed by 340 | this License without regard to the additional permissions. 341 | 342 | When you convey a copy of a covered work, you may at your option 343 | remove any additional permissions from that copy, or from any part of 344 | it. (Additional permissions may be written to require their own 345 | removal in certain cases when you modify the work.) You may place 346 | additional permissions on material, added by you to a covered work, 347 | for which you have or can give appropriate copyright permission. 348 | 349 | Notwithstanding any other provision of this License, for material you 350 | add to a covered work, you may (if authorized by the copyright holders of 351 | that material) supplement the terms of this License with terms: 352 | 353 | a) Disclaiming warranty or limiting liability differently from the 354 | terms of sections 15 and 16 of this License; or 355 | 356 | b) Requiring preservation of specified reasonable legal notices or 357 | author attributions in that material or in the Appropriate Legal 358 | Notices displayed by works containing it; or 359 | 360 | c) Prohibiting misrepresentation of the origin of that material, or 361 | requiring that modified versions of such material be marked in 362 | reasonable ways as different from the original version; or 363 | 364 | d) Limiting the use for publicity purposes of names of licensors or 365 | authors of the material; or 366 | 367 | e) Declining to grant rights under trademark law for use of some 368 | trade names, trademarks, or service marks; or 369 | 370 | f) Requiring indemnification of licensors and authors of that 371 | material by anyone who conveys the material (or modified versions of 372 | it) with contractual assumptions of liability to the recipient, for 373 | any liability that these contractual assumptions directly impose on 374 | those licensors and authors. 375 | 376 | All other non-permissive additional terms are considered "further 377 | restrictions" within the meaning of section 10. If the Program as you 378 | received it, or any part of it, contains a notice stating that it is 379 | governed by this License along with a term that is a further 380 | restriction, you may remove that term. If a license document contains 381 | a further restriction but permits relicensing or conveying under this 382 | License, you may add to a covered work material governed by the terms 383 | of that license document, provided that the further restriction does 384 | not survive such relicensing or conveying. 385 | 386 | If you add terms to a covered work in accord with this section, you 387 | must place, in the relevant source files, a statement of the 388 | additional terms that apply to those files, or a notice indicating 389 | where to find the applicable terms. 390 | 391 | Additional terms, permissive or non-permissive, may be stated in the 392 | form of a separately written license, or stated as exceptions; 393 | the above requirements apply either way. 394 | 395 | 8. Termination. 396 | 397 | You may not propagate or modify a covered work except as expressly 398 | provided under this License. Any attempt otherwise to propagate or 399 | modify it is void, and will automatically terminate your rights under 400 | this License (including any patent licenses granted under the third 401 | paragraph of section 11). 402 | 403 | However, if you cease all violation of this License, then your 404 | license from a particular copyright holder is reinstated (a) 405 | provisionally, unless and until the copyright holder explicitly and 406 | finally terminates your license, and (b) permanently, if the copyright 407 | holder fails to notify you of the violation by some reasonable means 408 | prior to 60 days after the cessation. 409 | 410 | Moreover, your license from a particular copyright holder is 411 | reinstated permanently if the copyright holder notifies you of the 412 | violation by some reasonable means, this is the first time you have 413 | received notice of violation of this License (for any work) from that 414 | copyright holder, and you cure the violation prior to 30 days after 415 | your receipt of the notice. 416 | 417 | Termination of your rights under this section does not terminate the 418 | licenses of parties who have received copies or rights from you under 419 | this License. If your rights have been terminated and not permanently 420 | reinstated, you do not qualify to receive new licenses for the same 421 | material under section 10. 422 | 423 | 9. Acceptance Not Required for Having Copies. 424 | 425 | You are not required to accept this License in order to receive or 426 | run a copy of the Program. Ancillary propagation of a covered work 427 | occurring solely as a consequence of using peer-to-peer transmission 428 | to receive a copy likewise does not require acceptance. However, 429 | nothing other than this License grants you permission to propagate or 430 | modify any covered work. These actions infringe copyright if you do 431 | not accept this License. Therefore, by modifying or propagating a 432 | covered work, you indicate your acceptance of this License to do so. 433 | 434 | 10. Automatic Licensing of Downstream Recipients. 435 | 436 | Each time you convey a covered work, the recipient automatically 437 | receives a license from the original licensors, to run, modify and 438 | propagate that work, subject to this License. You are not responsible 439 | for enforcing compliance by third parties with this License. 440 | 441 | An "entity transaction" is a transaction transferring control of an 442 | organization, or substantially all assets of one, or subdividing an 443 | organization, or merging organizations. If propagation of a covered 444 | work results from an entity transaction, each party to that 445 | transaction who receives a copy of the work also receives whatever 446 | licenses to the work the party's predecessor in interest had or could 447 | give under the previous paragraph, plus a right to possession of the 448 | Corresponding Source of the work from the predecessor in interest, if 449 | the predecessor has it or can get it with reasonable efforts. 450 | 451 | You may not impose any further restrictions on the exercise of the 452 | rights granted or affirmed under this License. For example, you may 453 | not impose a license fee, royalty, or other charge for exercise of 454 | rights granted under this License, and you may not initiate litigation 455 | (including a cross-claim or counterclaim in a lawsuit) alleging that 456 | any patent claim is infringed by making, using, selling, offering for 457 | sale, or importing the Program or any portion of it. 458 | 459 | 11. Patents. 460 | 461 | A "contributor" is a copyright holder who authorizes use under this 462 | License of the Program or a work on which the Program is based. The 463 | work thus licensed is called the contributor's "contributor version". 464 | 465 | A contributor's "essential patent claims" are all patent claims 466 | owned or controlled by the contributor, whether already acquired or 467 | hereafter acquired, that would be infringed by some manner, permitted 468 | by this License, of making, using, or selling its contributor version, 469 | but do not include claims that would be infringed only as a 470 | consequence of further modification of the contributor version. For 471 | purposes of this definition, "control" includes the right to grant 472 | patent sublicenses in a manner consistent with the requirements of 473 | this License. 474 | 475 | Each contributor grants you a non-exclusive, worldwide, royalty-free 476 | patent license under the contributor's essential patent claims, to 477 | make, use, sell, offer for sale, import and otherwise run, modify and 478 | propagate the contents of its contributor version. 479 | 480 | In the following three paragraphs, a "patent license" is any express 481 | agreement or commitment, however denominated, not to enforce a patent 482 | (such as an express permission to practice a patent or covenant not to 483 | sue for patent infringement). To "grant" such a patent license to a 484 | party means to make such an agreement or commitment not to enforce a 485 | patent against the party. 486 | 487 | If you convey a covered work, knowingly relying on a patent license, 488 | and the Corresponding Source of the work is not available for anyone 489 | to copy, free of charge and under the terms of this License, through a 490 | publicly available network server or other readily accessible means, 491 | then you must either (1) cause the Corresponding Source to be so 492 | available, or (2) arrange to deprive yourself of the benefit of the 493 | patent license for this particular work, or (3) arrange, in a manner 494 | consistent with the requirements of this License, to extend the patent 495 | license to downstream recipients. "Knowingly relying" means you have 496 | actual knowledge that, but for the patent license, your conveying the 497 | covered work in a country, or your recipient's use of the covered work 498 | in a country, would infringe one or more identifiable patents in that 499 | country that you have reason to believe are valid. 500 | 501 | If, pursuant to or in connection with a single transaction or 502 | arrangement, you convey, or propagate by procuring conveyance of, a 503 | covered work, and grant a patent license to some of the parties 504 | receiving the covered work authorizing them to use, propagate, modify 505 | or convey a specific copy of the covered work, then the patent license 506 | you grant is automatically extended to all recipients of the covered 507 | work and works based on it. 508 | 509 | A patent license is "discriminatory" if it does not include within 510 | the scope of its coverage, prohibits the exercise of, or is 511 | conditioned on the non-exercise of one or more of the rights that are 512 | specifically granted under this License. You may not convey a covered 513 | work if you are a party to an arrangement with a third party that is 514 | in the business of distributing software, under which you make payment 515 | to the third party based on the extent of your activity of conveying 516 | the work, and under which the third party grants, to any of the 517 | parties who would receive the covered work from you, a discriminatory 518 | patent license (a) in connection with copies of the covered work 519 | conveyed by you (or copies made from those copies), or (b) primarily 520 | for and in connection with specific products or compilations that 521 | contain the covered work, unless you entered into that arrangement, 522 | or that patent license was granted, prior to 28 March 2007. 523 | 524 | Nothing in this License shall be construed as excluding or limiting 525 | any implied license or other defenses to infringement that may 526 | otherwise be available to you under applicable patent law. 527 | 528 | 12. No Surrender of Others' Freedom. 529 | 530 | If conditions are imposed on you (whether by court order, agreement or 531 | otherwise) that contradict the conditions of this License, they do not 532 | excuse you from the conditions of this License. If you cannot convey a 533 | covered work so as to satisfy simultaneously your obligations under this 534 | License and any other pertinent obligations, then as a consequence you may 535 | not convey it at all. For example, if you agree to terms that obligate you 536 | to collect a royalty for further conveying from those to whom you convey 537 | the Program, the only way you could satisfy both those terms and this 538 | License would be to refrain entirely from conveying the Program. 539 | 540 | 13. Remote Network Interaction; Use with the GNU General Public License. 541 | 542 | Notwithstanding any other provision of this License, if you modify the 543 | Program, your modified version must prominently offer all users 544 | interacting with it remotely through a computer network (if your version 545 | supports such interaction) an opportunity to receive the Corresponding 546 | Source of your version by providing access to the Corresponding Source 547 | from a network server at no charge, through some standard or customary 548 | means of facilitating copying of software. This Corresponding Source 549 | shall include the Corresponding Source for any work covered by version 3 550 | of the GNU General Public License that is incorporated pursuant to the 551 | following paragraph. 552 | 553 | Notwithstanding any other provision of this License, you have 554 | permission to link or combine any covered work with a work licensed 555 | under version 3 of the GNU General Public License into a single 556 | combined work, and to convey the resulting work. The terms of this 557 | License will continue to apply to the part which is the covered work, 558 | but the work with which it is combined will remain governed by version 559 | 3 of the GNU General Public License. 560 | 561 | 14. Revised Versions of this License. 562 | 563 | The Free Software Foundation may publish revised and/or new versions of 564 | the GNU Affero General Public License from time to time. Such new versions 565 | will be similar in spirit to the present version, but may differ in detail to 566 | address new problems or concerns. 567 | 568 | Each version is given a distinguishing version number. If the 569 | Program specifies that a certain numbered version of the GNU Affero General 570 | Public License "or any later version" applies to it, you have the 571 | option of following the terms and conditions either of that numbered 572 | version or of any later version published by the Free Software 573 | Foundation. If the Program does not specify a version number of the 574 | GNU Affero General Public License, you may choose any version ever published 575 | by the Free Software Foundation. 576 | 577 | If the Program specifies that a proxy can decide which future 578 | versions of the GNU Affero General Public License can be used, that proxy's 579 | public statement of acceptance of a version permanently authorizes you 580 | to choose that version for the Program. 581 | 582 | Later license versions may give you additional or different 583 | permissions. However, no additional obligations are imposed on any 584 | author or copyright holder as a result of your choosing to follow a 585 | later version. 586 | 587 | 15. Disclaimer of Warranty. 588 | 589 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 590 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 591 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 592 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 593 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 594 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 595 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 596 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 597 | 598 | 16. Limitation of Liability. 599 | 600 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 601 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 602 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 603 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 604 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 605 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 606 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 607 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 608 | SUCH DAMAGES. 609 | 610 | 17. Interpretation of Sections 15 and 16. 611 | 612 | If the disclaimer of warranty and limitation of liability provided 613 | above cannot be given local legal effect according to their terms, 614 | reviewing courts shall apply local law that most closely approximates 615 | an absolute waiver of all civil liability in connection with the 616 | Program, unless a warranty or assumption of liability accompanies a 617 | copy of the Program in return for a fee. 618 | 619 | END OF TERMS AND CONDITIONS 620 | 621 | How to Apply These Terms to Your New Programs 622 | 623 | If you develop a new program, and you want it to be of the greatest 624 | possible use to the public, the best way to achieve this is to make it 625 | free software which everyone can redistribute and change under these terms. 626 | 627 | To do so, attach the following notices to the program. It is safest 628 | to attach them to the start of each source file to most effectively 629 | state the exclusion of warranty; and each file should have at least 630 | the "copyright" line and a pointer to where the full notice is found. 631 | 632 | 633 | Copyright (C) 634 | 635 | This program is free software: you can redistribute it and/or modify 636 | it under the terms of the GNU Affero General Public License as published 637 | by the Free Software Foundation, either version 3 of the License, or 638 | (at your option) any later version. 639 | 640 | This program is distributed in the hope that it will be useful, 641 | but WITHOUT ANY WARRANTY; without even the implied warranty of 642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 643 | GNU Affero General Public License for more details. 644 | 645 | You should have received a copy of the GNU Affero General Public License 646 | along with this program. If not, see . 647 | 648 | Also add information on how to contact you by electronic and paper mail. 649 | 650 | If your software can interact with users remotely through a computer 651 | network, you should also make sure that it provides a way for users to 652 | get its source. For example, if your program is a web application, its 653 | interface could display a "Source" link that leads users to an archive 654 | of the code. There are many ways you could offer source, and different 655 | solutions will be better for different programs; see section 13 for the 656 | specific requirements. 657 | 658 | You should also get your employer (if you work as a programmer) or school, 659 | if any, to sign a "copyright disclaimer" for the program, if necessary. 660 | For more information on this, and how to apply and follow the GNU AGPL, see 661 | . 662 | --------------------------------------------------------------------------------