├── module ├── __init__.py └── kohya_config_webui.py ├── requirements.txt ├── update.bat ├── run_webui.ps1 ├── install_webui.ps1 ├── install.py ├── scripts └── ui.py ├── .gitignore ├── README.md ├── LICENSE ├── kohya_train_webui.ipynb └── kohya_config_webui.ipynb /module/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gradio>=3.24.1 2 | toml>=0.10.2 -------------------------------------------------------------------------------- /update.bat: -------------------------------------------------------------------------------- 1 | git reset --hard 2 | git pull 3 | pause -------------------------------------------------------------------------------- /run_webui.ps1: -------------------------------------------------------------------------------- 1 | .\venv\Scripts\activate 2 | python .\module\kohya_config_webui.py 3 | pause -------------------------------------------------------------------------------- /install_webui.ps1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WSH032/kohya-config-webui/HEAD/install_webui.ps1 -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import launch 2 | 3 | # TODO: early init 4 | 5 | # TODO: add deps here 6 | if not launch.is_installed("toml"): 7 | launch.run_pip("install toml") 8 | print("Installing toml...") 9 | 10 | -------------------------------------------------------------------------------- /scripts/ui.py: -------------------------------------------------------------------------------- 1 | from module.kohya_config_webui import create_demo 2 | from modules import script_callbacks 3 | 4 | def ui_tab(): 5 | return [(create_demo(), "kohya-config", "kohya_config_maker")] 6 | 7 | script_callbacks.on_ui_tabs(ui_tab) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | module/kohya_config_webui_save/ 131 | module/kohya_config/ 132 | module/kohya_config/ 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # kohya-config-webui 2 | A WebUI for making config files used by kohya_sd_script 3 | ![9@`$94}$U~DCAQCEA }20AX](https://user-images.githubusercontent.com/126865849/232077304-cb04f8c4-e815-4de8-a5ec-e9116413c5e2.png) 4 | 5 | ## 现在我们有什么? 6 | 目前,我为kohya-lora-dreambooth训练中较为常用、实用的训练参数,使用gradio和toml库,编写了一个交互式的WebUI生成工具,可以在带有python环境的windows和Colab环境中快速部署。 7 | 8 | 使用本项目,你可以快速指定训练参数,并生成config_file.toml和sample_prompt.txt。 9 | 10 | 11 | 如果你觉得此项目有用,可以给我一颗小星星,非常感谢你⭐ 12 | 13 | --- 14 | 15 | | Notebook Name | Description | Link | Old-Version | 16 | | --- | --- | --- | --- | 17 | | [Colab_Lora_train](https://github.com/WSH032/lora-scripts/) | 基于[Akegarasu/lora-scripts](https://github.com/Akegarasu/lora-scripts)的定制化Colab notebook | [![](https://img.shields.io/static/v1?message=Open%20in%20Colab&logo=googlecolab&labelColor=5c5c5c&color=0f80c1&label=%20&style=flat)](https://colab.research.google.com/github/WSH032/lora-scripts/blob/main/Colab_Lora_train.ipynb) | [![](https://img.shields.io/static/v1?message=Older%20Version&logo=googlecolab&labelColor=5c5c5c&color=e74c3c&label=%20&style=flat)](https://colab.research.google.com/drive/1_f0qJdM43BSssNJWtgjIlk9DkIzLPadx) | 18 | | [kohya_train_webui](https://github.com/WSH032/kohya-config-webui) `NEW` | 基于[WSH032/kohya-config-webui](https://github.com/WSH032/kohya-config-webui)的WebUI版Colab notebook | [![](https://img.shields.io/static/v1?message=Open%20in%20Colab&logo=googlecolab&labelColor=5c5c5c&color=0f80c1&label=%20&style=flat)](https://colab.research.google.com/github/WSH032/kohya-config-webui/blob/main/kohya_train_webui.ipynb) | 19 | 20 | 21 | ## 使用方法 22 | 23 | ### (一)Colab版本(带有完整训练环境): 24 | 25 | [![](https://img.shields.io/static/v1?message=Open%20in%20Colab&logo=googlecolab&labelColor=5c5c5c&color=0f80c1&label=%20&style=flat)](https://colab.research.google.com/github/WSH032/kohya-config-webui/blob/main/kohya_train_webui.ipynb) 26 | 27 | ### (二)AUTOMATIC1111/stable-diffusion-webui插件: 28 | 29 | `https://github.com/WSH032/kohya-config-webui.git` 30 | 31 | 将这个仓库连接复制到SD-WebUi的`扩展`->`从网址安装`界面下载即可 32 | 33 | ### (三)直接下载: 34 | 35 | 运行以下代码,或直接从github上下载zip并解压(直接下载将无法使用[update.bat](update.bat)), 36 | ```Shell 37 | git clone https://github.com/WSH032/kohya-config-webui.git 38 | ``` 39 | 运行目录下的[install_webui.ps1](install_webui.ps1)在虚拟环境中安装gradio和toml,或者 40 | ```Python 41 | pip install gradio>=3.24.1 42 | pip install toml>=0.10.2 43 | ``` 44 | 运行[run_webui.ps1](run_webui.ps1),或者 45 | ```Python 46 | python .\module\kohya_config_webui.py 47 | ``` 48 | 49 | 50 | # Todo 51 | - [x] 增加上一次参数保存功能 52 | - [ ] 增加dataset_config.toml生成功能 53 | 54 | # Credit 55 | 56 | **Based on the work of [kohya-ss](https://github.com/kohya-ss/sd-scripts) , [Linaqruf](https://github.com/Linaqruf/kohya-trainer). Thanks to them.** 57 | - 这个项目使用的的训练脚本来自[kohya-ss](https://github.com/kohya-ss/sd-scripts) 58 | - notebook中部分代码(如下载模块)来自[Linaqruf](https://github.com/Linaqruf/kohya-trainer) 59 | 60 | **Attention: It's called kohya-config-webui, but I don't have a license for kohya. It just creates config files for kohya-ss.** 61 | - 这个插件叫kohya-config-webui,但是我没有kohya的授权,我只是为了说明它的作用是生成一个用于kohya-ss训练的config文件 62 | 63 | --- 64 | 65 | 上述两位作者和我目前采取的是Apache-2.0 license 66 | 67 | 如果你基于此项目进行了修改、引用等用途,请注意原作者的协议。 68 | 69 | 请在你使用的部分标明代码来源。 70 | 71 | # 其他部分展示 72 | ![0S2_@5B6638VJ 4%Y@TQDXK](https://user-images.githubusercontent.com/126865849/232079134-15154ccf-06ac-45a0-984f-244a6e8983f3.png) 73 | 74 | ![7$K@5833YHY(8T1 `3RE03L](https://user-images.githubusercontent.com/126865849/232079434-d471da6e-9e1d-457b-b635-4c37a838bf15.png) 75 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /kohya_train_webui.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "private_outputs": true, 7 | "provenance": [], 8 | "collapsed_sections": [ 9 | "Led8SgZ0YnGB", 10 | "pYZtXvtmes2I", 11 | "bSug2SNMq9Hg" 12 | ] 13 | }, 14 | "kernelspec": { 15 | "name": "python3", 16 | "display_name": "Python 3" 17 | }, 18 | "language_info": { 19 | "name": "python" 20 | }, 21 | "gpuClass": "standard", 22 | "accelerator": "GPU" 23 | }, 24 | "cells": [ 25 | { 26 | "cell_type": "markdown", 27 | "source": [ 28 | "![visitors](https://visitor-badge.glitch.me/badge?page_id=wsh.kohya_train_webui) \n", 29 | "[![Visitors](https://api.visitorbadge.io/api/combined?path=wsh.kohya_train_webui&countColor=%232ccce4&style=flat&labelStyle=none)](https://visitorbadge.io/status?path=wsh.kohya_train_webui)\n", 30 | "[![GitHub Repo stars](https://img.shields.io/github/stars/WSH032/kohya-config-webui?style=social)](https://github.com/WSH032/kohya-config-webui)\n", 31 | "\n", 32 | "| Notebook Name | Description | Link | Old-Version |\n", 33 | "| --- | --- | --- | --- |\n", 34 | "| [Colab_Lora_train](https://github.com/WSH032/lora-scripts/) | 基于[Akegarasu/lora-scripts](https://github.com/Akegarasu/lora-scripts)的定制化Colab notebook | [![](https://img.shields.io/static/v1?message=Open%20in%20Colab&logo=googlecolab&labelColor=5c5c5c&color=0f80c1&label=%20&style=flat)](https://colab.research.google.com/github/WSH032/lora-scripts/blob/main/Colab_Lora_train.ipynb) | [![](https://img.shields.io/static/v1?message=Older%20Version&logo=googlecolab&labelColor=5c5c5c&color=e74c3c&label=%20&style=flat)](https://colab.research.google.com/drive/1_f0qJdM43BSssNJWtgjIlk9DkIzLPadx) | \n", 35 | "| [kohya_train_webui](https://github.com/WSH032/kohya-config-webui) `NEW` | 基于[WSH032/kohya-config-webui](https://github.com/WSH032/kohya-config-webui)的WebUI版Colab notebook | [![](https://img.shields.io/static/v1?message=Open%20in%20Colab&logo=googlecolab&labelColor=5c5c5c&color=0f80c1&label=%20&style=flat)](https://colab.research.google.com/github/WSH032/kohya-config-webui/blob/main/kohya_train_webui.ipynb) |\n", 36 | "\n", 37 | "如果你觉得此项目有用,可以去 [![GitHub Repo stars](https://img.shields.io/github/stars/WSH032/kohya-config-webui?style=social)](https://github.com/WSH032/kohya-config-webui) 点一颗小星星,非常感谢你⭐\n", 38 | "\n", 39 | "---\n", 40 | "\n", 41 | "- [📚notebook的操作手册](https://www.bilibili.com/read/cv23401664)\n", 42 | "\n", 43 | "- 参数:\n", 44 | "\n", 45 | " - [🥶冷门而有用的参数](https://www.bilibili.com/video/BV1mo4y1t7Zu/)\n", 46 | " - [🆕新版参数](https://www.bilibili.com/video/BV13s4y1377X/)\n", 47 | "\n", 48 | "---\n", 49 | "\n", 50 | "Based on the work of [kohya-ss](https://github.com/kohya-ss/sd-scripts) and [Linaqruf](https://github.com/Linaqruf/kohya-trainer)\n", 51 | "\n", 52 | "WebUI from [WSH032](https://github.com/WSH032/kohya-config-webui)\n" 53 | ], 54 | "metadata": { 55 | "id": "ll6PRAEKIfjT" 56 | } 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "source": [ 61 | "# 更新日志\n", 62 | "\n", 63 | "> 2023年5月24日:适配colab的torch==2.0.1\n", 64 | "> \n", 65 | "> 内容:适合torch==2.0.1的xformer以发布,不再强制安装torch==2.0.0,依赖安装时间恢复至2分40秒\n" 66 | ], 67 | "metadata": { 68 | "id": "TO3SVJYKmQhT" 69 | } 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "source": [ 74 | "# (一)环境配置" 75 | ], 76 | "metadata": { 77 | "id": "Led8SgZ0YnGB" 78 | } 79 | }, 80 | { 81 | "cell_type": "code", 82 | "source": [ 83 | "#@title ##初始化常量与挂载谷歌硬盘(只要重启过colab就要再运行一次)\n", 84 | "\n", 85 | "#@markdown 是否挂载谷歌硬盘(推荐)\n", 86 | "use_google_drive = True #@param {type:\"boolean\"}\n", 87 | "\n", 88 | "import os\n", 89 | "import shutil\n", 90 | "import sys\n", 91 | "from google.colab import drive\n", 92 | "\n", 93 | "\n", 94 | "ROOT_DIR = os.getcwd() #获取根目录\n", 95 | "\n", 96 | "SD_SCRIPTS_DIR = os.path.join( ROOT_DIR, \"sd-scripts\" ) #kohya库克隆路径\n", 97 | "WEBUI_DIR = os.path.join( ROOT_DIR, \"kohya-config-webui\" ) #webui库克隆路径\n", 98 | "\n", 99 | "#TRAIN_DATA_DIR = os.path.join( ROOT_DIR, \"Lora\", \"input\" ) #拷贝后训练材料路径\n", 100 | "#REG_DATA_DIR = os.path.join( ROOT_DIR, \"Lora\", \"reg\" ) #拷贝后正则化材料路径\n", 101 | "\n", 102 | "SD_MODEL_DIR = os.path.join( ROOT_DIR, \"Lora\", \"sd_model\" ) #SD模型下载地址\n", 103 | "VAE_MODEL_DIR = os.path.join( ROOT_DIR, \"Lora\", \"vae_model\" ) #VAE模型下载地址\n", 104 | "\n", 105 | "DEFAULT_COLAB_INPUT_DIR = os.path.normpath(\"/content/drive/MyDrive/Lora/input\") #默认Colab训练集地址\n", 106 | "DEFAULT_COLAB_REG_DIR = os.path.normpath(\"/content/drive/MyDrive/Lora/reg\") #默认Colab正则化地址\n", 107 | "DEFAULT_COLAB_OUPUT_DIR = os.path.normpath(\"/content/drive/MyDrive/Lora/output\") #默认Colab模型输出地址\n", 108 | "DEFAULT_COLAB_WEBUI_SAVE_DIR = os.path.normpath(\"/content/drive/MyDrive/Lora/kohya_config_webui_save\") #默认Colab保存webui参数文件地址\n", 109 | "\n", 110 | "ACCELERATE_CONFIG_PATH = os.path.join( ROOT_DIR, \"accelerate_config.yaml\" ) #accelerate库config文件写入地址\n", 111 | "\n", 112 | "\n", 113 | "#@title ##挂载谷歌硬盘\n", 114 | "\n", 115 | "if use_google_drive:\n", 116 | " if not os.path.exists(\"/content/drive\"):\n", 117 | " drive.mount(\"/content/drive\")\n", 118 | "\n", 119 | "!nvidia-smi\n", 120 | "\n", 121 | "#训练用环境变量\n", 122 | "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n", 123 | "os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\" \n", 124 | "os.environ[\"SAFETENSORS_FAST_GPU\"] = \"1\"" 125 | ], 126 | "metadata": { 127 | "id": "lcFMoxnjwCDV", 128 | "cellView": "form" 129 | }, 130 | "execution_count": null, 131 | "outputs": [] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "source": [ 136 | "#@title ##克隆github的库、安装依赖\n", 137 | "os.chdir( ROOT_DIR )\n", 138 | "!git clone https://github.com/kohya-ss/sd-scripts.git {SD_SCRIPTS_DIR}\n", 139 | "#@title 克隆我的库 \n", 140 | "!git clone https://github.com/WSH032/kohya-config-webui.git {WEBUI_DIR}\n", 141 | "\n", 142 | "#安装torch\n", 143 | "print(f\"torch安装中\")\n", 144 | "!pip -q install torch torchvision xformers triton\n", 145 | "print(f\"torch安装完成\")\n", 146 | "\n", 147 | "#安装kohya依赖\n", 148 | "print(f\"kohya依赖安装中\")\n", 149 | "os.chdir(SD_SCRIPTS_DIR)\n", 150 | "!pip -q install -r requirements.txt\n", 151 | "os.chdir(ROOT_DIR)\n", 152 | "print(f\"kohya依赖安装完成\")\n", 153 | "\n", 154 | "#安装lion优化器、Dadaption优化器、lycoris\n", 155 | "print(f\"lion优化器、Dadaption优化器、lycoris安装中\")\n", 156 | "!pip -q install --upgrade lion-pytorch dadaptation lycoris-lora\n", 157 | "print(f\"lion优化器、Dadaption优化器、lycoris安装完成\")\n", 158 | "\n", 159 | "#安装wandb\n", 160 | "print(f\"wandb安装中\")\n", 161 | "!pip -q install wandb\n", 162 | "print(f\"wandb安装中\")\n", 163 | "\n", 164 | "#安装webui依赖\n", 165 | "print(f\"webui依赖安装中\")\n", 166 | "os.chdir(WEBUI_DIR)\n", 167 | "!pip -q install -r requirements.txt\n", 168 | "os.chdir(ROOT_DIR)\n", 169 | "print(f\"webui依赖安装完成\")\n", 170 | "\n", 171 | "#安装功能性依赖\n", 172 | "!apt -q install aria2\n", 173 | "!pip -q install portpicker\n", 174 | "\n", 175 | "\n", 176 | "import torch\n", 177 | "print(\"当前torch版本\",torch.__version__)\n", 178 | "import torchvision\n", 179 | "print(\"当前torchvision版本\",torchvision.__version__)\n", 180 | "import triton\n", 181 | "print(\"当前triton版本\", triton.__version__)\n", 182 | "\n", 183 | "!python -V" 184 | ], 185 | "metadata": { 186 | "cellView": "form", 187 | "id": "4wYnLUrYY6Ut" 188 | }, 189 | "execution_count": null, 190 | "outputs": [] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "source": [ 195 | "#@title ## 下载模型, 可以同时选多个模型下载,到时候是在WebUI里选(原始代码来源于:[Linaqruf](https://github.com/Linaqruf/kohya-trainer))\n", 196 | "installModels = []\n", 197 | "installv2Models = []\n", 198 | "\n", 199 | "#@markdown **预设底模**\n", 200 | "\n", 201 | "#@markdown SD1.x model\n", 202 | "modelName = \"Animefull-final-pruned.ckpt\" # @param [\"\", \"Animefull-final-pruned.ckpt\", \"Anything-v3-1.safetensors\", \"AnyLoRA.safetensors\", \"AnimePastelDream.safetensors\", \"Chillout-mix.safetensors\", \"OpenJourney-v4.ckpt\", \"Stable-Diffusion-v1-5.safetensors\"]\n", 203 | "#@markdown SD2.x model `这些为SD2.x模型,训练时请开启v2选项`\n", 204 | "v2ModelName = \"\" # @param [\"\", \"stable-diffusion-2-1-base.safetensors\", \"stable-diffusion-2-1-768v.safetensors\", \"plat-diffusion-v1-3-1.safetensors\", \"replicant-v1.safetensors\", \"illuminati-diffusion-v1-0.safetensors\", \"illuminati-diffusion-v1-1.safetensors\", \"waifu-diffusion-1-4-anime-e2.ckpt\", \"waifu-diffusion-1-5-e2.safetensors\", \"waifu-diffusion-1-5-e2-aesthetic.safetensors\"]\n", 205 | "\n", 206 | "#@markdown **自定义模型(不能超过5G)URL例如**`https://huggingface.co/a1079602570/animefull-final-pruned/resolve/main/novelailatest-pruned.ckpt`\n", 207 | "\n", 208 | "base_model_url = \"\" #@param {type:\"string\"}\n", 209 | "\n", 210 | "#@markdown **或者自定义模型(不能超过5G)路径例如**`/content/drive/MyDrive/Lora/model/your_model.ckpt`\n", 211 | "\n", 212 | "base_model_self_path = \"\" #@param {type:\"string\"}\n", 213 | "\n", 214 | "\n", 215 | "def get_sd_model():\n", 216 | " modelUrl = [\n", 217 | " \"\",\n", 218 | " \"https://huggingface.co/Linaqruf/personal-backup/resolve/main/models/animefull-final-pruned.ckpt\",\n", 219 | " \"https://huggingface.co/cag/anything-v3-1/resolve/main/anything-v3-1.safetensors\",\n", 220 | " \"https://huggingface.co/Lykon/AnyLoRA/resolve/main/AnyLoRA_noVae_fp16.safetensors\",\n", 221 | " \"https://huggingface.co/Lykon/AnimePastelDream/resolve/main/AnimePastelDream_Soft_noVae_fp16.safetensors\",\n", 222 | " \"https://huggingface.co/Linaqruf/stolen/resolve/main/pruned-models/chillout_mix-pruned.safetensors\",\n", 223 | " \"https://huggingface.co/prompthero/openjourney-v4/resolve/main/openjourney-v4.ckpt\",\n", 224 | " \"https://huggingface.co/Linaqruf/stolen/resolve/main/pruned-models/stable_diffusion_1_5-pruned.safetensors\",\n", 225 | " ]\n", 226 | " modelList = [\n", 227 | " \"\",\n", 228 | " \"Animefull-final-pruned.ckpt\",\n", 229 | " \"Anything-v3-1.safetensors\",\n", 230 | " \"AnyLoRA.safetensors\",\n", 231 | " \"AnimePastelDream.safetensors\", \n", 232 | " \"Chillout-mix.safetensors\",\n", 233 | " \"OpenJourney-v4.ckpt\",\n", 234 | " \"Stable-Diffusion-v1-5.safetensors\",\n", 235 | " ]\n", 236 | " v2ModelUrl = [\n", 237 | " \"\",\n", 238 | " \"https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors\",\n", 239 | " \"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors\",\n", 240 | " \"https://huggingface.co/p1atdev/pd-archive/resolve/main/plat-v1-3-1.safetensors\",\n", 241 | " \"https://huggingface.co/gsdf/Replicant-V1.0/resolve/main/Replicant-V1.0.safetensors\",\n", 242 | " \"https://huggingface.co/IlluminatiAI/Illuminati_Diffusion_v1.0/resolve/main/illuminati_diffusion_v1.0.safetensors\",\n", 243 | " \"https://huggingface.co/4eJIoBek/Illuminati-Diffusion-v1-1/resolve/main/illuminatiDiffusionV1_v11.safetensors\",\n", 244 | " \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e2.ckpt\",\n", 245 | " \"https://huggingface.co/waifu-diffusion/wd-1-5-beta2/resolve/main/checkpoints/wd-1-5-beta2-fp32.safetensors\",\n", 246 | " \"https://huggingface.co/waifu-diffusion/wd-1-5-beta2/resolve/main/checkpoints/wd-1-5-beta2-aesthetic-fp32.safetensors\",\n", 247 | " ]\n", 248 | " v2ModelList = [\n", 249 | " \"\",\n", 250 | " \"stable-diffusion-2-1-base.safetensors\",\n", 251 | " \"stable-diffusion-2-1-768v.safetensors\",\n", 252 | " \"plat-diffusion-v1-3-1.safetensors\",\n", 253 | " \"replicant-v1.safetensors\",\n", 254 | " \"illuminati-diffusion-v1-0.safetensors\",\n", 255 | " \"illuminati-diffusion-v1-1.safetensors\",\n", 256 | " \"waifu-diffusion-1-4-anime-e2.ckpt\",\n", 257 | " \"waifu-diffusion-1-5-e2.safetensors\",\n", 258 | " \"waifu-diffusion-1-5-e2-aesthetic.safetensors\",\n", 259 | " ]\n", 260 | " if modelName:\n", 261 | " installModels.append((modelName, modelUrl[modelList.index(modelName)]))\n", 262 | " if v2ModelName:\n", 263 | " installv2Models.append((v2ModelName, v2ModelUrl[v2ModelList.index(v2ModelName)]))\n", 264 | "\n", 265 | "\n", 266 | " #下载模型\n", 267 | " def install(checkpoint_name, url):\n", 268 | " hf_token = \"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\"\n", 269 | " user_header = f'\"Authorization: Bearer {hf_token}\"'\n", 270 | " print(checkpoint_name)\n", 271 | " print(url)\n", 272 | " !aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -d {SD_MODEL_DIR} -o {checkpoint_name} {url}\n", 273 | " def install_checkpoint():\n", 274 | " for model in installModels:\n", 275 | " install(model[0], model[1])\n", 276 | " for v2model in installv2Models:\n", 277 | " install(v2model[0], v2model[1])\n", 278 | "\n", 279 | " #下载预设模型\n", 280 | " install_checkpoint()\n", 281 | "\n", 282 | " #自定义链接不留空,则尝试下载\n", 283 | " if base_model_url:\n", 284 | " #!aria2c --content-disposition-default-utf8=true --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {SD_MODEL_DIR} {base_model_url}\n", 285 | " !wget {base_model_url} -P {SD_MODEL_DIR} -N\n", 286 | "\n", 287 | " #自定义路径不留空,则尝试拷贝\n", 288 | " if base_model_self_path:\n", 289 | " try:\n", 290 | " base_model_copy_path = os.path.join( SD_MODEL_DIR, os.path.basename(base_model_self_path) )\n", 291 | " shutil.copyfile(base_model_self_path, base_model_copy_path)\n", 292 | " print(f\"拷贝自定义底模成功, {base_model_self_path}被拷贝至{base_model_copy_path}\")\n", 293 | " except Exception as e:\n", 294 | " print(f\"拷贝自定义底模时发生错误, Error: {e}\")\n", 295 | "\n", 296 | "get_sd_model()\n", 297 | "\n", 298 | "\n", 299 | "#@markdown **(可选)选择一个Vae下载**`\"animevae.pt\", \"kl-f8-anime.ckpt\", \"vae-ft-mse-840000-ema-pruned.ckpt\"`\n", 300 | "\n", 301 | "vaeName = \"\" # @param [\"\", \"anime.vae.pt\", \"waifudiffusion.vae.pt\", \"stablediffusion.vae.pt\"]\n", 302 | "\n", 303 | "def get_vae_model():\n", 304 | "\n", 305 | " installVae = []\n", 306 | "\n", 307 | " vaeUrl = [\n", 308 | " \"\",\n", 309 | " \"https://huggingface.co/Linaqruf/personal-backup/resolve/main/vae/animevae.pt\",\n", 310 | " \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt\",\n", 311 | " \"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt\",\n", 312 | " ]\n", 313 | " vaeList = [\"\", \"anime.vae.pt\", \"waifudiffusion.vae.pt\", \"stablediffusion.vae.pt\"]\n", 314 | "\n", 315 | " installVae.append((vaeName, vaeUrl[vaeList.index(vaeName)]))\n", 316 | "\n", 317 | " #开始下载\n", 318 | " def install(vae_name, url):\n", 319 | " hf_token = \"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\"\n", 320 | " user_header = f'\"Authorization: Bearer {hf_token}\"'\n", 321 | " print(vae_name)\n", 322 | " print(url)\n", 323 | " !aria2c --console-log-level=error --allow-overwrite --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 -d {VAE_MODEL_DIR} -o {vae_name} \"{url}\"\n", 324 | "\n", 325 | " def install_vae():\n", 326 | " if vaeName:\n", 327 | " for vae in installVae:\n", 328 | " install(vae[0], vae[1])\n", 329 | " else:\n", 330 | " pass\n", 331 | " install_vae()\n", 332 | "\n", 333 | "get_vae_model()" 334 | ], 335 | "metadata": { 336 | "cellView": "form", 337 | "id": "OYGUN309MuUJ" 338 | }, 339 | "execution_count": null, 340 | "outputs": [] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "source": [ 345 | "# (二)训练参数" 346 | ], 347 | "metadata": { 348 | "id": "pYZtXvtmes2I" 349 | } 350 | }, 351 | { 352 | "cell_type": "code", 353 | "source": [ 354 | "#@title ##启动WebUI来设置参数\n", 355 | "\n", 356 | "#@markdown - 在谷歌硬盘的`/content/drive/MyDrive/Lora/kohya_config_webui_save`会生成一个`colab.toml`,在WebUI里读取它,会帮你完成默认参数设置。\n", 357 | "#@markdown - 读取的时候会提示参数找不到,这是正常的\n", 358 | "#@markdown - 设置好参数后可以保存`(默认会保存到你的谷歌硬盘)`,以后读取你保存的配置文件就行\n", 359 | "#@markdown - 保存toml配置文件时候不要用`colab.toml`这个名字,会被覆盖掉\n", 360 | "\n", 361 | "#@markdown - 在colab里要开`lowram`,不然很多模型载入不了,读取`colab.toml`的时候会自动帮你开启\n", 362 | "\n", 363 | "#@markdown ---\n", 364 | "\n", 365 | "#@markdown 是否在colab里打开webui`不勾选就输出一个链接,点击后在另一个网页操作,反正我喜欢不勾选`\n", 366 | "in_colab = False #@param {type:\"boolean\"}\n", 367 | "\n", 368 | "#@markdown 是否使用gradio的远程分享及队列功能\n", 369 | "use_queue = False #@param {type:\"boolean\"}\n", 370 | "\n", 371 | "#生成一个colab默认toml文件\n", 372 | "def creat_save_toml(save_dir):\n", 373 | " \"\"\"生成适用于Colab的webui参数保存文件colab.toml\"\"\"\n", 374 | " import toml\n", 375 | " #写入路径\n", 376 | " other={\"write_files_dir\":SD_SCRIPTS_DIR}\n", 377 | " #材料、模型、输出路径\n", 378 | " param={\n", 379 | " \"train_data_dir\":DEFAULT_COLAB_INPUT_DIR,\n", 380 | " \"reg_data_dir\":DEFAULT_COLAB_REG_DIR,\n", 381 | " \"base_model_dir\":SD_MODEL_DIR,\n", 382 | " \"vae_model_dir\":VAE_MODEL_DIR,\n", 383 | " \"output_dir\":DEFAULT_COLAB_OUPUT_DIR,\n", 384 | " \"lowram\":True,\n", 385 | " }\n", 386 | "\n", 387 | " save_dict = {\"other\":other, \"param\":param}\n", 388 | " #写入文件\n", 389 | " save_name = \"colab.toml\"\n", 390 | " save_path = os.path.join( save_dir, save_name )\n", 391 | " os.makedirs(save_dir, exist_ok=True)\n", 392 | " with open(save_path, \"w\", encoding=\"utf-8\") as f:\n", 393 | " f.write( toml.dumps(save_dict) )\n", 394 | "\n", 395 | "creat_save_toml(DEFAULT_COLAB_WEBUI_SAVE_DIR)\n", 396 | "\n", 397 | "#导入并生成demo\n", 398 | "launch_param = [f\"--save_dir={DEFAULT_COLAB_WEBUI_SAVE_DIR}\",\n", 399 | " f\"--save_name=kohya_config_webui_save.toml\",\n", 400 | " f\"--read_dir={DEFAULT_COLAB_WEBUI_SAVE_DIR}\"\n", 401 | "]\n", 402 | "os.chdir( os.path.join(WEBUI_DIR, \"module\") )\n", 403 | "from kohya_config_webui import create_demo\n", 404 | "os.chdir(ROOT_DIR)\n", 405 | "demo = create_demo(launch_param)\n", 406 | "\n", 407 | "#找一个空闲端口\n", 408 | "import portpicker\n", 409 | "port = portpicker.pick_unused_port()\n", 410 | "#启动\n", 411 | "if not use_queue:\n", 412 | " demo.launch(server_port=port, inbrowser=False, inline=False)\n", 413 | " #暴露端口\n", 414 | " from google.colab import output\n", 415 | " output.serve_kernel_port_as_window(port)\n", 416 | " #是否在Colab里打开\n", 417 | " if in_colab:\n", 418 | " output.serve_kernel_port_as_iframe(port)\n", 419 | "else:\n", 420 | " demo.queue().launch(server_port=port, inline=in_colab)\n", 421 | "\n" 422 | ], 423 | "metadata": { 424 | "cellView": "form", 425 | "id": "AlRW5ufPM-0a" 426 | }, 427 | "execution_count": null, 428 | "outputs": [] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "source": [ 433 | "#@title ### 开始训练\n", 434 | "\n", 435 | "#@markdown 若正确运行,训练完成后,模型会自动保存至你在WebUI里设置的地址\n", 436 | "\n", 437 | "#@markdown 默认训练配置文件在 `/content/sd-scripts/config_file.toml`\n", 438 | "\n", 439 | "#@markdown 默认采样参数文件在 `/content/sd-scripts/sample_prompts.txt`\n", 440 | "\n", 441 | "#@markdown ---\n", 442 | "\n", 443 | "#@markdown 如果你想用自己的配置文件,或者采样文件,请填入下方 `填入意味着启用`\n", 444 | "\n", 445 | "config_file_self_path = \"\" #@param {type:\"string\"}\n", 446 | "\n", 447 | "sample_prompts_self_path = \"\" #@param {type:\"string\"}\n", 448 | "\n", 449 | "os.chdir(ROOT_DIR)\n", 450 | "\n", 451 | "from accelerate.utils import write_basic_config\n", 452 | "if not os.path.exists(ACCELERATE_CONFIG_PATH):\n", 453 | " write_basic_config(save_location=ACCELERATE_CONFIG_PATH)\n", 454 | "\n", 455 | "\n", 456 | "\n", 457 | "os.chdir(SD_SCRIPTS_DIR)\n", 458 | "\n", 459 | "#开始训练!\n", 460 | "!accelerate launch --config_file={ACCELERATE_CONFIG_PATH} --num_cpu_threads_per_process=8 train_network.py\\\n", 461 | " --config_file={config_file_self_path if config_file_self_path else \"config_file.toml\"}\\\n", 462 | " --sample_prompts={sample_prompts_self_path if sample_prompts_self_path else \"sample_prompts.txt\"}\n", 463 | "\n", 464 | "os.chdir(ROOT_DIR)" 465 | ], 466 | "metadata": { 467 | "cellView": "form", 468 | "id": "NTRgMI7jR3DY" 469 | }, 470 | "execution_count": null, 471 | "outputs": [] 472 | }, 473 | { 474 | "cell_type": "markdown", 475 | "source": [ 476 | "# (三)开发代码`别碰`" 477 | ], 478 | "metadata": { 479 | "id": "bSug2SNMq9Hg" 480 | } 481 | }, 482 | { 483 | "cell_type": "code", 484 | "source": [ 485 | "#@title linaqfuf优化代码\n", 486 | "\n", 487 | "!sed -i \"s@cpu@cuda@\" /content/sd-scripts/library/model_util.py\n", 488 | "\n", 489 | "import zipfile\n", 490 | "def ubuntu_deps(url, name, dst):\n", 491 | " !wget --show-progress {url}\n", 492 | " with zipfile.ZipFile(name, \"r\") as deps:\n", 493 | " deps.extractall(dst)\n", 494 | " !dpkg -i {dst}/*\n", 495 | " os.remove(name)\n", 496 | " shutil.rmtree(dst)\n", 497 | "deps_dir = \"/conent/dep\"\n", 498 | "ubuntu_deps(\n", 499 | " \"https://huggingface.co/Linaqruf/fast-repo/resolve/main/deb-libs.zip\",\n", 500 | " \"deb-libs.zip\",\n", 501 | " deps_dir,\n", 502 | ")\n", 503 | "\n", 504 | "!apt -y update\n", 505 | "!apt install libunwind8-dev\n", 506 | "\n", 507 | "os.environ[\"LD_PRELOAD\"] = \"libtcmalloc.so\"\n", 508 | "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n", 509 | "os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\" \n", 510 | "os.environ[\"SAFETENSORS_FAST_GPU\"] = \"1\"\n", 511 | "\n", 512 | "cuda_path = \"/usr/local/cuda-11.8/targets/x86_64-linux/lib/\"\n", 513 | "ld_library_path = os.environ.get(\"LD_LIBRARY_PATH\", \"\")\n", 514 | "os.environ[\"LD_LIBRARY_PATH\"] = f\"{ld_library_path}:{cuda_path}\"" 515 | ], 516 | "metadata": { 517 | "cellView": "form", 518 | "id": "ps6GgFwVaqXx" 519 | }, 520 | "execution_count": null, 521 | "outputs": [] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "source": [ 526 | "#@title ##拷贝材料(支持重复训练时选择新的路径)\n", 527 | "\n", 528 | "#@markdown 训练集路径,正则化集路径(正则化留空则不拷贝)\n", 529 | "\n", 530 | "#@markdown `教程默认路径:`\n", 531 | "\n", 532 | "#@markdown `训练集:/content/drive/MyDrive/Lora/input/`\n", 533 | "\n", 534 | "#@markdown `正则化:/content/drive/MyDrive/Lora/reg/`\n", 535 | "\n", 536 | "train_data_dir_self = \"/content/drive/MyDrive/Lora/input/blue_archive\" #@param {type:'string'}\n", 537 | "reg_data_dir_self = \"\" #@param {type:'string'}\n", 538 | "\n", 539 | "\n", 540 | "def copy_data_and_reg(data_dir: str, reg_dir: str = \"\"):\n", 541 | " \"\"\"\n", 542 | " 将材料拷贝至TRAIN_DATA_DIR和REG_DATA_DIR\n", 543 | " 拷贝前会删除之前材料\n", 544 | " data_dir为训练集,必填; reg_dir,默认为空,不填则不拷贝\n", 545 | " \"\"\"\n", 546 | " #训练集路径为空直接退出\n", 547 | " if not data_dir:\n", 548 | " print(f\"训练集路径为空\")\n", 549 | " return\n", 550 | "\n", 551 | " #已经存在拷贝材料则删除\n", 552 | " def rm_dir(dir):\n", 553 | " if os.path.exists(dir):\n", 554 | " shutil.rmtree(dir)\n", 555 | " rm_dir(TRAIN_DATA_DIR)\n", 556 | " rm_dir(REG_DATA_DIR)\n", 557 | "\n", 558 | " #拷贝材料\n", 559 | " def cp_dir(from_dir, to_dir, name):\n", 560 | " print(f\"拷贝{name}中\")\n", 561 | " try:\n", 562 | " shutil.copytree(from_dir, to_dir, dirs_exist_ok=True)\n", 563 | " print(f\"{name}拷贝成功, {from_dir}被拷贝至{to_dir}\")\n", 564 | " except Exception as e:\n", 565 | " print(f\"拷贝{name}时发生错误, Error: {e}\")\n", 566 | "\n", 567 | " cp_dir(data_dir, TRAIN_DATA_DIR, \"训练集\")\n", 568 | " if reg_dir:\n", 569 | " cp_dir(reg_dir, REG_DATA_DIR, \"训练集\")\n", 570 | " else:\n", 571 | " print(f\"不拷贝正则化\")\n", 572 | "\n", 573 | "copy_data_and_reg(train_data_dir_self, reg_data_dir_self)\n" 574 | ], 575 | "metadata": { 576 | "id": "EzHADOPZMlN4", 577 | "cellView": "form" 578 | }, 579 | "execution_count": null, 580 | "outputs": [] 581 | } 582 | ] 583 | } 584 | -------------------------------------------------------------------------------- /kohya_config_webui.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "source": [ 16 | "| ![visitors](https://visitor-badge.glitch.me/badge?page_id=wsh.kohya_config_webui) | [![GitHub Repo stars](https://img.shields.io/github/stars/WSH032/kohya-config-webui?style=social)](https://github.com/WSH032/kohya-config-webui) |\n", 17 | "\n", 18 | "#A WebUI for making config files used by kohya_sd_script\n", 19 | "\n", 20 | "Created by [WSH](https://space.bilibili.com/8417436)" 21 | ], 22 | "metadata": { 23 | "id": "7aje0w7w2qsc" 24 | } 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 1, 29 | "metadata": { 30 | "cellView": "form", 31 | "id": "ec1xuZuQmvTg" 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "#@title 安装依赖\n", 36 | "!pip install gradio > /dev/null 2>&1" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 39, 42 | "metadata": { 43 | "id": "oa6oEre6KC_B", 44 | "cellView": "form" 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "#@title 函数部分\n", 49 | "\n", 50 | "#A WebUI for making config files used by kohya_sd_script\n", 51 | "\n", 52 | "#Created by [WSH](https://space.bilibili.com/8417436)\n", 53 | "#[![GitHub Repo stars](https://img.shields.io/github/stars/WSH032/kohya-config-webui?style=social)](https://github.com/WSH032/kohya-config-webui)\n", 54 | "\n", 55 | "import os\n", 56 | "import toml\n", 57 | "import warnings\n", 58 | "import gradio as gr\n", 59 | "\n", 60 | "common_parameter_dict_key_list=[]\n", 61 | "sample_parameter_dict_key_list=[]\n", 62 | "plus_parameter_dict_key_list=[]\n", 63 | "all_parameter_dict_key_list=[] #后面会有一次all_parameter_dict_key_list = common_parameter_dict_key_list + sample_parameter_dict_key_list + plus_parameter_dict_key_list\n", 64 | " \n", 65 | "\n", 66 | "common_parameter_dict=({})\n", 67 | "sample_parameter_dict=({})\n", 68 | "plus_parameter_dict=({})\n", 69 | "\n", 70 | "common_confirm_flag = False #必须要确认常规参数一次才允许写入toml\n", 71 | "\n", 72 | "parameter_len_dict={\"common\":0, \"sample\":0, \"plus\":0}\n", 73 | "\n", 74 | "random_symbol = '\\U0001f3b2\\ufe0f' # 🎲️\n", 75 | "reuse_symbol = '\\u267b\\ufe0f' # ♻️\n", 76 | "paste_symbol = '\\u2199\\ufe0f' # ↙\n", 77 | "refresh_symbol = '\\U0001f504' # 🔄\n", 78 | "save_style_symbol = '\\U0001f4be' # 💾\n", 79 | "apply_style_symbol = '\\U0001f4cb' # 📋\n", 80 | "clear_prompt_symbol = '\\U0001f5d1\\ufe0f' # 🗑️\n", 81 | "extra_networks_symbol = '\\U0001F3B4' # 🎴\n", 82 | "switch_values_symbol = '\\U000021C5' # ⇅\n", 83 | "folder_symbol = '\\U0001f4c2' # 📂\n", 84 | "\n", 85 | "\n", 86 | "def check_len_and_2dict(args, parameter_len_dict_value, parameter_dict_key_list, func_name=\"\"):\n", 87 | " if len(args) != parameter_len_dict_value:\n", 88 | " warnings.warn(f\"传入{func_name}的参数长度不匹配\", UserWarning)\n", 89 | " if len(parameter_dict_key_list) != parameter_len_dict_value:\n", 90 | " warnings.warn(f\" {func_name}内部字典赋值关键字列表的长度不匹配\", UserWarning)\n", 91 | " parameter_dict = dict(zip(parameter_dict_key_list, args))\n", 92 | " return parameter_dict\n", 93 | "\n", 94 | "def common_parameter_get(*args):\n", 95 | " global common_parameter_dict, common_confirm_flag\n", 96 | " common_confirm_flag = True #必须要确认常规参数一次才允许写入toml\n", 97 | " common_parameter_dict = check_len_and_2dict(args, parameter_len_dict[\"common\"], common_parameter_dict_key_list, func_name=\"common_parameter_get\")\n", 98 | " common_parameter_toml = toml.dumps(common_parameter_dict)\n", 99 | " common_parameter_title = \"基础参数配置确认\"\n", 100 | " return common_parameter_toml, common_parameter_title\n", 101 | "\n", 102 | "def sample_parameter_get(*args):\n", 103 | " global sample_parameter_dict\n", 104 | " sample_parameter_dict = check_len_and_2dict(args, parameter_len_dict[\"sample\"], sample_parameter_dict_key_list, func_name=\"sample_parameter_get\")\n", 105 | " sample_parameter_toml = toml.dumps(sample_parameter_dict)\n", 106 | " sample_parameter_title = \"采样配置确认\"\n", 107 | " return sample_parameter_toml, sample_parameter_title\n", 108 | "\n", 109 | "\n", 110 | "def plus_parameter_get(*args):\n", 111 | " global plus_parameter_dict\n", 112 | " plus_parameter_dict = check_len_and_2dict(args, parameter_len_dict[\"plus\"], plus_parameter_dict_key_list, func_name=\"plus_parameter_get\")\n", 113 | " plus_parameter_toml = toml.dumps(plus_parameter_dict)\n", 114 | " plus_parameter_title = \"进阶参数配置确认\"\n", 115 | " return plus_parameter_toml, plus_parameter_title\n", 116 | "\n", 117 | "\n", 118 | "def all_parameter_get(*args):\n", 119 | " if len(args) != sum( parameter_len_dict.values() ):\n", 120 | " warnings.warn(f\"传入all_parameter_get的参数长度不匹配\", UserWarning)\n", 121 | " common_parameter_toml, common_parameter_title = common_parameter_get( *args[ : parameter_len_dict[\"common\"] ] )\n", 122 | " sample_parameter_toml, sample_parameter_title = sample_parameter_get( *args[ parameter_len_dict[\"common\"] : parameter_len_dict[\"common\"] + parameter_len_dict[\"sample\"] ] )\n", 123 | " plus_parameter_toml, plus_parameter_title = plus_parameter_get( *args[ -parameter_len_dict[\"plus\"] : ] )\n", 124 | " return common_parameter_toml, sample_parameter_toml, plus_parameter_toml, \"全部参数确认\"\n", 125 | "\n", 126 | " \n", 127 | "def save_webui_config(save_webui_config_dir, save_webui_config_name, write_files_dir):\n", 128 | " os.makedirs(save_webui_config_dir, exist_ok=True)\n", 129 | " \n", 130 | " other = {\"write_files_dir\":write_files_dir}\n", 131 | " param = {**common_parameter_dict, **sample_parameter_dict, **plus_parameter_dict}\n", 132 | " dict = { \"other\":other, \"param\":param }\n", 133 | "\n", 134 | " save_webui_config_path = os.path.join(save_webui_config_dir, save_webui_config_name)\n", 135 | " with open(save_webui_config_path, \"w\", encoding=\"utf-8\") as f:\n", 136 | " webui_config_str = toml.dumps( dict )\n", 137 | " f.write(webui_config_str)\n", 138 | " return f\"保存webui配置成功,文件在{save_webui_config_path}\"\n", 139 | "\n", 140 | "def read_webui_config_get(read_webui_config_dir):\n", 141 | " try:\n", 142 | " files = [f for f in os.listdir(read_webui_config_dir) if f.endswith(\".toml\") ]\n", 143 | " if files:\n", 144 | " return gr.update( choices=files,value=files[0] )\n", 145 | " else:\n", 146 | " return gr.update( choices=[],value=\"没有找到webui配置文件\" )\n", 147 | " except Exception as e:\n", 148 | " return gr.update( choices=[], value=f\"错误的文件夹路径:{e}\" )\n", 149 | "\n", 150 | "def read_webui_config(read_webui_config_dir, read_webui_config_name, write_files_dir, *args):\n", 151 | " dir_change_flag = False\n", 152 | " param_len = sum( parameter_len_dict.values() )\n", 153 | " if len(args) != param_len:\n", 154 | " warnings.warn(f\"传入read_webui_config的*args长度不匹配\", UserWarning)\n", 155 | " \n", 156 | " read_webui_config_path = os.path.join(read_webui_config_dir, read_webui_config_name)\n", 157 | " #能打开就正常操作\n", 158 | " try:\n", 159 | " with open(read_webui_config_path, \"r\", encoding=\"utf-8\") as f:\n", 160 | " config_dict = toml.loads( f.read() )\n", 161 | " \n", 162 | " #能读到[\"other\"].[\"write_files_dir\"]就改,读不到就用原写入地址\n", 163 | " try:\n", 164 | " if config_dict[\"other\"][\"write_files_dir\"] != write_files_dir:\n", 165 | " write_files_dir = config_dict[\"other\"][\"write_files_dir\"]\n", 166 | " dir_change_flag = True\n", 167 | " except KeyError:\n", 168 | " pass\n", 169 | " \n", 170 | " param_dict_key_list = list( config_dict.get(\"param\",{}).keys() )\n", 171 | " #找出共有的key进行赋值,非共有的报错\n", 172 | " both_key = set(all_parameter_dict_key_list) & set(param_dict_key_list)\n", 173 | " parameter_unique_key = set(all_parameter_dict_key_list) - set(both_key)\n", 174 | " config_unique_key = set(param_dict_key_list) - set(both_key)\n", 175 | " #赋值\n", 176 | " count = 0\n", 177 | " if both_key:\n", 178 | " args = list(args)\n", 179 | " for key in both_key:\n", 180 | " index = all_parameter_dict_key_list.index(key)\n", 181 | " args[ index ] = config_dict[\"param\"][key]\n", 182 | " count += 1\n", 183 | " args = tuple(args)\n", 184 | " read_done = f\"\\n读取完成,WebUI中共有{param_len}项参数,更新了其中{count}项\\n\" + f\"写入文件夹发生改变:{write_files_dir}\" if dir_change_flag else \"\"\n", 185 | " config_warning = f\"\\nwebui-config文件中以下参数可能已经失效或错误:\\n{config_unique_key}\\n\" if config_unique_key else \"\"\n", 186 | " parameter_warning = f\"\\nWebUI中以下参数在webui-config文件中未找到,不发生修改:\\n{parameter_unique_key}\\n\" if parameter_unique_key else \"\"\n", 187 | " str = read_done + config_warning + parameter_warning\n", 188 | " return str, write_files_dir, *args\n", 189 | "\n", 190 | " #打不开就返回原值\n", 191 | " except FileNotFoundError:\n", 192 | " return \"文件或目录不存在\", write_files_dir, *args\n", 193 | " except PermissionError:\n", 194 | " return \"没有权限访问文件或目录\", write_files_dir, *args\n", 195 | " except OSError as e:\n", 196 | " return f\"something wrong:{e}\", write_files_dir, *args\n", 197 | " \n", 198 | " \n", 199 | "\n", 200 | "def model_get(model_dir):\n", 201 | " try:\n", 202 | " files = [f for f in os.listdir(model_dir) if os.path.isfile(os.path.join(model_dir, f))]\n", 203 | " if files:\n", 204 | " return gr.update( choices=files,value=files[0] )\n", 205 | " else:\n", 206 | " return gr.update( choices=[],value=\"没有找到模型\" )\n", 207 | " except Exception as e:\n", 208 | " return gr.update( choices=[], value=f\"错误的文件夹路径:{e}\" )\n", 209 | "\n", 210 | "\n", 211 | "def write_files(write_files_dir):\n", 212 | "\n", 213 | " if not common_confirm_flag:\n", 214 | " return \"必须要确认常规参数一次才允许写入toml\"\n", 215 | "\n", 216 | " write_files_dir = write_files_dir if write_files_dir else os.getcwd()\n", 217 | " os.makedirs(write_files_dir, exist_ok=True)\n", 218 | " config_file_toml_path = os.path.join(write_files_dir, \"config_file.toml\")\n", 219 | " sample_prompts_txt_path = os.path.join(write_files_dir, \"sample_prompts.txt\")\n", 220 | "\n", 221 | " all = {**common_parameter_dict, **sample_parameter_dict, **plus_parameter_dict}\n", 222 | "\n", 223 | " def parameter2toml():\n", 224 | "\n", 225 | " #生成config_file.toml的字典\n", 226 | "\n", 227 | " #model_arguments部分\n", 228 | " model_arguments = { key: all.get(key) for key in [\"v2\", \"v_parameterization\"] }\n", 229 | " \"\"\" 生成底模路径 \"\"\"\n", 230 | " base_model_path = os.path.join( all.get(\"base_model_dir\"), all.get(\"base_model_name\") )\n", 231 | " model_arguments.update( {\"pretrained_model_name_or_path\": base_model_path} )\n", 232 | " \"\"\" 生成vae路径 \"\"\"\n", 233 | " if all.get(\"use_vae\"):\n", 234 | " vae_model_path = os.path.join( all.get(\"vae_model_dir\"), all.get(\"vae_model_name\") )\n", 235 | " model_arguments.update( {\"vae\": vae_model_path} )\n", 236 | "\n", 237 | " #additional_network_arguments部分\n", 238 | " additional_network_arguments = { key: all.get(key) for key in [\"unet_lr\", \"text_encoder_lr\", \"network_dim\",\\\n", 239 | " \"network_alpha\", \"network_train_unet_only\",\\\n", 240 | " \"network_train_text_encoder_only\"] }\n", 241 | " \"\"\" 生成如network_module = \"locon.locon_kohya\" \"\"\"\n", 242 | " #[\"LoRA-LierLa\", \"LoRA-C3Lier\", \"LoCon_Lycoris\", \"LoHa_Lycoris\", \"DyLoRa-LierLa\", \"DyLoRa-C3Lier\"]\n", 243 | " #主要负责network_module的参数生成\n", 244 | " def network_module_param(train_method):\n", 245 | " conv_dim = all.get(\"conv_dim\") if train_method != \"DyLoRa-C3Lier\" else all.get(\"network_dim\")\n", 246 | " conv_alpha = all.get(\"conv_alpha\")\n", 247 | " algo = \"lora\" if train_method == \"LoCon_Lycoris\" else \"loha\"\n", 248 | " unit = all.get(\"unit\")\n", 249 | " if train_method in [\"LoRA-LierLa\", \"LoRA-C3Lier\"]:\n", 250 | " network_module = \"networks.lora\"\n", 251 | " if train_method == \"LoRA-C3Lier\":\n", 252 | " network_module_args = [f\"conv_dim={conv_dim}\", f\"conv_alpha={conv_alpha}\"]\n", 253 | " else:\n", 254 | " network_module_args = []\n", 255 | " elif train_method in [\"LoCon_Lycoris\", \"LoHa_Lycoris\"]:\n", 256 | " network_module = \"lycoris.kohya\"\n", 257 | " network_module_args = [f\"conv_dim={conv_dim}\", f\"conv_alpha={conv_alpha}\", f\"algo={algo}\"]\n", 258 | " elif train_method in [\"DyLoRa-LierLa\", \"DyLoRa-C3Lier\"]:\n", 259 | " network_module = \"networks.dylora\"\n", 260 | " if train_method == \"DyLoRa-C3Lier\":\n", 261 | " network_module_args = [f\"conv_dim={conv_dim}\", f\"conv_alpha={conv_alpha}\", f\"unit={unit}\"]\n", 262 | " else:\n", 263 | " network_module_args = [f\"unit={unit}\"]\n", 264 | " else: \n", 265 | " warnings.warn(f\"训练方法参数生成出错\", UserWarning)\n", 266 | " return network_module, network_module_args\n", 267 | " network_module, network_module_args = network_module_param( all.get(\"train_method\") )\n", 268 | " #更多network_args部分(主要为分层训练)\n", 269 | " network_lr_weight_args = [ f\"{name}={all.get(name)}\" for name in [\"up_lr_weight\", \"mid_lr_weight\", \"down_lr_weight\"] if all.get(name) ]\n", 270 | "\n", 271 | " def network_block_param(train_method):\n", 272 | " lst = [\"block_dims\", \"block_alphas\", \"conv_block_dims\", \"conv_block_alphas\"]\n", 273 | " if train_method == \"LoRA-LierLa\":\n", 274 | " return [ f\"{name}={all.get(name)}\" for name in lst[0:1] if all.get(name) ]\n", 275 | " if train_method in [\"LoRA-C3Lier\", \"LoCon_Lycoris\", \"LoHa_Lycoris\"]:\n", 276 | " return [ f\"{name}={all.get(name)}\" for name in lst if all.get(name) ]\n", 277 | " else:\n", 278 | " return []\n", 279 | " network_block_args = network_block_param( all.get(\"train_method\") )\n", 280 | " \n", 281 | "\n", 282 | " network_args = []\n", 283 | " network_args.extend(network_module_args)\n", 284 | " network_args.extend(network_lr_weight_args)\n", 285 | " network_args.extend(network_block_args)\n", 286 | "\n", 287 | " additional_network_arguments.update( { \"network_module\":network_module } )\n", 288 | " additional_network_arguments.update( {\"network_args\":network_args} ) \n", 289 | "\n", 290 | " #optimizer_arguments部分\n", 291 | " optimizer_arguments = { key: all.get(key) for key in [\"optimizer_type\", \"lr_scheduler\", \"lr_warmup_steps\"] }\n", 292 | " \"\"\"只有余弦重启调度器指定重启次数\"\"\"\n", 293 | " if all.get(\"lr_scheduler\") == \"cosine_with_restarts\":\n", 294 | " optimizer_arguments.update( {\"lr_restart_cycles\":all.get(\"lr_restart_cycles\")} )\n", 295 | " \"\"\"学习率lr指定=unet_lr\"\"\"\n", 296 | " optimizer_arguments.update( {\"learning_rate\":all.get(\"unet_lr\")} )\n", 297 | " #optimizer_args(待添加)\n", 298 | "\n", 299 | " #dataset_arguments部分\n", 300 | " dataset_arguments = { key: all.get(key) for key in [\"cache_latents\", \"shuffle_caption\", \"enable_bucket\"] }\n", 301 | " \n", 302 | " #training_arguments部分\n", 303 | " training_arguments = { key: all.get(key) for key in [\"batch_size\", \"noise_offset\", \"keep_tokens\",\\\n", 304 | " \"min_bucket_reso\", \"max_bucket_reso\",\\\n", 305 | " \"caption_extension\", \"max_token_length\", \"seed\",\\\n", 306 | " \"xformers\", \"lowram\"]\n", 307 | " }\n", 308 | " \"\"\"min_snr_gamma大于零才生效\"\"\"\n", 309 | " if all.get(\"min_snr_gamma\") > 0:\n", 310 | " training_arguments.update( { \"min_snr_gamma\":all.get(\"min_snr_gamma\") } )\n", 311 | " \"\"\" 最大训练时间 \"\"\"\n", 312 | " training_arguments.update( { all.get(\"max_train_method\"):all.get(\"max_train_value\") } )\n", 313 | " \"\"\" 训练分辨率 \"\"\"\n", 314 | " training_arguments.update( { \"resolution\":f\"{all.get('width')},{all.get('height')}\" } )\n", 315 | " \"\"\" 如果v2开启,则不指定clip_skip \"\"\"\n", 316 | " if not all.get(\"v2\"):\n", 317 | " training_arguments.update( { \"clip_skip\":all.get(\"clip_skip\") } )\n", 318 | " \"\"\" 重训练模块 \"\"\"\n", 319 | " if all.get(\"use_retrain\") == \"model\":\n", 320 | " training_arguments.update( { \"network_weights\":all.get(\"retrain_dir\") } )\n", 321 | " elif all.get(\"use_retrain\") == \"state\":\n", 322 | " training_arguments.update( { \"resume\":all.get(\"retrain_dir\") } )\n", 323 | " \"\"\" 训练精度、保存精度 \"\"\"\n", 324 | " training_arguments.update( { \"mixed_precision\":\"fp16\" } )\n", 325 | " training_arguments.update( { \"save_precision\":\"fp16\" } )\n", 326 | " \n", 327 | "\n", 328 | "\n", 329 | " #sample_prompt_arguments部分(采样间隔,采样文件地址待添加)\n", 330 | " sample_prompt_arguments = { key: all.get(key) for key in [\"sample_sampler\"] }\n", 331 | " if all.get(\"sample_every_n_type\"): #如果采样部分没确认过一次,会出现all.get(\"sample_every_n_type\")=None:None的字典造成报错\n", 332 | " sample_prompt_arguments.update( {all.get(\"sample_every_n_type\"):all.get(\"sample_every_n_type_value\")} )\n", 333 | "\n", 334 | " #dreambooth_arguments部分\n", 335 | " dreambooth_arguments = { key: all.get(key) for key in [\"train_data_dir\", \"reg_data_dir\", \"prior_loss_weight\"] }\n", 336 | "\n", 337 | " #saving_arguments部分\n", 338 | " saving_arguments = { key: all.get(key) for key in [\"output_dir\",\\\n", 339 | " \"output_name\", \"save_every_n_epochs\", \"save_n_epoch_ratio\",\\\n", 340 | " \"save_last_n_epochs\", \"save_state\", \"save_model_as\" ]\n", 341 | " }\n", 342 | " \"\"\" 指定log输出目录与output相同 \"\"\"\n", 343 | " saving_arguments.update( { \"logging_dir\":os.path.join( all.get(\"output_dir\"), \"logs\" ) } )\n", 344 | " \"\"\" 指定log前缀和输出名字相同 \"\"\"\n", 345 | " saving_arguments.update( { \"log_prefix\":all.get(\"output_name\") } )\n", 346 | " \n", 347 | "\n", 348 | " toml_dict = {\"model_arguments\":model_arguments,\n", 349 | " \"additional_network_arguments\":additional_network_arguments,\n", 350 | " \"optimizer_arguments\":optimizer_arguments,\n", 351 | " \"dataset_arguments\":dataset_arguments,\n", 352 | " \"training_arguments\":training_arguments,\n", 353 | " \"sample_prompt_arguments\":sample_prompt_arguments,\n", 354 | " \"dreambooth_arguments\":dreambooth_arguments,\n", 355 | " \"saving_arguments\":saving_arguments,\n", 356 | " }\n", 357 | " toml_str = toml.dumps(toml_dict)\n", 358 | " return toml_str\n", 359 | " def sample_parameter2txt():\n", 360 | " #key_list = [\"prompt\", \"negative\", \"sample_width\", \"sample_height\", \"sample_scale\", \"sample_steps\", \"sample_seed\"]\n", 361 | "\n", 362 | " if not all.get('sample_seed'): #如果采样部分没确认过,会出现all.get('sample_seed')=None > 0造成报错\n", 363 | " return \"\"\n", 364 | " sample_str = f\"\"\"{all.get(\"prompt\")} \\\n", 365 | "--n {all.get(\"negative\")} \\\n", 366 | "--w {all.get(\"sample_width\")} \\\n", 367 | "--h {all.get(\"sample_height\")} \\\n", 368 | "--l {all.get(\"sample_scale\")} \\\n", 369 | "--s {all.get(\"sample_steps\")} \\\n", 370 | "{f\"--d {all.get('sample_seed')}\" if all.get('sample_seed') > 0 else \"\"}\"\"\"\n", 371 | " return sample_str\n", 372 | "\n", 373 | " def write(content, path):\n", 374 | " with open(path, \"w\", encoding=\"utf-8\") as f:\n", 375 | " f.write(content)\n", 376 | "\n", 377 | " write(parameter2toml(), config_file_toml_path)\n", 378 | " write(sample_parameter2txt(), sample_prompts_txt_path)\n", 379 | " write_files_title = f\"写入成功, 训练配置文件在{config_file_toml_path}, 采样参数文件在{sample_prompts_txt_path}\"\n", 380 | " return write_files_title\n" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "metadata": { 387 | "id": "LT7NaOl8Jfgf", 388 | "cellView": "form" 389 | }, 390 | "outputs": [], 391 | "source": [ 392 | "#@title WebUI部分\n", 393 | "\n", 394 | "\n", 395 | "with gr.Blocks() as demo:\n", 396 | " with gr.Accordion(\"保存、读取\\nwebui配置\", open=False):\n", 397 | " save_read_webui_config_title = gr.Markdown(\"保存或读取\")\n", 398 | " with gr.Row():\n", 399 | " save_webui_config_button = gr.Button(\"保存\")\n", 400 | " with gr.Row():\n", 401 | " save_webui_config_dir = gr.Textbox(lines=1, label=\"保存目录\", value=os.path.join(os.getcwd(),\"kohya_config_webui_save\") )\n", 402 | " save_webui_config_name = gr.Textbox(lines=1, label=\"保存名字(以toml为扩展名,否则不会被读取)\", value=\"kohya_config_webui_save.toml\" )\n", 403 | " with gr.Row():\n", 404 | " read_webui_config_get_button = gr.Button(refresh_symbol)\n", 405 | " read_webui_config_button = gr.Button(\"读取\")\n", 406 | " with gr.Row():\n", 407 | " read_webui_config_dir = gr.Textbox(lines=1, label=\"读取目录\", value=os.path.join(os.getcwd(),\"kohya_config_webui_save\") ) \n", 408 | " read_webui_config_name = gr.Dropdown(choices=[], label=\"读取文件\", value=\"\" ) \n", 409 | " with gr.Row():\n", 410 | " write_files_button = gr.Button(\"生成toml参数与采样配置文件\")\n", 411 | " all_parameter_get_button = gr.Button(\"全部参数确认\")\n", 412 | " write_files_dir = gr.Textbox( lines=1, label=\"写入文件夹\", placeholder=\"一般填kohya_script目录,留空就默认根目录\", value=\"\" )\n", 413 | " write_files_title = gr.Markdown(\"生成适用于kohya/train_network.py的配置文件\")\n", 414 | " with gr.Tabs():\n", 415 | " with gr.TabItem(\"基础参数\"):\n", 416 | " common_parameter_get_button = gr.Button(\"确定\")\n", 417 | " common_parameter_title = gr.Markdown(\"\")\n", 418 | " with gr.Accordion(\"当前基础参数配置\", open=False):\n", 419 | " common_parameter_toml = gr.Textbox(label=\"toml形式\", placeholder=\"基础参数\", value=\"\")\n", 420 | " with gr.Row():\n", 421 | " train_data_dir = gr.Textbox(lines=1, label=\"train_data_dir\", placeholder=\"训练集路径\", value=\"\")\n", 422 | " with gr.Accordion(\"使用正则化(可选)\", open=False):\n", 423 | " with gr.Row():\n", 424 | " reg_data_dir = gr.Textbox(lines=1, label=\"reg_data_dir\", placeholder=\"正则化集路径(填入意味着启用正则化)\", value=\"\")\n", 425 | " prior_loss_weight = gr.Slider(0, 1, step=0.01, value=0.3, label=\"正则化权重\")\n", 426 | " with gr.Row():\n", 427 | " base_model_dir = gr.Textbox(label=\"底模文件夹地址\", placeholder=\"文件夹路径\", value=\"\")\n", 428 | " base_model_name = gr.Dropdown(choices=[],label=\"底模\",value=\"\")\n", 429 | " base_model_get_button = gr.Button(refresh_symbol)\n", 430 | " with gr.Accordion(\"使用vae(可选)\", open=False):\n", 431 | " with gr.Row():\n", 432 | " use_vae = gr.Checkbox(label=\"是否使用vae\",value=False)\n", 433 | " with gr.Row():\n", 434 | " vae_model_dir = gr.Textbox(label=\"vae文件夹地址\", placeholder=\"文件夹路径\", value=\"\")\n", 435 | " vae_model_name = gr.Dropdown(choices=[],label=\"vae\", value=\"\")\n", 436 | " vae_model_get_button = gr.Button(refresh_symbol)\n", 437 | " with gr.Row():\n", 438 | " width = gr.Slider(64, 1920, step=64, value=512, label=\"训练分辨率(宽)width\")\n", 439 | " height = gr.Slider(64, 1920, step=64, value=512, label=\"训练分辨率(高)height\")\n", 440 | " batch_size = gr.Slider(1, 24, step=1, value=1, label=\"batch大小\")\n", 441 | " with gr.Row():\n", 442 | " noise_offset = gr.Slider(0, 1, step=0.01, value=0.05, label=\"noise_offset\")\n", 443 | " keep_tokens = gr.Slider(0, 225, step=1, value=0, label=\"keep_tokens\")\n", 444 | " min_snr_gamma = gr.Slider(0, 100, step=0.1, value=5, label=\"min_snr_gamma(设置为0则不生效)\")\n", 445 | " \"\"\"\n", 446 | " with gr.Row():\n", 447 | " gr.Markdown(\"repeat * 图片数 = 每个epoch的steps数\")\n", 448 | " \"\"\"\n", 449 | " with gr.Row():\n", 450 | " max_train_method = gr.Dropdown([\"max_train_epochs\",\"max_train_steps\"], label=\"以epochs或steps来指定最大训练时间\", value=\"max_train_epochs\")\n", 451 | " max_train_value = gr.Number(label=\"最大训练epochs\\steps数\", value=10, precision=0)\n", 452 | " with gr.Accordion(\"输出设置\", open=True):\n", 453 | " with gr.Row():\n", 454 | " output_dir = gr.Textbox( label=\"模型、log日志输出地址(自行修改)\", placeholder=\"文件夹路径\",value=os.path.join(os.getcwd(),\"output\") )\n", 455 | " output_name = gr.Textbox(label=\"输出模型名称(自行修改)\", placeholder=\"名称\",value=\"output_name\")\n", 456 | " save_model_as = gr.Dropdown([\"safetensors\",\"ckpt\",\"pt\"], label=\"保存模型格式\", value=\"safetensors\")\n", 457 | " with gr.Row():\n", 458 | " save_every_n_epochs = gr.Slider(1, 499, step=1, value=1, label=\"每n个epoch保存一次\")\n", 459 | " save_n_epoch_ratio = gr.Slider(1, 499, step=1, value=0, label=\"等间隔保存n个(如不为0,会覆盖每n个epoch保存一次)\")\n", 460 | " save_last_n_epochs = gr.Slider(1, 499, step=1, value=499, label=\"最多保存n个(后面的出来就会把前面删了,优先级最高)\")\n", 461 | " with gr.Row(): \n", 462 | " save_state = gr.Checkbox(label=\"保存学习状态\",value=False)\n", 463 | " with gr.Row():\n", 464 | " optimizer_type = gr.Dropdown([\"AdamW8bit\", \"Lion\", \"DAdaptation\", \"AdamW\", \"SGDNesterov\", \"SGDNesterov8bit\", \"AdaFactor\"],\\\n", 465 | " label=\"optimizer_type优化器类型\", value=\"AdamW8bit\")\n", 466 | " unet_lr = gr.Number(label=\"unet学习率\", value=1e-4)\n", 467 | " text_encoder_lr = gr.Number(label=\"text_encoder学习率\", value=1e-5)\n", 468 | " with gr.Row():\n", 469 | " lr_scheduler = gr.Dropdown([\"cosine_with_restarts\",\"cosine\",\"polynomial\",\"linear\",\"constant_with_warmup\",\"constant\"],\\\n", 470 | " label=\"lr_scheduler学习率调度器\", value=\"cosine_with_restarts\")\n", 471 | " lr_warmup_steps = gr.Number(label=\"升温步数\", value=0, precision=0)\n", 472 | " lr_restart_cycles = gr.Number(label=\"退火重启次数\", value=1, precision=0)\n", 473 | " with gr.Row():\n", 474 | " train_method = gr.Dropdown([\"LoRA-LierLa\", \"LoRA-C3Lier\",\\\n", 475 | " \"LoCon_Lycoris\",\"LoHa_Lycoris\",\\\n", 476 | " \"DyLoRa-LierLa\", \"DyLoRa-C3Lier\"],\\\n", 477 | " label=\"train_method训练方法\", value=\"LoRA-LierLa\")\n", 478 | " network_dim = gr.Number(label=\"线性dim\", value=32, precision=0)\n", 479 | " network_alpha = gr.Number(label=\"线性alpha(可以为小数)\", value=16)\n", 480 | " with gr.Accordion(\"额外网络参数(LoRA-C3Lier、LoCon、LoHa、DyLoRa-C3Lier都属于卷积,unit为DyLoRa专用)\", open=True):\n", 481 | " with gr.Row():\n", 482 | " with gr.Column():\n", 483 | " conv_dim = gr.Number(label=\"卷积dim\", info=\"使用DyLoRa-C3Lier时会被设置为等于基础dim\", value=8, precision=0)\n", 484 | " with gr.Column():\n", 485 | " conv_alpha = gr.Number(label=\"卷积alpha\", info=\"可以为小数\", value=1)\n", 486 | " with gr.Column():\n", 487 | " unit = gr.Number(label=\"分割单位unit(整数)\", info=\"使用DyLoRa时,请让dim为unit的倍数\", value=1, precision=0)\n", 488 | " with gr.Row(): \n", 489 | " v2 = gr.Checkbox(label=\"v2\")\n", 490 | " v_parameterization = gr.Checkbox(label=\"v_parameterization\")\n", 491 | " lowram = gr.Checkbox(label=\"lowram\")\n", 492 | " xformers = gr.Checkbox(label=\"xformers\",value=True)\n", 493 | " cache_latents = gr.Checkbox(label=\"cache_latents\",value=True)\n", 494 | " shuffle_caption = gr.Checkbox(label=\"shuffle_caption\",value=True)\n", 495 | " enable_bucket = gr.Checkbox(label=\"enable_bucket\",value=True)\n", 496 | " with gr.TabItem(\"采样参数\"):\n", 497 | " sample_parameter_get_button = gr.Button(\"确定\")\n", 498 | " sample_parameter_title = gr.Markdown(\"\")\n", 499 | " with gr.Accordion(\"当前采样配置\", open=False):\n", 500 | " sample_parameter_toml = gr.Textbox(label=\"toml形式\", placeholder=\"采样配置\", value=\"\")\n", 501 | " with gr.Row():\n", 502 | " #enable_sample = gr.Checkbox(label=\"是否启用采样功能\")\n", 503 | " sample_every_n_type = gr.Dropdown([\"sample_every_n_epochs\", \"sample_every_n_steps\"], label=\"sample_every_n_type\", value=\"sample_every_n_epochs\")\n", 504 | " sample_every_n_type_value = gr.Number(label=\"sample_every_n_type_value\", value=1, precision=0)\n", 505 | " with gr.Row():\n", 506 | " sample_sampler = gr.Dropdown([\"ddim\", \"pndm\", \"lms\", \"euler\", \"euler_a\", \"heun\",\\\n", 507 | " \"dpm_2\", \"dpm_2_a\", \"dpmsolver\",\"dpmsolver++\", \"dpmsingle\",\\\n", 508 | " \"k_lms\", \"k_euler\", \"k_euler_a\", \"k_dpm_2\", \"k_dpm_2_a\"],\\\n", 509 | " label=\"采样器\", value=\"euler_a\")\n", 510 | " sample_seed = gr.Number(label=\"采样种子(-1不是随机,大于0才生效)\", value=-1, precision=0)\n", 511 | " with gr.Row():\n", 512 | " sample_width = gr.Slider(64, 1920, step=64, value=512, label=\"采样图片宽\")\n", 513 | " sample_height = gr.Slider(64, 1920, step=64, value=768, label=\"采样图片高\")\n", 514 | " sample_scale = gr.Slider(1, 30, step=0.5, value=7, label=\"提示词相关性\")\n", 515 | " sample_steps = gr.Slider(1, 150, step=1, value=24, label=\"采样迭代步数\")\n", 516 | " with gr.Row():\n", 517 | " prompt = gr.Textbox(lines=10, label=\"prompt\", placeholder=\"正面提示词\", value=\"(masterpiece, best quality, hires:1.2), 1girl, solo,\")\n", 518 | " default_negative = (\"(worst quality, bad quality:1.4), \"\n", 519 | " \"lowres, bad anatomy, bad hands, text, error, \"\n", 520 | " \"missing fingers, extra digit, fewer digits, \"\n", 521 | " \"cropped, worst quality, low quality, normal quality, \"\n", 522 | " \"jpeg artifacts,signature, watermark, username, blurry,\")\n", 523 | " negative = gr.Textbox(lines=10, label=\"negative\", placeholder=\"负面提示词\", value=default_negative)\n", 524 | " with gr.TabItem(\"进阶参数\"):\n", 525 | " plus_parameter_get_button = gr.Button(\"确定\")\n", 526 | " plus_parameter_title = gr.Markdown(\"\")\n", 527 | " with gr.Accordion(\"当前进阶参数配置\", open=False):\n", 528 | " plus_parameter_toml = gr.Textbox(label=\"toml形式\", placeholder=\"进阶参数\", value=\"\")\n", 529 | " with gr.Row():\n", 530 | " use_retrain = gr.Dropdown([\"no\",\"model\",\"state\"], label=\"是否使用重训练\", value=\"no\")\n", 531 | " retrain_dir = gr.Textbox(lines=1, label=\"重训练路径\", placeholder=\"模型或者状态路径\", value=\"\")\n", 532 | " with gr.Row():\n", 533 | " min_bucket_reso = gr.Slider(64, 1920, step=64, value=256, label=\"最低桶分辨率\")\n", 534 | " max_bucket_reso = gr.Slider(64, 1920, step=64, value=1024, label=\"最高桶分辨率\")\n", 535 | " clip_skip = gr.Slider(0, 25, step=1, value=2, label=\"跳过层数\")\n", 536 | " max_token_length = gr.Slider(75, 225, step=75, value=225, label=\"训练最大token数\")\n", 537 | " caption_extension = gr.Textbox(lines=1, label=\"标签文件扩展名\", placeholder=\"一般填.txt或.cap\", value=\".txt\")\n", 538 | " seed = gr.Number(label=\"种子\", value=1337, precision=0)\n", 539 | " with gr.Row():\n", 540 | " network_train_unet_only= gr.Checkbox(label=\"仅训练unet网络\",value=False)\n", 541 | " network_train_text_encoder_only = gr.Checkbox(label=\"仅训练text_encoder网络\",value=False)\n", 542 | " with gr.Accordion(\"分层学习模块\", open=True):\n", 543 | " gr.Markdown(\"学习率分层,为不同层的结构指定不同学习率倍数; 如果某一层权重为0,那该层不会被创建\")\n", 544 | " with gr.Row():\n", 545 | " with gr.Column(scale=15):\n", 546 | " up_lr_weight = gr.Textbox(lines=1, label=\"上层学习率权重\", placeholder=\"留空则不启用\",\\\n", 547 | " info=\"15层,例如1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5\", value=\"\")\n", 548 | " with gr.Column(scale=1):\n", 549 | " mid_lr_weight = gr.Textbox(lines=1, label=\"中层学习率权重\", placeholder=\"留空则不启用\",\\\n", 550 | " info=\"1层,例如2.0\", value=\"\")\n", 551 | " with gr.Column(scale=15):\n", 552 | " down_lr_weight = gr.Textbox(lines=1, label=\"下层学习率权重\", placeholder=\"留空则不启用\",\\\n", 553 | " info=\"15层,例如0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5\", value=\"\")\n", 554 | " gr.Markdown(\"dim和alpha分层,为不同层的结构指定不同的dim和alpha(`DyLoRa`无法使用,卷积分层只有`LoRa-C3Lier、LoCon、LoHa`可以使用)\")\n", 555 | " with gr.Row():\n", 556 | " block_dims = gr.Textbox(lines=1, label=\"线性dim分层\", placeholder=\"留空则不启用\",\\\n", 557 | " info=\"25层(上中下),例如2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2\", value=\"\")\n", 558 | " block_alphas = gr.Textbox(lines=1, label=\"线性alpha分层\", placeholder=\"留空则不启用\",\\\n", 559 | " info=\"25层(上中下),例如2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2\", value=\"\")\n", 560 | " with gr.Row():\n", 561 | " conv_block_dims = gr.Textbox(lines=1, label=\"卷积dim分层\", placeholder=\"留空则不启用\",\\\n", 562 | " info=\"25层(上中下),例如2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2\", value=\"\")\n", 563 | " conv_block_alphas = gr.Textbox(lines=1, label=\"卷积alpha分层\", placeholder=\"留空则不启用\",\\\n", 564 | " info=\"25层(上中下),例如2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2\", value=\"\")\n", 565 | "\n", 566 | "\n", 567 | " def dict_key_list_2_list(dict_key_list):\n", 568 | " list = []\n", 569 | " for key in dict_key_list:\n", 570 | " try:\n", 571 | " list.append(globals()[key])\n", 572 | " except KeyError:\n", 573 | " print(f\"Error: parameter_dict_key_list中{key}不存在\")\n", 574 | " list_len = len(list)\n", 575 | " return list, list_len\n", 576 | "\n", 577 | " common_parameter_dict_key_list = [\"train_data_dir\",\n", 578 | " \"reg_data_dir\",\n", 579 | " \"prior_loss_weight\",\n", 580 | " \"base_model_dir\",\n", 581 | " \"base_model_name\",\n", 582 | " \"use_vae\",\n", 583 | " \"vae_model_dir\",\n", 584 | " \"vae_model_name\",\n", 585 | " \"width\",\n", 586 | " \"height\",\n", 587 | " \"batch_size\",\n", 588 | " \"noise_offset\",\n", 589 | " \"keep_tokens\",\n", 590 | " \"min_snr_gamma\",\n", 591 | " \"max_train_method\",\n", 592 | " \"max_train_value\",\n", 593 | " \"output_dir\",\n", 594 | " \"output_name\",\n", 595 | " \"save_model_as\",\n", 596 | " \"save_every_n_epochs\",\n", 597 | " \"save_n_epoch_ratio\",\n", 598 | " \"save_last_n_epochs\",\n", 599 | " \"save_state\",\n", 600 | " \"optimizer_type\",\n", 601 | " \"unet_lr\",\n", 602 | " \"text_encoder_lr\",\n", 603 | " \"lr_scheduler\",\n", 604 | " \"lr_warmup_steps\",\n", 605 | " \"lr_restart_cycles\",\n", 606 | " \"train_method\",\n", 607 | " \"network_dim\",\n", 608 | " \"network_alpha\",\n", 609 | " \"conv_dim\",\n", 610 | " \"conv_alpha\",\n", 611 | " \"unit\",\n", 612 | " \"v2\",\n", 613 | " \"v_parameterization\",\n", 614 | " \"lowram\",\n", 615 | " \"xformers\",\n", 616 | " \"cache_latents\",\n", 617 | " \"shuffle_caption\",\n", 618 | " \"enable_bucket\"]\n", 619 | " common_parameter_list, parameter_len_dict[\"common\"] = dict_key_list_2_list(common_parameter_dict_key_list)\n", 620 | " sample_parameter_dict_key_list = [\"sample_every_n_type\",\n", 621 | " \"sample_every_n_type_value\",\n", 622 | " \"sample_sampler\",\n", 623 | " \"sample_seed\",\n", 624 | " \"sample_width\",\n", 625 | " \"sample_height\",\n", 626 | " \"sample_scale\",\n", 627 | " \"sample_steps\",\n", 628 | " \"prompt\",\n", 629 | " \"negative\"]\n", 630 | " sample_parameter_list, parameter_len_dict[\"sample\"] = dict_key_list_2_list(sample_parameter_dict_key_list)\n", 631 | " plus_parameter_dict_key_list = [\"use_retrain\",\n", 632 | " \"retrain_dir\",\n", 633 | " \"min_bucket_reso\",\n", 634 | " \"max_bucket_reso\",\n", 635 | " \"clip_skip\",\n", 636 | " \"max_token_length\",\n", 637 | " \"caption_extension\",\n", 638 | " \"seed\",\n", 639 | " \"network_train_unet_only\",\n", 640 | " \"network_train_text_encoder_only\",\n", 641 | " \"up_lr_weight\",\n", 642 | " \"mid_lr_weight\",\n", 643 | " \"down_lr_weight\",\n", 644 | " \"block_dims\",\n", 645 | " \"block_alphas\",\n", 646 | " \"conv_block_dims\",\n", 647 | " \"conv_block_alphas\"]\n", 648 | " plus_parameter_list, parameter_len_dict[\"plus\"] = dict_key_list_2_list(plus_parameter_dict_key_list)\n", 649 | "\n", 650 | " #注意,这几个list相加的顺序不能错\n", 651 | " all_parameter_list = common_parameter_list + sample_parameter_list + plus_parameter_list\n", 652 | " all_parameter_dict_key_list = common_parameter_dict_key_list + sample_parameter_dict_key_list + plus_parameter_dict_key_list\n", 653 | "\n", 654 | " save_webui_config_button.click(fn=save_webui_config,\n", 655 | " inputs=[save_webui_config_dir, save_webui_config_name, write_files_dir],\n", 656 | " outputs=save_read_webui_config_title \n", 657 | " )\n", 658 | " read_webui_config_get_button.click(fn=read_webui_config_get,\n", 659 | " inputs=[read_webui_config_dir],\n", 660 | " outputs=[read_webui_config_name] \n", 661 | " )\n", 662 | " read_webui_config_button.click(fn=read_webui_config,\n", 663 | " inputs=[read_webui_config_dir, read_webui_config_name, write_files_dir] + all_parameter_list,\n", 664 | " outputs=[save_read_webui_config_title, write_files_dir] + all_parameter_list\n", 665 | " )\n", 666 | " common_parameter_get_button.click(fn=common_parameter_get,\n", 667 | " inputs=common_parameter_list,\n", 668 | " outputs=[common_parameter_toml, common_parameter_title]\n", 669 | " )\n", 670 | " sample_parameter_get_button.click(fn=sample_parameter_get,\n", 671 | " inputs=sample_parameter_list,\n", 672 | " outputs=[sample_parameter_toml, sample_parameter_title]\n", 673 | " )\n", 674 | " plus_parameter_get_button.click(fn=plus_parameter_get,\n", 675 | " inputs=plus_parameter_list,\n", 676 | " outputs=[plus_parameter_toml, plus_parameter_title]\n", 677 | " )\n", 678 | " all_parameter_get_button.click(fn=all_parameter_get,\n", 679 | " inputs=all_parameter_list,\n", 680 | " outputs=[common_parameter_toml, sample_parameter_toml, plus_parameter_toml, write_files_title]\n", 681 | " )\n", 682 | " base_model_get_button.click(fn=model_get,\n", 683 | " inputs=[base_model_dir],\n", 684 | " outputs=[base_model_name]\n", 685 | " )\n", 686 | " vae_model_get_button.click(fn=model_get,\n", 687 | " inputs=[vae_model_dir],\n", 688 | " outputs=[vae_model_name]\n", 689 | " )\n", 690 | " write_files_button.click(fn=write_files,\n", 691 | " inputs=[write_files_dir],\n", 692 | " outputs=[write_files_title]\n", 693 | " )\n", 694 | "\n", 695 | "\n", 696 | "if __name__ == \"__main__\":\n", 697 | " demo.launch(share=False,inbrowser=False,inline=True,debug=True)" 698 | ] 699 | } 700 | ], 701 | "metadata": { 702 | "colab": { 703 | "provenance": [], 704 | "authorship_tag": "ABX9TyNv6whrNjHPpxmgtHuC8ypu", 705 | "include_colab_link": true 706 | }, 707 | "gpuClass": "standard", 708 | "kernelspec": { 709 | "display_name": "Python 3", 710 | "name": "python3" 711 | }, 712 | "language_info": { 713 | "name": "python" 714 | } 715 | }, 716 | "nbformat": 4, 717 | "nbformat_minor": 0 718 | } -------------------------------------------------------------------------------- /module/kohya_config_webui.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """kohya_config_webui.ipynb 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/github/WSH032/kohya-config-webui/blob/main/kohya_config_webui.ipynb 8 | 9 | | ![visitors](https://visitor-badge.glitch.me/badge?page_id=wsh.kohya_config_webui) | [![GitHub Repo stars](https://img.shields.io/github/stars/WSH032/kohya-config-webui?style=social)](https://github.com/WSH032/kohya-config-webui) | 10 | 11 | #A WebUI for making config files used by kohya_sd_script 12 | 13 | Created by [WSH](https://space.bilibili.com/8417436) 14 | """ 15 | 16 | 17 | """ 18 | 《维护指南》: 19 | 如果你想为webui三个选项卡中任意一个选项卡添加一个新的组件,并以其值指定toml文件相应的参数 20 | 1、 21 | 去webui部分添加相应的组件代码,并给新的组件命名。 22 | **如:在common参数部分添加 common_gr_dict["new_param"]= gr.Number(value=3) 23 | 2、 24 | 找到parameter2toml()函数, 确认你要写入toml文件的位置,比如说你想在[model_arguments]键下面添加一个 new_param=3 25 | 那么请用model_arguments.update( "new_param":all.get("new_param") ) 26 | 注意这里用的是all这个字典,因为在parameter2toml()一开始就会把三个全局参数字典合成一个all字典 27 | 3、 28 | 完成,检查生成的toml是否正确 29 | """ 30 | 31 | #@title 函数部分 32 | 33 | 34 | import os 35 | import toml 36 | import warnings 37 | import gradio as gr 38 | import argparse 39 | 40 | 41 | """ 以下的变量,标注gloab的是要在函数和webui部分同时使用的 """ 42 | 43 | ROOT_DIR = os.getcwd() #global 44 | 45 | #用于储存适三个选项卡中输入组件的名字; global 46 | common_parameter_dict_key_list=[] 47 | sample_parameter_dict_key_list=[] 48 | plus_parameter_dict_key_list=[] 49 | all_parameter_dict_key_list=[] #后面会有一次all_parameter_dict_key_list = common_parameter_dict_key_list + sample_parameter_dict_key_list + plus_parameter_dict_key_list 50 | 51 | #用于储存所有确认的组件值,如训练集地址; (只会在函数部分使用) 52 | common_parameter_dict=({}) 53 | sample_parameter_dict=({}) 54 | plus_parameter_dict=({}) 55 | 56 | common_confirm_flag = False #必须要确认常规参数一次才允许写入toml; (只会在函数部分使用) 57 | 58 | #用于确定每个选项卡中按键的数量,方便all_parameter_get; global 59 | parameter_len_dict={"common":0, "sample":0, "plus":0} 60 | 61 | 62 | def check_len_and_2dict(args, parameter_len_dict_value, parameter_dict_key_list, func_name=""): 63 | """ 三个parameter_get()会把gradio组件的值传进来,检查传入值数量和parameter_len_dict_value是否相等 """ 64 | """ 相等则返回以parameter_dict_key_list中字符串做为key名字的字典 """ 65 | """ 不相等就给一个wanring """ 66 | if len(args) != parameter_len_dict_value: 67 | warnings.warn(f"传入{func_name}的参数长度不匹配", UserWarning) 68 | if len(parameter_dict_key_list) != parameter_len_dict_value: 69 | warnings.warn(f" {func_name}内部字典赋值关键字列表的长度不匹配", UserWarning) 70 | parameter_dict = dict(zip(parameter_dict_key_list, args)) 71 | return parameter_dict 72 | 73 | 74 | """ 下面三个函数会获取各自选项卡中的输入值,然后使用将其转为字典并赋值给各自的全局变量parameter_dict """ 75 | """ 最后返回webui相应的输出信息 """ 76 | def common_parameter_get(*args): 77 | global common_parameter_dict, common_confirm_flag 78 | common_confirm_flag = True #必须要确认常规参数一次才允许写入toml 79 | common_parameter_dict = check_len_and_2dict(args, parameter_len_dict["common"], common_parameter_dict_key_list, func_name="common_parameter_get") 80 | common_parameter_toml = toml.dumps(common_parameter_dict) 81 | common_parameter_title = "基础参数配置确认" 82 | return common_parameter_toml, common_parameter_title 83 | 84 | def sample_parameter_get(*args): 85 | global sample_parameter_dict 86 | sample_parameter_dict = check_len_and_2dict(args, parameter_len_dict["sample"], sample_parameter_dict_key_list, func_name="sample_parameter_get") 87 | sample_parameter_toml = toml.dumps(sample_parameter_dict) 88 | sample_parameter_title = "采样配置确认" 89 | return sample_parameter_toml, sample_parameter_title 90 | 91 | 92 | def plus_parameter_get(*args): 93 | global plus_parameter_dict 94 | plus_parameter_dict = check_len_and_2dict(args, parameter_len_dict["plus"], plus_parameter_dict_key_list, func_name="plus_parameter_get") 95 | plus_parameter_toml = toml.dumps(plus_parameter_dict) 96 | plus_parameter_title = "进阶参数配置确认" 97 | return plus_parameter_toml, plus_parameter_title 98 | 99 | 100 | """ 调用上面上个函数来确认全部参数值,并赋值给三个全局字典parameter_dict """ 101 | def all_parameter_get(*args): 102 | if len(args) != sum( parameter_len_dict.values() ): 103 | warnings.warn("传入all_parameter_get的参数长度不匹配", UserWarning) 104 | """ 通过parameter_len_dict字典中记录的各个选项卡输入组件数量来分配传入三个子函数的参数 """ 105 | common_parameter_toml, common_parameter_title = common_parameter_get( *args[ : parameter_len_dict["common"] ] ) 106 | sample_parameter_toml, sample_parameter_title = sample_parameter_get( *args[ parameter_len_dict["common"] : parameter_len_dict["common"] + parameter_len_dict["sample"] ] ) 107 | plus_parameter_toml, plus_parameter_title = plus_parameter_get( *args[ -parameter_len_dict["plus"] : ] ) 108 | return common_parameter_toml, sample_parameter_toml, plus_parameter_toml, "全部参数确认" 109 | 110 | 111 | def save_webui_config(save_webui_config_dir, save_webui_config_name, write_files_dir): 112 | """ 保存当前已经确认(在三个参数字典中,而不是webui中组件值)的配置参数 """ 113 | 114 | if not common_confirm_flag: 115 | return "必须要确认常规参数一次才允许保存!" 116 | 117 | os.makedirs(save_webui_config_dir, exist_ok=True) 118 | 119 | other = {"write_files_dir":write_files_dir} 120 | param = {**common_parameter_dict, **sample_parameter_dict, **plus_parameter_dict} 121 | dict = { "other":other, "param":param } 122 | 123 | save_webui_config_path = os.path.join(save_webui_config_dir, save_webui_config_name) 124 | with open(save_webui_config_path, "w", encoding="utf-8") as f: 125 | webui_config_str = toml.dumps( dict ) 126 | f.write(webui_config_str) 127 | return f"保存webui配置成功,文件在{save_webui_config_path}" 128 | 129 | def read_webui_config_get(read_webui_config_dir): 130 | """ 读取目录下以.toml结尾的文件,返回一个读取到的文件list来更新gradio组件 """ 131 | try: 132 | files = [f for f in os.listdir(read_webui_config_dir) if os.path.isfile(os.path.join(read_webui_config_dir, f)) and f.endswith(".toml") ] 133 | if files: 134 | return gr.update( choices=files,value=files[0] ) 135 | else: 136 | return gr.update( choices=[],value="没有找到webui配置文件" ) 137 | except Exception as e: 138 | return gr.update( choices=[], value=f"错误的文件夹路径:{e}" ) 139 | 140 | def read_webui_config(read_webui_config_dir, read_webui_config_name, write_files_dir, *args): 141 | """ 读取预先保存的config文件,来更新webui界面(注意,不更新参数字典,需要用户手动确认) """ 142 | 143 | dir_change_flag = False #读取完保存的config文件后会修改写入文件夹这一组件,这里会判断是否修改,如修改给webui一个提示 144 | #检查传入的更新组件数量是否和webui中一致 145 | param_len = sum( parameter_len_dict.values() ) 146 | if len(args) != param_len: 147 | warnings.warn("传入read_webui_config的*args长度不匹配", UserWarning) 148 | 149 | read_webui_config_path = os.path.join(read_webui_config_dir, read_webui_config_name) 150 | #能打开就正常操作 151 | try: 152 | with open(read_webui_config_path, "r", encoding="utf-8") as f: 153 | config_dict = toml.loads( f.read() ) 154 | 155 | #能读到["other"].["write_files_dir"]就改,读不到就用原写入地址 156 | try: 157 | if config_dict["other"]["write_files_dir"] != write_files_dir: 158 | write_files_dir = config_dict["other"]["write_files_dir"] 159 | dir_change_flag = True 160 | except KeyError: 161 | pass 162 | 163 | param_dict_key_list = list( config_dict.get("param",{}).keys() ) 164 | #找出共有的key进行赋值,非共有的报错 165 | both_key = set(all_parameter_dict_key_list) & set(param_dict_key_list) 166 | parameter_unique_key = set(all_parameter_dict_key_list) - set(both_key) 167 | config_unique_key = set(param_dict_key_list) - set(both_key) 168 | 169 | 170 | 171 | #对共有的组件进行赋值 172 | count = 0 173 | if both_key: 174 | args = list(args) 175 | for key in both_key: 176 | index = all_parameter_dict_key_list.index(key) 177 | args[ index ] = config_dict["param"][key] 178 | count += 1 179 | 180 | def get_files_name_list(files_dir:str) -> list: 181 | """ 182 | 读取dir路径下的全部文件,返回一个字符串列表 183 | 用于更新base_model_name和vae_model_name组件 184 | """ 185 | try: 186 | #尝试读取 187 | files = [ f for f in os.listdir(files_dir) if os.path.isfile(os.path.join(files_dir, f)) ] 188 | #读取到了就返回 189 | if files: 190 | return files 191 | #读取不到就返回空字符串列表 192 | else: 193 | return [""] 194 | except Exception: 195 | #读取不到就返回空字符串列表 196 | return [""] 197 | 198 | def update_gr_model_list(model_dir_gr_name:str, model_name_gr_name:str): 199 | #如果保存的webui组件config文件中有模型路径的记录 200 | if model_dir_gr_name in both_key: 201 | #就尝试读取该路径下的所有文件 202 | model_name_list = get_files_name_list( config_dict["param"][model_dir_gr_name] ) 203 | #如果model_name在这个列表中,就保持;如果不在,就为列表第一个元素 204 | model_name = config_dict.get("param",{}).get(model_name_gr_name,"") 205 | model_name = model_name if model_name in model_name_list else model_name_list[0] 206 | #找到该gr组件在args中的索引 207 | index = all_parameter_dict_key_list.index(model_name_gr_name) 208 | args[ index ] = gr.update( choices=model_name_list, value=model_name ) 209 | 210 | update_gr_model_list("base_model_dir", "base_model_name") 211 | update_gr_model_list("vae_model_dir", "vae_model_name") 212 | 213 | args = tuple(args) 214 | 215 | 216 | read_done = f"\n读取完成,WebUI中共有{param_len}项参数,更新了其中{count}项\n" + (f"写入文件夹发生改变:{write_files_dir}" if dir_change_flag else "") 217 | config_warning = f"\nwebui-config文件中以下参数可能已经失效或错误:\n{config_unique_key}\n" if config_unique_key else "" 218 | parameter_warning = f"\nWebUI中以下参数在webui-config文件中未找到,不发生修改:\n{parameter_unique_key}\n" if parameter_unique_key else "" 219 | read_str = f"{read_done}{config_warning}{parameter_warning}" 220 | return read_str, write_files_dir, *args 221 | 222 | #打不开就返回原值 223 | except FileNotFoundError: 224 | return "文件或目录不存在", write_files_dir, *args 225 | except PermissionError: 226 | return "没有权限访问文件或目录", write_files_dir, *args 227 | except OSError as e: 228 | return f"something wrong:{e}", write_files_dir, *args 229 | 230 | 231 | 232 | def model_get(model_dir): 233 | """ 读取文件夹目录下的所有文件 """ 234 | try: 235 | #files = [f for f in os.listdir(model_dir) if os.path.isfile(os.path.join(model_dir, f)) and f.endswith(("ckpt", "pt", "safetensors")) ] 236 | files = [f for f in os.listdir(model_dir) if os.path.isfile(os.path.join(model_dir, f))] 237 | if files: 238 | return gr.update( choices=files,value=files[0] ) 239 | else: 240 | return gr.update( choices=[],value="没有找到模型" ) 241 | except Exception as e: 242 | return gr.update( choices=[], value=f"错误的文件夹路径:{e}" ) 243 | 244 | 245 | def write_files(write_files_dir): 246 | """ 用参数字典生成为kohya训练脚本要求的toml格式文件 """ 247 | 248 | if not common_confirm_flag: 249 | return "必须要确认常规参数一次才允许写入toml" 250 | 251 | write_files_dir = write_files_dir if write_files_dir else os.path.join(ROOT_DIR, "kohya_config") 252 | os.makedirs(write_files_dir, exist_ok=True) 253 | config_file_toml_path = os.path.join(write_files_dir, "config_file.toml") 254 | sample_prompts_txt_path = os.path.join(write_files_dir, "sample_prompts.txt") 255 | 256 | all = {**common_parameter_dict, **sample_parameter_dict, **plus_parameter_dict} 257 | 258 | def parameter2toml(): 259 | 260 | #生成config_file.toml的字典 261 | 262 | #model_arguments部分 263 | model_arguments = { key: all.get(key) for key in ["v2", "v_parameterization"] } 264 | """ 生成底模路径 """ 265 | base_model_path = os.path.join( all.get("base_model_dir"), all.get("base_model_name") ) 266 | model_arguments.update( {"pretrained_model_name_or_path": base_model_path} ) 267 | """ 生成vae路径 """ 268 | if all.get("use_vae"): 269 | vae_model_path = os.path.join( all.get("vae_model_dir"), all.get("vae_model_name") ) 270 | model_arguments.update( {"vae": vae_model_path} ) 271 | 272 | #additional_network_arguments部分 273 | additional_network_arguments = { key: all.get(key) for key in ["unet_lr", "text_encoder_lr", "network_dim",\ 274 | "network_alpha", "network_train_unet_only",\ 275 | "network_train_text_encoder_only"] } 276 | """ 生成如network_module = "locon.locon_kohya" """ 277 | #主要负责network_module的参数生成 278 | 279 | kohya_list = ["LoRA-LierLa", "LoRA-C3Lier"] 280 | lycoris_list = ["LoCon_Lycoris", "LoHa_Lycoris", "IA3_Lycoris", "LoKR_Lycoris"] 281 | dylora_list = ["DyLoRa-LierLa", "DyLoRa-C3Lier"] 282 | 283 | algo_list = ["lora", "loha", "ia3", "lokr"] 284 | 285 | def network_module_param(train_method): 286 | 287 | #train_method可能值如下 288 | #["LoRA-LierLa", "LoRA-C3Lier", "LoCon_Lycoris", "LoHa_Lycoris", "IA3_Lycoris", "LoKR_Lycoris", "DyLoRa-LierLa", "DyLoRa-C3Lier"] 289 | 290 | #卷积DyLoRa专门指定dim相同 291 | conv_dim = all.get("conv_dim") if train_method != "DyLoRa-C3Lier" else all.get("network_dim") 292 | conv_alpha = all.get("conv_alpha") 293 | 294 | unit = all.get("unit") 295 | 296 | #kohya网络 297 | if train_method in kohya_list: 298 | network_module = "networks.lora" 299 | if train_method == "LoRA-C3Lier": 300 | network_module_args = [f"conv_dim={conv_dim}", f"conv_alpha={conv_alpha}"] 301 | else: 302 | network_module_args = [] 303 | #lycoris网络 304 | elif train_method in lycoris_list: 305 | algo = algo_list[ lycoris_list.index(train_method) ] 306 | network_module = "lycoris.kohya" 307 | network_module_args = [f"conv_dim={conv_dim}", f"conv_alpha={conv_alpha}", f"algo={algo}"] 308 | #dylora网络 309 | elif train_method in dylora_list: 310 | network_module = "networks.dylora" 311 | if train_method == "DyLoRa-C3Lier": 312 | network_module_args = [f"conv_dim={conv_dim}", f"conv_alpha={conv_alpha}", f"unit={unit}"] 313 | else: 314 | network_module_args = [f"unit={unit}"] 315 | else: 316 | warnings.warn("训练方法参数生成出错", UserWarning) 317 | 318 | return network_module, network_module_args 319 | 320 | network_module, network_module_args = network_module_param( all.get("train_method") ) 321 | #更多network_args部分(主要为分层训练) 322 | network_lr_weight_args = [ f"{name}={all.get(name)}" for name in ["down_lr_weight", "mid_lr_weight", "up_lr_weight"] if all.get(name) ] 323 | 324 | def network_block_param(train_method): 325 | #dylora不允许分层 326 | lst = ["block_dims", "block_alphas", "conv_block_dims", "conv_block_alphas"] 327 | if train_method == "LoRA-LierLa": 328 | return [ f"{name}={all.get(name)}" for name in lst[0:1] if all.get(name) ] 329 | if train_method in ["LoRA-C3Lier"] + lycoris_list: 330 | return [ f"{name}={all.get(name)}" for name in lst if all.get(name) ] 331 | else: 332 | return [] 333 | 334 | network_block_args = network_block_param( all.get("train_method") ) 335 | 336 | #合成网络参数、分层参数 337 | network_args = [] 338 | network_args.extend(network_module_args) 339 | network_args.extend(network_lr_weight_args) 340 | network_args.extend(network_block_args) 341 | 342 | additional_network_arguments.update( { "network_module":network_module } ) 343 | additional_network_arguments.update( {"network_args":network_args} ) 344 | 345 | 346 | #optimizer_arguments部分 347 | optimizer_arguments = { key: all.get(key) for key in ["optimizer_type", "lr_scheduler", "lr_warmup_steps"] } 348 | """只有余弦重启调度器指定重启次数""" 349 | if all.get("lr_scheduler") == "cosine_with_restarts": 350 | optimizer_arguments.update( {"lr_restart_cycles":all.get("lr_restart_cycles")} ) 351 | """学习率lr指定=unet_lr""" 352 | optimizer_arguments.update( {"learning_rate":all.get("unet_lr")} ) 353 | """ optimizer_args """ 354 | optimizer_args_str = all.get("optimizer_args", "") 355 | if optimizer_args_str: 356 | optimizer_arguments.update( { "optimizer_args": [ x.strip() for x in optimizer_args_str.rstrip(", ").split(",") ] } ) 357 | 358 | 359 | #dataset_arguments部分 360 | dataset_arguments = { key: all.get(key) for key in ["cache_latents", "cache_latents_to_disk", "shuffle_caption",\ 361 | "enable_bucket", "weighted_captions"] } 362 | 363 | 364 | #training_arguments部分 365 | training_arguments = { key: all.get(key) for key in ["train_batch_size", "noise_offset", "keep_tokens",\ 366 | "min_bucket_reso", "max_bucket_reso",\ 367 | "caption_extension", "max_token_length", "seed",\ 368 | "xformers", "lowram",\ 369 | "gradient_checkpointing","gradient_accumulation_steps"] 370 | } 371 | """min_snr_gamma大于零才生效""" 372 | if all.get("min_snr_gamma") > 0: 373 | training_arguments.update( { "min_snr_gamma":all.get("min_snr_gamma") } ) 374 | """ 最大训练时间 """ 375 | training_arguments.update( { all.get("max_train_method"):all.get("max_train_value") } ) 376 | """ 训练分辨率 """ 377 | training_arguments.update( { "resolution":f"{all.get('width')},{all.get('height')}" } ) 378 | """ 如果v2开启,则不指定clip_skip """ 379 | if not all.get("v2"): 380 | training_arguments.update( { "clip_skip":all.get("clip_skip") } ) 381 | """ 重训练模块 """ 382 | if all.get("use_retrain") == "model": 383 | training_arguments.update( { "network_weights":all.get("retrain_dir") } ) 384 | elif all.get("use_retrain") == "state": 385 | training_arguments.update( { "resume":all.get("retrain_dir") } ) 386 | """ 训练精度、保存精度 """ 387 | training_arguments.update( { "mixed_precision":"fp16" } ) 388 | training_arguments.update( { "save_precision":"fp16" } ) 389 | 390 | 391 | 392 | #sample_prompt_arguments部分(采样间隔,采样文件地址待添加) 393 | sample_prompt_arguments = { key: all.get(key) for key in ["sample_sampler"] } 394 | if all.get("sample_every_n_type"): #如果采样部分没确认过一次,会出现all.get("sample_every_n_type")=None:None的字典造成报错 395 | sample_prompt_arguments.update( {all.get("sample_every_n_type"):all.get("sample_every_n_type_value")} ) 396 | 397 | 398 | #dreambooth_arguments部分 399 | def creat_dreambooth_arguments_list(use_reg:bool) -> list: 400 | if use_reg: 401 | return ["train_data_dir", "reg_data_dir", "prior_loss_weight"] 402 | else: 403 | return ["train_data_dir"] 404 | 405 | dreambooth_arguments = { key: all.get(key) for key in creat_dreambooth_arguments_list( all.get("use_reg") ) } 406 | 407 | 408 | #saving_arguments部分 409 | saving_arguments = { key: all.get(key) for key in ["output_name", "save_every_n_epochs", "save_n_epoch_ratio",\ 410 | "save_last_n_epochs", "save_state", "save_model_as" ] 411 | } 412 | """在输出文件夹output_dir后面加上output_name""" 413 | output_dir = os.path.join( all.get("output_dir"), all.get("output_name") ) 414 | saving_arguments.update( {"output_dir":output_dir } ) 415 | """ 指定log输出目录与output相同 """ 416 | saving_arguments.update( { "logging_dir":os.path.join( output_dir, "logs" ) } ) 417 | """ 指定log前缀和输出名字相同 """ 418 | saving_arguments.update( { "log_prefix":all.get("output_name") } ) 419 | """ 启用wandb""" 420 | #决定log记录方式 421 | if all.get("use_wandb"): 422 | saving_arguments.update( { "log_with":"all" } ) 423 | else: 424 | saving_arguments.update( { "log_with":"tensorboard" } ) 425 | #api_key和log_tracker_name的指定 426 | if all.get("wandb_api_key"): 427 | saving_arguments.update( { "wandb_api_key":all.get("wandb_api_key") } ) 428 | if all.get("log_tracker_name"): 429 | saving_arguments.update( { "log_tracker_name":all.get("log_tracker_name") } ) 430 | 431 | 432 | #self_arguments部分 433 | try: 434 | self_arguments = toml.loads( all.get("self_arguments") ) 435 | except Exception: 436 | self_arguments = {} 437 | 438 | 439 | ##合成总字典 440 | toml_dict = {"model_arguments":model_arguments, 441 | "additional_network_arguments":additional_network_arguments, 442 | "optimizer_arguments":optimizer_arguments, 443 | "dataset_arguments":dataset_arguments, 444 | "training_arguments":training_arguments, 445 | "sample_prompt_arguments":sample_prompt_arguments, 446 | "dreambooth_arguments":dreambooth_arguments, 447 | "saving_arguments":saving_arguments, 448 | "self_arguments":self_arguments, 449 | } 450 | toml_str = toml.dumps(toml_dict) 451 | return toml_str 452 | 453 | def sample_parameter2txt(): 454 | #key_list = ["prompt", "negative", "sample_width", "sample_height", "sample_scale", "sample_steps", "sample_seed"] 455 | 456 | #如果采样部分没确认过,这个值=None;或者没写任何prompt;就直接退出,返回一个空字符串 457 | if not all.get("prompt"): 458 | return "" 459 | #允许分行 460 | prompt = all.get("prompt").replace("\n", "") 461 | negative = all.get("negative").replace("\n", "") 462 | #生成采样文件str 463 | sample_str = f"""{prompt} \ 464 | --n {negative} \ 465 | --w {all.get("sample_width")} \ 466 | --h {all.get("sample_height")} \ 467 | --l {all.get("sample_scale")} \ 468 | --s {all.get("sample_steps")} \ 469 | {f"--d {all.get('sample_seed')}" if all.get('sample_seed') > 0 else ""}""" 470 | return sample_str 471 | 472 | def write(content, path): 473 | with open(path, "w", encoding="utf-8") as f: 474 | f.write(content) 475 | 476 | write(parameter2toml(), config_file_toml_path) 477 | write(sample_parameter2txt(), sample_prompts_txt_path) 478 | write_files_title = f"写入成功, 训练配置文件在{config_file_toml_path} , 采样参数文件在{sample_prompts_txt_path}" 479 | return write_files_title 480 | 481 | #@title WebUI部分 482 | def create_demo(parser_input:list=[]): 483 | 484 | #用命令行参数指定默认webui组件config保存、读写路径和名字 485 | parser = argparse.ArgumentParser() 486 | 487 | DEFAULT_SAVE_AND_READ_DIR = os.path.join(ROOT_DIR, "kohya_config_webui_save") 488 | 489 | parser.add_argument("--save_dir", type=str, default=DEFAULT_SAVE_AND_READ_DIR, help="webui组件config默认保存路径") 490 | parser.add_argument("--save_name", type=str, default="kohya_config_webui_save.toml", help="webui组件config默认保存名字") 491 | parser.add_argument("--read_dir", type=str, default=DEFAULT_SAVE_AND_READ_DIR, help="webui组件config默认读取路径") 492 | #parser.add_argument("--read_name", type=str, default="kohya_config_webui_save.toml", help="webui组件config默认读取名字") 493 | 494 | #如果直接调用就获取命令行参数 495 | if __name__ == "__main__": 496 | cmd_param, unknown = parser.parse_known_args() 497 | #如果是被导入,就由create_demo(parser_input:list=[])来指定参数 498 | else: 499 | cmd_param, unknown = parser.parse_known_args(parser_input) 500 | 501 | #图标常量 502 | """ 503 | random_symbol = '\U0001f3b2\ufe0f' # 🎲️ 504 | reuse_symbol = '\u267b\ufe0f' # ♻️ 505 | paste_symbol = '\u2199\ufe0f' # ↙ 506 | save_style_symbol = '\U0001f4be' # 💾 507 | apply_style_symbol = '\U0001f4cb' # 📋 508 | clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️ 509 | extra_networks_symbol = '\U0001F3B4' # 🎴 510 | switch_values_symbol = '\U000021C5' # ⇅ 511 | folder_symbol = '\U0001f4c2' # 📂 512 | """ 513 | refresh_symbol = '\U0001f504' # 🔄 514 | 515 | #全局变量 516 | global common_parameter_dict_key_list 517 | global sample_parameter_dict_key_list 518 | global plus_parameter_dict_key_list 519 | global all_parameter_dict_key_list 520 | global parameter_len_dict 521 | 522 | #用于储存组件变量,方便向botton.click函数传递 523 | common_gr_dict = {} 524 | sample_gr_dict = {} 525 | plus_gr_dict = {} 526 | all_gr_dict ={} #后面会有一次把三个字典合起来 527 | 528 | def init_gr_read_name(dir:str) -> list: 529 | """ 530 | 读取dir路径下以.toml的文件,返回一个字符串列表 531 | 用于初始化read_webui_config_name组件 532 | """ 533 | try: 534 | #尝试读取 535 | files = [ f for f in os.listdir(dir) if os.path.isfile(os.path.join(dir, f)) and f.endswith(".toml") ] 536 | #读取到了就返回 537 | if files: 538 | return files 539 | #读取不到就返回空字符串列表 540 | else: 541 | return [""] 542 | except Exception: 543 | #读取不到就返回空字符串列表 544 | return [""] 545 | 546 | with gr.Blocks() as demo: 547 | gr.Markdown("更新时间2023年4月26,如果新参数用不了,请保证kohya版本在更新时间后") 548 | with gr.Accordion("保存、读取\nwebui配置", open=False): 549 | save_read_webui_config_title = gr.Markdown("保存或读取") 550 | with gr.Row(): 551 | save_webui_config_button = gr.Button("保存(记得先确认!)") 552 | with gr.Row(): 553 | save_webui_config_dir = gr.Textbox(lines=1, label="保存目录", value=cmd_param.save_dir ) 554 | save_webui_config_name = gr.Textbox(lines=1, label="保存名字(以toml为扩展名,否则不会被读取)", value=cmd_param.save_name ) 555 | with gr.Row(): 556 | read_webui_config_get_button = gr.Button(refresh_symbol) 557 | read_webui_config_button = gr.Button("读取") 558 | with gr.Row(): 559 | read_webui_config_dir = gr.Textbox(lines=1, label="读取目录", value=cmd_param.read_dir ) 560 | read_webui_config_name = gr.Dropdown(choices=init_gr_read_name(cmd_param.read_dir), 561 | label="读取文件", 562 | value=init_gr_read_name(cmd_param.read_dir)[0] ) 563 | with gr.Row(): 564 | write_files_button = gr.Button("生成toml参数与采样配置文件") 565 | all_parameter_get_button = gr.Button("全部参数确认") 566 | write_files_dir = gr.Textbox( lines=1, label="写入文件夹", placeholder="一般填kohya_script目录,留空就默认根目录下的kohya_config文件夹", value="" ) 567 | write_files_title = gr.Markdown("生成适用于kohya/train_network.py的配置文件") 568 | with gr.Tabs(): 569 | with gr.TabItem("基础参数"): 570 | common_parameter_get_button = gr.Button("确定") 571 | common_parameter_title = gr.Markdown("") 572 | with gr.Accordion("当前基础参数配置", open=False): 573 | common_parameter_toml = gr.Textbox(label="toml形式", placeholder="基础参数", value="") 574 | with gr.Row(): 575 | common_gr_dict["train_data_dir"] = gr.Textbox(lines=1, label="train_data_dir", placeholder="训练集路径", value="") 576 | with gr.Accordion("使用正则化(可选)", open=False): 577 | with gr.Row(): 578 | common_gr_dict["use_reg"] = gr.Checkbox(label="是否使用正则化",value=False) 579 | with gr.Row(): 580 | common_gr_dict["reg_data_dir"] = gr.Textbox(lines=1, label="reg_data_dir", placeholder="正则化集路径(开启才有效)", value="") 581 | common_gr_dict["prior_loss_weight"] = gr.Slider(0, 1, step=0.01, value=0.3, label="正则化权重") 582 | with gr.Row(): 583 | common_gr_dict["base_model_dir"] = gr.Textbox(label="底模文件夹地址", placeholder="文件夹路径", value="") 584 | common_gr_dict["base_model_name"] = gr.Dropdown(choices=[],label="底模",value="") 585 | base_model_get_button = gr.Button(refresh_symbol) 586 | with gr.Accordion("使用vae(可选)", open=False): 587 | with gr.Row(): 588 | common_gr_dict["use_vae"] = gr.Checkbox(label="是否使用vae",value=False) 589 | with gr.Row(): 590 | common_gr_dict["vae_model_dir"] = gr.Textbox(label="vae文件夹地址", placeholder="文件夹路径", value="") 591 | common_gr_dict["vae_model_name"] = gr.Dropdown(choices=[],label="vae", value="") 592 | vae_model_get_button = gr.Button(refresh_symbol) 593 | with gr.Row(): 594 | common_gr_dict["width"]= gr.Slider(64, 1920, step=64, value=512, label="训练分辨率(宽)width") 595 | common_gr_dict["height"] = gr.Slider(64, 1920, step=64, value=512, label="训练分辨率(高)height") 596 | common_gr_dict["train_batch_size"] = gr.Slider(1, 24, step=1, value=1, label="batch大小") 597 | with gr.Row(): 598 | common_gr_dict["noise_offset"] = gr.Slider(0, 1, step=0.01, value=0.05, label="noise_offset") 599 | common_gr_dict["keep_tokens"] = gr.Slider(0, 225, step=1, value=0, label="keep_tokens") 600 | common_gr_dict["min_snr_gamma"] = gr.Slider(0, 100, step=0.1, value=5, label="min_snr_gamma(设置为0则不生效)") 601 | """ 602 | with gr.Row(): 603 | gr.Markdown("repeat * 图片数 = 每个epoch的steps数") 604 | """ 605 | with gr.Row(): 606 | common_gr_dict["max_train_method"] = gr.Dropdown(["max_train_epochs","max_train_steps"], label="以epochs或steps来指定最大训练时间", value="max_train_epochs") 607 | common_gr_dict["max_train_value"] = gr.Number(label="最大训练epochs\steps数", value=10, precision=0) 608 | with gr.Accordion("输出设置", open=True): 609 | with gr.Row(): 610 | common_gr_dict["output_dir"] = gr.Textbox( label="模型、log日志输出地址(自行修改)", placeholder="文件夹路径",value=os.path.join(ROOT_DIR,"output") ) 611 | common_gr_dict["output_name"] = gr.Textbox(label="输出模型名称(自行修改)", placeholder="名称",value="output_name") 612 | common_gr_dict["save_model_as"] = gr.Dropdown(["safetensors","ckpt","pt"], label="保存模型格式", value="safetensors") 613 | with gr.Row(): 614 | common_gr_dict["save_every_n_epochs"] = gr.Slider(1, 499, step=1, value=1, label="每n个epoch保存一次") 615 | common_gr_dict["save_n_epoch_ratio"] = gr.Slider(1, 499, step=1, value=0, label="等间隔保存n个(如不为0,会覆盖每n个epoch保存一次)") 616 | common_gr_dict["save_last_n_epochs"] = gr.Slider(1, 499, step=1, value=499, label="最多保存n个(后面的出来就会把前面删了,优先级最高)") 617 | with gr.Row(): 618 | common_gr_dict["save_state"] = gr.Checkbox(label="保存学习状态",value=False) 619 | with gr.Accordion(" * 启用远程记录", open=False): 620 | with gr.Row(): 621 | gr.Markdown( "[你可以在这里找到api_key](https://wandb.ai/authorize)") 622 | with gr.Row(): 623 | common_gr_dict["use_wandb"] = gr.Checkbox(label="是否使用wandb远程记录", value= False) 624 | common_gr_dict["wandb_api_key"] = gr.Textbox(label="wandb_api_key", placeholder="第一次使用,或者需要切换新API的时候,请填入", value="") 625 | common_gr_dict["log_tracker_name"] = gr.Textbox(label="log_tracker_name项目名称", placeholder="留空则指定为network_train",value="") 626 | with gr.Row(): 627 | common_gr_dict["optimizer_type"] = gr.Dropdown(["AdamW8bit", "Lion", "DAdaptation", "AdamW", "SGDNesterov", "SGDNesterov8bit", "AdaFactor"],\ 628 | label="optimizer_type优化器类型(DA优化器两个学习率要一样)", value="AdamW8bit") 629 | common_gr_dict["unet_lr"] = gr.Number(label="unet学习率", value=1e-4) 630 | common_gr_dict["text_encoder_lr"] = gr.Number(label="text_encoder学习率", value=1e-5) 631 | with gr.Accordion("optimizer_args优化器参数(不会就不要填)", open=False): 632 | common_gr_dict["optimizer_args"] = gr.Textbox(label="optimizer_args", placeholder="如果你要填,就这样: decouple=True, weight_decay=0.5", value="") 633 | 634 | 635 | """ 下拉框触发 """ 636 | def __optimizer_arg(optimizer_type:str, optimizer_args:str) -> str: 637 | if optimizer_type == "DAdaptation": 638 | #return "decouple=True, weight_decay=0.5" 639 | return "decouple=True" 640 | else: 641 | return optimizer_args 642 | optimizer_change_inputs = [ common_gr_dict["optimizer_type"], common_gr_dict["optimizer_args"] ] 643 | common_gr_dict["optimizer_type"].change(fn=__optimizer_arg, inputs=optimizer_change_inputs, outputs=common_gr_dict["optimizer_args"] ) 644 | 645 | 646 | with gr.Row(): 647 | common_gr_dict["lr_scheduler"] = gr.Dropdown(["cosine_with_restarts","cosine","polynomial","linear","constant_with_warmup","constant"],\ 648 | label="lr_scheduler学习率调度器", value="cosine_with_restarts") 649 | common_gr_dict["lr_warmup_steps"] = gr.Number(label="升温步数", value=0, precision=0) 650 | common_gr_dict["lr_restart_cycles"] = gr.Number(label="退火重启次数", value=1, precision=0) 651 | with gr.Row(): 652 | common_gr_dict["train_method"] = gr.Dropdown(["LoRA-LierLa", "LoRA-C3Lier",\ 653 | "LoCon_Lycoris", "LoHa_Lycoris",\ 654 | "IA3_Lycoris", "LoKR_Lycoris",\ 655 | "DyLoRa-LierLa", "DyLoRa-C3Lier"],\ 656 | label="train_method训练方法", value="LoRA-LierLa") 657 | common_gr_dict["network_dim"] = gr.Number(label="线性dim", value=32, precision=0) 658 | common_gr_dict["network_alpha"] = gr.Number(label="线性alpha(可以为小数)", value=16) 659 | with gr.Accordion("额外网络参数(LoRA-C3Lier、LoCon、LoHa、DyLoRa-C3Lier都属于卷积,unit为DyLoRa专用)", open=True): 660 | with gr.Row(): 661 | with gr.Column(): 662 | common_gr_dict["conv_dim"] = gr.Number(label="卷积dim", info="使用DyLoRa-C3Lier时会被设置为等于基础dim", value=8, precision=0) 663 | with gr.Column(): 664 | common_gr_dict["conv_alpha"] = gr.Number(label="卷积alpha", info="可以为小数", value=1) 665 | with gr.Column(): 666 | common_gr_dict["unit"] = gr.Number(label="分割单位unit(整数)", info="使用DyLoRa时,请让dim为unit的倍数", value=4, precision=0) 667 | with gr.Row(): 668 | common_gr_dict["v2"] = gr.Checkbox(label="v2") 669 | common_gr_dict["v_parameterization"] = gr.Checkbox(label="v_parameterization") 670 | common_gr_dict["lowram"] = gr.Checkbox(label="lowram") 671 | common_gr_dict["xformers"] = gr.Checkbox(label="xformers",value=True) 672 | common_gr_dict["shuffle_caption"] = gr.Checkbox(label="shuffle_caption",value=True) 673 | common_gr_dict["enable_bucket"] = gr.Checkbox(label="enable_bucket",value=True) 674 | with gr.Row(): 675 | common_gr_dict["cache_latents"] = gr.Checkbox(label="cache_latents", value=True, info="缓存latents,注意开启这个功能时,训练过程中图像不能发生改变(比如增强AUG)") 676 | common_gr_dict["cache_latents_to_disk"] = gr.Checkbox(label="cache_latents_to_disk", value=False, info="缓存latents到硬盘,下次直接载入(需要把cache_latents启用,缓存后不能改变图像)") 677 | with gr.TabItem("采样参数"): 678 | sample_parameter_get_button = gr.Button("确定") 679 | sample_parameter_title = gr.Markdown("") 680 | with gr.Accordion("当前采样配置", open=False): 681 | sample_parameter_toml = gr.Textbox(label="toml形式", placeholder="采样配置", value="") 682 | with gr.Row(): 683 | #enable_sample = gr.Checkbox(label="是否启用采样功能") 684 | sample_gr_dict["sample_every_n_type"] = gr.Dropdown(["sample_every_n_epochs", "sample_every_n_steps"], label="sample_every_n_type", value="sample_every_n_epochs") 685 | sample_gr_dict["sample_every_n_type_value"] = gr.Number(label="sample_every_n_type_value", value=1, precision=0) 686 | with gr.Row(): 687 | sample_gr_dict["sample_sampler"] = gr.Dropdown(["ddim", "pndm", "lms", "euler", "euler_a", "heun",\ 688 | "dpm_2", "dpm_2_a", "dpmsolver","dpmsolver++", "dpmsingle",\ 689 | "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"],\ 690 | label="采样器", value="euler_a") 691 | sample_gr_dict["sample_seed"] = gr.Number(label="采样种子(-1不是随机,大于0才生效)", value=-1, precision=0) 692 | with gr.Row(): 693 | sample_gr_dict["sample_width"] = gr.Slider(64, 1920, step=64, value=512, label="采样图片宽") 694 | sample_gr_dict["sample_height"] = gr.Slider(64, 1920, step=64, value=768, label="采样图片高") 695 | sample_gr_dict["sample_scale"] = gr.Slider(1, 30, step=0.5, value=7, label="提示词相关性") 696 | sample_gr_dict["sample_steps"] = gr.Slider(1, 150, step=1, value=24, label="采样迭代步数") 697 | with gr.Row(): 698 | sample_gr_dict["prompt"] = gr.Textbox(lines=10, label="prompt", placeholder="正面提示词", value="(masterpiece, best quality, hires:1.2), 1girl, solo,") 699 | default_negative_str = ("(worst quality, bad quality:1.4), " 700 | "lowres, bad anatomy, bad hands, text, error, " 701 | "missing fingers, extra digit, fewer digits, " 702 | "cropped, worst quality, low quality, normal quality, " 703 | "jpeg artifacts,signature, watermark, username, blurry,") 704 | sample_gr_dict["negative"] = gr.Textbox(lines=10, label="negative", placeholder="负面提示词", value=default_negative_str) 705 | with gr.TabItem("进阶参数"): 706 | plus_parameter_get_button = gr.Button("确定") 707 | plus_parameter_title = gr.Markdown("") 708 | with gr.Accordion("当前进阶参数配置", open=False): 709 | plus_parameter_toml = gr.Textbox(label="toml形式", placeholder="进阶参数", value="") 710 | with gr.Row(): 711 | plus_gr_dict["use_retrain"] = gr.Dropdown(["no","model","state"], label="是否使用重训练", value="no") 712 | plus_gr_dict["retrain_dir"] = gr.Textbox(lines=1, label="重训练路径", placeholder="模型或者状态路径", value="") 713 | with gr.Row(): 714 | plus_gr_dict["weighted_captions"] = gr.Checkbox(label="开启权重标",value=False,info="你开启了,最大token就能75个") 715 | plus_gr_dict["min_bucket_reso"] = gr.Slider(64, 1920, step=64, value=256, label="最低桶分辨率") 716 | plus_gr_dict["max_bucket_reso"] = gr.Slider(64, 1920, step=64, value=1024, label="最高桶分辨率") 717 | plus_gr_dict["clip_skip"] = gr.Slider(0, 25, step=1, value=2, label="跳过层数") 718 | plus_gr_dict["max_token_length"] = gr.Slider(75, 225, step=75, value=225, label="训练最大token数") 719 | plus_gr_dict["caption_extension"] = gr.Textbox(lines=1, label="标签文件扩展名", placeholder="一般填.txt或.cap", value=".txt") 720 | plus_gr_dict["seed"] = gr.Number(label="种子", value=1337, precision=0) 721 | with gr.Row(): 722 | plus_gr_dict["network_train_unet_only"] = gr.Checkbox(label="仅训练unet网络",value=False) 723 | plus_gr_dict["network_train_text_encoder_only"] = gr.Checkbox(label="仅训练text_encoder网络",value=False) 724 | with gr.Row(): 725 | plus_gr_dict["gradient_checkpointing"] = gr.Checkbox(label="gradient_checkpointing", value=False, info="逐步计算权重(会使速度变慢,但可以用更大batch)") 726 | plus_gr_dict["gradient_accumulation_steps"]= gr.Number(label="梯度累积步数(累积n步更新一次权重)", value=1, precision=0, info="等效batch = batch * gradient_accumulation_steps") 727 | with gr.Accordion("分层学习模块", open=True): 728 | gr.Markdown("学习率分层,为不同层的结构指定不同学习率倍数; 如果某一层权重为0,那该层不会被创建") 729 | with gr.Row(): 730 | with gr.Column(scale=15): 731 | plus_gr_dict["down_lr_weight"] = gr.Textbox(lines=1, label="IN层学习率权重", placeholder="留空则不启用",\ 732 | info="12层,例如0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5", value="") 733 | with gr.Column(scale=1): 734 | plus_gr_dict["mid_lr_weight"] = gr.Textbox(lines=1, label="MID层学习率权重", placeholder="留空则不启用",\ 735 | info="1层,例如2.0", value="") 736 | with gr.Column(scale=15): 737 | plus_gr_dict["up_lr_weight"] = gr.Textbox(lines=1, label="OUT层学习率权重", placeholder="留空则不启用",\ 738 | info="12层,例如1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5", value="") 739 | with gr.Accordion(" * 分层示例", open=False): 740 | gr.Examples(examples = [ ["MIDD", "0,0,0,0,0,1,1,1,1,1,1,1", "1", "1,1,1,1,1,1,1,0,0,0,0,0"],\ 741 | ["OUTALL","0,0,0,0,0,0,0,0,0,0,0,0", "0", "1,1,1,1,1,1,1,1,1,1,1,1"],\ 742 | ["OUTD", "0,0,0,0,0,0,0,0,0,0,0,0", "0", "1,1,1,1,1,1,1,0,0,0,0,0"] 743 | ], 744 | inputs = [ gr.Textbox(label="预设",visible=False), plus_gr_dict["down_lr_weight"], plus_gr_dict["mid_lr_weight"], plus_gr_dict["up_lr_weight"] ] 745 | ) 746 | gr.Markdown("dim和alpha分层,为不同层的结构指定不同的dim和alpha(`DyLoRa`无法使用,卷积分层只有`LoRa-C3Lier、LoCon、LoHa`可以使用)") 747 | with gr.Row(): 748 | plus_gr_dict["block_dims"] = gr.Textbox(lines=1, label="线性dim分层", placeholder="留空则不启用",\ 749 | info="25层(上中下),例如2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2", value="") 750 | plus_gr_dict["block_alphas"] = gr.Textbox(lines=1, label="线性alpha分层", placeholder="留空则不启用",\ 751 | info="25层(上中下),例如2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2", value="") 752 | with gr.Row(): 753 | plus_gr_dict["conv_block_dims"] = gr.Textbox(lines=1, label="卷积dim分层", placeholder="留空则不启用",\ 754 | info="25层(上中下),例如2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2", value="") 755 | plus_gr_dict["conv_block_alphas"] = gr.Textbox(lines=1, label="卷积alpha分层", placeholder="留空则不启用",\ 756 | info="25层(上中下),例如2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2", value="") 757 | with gr.Row(): 758 | with gr.Accordion("额外参数(toml格式)", open=False): 759 | check_self_arguments_title = gr.Markdown("") 760 | chek_self_arguments_botton = gr.Button("格式检查,如果你格式错了,到时候写入文件时整个自定义参数都无效") 761 | self_arguments_placeholder_str = "如果你要用,按下面那样的toml格式填,以行为分隔:" 762 | plus_gr_dict["self_arguments"] = gr.Textbox(lines=10, label="你设置的参数会写被到toml的最下面,如果出现与预设参数重名,理论上你的参数优先级高,但尽量别这么做,可能会有意想不到的后果",\ 763 | placeholder=f"{self_arguments_placeholder_str}\nmax_grad_norm = 1\noutput_config = true") 764 | #检查格式 765 | def check_self_arguments(self_arguments:str) -> str: 766 | try: 767 | toml.loads(self_arguments) 768 | return "格式正确" 769 | except Exception as e: 770 | return f"格式错误 error:{e}" 771 | chek_self_arguments_botton.click(fn=check_self_arguments, 772 | inputs=plus_gr_dict["self_arguments"], 773 | outputs=check_self_arguments_title 774 | ) 775 | 776 | 777 | all_gr_dict = {**common_gr_dict, **sample_gr_dict, **plus_gr_dict} 778 | 779 | def dict_key_list_2_list(dict_key_list:list, gr_dict:dict): 780 | """ 输入一个指定key顺序的字符串list,和一个gr组件变量的字典""" 781 | """ 将gr_dict中键名与list中字符串相等的值变成一个gr组件列表 """ 782 | """ 同时返回parameter_list的长度,方便确认各标签页中组件数 """ 783 | list = [] 784 | for key in dict_key_list: 785 | try: 786 | list.append(gr_dict[key]) 787 | except KeyError: 788 | print(f"Error: parameter_dict_key_list中{key}不存在") 789 | list_len = len(list) 790 | return list, list_len 791 | 792 | """ 获取三个参数的字典键名成为一个list """ 793 | common_parameter_dict_key_list = list( common_gr_dict.keys() ) 794 | common_parameter_list, parameter_len_dict["common"] = dict_key_list_2_list(common_parameter_dict_key_list, common_gr_dict) 795 | 796 | sample_parameter_dict_key_list = list( sample_gr_dict.keys() ) 797 | sample_parameter_list, parameter_len_dict["sample"] = dict_key_list_2_list(sample_parameter_dict_key_list, sample_gr_dict) 798 | 799 | plus_parameter_dict_key_list = list( plus_gr_dict.keys() ) 800 | plus_parameter_list, parameter_len_dict["plus"] = dict_key_list_2_list(plus_parameter_dict_key_list, plus_gr_dict) 801 | 802 | all_parameter_list = common_parameter_list + sample_parameter_list + plus_parameter_list 803 | all_parameter_dict_key_list = common_parameter_dict_key_list + sample_parameter_dict_key_list + plus_parameter_dict_key_list 804 | 805 | 806 | 807 | """ 按钮部分 """ 808 | #在指定路径保存webui组件config文件 809 | save_webui_config_button.click(fn=save_webui_config, 810 | inputs=[save_webui_config_dir, save_webui_config_name, write_files_dir], 811 | outputs=save_read_webui_config_title 812 | ) 813 | #获取指定路径下的所有以.toml扩展名的文件列表 814 | read_webui_config_get_button.click(fn=read_webui_config_get, 815 | inputs=[read_webui_config_dir], 816 | outputs=[read_webui_config_name] 817 | ) 818 | #读取指定路径webui组件config文件 819 | read_webui_config_button.click(fn=read_webui_config, 820 | inputs=[read_webui_config_dir, read_webui_config_name, write_files_dir] + all_parameter_list, 821 | outputs=[save_read_webui_config_title, write_files_dir] + all_parameter_list 822 | ) 823 | #在指定路径下写入kohya_toml 824 | write_files_button.click(fn=write_files, 825 | inputs=[write_files_dir], 826 | outputs=[write_files_title] 827 | ) 828 | #确定常规参数 829 | common_parameter_get_button.click(fn=common_parameter_get, 830 | inputs=common_parameter_list, 831 | outputs=[common_parameter_toml, common_parameter_title] 832 | ) 833 | #确定采样参数 834 | sample_parameter_get_button.click(fn=sample_parameter_get, 835 | inputs=sample_parameter_list, 836 | outputs=[sample_parameter_toml, sample_parameter_title] 837 | ) 838 | #确定进阶参数 839 | plus_parameter_get_button.click(fn=plus_parameter_get, 840 | inputs=plus_parameter_list, 841 | outputs=[plus_parameter_toml, plus_parameter_title] 842 | ) 843 | #确定全部参数 844 | all_parameter_get_button.click(fn=all_parameter_get, 845 | inputs=all_parameter_list, 846 | outputs=[common_parameter_toml, sample_parameter_toml, plus_parameter_toml, write_files_title] 847 | ) 848 | #读取路径下的所有文件 849 | base_model_get_button.click(fn=model_get, 850 | inputs=all_gr_dict["base_model_dir"], 851 | outputs=all_gr_dict["base_model_name"] 852 | ) 853 | #读取路径下的所有文件 854 | vae_model_get_button.click(fn=model_get, 855 | inputs=all_gr_dict["vae_model_dir"], 856 | outputs=all_gr_dict["vae_model_name"] 857 | ) 858 | return demo 859 | 860 | 861 | if __name__ == "__main__": 862 | demo = create_demo() 863 | demo.launch(share=False,inbrowser=True,inline=True,debug=True) 864 | --------------------------------------------------------------------------------