├── .gitattributes
├── .gitignore
├── .gitmodules
├── Dockerfile
├── Dockerfile-for-Mainland-China
├── LICENSE
├── README-zh.md
├── README.md
├── assets
├── favicon.ico
├── gitconfig-cn
└── tensorboard-example.png
├── config
├── default.toml
├── lora.toml
├── presets
│ └── example.toml
└── sample_prompts.txt
├── gui.py
├── huggingface
├── accelerate
│ └── default_config.yaml
└── hub
│ └── version.txt
├── install-cn.ps1
├── install.bash
├── install.ps1
├── interrogate.ps1
├── logs
└── .keep
├── mikazuki
├── app
│ ├── __init__.py
│ ├── api.py
│ ├── application.py
│ ├── config.py
│ ├── models.py
│ └── proxy.py
├── global.d.ts
├── hook
│ ├── i18n.json
│ └── sitecustomize.py
├── launch_utils.py
├── log.py
├── process.py
├── schema
│ ├── dreambooth.ts
│ ├── flux-lora.ts
│ ├── lora-basic.ts
│ ├── lora-master.ts
│ ├── lumina2-lora.ts
│ ├── sd3-lora.ts
│ └── shared.ts
├── scripts
│ ├── fix_scripts_python_executable_path.py
│ └── torch_check.py
├── tagger
│ ├── dbimutils.py
│ ├── format.py
│ └── interrogator.py
├── tasks.py
├── tsconfig.json
└── utils
│ ├── devices.py
│ ├── tk_window.py
│ └── train_utils.py
├── output
└── .keep
├── requirements.txt
├── resize.ps1
├── run.ipynb
├── run_gui.ps1
├── run_gui.sh
├── run_gui_cn.sh
├── scripts
├── dev
│ ├── .gitignore
│ ├── COMMIT_ID
│ ├── LICENSE.md
│ ├── README-ja.md
│ ├── README.md
│ ├── XTI_hijack.py
│ ├── _typos.toml
│ ├── fine_tune.py
│ ├── finetune
│ │ ├── blip
│ │ │ ├── blip.py
│ │ │ ├── med.py
│ │ │ ├── med_config.json
│ │ │ └── vit.py
│ │ ├── clean_captions_and_tags.py
│ │ ├── hypernetwork_nai.py
│ │ ├── make_captions.py
│ │ ├── make_captions_by_git.py
│ │ ├── merge_captions_to_metadata.py
│ │ ├── merge_dd_tags_to_metadata.py
│ │ ├── prepare_buckets_latents.py
│ │ └── tag_images_by_wd14_tagger.py
│ ├── flux_minimal_inference.py
│ ├── flux_train.py
│ ├── flux_train_control_net.py
│ ├── flux_train_network.py
│ ├── gen_img.py
│ ├── gen_img_diffusers.py
│ ├── library
│ │ ├── __init__.py
│ │ ├── adafactor_fused.py
│ │ ├── attention_processors.py
│ │ ├── config_util.py
│ │ ├── custom_offloading_utils.py
│ │ ├── custom_train_functions.py
│ │ ├── deepspeed_utils.py
│ │ ├── device_utils.py
│ │ ├── flux_models.py
│ │ ├── flux_train_utils.py
│ │ ├── flux_utils.py
│ │ ├── huggingface_util.py
│ │ ├── hypernetwork.py
│ │ ├── ipex
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── diffusers.py
│ │ │ ├── gradscaler.py
│ │ │ └── hijacks.py
│ │ ├── lpw_stable_diffusion.py
│ │ ├── model_util.py
│ │ ├── original_unet.py
│ │ ├── sai_model_spec.py
│ │ ├── sd3_models.py
│ │ ├── sd3_train_utils.py
│ │ ├── sd3_utils.py
│ │ ├── sdxl_lpw_stable_diffusion.py
│ │ ├── sdxl_model_util.py
│ │ ├── sdxl_original_control_net.py
│ │ ├── sdxl_original_unet.py
│ │ ├── sdxl_train_util.py
│ │ ├── slicing_vae.py
│ │ ├── strategy_base.py
│ │ ├── strategy_flux.py
│ │ ├── strategy_sd.py
│ │ ├── strategy_sd3.py
│ │ ├── strategy_sdxl.py
│ │ ├── train_util.py
│ │ └── utils.py
│ ├── networks
│ │ ├── check_lora_weights.py
│ │ ├── control_net_lllite.py
│ │ ├── control_net_lllite_for_train.py
│ │ ├── convert_flux_lora.py
│ │ ├── dylora.py
│ │ ├── extract_lora_from_dylora.py
│ │ ├── extract_lora_from_models.py
│ │ ├── flux_extract_lora.py
│ │ ├── flux_merge_lora.py
│ │ ├── lora.py
│ │ ├── lora_diffusers.py
│ │ ├── lora_fa.py
│ │ ├── lora_flux.py
│ │ ├── lora_interrogator.py
│ │ ├── lora_sd3.py
│ │ ├── merge_lora.py
│ │ ├── merge_lora_old.py
│ │ ├── oft.py
│ │ ├── oft_flux.py
│ │ ├── resize_lora.py
│ │ ├── sdxl_merge_lora.py
│ │ └── svd_merge_lora.py
│ ├── pytest.ini
│ ├── requirements.txt
│ ├── sd3_minimal_inference.py
│ ├── sd3_train.py
│ ├── sd3_train_network.py
│ ├── sdxl_gen_img.py
│ ├── sdxl_minimal_inference.py
│ ├── sdxl_train.py
│ ├── sdxl_train_control_net.py
│ ├── sdxl_train_control_net_lllite.py
│ ├── sdxl_train_control_net_lllite_old.py
│ ├── sdxl_train_network.py
│ ├── sdxl_train_textual_inversion.py
│ ├── setup.py
│ ├── tools
│ │ ├── cache_latents.py
│ │ ├── cache_text_encoder_outputs.py
│ │ ├── canny.py
│ │ ├── convert_diffusers20_original_sd.py
│ │ ├── convert_diffusers_to_flux.py
│ │ ├── detect_face_rotate.py
│ │ ├── latent_upscaler.py
│ │ ├── merge_models.py
│ │ ├── merge_sd3_safetensors.py
│ │ ├── original_control_net.py
│ │ ├── resize_images_to_resolution.py
│ │ └── show_metadata.py
│ ├── train_control_net.py
│ ├── train_controlnet.py
│ ├── train_db.py
│ ├── train_network.py
│ ├── train_textual_inversion.py
│ └── train_textual_inversion_XTI.py
└── stable
│ ├── .gitignore
│ ├── COMMIT_ID
│ ├── LICENSE.md
│ ├── README-ja.md
│ ├── README.md
│ ├── XTI_hijack.py
│ ├── _typos.toml
│ ├── fine_tune.py
│ ├── finetune
│ ├── blip
│ │ ├── blip.py
│ │ ├── med.py
│ │ ├── med_config.json
│ │ └── vit.py
│ ├── clean_captions_and_tags.py
│ ├── hypernetwork_nai.py
│ ├── make_captions.py
│ ├── make_captions_by_git.py
│ ├── merge_captions_to_metadata.py
│ ├── merge_dd_tags_to_metadata.py
│ ├── prepare_buckets_latents.py
│ └── tag_images_by_wd14_tagger.py
│ ├── gen_img.py
│ ├── gen_img_diffusers.py
│ ├── library
│ ├── __init__.py
│ ├── adafactor_fused.py
│ ├── attention_processors.py
│ ├── config_util.py
│ ├── custom_train_functions.py
│ ├── deepspeed_utils.py
│ ├── device_utils.py
│ ├── huggingface_util.py
│ ├── hypernetwork.py
│ ├── ipex
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── diffusers.py
│ │ ├── gradscaler.py
│ │ └── hijacks.py
│ ├── lpw_stable_diffusion.py
│ ├── model_util.py
│ ├── original_unet.py
│ ├── sai_model_spec.py
│ ├── sdxl_lpw_stable_diffusion.py
│ ├── sdxl_model_util.py
│ ├── sdxl_original_unet.py
│ ├── sdxl_train_util.py
│ ├── slicing_vae.py
│ ├── train_util.py
│ └── utils.py
│ ├── networks
│ ├── check_lora_weights.py
│ ├── control_net_lllite.py
│ ├── control_net_lllite_for_train.py
│ ├── dylora.py
│ ├── extract_lora_from_dylora.py
│ ├── extract_lora_from_models.py
│ ├── lora.py
│ ├── lora_diffusers.py
│ ├── lora_fa.py
│ ├── lora_interrogator.py
│ ├── merge_lora.py
│ ├── merge_lora_old.py
│ ├── oft.py
│ ├── resize_lora.py
│ ├── sdxl_merge_lora.py
│ └── svd_merge_lora.py
│ ├── requirements.txt
│ ├── sdxl_gen_img.py
│ ├── sdxl_minimal_inference.py
│ ├── sdxl_train.py
│ ├── sdxl_train_control_net_lllite.py
│ ├── sdxl_train_control_net_lllite_old.py
│ ├── sdxl_train_network.py
│ ├── sdxl_train_textual_inversion.py
│ ├── setup.py
│ ├── tools
│ ├── cache_latents.py
│ ├── cache_text_encoder_outputs.py
│ ├── canny.py
│ ├── convert_diffusers20_original_sd.py
│ ├── detect_face_rotate.py
│ ├── latent_upscaler.py
│ ├── merge_models.py
│ ├── original_control_net.py
│ ├── resize_images_to_resolution.py
│ └── show_metadata.py
│ ├── train_controlnet.py
│ ├── train_db.py
│ ├── train_network.py
│ ├── train_textual_inversion.py
│ └── train_textual_inversion_XTI.py
├── sd-models
└── put stable diffusion model here.txt
├── svd_merge.ps1
├── tagger.ps1
├── tagger.sh
├── tensorboard.ps1
├── train.ipynb
├── train.ps1
├── train.sh
├── train_by_toml.ps1
└── train_by_toml.sh
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ps1 text eol=crlf
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | .idea
3 |
4 | venv
5 | __pycache__
6 |
7 | output/*
8 | !output/.keep
9 |
10 | assets/config.json
11 |
12 | py310
13 | python
14 | git
15 | wd14_tagger_model
16 |
17 | train/*
18 | logs/*
19 | sd-models/*
20 | toml/autosave/*
21 | config/autosave/*
22 | config/presets/test*.toml
23 |
24 | !sd-models/put stable diffusion model here.txt
25 | !logs/.keep
26 |
27 | tests/
28 |
29 | huggingface/hub/models*
30 | huggingface/hub/version_diffusers_cache.txt
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "frontend"]
2 | path = frontend
3 | url = https://github.com/hanamizuki-ai/lora-gui-dist
4 | [submodule "mikazuki/dataset-tag-editor"]
5 | path = mikazuki/dataset-tag-editor
6 | url = https://github.com/Akegarasu/dataset-tag-editor
7 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/pytorch:24.07-py3
2 |
3 | EXPOSE 28000
4 |
5 | ENV TZ=Asia/Shanghai
6 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && apt update && apt install python3-tk -y
7 |
8 | RUN mkdir /app
9 |
10 | WORKDIR /app
11 | RUN git clone --recurse-submodules https://github.com/Akegarasu/lora-scripts
12 |
13 | WORKDIR /app/lora-scripts
14 | RUN pip install xformers==0.0.27.post2 --no-deps && pip install -r requirements.txt
15 |
16 | WORKDIR /app/lora-scripts/scripts
17 | RUN pip install -r requirements.txt
18 |
19 | WORKDIR /app/lora-scripts
20 |
21 | CMD ["python", "gui.py", "--listen"]
--------------------------------------------------------------------------------
/Dockerfile-for-Mainland-China:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/pytorch:24.07-py3
2 |
3 | EXPOSE 28000
4 |
5 | ENV TZ=Asia/Shanghai
6 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && apt update && apt install python3-tk -y
7 |
8 | RUN mkdir /app
9 |
10 | WORKDIR /app
11 | RUN git clone --recurse-submodules https://github.com/Akegarasu/lora-scripts
12 |
13 | WORKDIR /app/lora-scripts
14 |
15 | # 设置 Python pip 软件包国内镜像代理
16 | RUN pip config set global.index-url 'https://pypi.tuna.tsinghua.edu.cn/simple' && \
17 | pip config set install.trusted-host 'pypi.tuna.tsinghua.edu.cn'
18 |
19 | # 初次安装依赖
20 | RUN pip install xformers==0.0.27.post2 --no-deps && pip install -r requirements.txt
21 |
22 | # 更新 训练程序 stable 版本依赖
23 | WORKDIR /app/lora-scripts/scripts/stable
24 | RUN pip install -r requirements.txt
25 |
26 | # 更新 训练程序 dev 版本依赖
27 | WORKDIR /app/lora-scripts/scripts/dev
28 | RUN pip install -r requirements.txt
29 |
30 | WORKDIR /app/lora-scripts
31 |
32 | # 修正运行报错以及底包缺失的依赖
33 | # ref
34 | # - https://soulteary.com/2024/01/07/fix-opencv-dependency-errors-opencv-fixer.html
35 | # - https://blog.csdn.net/qq_50195602/article/details/124188467
36 | RUN pip install opencv-fixer==0.2.5 && python -c "from opencv_fixer import AutoFix; AutoFix()" \
37 | pip install opencv-python-headless && apt install ffmpeg libsm6 libxext6 libgl1 -y
38 |
39 | CMD ["python", "gui.py", "--listen"]
--------------------------------------------------------------------------------
/README-zh.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |

4 |
5 | # SD-Trainer
6 |
7 | _✨ 享受 Stable Diffusion 训练! ✨_
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 | 下载
28 | ·
29 | 文档
30 | ·
31 | 中文README
32 |
33 |
34 | LoRA-scripts(又名 SD-Trainer)
35 |
36 | LoRA & Dreambooth 训练图形界面 & 脚本预设 & 一键训练环境,用于 [kohya-ss/sd-scripts](https://github.com/kohya-ss/sd-scripts.git)
37 |
38 | ## ✨新特性: 训练 WebUI
39 |
40 | Stable Diffusion 训练工作台。一切集成于一个 WebUI 中。
41 |
42 | 按照下面的安装指南安装 GUI,然后运行 `run_gui.ps1`(Windows) 或 `run_gui.sh`(Linux) 来启动 GUI。
43 |
44 | 
45 |
46 | | Tensorboard | WD 1.4 标签器 | 标签编辑器 |
47 | | ------------ | ------------ | ------------ |
48 | |  |  |  |
49 |
50 |
51 | # 使用方法
52 |
53 | ### 必要依赖
54 |
55 | Python 3.10 和 Git
56 |
57 | ### 克隆带子模块的仓库
58 |
59 | ```sh
60 | git clone --recurse-submodules https://github.com/Akegarasu/lora-scripts
61 | ```
62 |
63 | ## ✨ SD-Trainer GUI
64 |
65 | ### Windows
66 |
67 | #### 安装
68 |
69 | 运行 `install-cn.ps1` 将自动为您创建虚拟环境并安装必要的依赖。
70 |
71 | #### 训练
72 |
73 | 运行 `run_gui.ps1`,程序将自动打开 [http://127.0.0.1:28000](http://127.0.0.1:28000)
74 |
75 | ### Linux
76 |
77 | #### 安装
78 |
79 | 运行 `install.bash` 将创建虚拟环境并安装必要的依赖。
80 |
81 | #### 训练
82 |
83 | 运行 `bash run_gui.sh`,程序将自动打开 [http://127.0.0.1:28000](http://127.0.0.1:28000)
84 |
85 | ### Docker
86 |
87 | #### 编译镜像
88 |
89 | ```bash
90 | # 国内镜像优化版本
91 | # 其中 akegarasu_lora-scripts:latest 为镜像及其 tag 名,根据镜像托管服务商实际进行修改
92 | docker build -t akegarasu_lora-scripts:latest -f Dockfile-for-Mainland-China .
93 | docker push akegarasu_lora-scripts:latest
94 | ```
95 |
96 | #### 使用镜像
97 |
98 | > 提供一个本人已打包好并推送到 `aliyuncs` 上的镜像,此镜像压缩归档大小约 `10G` 左右,请耐心等待拉取。
99 |
100 | ```bash
101 | docker run --gpus all -p 28000:28000 -p 6006:6006 registry.cn-hangzhou.aliyuncs.com/go-to-mirror/akegarasu_lora-scripts:latest
102 | ```
103 |
104 | 或者使用 `docker-compose.yaml` 。
105 |
106 | ```yaml
107 | services:
108 | lora-scripts:
109 | container_name: lora-scripts
110 | build:
111 | context: .
112 | dockerfile: Dockerfile-for-Mainland-China
113 | image: "registry.cn-hangzhou.aliyuncs.com/go-to-mirror/akegarasu_lora-scripts:latest"
114 | ports:
115 | - "28000:28000"
116 | - "6006:6006"
117 | # 共享本地文件夹(请根据实际修改)
118 | #volumes:
119 | # - "/data/srv/lora-scripts:/app/lora-scripts"
120 | # 共享 comfyui 大模型
121 | # - "/data/srv/comfyui/models/checkpoints:/app/lora-scripts/sd-models/comfyui"
122 | # 共享 sd-webui 大模型
123 | # - "/data/srv/stable-diffusion-webui/models/Stable-diffusion:/app/lora-scripts/sd-models/sd-webui"
124 | environment:
125 | - HF_HOME=huggingface
126 | - PYTHONUTF8=1
127 | security_opt:
128 | - "label=type:nvidia_container_t"
129 | runtime: nvidia
130 | deploy:
131 | resources:
132 | reservations:
133 | devices:
134 | - driver: nvidia
135 | device_ids: ['0']
136 | capabilities: [gpu]
137 | ```
138 |
139 | 关于容器使用 GPU 相关依赖安装问题,请自行搜索查阅资料解决。
140 |
141 | ## 通过手动运行脚本的传统训练方式
142 |
143 | ### Windows
144 |
145 | #### 安装
146 |
147 | 运行 `install.ps1` 将自动为您创建虚拟环境并安装必要的依赖。
148 |
149 | #### 训练
150 |
151 | 编辑 `train.ps1`,然后运行它。
152 |
153 | ### Linux
154 |
155 | #### 安装
156 |
157 | 运行 `install.bash` 将创建虚拟环境并安装必要的依赖。
158 |
159 | #### 训练
160 |
161 | 训练
162 |
163 | 脚本 `train.sh` **不会** 为您激活虚拟环境。您应该先激活虚拟环境。
164 |
165 | ```sh
166 | source venv/bin/activate
167 | ```
168 |
169 | 编辑 `train.sh`,然后运行它。
170 |
171 | #### TensorBoard
172 |
173 | 运行 `tensorboard.ps1` 将在 http://localhost:6006/ 启动 TensorBoard
174 |
175 | ## 程序参数
176 |
177 | | 参数名称 | 类型 | 默认值 | 描述 |
178 | |------------------------------|-------|--------------|-------------------------------------------------|
179 | | `--host` | str | "127.0.0.1" | 服务器的主机名 |
180 | | `--port` | int | 28000 | 运行服务器的端口 |
181 | | `--listen` | bool | false | 启用服务器的监听模式 |
182 | | `--skip-prepare-environment` | bool | false | 跳过环境准备步骤 |
183 | | `--disable-tensorboard` | bool | false | 禁用 TensorBoard |
184 | | `--disable-tageditor` | bool | false | 禁用标签编辑器 |
185 | | `--tensorboard-host` | str | "127.0.0.1" | 运行 TensorBoard 的主机 |
186 | | `--tensorboard-port` | int | 6006 | 运行 TensorBoard 的端口 |
187 | | `--localization` | str | | 界面的本地化设置 |
188 | | `--dev` | bool | false | 开发者模式,用于禁用某些检查 |
189 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |

4 |
5 | # SD-Trainer
6 |
7 | _✨ Enjoy Stable Diffusion Train! ✨_
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 | Download
28 | ·
29 | Documents
30 | ·
31 | 中文README
32 |
33 |
34 | LoRA-scripts (a.k.a SD-Trainer)
35 |
36 | LoRA & Dreambooth training GUI & scripts preset & one key training environment for [kohya-ss/sd-scripts](https://github.com/kohya-ss/sd-scripts.git)
37 |
38 | ## ✨NEW: Train WebUI
39 |
40 | The **REAL** Stable Diffusion Training Studio. Everything in one WebUI.
41 |
42 | Follow the installation guide below to install the GUI, then run `run_gui.ps1`(windows) or `run_gui.sh`(linux) to start the GUI.
43 |
44 | 
45 |
46 | | Tensorboard | WD 1.4 Tagger | Tag Editor |
47 | | ------------ | ------------ | ------------ |
48 | |  |  |  |
49 |
50 |
51 | # Usage
52 |
53 | ### Required Dependencies
54 |
55 | Python 3.10 and Git
56 |
57 | ### Clone repo with submodules
58 |
59 | ```sh
60 | git clone --recurse-submodules https://github.com/Akegarasu/lora-scripts
61 | ```
62 |
63 | ## ✨ SD-Trainer GUI
64 |
65 | ### Windows
66 |
67 | #### Installation
68 |
69 | Run `install.ps1` will automatically create a venv for you and install necessary deps.
70 | If you are in China mainland, please use `install-cn.ps1`
71 |
72 | #### Train
73 |
74 | run `run_gui.ps1`, then program will open [http://127.0.0.1:28000](http://127.0.0.1:28000) automanticlly
75 |
76 | ### Linux
77 |
78 | #### Installation
79 |
80 | Run `install.bash` will create a venv and install necessary deps.
81 |
82 | #### Train
83 |
84 | run `bash run_gui.sh`, then program will open [http://127.0.0.1:28000](http://127.0.0.1:28000) automanticlly
85 |
86 | ## Legacy training through run script manually
87 |
88 | ### Windows
89 |
90 | #### Installation
91 |
92 | Run `install.ps1` will automatically create a venv for you and install necessary deps.
93 |
94 | #### Train
95 |
96 | Edit `train.ps1`, and run it.
97 |
98 | ### Linux
99 |
100 | #### Installation
101 |
102 | Run `install.bash` will create a venv and install necessary deps.
103 |
104 | #### Train
105 |
106 | Training script `train.sh` **will not** activate venv for you. You should activate venv first.
107 |
108 | ```sh
109 | source venv/bin/activate
110 | ```
111 |
112 | Edit `train.sh`, and run it.
113 |
114 | #### TensorBoard
115 |
116 | Run `tensorboard.ps1` will start TensorBoard at http://localhost:6006/
117 |
118 | ## Program arguments
119 |
120 | | Parameter Name | Type | Default Value | Description |
121 | |-------------------------------|-------|---------------|--------------------------------------------------|
122 | | `--host` | str | "127.0.0.1" | Hostname for the server |
123 | | `--port` | int | 28000 | Port to run the server |
124 | | `--listen` | bool | false | Enable listening mode for the server |
125 | | `--skip-prepare-environment` | bool | false | Skip the environment preparation step |
126 | | `--disable-tensorboard` | bool | false | Disable TensorBoard |
127 | | `--disable-tageditor` | bool | false | Disable tag editor |
128 | | `--tensorboard-host` | str | "127.0.0.1" | Host to run TensorBoard |
129 | | `--tensorboard-port` | int | 6006 | Port to run TensorBoard |
130 | | `--localization` | str | | Localization settings for the interface |
131 | | `--dev` | bool | false | Developer mode to disale some checks |
132 |
--------------------------------------------------------------------------------
/assets/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Akegarasu/lora-scripts/e0f5194815203093659d6ec280b9362b9792c070/assets/favicon.ico
--------------------------------------------------------------------------------
/assets/gitconfig-cn:
--------------------------------------------------------------------------------
1 | [url "https://jihulab.com/Akegarasu/lora-scripts"]
2 | insteadOf = https://github.com/Akegarasu/lora-scripts
3 |
4 | [url "https://jihulab.com/Akegarasu/sd-scripts"]
5 | insteadOf = https://github.com/Akegarasu/sd-scripts
6 |
7 | [url "https://jihulab.com/affair3547/sd-scripts"]
8 | insteadOf = https://github.com/kohya-ss/sd-scripts.git
9 |
10 | [url "https://jihulab.com/affair3547/lora-gui-dist"]
11 | insteadOf = https://github.com/hanamizuki-ai/lora-gui-dist
12 |
13 | [url "https://jihulab.com/Akegarasu/dataset-tag-editor"]
14 | insteadOf = https://github.com/Akegarasu/dataset-tag-editor
15 |
--------------------------------------------------------------------------------
/assets/tensorboard-example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Akegarasu/lora-scripts/e0f5194815203093659d6ec280b9362b9792c070/assets/tensorboard-example.png
--------------------------------------------------------------------------------
/config/default.toml:
--------------------------------------------------------------------------------
1 | [model]
2 | v2 = false
3 | v_parameterization = false
4 | pretrained_model_name_or_path = "./sd-models/model.ckpt"
5 |
6 | [dataset]
7 | train_data_dir = "./train/input"
8 | reg_data_dir = ""
9 | prior_loss_weight = 1
10 | cache_latents = true
11 | shuffle_caption = true
12 | enable_bucket = true
13 |
14 | [additional_network]
15 | network_dim = 32
16 | network_alpha = 16
17 | network_train_unet_only = false
18 | network_train_text_encoder_only = false
19 | network_module = "networks.lora"
20 | network_args = []
21 |
22 | [optimizer]
23 | unet_lr = 1e-4
24 | text_encoder_lr = 1e-5
25 | optimizer_type = "AdamW8bit"
26 | lr_scheduler = "cosine_with_restarts"
27 | lr_warmup_steps = 0
28 | lr_restart_cycles = 1
29 |
30 | [training]
31 | resolution = "512,512"
32 | train_batch_size = 1
33 | max_train_epochs = 10
34 | noise_offset = 0.0
35 | keep_tokens = 0
36 | xformers = true
37 | lowram = false
38 | clip_skip = 2
39 | mixed_precision = "fp16"
40 | save_precision = "fp16"
41 |
42 | [sample_prompt]
43 | sample_sampler = "euler_a"
44 | sample_every_n_epochs = 1
45 |
46 | [saving]
47 | output_name = "output_name"
48 | save_every_n_epochs = 1
49 | save_n_epoch_ratio = 0
50 | save_last_n_epochs = 499
51 | save_state = false
52 | save_model_as = "safetensors"
53 | output_dir = "./output"
54 | logging_dir = "./logs"
55 | log_prefix = "output_name"
56 |
57 | [others]
58 | min_bucket_reso = 256
59 | max_bucket_reso = 1024
60 | caption_extension = ".txt"
61 | max_token_length = 225
62 | seed = 1337
63 |
--------------------------------------------------------------------------------
/config/lora.toml:
--------------------------------------------------------------------------------
1 | [model_arguments]
2 | v2 = false
3 | v_parameterization = false
4 | pretrained_model_name_or_path = "./sd-models/model.ckpt"
5 |
6 | [dataset_arguments]
7 | train_data_dir = "./train/aki"
8 | reg_data_dir = ""
9 | resolution = "512,512"
10 | prior_loss_weight = 1
11 |
12 | [additional_network_arguments]
13 | network_dim = 32
14 | network_alpha = 16
15 | network_train_unet_only = false
16 | network_train_text_encoder_only = false
17 | network_module = "networks.lora"
18 | network_args = []
19 |
20 | [optimizer_arguments]
21 | unet_lr = 1e-4
22 | text_encoder_lr = 1e-5
23 |
24 | optimizer_type = "AdamW8bit"
25 | lr_scheduler = "cosine_with_restarts"
26 | lr_warmup_steps = 0
27 | lr_restart_cycles = 1
28 |
29 | [training_arguments]
30 | train_batch_size = 1
31 | noise_offset = 0.0
32 | keep_tokens = 0
33 | min_bucket_reso = 256
34 | max_bucket_reso = 1024
35 | caption_extension = ".txt"
36 | max_token_length = 225
37 | seed = 1337
38 | xformers = true
39 | lowram = false
40 | max_train_epochs = 10
41 | resolution = "512,512"
42 | clip_skip = 2
43 | mixed_precision = "fp16"
44 |
45 | [sample_prompt_arguments]
46 | sample_sampler = "euler_a"
47 | sample_every_n_epochs = 5
48 |
49 | [saving_arguments]
50 | output_name = "output_name"
51 | save_every_n_epochs = 1
52 | save_state = false
53 | save_model_as = "safetensors"
54 | output_dir = "./output"
55 | logging_dir = "./logs"
56 | log_prefix = ""
57 | save_precision = "fp16"
58 |
59 | [others]
60 | cache_latents = true
61 | shuffle_caption = true
62 | enable_bucket = true
--------------------------------------------------------------------------------
/config/presets/example.toml:
--------------------------------------------------------------------------------
1 | # 模板显示的信息
2 | [metadata]
3 | name = "中分辨率训练" # 模板名称
4 | version = "0.0.1" # 模板版本
5 | author = "秋叶" # 模板作者
6 | # train_type 参数可以设置 lora-basic,lora-master,flux-lora 等内容。用于过滤显示的。不填就全部显示
7 | # train_type = ""
8 | description = "这是一个样例模板,提高训练分辨率。如果你想加入自己的模板,可以按照 config/presets 中的文件,修改后前往 Github 发起 PR"
9 |
10 | # 模板配置内容。请只填写修改了的内容,未修改无需填写。
11 | [data]
12 | resolution = "768,768"
13 | enable_bucket = true
14 | min_bucket_reso = 256
15 | max_bucket_reso = 2048
16 | bucket_no_upscale = true
--------------------------------------------------------------------------------
/config/sample_prompts.txt:
--------------------------------------------------------------------------------
1 | (masterpiece, best quality:1.2), 1girl, solo, --n lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts,signature, watermark, username, blurry, --w 512 --h 768 --l 7 --s 24 --d 1337
--------------------------------------------------------------------------------
/gui.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import locale
3 | import os
4 | import platform
5 | import subprocess
6 | import sys
7 |
8 | from mikazuki.launch_utils import (base_dir_path, catch_exception, git_tag,
9 | prepare_environment, check_port_avaliable, find_avaliable_ports)
10 | from mikazuki.log import log
11 |
12 | parser = argparse.ArgumentParser(description="GUI for stable diffusion training")
13 | parser.add_argument("--host", type=str, default="127.0.0.1")
14 | parser.add_argument("--port", type=int, default=28000, help="Port to run the server on")
15 | parser.add_argument("--listen", action="store_true")
16 | parser.add_argument("--skip-prepare-environment", action="store_true")
17 | parser.add_argument("--skip-prepare-onnxruntime", action="store_true")
18 | parser.add_argument("--disable-tensorboard", action="store_true")
19 | parser.add_argument("--disable-tageditor", action="store_true")
20 | parser.add_argument("--disable-auto-mirror", action="store_true")
21 | parser.add_argument("--tensorboard-host", type=str, default="127.0.0.1", help="Port to run the tensorboard")
22 | parser.add_argument("--tensorboard-port", type=int, default=6006, help="Port to run the tensorboard")
23 | parser.add_argument("--localization", type=str)
24 | parser.add_argument("--dev", action="store_true")
25 |
26 |
27 | @catch_exception
28 | def run_tensorboard():
29 | log.info("Starting tensorboard...")
30 | subprocess.Popen([sys.executable, "-m", "tensorboard.main", "--logdir", "logs",
31 | "--host", args.tensorboard_host, "--port", str(args.tensorboard_port)])
32 |
33 |
34 | @catch_exception
35 | def run_tag_editor():
36 | log.info("Starting tageditor...")
37 | cmd = [
38 | sys.executable,
39 | base_dir_path() / "mikazuki/dataset-tag-editor/scripts/launch.py",
40 | "--port", "28001",
41 | "--shadow-gradio-output",
42 | "--root-path", "/proxy/tageditor"
43 | ]
44 | if args.localization:
45 | cmd.extend(["--localization", args.localization])
46 | else:
47 | l = locale.getdefaultlocale()[0]
48 | if l and l.startswith("zh"):
49 | cmd.extend(["--localization", "zh-Hans"])
50 | subprocess.Popen(cmd)
51 |
52 |
53 | def launch():
54 | log.info("Starting SD-Trainer Mikazuki GUI...")
55 | log.info(f"Base directory: {base_dir_path()}, Working directory: {os.getcwd()}")
56 | log.info(f"{platform.system()} Python {platform.python_version()} {sys.executable}")
57 |
58 | if not args.skip_prepare_environment:
59 | prepare_environment(disable_auto_mirror=args.disable_auto_mirror)
60 |
61 | if not check_port_avaliable(args.port):
62 | avaliable = find_avaliable_ports(30000, 30000+20)
63 | if avaliable:
64 | args.port = avaliable
65 | else:
66 | log.error("port finding fallback error")
67 |
68 | log.info(f"SD-Trainer Version: {git_tag(base_dir_path())}")
69 |
70 | os.environ["MIKAZUKI_HOST"] = args.host
71 | os.environ["MIKAZUKI_PORT"] = str(args.port)
72 | os.environ["MIKAZUKI_TENSORBOARD_HOST"] = args.tensorboard_host
73 | os.environ["MIKAZUKI_TENSORBOARD_PORT"] = str(args.tensorboard_port)
74 | os.environ["MIKAZUKI_DEV"] = "1" if args.dev else "0"
75 |
76 | if args.listen:
77 | args.host = "0.0.0.0"
78 | args.tensorboard_host = "0.0.0.0"
79 |
80 | if not args.disable_tageditor:
81 | run_tag_editor()
82 |
83 | if not args.disable_tensorboard:
84 | run_tensorboard()
85 |
86 | import uvicorn
87 | log.info(f"Server started at http://{args.host}:{args.port}")
88 | uvicorn.run("mikazuki.app:app", host=args.host, port=args.port, log_level="error", reload=args.dev)
89 |
90 |
91 | if __name__ == "__main__":
92 | args, _ = parser.parse_known_args()
93 | launch()
94 |
--------------------------------------------------------------------------------
/huggingface/accelerate/default_config.yaml:
--------------------------------------------------------------------------------
1 | command_file: null
2 | commands: null
3 | compute_environment: LOCAL_MACHINE
4 | deepspeed_config: {}
5 | distributed_type: 'NO'
6 | downcast_bf16: 'no'
7 | dynamo_backend: 'NO'
8 | fsdp_config: {}
9 | gpu_ids: all
10 | machine_rank: 0
11 | main_process_ip: null
12 | main_process_port: null
13 | main_training_function: main
14 | megatron_lm_config: {}
15 | mixed_precision: fp16
16 | num_machines: 1
17 | num_processes: 1
18 | rdzv_backend: static
19 | same_network: true
20 | tpu_name: null
21 | tpu_zone: null
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/huggingface/hub/version.txt:
--------------------------------------------------------------------------------
1 | 1
--------------------------------------------------------------------------------
/install-cn.ps1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Akegarasu/lora-scripts/e0f5194815203093659d6ec280b9362b9792c070/install-cn.ps1
--------------------------------------------------------------------------------
/install.bash:
--------------------------------------------------------------------------------
1 | #!/usr/bin/bash
2 |
3 | script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
4 | create_venv=true
5 |
6 | while [ -n "$1" ]; do
7 | case "$1" in
8 | --disable-venv)
9 | create_venv=false
10 | shift
11 | ;;
12 | *)
13 | shift
14 | ;;
15 | esac
16 | done
17 |
18 | if $create_venv; then
19 | echo "Creating python venv..."
20 | python3 -m venv venv
21 | source "$script_dir/venv/bin/activate"
22 | echo "active venv"
23 | fi
24 |
25 | echo "Installing torch & xformers..."
26 |
27 | cuda_version=$(nvidia-smi | grep -oiP 'CUDA Version: \K[\d\.]+')
28 |
29 | if [ -z "$cuda_version" ]; then
30 | cuda_version=$(nvcc --version | grep -oiP 'release \K[\d\.]+')
31 | fi
32 | cuda_major_version=$(echo "$cuda_version" | awk -F'.' '{print $1}')
33 | cuda_minor_version=$(echo "$cuda_version" | awk -F'.' '{print $2}')
34 |
35 | echo "CUDA Version: $cuda_version"
36 |
37 |
38 | if (( cuda_major_version >= 12 )); then
39 | echo "install torch 2.7.0+cu128"
40 | pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --extra-index-url https://download.pytorch.org/whl/cu128
41 | pip install --no-deps xformers==0.0.30 --extra-index-url https://download.pytorch.org/whl/cu128
42 | elif (( cuda_major_version == 11 && cuda_minor_version >= 8 )); then
43 | echo "install torch 2.4.0+cu118"
44 | pip install torch==2.4.0+cu118 torchvision==0.19.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
45 | pip install --no-deps xformers==0.0.27.post2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
46 | elif (( cuda_major_version == 11 && cuda_minor_version >= 6 )); then
47 | echo "install torch 1.12.1+cu116"
48 | pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
49 | # for RTX3090+cu113/cu116 xformers, we need to install this version from source. You can also try xformers==0.0.18
50 | pip install --upgrade git+https://github.com/facebookresearch/xformers.git@0bad001ddd56c080524d37c84ff58d9cd030ebfd
51 | pip install triton==2.0.0.dev20221202
52 | elif (( cuda_major_version == 11 && cuda_minor_version >= 2 )); then
53 | echo "install torch 1.12.1+cu113"
54 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu116
55 | pip install --upgrade git+https://github.com/facebookresearch/xformers.git@0bad001ddd56c080524d37c84ff58d9cd030ebfd
56 | pip install triton==2.0.0.dev20221202
57 | else
58 | echo "Unsupported cuda version:$cuda_version"
59 | exit 1
60 | fi
61 |
62 | echo "Installing deps..."
63 |
64 | cd "$script_dir" || exit
65 | pip install --upgrade -r requirements.txt
66 |
67 | echo "Install completed"
68 |
--------------------------------------------------------------------------------
/install.ps1:
--------------------------------------------------------------------------------
1 | $Env:HF_HOME = "huggingface"
2 |
3 | if (!(Test-Path -Path "venv")) {
4 | Write-Output "Creating venv for python..."
5 | python -m venv venv
6 | }
7 | .\venv\Scripts\activate
8 |
9 | Write-Output "Installing deps..."
10 |
11 | pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --extra-index-url https://download.pytorch.org/whl/cu128
12 | pip install -U -I --no-deps xformers==0.0.30 --extra-index-url https://download.pytorch.org/whl/cu128
13 | pip install --upgrade -r requirements.txt
14 |
15 | Write-Output "Install completed"
16 | Read-Host | Out-Null ;
17 |
--------------------------------------------------------------------------------
/interrogate.ps1:
--------------------------------------------------------------------------------
1 | # LoRA interrogate script by @bdsqlsz
2 |
3 | $v2 = 0 # load Stable Diffusion v2.x model / Stable Diffusion 2.x模型读取
4 | $sd_model = "./sd-models/sd_model.safetensors" # Stable Diffusion model to load: ckpt or safetensors file | 读取的基础SD模型, 保存格式 cpkt 或 safetensors
5 | $model = "./output/LoRA.safetensors" # LoRA model to interrogate: ckpt or safetensors file | 需要调查关键字的LORA模型, 保存格式 cpkt 或 safetensors
6 | $batch_size = 64 # batch size for processing with Text Encoder | 使用 Text Encoder 处理时的批量大小,默认16,推荐64/128
7 | $clip_skip = 1 # use output of nth layer from back of text encoder (n>=1) | 使用文本编码器倒数第 n 层的输出,n 可以是大于等于 1 的整数
8 |
9 |
10 | # Activate python venv
11 | .\venv\Scripts\activate
12 |
13 | $Env:HF_HOME = "huggingface"
14 | $ext_args = [System.Collections.ArrayList]::new()
15 |
16 | if ($v2) {
17 | [void]$ext_args.Add("--v2")
18 | }
19 |
20 | # run interrogate
21 | accelerate launch --num_cpu_threads_per_process=8 "./scripts/stable/networks/lora_interrogator.py" `
22 | --sd_model=$sd_model `
23 | --model=$model `
24 | --batch_size=$batch_size `
25 | --clip_skip=$clip_skip `
26 | $ext_args
27 |
28 | Write-Output "Interrogate finished"
29 | Read-Host | Out-Null ;
30 |
--------------------------------------------------------------------------------
/logs/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Akegarasu/lora-scripts/e0f5194815203093659d6ec280b9362b9792c070/logs/.keep
--------------------------------------------------------------------------------
/mikazuki/app/__init__.py:
--------------------------------------------------------------------------------
1 | from . import application
2 |
3 | app = application.app
--------------------------------------------------------------------------------
/mikazuki/app/application.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import mimetypes
3 | import os
4 | import sys
5 | import webbrowser
6 | from contextlib import asynccontextmanager
7 |
8 | from fastapi import FastAPI
9 | from fastapi.middleware.cors import CORSMiddleware
10 | from fastapi.responses import FileResponse
11 | from fastapi.staticfiles import StaticFiles
12 | from starlette.exceptions import HTTPException
13 |
14 | from mikazuki.app.config import app_config
15 | from mikazuki.app.api import load_schemas, load_presets
16 | from mikazuki.app.api import router as api_router
17 | # from mikazuki.app.ipc import router as ipc_router
18 | from mikazuki.app.proxy import router as proxy_router
19 | from mikazuki.utils.devices import check_torch_gpu
20 |
21 | mimetypes.add_type("application/javascript", ".js")
22 | mimetypes.add_type("text/css", ".css")
23 |
24 |
25 | class SPAStaticFiles(StaticFiles):
26 | async def get_response(self, path: str, scope):
27 | try:
28 | return await super().get_response(path, scope)
29 | except HTTPException as ex:
30 | if ex.status_code == 404:
31 | return await super().get_response("index.html", scope)
32 | else:
33 | raise ex
34 |
35 |
36 | async def app_startup():
37 | app_config.load_config()
38 |
39 | await load_schemas()
40 | await load_presets()
41 | await asyncio.to_thread(check_torch_gpu)
42 |
43 | if sys.platform == "win32" and os.environ.get("MIKAZUKI_DEV", "0") != "1":
44 | webbrowser.open(f'http://{os.environ["MIKAZUKI_HOST"]}:{os.environ["MIKAZUKI_PORT"]}')
45 |
46 |
47 | @asynccontextmanager
48 | async def lifespan(app: FastAPI):
49 | await app_startup()
50 | yield
51 |
52 |
53 | app = FastAPI(lifespan=lifespan)
54 | app.include_router(proxy_router)
55 |
56 |
57 | cors_config = os.environ.get("MIKAZUKI_APP_CORS", "")
58 | if cors_config != "":
59 | if cors_config == "1":
60 | cors_config = ["http://localhost:8004", "*"]
61 | else:
62 | cors_config = cors_config.split(";")
63 | app.add_middleware(
64 | CORSMiddleware,
65 | allow_origins=cors_config,
66 | allow_credentials=True,
67 | allow_methods=["*"],
68 | allow_headers=["*"],
69 | )
70 |
71 |
72 | @app.middleware("http")
73 | async def add_cache_control_header(request, call_next):
74 | response = await call_next(request)
75 | response.headers["Cache-Control"] = "max-age=0"
76 | return response
77 |
78 | app.include_router(api_router, prefix="/api")
79 | # app.include_router(ipc_router, prefix="/ipc")
80 |
81 |
82 | @app.get("/")
83 | async def index():
84 | return FileResponse("./frontend/dist/index.html")
85 |
86 |
87 | @app.get("/favicon.ico", response_class=FileResponse)
88 | async def favicon():
89 | return FileResponse("assets/favicon.ico")
90 |
91 | app.mount("/", SPAStaticFiles(directory="frontend/dist", html=True), name="static")
92 |
--------------------------------------------------------------------------------
/mikazuki/app/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from pathlib import Path
4 | from mikazuki.log import log
5 |
6 | class Config:
7 |
8 | def __init__(self, path: str):
9 | self.path = path
10 | self._stored = {}
11 | self._default = {
12 | "last_path": "",
13 | "saved_params": {}
14 | }
15 | self.lock = False
16 |
17 | def load_config(self):
18 | log.info(f"Loading config from {self.path}")
19 | if not os.path.exists(self.path):
20 | self._stored = self._default
21 | self.save_config()
22 | return
23 |
24 | try:
25 | with open(self.path, "r", encoding="utf-8") as f:
26 | self._stored = json.load(f)
27 | except Exception as e:
28 | log.error(f"Error loading config: {e}")
29 | self._stored = self._default
30 | return
31 |
32 | def save_config(self):
33 | try:
34 | with open(self.path, "w", encoding="utf-8") as f:
35 | json.dump(self._stored, f, indent=4, ensure_ascii=False)
36 | except Exception as e:
37 | log.error(f"Error saving config: {e}")
38 |
39 | def __getitem__(self, key):
40 |
41 | return self._stored.get(key, None)
42 |
43 | def __setitem__(self, key, value):
44 | self._stored[key] = value
45 |
46 |
47 | app_config = Config(Path(__file__).parents[2].absolute() / "assets" / "config.json")
48 |
--------------------------------------------------------------------------------
/mikazuki/app/models.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, Field
2 | from typing import List, Optional, Union, Dict, Any
3 |
4 |
5 | class TaggerInterrogateRequest(BaseModel):
6 | path: str
7 | interrogator_model: str = Field(
8 | default="wd14-convnextv2-v2"
9 | )
10 | threshold: float = Field(
11 | default=0.35,
12 | ge=0,
13 | le=1
14 | )
15 | additional_tags: str = ""
16 | exclude_tags: str = ""
17 | escape_tag: bool = True
18 | batch_input_recursive: bool = False
19 | batch_output_action_on_conflict: str = "ignore"
20 | replace_underscore: bool = True
21 | replace_underscore_excludes: str = Field(
22 | default="0_0, (o)_(o), +_+, +_-, ._., _, <|>_<|>, =_=, >_<, 3_3, 6_9, >_o, @_@, ^_^, o_o, u_u, x_x, |_|, ||_||"
23 | )
24 |
25 |
26 | class APIResponse(BaseModel):
27 | status: str
28 | message: Optional[str]
29 | data: Optional[Dict]
30 |
31 |
32 | class APIResponseSuccess(APIResponse):
33 | status: str = "success"
34 |
35 |
36 | class APIResponseFail(APIResponse):
37 | status: str = "fail"
38 |
--------------------------------------------------------------------------------
/mikazuki/app/proxy.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 |
4 | import httpx
5 | import starlette
6 | import websockets
7 | from fastapi import APIRouter, Request, WebSocket
8 | from httpx import ConnectError
9 | from starlette.background import BackgroundTask
10 | from starlette.requests import Request
11 | from starlette.responses import PlainTextResponse, StreamingResponse
12 |
13 | from mikazuki.log import log
14 |
15 | router = APIRouter()
16 |
17 |
18 | def reverse_proxy_maker(url_type: str, full_path: bool = False):
19 | if url_type == "tensorboard":
20 | host = os.environ.get("MIKAZUKI_TENSORBOARD_HOST", "127.0.0.1")
21 | port = os.environ.get("MIKAZUKI_TENSORBOARD_PORT", "6006")
22 | elif url_type == "tageditor":
23 | host = os.environ.get("MIKAZUKI_TAGEDITOR_HOST", "127.0.0.1")
24 | port = os.environ.get("MIKAZUKI_TAGEDITOR_PORT", "28001")
25 |
26 | client = httpx.AsyncClient(base_url=f"http://{host}:{port}/", proxies={}, trust_env=False, timeout=360)
27 |
28 | async def _reverse_proxy(request: Request):
29 | if full_path:
30 | url = httpx.URL(path=request.url.path, query=request.url.query.encode("utf-8"))
31 | else:
32 | url = httpx.URL(
33 | path=request.path_params.get("path", ""),
34 | query=request.url.query.encode("utf-8")
35 | )
36 | rp_req = client.build_request(
37 | request.method, url,
38 | headers=request.headers.raw,
39 | content=request.stream() if request.method != "GET" else None
40 | )
41 | try:
42 | rp_resp = await client.send(rp_req, stream=True)
43 | except ConnectError:
44 | return PlainTextResponse(
45 | content="The requested service not started yet or service started fail. This may cost a while when you first time startup\n请求的服务尚未启动或启动失败。若是第一次启动,可能需要等待一段时间后再刷新网页。",
46 | status_code=502
47 | )
48 | return StreamingResponse(
49 | rp_resp.aiter_raw(),
50 | status_code=rp_resp.status_code,
51 | headers=rp_resp.headers,
52 | background=BackgroundTask(rp_resp.aclose),
53 | )
54 |
55 | return _reverse_proxy
56 |
57 |
58 | async def proxy_ws_forward(ws_a: WebSocket, ws_b: websockets.WebSocketClientProtocol):
59 | while True:
60 | try:
61 | data = await ws_a.receive_text()
62 | await ws_b.send(data)
63 | except starlette.websockets.WebSocketDisconnect as e:
64 | break
65 | except Exception as e:
66 | log.error(f"Error when proxy data client -> backend: {e}")
67 | break
68 |
69 |
70 | async def proxy_ws_reverse(ws_a: WebSocket, ws_b: websockets.WebSocketClientProtocol):
71 | while True:
72 | try:
73 | data = await ws_b.recv()
74 | await ws_a.send_text(data)
75 | except websockets.exceptions.ConnectionClosedOK as e:
76 | break
77 | except Exception as e:
78 | log.error(f"Error when proxy data backend -> client: {e}")
79 | break
80 |
81 |
82 | @router.websocket("/proxy/tageditor/queue/join")
83 | async def websocket_a(ws_a: WebSocket):
84 | # for temp use
85 | ws_b_uri = "ws://127.0.0.1:28001/queue/join"
86 | await ws_a.accept()
87 | async with websockets.connect(ws_b_uri, timeout=360, ping_timeout=None) as ws_b_client:
88 | fwd_task = asyncio.create_task(proxy_ws_forward(ws_a, ws_b_client))
89 | rev_task = asyncio.create_task(proxy_ws_reverse(ws_a, ws_b_client))
90 | await asyncio.gather(fwd_task, rev_task)
91 |
92 | router.add_route("/proxy/tensorboard/{path:path}", reverse_proxy_maker("tensorboard"), ["GET", "POST"])
93 | router.add_route("/font-roboto/{path:path}", reverse_proxy_maker("tensorboard", full_path=True), ["GET", "POST"])
94 | router.add_route("/proxy/tageditor/{path:path}", reverse_proxy_maker("tageditor"), ["GET", "POST"])
95 |
--------------------------------------------------------------------------------
/mikazuki/global.d.ts:
--------------------------------------------------------------------------------
1 | interface Window {
2 | __MIKAZUKI__: any;
3 | }
4 |
5 | type Dict = {
6 | [key in K]: T;
7 | };
8 |
9 | declare const kSchema: unique symbol;
10 |
11 | declare namespace Schemastery {
12 | type From = X extends string | number | boolean ? SchemaI : X extends SchemaI ? X : X extends typeof String ? SchemaI : X extends typeof Number ? SchemaI : X extends typeof Boolean ? SchemaI : X extends typeof Function ? SchemaI any> : X extends Constructor ? SchemaI : never;
13 | type TypeS1 = X extends SchemaI ? S : never;
14 | type Inverse = X extends SchemaI ? (arg: Y) => void : never;
15 | type TypeS = TypeS1>;
16 | type TypeT = ReturnType>;
17 | type Resolve = (data: any, schema: SchemaI, options?: Options, strict?: boolean) => [any, any?];
18 | type IntersectS = From extends SchemaI ? S : never;
19 | type IntersectT = Inverse> extends ((arg: infer T) => void) ? T : never;
20 | type TupleS = X extends readonly [infer L, ...infer R] ? [TypeS?, ...TupleS] : any[];
21 | type TupleT = X extends readonly [infer L, ...infer R] ? [TypeT?, ...TupleT] : any[];
22 | type ObjectS = {
23 | [K in keyof X]?: TypeS | null;
24 | } & Dict;
25 | type ObjectT = {
26 | [K in keyof X]: TypeT;
27 | } & Dict;
28 | type Constructor = new (...args: any[]) => T;
29 | interface Static {
30 | (options: Partial>): SchemaI;
31 | new (options: Partial>): SchemaI;
32 | prototype: SchemaI;
33 | resolve: Resolve;
34 | from(source?: X): From;
35 | extend(type: string, resolve: Resolve): void;
36 | any(): SchemaI;
37 | never(): SchemaI;
38 | const(value: T): SchemaI;
39 | string(): SchemaI;
40 | number(): SchemaI;
41 | natural(): SchemaI;
42 | percent(): SchemaI;
43 | boolean(): SchemaI;
44 | date(): SchemaI;
45 | bitset(bits: Partial>): SchemaI;
46 | function(): SchemaI any>;
47 | is(constructor: Constructor): SchemaI;
48 | array(inner: X): SchemaI[], TypeT[]>;
49 | dict = SchemaI>(inner: X, sKey?: Y): SchemaI, TypeS>, Dict, TypeT>>;
50 | tuple(list: X): SchemaI, TupleT>;
51 | object(dict: X): SchemaI, ObjectT>;
52 | union(list: readonly X[]): SchemaI, TypeT>;
53 | intersect(list: readonly X[]): SchemaI, IntersectT>;
54 | transform(inner: X, callback: (value: TypeS) => T, preserve?: boolean): SchemaI, T>;
55 | }
56 | interface Options {
57 | autofix?: boolean;
58 | }
59 | interface Meta {
60 | default?: T extends {} ? Partial : T;
61 | required?: boolean;
62 | disabled?: boolean;
63 | collapse?: boolean;
64 | badges?: {
65 | text: string;
66 | type: string;
67 | }[];
68 | hidden?: boolean;
69 | loose?: boolean;
70 | role?: string;
71 | extra?: any;
72 | link?: string;
73 | description?: string | Dict;
74 | comment?: string;
75 | pattern?: {
76 | source: string;
77 | flags?: string;
78 | };
79 | max?: number;
80 | min?: number;
81 | step?: number;
82 | }
83 |
84 | interface Schemastery {
85 | (data?: S | null, options?: Schemastery.Options): T;
86 | new(data?: S | null, options?: Schemastery.Options): T;
87 | [kSchema]: true;
88 | uid: number;
89 | meta: Schemastery.Meta;
90 | type: string;
91 | sKey?: SchemaI;
92 | inner?: SchemaI;
93 | list?: SchemaI[];
94 | dict?: Dict;
95 | bits?: Dict;
96 | callback?: Function;
97 | value?: T;
98 | refs?: Dict;
99 | preserve?: boolean;
100 | toString(inline?: boolean): string;
101 | toJSON(): SchemaI;
102 | required(value?: boolean): SchemaI;
103 | hidden(value?: boolean): SchemaI;
104 | loose(value?: boolean): SchemaI;
105 | role(text: string, extra?: any): SchemaI;
106 | link(link: string): SchemaI;
107 | default(value: T): SchemaI;
108 | comment(text: string): SchemaI;
109 | description(text: string): SchemaI;
110 | disabled(value?: boolean): SchemaI;
111 | collapse(value?: boolean): SchemaI;
112 | deprecated(): SchemaI;
113 | experimental(): SchemaI;
114 | pattern(regexp: RegExp): SchemaI;
115 | max(value: number): SchemaI;
116 | min(value: number): SchemaI;
117 | step(value: number): SchemaI;
118 | set(key: string, value: SchemaI): SchemaI;
119 | push(value: SchemaI): SchemaI;
120 | simplify(value?: any): any;
121 | i18n(messages: Dict): SchemaI;
122 | extra(key: K, value: Schemastery.Meta[K]): SchemaI;
123 | }
124 |
125 | }
126 |
127 | type SchemaI = Schemastery.Schemastery;
128 |
129 | declare const Schema: Schemastery.Static
130 |
131 | declare const SHARED_SCHEMAS: Dict
132 |
133 | declare function UpdateSchema(origin: Record, modify?: Record, toDelete?: string[]): Record;
134 |
--------------------------------------------------------------------------------
/mikazuki/hook/i18n.json:
--------------------------------------------------------------------------------
1 | {
2 | "指定エポックまでのステップ数": "指定 epoch 之前的步数",
3 | "学習開始": "训练开始",
4 | "学習画像の数×繰り返し回数": "训练图像数量×重复次数",
5 | "正則化画像の数": "正则化图像数量",
6 | "正則化画像の数×繰り返し回数": "正则化图像数量×重复次数",
7 | "1epochのバッチ数": "每个 epoch 的 batch 数",
8 | "バッチサイズ": "batch 大小",
9 | "学習率": "学习率",
10 | "勾配を合計するステップ数": "梯度累积步数",
11 | "学習率の減衰率": "学习率衰减率",
12 | "学習ステップ数": "训练步数"
13 | }
--------------------------------------------------------------------------------
/mikazuki/hook/sitecustomize.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | _origin_print = print
4 |
5 |
6 | def i18n_print(data, *args, **kwargs):
7 | _origin_print(data, *args, **kwargs)
8 | _origin_print("i18n_print")
9 |
10 |
11 | __builtins__["print"] = i18n_print
12 |
--------------------------------------------------------------------------------
/mikazuki/log.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | log = logging.getLogger('sd-trainer')
5 | log.setLevel(logging.DEBUG)
6 |
7 | try:
8 | from rich.console import Console
9 | from rich.logging import RichHandler
10 | from rich.pretty import install as pretty_install
11 | from rich.theme import Theme
12 |
13 | console = Console(
14 | log_time=True,
15 | log_time_format='%H:%M:%S-%f',
16 | theme=Theme(
17 | {
18 | 'traceback.border': 'black',
19 | 'traceback.border.syntax_error': 'black',
20 | 'inspect.value.border': 'black',
21 | }
22 | ),
23 | )
24 | pretty_install(console=console)
25 | rh = RichHandler(
26 | show_time=True,
27 | omit_repeated_times=False,
28 | show_level=True,
29 | show_path=False,
30 | markup=False,
31 | rich_tracebacks=True,
32 | log_time_format='%H:%M:%S-%f',
33 | level=logging.INFO,
34 | console=console,
35 | )
36 | rh.set_name(logging.INFO)
37 | while log.hasHandlers() and len(log.handlers) > 0:
38 | log.removeHandler(log.handlers[0])
39 | log.addHandler(rh)
40 |
41 | except ModuleNotFoundError:
42 | pass
43 |
44 |
--------------------------------------------------------------------------------
/mikazuki/process.py:
--------------------------------------------------------------------------------
1 |
2 | import asyncio
3 | import os
4 | import sys
5 | from typing import Optional
6 |
7 | from mikazuki.app.models import APIResponse
8 | from mikazuki.log import log
9 | from mikazuki.tasks import tm
10 | from mikazuki.launch_utils import base_dir_path
11 |
12 |
13 | def run_train(toml_path: str,
14 | trainer_file: str = "./scripts/train_network.py",
15 | gpu_ids: Optional[list] = None,
16 | cpu_threads: Optional[int] = 2):
17 | log.info(f"Training started with config file / 训练开始,使用配置文件: {toml_path}")
18 | args = [
19 | sys.executable, "-m", "accelerate.commands.launch", # use -m to avoid python script executable error
20 | "--num_cpu_threads_per_process", str(cpu_threads), # cpu threads
21 | "--quiet", # silence accelerate error message
22 | trainer_file,
23 | "--config_file", toml_path,
24 | ]
25 |
26 | customize_env = os.environ.copy()
27 | customize_env["ACCELERATE_DISABLE_RICH"] = "1"
28 | customize_env["PYTHONUNBUFFERED"] = "1"
29 | customize_env["PYTHONWARNINGS"] = "ignore::FutureWarning,ignore::UserWarning"
30 |
31 | if gpu_ids:
32 | customize_env["CUDA_VISIBLE_DEVICES"] = ",".join(gpu_ids)
33 | log.info(f"Using GPU(s) / 使用 GPU: {gpu_ids}")
34 |
35 | if len(gpu_ids) > 1:
36 | args[3:3] = ["--multi_gpu", "--num_processes", str(len(gpu_ids))]
37 | if sys.platform == "win32":
38 | customize_env["USE_LIBUV"] = "0"
39 | args[3:3] = ["--rdzv_backend", "c10d"]
40 |
41 | if not (task := tm.create_task(args, customize_env)):
42 | return APIResponse(status="error", message="Failed to create task / 无法创建训练任务")
43 |
44 | def _run():
45 | try:
46 | task.execute()
47 | result = task.communicate()
48 | if result.returncode != 0:
49 | log.error(f"Training failed / 训练失败")
50 | else:
51 | log.info(f"Training finished / 训练完成")
52 | except Exception as e:
53 | log.error(f"An error occurred when training / 训练出现致命错误: {e}")
54 |
55 | coro = asyncio.to_thread(_run)
56 | asyncio.create_task(coro)
57 |
58 | return APIResponse(status="success", message=f"Training started / 训练开始 ID: {task.task_id}")
59 |
--------------------------------------------------------------------------------
/mikazuki/schema/flux-lora.ts:
--------------------------------------------------------------------------------
1 | Schema.intersect([
2 | Schema.object({
3 | model_train_type: Schema.string().default("flux-lora").disabled().description("训练种类"),
4 | pretrained_model_name_or_path: Schema.string().role('filepicker', { type: "model-file" }).default("./sd-models/model.safetensors").description("Flux 模型路径"),
5 | ae: Schema.string().role('filepicker', { type: "model-file" }).description("AE 模型文件路径"),
6 | clip_l: Schema.string().role('filepicker', { type: "model-file" }).description("clip_l 模型文件路径"),
7 | t5xxl: Schema.string().role('filepicker', { type: "model-file" }).description("t5xxl 模型文件路径"),
8 | resume: Schema.string().role('filepicker', { type: "folder" }).description("从某个 `save_state` 保存的中断状态继续训练,填写文件路径"),
9 | }).description("训练用模型"),
10 |
11 | Schema.object({
12 | timestep_sampling: Schema.union(["sigma", "uniform", "sigmoid", "shift"]).default("sigmoid").description("时间步采样"),
13 | sigmoid_scale: Schema.number().step(0.001).default(1.0).description("sigmoid 缩放"),
14 | model_prediction_type: Schema.union(["raw", "additive", "sigma_scaled"]).default("raw").description("模型预测类型"),
15 | discrete_flow_shift: Schema.number().step(0.001).default(1.0).description("Euler 调度器离散流位移"),
16 | loss_type: Schema.union(["l1", "l2", "huber", "smooth_l1"]).default("l2").description("损失函数类型"),
17 | guidance_scale: Schema.number().step(0.01).default(1.0).description("CFG 引导缩放"),
18 | t5xxl_max_token_length: Schema.number().step(1).description("T5XXL 最大 token 长度(不填写使用自动)"),
19 | train_t5xxl: Schema.boolean().default(false).description("训练 T5XXL(不推荐)"),
20 | }).description("Flux 专用参数"),
21 |
22 | Schema.object(
23 | UpdateSchema(SHARED_SCHEMAS.RAW.DATASET_SETTINGS, {
24 | resolution: Schema.string().default("768,768").description("训练图片分辨率,宽x高。支持非正方形,但必须是 64 倍数。"),
25 | enable_bucket: Schema.boolean().default(true).description("启用 arb 桶以允许非固定宽高比的图片"),
26 | min_bucket_reso: Schema.number().default(256).description("arb 桶最小分辨率"),
27 | max_bucket_reso: Schema.number().default(2048).description("arb 桶最大分辨率"),
28 | bucket_reso_steps: Schema.number().default(64).description("arb 桶分辨率划分单位,FLUX 需大于 64"),
29 | })
30 | ).description("数据集设置"),
31 |
32 | // 保存设置
33 | SHARED_SCHEMAS.SAVE_SETTINGS,
34 |
35 | Schema.object({
36 | max_train_epochs: Schema.number().min(1).default(20).description("最大训练 epoch(轮数)"),
37 | train_batch_size: Schema.number().min(1).default(1).description("批量大小, 越高显存占用越高"),
38 | gradient_checkpointing: Schema.boolean().default(true).description("梯度检查点"),
39 | gradient_accumulation_steps: Schema.number().min(1).default(1).description("梯度累加步数"),
40 | network_train_unet_only: Schema.boolean().default(true).description("仅训练 U-Net"),
41 | network_train_text_encoder_only: Schema.boolean().default(false).description("仅训练文本编码器"),
42 | }).description("训练相关参数"),
43 |
44 | // 学习率&优化器设置
45 | SHARED_SCHEMAS.LR_OPTIMIZER,
46 |
47 | Schema.intersect([
48 | Schema.object({
49 | network_module: Schema.union(["networks.lora_flux", "networks.oft_flux", "lycoris.kohya"]).default("networks.lora_flux").description("训练网络模块"),
50 | network_weights: Schema.string().role('filepicker').description("从已有的 LoRA 模型上继续训练,填写路径"),
51 | network_dim: Schema.number().min(1).default(2).description("网络维度,常用 4~128,不是越大越好, 低dim可以降低显存占用"),
52 | network_alpha: Schema.number().min(1).default(16).description("常用值:等于 network_dim 或 network_dim*1/2 或 1。使用较小的 alpha 需要提升学习率"),
53 | network_dropout: Schema.number().step(0.01).default(0).description('dropout 概率 (与 lycoris 不兼容,需要用 lycoris 自带的)'),
54 | scale_weight_norms: Schema.number().step(0.01).min(0).description("最大范数正则化。如果使用,推荐为 1"),
55 | network_args_custom: Schema.array(String).role('table').description('自定义 network_args,一行一个'),
56 | enable_base_weight: Schema.boolean().default(false).description('启用基础权重(差异炼丹)'),
57 | }).description("网络设置"),
58 |
59 | // lycoris 参数
60 | SHARED_SCHEMAS.LYCORIS_MAIN,
61 | SHARED_SCHEMAS.LYCORIS_LOKR,
62 |
63 | SHARED_SCHEMAS.NETWORK_OPTION_BASEWEIGHT,
64 | ]),
65 |
66 | // 预览图设置
67 | SHARED_SCHEMAS.PREVIEW_IMAGE,
68 |
69 | // 日志设置
70 | SHARED_SCHEMAS.LOG_SETTINGS,
71 |
72 | // caption 选项
73 | // FLUX 去除 max_token_length
74 | Schema.object(UpdateSchema(SHARED_SCHEMAS.RAW.CAPTION_SETTINGS, {}, ["max_token_length"])).description("caption(Tag)选项"),
75 |
76 | // 噪声设置
77 | SHARED_SCHEMAS.NOISE_SETTINGS,
78 |
79 | // 数据增强
80 | SHARED_SCHEMAS.DATA_ENCHANCEMENT,
81 |
82 | // 其他选项
83 | SHARED_SCHEMAS.OTHER,
84 |
85 | // 速度优化选项
86 | Schema.object(
87 | UpdateSchema(SHARED_SCHEMAS.RAW.PRECISION_CACHE_BATCH, {
88 | fp8_base: Schema.boolean().default(true).description("对基础模型使用 FP8 精度"),
89 | fp8_base_unet: Schema.boolean().description("仅对 U-Net 使用 FP8 精度(CLIP-L不使用)"),
90 | sdpa: Schema.boolean().default(true).description("启用 sdpa"),
91 | cache_text_encoder_outputs: Schema.boolean().default(true).description("缓存文本编码器的输出,减少显存使用。使用时需要关闭 shuffle_caption"),
92 | cache_text_encoder_outputs_to_disk: Schema.boolean().default(true).description("缓存文本编码器的输出到磁盘"),
93 | }, ["xformers"])
94 | ).description("速度优化选项"),
95 |
96 | // 分布式训练
97 | SHARED_SCHEMAS.DISTRIBUTED_TRAINING
98 | ]);
99 |
--------------------------------------------------------------------------------
/mikazuki/schema/lora-basic.ts:
--------------------------------------------------------------------------------
1 | Schema.intersect([
2 | Schema.object({
3 | pretrained_model_name_or_path: Schema.string().role('filepicker', {type: "model-file"}).default("./sd-models/model.safetensors").description("底模文件路径"),
4 | }).description("训练用模型"),
5 |
6 | Schema.object({
7 | train_data_dir: Schema.string().role('filepicker', { type: "folder", internal: "train-dir" }).default("./train/aki").description("训练数据集路径"),
8 | reg_data_dir: Schema.string().role('filepicker', { type: "folder", internal: "train-dir" }).description("正则化数据集路径。默认留空,不使用正则化图像"),
9 | resolution: Schema.string().default("512,512").description("训练图片分辨率,宽x高。支持非正方形,但必须是 64 倍数。"),
10 | }).description("数据集设置"),
11 |
12 | Schema.object({
13 | output_name: Schema.string().default("aki").description("模型保存名称"),
14 | output_dir: Schema.string().default("./output").role('filepicker', { type: "folder" }).description("模型保存文件夹"),
15 | save_every_n_epochs: Schema.number().default(2).description("每 N epoch(轮)自动保存一次模型"),
16 | }).description("保存设置"),
17 |
18 | Schema.object({
19 | max_train_epochs: Schema.number().min(1).default(10).description("最大训练 epoch(轮数)"),
20 | train_batch_size: Schema.number().min(1).default(1).description("批量大小"),
21 | }).description("训练相关参数"),
22 |
23 | Schema.intersect([
24 | Schema.object({
25 | unet_lr: Schema.string().default("1e-4").description("U-Net 学习率"),
26 | text_encoder_lr: Schema.string().default("1e-5").description("文本编码器学习率"),
27 | lr_scheduler: Schema.union([
28 | "cosine",
29 | "cosine_with_restarts",
30 | "constant",
31 | "constant_with_warmup",
32 | ]).default("cosine_with_restarts").description("学习率调度器设置"),
33 | lr_warmup_steps: Schema.number().default(0).description('学习率预热步数'),
34 | }).description("学习率与优化器设置"),
35 | Schema.union([
36 | Schema.object({
37 | lr_scheduler: Schema.const('cosine_with_restarts'),
38 | lr_scheduler_num_cycles: Schema.number().default(1).description('重启次数'),
39 | }),
40 | Schema.object({}),
41 | ]),
42 | Schema.object({
43 | optimizer_type: Schema.union([
44 | "AdamW8bit",
45 | "Lion",
46 | ]).default("AdamW8bit").description("优化器设置"),
47 | })
48 | ]),
49 |
50 | Schema.intersect([
51 | Schema.object({
52 | enable_preview: Schema.boolean().default(false).description('启用训练预览图'),
53 | }).description('训练预览图设置'),
54 |
55 | Schema.union([
56 | Schema.object({
57 | enable_preview: Schema.const(true).required(),
58 | sample_prompts: Schema.string().role('textarea').default(window.__MIKAZUKI__.SAMPLE_PROMPTS_DEFAULT).description(window.__MIKAZUKI__.SAMPLE_PROMPTS_DESCRIPTION),
59 | sample_sampler: Schema.union(["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver", "dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]).default("euler_a").description("生成预览图所用采样器"),
60 | sample_every_n_epochs: Schema.number().default(2).description("每 N 个 epoch 生成一次预览图"),
61 | }),
62 | Schema.object({}),
63 | ]),
64 | ]),
65 |
66 | Schema.intersect([
67 | Schema.object({
68 | network_weights: Schema.string().role('filepicker', { type: "model-file", internal: "model-saved-file" }).description("从已有的 LoRA 模型上继续训练,填写路径"),
69 | network_dim: Schema.number().min(8).max(256).step(8).default(32).description("网络维度,常用 4~128,不是越大越好, 低dim可以降低显存占用"),
70 | network_alpha: Schema.number().min(1).default(32).description(
71 | "常用值:等于 network_dim 或 network_dim*1/2 或 1。使用较小的 alpha 需要提升学习率。"
72 | ),
73 | }).description("网络设置"),
74 | ]),
75 |
76 | Schema.object({
77 | shuffle_caption: Schema.boolean().default(true).description("训练时随机打乱 tokens"),
78 | keep_tokens: Schema.number().min(0).max(255).step(1).default(0).description("在随机打乱 tokens 时,保留前 N 个不变"),
79 | }).description("caption 选项"),
80 |
81 | Schema.object({
82 | mixed_precision: Schema.union(["no", "fp16", "bf16"]).default("fp16").description("混合精度, RTX30系列以后也可以指定`bf16`"),
83 | no_half_vae: Schema.boolean().description("不使用半精度 VAE,当出现 NaN detected in latents 报错时使用"),
84 | xformers: Schema.boolean().default(true).description("启用 xformers"),
85 | cache_latents: Schema.boolean().default(true).description("缓存图像 latent, 缓存 VAE 输出以减少 VRAM 使用")
86 | }).description("速度优化选项"),
87 | ]);
88 |
--------------------------------------------------------------------------------
/mikazuki/schema/lora-master.ts:
--------------------------------------------------------------------------------
1 | Schema.intersect([
2 | Schema.intersect([
3 | Schema.object({
4 | model_train_type: Schema.union(["sd-lora", "sdxl-lora"]).default("sd-lora").description("训练种类"),
5 | pretrained_model_name_or_path: Schema.string().role('filepicker', { type: "model-file" }).default("./sd-models/model.safetensors").description("底模文件路径"),
6 | resume: Schema.string().role('filepicker', { type: "folder" }).description("从某个 `save_state` 保存的中断状态继续训练,填写文件路径"),
7 | vae: Schema.string().role('filepicker', { type: "model-file" }).description("(可选) VAE 模型文件路径,使用外置 VAE 文件覆盖模型内本身的"),
8 | }).description("训练用模型"),
9 |
10 | Schema.union([
11 | Schema.object({
12 | model_train_type: Schema.const("sd-lora"),
13 | v2: Schema.boolean().default(false).description("底模为 sd2.0 以后的版本需要启用"),
14 | }),
15 | Schema.object({}),
16 | ]),
17 |
18 | Schema.union([
19 | Schema.object({
20 | model_train_type: Schema.const("sd-lora"),
21 | v2: Schema.const(true).required(),
22 | v_parameterization: Schema.boolean().default(false).description("v-parameterization 学习"),
23 | scale_v_pred_loss_like_noise_pred: Schema.boolean().default(false).description("缩放 v-prediction 损失(与v-parameterization配合使用)"),
24 | }),
25 | Schema.object({}),
26 | ]),
27 | ]),
28 |
29 | // 数据集设置
30 | Schema.object(SHARED_SCHEMAS.RAW.DATASET_SETTINGS).description("数据集设置"),
31 |
32 | // 保存设置
33 | SHARED_SCHEMAS.SAVE_SETTINGS,
34 |
35 | Schema.object({
36 | max_train_epochs: Schema.number().min(1).default(10).description("最大训练 epoch(轮数)"),
37 | train_batch_size: Schema.number().min(1).default(1).description("批量大小, 越高显存占用越高"),
38 | gradient_checkpointing: Schema.boolean().default(false).description("梯度检查点"),
39 | gradient_accumulation_steps: Schema.number().min(1).description("梯度累加步数"),
40 | network_train_unet_only: Schema.boolean().default(false).description("仅训练 U-Net 训练SDXL Lora时推荐开启"),
41 | network_train_text_encoder_only: Schema.boolean().default(false).description("仅训练文本编码器"),
42 | }).description("训练相关参数"),
43 |
44 | // 学习率&优化器设置
45 | SHARED_SCHEMAS.LR_OPTIMIZER,
46 |
47 | Schema.intersect([
48 | Schema.object({
49 | network_module: Schema.union(["networks.lora", "networks.dylora", "networks.oft", "lycoris.kohya"]).default("networks.lora").description("训练网络模块"),
50 | network_weights: Schema.string().role('filepicker').description("从已有的 LoRA 模型上继续训练,填写路径"),
51 | network_dim: Schema.number().min(1).default(32).description("网络维度,常用 4~128,不是越大越好, 低dim可以降低显存占用"),
52 | network_alpha: Schema.number().min(1).default(32).description("常用值:等于 network_dim 或 network_dim*1/2 或 1。使用较小的 alpha 需要提升学习率"),
53 | network_dropout: Schema.number().step(0.01).default(0).description('dropout 概率 (与 lycoris 不兼容,需要用 lycoris 自带的)'),
54 | scale_weight_norms: Schema.number().step(0.01).min(0).description("最大范数正则化。如果使用,推荐为 1"),
55 | network_args_custom: Schema.array(String).role('table').description('自定义 network_args,一行一个'),
56 | enable_block_weights: Schema.boolean().default(false).description('启用分层学习率训练(只支持网络模块 networks.lora)'),
57 | enable_base_weight: Schema.boolean().default(false).description('启用基础权重(差异炼丹)'),
58 | }).description("网络设置"),
59 |
60 | // lycoris 参数
61 | SHARED_SCHEMAS.LYCORIS_MAIN,
62 | SHARED_SCHEMAS.LYCORIS_LOKR,
63 |
64 | // dylora 参数
65 | SHARED_SCHEMAS.NETWORK_OPTION_DYLORA,
66 |
67 | // 分层学习率参数
68 | SHARED_SCHEMAS.NETWORK_OPTION_BLOCK_WEIGHTS,
69 |
70 | SHARED_SCHEMAS.NETWORK_OPTION_BASEWEIGHT,
71 | ]),
72 |
73 | // 预览图设置
74 | SHARED_SCHEMAS.PREVIEW_IMAGE,
75 |
76 | // 日志设置
77 | SHARED_SCHEMAS.LOG_SETTINGS,
78 |
79 | // caption 选项
80 | Schema.object(SHARED_SCHEMAS.RAW.CAPTION_SETTINGS).description("caption(Tag)选项"),
81 |
82 | // 噪声设置
83 | SHARED_SCHEMAS.NOISE_SETTINGS,
84 |
85 | // 数据增强
86 | SHARED_SCHEMAS.DATA_ENCHANCEMENT,
87 |
88 | // 其他选项
89 | SHARED_SCHEMAS.OTHER,
90 |
91 | // 速度优化选项
92 | Schema.object(SHARED_SCHEMAS.RAW.PRECISION_CACHE_BATCH).description("速度优化选项"),
93 |
94 | // 分布式训练
95 | SHARED_SCHEMAS.DISTRIBUTED_TRAINING
96 | ]);
97 |
--------------------------------------------------------------------------------
/mikazuki/schema/lumina2-lora.ts:
--------------------------------------------------------------------------------
1 | //使用sd-script的配置
2 | Schema.intersect([
3 | Schema.object({
4 | model_train_type: Schema.string().default("lumina-lora").disabled().description("训练种类"),
5 | pretrained_model_name_or_path: Schema.string().role('filepicker', { type: "model-file" }).default("./sd-models/model.safetensors").description("Lumina 模型路径"),
6 | ae: Schema.string().role('filepicker', { type: "model-file" }).description("AE 模型文件路径"),
7 | gemma2: Schema.string().role('filepicker', { type: "model-file" }).description("gemma2 模型文件路径"),
8 | resume: Schema.string().role('filepicker', { type: "folder" }).description("从某个 `save_state` 保存的中断状态继续训练,填写文件路径"),
9 | }).description("训练用模型"),
10 |
11 | Schema.object({
12 | timestep_sampling: Schema.union(["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"]).default("nextdit_shift").description("时间步采样"),
13 | sigmoid_scale: Schema.number().step(0.001).default(1.0).description("sigmoid 缩放"),
14 | model_prediction_type: Schema.union(["raw", "additive", "sigma_scaled"]).default("raw").description("模型预测类型"),
15 | discrete_flow_shift: Schema.number().step(0.001).default(3.185).description("Euler 调度器离散流位移"),
16 | //loss_type: Schema.union(["l1", "l2", "huber", "smooth_l1"]).default("l2").description("损失函数类型"),
17 | guidance_scale: Schema.number().step(0.01).default(1.0).description("CFG 引导缩放"),
18 | use_flash_attn: Schema.boolean().default(false).description("是否使用 Flash Attention"),
19 | cfg_trunc: Schema.number().step(0.01).default(0.25).description("CFG 截断"),
20 | renorm_cfg: Schema.number().step(0.01).default(1.0).description("重归一化 CFG"),
21 | system_prompt: Schema.string().default("You are an assistant designed to generate high-quality images based on user prompts. ").description("Gemma2b的系统提示"),
22 | }).description("Lumina 专用参数"),
23 |
24 | Schema.object(
25 | UpdateSchema(SHARED_SCHEMAS.RAW.DATASET_SETTINGS, {
26 | resolution: Schema.string().default("1024,1024").description("训练图片分辨率,宽x高。支持非正方形,但必须是 64 倍数。"),
27 | enable_bucket: Schema.boolean().default(true).description("启用 arb 桶以允许非固定宽高比的图片"),
28 | min_bucket_reso: Schema.number().default(256).description("arb 桶最小分辨率"),
29 | max_bucket_reso: Schema.number().default(2048).description("arb 桶最大分辨率"),
30 | bucket_reso_steps: Schema.number().default(64).description("arb 桶分辨率划分单位"),
31 | })
32 | ).description("数据集设置"),
33 |
34 | // 保存设置
35 | SHARED_SCHEMAS.SAVE_SETTINGS,
36 |
37 | Schema.object({
38 | max_train_epochs: Schema.number().min(1).default(10).description("最大训练 epoch(轮数)"),
39 | train_batch_size: Schema.number().min(1).default(2).description("批量大小, 越高显存占用越高"),
40 | gradient_checkpointing: Schema.boolean().default(true).description("梯度检查点"),
41 | gradient_accumulation_steps: Schema.number().min(1).default(1).description("梯度累加步数"),
42 | network_train_unet_only: Schema.boolean().default(true).description("仅训练 U-Net"),
43 | network_train_text_encoder_only: Schema.boolean().default(false).description("仅训练文本编码器"),
44 | }).description("训练相关参数"),
45 |
46 | // 学习率&优化器设置
47 | SHARED_SCHEMAS.LR_OPTIMIZER,
48 |
49 | Schema.intersect([
50 | Schema.object({
51 | network_module: Schema.union(["networks.lora_lumina", "networks.oft_lumina", "lycoris.kohya"]).default("networks.lora_lumina").description("训练网络模块"),
52 | network_weights: Schema.string().role('filepicker').description("从已有的 LoRA 模型上继续训练,填写路径"),
53 | network_dim: Schema.number().min(1).default(16).description("网络维度,常用 4~128,不是越大越好, 低dim可以降低显存占用"),
54 | network_alpha: Schema.number().min(1).default(8).description("常用值:等于 network_dim 或 network_dim*1/2 或 1。使用较小的 alpha 需要提升学习率"),
55 | network_dropout: Schema.number().step(0.01).default(0).description('dropout 概率 (与 lycoris 不兼容,需要用 lycoris 自带的)'),
56 | scale_weight_norms: Schema.number().step(0.01).min(0).default(1.0).description("最大范数正则化。如果使用,推荐为 1"),
57 | network_args_custom: Schema.array(String).role('table').description('自定义 network_args,一行一个'),
58 | enable_base_weight: Schema.boolean().default(false).description('启用基础权重(差异炼丹)'),
59 | }).description("网络设置"),
60 |
61 | // lycoris 参数
62 | SHARED_SCHEMAS.LYCORIS_MAIN,
63 | SHARED_SCHEMAS.LYCORIS_LOKR,
64 |
65 | SHARED_SCHEMAS.NETWORK_OPTION_BASEWEIGHT,
66 | ]),
67 |
68 | // 预览图设置
69 | SHARED_SCHEMAS.PREVIEW_IMAGE,
70 |
71 | // 日志设置
72 | SHARED_SCHEMAS.LOG_SETTINGS,
73 |
74 | // caption 选项
75 | Schema.object(UpdateSchema(SHARED_SCHEMAS.RAW.CAPTION_SETTINGS, {}, ["max_token_length"])).description("caption(Tag)选项"),
76 |
77 | // 噪声设置
78 | SHARED_SCHEMAS.NOISE_SETTINGS,
79 |
80 | // 数据增强
81 | SHARED_SCHEMAS.DATA_ENCHANCEMENT,
82 |
83 | // 其他选项
84 | SHARED_SCHEMAS.OTHER,
85 |
86 | // 速度优化选项
87 | Schema.object(
88 | UpdateSchema(SHARED_SCHEMAS.RAW.PRECISION_CACHE_BATCH, {
89 | fp8_base: Schema.boolean().default(false).description("对基础模型使用 FP8 精度"), // lumina 默认为 false
90 | fp8_base_unet: Schema.boolean().default(false).description("仅对 U-Net 使用 FP8 精度(CLIP-L不使用)"), // lumina 默认为 false
91 | sdpa: Schema.boolean().default(true).description("启用 sdpa"), // 脚本中未明确指定,但通常建议开启
92 | cache_text_encoder_outputs: Schema.boolean().default(true).description("缓存文本编码器的输出,减少显存使用。使用时需要关闭 shuffle_caption"),
93 | cache_text_encoder_outputs_to_disk: Schema.boolean().default(true).description("缓存文本编码器的输出到磁盘"),
94 | }, ["xformers"])
95 | ).description("速度优化选项"),
96 |
97 | // 分布式训练
98 | SHARED_SCHEMAS.DISTRIBUTED_TRAINING
99 | ]);
--------------------------------------------------------------------------------
/mikazuki/schema/sd3-lora.ts:
--------------------------------------------------------------------------------
1 | Schema.intersect([
2 | Schema.object({
3 | model_train_type: Schema.string().default("sd3-lora").disabled().description("训练种类"),
4 | pretrained_model_name_or_path: Schema.string().role('filepicker', { type: "model-file" }).default("./sd-models/model.safetensors").description("SD3 模型路径"),
5 | clip_l: Schema.string().role('filepicker', { type: "model-file" }).description("clip_l 模型文件路径"),
6 | clip_g: Schema.string().role('filepicker', { type: "model-file" }).description("clip_g 模型文件路径"),
7 | t5xxl: Schema.string().role('filepicker', { type: "model-file" }).description("t5xxl 模型文件路径"),
8 | resume: Schema.string().role('filepicker', { type: "folder" }).description("从某个 `save_state` 保存的中断状态继续训练,填写文件路径"),
9 | }).description("训练用模型"),
10 |
11 | Schema.object({
12 | t5xxl_max_token_length: Schema.number().step(1).description("T5XXL 最大 token 长度(不填写使用自动)"),
13 | train_t5xxl: Schema.boolean().default(false).description("训练 T5XXL(不推荐)"),
14 | }).description("SD3 专用参数"),
15 |
16 | Schema.object(
17 | UpdateSchema(SHARED_SCHEMAS.RAW.DATASET_SETTINGS, {
18 | resolution: Schema.string().default("768,768").description("训练图片分辨率,宽x高。支持非正方形,但必须是 64 倍数。"),
19 | enable_bucket: Schema.boolean().default(true).description("启用 arb 桶以允许非固定宽高比的图片"),
20 | min_bucket_reso: Schema.number().default(256).description("arb 桶最小分辨率"),
21 | max_bucket_reso: Schema.number().default(2048).description("arb 桶最大分辨率"),
22 | bucket_reso_steps: Schema.number().default(64).description("arb 桶分辨率划分单位"),
23 | })
24 | ).description("数据集设置"),
25 |
26 | // 保存设置
27 | SHARED_SCHEMAS.SAVE_SETTINGS,
28 |
29 | Schema.object({
30 | max_train_epochs: Schema.number().min(1).default(20).description("最大训练 epoch(轮数)"),
31 | train_batch_size: Schema.number().min(1).default(1).description("批量大小, 越高显存占用越高"),
32 | gradient_checkpointing: Schema.boolean().default(true).description("梯度检查点"),
33 | gradient_accumulation_steps: Schema.number().min(1).default(1).description("梯度累加步数"),
34 | network_train_unet_only: Schema.boolean().default(true).description("仅训练 U-Net"),
35 | network_train_text_encoder_only: Schema.boolean().default(false).description("仅训练文本编码器"),
36 | }).description("训练相关参数"),
37 |
38 | // 学习率&优化器设置
39 | SHARED_SCHEMAS.LR_OPTIMIZER,
40 |
41 | Schema.intersect([
42 | Schema.object({
43 | network_module: Schema.union(["networks.lora_sd3", "lycoris.kohya"]).default("networks.lora_sd3").description("训练网络模块"),
44 | network_weights: Schema.string().role('filepicker').description("从已有的 LoRA 模型上继续训练,填写路径"),
45 | network_dim: Schema.number().min(1).default(4).description("网络维度,常用 4~128,不是越大越好, 低dim可以降低显存占用"),
46 | network_alpha: Schema.number().min(1).default(1).description("常用值:等于 network_dim 或 network_dim*1/2 或 1。使用较小的 alpha 需要提升学习率"),
47 | network_args_custom: Schema.array(String).role('table').description('自定义 network_args,一行一个'),
48 | enable_base_weight: Schema.boolean().default(false).description('启用基础权重(差异炼丹)'),
49 | }).description("网络设置"),
50 |
51 | // lycoris 参数
52 | SHARED_SCHEMAS.LYCORIS_MAIN,
53 | SHARED_SCHEMAS.LYCORIS_LOKR,
54 |
55 | SHARED_SCHEMAS.NETWORK_OPTION_BASEWEIGHT,
56 | ]),
57 |
58 | // 预览图设置
59 | SHARED_SCHEMAS.PREVIEW_IMAGE,
60 |
61 | // 日志设置
62 | SHARED_SCHEMAS.LOG_SETTINGS,
63 |
64 | // caption 选项
65 | // FLUX 去除 max_token_length
66 | Schema.object(UpdateSchema(SHARED_SCHEMAS.RAW.CAPTION_SETTINGS, {}, ["max_token_length"])).description("caption(Tag)选项"),
67 |
68 | // 噪声设置
69 | SHARED_SCHEMAS.NOISE_SETTINGS,
70 |
71 | // 数据增强
72 | SHARED_SCHEMAS.DATA_ENCHANCEMENT,
73 |
74 | // 其他选项
75 | SHARED_SCHEMAS.OTHER,
76 |
77 | // 速度优化选项
78 | Schema.object(
79 | UpdateSchema(SHARED_SCHEMAS.RAW.PRECISION_CACHE_BATCH, {
80 | fp8_base: Schema.boolean().default(true).description("对基础模型使用 FP8 精度"),
81 | fp8_base_unet: Schema.boolean().description("仅对 U-Net 使用 FP8 精度(CLIP-L不使用)"),
82 | sdpa: Schema.boolean().default(true).description("启用 sdpa"),
83 | cache_text_encoder_outputs: Schema.boolean().default(true).description("缓存文本编码器的输出,减少显存使用。使用时需要关闭 shuffle_caption"),
84 | cache_text_encoder_outputs_to_disk: Schema.boolean().default(true).description("缓存文本编码器的输出到磁盘"),
85 | }, ["xformers"])
86 | ).description("速度优化选项"),
87 |
88 | // 分布式训练
89 | SHARED_SCHEMAS.DISTRIBUTED_TRAINING
90 | ]);
91 |
--------------------------------------------------------------------------------
/mikazuki/scripts/fix_scripts_python_executable_path.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import re
4 | from pathlib import Path
5 |
6 | py_path = sys.executable
7 | scripts_path = Path(sys.executable).parent
8 |
9 | if scripts_path.name != "Scripts":
10 | print("Seems your env not venv, do you want to continue? [y/n]")
11 | sure = input()
12 | if sure != "y":
13 | sys.exit(1)
14 |
15 | scripts_list = os.listdir(scripts_path)
16 |
17 | for script in scripts_list:
18 | if not script.endswith(".exe") or script in ["python.exe", "pythonw.exe"]:
19 | continue
20 |
21 | with open(os.path.join(scripts_path, script), "rb+") as f:
22 | s = f.read()
23 | spl = re.split(b'(#!.*python\.exe)', s)
24 | if len(spl) == 3:
25 | spl[1] = bytes(b"#!"+sys.executable.encode())
26 | f.seek(0)
27 | f.write(b''.join(spl))
28 | print(f"fixed {script}")
--------------------------------------------------------------------------------
/mikazuki/scripts/torch_check.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | def check_torch_gpu():
4 | try:
5 | import torch
6 | print(f'Torch {torch.__version__}')
7 | if torch.cuda.is_available():
8 | if torch.version.cuda:
9 | print(
10 | f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}')
11 | for device in [torch.cuda.device(i) for i in range(torch.cuda.device_count())]:
12 | print(f'Torch detected GPU: {torch.cuda.get_device_name(device)} VRAM {round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}')
13 | else:
14 | print("Torch is not able to use GPU, please check your torch installation.\n Use --skip-prepare-environment to disable this check")
15 | except Exception as e:
16 | print(f'Could not load torch: {e}')
17 | sys.exit(1)
18 |
19 | check_torch_gpu()
--------------------------------------------------------------------------------
/mikazuki/tagger/dbimutils.py:
--------------------------------------------------------------------------------
1 | # DanBooru IMage Utility functions
2 |
3 | import cv2
4 | import numpy as np
5 | from PIL import Image
6 |
7 |
8 | def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
9 | if img.endswith(".gif"):
10 | img = Image.open(img)
11 | img = img.convert("RGB")
12 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
13 | else:
14 | img = cv2.imread(img, flag)
15 | return img
16 |
17 |
18 | def smart_24bit(img):
19 | if img.dtype is np.dtype(np.uint16):
20 | img = (img / 257).astype(np.uint8)
21 |
22 | if len(img.shape) == 2:
23 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
24 | elif img.shape[2] == 4:
25 | trans_mask = img[:, :, 3] == 0
26 | img[trans_mask] = [255, 255, 255, 255]
27 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
28 | return img
29 |
30 |
31 | def make_square(img, target_size):
32 | old_size = img.shape[:2]
33 | desired_size = max(old_size)
34 | desired_size = max(desired_size, target_size)
35 |
36 | delta_w = desired_size - old_size[1]
37 | delta_h = desired_size - old_size[0]
38 | top, bottom = delta_h // 2, delta_h - (delta_h // 2)
39 | left, right = delta_w // 2, delta_w - (delta_w // 2)
40 |
41 | color = [255, 255, 255]
42 | new_im = cv2.copyMakeBorder(
43 | img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
44 | )
45 | return new_im
46 |
47 |
48 | def smart_resize(img, size):
49 | # Assumes the image has already gone through make_square
50 | if img.shape[0] > size:
51 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
52 | elif img.shape[0] < size:
53 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
54 | return img
55 |
--------------------------------------------------------------------------------
/mikazuki/tagger/format.py:
--------------------------------------------------------------------------------
1 | import re
2 | import hashlib
3 |
4 | from typing import Dict, Callable, NamedTuple
5 | from pathlib import Path
6 |
7 |
8 | class Info(NamedTuple):
9 | path: Path
10 | output_ext: str
11 |
12 |
13 | def hash(i: Info, algo='sha1') -> str:
14 | try:
15 | hash = hashlib.new(algo)
16 | except ImportError:
17 | raise ValueError(f"'{algo}' is invalid hash algorithm")
18 |
19 | # TODO: is okay to hash large image?
20 | with open(i.path, 'rb') as file:
21 | hash.update(file.read())
22 |
23 | return hash.hexdigest()
24 |
25 |
26 | pattern = re.compile(r'\[([\w:]+)\]')
27 |
28 | # all function must returns string or raise TypeError or ValueError
29 | # other errors will cause the extension error
30 | available_formats: Dict[str, Callable] = {
31 | 'name': lambda i: i.path.stem,
32 | 'extension': lambda i: i.path.suffix[1:],
33 | 'hash': hash,
34 |
35 | 'output_extension': lambda i: i.output_ext
36 | }
37 |
38 |
39 | def format(match: re.Match, info: Info) -> str:
40 | matches = match[1].split(':')
41 | name, args = matches[0], matches[1:]
42 |
43 | if name not in available_formats:
44 | return match[0]
45 |
46 | return available_formats[name](info, *args)
47 |
--------------------------------------------------------------------------------
/mikazuki/tasks.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import sys
3 | import os
4 | import threading
5 | import uuid
6 | from enum import Enum
7 | from typing import Dict, List
8 | from subprocess import Popen, PIPE, TimeoutExpired, CalledProcessError, CompletedProcess
9 | import psutil
10 |
11 | from mikazuki.log import log
12 |
13 | try:
14 | import msvcrt
15 | import _winapi
16 | _mswindows = True
17 | except ModuleNotFoundError:
18 | _mswindows = False
19 |
20 |
21 | def kill_proc_tree(pid, including_parent=True):
22 | parent = psutil.Process(pid)
23 | children = parent.children(recursive=True)
24 | for child in children:
25 | child.kill()
26 | gone, still_alive = psutil.wait_procs(children, timeout=5)
27 | if including_parent:
28 | parent.kill()
29 | parent.wait(5)
30 |
31 |
32 | class TaskStatus(Enum):
33 | CREATED = 0
34 | RUNNING = 1
35 | FINISHED = 2
36 | TERMINATED = 3
37 |
38 |
39 | class Task:
40 | def __init__(self, task_id, command, environ=None):
41 | self.task_id = task_id
42 | self.lock = threading.Lock()
43 | self.command = command
44 | self.status = TaskStatus.CREATED
45 | self.environ = environ or os.environ
46 |
47 | def communicate(self, input=None, timeout=None):
48 | try:
49 | stdout, stderr = self.process.communicate(input, timeout=timeout)
50 | except TimeoutExpired as exc:
51 | self.process.kill()
52 | if _mswindows:
53 | exc.stdout, exc.stderr = self.process.communicate()
54 | else:
55 | self.process.wait()
56 | raise
57 | except:
58 | self.process.kill()
59 | raise
60 | retcode = self.process.poll()
61 | self.status = TaskStatus.FINISHED
62 | return CompletedProcess(self.process.args, retcode, stdout, stderr)
63 |
64 | def wait(self):
65 | self.process.wait()
66 | self.status = TaskStatus.FINISHED
67 |
68 | def execute(self):
69 | self.status = TaskStatus.RUNNING
70 | self.process = subprocess.Popen(self.command, env=self.environ)
71 |
72 | def terminate(self):
73 | try:
74 | kill_proc_tree(self.process.pid, False)
75 | except Exception as e:
76 | log.error(f"Error when killing process: {e}")
77 | return
78 | finally:
79 | self.status = TaskStatus.TERMINATED
80 |
81 |
82 | class TaskManager:
83 | def __init__(self, max_concurrent=1) -> None:
84 | self.max_concurrent = max_concurrent
85 | self.tasks: Dict[Task] = {}
86 |
87 | def create_task(self, command: List[str], environ):
88 | running_tasks = [t for _, t in self.tasks.items() if t.status == TaskStatus.RUNNING]
89 | if len(running_tasks) >= self.max_concurrent:
90 | log.error(
91 | f"Unable to create a task because there are already {len(running_tasks)} tasks running, reaching the maximum concurrent limit. / 无法创建任务,因为已经有 {len(running_tasks)} 个任务正在运行,已达到最大并发限制。")
92 | return None
93 | task_id = str(uuid.uuid4())
94 | task = Task(task_id=task_id, command=command, environ=environ)
95 | self.tasks[task_id] = task
96 | # task.execute() # breaking change
97 | log.info(f"Task {task_id} created")
98 | return task
99 |
100 | def add_task(self, task_id: str, task: Task):
101 | self.tasks[task_id] = task
102 |
103 | def terminate_task(self, task_id: str):
104 | if task_id in self.tasks:
105 | task = self.tasks[task_id]
106 | task.terminate()
107 |
108 | def wait_for_process(self, task_id: str):
109 | if task_id in self.tasks:
110 | task: Task = self.tasks[task_id]
111 | task.wait()
112 |
113 | def dump(self) -> List[Dict]:
114 | return [
115 | {
116 | "id": task.task_id,
117 | "status": task.status.name,
118 | }
119 | for task in self.tasks.values()
120 | ]
121 |
122 |
123 | tm = TaskManager()
124 |
--------------------------------------------------------------------------------
/mikazuki/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "ES2020",
4 | "module": "commonjs",
5 | },
6 | "include": [
7 | "**/*.ts"
8 | ],
9 | }
--------------------------------------------------------------------------------
/mikazuki/utils/devices.py:
--------------------------------------------------------------------------------
1 | from mikazuki.log import log
2 | from packaging.version import Version
3 |
4 | available_devices = []
5 | printable_devices = []
6 |
7 |
8 | def check_torch_gpu():
9 | try:
10 | import torch
11 | log.info(f'Torch {torch.__version__}')
12 | if not torch.cuda.is_available():
13 | log.error("Torch is not able to use GPU, please check your torch installation.\n Use --skip-prepare-environment to disable this check")
14 | log.error("!!!Torch 无法使用 GPU,您无法正常开始训练!!!\n您的显卡可能并不支持,或是 torch 安装有误。请检查您的 torch 安装。")
15 | if "cpu" in torch.__version__:
16 | log.error("You are using torch CPU, please install torch GPU version by run install script again.")
17 | log.error("!!!您正在使用 CPU 版本的 torch,无法正常开始训练。请重新运行安装脚本!!!")
18 | return
19 |
20 | if Version(torch.__version__) < Version("2.3.0"):
21 | log.warning("Torch version is lower than 2.3.0, which may not be able to train FLUX model properly. Please re-run the installation script (install.ps1 or install.bash) to upgrade Torch.")
22 | log.warning("!!!Torch 版本低于 2.3.0,将无法正常训练 FLUX 模型。请考虑重新运行安装脚本以升级 Torch!!!")
23 | log.warning("!!!若您正在使用训练包,请直接下载最新训练包!!!")
24 |
25 | if torch.version.cuda:
26 | log.info(
27 | f'Torch backend: nVidia CUDA {torch.version.cuda} cuDNN {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}')
28 | elif torch.version.hip:
29 | log.info(f'Torch backend: AMD ROCm HIP {torch.version.hip}')
30 |
31 | devices = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
32 |
33 | for pos, device in enumerate(devices):
34 | name = torch.cuda.get_device_name(device)
35 | memory = torch.cuda.get_device_properties(device).total_memory
36 | available_devices.append(device)
37 | printable_devices.append(f"GPU {pos}: {name} ({round(memory / (1024**3))} GB)")
38 | log.info(
39 | f'Torch detected GPU: {name} VRAM {round(memory / 1024 / 1024)} Arch {torch.cuda.get_device_capability(device)} Cores {torch.cuda.get_device_properties(device).multi_processor_count}')
40 | except Exception as e:
41 | log.error(f'Could not load torch: {e}')
42 |
--------------------------------------------------------------------------------
/mikazuki/utils/tk_window.py:
--------------------------------------------------------------------------------
1 | import os
2 | from mikazuki.log import log
3 | try:
4 | import tkinter
5 | from tkinter.filedialog import askdirectory, askopenfilename
6 | except ImportError:
7 | tkinter = None
8 | askdirectory = None
9 | askopenfilename = None
10 | log.warning("tkinter not found, file selector will not work.")
11 |
12 | last_dir = ""
13 |
14 |
15 | def tk_window():
16 | window = tkinter.Tk()
17 | window.wm_attributes('-topmost', 1)
18 | window.withdraw()
19 |
20 |
21 | def open_file_selector(
22 | initialdir="",
23 | title="Select a file",
24 | filetypes="*") -> str:
25 | global last_dir
26 | if last_dir != "":
27 | initialdir = last_dir
28 | elif initialdir == "":
29 | initialdir = os.getcwd()
30 | try:
31 | tk_window()
32 | filename = askopenfilename(
33 | initialdir=initialdir, title=title,
34 | filetypes=filetypes
35 | )
36 | last_dir = os.path.dirname(filename)
37 | return filename
38 | except:
39 | return ""
40 |
41 |
42 | def open_directory_selector(initialdir) -> str:
43 | global last_dir
44 | if last_dir != "":
45 | initialdir = last_dir
46 | elif initialdir == "":
47 | initialdir = os.getcwd()
48 | try:
49 | tk_window()
50 | directory = askdirectory(
51 | initialdir=initialdir
52 | )
53 | last_dir = directory
54 | return directory
55 | except:
56 | return ""
57 |
--------------------------------------------------------------------------------
/output/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Akegarasu/lora-scripts/e0f5194815203093659d6ec280b9362b9792c070/output/.keep
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.33.0
2 | transformers==4.44.0
3 | diffusers[torch]==0.25.0
4 | ftfy==6.1.1
5 | # albumentations==1.3.0
6 | opencv-python==4.8.1.78
7 | einops==0.7.0
8 | pytorch-lightning==1.9.0
9 | bitsandbytes==0.46.0
10 | lion-pytorch==0.1.2
11 | schedulefree==1.4
12 | pytorch-optimizer==3.5.0
13 | prodigy-plus-schedule-free==1.9.0
14 | prodigyopt==1.1.2
15 | tensorboard==2.10.1
16 | safetensors==0.4.4
17 | prodigy-plus-schedule-free
18 | # gradio==3.16.2
19 | altair==4.2.2
20 | easygui==0.98.3
21 | toml==0.10.2
22 | voluptuous==0.13.1
23 | huggingface-hub==0.24.5
24 | # for Image utils
25 | imagesize==1.4.1
26 | # for T5XXL tokenizer (SD3/FLUX)
27 | sentencepiece==0.2.0
28 | # for ui
29 | rich==13.7.0
30 | pandas
31 | scipy
32 | requests
33 | pillow
34 | numpy==1.26.4
35 | # <=2.0.0
36 | gradio==3.44.2
37 | fastapi==0.95.1
38 | uvicorn==0.22.0
39 | wandb==0.16.2
40 | httpx==0.24.1
41 | # extra
42 | open-clip-torch==2.20.0
43 | lycoris-lora==2.1.0.post3
44 | dadaptation==3.1
45 |
--------------------------------------------------------------------------------
/resize.ps1:
--------------------------------------------------------------------------------
1 | # LoRA resize script by @bdsqlsz
2 |
3 | $save_precision = "fp16" # precision in saving, default float | 保存精度, 可选 float、fp16、bf16, 默认 float
4 | $new_rank = 4 # dim rank of output LoRA | dim rank等级, 默认 4
5 | $model = "./output/lora_name.safetensors" # original LoRA model path need to resize, save as cpkt or safetensors | 需要调整大小的模型路径, 保存格式 cpkt 或 safetensors
6 | $save_to = "./output/lora_name_new.safetensors" # output LoRA model path, save as ckpt or safetensors | 输出路径, 保存格式 cpkt 或 safetensors
7 | $device = "cuda" # device to use, cuda for GPU | 使用 GPU跑, 默认 CPU
8 | $verbose = 1 # display verbose resizing information | rank变更时, 显示详细信息
9 | $dynamic_method = "" # Specify dynamic resizing method, --new_rank is used as a hard limit for max rank | 动态调节大小,可选"sv_ratio", "sv_fro", "sv_cumulative",默认无
10 | $dynamic_param = "" # Specify target for dynamic reduction | 动态参数,sv_ratio模式推荐1~2, sv_cumulative模式0~1, sv_fro模式0~1, 比sv_cumulative要高
11 |
12 |
13 | # Activate python venv
14 | .\venv\Scripts\activate
15 |
16 | $Env:HF_HOME = "huggingface"
17 | $ext_args = [System.Collections.ArrayList]::new()
18 |
19 | if ($verbose) {
20 | [void]$ext_args.Add("--verbose")
21 | }
22 |
23 | if ($dynamic_method) {
24 | [void]$ext_args.Add("--dynamic_method=" + $dynamic_method)
25 | }
26 |
27 | if ($dynamic_param) {
28 | [void]$ext_args.Add("--dynamic_param=" + $dynamic_param)
29 | }
30 |
31 | # run resize
32 | accelerate launch --num_cpu_threads_per_process=8 "./scripts/networks/resize_lora.py" `
33 | --save_precision=$save_precision `
34 | --new_rank=$new_rank `
35 | --model=$model `
36 | --save_to=$save_to `
37 | --device=$device `
38 | $ext_args
39 |
40 | Write-Output "Resize finished"
41 | Read-Host | Out-Null ;
42 |
--------------------------------------------------------------------------------
/run.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "5e35269a-ec20-41a3-93a6-da798c3a8401",
6 | "metadata": {},
7 | "source": [
8 | "# LoRA Train UI: SD-Trainer\n",
9 | "\n",
10 | "LoRA Training UI By [Akegarasu](https://github.com/Akegarasu)\n",
11 | "User Guide:https://github.com/Akegarasu/lora-scripts/blob/main/README.md\n",
12 | "\n",
13 | "LoRA 训练 By [秋葉aaaki@bilibili](https://space.bilibili.com/12566101)\n",
14 | "使用方法:https://www.bilibili.com/read/cv24050162/"
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "id": "12c2a3d0-9aec-4680-9b8a-cb02cac48de6",
20 | "metadata": {},
21 | "source": [
22 | "### Run | 运行"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "id": "7ae0678f-69df-4a12-a0bc-1325e52e9122",
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "import sys\n",
33 | "!export HF_HOME=huggingface && $sys.executable gui.py --host 0.0.0.0"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "id": "99edaa2b-9ba2-4fde-9b2e-af5dc8bf7062",
39 | "metadata": {},
40 | "source": [
41 | "## Update | 更新"
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {},
47 | "source": [
48 | "### Github"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "!git pull && git submodule init && git submodule update"
58 | ]
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "metadata": {},
63 | "source": [
64 | "### 国内镜像加速"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "!export GIT_CONFIG_GLOBAL=./assets/gitconfig-cn && export GIT_TERMINAL_PROMPT=false && git pull && git submodule init && git submodule update"
74 | ]
75 | }
76 | ],
77 | "metadata": {
78 | "kernelspec": {
79 | "display_name": "Python 3 (ipykernel)",
80 | "language": "python",
81 | "name": "python3"
82 | },
83 | "language_info": {
84 | "codemirror_mode": {
85 | "name": "ipython",
86 | "version": 3
87 | },
88 | "file_extension": ".py",
89 | "mimetype": "text/x-python",
90 | "name": "python",
91 | "nbconvert_exporter": "python",
92 | "pygments_lexer": "ipython3",
93 | "version": "3.10.8"
94 | }
95 | },
96 | "nbformat": 4,
97 | "nbformat_minor": 5
98 | }
99 |
--------------------------------------------------------------------------------
/run_gui.ps1:
--------------------------------------------------------------------------------
1 | $Env:HF_HOME = "huggingface"
2 | $Env:PYTHONUTF8 = "1"
3 |
4 | if (Test-Path -Path "venv\Scripts\activate") {
5 | Write-Host -ForegroundColor green "Activating virtual environment..."
6 | .\venv\Scripts\activate
7 | }
8 | elseif (Test-Path -Path "python\python.exe") {
9 | Write-Host -ForegroundColor green "Using python from python folder..."
10 | $py_path = (Get-Item "python").FullName
11 | $env:PATH = "$py_path;$env:PATH"
12 | }
13 | else {
14 | Write-Host -ForegroundColor Blue "No virtual environment found, using system python..."
15 | }
16 |
17 | python gui.py
--------------------------------------------------------------------------------
/run_gui.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export HF_HOME=huggingface
4 | export PYTHONUTF8=1
5 |
6 | python gui.py "$@"
7 |
8 |
--------------------------------------------------------------------------------
/run_gui_cn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source "./venv/bin/activate"
4 |
5 | export HF_HOME=huggingface
6 | export HF_ENDPOINT=https://hf-mirror.com
7 | export PIP_INDEX_URL="https://pypi.tuna.tsinghua.edu.cn/simple"
8 | export PYTHONUTF8=1
9 |
10 | python gui.py "$@"
11 |
12 |
13 |
--------------------------------------------------------------------------------
/scripts/dev/.gitignore:
--------------------------------------------------------------------------------
1 | logs
2 | __pycache__
3 | wd14_tagger_model
4 | venv
5 | *.egg-info
6 | build
7 | .vscode
8 | wandb
9 |
--------------------------------------------------------------------------------
/scripts/dev/COMMIT_ID:
--------------------------------------------------------------------------------
1 | 6364379f17d50add1696b0672f39c25c08a006b6
--------------------------------------------------------------------------------
/scripts/dev/README-ja.md:
--------------------------------------------------------------------------------
1 | ## リポジトリについて
2 | Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
3 |
4 | [README in English](./README.md) ←更新情報はこちらにあります
5 |
6 | 開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。
7 |
8 | FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。
9 |
10 | GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
11 |
12 | 以下のスクリプトがあります。
13 |
14 | * DreamBooth、U-NetおよびText Encoderの学習をサポート
15 | * fine-tuning、同上
16 | * LoRAの学習をサポート
17 | * 画像生成
18 | * モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換)
19 |
20 | ## 使用法について
21 |
22 | * [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど
23 | * [データセット設定](./docs/config_README-ja.md)
24 | * [SDXL学習](./docs/train_SDXL-en.md) (英語版)
25 | * [DreamBoothの学習について](./docs/train_db_README-ja.md)
26 | * [fine-tuningのガイド](./docs/fine_tune_README_ja.md):
27 | * [LoRAの学習について](./docs/train_network_README-ja.md)
28 | * [Textual Inversionの学習について](./docs/train_ti_README-ja.md)
29 | * [画像生成スクリプト](./docs/gen_img_README-ja.md)
30 | * note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
31 |
32 | ## Windowsでの動作に必要なプログラム
33 |
34 | Python 3.10.6およびGitが必要です。
35 |
36 | - Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
37 | - git: https://git-scm.com/download/win
38 |
39 | Python 3.10.x、3.11.x、3.12.xでも恐らく動作しますが、3.10.6でテストしています。
40 |
41 | PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。
42 | (venvに限らずスクリプトの実行が可能になりますので注意してください。)
43 |
44 | - PowerShellを管理者として開きます。
45 | - 「Set-ExecutionPolicy Unrestricted」と入力し、Yと答えます。
46 | - 管理者のPowerShellを閉じます。
47 |
48 | ## Windows環境でのインストール
49 |
50 | スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.2以降でも恐らく動作します。
51 |
52 | (なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。)
53 |
54 | PowerShellを使う場合、通常の(管理者ではない)PowerShellを開き以下を順に実行します。
55 |
56 | ```powershell
57 | git clone https://github.com/kohya-ss/sd-scripts.git
58 | cd sd-scripts
59 |
60 | python -m venv venv
61 | .\venv\Scripts\activate
62 |
63 | pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
64 | pip install --upgrade -r requirements.txt
65 | pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
66 |
67 | accelerate config
68 | ```
69 |
70 | コマンドプロンプトでも同一です。
71 |
72 | 注:`bitsandbytes==0.44.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。
73 |
74 | この例では PyTorch および xfomers は2.1.2/CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。
75 |
76 | PyTorch 2.2以降を用いる場合は、`torch==2.1.2` と `torchvision==0.16.2` 、および `xformers==0.0.23.post1` を適宜変更してください。
77 |
78 | accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。)
79 |
80 | ```txt
81 | - This machine
82 | - No distributed training
83 | - NO
84 | - NO
85 | - NO
86 | - all
87 | - fp16
88 | ```
89 |
90 | ※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問(
91 | ``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)
92 |
93 | ## アップグレード
94 |
95 | 新しいリリースがあった場合、以下のコマンドで更新できます。
96 |
97 | ```powershell
98 | cd sd-scripts
99 | git pull
100 | .\venv\Scripts\activate
101 | pip install --use-pep517 --upgrade -r requirements.txt
102 | ```
103 |
104 | コマンドが成功すれば新しいバージョンが使用できます。
105 |
106 | ## 謝意
107 |
108 | LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。
109 |
110 | Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。
111 |
112 | ## ライセンス
113 |
114 | スクリプトのライセンスはASL 2.0ですが(Diffusersおよびcloneofsimo氏のリポジトリ由来のものも同様)、一部他のライセンスのコードを含みます。
115 |
116 | [Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT
117 |
118 | [bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
119 |
120 | [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
121 |
122 | ## その他の情報
123 |
124 | ### LoRAの名称について
125 |
126 | `train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
127 |
128 | 1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
129 |
130 | Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
131 |
132 | 2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
133 |
134 | 1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
135 |
136 | デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args` に `conv_dim` を指定してください。
137 |
138 |
143 |
144 | ### 学習中のサンプル画像生成
145 |
146 | プロンプトファイルは例えば以下のようになります。
147 |
148 | ```
149 | # prompt 1
150 | masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
151 |
152 | # prompt 2
153 | masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
154 | ```
155 |
156 | `#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
157 |
158 | * `--n` Negative prompt up to the next option.
159 | * `--w` Specifies the width of the generated image.
160 | * `--h` Specifies the height of the generated image.
161 | * `--d` Specifies the seed of the generated image.
162 | * `--l` Specifies the CFG scale of the generated image.
163 | * `--s` Specifies the number of steps in the generation.
164 |
165 | `( )` や `[ ]` などの重みづけも動作します。
166 |
--------------------------------------------------------------------------------
/scripts/dev/_typos.toml:
--------------------------------------------------------------------------------
1 | # Files for typos
2 | # Instruction: https://github.com/marketplace/actions/typos-action#getting-started
3 |
4 | [default.extend-identifiers]
5 | ddPn08="ddPn08"
6 |
7 | [default.extend-words]
8 | NIN="NIN"
9 | parms="parms"
10 | nin="nin"
11 | extention="extention" # Intentionally left
12 | nd="nd"
13 | shs="shs"
14 | sts="sts"
15 | scs="scs"
16 | cpc="cpc"
17 | coc="coc"
18 | cic="cic"
19 | msm="msm"
20 | usu="usu"
21 | ici="ici"
22 | lvl="lvl"
23 | dii="dii"
24 | muk="muk"
25 | ori="ori"
26 | hru="hru"
27 | rik="rik"
28 | koo="koo"
29 | yos="yos"
30 | wn="wn"
31 | hime="hime"
32 |
33 |
34 | [files]
35 | extend-exclude = ["_typos.toml", "venv"]
36 |
--------------------------------------------------------------------------------
/scripts/dev/finetune/blip/med_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertModel"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 3072,
11 | "layer_norm_eps": 1e-12,
12 | "max_position_embeddings": 512,
13 | "model_type": "bert",
14 | "num_attention_heads": 12,
15 | "num_hidden_layers": 12,
16 | "pad_token_id": 0,
17 | "type_vocab_size": 2,
18 | "vocab_size": 30524,
19 | "encoder_width": 768,
20 | "add_cross_attention": true
21 | }
22 |
--------------------------------------------------------------------------------
/scripts/dev/finetune/hypernetwork_nai.py:
--------------------------------------------------------------------------------
1 | # NAI compatible
2 |
3 | import torch
4 |
5 |
6 | class HypernetworkModule(torch.nn.Module):
7 | def __init__(self, dim, multiplier=1.0):
8 | super().__init__()
9 |
10 | linear1 = torch.nn.Linear(dim, dim * 2)
11 | linear2 = torch.nn.Linear(dim * 2, dim)
12 | linear1.weight.data.normal_(mean=0.0, std=0.01)
13 | linear1.bias.data.zero_()
14 | linear2.weight.data.normal_(mean=0.0, std=0.01)
15 | linear2.bias.data.zero_()
16 | linears = [linear1, linear2]
17 |
18 | self.linear = torch.nn.Sequential(*linears)
19 | self.multiplier = multiplier
20 |
21 | def forward(self, x):
22 | return x + self.linear(x) * self.multiplier
23 |
24 |
25 | class Hypernetwork(torch.nn.Module):
26 | enable_sizes = [320, 640, 768, 1280]
27 | # return self.modules[Hypernetwork.enable_sizes.index(size)]
28 |
29 | def __init__(self, multiplier=1.0) -> None:
30 | super().__init__()
31 | self.modules = []
32 | for size in Hypernetwork.enable_sizes:
33 | self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
34 | self.register_module(f"{size}_0", self.modules[-1][0])
35 | self.register_module(f"{size}_1", self.modules[-1][1])
36 |
37 | def apply_to_stable_diffusion(self, text_encoder, vae, unet):
38 | blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
39 | for block in blocks:
40 | for subblk in block:
41 | if 'SpatialTransformer' in str(type(subblk)):
42 | for tf_block in subblk.transformer_blocks:
43 | for attn in [tf_block.attn1, tf_block.attn2]:
44 | size = attn.context_dim
45 | if size in Hypernetwork.enable_sizes:
46 | attn.hypernetwork = self
47 | else:
48 | attn.hypernetwork = None
49 |
50 | def apply_to_diffusers(self, text_encoder, vae, unet):
51 | blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
52 | for block in blocks:
53 | if hasattr(block, 'attentions'):
54 | for subblk in block.attentions:
55 | if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
56 | for tf_block in subblk.transformer_blocks:
57 | for attn in [tf_block.attn1, tf_block.attn2]:
58 | size = attn.to_k.in_features
59 | if size in Hypernetwork.enable_sizes:
60 | attn.hypernetwork = self
61 | else:
62 | attn.hypernetwork = None
63 | return True # TODO error checking
64 |
65 | def forward(self, x, context):
66 | size = context.shape[-1]
67 | assert size in Hypernetwork.enable_sizes
68 | module = self.modules[Hypernetwork.enable_sizes.index(size)]
69 | return module[0].forward(context), module[1].forward(context)
70 |
71 | def load_from_state_dict(self, state_dict):
72 | # old ver to new ver
73 | changes = {
74 | 'linear1.bias': 'linear.0.bias',
75 | 'linear1.weight': 'linear.0.weight',
76 | 'linear2.bias': 'linear.1.bias',
77 | 'linear2.weight': 'linear.1.weight',
78 | }
79 | for key_from, key_to in changes.items():
80 | if key_from in state_dict:
81 | state_dict[key_to] = state_dict[key_from]
82 | del state_dict[key_from]
83 |
84 | for size, sd in state_dict.items():
85 | if type(size) == int:
86 | self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
87 | self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
88 | return True
89 |
90 | def get_state_dict(self):
91 | state_dict = {}
92 | for i, size in enumerate(Hypernetwork.enable_sizes):
93 | sd0 = self.modules[i][0].state_dict()
94 | sd1 = self.modules[i][1].state_dict()
95 | state_dict[size] = [sd0, sd1]
96 | return state_dict
97 |
--------------------------------------------------------------------------------
/scripts/dev/finetune/merge_captions_to_metadata.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from pathlib import Path
4 | from typing import List
5 | from tqdm import tqdm
6 | import library.train_util as train_util
7 | import os
8 | from library.utils import setup_logging
9 |
10 | setup_logging()
11 | import logging
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def main(args):
17 | assert not args.recursive or (
18 | args.recursive and args.full_path
19 | ), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
20 |
21 | train_data_dir_path = Path(args.train_data_dir)
22 | image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
23 | logger.info(f"found {len(image_paths)} images.")
24 |
25 | if args.in_json is None and Path(args.out_json).is_file():
26 | args.in_json = args.out_json
27 |
28 | if args.in_json is not None:
29 | logger.info(f"loading existing metadata: {args.in_json}")
30 | metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8"))
31 | logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
32 | else:
33 | logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
34 | metadata = {}
35 |
36 | logger.info("merge caption texts to metadata json.")
37 | for image_path in tqdm(image_paths):
38 | caption_path = image_path.with_suffix(args.caption_extension)
39 | caption = caption_path.read_text(encoding="utf-8").strip()
40 |
41 | if not os.path.exists(caption_path):
42 | caption_path = os.path.join(image_path, args.caption_extension)
43 |
44 | image_key = str(image_path) if args.full_path else image_path.stem
45 | if image_key not in metadata:
46 | metadata[image_key] = {}
47 |
48 | metadata[image_key]["caption"] = caption
49 | if args.debug:
50 | logger.info(f"{image_key} {caption}")
51 |
52 | # metadataを書き出して終わり
53 | logger.info(f"writing metadata: {args.out_json}")
54 | Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8")
55 | logger.info("done!")
56 |
57 |
58 | def setup_parser() -> argparse.ArgumentParser:
59 | parser = argparse.ArgumentParser()
60 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
61 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
62 | parser.add_argument(
63 | "--in_json",
64 | type=str,
65 | help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)",
66 | )
67 | parser.add_argument(
68 | "--caption_extention",
69 | type=str,
70 | default=None,
71 | help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
72 | )
73 | parser.add_argument(
74 | "--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子"
75 | )
76 | parser.add_argument(
77 | "--full_path",
78 | action="store_true",
79 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
80 | )
81 | parser.add_argument(
82 | "--recursive",
83 | action="store_true",
84 | help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
85 | )
86 | parser.add_argument("--debug", action="store_true", help="debug mode")
87 |
88 | return parser
89 |
90 |
91 | if __name__ == "__main__":
92 | parser = setup_parser()
93 |
94 | args = parser.parse_args()
95 |
96 | # スペルミスしていたオプションを復元する
97 | if args.caption_extention is not None:
98 | args.caption_extension = args.caption_extention
99 |
100 | main(args)
101 |
--------------------------------------------------------------------------------
/scripts/dev/finetune/merge_dd_tags_to_metadata.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from pathlib import Path
4 | from typing import List
5 | from tqdm import tqdm
6 | import library.train_util as train_util
7 | import os
8 | from library.utils import setup_logging
9 |
10 | setup_logging()
11 | import logging
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def main(args):
17 | assert not args.recursive or (
18 | args.recursive and args.full_path
19 | ), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
20 |
21 | train_data_dir_path = Path(args.train_data_dir)
22 | image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
23 | logger.info(f"found {len(image_paths)} images.")
24 |
25 | if args.in_json is None and Path(args.out_json).is_file():
26 | args.in_json = args.out_json
27 |
28 | if args.in_json is not None:
29 | logger.info(f"loading existing metadata: {args.in_json}")
30 | metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8"))
31 | logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
32 | else:
33 | logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
34 | metadata = {}
35 |
36 | logger.info("merge tags to metadata json.")
37 | for image_path in tqdm(image_paths):
38 | tags_path = image_path.with_suffix(args.caption_extension)
39 | tags = tags_path.read_text(encoding="utf-8").strip()
40 |
41 | if not os.path.exists(tags_path):
42 | tags_path = os.path.join(image_path, args.caption_extension)
43 |
44 | image_key = str(image_path) if args.full_path else image_path.stem
45 | if image_key not in metadata:
46 | metadata[image_key] = {}
47 |
48 | metadata[image_key]["tags"] = tags
49 | if args.debug:
50 | logger.info(f"{image_key} {tags}")
51 |
52 | # metadataを書き出して終わり
53 | logger.info(f"writing metadata: {args.out_json}")
54 | Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8")
55 |
56 | logger.info("done!")
57 |
58 |
59 | def setup_parser() -> argparse.ArgumentParser:
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
62 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
63 | parser.add_argument(
64 | "--in_json",
65 | type=str,
66 | help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)",
67 | )
68 | parser.add_argument(
69 | "--full_path",
70 | action="store_true",
71 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
72 | )
73 | parser.add_argument(
74 | "--recursive",
75 | action="store_true",
76 | help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
77 | )
78 | parser.add_argument(
79 | "--caption_extension",
80 | type=str,
81 | default=".txt",
82 | help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子",
83 | )
84 | parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
85 |
86 | return parser
87 |
88 |
89 | if __name__ == "__main__":
90 | parser = setup_parser()
91 |
92 | args = parser.parse_args()
93 | main(args)
94 |
--------------------------------------------------------------------------------
/scripts/dev/library/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Akegarasu/lora-scripts/e0f5194815203093659d6ec280b9362b9792c070/scripts/dev/library/__init__.py
--------------------------------------------------------------------------------
/scripts/dev/library/adafactor_fused.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from transformers import Adafactor
4 |
5 | # stochastic rounding for bfloat16
6 | # The implementation was provided by 2kpr. Thank you very much!
7 |
8 | def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
9 | """
10 | copies source into target using stochastic rounding
11 |
12 | Args:
13 | target: the target tensor with dtype=bfloat16
14 | source: the target tensor with dtype=float32
15 | """
16 | # create a random 16 bit integer
17 | result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
18 |
19 | # add the random number to the lower 16 bit of the mantissa
20 | result.add_(source.view(dtype=torch.int32))
21 |
22 | # mask off the lower 16 bit of the mantissa
23 | result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
24 |
25 | # copy the higher 16 bit into the target tensor
26 | target.copy_(result.view(dtype=torch.float32))
27 |
28 | del result
29 |
30 |
31 | @torch.no_grad()
32 | def adafactor_step_param(self, p, group):
33 | if p.grad is None:
34 | return
35 | grad = p.grad
36 | if grad.dtype in {torch.float16, torch.bfloat16}:
37 | grad = grad.float()
38 | if grad.is_sparse:
39 | raise RuntimeError("Adafactor does not support sparse gradients.")
40 |
41 | state = self.state[p]
42 | grad_shape = grad.shape
43 |
44 | factored, use_first_moment = Adafactor._get_options(group, grad_shape)
45 | # State Initialization
46 | if len(state) == 0:
47 | state["step"] = 0
48 |
49 | if use_first_moment:
50 | # Exponential moving average of gradient values
51 | state["exp_avg"] = torch.zeros_like(grad)
52 | if factored:
53 | state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
54 | state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
55 | else:
56 | state["exp_avg_sq"] = torch.zeros_like(grad)
57 |
58 | state["RMS"] = 0
59 | else:
60 | if use_first_moment:
61 | state["exp_avg"] = state["exp_avg"].to(grad)
62 | if factored:
63 | state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
64 | state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
65 | else:
66 | state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
67 |
68 | p_data_fp32 = p
69 | if p.dtype in {torch.float16, torch.bfloat16}:
70 | p_data_fp32 = p_data_fp32.float()
71 |
72 | state["step"] += 1
73 | state["RMS"] = Adafactor._rms(p_data_fp32)
74 | lr = Adafactor._get_lr(group, state)
75 |
76 | beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
77 | update = (grad**2) + group["eps"][0]
78 | if factored:
79 | exp_avg_sq_row = state["exp_avg_sq_row"]
80 | exp_avg_sq_col = state["exp_avg_sq_col"]
81 |
82 | exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
83 | exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
84 |
85 | # Approximation of exponential moving average of square of gradient
86 | update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
87 | update.mul_(grad)
88 | else:
89 | exp_avg_sq = state["exp_avg_sq"]
90 |
91 | exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
92 | update = exp_avg_sq.rsqrt().mul_(grad)
93 |
94 | update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
95 | update.mul_(lr)
96 |
97 | if use_first_moment:
98 | exp_avg = state["exp_avg"]
99 | exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
100 | update = exp_avg
101 |
102 | if group["weight_decay"] != 0:
103 | p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
104 |
105 | p_data_fp32.add_(-update)
106 |
107 | # if p.dtype in {torch.float16, torch.bfloat16}:
108 | # p.copy_(p_data_fp32)
109 |
110 | if p.dtype == torch.bfloat16:
111 | copy_stochastic_(p, p_data_fp32)
112 | elif p.dtype == torch.float16:
113 | p.copy_(p_data_fp32)
114 |
115 |
116 | @torch.no_grad()
117 | def adafactor_step(self, closure=None):
118 | """
119 | Performs a single optimization step
120 |
121 | Arguments:
122 | closure (callable, optional): A closure that reevaluates the model
123 | and returns the loss.
124 | """
125 | loss = None
126 | if closure is not None:
127 | loss = closure()
128 |
129 | for group in self.param_groups:
130 | for p in group["params"]:
131 | adafactor_step_param(self, p, group)
132 |
133 | return loss
134 |
135 |
136 | def patch_adafactor_fused(optimizer: Adafactor):
137 | optimizer.step_param = adafactor_step_param.__get__(optimizer)
138 | optimizer.step = adafactor_step.__get__(optimizer)
139 |
--------------------------------------------------------------------------------
/scripts/dev/library/deepspeed_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | from accelerate import DeepSpeedPlugin, Accelerator
5 |
6 | from .utils import setup_logging
7 |
8 | setup_logging()
9 | import logging
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | def add_deepspeed_arguments(parser: argparse.ArgumentParser):
15 | # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
16 | parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
17 | parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
18 | parser.add_argument(
19 | "--offload_optimizer_device",
20 | type=str,
21 | default=None,
22 | choices=[None, "cpu", "nvme"],
23 | help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
24 | )
25 | parser.add_argument(
26 | "--offload_optimizer_nvme_path",
27 | type=str,
28 | default=None,
29 | help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
30 | )
31 | parser.add_argument(
32 | "--offload_param_device",
33 | type=str,
34 | default=None,
35 | choices=[None, "cpu", "nvme"],
36 | help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
37 | )
38 | parser.add_argument(
39 | "--offload_param_nvme_path",
40 | type=str,
41 | default=None,
42 | help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
43 | )
44 | parser.add_argument(
45 | "--zero3_init_flag",
46 | action="store_true",
47 | help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
48 | "Only applicable with ZeRO Stage-3.",
49 | )
50 | parser.add_argument(
51 | "--zero3_save_16bit_model",
52 | action="store_true",
53 | help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
54 | )
55 | parser.add_argument(
56 | "--fp16_master_weights_and_gradients",
57 | action="store_true",
58 | help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
59 | )
60 |
61 |
62 | def prepare_deepspeed_args(args: argparse.Namespace):
63 | if not args.deepspeed:
64 | return
65 |
66 | # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
67 | args.max_data_loader_n_workers = 1
68 |
69 |
70 | def prepare_deepspeed_plugin(args: argparse.Namespace):
71 | if not args.deepspeed:
72 | return None
73 |
74 | try:
75 | import deepspeed
76 | except ImportError as e:
77 | logger.error(
78 | "deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
79 | )
80 | exit(1)
81 |
82 | deepspeed_plugin = DeepSpeedPlugin(
83 | zero_stage=args.zero_stage,
84 | gradient_accumulation_steps=args.gradient_accumulation_steps,
85 | gradient_clipping=args.max_grad_norm,
86 | offload_optimizer_device=args.offload_optimizer_device,
87 | offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
88 | offload_param_device=args.offload_param_device,
89 | offload_param_nvme_path=args.offload_param_nvme_path,
90 | zero3_init_flag=args.zero3_init_flag,
91 | zero3_save_16bit_model=args.zero3_save_16bit_model,
92 | )
93 | deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
94 | deepspeed_plugin.deepspeed_config["train_batch_size"] = (
95 | args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
96 | )
97 | deepspeed_plugin.set_mixed_precision(args.mixed_precision)
98 | if args.mixed_precision.lower() == "fp16":
99 | deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
100 | if args.full_fp16 or args.fp16_master_weights_and_gradients:
101 | if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
102 | deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
103 | logger.info("[DeepSpeed] full fp16 enable.")
104 | else:
105 | logger.info(
106 | "[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
107 | )
108 |
109 | if args.offload_optimizer_device is not None:
110 | logger.info("[DeepSpeed] start to manually build cpu_adam.")
111 | deepspeed.ops.op_builder.CPUAdamBuilder().load()
112 | logger.info("[DeepSpeed] building cpu_adam done.")
113 |
114 | return deepspeed_plugin
115 |
116 |
117 | # Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
118 | def prepare_deepspeed_model(args: argparse.Namespace, **models):
119 | # remove None from models
120 | models = {k: v for k, v in models.items() if v is not None}
121 |
122 | class DeepSpeedWrapper(torch.nn.Module):
123 | def __init__(self, **kw_models) -> None:
124 | super().__init__()
125 | self.models = torch.nn.ModuleDict()
126 |
127 | for key, model in kw_models.items():
128 | if isinstance(model, list):
129 | model = torch.nn.ModuleList(model)
130 | assert isinstance(
131 | model, torch.nn.Module
132 | ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
133 | self.models.update(torch.nn.ModuleDict({key: model}))
134 |
135 | def get_models(self):
136 | return self.models
137 |
138 | ds_model = DeepSpeedWrapper(**models)
139 | return ds_model
140 |
--------------------------------------------------------------------------------
/scripts/dev/library/device_utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import gc
3 |
4 | import torch
5 | try:
6 | # intel gpu support for pytorch older than 2.5
7 | # ipex is not needed after pytorch 2.5
8 | import intel_extension_for_pytorch as ipex # noqa
9 | except Exception:
10 | pass
11 |
12 |
13 | try:
14 | HAS_CUDA = torch.cuda.is_available()
15 | except Exception:
16 | HAS_CUDA = False
17 |
18 | try:
19 | HAS_MPS = torch.backends.mps.is_available()
20 | except Exception:
21 | HAS_MPS = False
22 |
23 | try:
24 | HAS_XPU = torch.xpu.is_available()
25 | except Exception:
26 | HAS_XPU = False
27 |
28 |
29 | def clean_memory():
30 | gc.collect()
31 | if HAS_CUDA:
32 | torch.cuda.empty_cache()
33 | if HAS_XPU:
34 | torch.xpu.empty_cache()
35 | if HAS_MPS:
36 | torch.mps.empty_cache()
37 |
38 |
39 | def clean_memory_on_device(device: torch.device):
40 | r"""
41 | Clean memory on the specified device, will be called from training scripts.
42 | """
43 | gc.collect()
44 |
45 | # device may "cuda" or "cuda:0", so we need to check the type of device
46 | if device.type == "cuda":
47 | torch.cuda.empty_cache()
48 | if device.type == "xpu":
49 | torch.xpu.empty_cache()
50 | if device.type == "mps":
51 | torch.mps.empty_cache()
52 |
53 |
54 | @functools.lru_cache(maxsize=None)
55 | def get_preferred_device() -> torch.device:
56 | r"""
57 | Do not call this function from training scripts. Use accelerator.device instead.
58 | """
59 | if HAS_CUDA:
60 | device = torch.device("cuda")
61 | elif HAS_XPU:
62 | device = torch.device("xpu")
63 | elif HAS_MPS:
64 | device = torch.device("mps")
65 | else:
66 | device = torch.device("cpu")
67 | print(f"get_preferred_device() -> {device}")
68 | return device
69 |
70 |
71 | def init_ipex():
72 | """
73 | Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
74 |
75 | This function should run right after importing torch and before doing anything else.
76 |
77 | If xpu is not available, this function does nothing.
78 | """
79 | try:
80 | if HAS_XPU:
81 | from library.ipex import ipex_init
82 |
83 | is_initialized, error_message = ipex_init()
84 | if not is_initialized:
85 | print("failed to initialize ipex:", error_message)
86 | else:
87 | return
88 | except Exception as e:
89 | print("failed to initialize ipex:", e)
90 |
--------------------------------------------------------------------------------
/scripts/dev/library/huggingface_util.py:
--------------------------------------------------------------------------------
1 | from typing import Union, BinaryIO
2 | from huggingface_hub import HfApi
3 | from pathlib import Path
4 | import argparse
5 | import os
6 | from library.utils import fire_in_thread
7 | from library.utils import setup_logging
8 | setup_logging()
9 | import logging
10 | logger = logging.getLogger(__name__)
11 |
12 | def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
13 | api = HfApi(
14 | token=token,
15 | )
16 | try:
17 | api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
18 | return True
19 | except:
20 | return False
21 |
22 |
23 | def upload(
24 | args: argparse.Namespace,
25 | src: Union[str, Path, bytes, BinaryIO],
26 | dest_suffix: str = "",
27 | force_sync_upload: bool = False,
28 | ):
29 | repo_id = args.huggingface_repo_id
30 | repo_type = args.huggingface_repo_type
31 | token = args.huggingface_token
32 | path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
33 | private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
34 | api = HfApi(token=token)
35 | if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
36 | try:
37 | api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
38 | except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
39 | logger.error("===========================================")
40 | logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
41 | logger.error("===========================================")
42 |
43 | is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
44 |
45 | def uploader():
46 | try:
47 | if is_folder:
48 | api.upload_folder(
49 | repo_id=repo_id,
50 | repo_type=repo_type,
51 | folder_path=src,
52 | path_in_repo=path_in_repo,
53 | )
54 | else:
55 | api.upload_file(
56 | repo_id=repo_id,
57 | repo_type=repo_type,
58 | path_or_fileobj=src,
59 | path_in_repo=path_in_repo,
60 | )
61 | except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
62 | logger.error("===========================================")
63 | logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
64 | logger.error("===========================================")
65 |
66 | if args.async_upload and not force_sync_upload:
67 | fire_in_thread(uploader)
68 | else:
69 | uploader()
70 |
71 |
72 | def list_dir(
73 | repo_id: str,
74 | subfolder: str,
75 | repo_type: str,
76 | revision: str = "main",
77 | token: str = None,
78 | ):
79 | api = HfApi(
80 | token=token,
81 | )
82 | repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
83 | file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
84 | return file_list
85 |
--------------------------------------------------------------------------------
/scripts/dev/library/ipex/diffusers.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | import torch
3 | import diffusers # pylint: disable=import-error
4 |
5 | # pylint: disable=protected-access, missing-function-docstring, line-too-long
6 |
7 |
8 | # Diffusers FreeU
9 | original_fourier_filter = diffusers.utils.torch_utils.fourier_filter
10 | @wraps(diffusers.utils.torch_utils.fourier_filter)
11 | def fourier_filter(x_in, threshold, scale):
12 | return_dtype = x_in.dtype
13 | return original_fourier_filter(x_in.to(dtype=torch.float32), threshold, scale).to(dtype=return_dtype)
14 |
15 |
16 | # fp64 error
17 | class FluxPosEmbed(torch.nn.Module):
18 | def __init__(self, theta: int, axes_dim):
19 | super().__init__()
20 | self.theta = theta
21 | self.axes_dim = axes_dim
22 |
23 | def forward(self, ids: torch.Tensor) -> torch.Tensor:
24 | n_axes = ids.shape[-1]
25 | cos_out = []
26 | sin_out = []
27 | pos = ids.float()
28 | for i in range(n_axes):
29 | cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed(
30 | self.axes_dim[i],
31 | pos[:, i],
32 | theta=self.theta,
33 | repeat_interleave_real=True,
34 | use_real=True,
35 | freqs_dtype=torch.float32,
36 | )
37 | cos_out.append(cos)
38 | sin_out.append(sin)
39 | freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
40 | freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
41 | return freqs_cos, freqs_sin
42 |
43 |
44 | def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False):
45 | diffusers.utils.torch_utils.fourier_filter = fourier_filter
46 | if not device_supports_fp64:
47 | diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed
48 |
--------------------------------------------------------------------------------
/scripts/dev/networks/check_lora_weights.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from safetensors.torch import load_file
5 | from library.utils import setup_logging
6 | setup_logging()
7 | import logging
8 | logger = logging.getLogger(__name__)
9 |
10 | def main(file):
11 | logger.info(f"loading: {file}")
12 | if os.path.splitext(file)[1] == ".safetensors":
13 | sd = load_file(file)
14 | else:
15 | sd = torch.load(file, map_location="cpu")
16 |
17 | values = []
18 |
19 | keys = list(sd.keys())
20 | for key in keys:
21 | if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key:
22 | values.append((key, sd[key]))
23 | print(f"number of LoRA modules: {len(values)}")
24 |
25 | if args.show_all_keys:
26 | for key in [k for k in keys if k not in values]:
27 | values.append((key, sd[key]))
28 | print(f"number of all modules: {len(values)}")
29 |
30 | for key, value in values:
31 | value = value.to(torch.float32)
32 | print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
33 |
34 |
35 | def setup_parser() -> argparse.ArgumentParser:
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
38 | parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する")
39 |
40 | return parser
41 |
42 |
43 | if __name__ == "__main__":
44 | parser = setup_parser()
45 |
46 | args = parser.parse_args()
47 |
48 | main(args.file)
49 |
--------------------------------------------------------------------------------
/scripts/dev/networks/extract_lora_from_dylora.py:
--------------------------------------------------------------------------------
1 | # Convert LoRA to different rank approximation (should only be used to go to lower rank)
2 | # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3 | # Thanks to cloneofsimo
4 |
5 | import argparse
6 | import math
7 | import os
8 | import torch
9 | from safetensors.torch import load_file, save_file, safe_open
10 | from tqdm import tqdm
11 | from library import train_util, model_util
12 | import numpy as np
13 | from library.utils import setup_logging
14 | setup_logging()
15 | import logging
16 | logger = logging.getLogger(__name__)
17 |
18 | def load_state_dict(file_name):
19 | if model_util.is_safetensors(file_name):
20 | sd = load_file(file_name)
21 | with safe_open(file_name, framework="pt") as f:
22 | metadata = f.metadata()
23 | else:
24 | sd = torch.load(file_name, map_location="cpu")
25 | metadata = None
26 |
27 | return sd, metadata
28 |
29 |
30 | def save_to_file(file_name, model, metadata):
31 | if model_util.is_safetensors(file_name):
32 | save_file(model, file_name, metadata)
33 | else:
34 | torch.save(model, file_name)
35 |
36 |
37 | def split_lora_model(lora_sd, unit):
38 | max_rank = 0
39 |
40 | # Extract loaded lora dim and alpha
41 | for key, value in lora_sd.items():
42 | if "lora_down" in key:
43 | rank = value.size()[0]
44 | if rank > max_rank:
45 | max_rank = rank
46 | logger.info(f"Max rank: {max_rank}")
47 |
48 | rank = unit
49 | split_models = []
50 | new_alpha = None
51 | while rank < max_rank:
52 | logger.info(f"Splitting rank {rank}")
53 | new_sd = {}
54 | for key, value in lora_sd.items():
55 | if "lora_down" in key:
56 | new_sd[key] = value[:rank].contiguous()
57 | elif "lora_up" in key:
58 | new_sd[key] = value[:, :rank].contiguous()
59 | else:
60 | # なぜかscaleするとおかしくなる……
61 | # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
62 | # scale = math.sqrt(this_rank / rank) # rank is > unit
63 | # logger.info(key, value.size(), this_rank, rank, value, scale)
64 | # new_alpha = value * scale # always same
65 | # new_sd[key] = new_alpha
66 | new_sd[key] = value
67 |
68 | split_models.append((new_sd, rank, new_alpha))
69 | rank += unit
70 |
71 | return max_rank, split_models
72 |
73 |
74 | def split(args):
75 | logger.info("loading Model...")
76 | lora_sd, metadata = load_state_dict(args.model)
77 |
78 | logger.info("Splitting Model...")
79 | original_rank, split_models = split_lora_model(lora_sd, args.unit)
80 |
81 | comment = metadata.get("ss_training_comment", "")
82 | for state_dict, new_rank, new_alpha in split_models:
83 | # update metadata
84 | if metadata is None:
85 | new_metadata = {}
86 | else:
87 | new_metadata = metadata.copy()
88 |
89 | new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
90 | new_metadata["ss_network_dim"] = str(new_rank)
91 | # new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy())
92 |
93 | model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
94 | metadata["sshs_model_hash"] = model_hash
95 | metadata["sshs_legacy_hash"] = legacy_hash
96 |
97 | filename, ext = os.path.splitext(args.save_to)
98 | model_file_name = filename + f"-{new_rank:04d}{ext}"
99 |
100 | logger.info(f"saving model to: {model_file_name}")
101 | save_to_file(model_file_name, state_dict, new_metadata)
102 |
103 |
104 | def setup_parser() -> argparse.ArgumentParser:
105 | parser = argparse.ArgumentParser()
106 |
107 | parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ")
108 | parser.add_argument(
109 | "--save_to",
110 | type=str,
111 | default=None,
112 | help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors",
113 | )
114 | parser.add_argument(
115 | "--model",
116 | type=str,
117 | default=None,
118 | help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors",
119 | )
120 |
121 | return parser
122 |
123 |
124 | if __name__ == "__main__":
125 | parser = setup_parser()
126 |
127 | args = parser.parse_args()
128 | split(args)
129 |
--------------------------------------------------------------------------------
/scripts/dev/networks/lora_interrogator.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from tqdm import tqdm
4 | from library import model_util
5 | import library.train_util as train_util
6 | import argparse
7 | from transformers import CLIPTokenizer
8 |
9 | import torch
10 | from library.device_utils import init_ipex, get_preferred_device
11 | init_ipex()
12 |
13 | import library.model_util as model_util
14 | import lora
15 | from library.utils import setup_logging
16 | setup_logging()
17 | import logging
18 | logger = logging.getLogger(__name__)
19 |
20 | TOKENIZER_PATH = "openai/clip-vit-large-patch14"
21 | V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
22 |
23 | DEVICE = get_preferred_device()
24 |
25 |
26 | def interrogate(args):
27 | weights_dtype = torch.float16
28 |
29 | # いろいろ準備する
30 | logger.info(f"loading SD model: {args.sd_model}")
31 | args.pretrained_model_name_or_path = args.sd_model
32 | args.vae = None
33 | text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
34 |
35 | logger.info(f"loading LoRA: {args.model}")
36 | network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
37 |
38 | # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
39 | has_te_weight = False
40 | for key in weights_sd.keys():
41 | if 'lora_te' in key:
42 | has_te_weight = True
43 | break
44 | if not has_te_weight:
45 | logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
46 | return
47 | del vae
48 |
49 | logger.info("loading tokenizer")
50 | if args.v2:
51 | tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
52 | else:
53 | tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
54 |
55 | text_encoder.to(DEVICE, dtype=weights_dtype)
56 | text_encoder.eval()
57 | unet.to(DEVICE, dtype=weights_dtype)
58 | unet.eval() # U-Netは呼び出さないので不要だけど
59 |
60 | # トークンをひとつひとつ当たっていく
61 | token_id_start = 0
62 | token_id_end = max(tokenizer.all_special_ids)
63 | logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}")
64 |
65 | def get_all_embeddings(text_encoder):
66 | embs = []
67 | with torch.no_grad():
68 | for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
69 | batch = []
70 | for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
71 | tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
72 | # tokens = [tid] # こちらは結果がいまひとつ
73 | batch.append(tokens)
74 |
75 | # batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
76 | # clip skip対応
77 | batch = torch.tensor(batch).to(DEVICE)
78 | if args.clip_skip is None:
79 | encoder_hidden_states = text_encoder(batch)[0]
80 | else:
81 | enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
82 | encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
83 | encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
84 | encoder_hidden_states = encoder_hidden_states.to("cpu")
85 |
86 | embs.extend(encoder_hidden_states)
87 | return torch.stack(embs)
88 |
89 | logger.info("get original text encoder embeddings.")
90 | orig_embs = get_all_embeddings(text_encoder)
91 |
92 | network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
93 | info = network.load_state_dict(weights_sd, strict=False)
94 | logger.info(f"Loading LoRA weights: {info}")
95 |
96 | network.to(DEVICE, dtype=weights_dtype)
97 | network.eval()
98 |
99 | del unet
100 |
101 | logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
102 | logger.info("get text encoder embeddings with lora.")
103 | lora_embs = get_all_embeddings(text_encoder)
104 |
105 | # 比べる:とりあえず単純に差分の絶対値で
106 | logger.info("comparing...")
107 | diffs = {}
108 | for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
109 | diff = torch.mean(torch.abs(orig_emb - lora_emb))
110 | # diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
111 | diff = float(diff.detach().to('cpu').numpy())
112 | diffs[token_id_start + i] = diff
113 |
114 | diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
115 |
116 | # 結果を表示する
117 | print("top 100:")
118 | for i, (token, diff) in enumerate(diffs_sorted[:100]):
119 | # if diff < 1e-6:
120 | # break
121 | string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
122 | print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
123 |
124 |
125 | def setup_parser() -> argparse.ArgumentParser:
126 | parser = argparse.ArgumentParser()
127 |
128 | parser.add_argument("--v2", action='store_true',
129 | help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
130 | parser.add_argument("--sd_model", type=str, default=None,
131 | help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
132 | parser.add_argument("--model", type=str, default=None,
133 | help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
134 | parser.add_argument("--batch_size", type=int, default=16,
135 | help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
136 | parser.add_argument("--clip_skip", type=int, default=None,
137 | help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
138 |
139 | return parser
140 |
141 |
142 | if __name__ == '__main__':
143 | parser = setup_parser()
144 |
145 | args = parser.parse_args()
146 | interrogate(args)
147 |
--------------------------------------------------------------------------------
/scripts/dev/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | minversion = 6.0
3 | testpaths =
4 | tests
5 | filterwarnings =
6 | ignore::DeprecationWarning
7 | ignore::UserWarning
8 | ignore::FutureWarning
9 |
--------------------------------------------------------------------------------
/scripts/dev/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.33.0
2 | transformers==4.44.0
3 | diffusers[torch]==0.25.0
4 | ftfy==6.1.1
5 | # albumentations==1.3.0
6 | opencv-python==4.8.1.78
7 | einops==0.7.0
8 | pytorch-lightning==1.9.0
9 | bitsandbytes==0.44.0
10 | lion-pytorch==0.0.6
11 | schedulefree==1.4
12 | pytorch-optimizer==3.5.0
13 | prodigy-plus-schedule-free==1.9.0
14 | prodigyopt==1.1.2
15 | tensorboard
16 | safetensors==0.4.4
17 | # gradio==3.16.2
18 | altair==4.2.2
19 | easygui==0.98.3
20 | toml==0.10.2
21 | voluptuous==0.13.1
22 | huggingface-hub==0.24.5
23 | # for Image utils
24 | imagesize==1.4.1
25 | numpy<=2.0
26 | # for BLIP captioning
27 | # requests==2.28.2
28 | # timm==0.6.12
29 | # fairscale==0.4.13
30 | # for WD14 captioning (tensorflow)
31 | # tensorflow==2.10.1
32 | # for WD14 captioning (onnx)
33 | # onnx==1.15.0
34 | # onnxruntime-gpu==1.17.1
35 | # onnxruntime==1.17.1
36 | # for cuda 12.1(default 11.8)
37 | # onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
38 |
39 | # this is for onnx:
40 | # protobuf==3.20.3
41 | # open clip for SDXL
42 | # open-clip-torch==2.20.0
43 | # For logging
44 | rich==13.7.0
45 | # for T5XXL tokenizer (SD3/FLUX)
46 | sentencepiece==0.2.0
47 | # for kohya_ss library
48 | -e .
49 |
--------------------------------------------------------------------------------
/scripts/dev/sdxl_train_textual_inversion.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from typing import Optional, Union
4 |
5 | import regex
6 |
7 | import torch
8 | from library.device_utils import init_ipex
9 |
10 | init_ipex()
11 |
12 | from library import sdxl_model_util, sdxl_train_util, strategy_sd, strategy_sdxl, train_util
13 | import train_textual_inversion
14 |
15 |
16 | class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer):
17 | def __init__(self):
18 | super().__init__()
19 | self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
20 | self.is_sdxl = True
21 |
22 | def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
23 | super().assert_extra_args(args, train_dataset_group, val_dataset_group)
24 | sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
25 |
26 | train_dataset_group.verify_bucket_reso_steps(32)
27 | if val_dataset_group is not None:
28 | val_dataset_group.verify_bucket_reso_steps(32)
29 |
30 | def load_target_model(self, args, weight_dtype, accelerator):
31 | (
32 | load_stable_diffusion_format,
33 | text_encoder1,
34 | text_encoder2,
35 | vae,
36 | unet,
37 | logit_scale,
38 | ckpt_info,
39 | ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
40 |
41 | self.load_stable_diffusion_format = load_stable_diffusion_format
42 | self.logit_scale = logit_scale
43 | self.ckpt_info = ckpt_info
44 |
45 | return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
46 |
47 | def get_tokenize_strategy(self, args):
48 | return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
49 |
50 | def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy):
51 | return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2]
52 |
53 | def get_latents_caching_strategy(self, args):
54 | latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
55 | False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
56 | )
57 | return latents_caching_strategy
58 |
59 | def get_text_encoding_strategy(self, args):
60 | return strategy_sdxl.SdxlTextEncodingStrategy()
61 |
62 | def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
63 | noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
64 |
65 | # get size embeddings
66 | orig_size = batch["original_sizes_hw"]
67 | crop_size = batch["crop_top_lefts"]
68 | target_size = batch["target_sizes_hw"]
69 | embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
70 |
71 | # concat embeddings
72 | encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
73 | vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
74 | text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
75 |
76 | noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
77 | return noise_pred
78 |
79 | def sample_images(
80 | self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
81 | ):
82 | sdxl_train_util.sample_images(
83 | accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
84 | )
85 |
86 | def save_weights(self, file, updated_embs, save_dtype, metadata):
87 | state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]}
88 |
89 | if save_dtype is not None:
90 | for key in list(state_dict.keys()):
91 | v = state_dict[key]
92 | v = v.detach().clone().to("cpu").to(save_dtype)
93 | state_dict[key] = v
94 |
95 | if os.path.splitext(file)[1] == ".safetensors":
96 | from safetensors.torch import save_file
97 |
98 | save_file(state_dict, file, metadata)
99 | else:
100 | torch.save(state_dict, file)
101 |
102 | def load_weights(self, file):
103 | if os.path.splitext(file)[1] == ".safetensors":
104 | from safetensors.torch import load_file
105 |
106 | data = load_file(file)
107 | else:
108 | data = torch.load(file, map_location="cpu")
109 |
110 | emb_l = data.get("clip_l", None) # ViT-L text encoder 1
111 | emb_g = data.get("clip_g", None) # BiG-G text encoder 2
112 |
113 | assert (
114 | emb_l is not None or emb_g is not None
115 | ), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}"
116 |
117 | return [emb_l, emb_g]
118 |
119 |
120 | def setup_parser() -> argparse.ArgumentParser:
121 | parser = train_textual_inversion.setup_parser()
122 | sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
123 | return parser
124 |
125 |
126 | if __name__ == "__main__":
127 | parser = setup_parser()
128 |
129 | args = parser.parse_args()
130 | train_util.verify_command_line_training_args(args)
131 | args = train_util.read_config_from_file(args, parser)
132 |
133 | trainer = SdxlTextualInversionTrainer()
134 | trainer.train(args)
135 |
--------------------------------------------------------------------------------
/scripts/dev/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(name = "library", packages = find_packages())
--------------------------------------------------------------------------------
/scripts/dev/tools/canny.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 |
4 | import logging
5 | from library.utils import setup_logging
6 | setup_logging()
7 | logger = logging.getLogger(__name__)
8 |
9 | def canny(args):
10 | img = cv2.imread(args.input)
11 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
12 |
13 | canny_img = cv2.Canny(img, args.thres1, args.thres2)
14 | # canny_img = 255 - canny_img
15 |
16 | cv2.imwrite(args.output, canny_img)
17 | logger.info("done!")
18 |
19 |
20 | def setup_parser() -> argparse.ArgumentParser:
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument("--input", type=str, default=None, help="input path")
23 | parser.add_argument("--output", type=str, default=None, help="output path")
24 | parser.add_argument("--thres1", type=int, default=32, help="thres1")
25 | parser.add_argument("--thres2", type=int, default=224, help="thres2")
26 |
27 | return parser
28 |
29 |
30 | if __name__ == '__main__':
31 | parser = setup_parser()
32 |
33 | args = parser.parse_args()
34 | canny(args)
35 |
--------------------------------------------------------------------------------
/scripts/dev/tools/show_metadata.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from safetensors import safe_open
4 | from library.utils import setup_logging
5 | setup_logging()
6 | import logging
7 | logger = logging.getLogger(__name__)
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--model", type=str, required=True)
11 | args = parser.parse_args()
12 |
13 | with safe_open(args.model, framework="pt") as f:
14 | metadata = f.metadata()
15 |
16 | if metadata is None:
17 | logger.error("No metadata found")
18 | else:
19 | # metadata is json dict, but not pretty printed
20 | # sort by key and pretty print
21 | print(json.dumps(metadata, indent=4, sort_keys=True))
22 |
23 |
24 |
--------------------------------------------------------------------------------
/scripts/dev/train_controlnet.py:
--------------------------------------------------------------------------------
1 | from library.utils import setup_logging
2 |
3 | setup_logging()
4 | import logging
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | from library import train_util
10 | from train_control_net import setup_parser, train
11 |
12 | if __name__ == "__main__":
13 | logger.warning(
14 | "The module 'train_controlnet.py' is deprecated. Please use 'train_control_net.py' instead"
15 | " / 'train_controlnet.py'は非推奨です。代わりに'train_control_net.py'を使用してください。"
16 | )
17 | parser = setup_parser()
18 |
19 | args = parser.parse_args()
20 | train_util.verify_command_line_training_args(args)
21 | args = train_util.read_config_from_file(args, parser)
22 |
23 | train(args)
24 |
--------------------------------------------------------------------------------
/scripts/stable/.gitignore:
--------------------------------------------------------------------------------
1 | logs
2 | __pycache__
3 | wd14_tagger_model
4 | venv
5 | *.egg-info
6 | build
7 | .vscode
8 | wandb
9 |
--------------------------------------------------------------------------------
/scripts/stable/COMMIT_ID:
--------------------------------------------------------------------------------
1 | 8f4ee8fc343b047965cd8976fca65c3a35b7593a
--------------------------------------------------------------------------------
/scripts/stable/README-ja.md:
--------------------------------------------------------------------------------
1 | ## リポジトリについて
2 | Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
3 |
4 | [README in English](./README.md) ←更新情報はこちらにあります
5 |
6 | 開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。
7 |
8 | FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。
9 |
10 | GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
11 |
12 | 以下のスクリプトがあります。
13 |
14 | * DreamBooth、U-NetおよびText Encoderの学習をサポート
15 | * fine-tuning、同上
16 | * LoRAの学習をサポート
17 | * 画像生成
18 | * モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換)
19 |
20 | ## 使用法について
21 |
22 | * [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど
23 | * [データセット設定](./docs/config_README-ja.md)
24 | * [SDXL学習](./docs/train_SDXL-en.md) (英語版)
25 | * [DreamBoothの学習について](./docs/train_db_README-ja.md)
26 | * [fine-tuningのガイド](./docs/fine_tune_README_ja.md):
27 | * [LoRAの学習について](./docs/train_network_README-ja.md)
28 | * [Textual Inversionの学習について](./docs/train_ti_README-ja.md)
29 | * [画像生成スクリプト](./docs/gen_img_README-ja.md)
30 | * note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
31 |
32 | ## Windowsでの動作に必要なプログラム
33 |
34 | Python 3.10.6およびGitが必要です。
35 |
36 | - Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
37 | - git: https://git-scm.com/download/win
38 |
39 | Python 3.10.x、3.11.x、3.12.xでも恐らく動作しますが、3.10.6でテストしています。
40 |
41 | PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。
42 | (venvに限らずスクリプトの実行が可能になりますので注意してください。)
43 |
44 | - PowerShellを管理者として開きます。
45 | - 「Set-ExecutionPolicy Unrestricted」と入力し、Yと答えます。
46 | - 管理者のPowerShellを閉じます。
47 |
48 | ## Windows環境でのインストール
49 |
50 | スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.2以降でも恐らく動作します。
51 |
52 | (なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。)
53 |
54 | PowerShellを使う場合、通常の(管理者ではない)PowerShellを開き以下を順に実行します。
55 |
56 | ```powershell
57 | git clone https://github.com/kohya-ss/sd-scripts.git
58 | cd sd-scripts
59 |
60 | python -m venv venv
61 | .\venv\Scripts\activate
62 |
63 | pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
64 | pip install --upgrade -r requirements.txt
65 | pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
66 |
67 | accelerate config
68 | ```
69 |
70 | コマンドプロンプトでも同一です。
71 |
72 | 注:`bitsandbytes==0.44.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。
73 |
74 | この例では PyTorch および xfomers は2.1.2/CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。
75 |
76 | PyTorch 2.2以降を用いる場合は、`torch==2.1.2` と `torchvision==0.16.2` 、および `xformers==0.0.23.post1` を適宜変更してください。
77 |
78 | accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。)
79 |
80 | ```txt
81 | - This machine
82 | - No distributed training
83 | - NO
84 | - NO
85 | - NO
86 | - all
87 | - fp16
88 | ```
89 |
90 | ※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問(
91 | ``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)
92 |
93 | ## アップグレード
94 |
95 | 新しいリリースがあった場合、以下のコマンドで更新できます。
96 |
97 | ```powershell
98 | cd sd-scripts
99 | git pull
100 | .\venv\Scripts\activate
101 | pip install --use-pep517 --upgrade -r requirements.txt
102 | ```
103 |
104 | コマンドが成功すれば新しいバージョンが使用できます。
105 |
106 | ## 謝意
107 |
108 | LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。
109 |
110 | Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。
111 |
112 | ## ライセンス
113 |
114 | スクリプトのライセンスはASL 2.0ですが(Diffusersおよびcloneofsimo氏のリポジトリ由来のものも同様)、一部他のライセンスのコードを含みます。
115 |
116 | [Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT
117 |
118 | [bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
119 |
120 | [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
121 |
122 | ## その他の情報
123 |
124 | ### LoRAの名称について
125 |
126 | `train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
127 |
128 | 1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
129 |
130 | Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
131 |
132 | 2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
133 |
134 | 1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
135 |
136 | デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args` に `conv_dim` を指定してください。
137 |
138 |
143 |
144 | ### 学習中のサンプル画像生成
145 |
146 | プロンプトファイルは例えば以下のようになります。
147 |
148 | ```
149 | # prompt 1
150 | masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
151 |
152 | # prompt 2
153 | masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
154 | ```
155 |
156 | `#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
157 |
158 | * `--n` Negative prompt up to the next option.
159 | * `--w` Specifies the width of the generated image.
160 | * `--h` Specifies the height of the generated image.
161 | * `--d` Specifies the seed of the generated image.
162 | * `--l` Specifies the CFG scale of the generated image.
163 | * `--s` Specifies the number of steps in the generation.
164 |
165 | `( )` や `[ ]` などの重みづけも動作します。
166 |
--------------------------------------------------------------------------------
/scripts/stable/_typos.toml:
--------------------------------------------------------------------------------
1 | # Files for typos
2 | # Instruction: https://github.com/marketplace/actions/typos-action#getting-started
3 |
4 | [default.extend-identifiers]
5 | ddPn08="ddPn08"
6 |
7 | [default.extend-words]
8 | NIN="NIN"
9 | parms="parms"
10 | nin="nin"
11 | extention="extention" # Intentionally left
12 | nd="nd"
13 | shs="shs"
14 | sts="sts"
15 | scs="scs"
16 | cpc="cpc"
17 | coc="coc"
18 | cic="cic"
19 | msm="msm"
20 | usu="usu"
21 | ici="ici"
22 | lvl="lvl"
23 | dii="dii"
24 | muk="muk"
25 | ori="ori"
26 | hru="hru"
27 | rik="rik"
28 | koo="koo"
29 | yos="yos"
30 | wn="wn"
31 | hime="hime"
32 |
33 |
34 | [files]
35 | extend-exclude = ["_typos.toml", "venv"]
36 |
--------------------------------------------------------------------------------
/scripts/stable/finetune/blip/med_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertModel"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 3072,
11 | "layer_norm_eps": 1e-12,
12 | "max_position_embeddings": 512,
13 | "model_type": "bert",
14 | "num_attention_heads": 12,
15 | "num_hidden_layers": 12,
16 | "pad_token_id": 0,
17 | "type_vocab_size": 2,
18 | "vocab_size": 30524,
19 | "encoder_width": 768,
20 | "add_cross_attention": true
21 | }
22 |
--------------------------------------------------------------------------------
/scripts/stable/finetune/hypernetwork_nai.py:
--------------------------------------------------------------------------------
1 | # NAI compatible
2 |
3 | import torch
4 |
5 |
6 | class HypernetworkModule(torch.nn.Module):
7 | def __init__(self, dim, multiplier=1.0):
8 | super().__init__()
9 |
10 | linear1 = torch.nn.Linear(dim, dim * 2)
11 | linear2 = torch.nn.Linear(dim * 2, dim)
12 | linear1.weight.data.normal_(mean=0.0, std=0.01)
13 | linear1.bias.data.zero_()
14 | linear2.weight.data.normal_(mean=0.0, std=0.01)
15 | linear2.bias.data.zero_()
16 | linears = [linear1, linear2]
17 |
18 | self.linear = torch.nn.Sequential(*linears)
19 | self.multiplier = multiplier
20 |
21 | def forward(self, x):
22 | return x + self.linear(x) * self.multiplier
23 |
24 |
25 | class Hypernetwork(torch.nn.Module):
26 | enable_sizes = [320, 640, 768, 1280]
27 | # return self.modules[Hypernetwork.enable_sizes.index(size)]
28 |
29 | def __init__(self, multiplier=1.0) -> None:
30 | super().__init__()
31 | self.modules = []
32 | for size in Hypernetwork.enable_sizes:
33 | self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
34 | self.register_module(f"{size}_0", self.modules[-1][0])
35 | self.register_module(f"{size}_1", self.modules[-1][1])
36 |
37 | def apply_to_stable_diffusion(self, text_encoder, vae, unet):
38 | blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
39 | for block in blocks:
40 | for subblk in block:
41 | if 'SpatialTransformer' in str(type(subblk)):
42 | for tf_block in subblk.transformer_blocks:
43 | for attn in [tf_block.attn1, tf_block.attn2]:
44 | size = attn.context_dim
45 | if size in Hypernetwork.enable_sizes:
46 | attn.hypernetwork = self
47 | else:
48 | attn.hypernetwork = None
49 |
50 | def apply_to_diffusers(self, text_encoder, vae, unet):
51 | blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
52 | for block in blocks:
53 | if hasattr(block, 'attentions'):
54 | for subblk in block.attentions:
55 | if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
56 | for tf_block in subblk.transformer_blocks:
57 | for attn in [tf_block.attn1, tf_block.attn2]:
58 | size = attn.to_k.in_features
59 | if size in Hypernetwork.enable_sizes:
60 | attn.hypernetwork = self
61 | else:
62 | attn.hypernetwork = None
63 | return True # TODO error checking
64 |
65 | def forward(self, x, context):
66 | size = context.shape[-1]
67 | assert size in Hypernetwork.enable_sizes
68 | module = self.modules[Hypernetwork.enable_sizes.index(size)]
69 | return module[0].forward(context), module[1].forward(context)
70 |
71 | def load_from_state_dict(self, state_dict):
72 | # old ver to new ver
73 | changes = {
74 | 'linear1.bias': 'linear.0.bias',
75 | 'linear1.weight': 'linear.0.weight',
76 | 'linear2.bias': 'linear.1.bias',
77 | 'linear2.weight': 'linear.1.weight',
78 | }
79 | for key_from, key_to in changes.items():
80 | if key_from in state_dict:
81 | state_dict[key_to] = state_dict[key_from]
82 | del state_dict[key_from]
83 |
84 | for size, sd in state_dict.items():
85 | if type(size) == int:
86 | self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
87 | self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
88 | return True
89 |
90 | def get_state_dict(self):
91 | state_dict = {}
92 | for i, size in enumerate(Hypernetwork.enable_sizes):
93 | sd0 = self.modules[i][0].state_dict()
94 | sd1 = self.modules[i][1].state_dict()
95 | state_dict[size] = [sd0, sd1]
96 | return state_dict
97 |
--------------------------------------------------------------------------------
/scripts/stable/finetune/merge_captions_to_metadata.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from pathlib import Path
4 | from typing import List
5 | from tqdm import tqdm
6 | import library.train_util as train_util
7 | import os
8 | from library.utils import setup_logging
9 |
10 | setup_logging()
11 | import logging
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def main(args):
17 | assert not args.recursive or (
18 | args.recursive and args.full_path
19 | ), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
20 |
21 | train_data_dir_path = Path(args.train_data_dir)
22 | image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
23 | logger.info(f"found {len(image_paths)} images.")
24 |
25 | if args.in_json is None and Path(args.out_json).is_file():
26 | args.in_json = args.out_json
27 |
28 | if args.in_json is not None:
29 | logger.info(f"loading existing metadata: {args.in_json}")
30 | metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8"))
31 | logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
32 | else:
33 | logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
34 | metadata = {}
35 |
36 | logger.info("merge caption texts to metadata json.")
37 | for image_path in tqdm(image_paths):
38 | caption_path = image_path.with_suffix(args.caption_extension)
39 | caption = caption_path.read_text(encoding="utf-8").strip()
40 |
41 | if not os.path.exists(caption_path):
42 | caption_path = os.path.join(image_path, args.caption_extension)
43 |
44 | image_key = str(image_path) if args.full_path else image_path.stem
45 | if image_key not in metadata:
46 | metadata[image_key] = {}
47 |
48 | metadata[image_key]["caption"] = caption
49 | if args.debug:
50 | logger.info(f"{image_key} {caption}")
51 |
52 | # metadataを書き出して終わり
53 | logger.info(f"writing metadata: {args.out_json}")
54 | Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8")
55 | logger.info("done!")
56 |
57 |
58 | def setup_parser() -> argparse.ArgumentParser:
59 | parser = argparse.ArgumentParser()
60 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
61 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
62 | parser.add_argument(
63 | "--in_json",
64 | type=str,
65 | help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)",
66 | )
67 | parser.add_argument(
68 | "--caption_extention",
69 | type=str,
70 | default=None,
71 | help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
72 | )
73 | parser.add_argument(
74 | "--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子"
75 | )
76 | parser.add_argument(
77 | "--full_path",
78 | action="store_true",
79 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
80 | )
81 | parser.add_argument(
82 | "--recursive",
83 | action="store_true",
84 | help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
85 | )
86 | parser.add_argument("--debug", action="store_true", help="debug mode")
87 |
88 | return parser
89 |
90 |
91 | if __name__ == "__main__":
92 | parser = setup_parser()
93 |
94 | args = parser.parse_args()
95 |
96 | # スペルミスしていたオプションを復元する
97 | if args.caption_extention is not None:
98 | args.caption_extension = args.caption_extention
99 |
100 | main(args)
101 |
--------------------------------------------------------------------------------
/scripts/stable/finetune/merge_dd_tags_to_metadata.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from pathlib import Path
4 | from typing import List
5 | from tqdm import tqdm
6 | import library.train_util as train_util
7 | import os
8 | from library.utils import setup_logging
9 |
10 | setup_logging()
11 | import logging
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def main(args):
17 | assert not args.recursive or (
18 | args.recursive and args.full_path
19 | ), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
20 |
21 | train_data_dir_path = Path(args.train_data_dir)
22 | image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
23 | logger.info(f"found {len(image_paths)} images.")
24 |
25 | if args.in_json is None and Path(args.out_json).is_file():
26 | args.in_json = args.out_json
27 |
28 | if args.in_json is not None:
29 | logger.info(f"loading existing metadata: {args.in_json}")
30 | metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8"))
31 | logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
32 | else:
33 | logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
34 | metadata = {}
35 |
36 | logger.info("merge tags to metadata json.")
37 | for image_path in tqdm(image_paths):
38 | tags_path = image_path.with_suffix(args.caption_extension)
39 | tags = tags_path.read_text(encoding="utf-8").strip()
40 |
41 | if not os.path.exists(tags_path):
42 | tags_path = os.path.join(image_path, args.caption_extension)
43 |
44 | image_key = str(image_path) if args.full_path else image_path.stem
45 | if image_key not in metadata:
46 | metadata[image_key] = {}
47 |
48 | metadata[image_key]["tags"] = tags
49 | if args.debug:
50 | logger.info(f"{image_key} {tags}")
51 |
52 | # metadataを書き出して終わり
53 | logger.info(f"writing metadata: {args.out_json}")
54 | Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8")
55 |
56 | logger.info("done!")
57 |
58 |
59 | def setup_parser() -> argparse.ArgumentParser:
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
62 | parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
63 | parser.add_argument(
64 | "--in_json",
65 | type=str,
66 | help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)",
67 | )
68 | parser.add_argument(
69 | "--full_path",
70 | action="store_true",
71 | help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
72 | )
73 | parser.add_argument(
74 | "--recursive",
75 | action="store_true",
76 | help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
77 | )
78 | parser.add_argument(
79 | "--caption_extension",
80 | type=str,
81 | default=".txt",
82 | help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子",
83 | )
84 | parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
85 |
86 | return parser
87 |
88 |
89 | if __name__ == "__main__":
90 | parser = setup_parser()
91 |
92 | args = parser.parse_args()
93 | main(args)
94 |
--------------------------------------------------------------------------------
/scripts/stable/library/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Akegarasu/lora-scripts/e0f5194815203093659d6ec280b9362b9792c070/scripts/stable/library/__init__.py
--------------------------------------------------------------------------------
/scripts/stable/library/adafactor_fused.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from transformers import Adafactor
4 |
5 | @torch.no_grad()
6 | def adafactor_step_param(self, p, group):
7 | if p.grad is None:
8 | return
9 | grad = p.grad
10 | if grad.dtype in {torch.float16, torch.bfloat16}:
11 | grad = grad.float()
12 | if grad.is_sparse:
13 | raise RuntimeError("Adafactor does not support sparse gradients.")
14 |
15 | state = self.state[p]
16 | grad_shape = grad.shape
17 |
18 | factored, use_first_moment = Adafactor._get_options(group, grad_shape)
19 | # State Initialization
20 | if len(state) == 0:
21 | state["step"] = 0
22 |
23 | if use_first_moment:
24 | # Exponential moving average of gradient values
25 | state["exp_avg"] = torch.zeros_like(grad)
26 | if factored:
27 | state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
28 | state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
29 | else:
30 | state["exp_avg_sq"] = torch.zeros_like(grad)
31 |
32 | state["RMS"] = 0
33 | else:
34 | if use_first_moment:
35 | state["exp_avg"] = state["exp_avg"].to(grad)
36 | if factored:
37 | state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
38 | state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
39 | else:
40 | state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
41 |
42 | p_data_fp32 = p
43 | if p.dtype in {torch.float16, torch.bfloat16}:
44 | p_data_fp32 = p_data_fp32.float()
45 |
46 | state["step"] += 1
47 | state["RMS"] = Adafactor._rms(p_data_fp32)
48 | lr = Adafactor._get_lr(group, state)
49 |
50 | beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
51 | update = (grad ** 2) + group["eps"][0]
52 | if factored:
53 | exp_avg_sq_row = state["exp_avg_sq_row"]
54 | exp_avg_sq_col = state["exp_avg_sq_col"]
55 |
56 | exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
57 | exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
58 |
59 | # Approximation of exponential moving average of square of gradient
60 | update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
61 | update.mul_(grad)
62 | else:
63 | exp_avg_sq = state["exp_avg_sq"]
64 |
65 | exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
66 | update = exp_avg_sq.rsqrt().mul_(grad)
67 |
68 | update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
69 | update.mul_(lr)
70 |
71 | if use_first_moment:
72 | exp_avg = state["exp_avg"]
73 | exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
74 | update = exp_avg
75 |
76 | if group["weight_decay"] != 0:
77 | p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
78 |
79 | p_data_fp32.add_(-update)
80 |
81 | if p.dtype in {torch.float16, torch.bfloat16}:
82 | p.copy_(p_data_fp32)
83 |
84 |
85 | @torch.no_grad()
86 | def adafactor_step(self, closure=None):
87 | """
88 | Performs a single optimization step
89 |
90 | Arguments:
91 | closure (callable, optional): A closure that reevaluates the model
92 | and returns the loss.
93 | """
94 | loss = None
95 | if closure is not None:
96 | loss = closure()
97 |
98 | for group in self.param_groups:
99 | for p in group["params"]:
100 | adafactor_step_param(self, p, group)
101 |
102 | return loss
103 |
104 | def patch_adafactor_fused(optimizer: Adafactor):
105 | optimizer.step_param = adafactor_step_param.__get__(optimizer)
106 | optimizer.step = adafactor_step.__get__(optimizer)
107 |
--------------------------------------------------------------------------------
/scripts/stable/library/deepspeed_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | from accelerate import DeepSpeedPlugin, Accelerator
5 |
6 | from .utils import setup_logging
7 |
8 | setup_logging()
9 | import logging
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | def add_deepspeed_arguments(parser: argparse.ArgumentParser):
15 | # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
16 | parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
17 | parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
18 | parser.add_argument(
19 | "--offload_optimizer_device",
20 | type=str,
21 | default=None,
22 | choices=[None, "cpu", "nvme"],
23 | help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
24 | )
25 | parser.add_argument(
26 | "--offload_optimizer_nvme_path",
27 | type=str,
28 | default=None,
29 | help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
30 | )
31 | parser.add_argument(
32 | "--offload_param_device",
33 | type=str,
34 | default=None,
35 | choices=[None, "cpu", "nvme"],
36 | help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
37 | )
38 | parser.add_argument(
39 | "--offload_param_nvme_path",
40 | type=str,
41 | default=None,
42 | help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
43 | )
44 | parser.add_argument(
45 | "--zero3_init_flag",
46 | action="store_true",
47 | help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
48 | "Only applicable with ZeRO Stage-3.",
49 | )
50 | parser.add_argument(
51 | "--zero3_save_16bit_model",
52 | action="store_true",
53 | help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
54 | )
55 | parser.add_argument(
56 | "--fp16_master_weights_and_gradients",
57 | action="store_true",
58 | help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
59 | )
60 |
61 |
62 | def prepare_deepspeed_args(args: argparse.Namespace):
63 | if not args.deepspeed:
64 | return
65 |
66 | # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
67 | args.max_data_loader_n_workers = 1
68 |
69 |
70 | def prepare_deepspeed_plugin(args: argparse.Namespace):
71 | if not args.deepspeed:
72 | return None
73 |
74 | try:
75 | import deepspeed
76 | except ImportError as e:
77 | logger.error(
78 | "deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
79 | )
80 | exit(1)
81 |
82 | deepspeed_plugin = DeepSpeedPlugin(
83 | zero_stage=args.zero_stage,
84 | gradient_accumulation_steps=args.gradient_accumulation_steps,
85 | gradient_clipping=args.max_grad_norm,
86 | offload_optimizer_device=args.offload_optimizer_device,
87 | offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
88 | offload_param_device=args.offload_param_device,
89 | offload_param_nvme_path=args.offload_param_nvme_path,
90 | zero3_init_flag=args.zero3_init_flag,
91 | zero3_save_16bit_model=args.zero3_save_16bit_model,
92 | )
93 | deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
94 | deepspeed_plugin.deepspeed_config["train_batch_size"] = (
95 | args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
96 | )
97 | deepspeed_plugin.set_mixed_precision(args.mixed_precision)
98 | if args.mixed_precision.lower() == "fp16":
99 | deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
100 | if args.full_fp16 or args.fp16_master_weights_and_gradients:
101 | if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
102 | deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
103 | logger.info("[DeepSpeed] full fp16 enable.")
104 | else:
105 | logger.info(
106 | "[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
107 | )
108 |
109 | if args.offload_optimizer_device is not None:
110 | logger.info("[DeepSpeed] start to manually build cpu_adam.")
111 | deepspeed.ops.op_builder.CPUAdamBuilder().load()
112 | logger.info("[DeepSpeed] building cpu_adam done.")
113 |
114 | return deepspeed_plugin
115 |
116 |
117 | # Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
118 | def prepare_deepspeed_model(args: argparse.Namespace, **models):
119 | # remove None from models
120 | models = {k: v for k, v in models.items() if v is not None}
121 |
122 | class DeepSpeedWrapper(torch.nn.Module):
123 | def __init__(self, **kw_models) -> None:
124 | super().__init__()
125 | self.models = torch.nn.ModuleDict()
126 |
127 | for key, model in kw_models.items():
128 | if isinstance(model, list):
129 | model = torch.nn.ModuleList(model)
130 | assert isinstance(
131 | model, torch.nn.Module
132 | ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
133 | self.models.update(torch.nn.ModuleDict({key: model}))
134 |
135 | def get_models(self):
136 | return self.models
137 |
138 | ds_model = DeepSpeedWrapper(**models)
139 | return ds_model
140 |
--------------------------------------------------------------------------------
/scripts/stable/library/device_utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import gc
3 |
4 | import torch
5 |
6 | try:
7 | HAS_CUDA = torch.cuda.is_available()
8 | except Exception:
9 | HAS_CUDA = False
10 |
11 | try:
12 | HAS_MPS = torch.backends.mps.is_available()
13 | except Exception:
14 | HAS_MPS = False
15 |
16 | try:
17 | import intel_extension_for_pytorch as ipex # noqa
18 |
19 | HAS_XPU = torch.xpu.is_available()
20 | except Exception:
21 | HAS_XPU = False
22 |
23 |
24 | def clean_memory():
25 | gc.collect()
26 | if HAS_CUDA:
27 | torch.cuda.empty_cache()
28 | if HAS_XPU:
29 | torch.xpu.empty_cache()
30 | if HAS_MPS:
31 | torch.mps.empty_cache()
32 |
33 |
34 | def clean_memory_on_device(device: torch.device):
35 | r"""
36 | Clean memory on the specified device, will be called from training scripts.
37 | """
38 | gc.collect()
39 |
40 | # device may "cuda" or "cuda:0", so we need to check the type of device
41 | if device.type == "cuda":
42 | torch.cuda.empty_cache()
43 | if device.type == "xpu":
44 | torch.xpu.empty_cache()
45 | if device.type == "mps":
46 | torch.mps.empty_cache()
47 |
48 |
49 | @functools.lru_cache(maxsize=None)
50 | def get_preferred_device() -> torch.device:
51 | r"""
52 | Do not call this function from training scripts. Use accelerator.device instead.
53 | """
54 | if HAS_CUDA:
55 | device = torch.device("cuda")
56 | elif HAS_XPU:
57 | device = torch.device("xpu")
58 | elif HAS_MPS:
59 | device = torch.device("mps")
60 | else:
61 | device = torch.device("cpu")
62 | print(f"get_preferred_device() -> {device}")
63 | return device
64 |
65 |
66 | def init_ipex():
67 | """
68 | Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
69 |
70 | This function should run right after importing torch and before doing anything else.
71 |
72 | If IPEX is not available, this function does nothing.
73 | """
74 | try:
75 | if HAS_XPU:
76 | from library.ipex import ipex_init
77 |
78 | is_initialized, error_message = ipex_init()
79 | if not is_initialized:
80 | print("failed to initialize ipex:", error_message)
81 | else:
82 | return
83 | except Exception as e:
84 | print("failed to initialize ipex:", e)
85 |
--------------------------------------------------------------------------------
/scripts/stable/library/huggingface_util.py:
--------------------------------------------------------------------------------
1 | from typing import Union, BinaryIO
2 | from huggingface_hub import HfApi
3 | from pathlib import Path
4 | import argparse
5 | import os
6 | from library.utils import fire_in_thread
7 | from library.utils import setup_logging
8 | setup_logging()
9 | import logging
10 | logger = logging.getLogger(__name__)
11 |
12 | def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
13 | api = HfApi(
14 | token=token,
15 | )
16 | try:
17 | api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
18 | return True
19 | except:
20 | return False
21 |
22 |
23 | def upload(
24 | args: argparse.Namespace,
25 | src: Union[str, Path, bytes, BinaryIO],
26 | dest_suffix: str = "",
27 | force_sync_upload: bool = False,
28 | ):
29 | repo_id = args.huggingface_repo_id
30 | repo_type = args.huggingface_repo_type
31 | token = args.huggingface_token
32 | path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
33 | private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
34 | api = HfApi(token=token)
35 | if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
36 | try:
37 | api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
38 | except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
39 | logger.error("===========================================")
40 | logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
41 | logger.error("===========================================")
42 |
43 | is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
44 |
45 | def uploader():
46 | try:
47 | if is_folder:
48 | api.upload_folder(
49 | repo_id=repo_id,
50 | repo_type=repo_type,
51 | folder_path=src,
52 | path_in_repo=path_in_repo,
53 | )
54 | else:
55 | api.upload_file(
56 | repo_id=repo_id,
57 | repo_type=repo_type,
58 | path_or_fileobj=src,
59 | path_in_repo=path_in_repo,
60 | )
61 | except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
62 | logger.error("===========================================")
63 | logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
64 | logger.error("===========================================")
65 |
66 | if args.async_upload and not force_sync_upload:
67 | fire_in_thread(uploader)
68 | else:
69 | uploader()
70 |
71 |
72 | def list_dir(
73 | repo_id: str,
74 | subfolder: str,
75 | repo_type: str,
76 | revision: str = "main",
77 | token: str = None,
78 | ):
79 | api = HfApi(
80 | token=token,
81 | )
82 | repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
83 | file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
84 | return file_list
85 |
--------------------------------------------------------------------------------
/scripts/stable/networks/check_lora_weights.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from safetensors.torch import load_file
5 | from library.utils import setup_logging
6 | setup_logging()
7 | import logging
8 | logger = logging.getLogger(__name__)
9 |
10 | def main(file):
11 | logger.info(f"loading: {file}")
12 | if os.path.splitext(file)[1] == ".safetensors":
13 | sd = load_file(file)
14 | else:
15 | sd = torch.load(file, map_location="cpu")
16 |
17 | values = []
18 |
19 | keys = list(sd.keys())
20 | for key in keys:
21 | if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key:
22 | values.append((key, sd[key]))
23 | print(f"number of LoRA modules: {len(values)}")
24 |
25 | if args.show_all_keys:
26 | for key in [k for k in keys if k not in values]:
27 | values.append((key, sd[key]))
28 | print(f"number of all modules: {len(values)}")
29 |
30 | for key, value in values:
31 | value = value.to(torch.float32)
32 | print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
33 |
34 |
35 | def setup_parser() -> argparse.ArgumentParser:
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
38 | parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する")
39 |
40 | return parser
41 |
42 |
43 | if __name__ == "__main__":
44 | parser = setup_parser()
45 |
46 | args = parser.parse_args()
47 |
48 | main(args.file)
49 |
--------------------------------------------------------------------------------
/scripts/stable/networks/extract_lora_from_dylora.py:
--------------------------------------------------------------------------------
1 | # Convert LoRA to different rank approximation (should only be used to go to lower rank)
2 | # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3 | # Thanks to cloneofsimo
4 |
5 | import argparse
6 | import math
7 | import os
8 | import torch
9 | from safetensors.torch import load_file, save_file, safe_open
10 | from tqdm import tqdm
11 | from library import train_util, model_util
12 | import numpy as np
13 | from library.utils import setup_logging
14 | setup_logging()
15 | import logging
16 | logger = logging.getLogger(__name__)
17 |
18 | def load_state_dict(file_name):
19 | if model_util.is_safetensors(file_name):
20 | sd = load_file(file_name)
21 | with safe_open(file_name, framework="pt") as f:
22 | metadata = f.metadata()
23 | else:
24 | sd = torch.load(file_name, map_location="cpu")
25 | metadata = None
26 |
27 | return sd, metadata
28 |
29 |
30 | def save_to_file(file_name, model, metadata):
31 | if model_util.is_safetensors(file_name):
32 | save_file(model, file_name, metadata)
33 | else:
34 | torch.save(model, file_name)
35 |
36 |
37 | def split_lora_model(lora_sd, unit):
38 | max_rank = 0
39 |
40 | # Extract loaded lora dim and alpha
41 | for key, value in lora_sd.items():
42 | if "lora_down" in key:
43 | rank = value.size()[0]
44 | if rank > max_rank:
45 | max_rank = rank
46 | logger.info(f"Max rank: {max_rank}")
47 |
48 | rank = unit
49 | split_models = []
50 | new_alpha = None
51 | while rank < max_rank:
52 | logger.info(f"Splitting rank {rank}")
53 | new_sd = {}
54 | for key, value in lora_sd.items():
55 | if "lora_down" in key:
56 | new_sd[key] = value[:rank].contiguous()
57 | elif "lora_up" in key:
58 | new_sd[key] = value[:, :rank].contiguous()
59 | else:
60 | # なぜかscaleするとおかしくなる……
61 | # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
62 | # scale = math.sqrt(this_rank / rank) # rank is > unit
63 | # logger.info(key, value.size(), this_rank, rank, value, scale)
64 | # new_alpha = value * scale # always same
65 | # new_sd[key] = new_alpha
66 | new_sd[key] = value
67 |
68 | split_models.append((new_sd, rank, new_alpha))
69 | rank += unit
70 |
71 | return max_rank, split_models
72 |
73 |
74 | def split(args):
75 | logger.info("loading Model...")
76 | lora_sd, metadata = load_state_dict(args.model)
77 |
78 | logger.info("Splitting Model...")
79 | original_rank, split_models = split_lora_model(lora_sd, args.unit)
80 |
81 | comment = metadata.get("ss_training_comment", "")
82 | for state_dict, new_rank, new_alpha in split_models:
83 | # update metadata
84 | if metadata is None:
85 | new_metadata = {}
86 | else:
87 | new_metadata = metadata.copy()
88 |
89 | new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
90 | new_metadata["ss_network_dim"] = str(new_rank)
91 | # new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy())
92 |
93 | model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
94 | metadata["sshs_model_hash"] = model_hash
95 | metadata["sshs_legacy_hash"] = legacy_hash
96 |
97 | filename, ext = os.path.splitext(args.save_to)
98 | model_file_name = filename + f"-{new_rank:04d}{ext}"
99 |
100 | logger.info(f"saving model to: {model_file_name}")
101 | save_to_file(model_file_name, state_dict, new_metadata)
102 |
103 |
104 | def setup_parser() -> argparse.ArgumentParser:
105 | parser = argparse.ArgumentParser()
106 |
107 | parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ")
108 | parser.add_argument(
109 | "--save_to",
110 | type=str,
111 | default=None,
112 | help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors",
113 | )
114 | parser.add_argument(
115 | "--model",
116 | type=str,
117 | default=None,
118 | help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors",
119 | )
120 |
121 | return parser
122 |
123 |
124 | if __name__ == "__main__":
125 | parser = setup_parser()
126 |
127 | args = parser.parse_args()
128 | split(args)
129 |
--------------------------------------------------------------------------------
/scripts/stable/networks/lora_interrogator.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from tqdm import tqdm
4 | from library import model_util
5 | import library.train_util as train_util
6 | import argparse
7 | from transformers import CLIPTokenizer
8 |
9 | import torch
10 | from library.device_utils import init_ipex, get_preferred_device
11 | init_ipex()
12 |
13 | import library.model_util as model_util
14 | import lora
15 | from library.utils import setup_logging
16 | setup_logging()
17 | import logging
18 | logger = logging.getLogger(__name__)
19 |
20 | TOKENIZER_PATH = "openai/clip-vit-large-patch14"
21 | V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
22 |
23 | DEVICE = get_preferred_device()
24 |
25 |
26 | def interrogate(args):
27 | weights_dtype = torch.float16
28 |
29 | # いろいろ準備する
30 | logger.info(f"loading SD model: {args.sd_model}")
31 | args.pretrained_model_name_or_path = args.sd_model
32 | args.vae = None
33 | text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
34 |
35 | logger.info(f"loading LoRA: {args.model}")
36 | network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
37 |
38 | # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
39 | has_te_weight = False
40 | for key in weights_sd.keys():
41 | if 'lora_te' in key:
42 | has_te_weight = True
43 | break
44 | if not has_te_weight:
45 | logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
46 | return
47 | del vae
48 |
49 | logger.info("loading tokenizer")
50 | if args.v2:
51 | tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
52 | else:
53 | tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
54 |
55 | text_encoder.to(DEVICE, dtype=weights_dtype)
56 | text_encoder.eval()
57 | unet.to(DEVICE, dtype=weights_dtype)
58 | unet.eval() # U-Netは呼び出さないので不要だけど
59 |
60 | # トークンをひとつひとつ当たっていく
61 | token_id_start = 0
62 | token_id_end = max(tokenizer.all_special_ids)
63 | logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}")
64 |
65 | def get_all_embeddings(text_encoder):
66 | embs = []
67 | with torch.no_grad():
68 | for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
69 | batch = []
70 | for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
71 | tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
72 | # tokens = [tid] # こちらは結果がいまひとつ
73 | batch.append(tokens)
74 |
75 | # batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
76 | # clip skip対応
77 | batch = torch.tensor(batch).to(DEVICE)
78 | if args.clip_skip is None:
79 | encoder_hidden_states = text_encoder(batch)[0]
80 | else:
81 | enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
82 | encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
83 | encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
84 | encoder_hidden_states = encoder_hidden_states.to("cpu")
85 |
86 | embs.extend(encoder_hidden_states)
87 | return torch.stack(embs)
88 |
89 | logger.info("get original text encoder embeddings.")
90 | orig_embs = get_all_embeddings(text_encoder)
91 |
92 | network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
93 | info = network.load_state_dict(weights_sd, strict=False)
94 | logger.info(f"Loading LoRA weights: {info}")
95 |
96 | network.to(DEVICE, dtype=weights_dtype)
97 | network.eval()
98 |
99 | del unet
100 |
101 | logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
102 | logger.info("get text encoder embeddings with lora.")
103 | lora_embs = get_all_embeddings(text_encoder)
104 |
105 | # 比べる:とりあえず単純に差分の絶対値で
106 | logger.info("comparing...")
107 | diffs = {}
108 | for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
109 | diff = torch.mean(torch.abs(orig_emb - lora_emb))
110 | # diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
111 | diff = float(diff.detach().to('cpu').numpy())
112 | diffs[token_id_start + i] = diff
113 |
114 | diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
115 |
116 | # 結果を表示する
117 | print("top 100:")
118 | for i, (token, diff) in enumerate(diffs_sorted[:100]):
119 | # if diff < 1e-6:
120 | # break
121 | string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
122 | print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
123 |
124 |
125 | def setup_parser() -> argparse.ArgumentParser:
126 | parser = argparse.ArgumentParser()
127 |
128 | parser.add_argument("--v2", action='store_true',
129 | help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
130 | parser.add_argument("--sd_model", type=str, default=None,
131 | help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
132 | parser.add_argument("--model", type=str, default=None,
133 | help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
134 | parser.add_argument("--batch_size", type=int, default=16,
135 | help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
136 | parser.add_argument("--clip_skip", type=int, default=None,
137 | help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
138 |
139 | return parser
140 |
141 |
142 | if __name__ == '__main__':
143 | parser = setup_parser()
144 |
145 | args = parser.parse_args()
146 | interrogate(args)
147 |
--------------------------------------------------------------------------------
/scripts/stable/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.30.0
2 | transformers==4.44.0
3 | diffusers[torch]==0.25.0
4 | ftfy==6.1.1
5 | # albumentations==1.3.0
6 | opencv-python==4.8.1.78
7 | einops==0.7.0
8 | pytorch-lightning==1.9.0
9 | bitsandbytes==0.44.0
10 | prodigyopt==1.0
11 | lion-pytorch==0.0.6
12 | tensorboard
13 | safetensors==0.4.2
14 | # gradio==3.16.2
15 | altair==4.2.2
16 | easygui==0.98.3
17 | toml==0.10.2
18 | voluptuous==0.13.1
19 | huggingface-hub==0.24.5
20 | # for Image utils
21 | imagesize==1.4.1
22 | # for BLIP captioning
23 | # requests==2.28.2
24 | # timm==0.6.12
25 | # fairscale==0.4.13
26 | # for WD14 captioning (tensorflow)
27 | # tensorflow==2.10.1
28 | # for WD14 captioning (onnx)
29 | # onnx==1.15.0
30 | # onnxruntime-gpu==1.17.1
31 | # onnxruntime==1.17.1
32 | # for cuda 12.1(default 11.8)
33 | # onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
34 |
35 | # this is for onnx:
36 | # protobuf==3.20.3
37 | # open clip for SDXL
38 | # open-clip-torch==2.20.0
39 | # For logging
40 | rich==13.7.0
41 | # for kohya_ss library
42 | -e .
43 |
--------------------------------------------------------------------------------
/scripts/stable/sdxl_train_textual_inversion.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import regex
5 |
6 | import torch
7 | from library.device_utils import init_ipex
8 | init_ipex()
9 |
10 | from library import sdxl_model_util, sdxl_train_util, train_util
11 |
12 | import train_textual_inversion
13 |
14 |
15 | class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer):
16 | def __init__(self):
17 | super().__init__()
18 | self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
19 | self.is_sdxl = True
20 |
21 | def assert_extra_args(self, args, train_dataset_group):
22 | super().assert_extra_args(args, train_dataset_group)
23 | sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
24 |
25 | train_dataset_group.verify_bucket_reso_steps(32)
26 |
27 | def load_target_model(self, args, weight_dtype, accelerator):
28 | (
29 | load_stable_diffusion_format,
30 | text_encoder1,
31 | text_encoder2,
32 | vae,
33 | unet,
34 | logit_scale,
35 | ckpt_info,
36 | ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
37 |
38 | self.load_stable_diffusion_format = load_stable_diffusion_format
39 | self.logit_scale = logit_scale
40 | self.ckpt_info = ckpt_info
41 |
42 | return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
43 |
44 | def load_tokenizer(self, args):
45 | tokenizer = sdxl_train_util.load_tokenizers(args)
46 | return tokenizer
47 |
48 | def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
49 | input_ids1 = batch["input_ids"]
50 | input_ids2 = batch["input_ids2"]
51 | with torch.enable_grad():
52 | input_ids1 = input_ids1.to(accelerator.device)
53 | input_ids2 = input_ids2.to(accelerator.device)
54 | encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
55 | args.max_token_length,
56 | input_ids1,
57 | input_ids2,
58 | tokenizers[0],
59 | tokenizers[1],
60 | text_encoders[0],
61 | text_encoders[1],
62 | None if not args.full_fp16 else weight_dtype,
63 | accelerator=accelerator,
64 | )
65 | return encoder_hidden_states1, encoder_hidden_states2, pool2
66 |
67 | def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
68 | noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
69 |
70 | # get size embeddings
71 | orig_size = batch["original_sizes_hw"]
72 | crop_size = batch["crop_top_lefts"]
73 | target_size = batch["target_sizes_hw"]
74 | embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
75 |
76 | # concat embeddings
77 | encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
78 | vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
79 | text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
80 |
81 | noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
82 | return noise_pred
83 |
84 | def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
85 | sdxl_train_util.sample_images(
86 | accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
87 | )
88 |
89 | def save_weights(self, file, updated_embs, save_dtype, metadata):
90 | state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]}
91 |
92 | if save_dtype is not None:
93 | for key in list(state_dict.keys()):
94 | v = state_dict[key]
95 | v = v.detach().clone().to("cpu").to(save_dtype)
96 | state_dict[key] = v
97 |
98 | if os.path.splitext(file)[1] == ".safetensors":
99 | from safetensors.torch import save_file
100 |
101 | save_file(state_dict, file, metadata)
102 | else:
103 | torch.save(state_dict, file)
104 |
105 | def load_weights(self, file):
106 | if os.path.splitext(file)[1] == ".safetensors":
107 | from safetensors.torch import load_file
108 |
109 | data = load_file(file)
110 | else:
111 | data = torch.load(file, map_location="cpu")
112 |
113 | emb_l = data.get("clip_l", None) # ViT-L text encoder 1
114 | emb_g = data.get("clip_g", None) # BiG-G text encoder 2
115 |
116 | assert (
117 | emb_l is not None or emb_g is not None
118 | ), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}"
119 |
120 | return [emb_l, emb_g]
121 |
122 |
123 | def setup_parser() -> argparse.ArgumentParser:
124 | parser = train_textual_inversion.setup_parser()
125 | # don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
126 | # sdxl_train_util.add_sdxl_training_arguments(parser)
127 | return parser
128 |
129 |
130 | if __name__ == "__main__":
131 | parser = setup_parser()
132 |
133 | args = parser.parse_args()
134 | train_util.verify_command_line_training_args(args)
135 | args = train_util.read_config_from_file(args, parser)
136 |
137 | trainer = SdxlTextualInversionTrainer()
138 | trainer.train(args)
139 |
--------------------------------------------------------------------------------
/scripts/stable/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(name = "library", packages = find_packages())
--------------------------------------------------------------------------------
/scripts/stable/tools/canny.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 |
4 | import logging
5 | from library.utils import setup_logging
6 | setup_logging()
7 | logger = logging.getLogger(__name__)
8 |
9 | def canny(args):
10 | img = cv2.imread(args.input)
11 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
12 |
13 | canny_img = cv2.Canny(img, args.thres1, args.thres2)
14 | # canny_img = 255 - canny_img
15 |
16 | cv2.imwrite(args.output, canny_img)
17 | logger.info("done!")
18 |
19 |
20 | def setup_parser() -> argparse.ArgumentParser:
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument("--input", type=str, default=None, help="input path")
23 | parser.add_argument("--output", type=str, default=None, help="output path")
24 | parser.add_argument("--thres1", type=int, default=32, help="thres1")
25 | parser.add_argument("--thres2", type=int, default=224, help="thres2")
26 |
27 | return parser
28 |
29 |
30 | if __name__ == '__main__':
31 | parser = setup_parser()
32 |
33 | args = parser.parse_args()
34 | canny(args)
35 |
--------------------------------------------------------------------------------
/scripts/stable/tools/show_metadata.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from safetensors import safe_open
4 | from library.utils import setup_logging
5 | setup_logging()
6 | import logging
7 | logger = logging.getLogger(__name__)
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--model", type=str, required=True)
11 | args = parser.parse_args()
12 |
13 | with safe_open(args.model, framework="pt") as f:
14 | metadata = f.metadata()
15 |
16 | if metadata is None:
17 | logger.error("No metadata found")
18 | else:
19 | # metadata is json dict, but not pretty printed
20 | # sort by key and pretty print
21 | print(json.dumps(metadata, indent=4, sort_keys=True))
22 |
23 |
24 |
--------------------------------------------------------------------------------
/sd-models/put stable diffusion model here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Akegarasu/lora-scripts/e0f5194815203093659d6ec280b9362b9792c070/sd-models/put stable diffusion model here.txt
--------------------------------------------------------------------------------
/svd_merge.ps1:
--------------------------------------------------------------------------------
1 | # LoRA svd_merge script by @bdsqlsz
2 |
3 | $save_precision = "fp16" # precision in saving, default float | 保存精度, 可选 float、fp16、bf16, 默认 和源文件相同
4 | $precision = "float" # precision in merging (float is recommended) | 合并时计算精度, 可选 float、fp16、bf16, 推荐float
5 | $new_rank = 4 # dim rank of output LoRA | dim rank等级, 默认 4
6 | $models = "./output/modelA.safetensors ./output/modelB.safetensors" # original LoRA model path need to resize, save as cpkt or safetensors | 需要合并的模型路径, 保存格式 cpkt 或 safetensors,多个用空格隔开
7 | $ratios = "1.0 -1.0" # ratios for each model / LoRA模型合并比例,数量等于模型数量,多个用空格隔开
8 | $save_to = "./output/lora_name_new.safetensors" # output LoRA model path, save as ckpt or safetensors | 输出路径, 保存格式 cpkt 或 safetensors
9 | $device = "cuda" # device to use, cuda for GPU | 使用 GPU跑, 默认 CPU
10 | $new_conv_rank = 0 # Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank | Conv2d 3x3输出,没有默认同new_rank
11 |
12 | # Activate python venv
13 | .\venv\Scripts\activate
14 |
15 | $Env:HF_HOME = "huggingface"
16 | $Env:XFORMERS_FORCE_DISABLE_TRITON = "1"
17 | $ext_args = [System.Collections.ArrayList]::new()
18 |
19 | [void]$ext_args.Add("--models")
20 | foreach ($model in $models.Split(" ")) {
21 | [void]$ext_args.Add($model)
22 | }
23 |
24 | [void]$ext_args.Add("--ratios")
25 | foreach ($ratio in $ratios.Split(" ")) {
26 | [void]$ext_args.Add([float]$ratio)
27 | }
28 |
29 | if ($new_conv_rank) {
30 | [void]$ext_args.Add("--new_conv_rank=" + $new_conv_rank)
31 | }
32 |
33 | # run svd_merge
34 | accelerate launch --num_cpu_threads_per_process=8 "./scripts/stable/networks/svd_merge_lora.py" `
35 | --save_precision=$save_precision `
36 | --precision=$precision `
37 | --new_rank=$new_rank `
38 | --save_to=$save_to `
39 | --device=$device `
40 | $ext_args
41 |
42 | Write-Output "SVD Merge finished"
43 | Read-Host | Out-Null ;
44 |
--------------------------------------------------------------------------------
/tagger.ps1:
--------------------------------------------------------------------------------
1 | # tagger script by @bdsqlsz
2 |
3 | # Train data path
4 | $train_data_dir = "./input" # input images path | 图片输入路径
5 | $repo_id = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" # model repo id from huggingface |huggingface模型repoID
6 | $model_dir = "" # model dir path | 本地模型文件夹路径
7 | $batch_size = 4 # batch size in inference 批处理大小,越大越快
8 | $max_data_loader_n_workers = 0 # enable image reading by DataLoader with this number of workers (faster) | 0最快
9 | $thresh = 0.35 # concept thresh | 最小识别阈值
10 | $general_threshold = 0.35 # general threshold | 总体识别阈值
11 | $character_threshold = 0.1 # character threshold | 人物姓名识别阈值
12 | $remove_underscore = 0 # remove_underscore | 下划线转空格,1为开,0为关
13 | $undesired_tags = "" # no need tags | 排除标签
14 | $recursive = 0 # search for images in subfolders recursively | 递归搜索下层文件夹,1为开,0为关
15 | $frequency_tags = 0 # order by frequency tags | 从大到小按识别率排序标签,1为开,0为关
16 |
17 |
18 | # Activate python venv
19 | .\venv\Scripts\activate
20 |
21 | $Env:HF_HOME = "huggingface"
22 | $Env:XFORMERS_FORCE_DISABLE_TRITON = "1"
23 | $ext_args = [System.Collections.ArrayList]::new()
24 |
25 | if ($repo_id) {
26 | [void]$ext_args.Add("--repo_id=" + $repo_id)
27 | }
28 |
29 | if ($model_dir) {
30 | [void]$ext_args.Add("--model_dir=" + $model_dir)
31 | }
32 |
33 | if ($batch_size) {
34 | [void]$ext_args.Add("--batch_size=" + $batch_size)
35 | }
36 |
37 | if ($max_data_loader_n_workers) {
38 | [void]$ext_args.Add("--max_data_loader_n_workers=" + $max_data_loader_n_workers)
39 | }
40 |
41 | if ($general_threshold) {
42 | [void]$ext_args.Add("--general_threshold=" + $general_threshold)
43 | }
44 |
45 | if ($character_threshold) {
46 | [void]$ext_args.Add("--character_threshold=" + $character_threshold)
47 | }
48 |
49 | if ($remove_underscore) {
50 | [void]$ext_args.Add("--remove_underscore")
51 | }
52 |
53 | if ($undesired_tags) {
54 | [void]$ext_args.Add("--undesired_tags=" + $undesired_tags)
55 | }
56 |
57 | if ($recursive) {
58 | [void]$ext_args.Add("--recursive")
59 | }
60 |
61 | if ($frequency_tags) {
62 | [void]$ext_args.Add("--frequency_tags")
63 | }
64 |
65 | # run tagger
66 | accelerate launch --num_cpu_threads_per_process=8 "./scripts/stable/finetune/tag_images_by_wd14_tagger.py" `
67 | $train_data_dir `
68 | --thresh=$thresh `
69 | --caption_extension .txt `
70 | $ext_args
71 |
72 | Write-Output "Tagger finished"
73 | Read-Host | Out-Null ;
74 |
--------------------------------------------------------------------------------
/tagger.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # tagger script by @bdsqlsz
3 | # Train data path
4 | train_data_dir="./input" # input images path | 图片输入路径
5 | repo_id="SmilingWolf/wd-v1-4-swinv2-tagger-v2" # model repo id from huggingface |huggingface模型repoID
6 | model_dir="" # model dir path | 本地模型文件夹路径
7 | batch_size=12 # batch size in inference 批处理大小,越大越快
8 | max_data_loader_n_workers=0 # enable image reading by DataLoader with this number of workers (faster) | 0最快
9 | thresh=0.35 # concept thresh | 最小识别阈值
10 | general_threshold=0.35 # general threshold | 总体识别阈值
11 | character_threshold=0.1 # character threshold | 人物姓名识别阈值
12 | remove_underscore=0 # remove_underscore | 下划线转空格,1为开,0为关
13 | undesired_tags="" # no need tags | 排除标签
14 | recursive=0 # search for images in subfolders recursively | 递归搜索下层文件夹,1为开,0为关
15 | frequency_tags=0 # order by frequency tags | 从大到小按识别率排序标签,1为开,0为关
16 |
17 |
18 | # ============= DO NOT MODIFY CONTENTS BELOW | 请勿修改下方内容 =====================
19 |
20 | export HF_HOME="huggingface"
21 | export TF_CPP_MIN_LOG_LEVEL=3
22 | extArgs=()
23 |
24 | if [ -n "$repo_id" ]; then
25 | extArgs+=( "--repo_id=$repo_id" )
26 | fi
27 |
28 | if [ -n "$model_dir" ]; then
29 | extArgs+=( "--model_dir=$model_dir" )
30 | fi
31 |
32 | if [[ $batch_size -ne 0 ]]; then
33 | extArgs+=( "--batch_size=$batch_size" )
34 | fi
35 |
36 | if [ -n "$max_data_loader_n_workers" ]; then
37 | extArgs+=( "--max_data_loader_n_workers=$max_data_loader_n_workers" )
38 | fi
39 |
40 | if [ -n "$general_threshold" ]; then
41 | extArgs+=( "--general_threshold=$general_threshold" )
42 | fi
43 |
44 | if [ -n "$character_threshold" ]; then
45 | extArgs+=( "--character_threshold=$character_threshold" )
46 | fi
47 |
48 | if [ "$remove_underscore" -eq 1 ]; then
49 | extArgs+=( "--remove_underscore" )
50 | fi
51 |
52 | if [ -n "$undesired_tags" ]; then
53 | extArgs+=( "--undesired_tags=$undesired_tags" )
54 | fi
55 |
56 | if [ "$recursive" -eq 1 ]; then
57 | extArgs+=( "--recursive" )
58 | fi
59 |
60 | if [ "$frequency_tags" -eq 1 ]; then
61 | extArgs+=( "--frequency_tags" )
62 | fi
63 |
64 |
65 | # run tagger
66 | accelerate launch --num_cpu_threads_per_process=8 "./scripts/stable/finetune/tag_images_by_wd14_tagger.py" \
67 | $train_data_dir \
68 | --thresh=$thresh \
69 | --caption_extension .txt \
70 | ${extArgs[@]}
71 |
--------------------------------------------------------------------------------
/tensorboard.ps1:
--------------------------------------------------------------------------------
1 | $Env:TF_CPP_MIN_LOG_LEVEL = "3"
2 |
3 | .\venv\Scripts\activate
4 | tensorboard --logdir=logs
--------------------------------------------------------------------------------
/train.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "pycharm": {
8 | "name": "#%%\n"
9 | }
10 | },
11 | "outputs": [],
12 | "source": [
13 | "# Train data path | 设置训练用模型、图片\n",
14 | "pretrained_model = \"./sd-models/model.ckpt\" # base model path | 底模路径\n",
15 | "train_data_dir = \"./train/aki\" # train dataset path | 训练数据集路径\n",
16 | "\n",
17 | "# Train related params | 训练相关参数\n",
18 | "resolution = \"512,512\" # image resolution w,h. 图片分辨率,宽,高。支持非正方形,但必须是 64 倍数。\n",
19 | "batch_size = 1 # batch size\n",
20 | "max_train_epoches = 10 # max train epoches | 最大训练 epoch\n",
21 | "save_every_n_epochs = 2 # save every n epochs | 每 N 个 epoch 保存一次\n",
22 | "network_dim = 32 # network dim | 常用 4~128,不是越大越好\n",
23 | "network_alpha= 32 # network alpha | 常用与 network_dim 相同的值或者采用较小的值,如 network_dim的一半 防止下溢。默认值为 1,使用较小的 alpha 需要提升学习率。\n",
24 | "clip_skip = 2 # clip skip | 玄学 一般用 2\n",
25 | "train_unet_only = 0 # train U-Net only | 仅训练 U-Net,开启这个会牺牲效果大幅减少显存使用。6G显存可以开启\n",
26 | "train_text_encoder_only = 0 # train Text Encoder only | 仅训练 文本编码器\n",
27 | "\n",
28 | "# Learning rate | 学习率\n",
29 | "lr = \"1e-4\"\n",
30 | "unet_lr = \"1e-4\"\n",
31 | "text_encoder_lr = \"1e-5\"\n",
32 | "lr_scheduler = \"cosine_with_restarts\" # \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"\n",
33 | "\n",
34 | "# Output settings | 输出设置\n",
35 | "output_name = \"aki\" # output model name | 模型保存名称\n",
36 | "save_model_as = \"safetensors\" # model save ext | 模型保存格式 ckpt, pt, safetensors"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {
43 | "pycharm": {
44 | "name": "#%%\n"
45 | }
46 | },
47 | "outputs": [],
48 | "source": [
49 | "!accelerate launch --num_cpu_threads_per_process=8 \"./scripts/train_network.py\" \\\n",
50 | " --enable_bucket \\\n",
51 | " --pretrained_model_name_or_path=$pretrained_model \\\n",
52 | " --train_data_dir=$train_data_dir \\\n",
53 | " --output_dir=\"./output\" \\\n",
54 | " --logging_dir=\"./logs\" \\\n",
55 | " --resolution=$resolution \\\n",
56 | " --network_module=networks.lora \\\n",
57 | " --max_train_epochs=$max_train_epoches \\\n",
58 | " --learning_rate=$lr \\\n",
59 | " --unet_lr=$unet_lr \\\n",
60 | " --text_encoder_lr=$text_encoder_lr \\\n",
61 | " --network_dim=$network_dim \\\n",
62 | " --network_alpha=$network_alpha \\\n",
63 | " --output_name=$output_name \\\n",
64 | " --lr_scheduler=$lr_scheduler \\\n",
65 | " --train_batch_size=$batch_size \\\n",
66 | " --save_every_n_epochs=$save_every_n_epochs \\\n",
67 | " --mixed_precision=\"fp16\" \\\n",
68 | " --save_precision=\"fp16\" \\\n",
69 | " --seed=\"1337\" \\\n",
70 | " --cache_latents \\\n",
71 | " --clip_skip=$clip_skip \\\n",
72 | " --prior_loss_weight=1 \\\n",
73 | " --max_token_length=225 \\\n",
74 | " --caption_extension=\".txt\" \\\n",
75 | " --save_model_as=$save_model_as \\\n",
76 | " --xformers --shuffle_caption --use_8bit_adam"
77 | ]
78 | }
79 | ],
80 | "metadata": {
81 | "kernelspec": {
82 | "display_name": "Python 3",
83 | "language": "python",
84 | "name": "python3"
85 | },
86 | "language_info": {
87 | "name": "python",
88 | "version": "3.10.7 (tags/v3.10.7:6cc6b13, Sep 5 2022, 14:08:36) [MSC v.1933 64 bit (AMD64)]"
89 | },
90 | "orig_nbformat": 4,
91 | "vscode": {
92 | "interpreter": {
93 | "hash": "675b13e958f0d0236d13cdfe08a1df3882cae564fa23a2e7e5eb1f2c6c632b02"
94 | }
95 | }
96 | },
97 | "nbformat": 4,
98 | "nbformat_minor": 2
99 | }
100 |
--------------------------------------------------------------------------------
/train_by_toml.ps1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Akegarasu/lora-scripts/e0f5194815203093659d6ec280b9362b9792c070/train_by_toml.ps1
--------------------------------------------------------------------------------
/train_by_toml.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # LoRA train script by @Akegarasu
3 |
4 | config_file="./config/default.toml" # config file | 使用 toml 文件指定训练参数
5 | sample_prompts="./config/sample_prompts.txt" # prompt file for sample | 采样 prompts 文件, 留空则不启用采样功能
6 |
7 | sdxl=0 # train sdxl LoRA | 训练 SDXL LoRA
8 | multi_gpu=0 # multi gpu | 多显卡训练 该参数仅限在显卡数 >= 2 使用
9 |
10 | # ============= DO NOT MODIFY CONTENTS BELOW | 请勿修改下方内容 =====================
11 |
12 | export HF_HOME="huggingface"
13 | export TF_CPP_MIN_LOG_LEVEL=3
14 | export PYTHONUTF8=1
15 |
16 | extArgs=()
17 | launchArgs=()
18 |
19 | if [[ $multi_gpu == 1 ]]; then
20 | launchArgs+=("--multi_gpu")
21 | launchArgs+=("--num_processes=2")
22 | fi
23 |
24 | # run train
25 | if [[ $sdxl == 1 ]]; then
26 | script_name="sdxl_train_network.py"
27 | else
28 | script_name="train_network.py"
29 | fi
30 |
31 | python -m accelerate.commands.launch "${launchArgs[@]}" --num_cpu_threads_per_process=8 "./scripts/$script_name" \
32 | --config_file="$config_file" \
33 | --sample_prompts="$sample_prompts" \
34 | "${extArgs[@]}"
35 |
--------------------------------------------------------------------------------