├── .github └── workflows │ └── publish.yml ├── .gitignore ├── README.md ├── __init__.py ├── configs ├── kohya_ss_lora │ ├── controlnet_sd1_5.json │ ├── lora_hunyuan1_1.json │ ├── lora_hunyuan1_2.json │ ├── lora_sd1_5.json │ └── lora_sdxl.json ├── models_config │ ├── clip-vit-large-patch14 │ │ └── tokenizer_config.json │ ├── stable-diffusion-v1-5 │ │ ├── .gitattributes │ │ ├── README.md │ │ ├── feature_extractor │ │ │ └── preprocessor_config.json │ │ ├── model_index.json │ │ ├── safety_checker │ │ │ └── config.json │ │ ├── scheduler │ │ │ └── scheduler_config.json │ │ ├── text_encoder │ │ │ └── config.json │ │ ├── tokenizer │ │ │ ├── merges.txt │ │ │ ├── special_tokens_map.json │ │ │ ├── tokenizer_config.json │ │ │ └── vocab.json │ │ ├── unet │ │ │ └── config.json │ │ ├── v1-inference.yaml │ │ └── vae │ │ │ └── config.json │ ├── stable-diffusion-v1.5 │ │ └── v1-inference.yaml │ ├── stable-diffusion-xl-base-1.0 │ │ ├── .gitattributes │ │ ├── LICENSE.md │ │ ├── README.md │ │ ├── model_index.json │ │ ├── scheduler │ │ │ └── scheduler_config.json │ │ ├── text_encoder │ │ │ ├── config.json │ │ │ └── openvino_model.xml │ │ ├── text_encoder_2 │ │ │ ├── config.json │ │ │ └── openvino_model.xml │ │ ├── tokenizer │ │ │ ├── merges.txt │ │ │ ├── special_tokens_map.json │ │ │ ├── tokenizer_config.json │ │ │ └── vocab.json │ │ ├── tokenizer_2 │ │ │ ├── merges.txt │ │ │ ├── special_tokens_map.json │ │ │ ├── tokenizer_config.json │ │ │ └── vocab.json │ │ ├── unet │ │ │ ├── config.json │ │ │ └── openvino_model.xml │ │ ├── vae │ │ │ └── config.json │ │ ├── vae_1_0 │ │ │ └── config.json │ │ ├── vae_decoder │ │ │ ├── config.json │ │ │ └── openvino_model.xml │ │ └── vae_encoder │ │ │ ├── config.json │ │ │ └── openvino_model.xml │ └── stable-diffusion-xl │ │ └── sd_xl_base.yaml └── sliders │ ├── config-xl.yaml │ └── prompts-xl.yaml ├── examples ├── captioner.png └── workflow.png ├── hook_HYDiT_idk_run.py ├── hook_HYDiT_main_train_deepspeed.py ├── hook_HYDiT_run.py ├── hook_HYDiT_utils.py ├── hook_kohya_ss_hunyuan_pipe.py ├── hook_kohya_ss_run.py ├── hook_kohya_ss_utils.py ├── mz_train_tools_core.py ├── mz_train_tools_core_HYDiT.py ├── mz_train_tools_utils.py └── pyproject.toml /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'MinusZoneAI' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@v1 23 | with: 24 | ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | kohya_ss 164 | copytoww.bat 165 | exclude.txt 166 | .vscode/settings.json 167 | *.safetensors 168 | *.ckpt 169 | *.bin 170 | *.onnx 171 | *.msgpack 172 | *.onnx_data 173 | configs/models_config/**/*.png 174 | HunyuanDiT -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![image](https://github.com/MinusZoneAI/ComfyUI-TrainTools-MZ/assets/5035199/3bdce469-5a49-4f59-8a88-1b20e3a75c85) 2 | ![image](https://github.com/MinusZoneAI/ComfyUI-TrainTools-MZ/assets/5035199/9cd7f4eb-f971-49a8-ad57-7a56b71bf022) 3 | 4 | 5 | # ComfyUI-TrainTools-MZ 6 | 在ComfyUI中进行lora微调的节点,依赖于kohya-ss/sd-scripts等训练工具(Nodes for fine-tuning lora in ComfyUI, dependent on training tools such as kohya-ss/sd-scripts) 7 | 8 | ## Recent changes 9 | [20240706] 支持混元lora训练,脚本来自https://github.com/KohakuBlueleaf/sd-scripts/tree/HunYuanDiT 分支 10 | 11 | ## Installation 12 | 1. Clone this repo into `custom_nodes` folder. 13 | 2. Restart ComfyUI. 14 | 15 | ## Nodes 16 | ### MZ_KohyaSS_KohakuBlueleaf_HYHiDInitWorkspace | MZ_KohyaSS_KohakuBlueleaf_HYHiDLoraTrain 17 | ![image](https://github.com/MinusZoneAI/ComfyUI-TrainTools-MZ/assets/5035199/56d4e2cf-aaa6-4b05-95b1-003ed551c757) 18 | 19 | ### MZ_KohyaSSInitWorkspace 20 | 初始化训练文件夹,文件夹位于output目录(Initialize the training folder, the folder in the output directory) 21 | + lora_name(LoRa名称): 用于生成训练文件夹的名称(Used to generate the name of the training folder 22 | + branch(分支): sd-scripts的分支,默认为当前代码调试时使用的分支(sd-scripts branch, default is the branch used when debugging the current code) 23 | + source(源): sd-scripts的源,默认为github,下载有问题的话可以切换加速源(sd-scripts source, default is github, if there is a problem with the download, you can switch to the accelerated source) 24 | 25 | ![image](https://github.com/MinusZoneAI/ComfyUI-TrainTools-MZ/assets/5035199/8714d3e3-bc4f-4f99-9c0c-a5ea938b10a9) 26 | 27 | 28 | ### MZ_ImagesCopyWorkspace 29 | 复制图片到训练文件夹中和一些数据集配置(Copy images to the training folder and some dataset configurations) 30 | + images(图片列表): 推荐使用 https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite 中的上传文件夹节点 (It is recommended to use the upload folder node in https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite ) 31 | + force_clear(强制清空): 复制图片前是否强制清空原有文件夹内容(Whether to force clear the original folder content before copying the image) 32 | + force_clear_only_images(仅清空图片): 仅清空图片文件夹内容,不清空其他文件夹内容(Only clear the content of the image folder, not the content of other folders) 33 | + same_caption_generate(生成相同标注): 是否生成相同的标注文件(Whether to generate the same annotation file) 34 | + same_caption(单一标签): 生成相同标签的内容(Generate content with the same label) 35 | + 其他字段参考: https://github.com/kohya-ss/sd-scripts 36 | 37 | ![image](https://github.com/MinusZoneAI/ComfyUI-TrainTools-MZ/assets/5035199/739804dc-d8be-4d42-8b04-f3b4a1bc5e33) 38 | 39 | 40 | ### MZ_KohyaSSUseConfig 41 | 一些基础的训练配置(Some basic training configurations) 42 | + 没什么特殊的,字段参考: https://github.com/kohya-ss/sd-scripts 43 | 44 | ![image](https://github.com/MinusZoneAI/ComfyUI-TrainTools-MZ/assets/5035199/9cf82ed6-f3a2-4032-8032-1fd0b0ff6bdd) 45 | 46 | 47 | ### MZ_KohyaSSAdvConfig 48 | 更多的训练配置(More training configurations) 49 | + 没什么特殊的,字段参考: https://github.com/kohya-ss/sd-scripts 50 | 51 | ![image](https://github.com/MinusZoneAI/ComfyUI-TrainTools-MZ/assets/5035199/b7a5f904-8357-408a-9c27-4bda4e4f8c85) 52 | 53 | 54 | ### MZ_KohyaSSTrain 55 | 训练主线程(Training main thread) 56 | + base_lora(基础lora): 加载一个lora模型后进行训练,和sd-scripts中的`network_weights`参数一致,启用时忽略dim/alpha/dropout(Train after loading a lora model, consistent with the `network_weights` parameter in sd-scripts, ignore dim/alpha/dropout when enabled) 57 | + sample_generate(启用样图生成): 每次保存模型时进行一次示例图片生成,并展示训练过程中每个保存epoch时的示例图片(Enable example image generation each time the model is saved, and display the example image at each saved epoch during training) 58 | + sample_prompt(提示词): 生成示例图片时使用的提示词(Phrase used when generating example images) 59 | 60 | ![image](https://github.com/MinusZoneAI/ComfyUI-TrainTools-MZ/assets/5035199/e32d2132-cf0e-46b7-807d-a3160aaeea7d) 61 | 62 | 63 | ## FAQ 64 | 65 | 66 | 67 | ## Credits 68 | + [https://github.com/comfyanonymous/ComfyUI](https://github.com/comfyanonymous/ComfyUI) 69 | + [https://github.com/kohya-ss/sd-scripts](https://github.com/kohya-ss/sd-scripts) 70 | 71 | ## Star History 72 | 73 | 74 | 75 | 76 | 77 | Star History Chart 78 | 79 | 80 | 81 | ## Contact 82 | - 绿泡泡: minrszone 83 | - Bilibili: [minus_zone](https://space.bilibili.com/5950992) 84 | - 小红书: [MinusZoneAI](https://www.xiaohongshu.com/user/profile/5f072e990000000001005472) 85 | - 爱发电: [MinusZoneAI](https://afdian.net/@MinusZoneAI) 86 | -------------------------------------------------------------------------------- /configs/kohya_ss_lora/controlnet_sd1_5.json: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "train_type": "controlnet_sd1_5" 4 | }, 5 | "train_config": { 6 | "pretrained_model_name_or_path": "", 7 | "max_train_steps": "4500", 8 | "xformers": true, 9 | "sdpa": false, 10 | "fp8_base": false, 11 | "mixed_precision": "fp16", 12 | "cache_latents": true, 13 | "cache_latents_to_disk": true, 14 | "learning_rate": "1e-5", 15 | "lr_scheduler": "cosine_with_restarts", 16 | "optimizer_type": "AdamW", 17 | "save_every_n_epochs": "20", 18 | "shuffle_caption": false, 19 | "lr_warmup_steps": "0", 20 | "save_precision": "fp16", 21 | "lr_scheduler_num_cycles": "1", 22 | "persistent_data_loader_workers": true, 23 | "noise_offset": "0.1", 24 | "output_dir": "", 25 | "output_name": "", 26 | "lowram": false 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /configs/kohya_ss_lora/lora_hunyuan1_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "train_type": "lora_hunyuan1_1", 4 | "version": "1.1" 5 | }, 6 | "train_config": { 7 | "max_train_steps": "4500", 8 | "xformers": true, 9 | "sdpa": false, 10 | "fp8_base": false, 11 | "mixed_precision": "fp16", 12 | "cache_latents": true, 13 | "cache_latents_to_disk": true, 14 | "network_dim": "16", 15 | "network_alpha": "8", 16 | "network_module": "networks.lora", 17 | "network_train_unet_only": true, 18 | "learning_rate": "1e-5", 19 | "lr_scheduler": "cosine_with_restarts", 20 | "optimizer_type": "AdamW", 21 | "save_every_n_epochs": "20", 22 | "shuffle_caption": false, 23 | "lr_warmup_steps": "0", 24 | "save_precision": "fp16", 25 | "lr_scheduler_num_cycles": "1", 26 | "persistent_data_loader_workers": true, 27 | "no_metadata": true, 28 | "noise_offset": "0.1", 29 | "output_dir": "", 30 | "output_name": "", 31 | "no_half_vae": true, 32 | "v_parameterization": true, 33 | "lowram": false, 34 | "max_token_length": "225", 35 | "use_extra_cond": true 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /configs/kohya_ss_lora/lora_hunyuan1_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "train_type": "lora_hunyuan1_2", 4 | "version": "1.2" 5 | }, 6 | "train_config": { 7 | "max_train_steps": "4500", 8 | "xformers": true, 9 | "sdpa": false, 10 | "fp8_base": false, 11 | "mixed_precision": "fp16", 12 | "cache_latents": true, 13 | "cache_latents_to_disk": true, 14 | "network_dim": "4", 15 | "network_alpha": "4", 16 | "network_module": "networks.lora", 17 | "network_train_unet_only": true, 18 | "learning_rate": "1e-5", 19 | "lr_scheduler": "cosine_with_restarts", 20 | "optimizer_type": "AdamW", 21 | "save_every_n_epochs": "20", 22 | "shuffle_caption": false, 23 | "lr_warmup_steps": "0", 24 | "save_precision": "fp16", 25 | "lr_scheduler_num_cycles": "1", 26 | "persistent_data_loader_workers": true, 27 | "no_metadata": true, 28 | "noise_offset": "0.00", 29 | "output_dir": "", 30 | "output_name": "", 31 | "no_half_vae": true, 32 | "v_parameterization": true, 33 | "lowram": false, 34 | "max_token_length": "150", 35 | "use_extra_cond": false 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /configs/kohya_ss_lora/lora_sd1_5.json: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "train_type": "lora_sd1_5" 4 | }, 5 | "train_config": { 6 | "pretrained_model_name_or_path": "", 7 | "max_train_steps": "4500", 8 | "xformers": true, 9 | "sdpa": false, 10 | "fp8_base": false, 11 | "mixed_precision": "fp16", 12 | "cache_latents": true, 13 | "cache_latents_to_disk": true, 14 | "network_dim": "16", 15 | "network_alpha": "8", 16 | "network_module": "networks.lora", 17 | "network_train_unet_only": true, 18 | "learning_rate": "1e-5", 19 | "lr_scheduler": "cosine_with_restarts", 20 | "optimizer_type": "AdamW", 21 | "save_every_n_epochs": "20", 22 | "shuffle_caption": false, 23 | "lr_warmup_steps": "0", 24 | "save_precision": "fp16", 25 | "lr_scheduler_num_cycles": "1", 26 | "persistent_data_loader_workers": true, 27 | "no_metadata": true, 28 | "noise_offset": "0.1", 29 | "output_dir": "", 30 | "output_name": "", 31 | "no_half_vae": true, 32 | "lowram": false 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/kohya_ss_lora/lora_sdxl.json: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "train_type": "lora_sdxl" 4 | }, 5 | "train_config": { 6 | "pretrained_model_name_or_path": "", 7 | "max_train_steps": "4500", 8 | "xformers": true, 9 | "sdpa": false, 10 | "fp8_base": false, 11 | "mixed_precision": "fp16", 12 | "cache_latents": true, 13 | "cache_latents_to_disk": true, 14 | "network_dim": "16", 15 | "network_alpha": "8", 16 | "network_module": "networks.lora", 17 | "network_train_unet_only": true, 18 | "learning_rate": "1e-5", 19 | "lr_scheduler": "cosine_with_restarts", 20 | "optimizer_type": "AdamW", 21 | "save_every_n_epochs": "20", 22 | "shuffle_caption": false, 23 | "lr_warmup_steps": "0", 24 | "save_precision": "fp16", 25 | "lr_scheduler_num_cycles": "1", 26 | "persistent_data_loader_workers": true, 27 | "no_metadata": true, 28 | "noise_offset": "0.1", 29 | "output_dir": "", 30 | "output_name": "", 31 | "no_half_vae": true, 32 | "lowram": false 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/models_config/clip-vit-large-patch14/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "unk_token": { 3 | "content": "<|endoftext|>", 4 | "single_word": false, 5 | "lstrip": false, 6 | "rstrip": false, 7 | "normalized": true, 8 | "__type": "AddedToken" 9 | }, 10 | "bos_token": { 11 | "content": "<|startoftext|>", 12 | "single_word": false, 13 | "lstrip": false, 14 | "rstrip": false, 15 | "normalized": true, 16 | "__type": "AddedToken" 17 | }, 18 | "eos_token": { 19 | "content": "<|endoftext|>", 20 | "single_word": false, 21 | "lstrip": false, 22 | "rstrip": false, 23 | "normalized": true, 24 | "__type": "AddedToken" 25 | }, 26 | "pad_token": "<|endoftext|>", 27 | "add_prefix_space": false, 28 | "errors": "replace", 29 | "do_lower_case": true, 30 | "name_or_path": "openai/clip-vit-base-patch32", 31 | "model_max_length": 77, 32 | "special_tokens_map_file": "./special_tokens_map.json", 33 | "tokenizer_class": "CLIPTokenizer" 34 | } 35 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-v1-5/.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ftz filter=lfs diff=lfs merge=lfs -text 6 | *.gz filter=lfs diff=lfs merge=lfs -text 7 | *.h5 filter=lfs diff=lfs merge=lfs -text 8 | *.joblib filter=lfs diff=lfs merge=lfs -text 9 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 10 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 11 | *.model filter=lfs diff=lfs merge=lfs -text 12 | *.msgpack filter=lfs diff=lfs merge=lfs -text 13 | *.npy filter=lfs diff=lfs merge=lfs -text 14 | *.npz filter=lfs diff=lfs merge=lfs -text 15 | *.onnx filter=lfs diff=lfs merge=lfs -text 16 | *.ot filter=lfs diff=lfs merge=lfs -text 17 | *.parquet filter=lfs diff=lfs merge=lfs -text 18 | *.pb filter=lfs diff=lfs merge=lfs -text 19 | *.pickle filter=lfs diff=lfs merge=lfs -text 20 | *.pkl filter=lfs diff=lfs merge=lfs -text 21 | *.pt filter=lfs diff=lfs merge=lfs -text 22 | *.pth filter=lfs diff=lfs merge=lfs -text 23 | *.rar filter=lfs diff=lfs merge=lfs -text 24 | *.safetensors filter=lfs diff=lfs merge=lfs -text 25 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 26 | *.tar.* filter=lfs diff=lfs merge=lfs -text 27 | *.tflite filter=lfs diff=lfs merge=lfs -text 28 | *.tgz filter=lfs diff=lfs merge=lfs -text 29 | *.wasm filter=lfs diff=lfs merge=lfs -text 30 | *.xz filter=lfs diff=lfs merge=lfs -text 31 | *.zip filter=lfs diff=lfs merge=lfs -text 32 | *.zst filter=lfs diff=lfs merge=lfs -text 33 | *tfevents* filter=lfs diff=lfs merge=lfs -text 34 | v1-5-pruned-emaonly.ckpt filter=lfs diff=lfs merge=lfs -text 35 | v1-5-pruned.ckpt filter=lfs diff=lfs merge=lfs -text 36 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-v1-5/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | license: creativeml-openrail-m 3 | tags: 4 | - stable-diffusion 5 | - stable-diffusion-diffusers 6 | - text-to-image 7 | inference: true 8 | extra_gated_prompt: |- 9 | This model is open access and available to all, with a CreativeML OpenRAIL-M license further specifying rights and usage. 10 | The CreativeML OpenRAIL License specifies: 11 | 12 | 1. You can't use the model to deliberately produce nor share illegal or harmful outputs or content 13 | 2. CompVis claims no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in the license 14 | 3. You may re-distribute the weights and use the model commercially and/or as a service. If you do, please be aware you have to include the same use restrictions as the ones in the license and share a copy of the CreativeML OpenRAIL-M to all your users (please read the license entirely and carefully) 15 | Please read the full license carefully here: https://huggingface.co/spaces/CompVis/stable-diffusion-license 16 | 17 | extra_gated_heading: Please read the LICENSE to access this model 18 | --- 19 | 20 | # Stable Diffusion v1-5 Model Card 21 | 22 | Stable Diffusion is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input. 23 | For more information about how Stable Diffusion functions, please have a look at [🤗's Stable Diffusion blog](https://huggingface.co/blog/stable_diffusion). 24 | 25 | The **Stable-Diffusion-v1-5** checkpoint was initialized with the weights of the [Stable-Diffusion-v1-2](https:/steps/huggingface.co/CompVis/stable-diffusion-v1-2) 26 | checkpoint and subsequently fine-tuned on 595k steps at resolution 512x512 on "laion-aesthetics v2 5+" and 10% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598). 27 | 28 | You can use this both with the [🧨Diffusers library](https://github.com/huggingface/diffusers) and the [RunwayML GitHub repository](https://github.com/runwayml/stable-diffusion). 29 | 30 | ### Diffusers 31 | ```py 32 | from diffusers import StableDiffusionPipeline 33 | import torch 34 | 35 | model_id = "runwayml/stable-diffusion-v1-5" 36 | pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) 37 | pipe = pipe.to("cuda") 38 | 39 | prompt = "a photo of an astronaut riding a horse on mars" 40 | image = pipe(prompt).images[0] 41 | 42 | image.save("astronaut_rides_horse.png") 43 | ``` 44 | For more detailed instructions, use-cases and examples in JAX follow the instructions [here](https://github.com/huggingface/diffusers#text-to-image-generation-with-stable-diffusion) 45 | 46 | ### Original GitHub Repository 47 | 48 | 1. Download the weights 49 | - [v1-5-pruned-emaonly.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt) - 4.27GB, ema-only weight. uses less VRAM - suitable for inference 50 | - [v1-5-pruned.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt) - 7.7GB, ema+non-ema weights. uses more VRAM - suitable for fine-tuning 51 | 52 | 2. Follow instructions [here](https://github.com/runwayml/stable-diffusion). 53 | 54 | ## Model Details 55 | - **Developed by:** Robin Rombach, Patrick Esser 56 | - **Model type:** Diffusion-based text-to-image generation model 57 | - **Language(s):** English 58 | - **License:** [The CreativeML OpenRAIL M license](https://huggingface.co/spaces/CompVis/stable-diffusion-license) is an [Open RAIL M license](https://www.licenses.ai/blog/2022/8/18/naming-convention-of-responsible-ai-licenses), adapted from the work that [BigScience](https://bigscience.huggingface.co/) and [the RAIL Initiative](https://www.licenses.ai/) are jointly carrying in the area of responsible AI licensing. See also [the article about the BLOOM Open RAIL license](https://bigscience.huggingface.co/blog/the-bigscience-rail-license) on which our license is based. 59 | - **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487). 60 | - **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752). 61 | - **Cite as:** 62 | 63 | @InProceedings{Rombach_2022_CVPR, 64 | author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn}, 65 | title = {High-Resolution Image Synthesis With Latent Diffusion Models}, 66 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 67 | month = {June}, 68 | year = {2022}, 69 | pages = {10684-10695} 70 | } 71 | 72 | # Uses 73 | 74 | ## Direct Use 75 | The model is intended for research purposes only. Possible research areas and 76 | tasks include 77 | 78 | - Safe deployment of models which have the potential to generate harmful content. 79 | - Probing and understanding the limitations and biases of generative models. 80 | - Generation of artworks and use in design and other artistic processes. 81 | - Applications in educational or creative tools. 82 | - Research on generative models. 83 | 84 | Excluded uses are described below. 85 | 86 | ### Misuse, Malicious Use, and Out-of-Scope Use 87 | _Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_. 88 | 89 | 90 | The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. 91 | 92 | #### Out-of-Scope Use 93 | The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model. 94 | 95 | #### Misuse and Malicious Use 96 | Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to: 97 | 98 | - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc. 99 | - Intentionally promoting or propagating discriminatory content or harmful stereotypes. 100 | - Impersonating individuals without their consent. 101 | - Sexual content without consent of the people who might see it. 102 | - Mis- and disinformation 103 | - Representations of egregious violence and gore 104 | - Sharing of copyrighted or licensed material in violation of its terms of use. 105 | - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use. 106 | 107 | ## Limitations and Bias 108 | 109 | ### Limitations 110 | 111 | - The model does not achieve perfect photorealism 112 | - The model cannot render legible text 113 | - The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere” 114 | - Faces and people in general may not be generated properly. 115 | - The model was trained mainly with English captions and will not work as well in other languages. 116 | - The autoencoding part of the model is lossy 117 | - The model was trained on a large-scale dataset 118 | [LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material 119 | and is not fit for product use without additional safety mechanisms and 120 | considerations. 121 | - No additional measures were used to deduplicate the dataset. As a result, we observe some degree of memorization for images that are duplicated in the training data. 122 | The training data can be searched at [https://rom1504.github.io/clip-retrieval/](https://rom1504.github.io/clip-retrieval/) to possibly assist in the detection of memorized images. 123 | 124 | ### Bias 125 | 126 | While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases. 127 | Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/), 128 | which consists of images that are primarily limited to English descriptions. 129 | Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for. 130 | This affects the overall output of the model, as white and western cultures are often set as the default. Further, the 131 | ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts. 132 | 133 | ### Safety Module 134 | 135 | The intended use of this model is with the [Safety Checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) in Diffusers. 136 | This checker works by checking model outputs against known hard-coded NSFW concepts. 137 | The concepts are intentionally hidden to reduce the likelihood of reverse-engineering this filter. 138 | Specifically, the checker compares the class probability of harmful concepts in the embedding space of the `CLIPTextModel` *after generation* of the images. 139 | The concepts are passed into the model with the generated image and compared to a hand-engineered weight for each NSFW concept. 140 | 141 | 142 | ## Training 143 | 144 | **Training Data** 145 | The model developers used the following dataset for training the model: 146 | 147 | - LAION-2B (en) and subsets thereof (see next section) 148 | 149 | **Training Procedure** 150 | Stable Diffusion v1-5 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training, 151 | 152 | - Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4 153 | - Text prompts are encoded through a ViT-L/14 text-encoder. 154 | - The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention. 155 | - The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet. 156 | 157 | Currently six Stable Diffusion checkpoints are provided, which were trained as follows. 158 | - [`stable-diffusion-v1-1`](https://huggingface.co/CompVis/stable-diffusion-v1-1): 237,000 steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en). 159 | 194,000 steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`). 160 | - [`stable-diffusion-v1-2`](https://huggingface.co/CompVis/stable-diffusion-v1-2): Resumed from `stable-diffusion-v1-1`. 161 | 515,000 steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en, 162 | filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)). 163 | - [`stable-diffusion-v1-3`](https://huggingface.co/CompVis/stable-diffusion-v1-3): Resumed from `stable-diffusion-v1-2` - 195,000 steps at resolution `512x512` on "laion-improved-aesthetics" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598). 164 | - [`stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) Resumed from `stable-diffusion-v1-2` - 225,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598). 165 | - [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) Resumed from `stable-diffusion-v1-2` - 595,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598). 166 | - [`stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting) Resumed from `stable-diffusion-v1-5` - then 440,000 steps of inpainting training at resolution 512x512 on “laion-aesthetics v2 5+” and 10% dropping of the text-conditioning. For inpainting, the UNet has 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself) whose weights were zero-initialized after restoring the non-inpainting checkpoint. During training, we generate synthetic masks and in 25% mask everything. 167 | 168 | - **Hardware:** 32 x 8 x A100 GPUs 169 | - **Optimizer:** AdamW 170 | - **Gradient Accumulations**: 2 171 | - **Batch:** 32 x 8 x 2 x 4 = 2048 172 | - **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant 173 | 174 | ## Evaluation Results 175 | Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0, 176 | 5.0, 6.0, 7.0, 8.0) and 50 PNDM/PLMS sampling 177 | steps show the relative improvements of the checkpoints: 178 | 179 | ![pareto](https://huggingface.co/CompVis/stable-diffusion/resolve/main/v1-1-to-v1-5.png) 180 | 181 | Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores. 182 | ## Environmental Impact 183 | 184 | **Stable Diffusion v1** **Estimated Emissions** 185 | Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact. 186 | 187 | - **Hardware Type:** A100 PCIe 40GB 188 | - **Hours used:** 150000 189 | - **Cloud Provider:** AWS 190 | - **Compute Region:** US-east 191 | - **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq. 192 | 193 | 194 | ## Citation 195 | 196 | ```bibtex 197 | @InProceedings{Rombach_2022_CVPR, 198 | author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn}, 199 | title = {High-Resolution Image Synthesis With Latent Diffusion Models}, 200 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 201 | month = {June}, 202 | year = {2022}, 203 | pages = {10684-10695} 204 | } 205 | ``` 206 | 207 | *This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).* -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-v1-5/feature_extractor/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "crop_size": 224, 3 | "do_center_crop": true, 4 | "do_convert_rgb": true, 5 | "do_normalize": true, 6 | "do_resize": true, 7 | "feature_extractor_type": "CLIPFeatureExtractor", 8 | "image_mean": [ 9 | 0.48145466, 10 | 0.4578275, 11 | 0.40821073 12 | ], 13 | "image_std": [ 14 | 0.26862954, 15 | 0.26130258, 16 | 0.27577711 17 | ], 18 | "resample": 3, 19 | "size": 224 20 | } 21 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-v1-5/model_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "StableDiffusionPipeline", 3 | "_diffusers_version": "0.6.0", 4 | "feature_extractor": [ 5 | "transformers", 6 | "CLIPImageProcessor" 7 | ], 8 | "safety_checker": [ 9 | "stable_diffusion", 10 | "StableDiffusionSafetyChecker" 11 | ], 12 | "scheduler": [ 13 | "diffusers", 14 | "PNDMScheduler" 15 | ], 16 | "text_encoder": [ 17 | "transformers", 18 | "CLIPTextModel" 19 | ], 20 | "tokenizer": [ 21 | "transformers", 22 | "CLIPTokenizer" 23 | ], 24 | "unet": [ 25 | "diffusers", 26 | "UNet2DConditionModel" 27 | ], 28 | "vae": [ 29 | "diffusers", 30 | "AutoencoderKL" 31 | ] 32 | } 33 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-v1-5/safety_checker/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_commit_hash": "4bb648a606ef040e7685bde262611766a5fdd67b", 3 | "_name_or_path": "CompVis/stable-diffusion-safety-checker", 4 | "architectures": [ 5 | "StableDiffusionSafetyChecker" 6 | ], 7 | "initializer_factor": 1.0, 8 | "logit_scale_init_value": 2.6592, 9 | "model_type": "clip", 10 | "projection_dim": 768, 11 | "text_config": { 12 | "_name_or_path": "", 13 | "add_cross_attention": false, 14 | "architectures": null, 15 | "attention_dropout": 0.0, 16 | "bad_words_ids": null, 17 | "bos_token_id": 0, 18 | "chunk_size_feed_forward": 0, 19 | "cross_attention_hidden_size": null, 20 | "decoder_start_token_id": null, 21 | "diversity_penalty": 0.0, 22 | "do_sample": false, 23 | "dropout": 0.0, 24 | "early_stopping": false, 25 | "encoder_no_repeat_ngram_size": 0, 26 | "eos_token_id": 2, 27 | "exponential_decay_length_penalty": null, 28 | "finetuning_task": null, 29 | "forced_bos_token_id": null, 30 | "forced_eos_token_id": null, 31 | "hidden_act": "quick_gelu", 32 | "hidden_size": 768, 33 | "id2label": { 34 | "0": "LABEL_0", 35 | "1": "LABEL_1" 36 | }, 37 | "initializer_factor": 1.0, 38 | "initializer_range": 0.02, 39 | "intermediate_size": 3072, 40 | "is_decoder": false, 41 | "is_encoder_decoder": false, 42 | "label2id": { 43 | "LABEL_0": 0, 44 | "LABEL_1": 1 45 | }, 46 | "layer_norm_eps": 1e-05, 47 | "length_penalty": 1.0, 48 | "max_length": 20, 49 | "max_position_embeddings": 77, 50 | "min_length": 0, 51 | "model_type": "clip_text_model", 52 | "no_repeat_ngram_size": 0, 53 | "num_attention_heads": 12, 54 | "num_beam_groups": 1, 55 | "num_beams": 1, 56 | "num_hidden_layers": 12, 57 | "num_return_sequences": 1, 58 | "output_attentions": false, 59 | "output_hidden_states": false, 60 | "output_scores": false, 61 | "pad_token_id": 1, 62 | "prefix": null, 63 | "problem_type": null, 64 | "pruned_heads": {}, 65 | "remove_invalid_values": false, 66 | "repetition_penalty": 1.0, 67 | "return_dict": true, 68 | "return_dict_in_generate": false, 69 | "sep_token_id": null, 70 | "task_specific_params": null, 71 | "temperature": 1.0, 72 | "tf_legacy_loss": false, 73 | "tie_encoder_decoder": false, 74 | "tie_word_embeddings": true, 75 | "tokenizer_class": null, 76 | "top_k": 50, 77 | "top_p": 1.0, 78 | "torch_dtype": null, 79 | "torchscript": false, 80 | "transformers_version": "4.22.0.dev0", 81 | "typical_p": 1.0, 82 | "use_bfloat16": false, 83 | "vocab_size": 49408 84 | }, 85 | "text_config_dict": { 86 | "hidden_size": 768, 87 | "intermediate_size": 3072, 88 | "num_attention_heads": 12, 89 | "num_hidden_layers": 12 90 | }, 91 | "torch_dtype": "float32", 92 | "transformers_version": null, 93 | "vision_config": { 94 | "_name_or_path": "", 95 | "add_cross_attention": false, 96 | "architectures": null, 97 | "attention_dropout": 0.0, 98 | "bad_words_ids": null, 99 | "bos_token_id": null, 100 | "chunk_size_feed_forward": 0, 101 | "cross_attention_hidden_size": null, 102 | "decoder_start_token_id": null, 103 | "diversity_penalty": 0.0, 104 | "do_sample": false, 105 | "dropout": 0.0, 106 | "early_stopping": false, 107 | "encoder_no_repeat_ngram_size": 0, 108 | "eos_token_id": null, 109 | "exponential_decay_length_penalty": null, 110 | "finetuning_task": null, 111 | "forced_bos_token_id": null, 112 | "forced_eos_token_id": null, 113 | "hidden_act": "quick_gelu", 114 | "hidden_size": 1024, 115 | "id2label": { 116 | "0": "LABEL_0", 117 | "1": "LABEL_1" 118 | }, 119 | "image_size": 224, 120 | "initializer_factor": 1.0, 121 | "initializer_range": 0.02, 122 | "intermediate_size": 4096, 123 | "is_decoder": false, 124 | "is_encoder_decoder": false, 125 | "label2id": { 126 | "LABEL_0": 0, 127 | "LABEL_1": 1 128 | }, 129 | "layer_norm_eps": 1e-05, 130 | "length_penalty": 1.0, 131 | "max_length": 20, 132 | "min_length": 0, 133 | "model_type": "clip_vision_model", 134 | "no_repeat_ngram_size": 0, 135 | "num_attention_heads": 16, 136 | "num_beam_groups": 1, 137 | "num_beams": 1, 138 | "num_channels": 3, 139 | "num_hidden_layers": 24, 140 | "num_return_sequences": 1, 141 | "output_attentions": false, 142 | "output_hidden_states": false, 143 | "output_scores": false, 144 | "pad_token_id": null, 145 | "patch_size": 14, 146 | "prefix": null, 147 | "problem_type": null, 148 | "pruned_heads": {}, 149 | "remove_invalid_values": false, 150 | "repetition_penalty": 1.0, 151 | "return_dict": true, 152 | "return_dict_in_generate": false, 153 | "sep_token_id": null, 154 | "task_specific_params": null, 155 | "temperature": 1.0, 156 | "tf_legacy_loss": false, 157 | "tie_encoder_decoder": false, 158 | "tie_word_embeddings": true, 159 | "tokenizer_class": null, 160 | "top_k": 50, 161 | "top_p": 1.0, 162 | "torch_dtype": null, 163 | "torchscript": false, 164 | "transformers_version": "4.22.0.dev0", 165 | "typical_p": 1.0, 166 | "use_bfloat16": false 167 | }, 168 | "vision_config_dict": { 169 | "hidden_size": 1024, 170 | "intermediate_size": 4096, 171 | "num_attention_heads": 16, 172 | "num_hidden_layers": 24, 173 | "patch_size": 14 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-v1-5/scheduler/scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "PNDMScheduler", 3 | "_diffusers_version": "0.6.0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "num_train_timesteps": 1000, 8 | "set_alpha_to_one": false, 9 | "skip_prk_steps": true, 10 | "steps_offset": 1, 11 | "trained_betas": null, 12 | "clip_sample": false 13 | } 14 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-v1-5/text_encoder/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "openai/clip-vit-large-patch14", 3 | "architectures": [ 4 | "CLIPTextModel" 5 | ], 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 0, 8 | "dropout": 0.0, 9 | "eos_token_id": 2, 10 | "hidden_act": "quick_gelu", 11 | "hidden_size": 768, 12 | "initializer_factor": 1.0, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 3072, 15 | "layer_norm_eps": 1e-05, 16 | "max_position_embeddings": 77, 17 | "model_type": "clip_text_model", 18 | "num_attention_heads": 12, 19 | "num_hidden_layers": 12, 20 | "pad_token_id": 1, 21 | "projection_dim": 768, 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.22.0.dev0", 24 | "vocab_size": 49408 25 | } 26 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-v1-5/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": "<|endoftext|>", 17 | "unk_token": { 18 | "content": "<|endoftext|>", 19 | "lstrip": false, 20 | "normalized": true, 21 | "rstrip": false, 22 | "single_word": false 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-v1-5/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": { 4 | "__type": "AddedToken", 5 | "content": "<|startoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false 10 | }, 11 | "do_lower_case": true, 12 | "eos_token": { 13 | "__type": "AddedToken", 14 | "content": "<|endoftext|>", 15 | "lstrip": false, 16 | "normalized": true, 17 | "rstrip": false, 18 | "single_word": false 19 | }, 20 | "errors": "replace", 21 | "model_max_length": 77, 22 | "name_or_path": "openai/clip-vit-large-patch14", 23 | "pad_token": "<|endoftext|>", 24 | "special_tokens_map_file": "./special_tokens_map.json", 25 | "tokenizer_class": "CLIPTokenizer", 26 | "unk_token": { 27 | "__type": "AddedToken", 28 | "content": "<|endoftext|>", 29 | "lstrip": false, 30 | "normalized": true, 31 | "rstrip": false, 32 | "single_word": false 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-v1-5/unet/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNet2DConditionModel", 3 | "_diffusers_version": "0.6.0", 4 | "act_fn": "silu", 5 | "attention_head_dim": 8, 6 | "block_out_channels": [ 7 | 320, 8 | 640, 9 | 1280, 10 | 1280 11 | ], 12 | "center_input_sample": false, 13 | "cross_attention_dim": 768, 14 | "down_block_types": [ 15 | "CrossAttnDownBlock2D", 16 | "CrossAttnDownBlock2D", 17 | "CrossAttnDownBlock2D", 18 | "DownBlock2D" 19 | ], 20 | "downsample_padding": 1, 21 | "flip_sin_to_cos": true, 22 | "freq_shift": 0, 23 | "in_channels": 4, 24 | "layers_per_block": 2, 25 | "mid_block_scale_factor": 1, 26 | "norm_eps": 1e-05, 27 | "norm_num_groups": 32, 28 | "out_channels": 4, 29 | "sample_size": 64, 30 | "up_block_types": [ 31 | "UpBlock2D", 32 | "CrossAttnUpBlock2D", 33 | "CrossAttnUpBlock2D", 34 | "CrossAttnUpBlock2D" 35 | ] 36 | } 37 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-v1-5/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-v1-5/vae/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKL", 3 | "_diffusers_version": "0.6.0", 4 | "act_fn": "silu", 5 | "block_out_channels": [ 6 | 128, 7 | 256, 8 | 512, 9 | 512 10 | ], 11 | "down_block_types": [ 12 | "DownEncoderBlock2D", 13 | "DownEncoderBlock2D", 14 | "DownEncoderBlock2D", 15 | "DownEncoderBlock2D" 16 | ], 17 | "in_channels": 3, 18 | "latent_channels": 4, 19 | "layers_per_block": 2, 20 | "norm_num_groups": 32, 21 | "out_channels": 3, 22 | "sample_size": 512, 23 | "up_block_types": [ 24 | "UpDecoderBlock2D", 25 | "UpDecoderBlock2D", 26 | "UpDecoderBlock2D", 27 | "UpDecoderBlock2D" 28 | ] 29 | } 30 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-v1.5/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tar filter=lfs diff=lfs merge=lfs -text 29 | *.tflite filter=lfs diff=lfs merge=lfs -text 30 | *.tgz filter=lfs diff=lfs merge=lfs -text 31 | *.wasm filter=lfs diff=lfs merge=lfs -text 32 | *.xz filter=lfs diff=lfs merge=lfs -text 33 | *.zip filter=lfs diff=lfs merge=lfs -text 34 | *.zst filter=lfs diff=lfs merge=lfs -text 35 | *tfevents* filter=lfs diff=lfs merge=lfs -text 36 | 01.png filter=lfs diff=lfs merge=lfs -text 37 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 Stability AI 2 | CreativeML Open RAIL++-M License dated July 26, 2023 3 | 4 | Section I: PREAMBLE 5 | Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation. 6 | Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations. 7 | In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation. 8 | Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this agreement aims to strike a balance between both in order to enable responsible open-science in the field of AI. 9 | This CreativeML Open RAIL++-M License governs the use of the model (and its derivatives) and is informed by the model card associated with the model. 10 | NOW THEREFORE, You and Licensor agree as follows: 11 | Definitions 12 | "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document. 13 | "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. 14 | "Output" means the results of operating a Model as embodied in informational content resulting therefrom. 15 | "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material. 16 | "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model. 17 | "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any. 18 | "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access. 19 | "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model. 20 | "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator. 21 | "Third Parties" means individuals or legal entities that are not under common control with Licensor or You. 22 | "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 23 | "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model. 24 | 25 | Section II: INTELLECTUAL PROPERTY RIGHTS 26 | Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in 27 | 28 | Section III. 29 | Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model. 30 | Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed. 31 | Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION 32 | Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions: Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material. You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License; You must cause any modified files to carry prominent notices stating that You changed the files; You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License. 33 | Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5). 34 | The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License. 35 | 36 | Section IV: OTHER PROVISIONS 37 | Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License. 38 | Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors. 39 | Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License. 40 | Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 41 | Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 42 | If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. 43 | 44 | END OF TERMS AND CONDITIONS 45 | 46 | Attachment A 47 | Use Restrictions 48 | You agree not to use the Model or Derivatives of the Model: 49 | In any way that violates any applicable national, federal, state, local or international law or regulation; 50 | For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; 51 | To generate or disseminate verifiably false information and/or content with the purpose of harming others; 52 | To generate or disseminate personal identifiable information that can be used to harm an individual; 53 | To defame, disparage or otherwise harass others; 54 | For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation; 55 | For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics; 56 | To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; 57 | For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories; 58 | To provide medical advice and medical results interpretation; 59 | To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use). 60 | 61 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | license: openrail++ 3 | tags: 4 | - text-to-image 5 | - stable-diffusion 6 | --- 7 | # SD-XL 1.0-base Model Card 8 | ![row01](01.png) 9 | 10 | ## Model 11 | 12 | ![pipeline](pipeline.png) 13 | 14 | [SDXL](https://arxiv.org/abs/2307.01952) consists of an [ensemble of experts](https://arxiv.org/abs/2211.01324) pipeline for latent diffusion: 15 | In a first step, the base model is used to generate (noisy) latents, 16 | which are then further processed with a refinement model (available here: https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/) specialized for the final denoising steps. 17 | Note that the base model can be used as a standalone module. 18 | 19 | Alternatively, we can use a two-stage pipeline as follows: 20 | First, the base model is used to generate latents of the desired output size. 21 | In the second step, we use a specialized high-resolution model and apply a technique called SDEdit (https://arxiv.org/abs/2108.01073, also known as "img2img") 22 | to the latents generated in the first step, using the same prompt. This technique is slightly slower than the first one, as it requires more function evaluations. 23 | 24 | Source code is available at https://github.com/Stability-AI/generative-models . 25 | 26 | ### Model Description 27 | 28 | - **Developed by:** Stability AI 29 | - **Model type:** Diffusion-based text-to-image generative model 30 | - **License:** [CreativeML Open RAIL++-M License](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md) 31 | - **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses two fixed, pretrained text encoders ([OpenCLIP-ViT/G](https://github.com/mlfoundations/open_clip) and [CLIP-ViT/L](https://github.com/openai/CLIP/tree/main)). 32 | - **Resources for more information:** Check out our [GitHub Repository](https://github.com/Stability-AI/generative-models) and the [SDXL report on arXiv](https://arxiv.org/abs/2307.01952). 33 | 34 | ### Model Sources 35 | 36 | For research purposes, we recommend our `generative-models` Github repository (https://github.com/Stability-AI/generative-models), which implements the most popular diffusion frameworks (both training and inference) and for which new functionalities like distillation will be added over time. 37 | [Clipdrop](https://clipdrop.co/stable-diffusion) provides free SDXL inference. 38 | 39 | - **Repository:** https://github.com/Stability-AI/generative-models 40 | - **Demo:** https://clipdrop.co/stable-diffusion 41 | 42 | 43 | ## Evaluation 44 | ![comparison](comparison.png) 45 | The chart above evaluates user preference for SDXL (with and without refinement) over SDXL 0.9 and Stable Diffusion 1.5 and 2.1. 46 | The SDXL base model performs significantly better than the previous variants, and the model combined with the refinement module achieves the best overall performance. 47 | 48 | 49 | ### 🧨 Diffusers 50 | 51 | Make sure to upgrade diffusers to >= 0.19.0: 52 | ``` 53 | pip install diffusers --upgrade 54 | ``` 55 | 56 | In addition make sure to install `transformers`, `safetensors`, `accelerate` as well as the invisible watermark: 57 | ``` 58 | pip install invisible_watermark transformers accelerate safetensors 59 | ``` 60 | 61 | To just use the base model, you can run: 62 | 63 | ```py 64 | from diffusers import DiffusionPipeline 65 | import torch 66 | 67 | pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16") 68 | pipe.to("cuda") 69 | 70 | # if using torch < 2.0 71 | # pipe.enable_xformers_memory_efficient_attention() 72 | 73 | prompt = "An astronaut riding a green horse" 74 | 75 | images = pipe(prompt=prompt).images[0] 76 | ``` 77 | 78 | To use the whole base + refiner pipeline as an ensemble of experts you can run: 79 | 80 | ```py 81 | from diffusers import DiffusionPipeline 82 | import torch 83 | 84 | # load both base & refiner 85 | base = DiffusionPipeline.from_pretrained( 86 | "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True 87 | ) 88 | base.to("cuda") 89 | refiner = DiffusionPipeline.from_pretrained( 90 | "stabilityai/stable-diffusion-xl-refiner-1.0", 91 | text_encoder_2=base.text_encoder_2, 92 | vae=base.vae, 93 | torch_dtype=torch.float16, 94 | use_safetensors=True, 95 | variant="fp16", 96 | ) 97 | refiner.to("cuda") 98 | 99 | # Define how many steps and what % of steps to be run on each experts (80/20) here 100 | n_steps = 40 101 | high_noise_frac = 0.8 102 | 103 | prompt = "A majestic lion jumping from a big stone at night" 104 | 105 | # run both experts 106 | image = base( 107 | prompt=prompt, 108 | num_inference_steps=n_steps, 109 | denoising_end=high_noise_frac, 110 | output_type="latent", 111 | ).images 112 | image = refiner( 113 | prompt=prompt, 114 | num_inference_steps=n_steps, 115 | denoising_start=high_noise_frac, 116 | image=image, 117 | ).images[0] 118 | ``` 119 | 120 | When using `torch >= 2.0`, you can improve the inference speed by 20-30% with torch.compile. Simple wrap the unet with torch compile before running the pipeline: 121 | ```py 122 | pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) 123 | ``` 124 | 125 | If you are limited by GPU VRAM, you can enable *cpu offloading* by calling `pipe.enable_model_cpu_offload` 126 | instead of `.to("cuda")`: 127 | 128 | ```diff 129 | - pipe.to("cuda") 130 | + pipe.enable_model_cpu_offload() 131 | ``` 132 | 133 | For more information on how to use Stable Diffusion XL with `diffusers`, please have a look at [the Stable Diffusion XL Docs](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl). 134 | 135 | ### Optimum 136 | [Optimum](https://github.com/huggingface/optimum) provides a Stable Diffusion pipeline compatible with both [OpenVINO](https://docs.openvino.ai/latest/index.html) and [ONNX Runtime](https://onnxruntime.ai/). 137 | 138 | #### OpenVINO 139 | 140 | To install Optimum with the dependencies required for OpenVINO : 141 | 142 | ```bash 143 | pip install optimum[openvino] 144 | ``` 145 | 146 | To load an OpenVINO model and run inference with OpenVINO Runtime, you need to replace `StableDiffusionXLPipeline` with Optimum `OVStableDiffusionXLPipeline`. In case you want to load a PyTorch model and convert it to the OpenVINO format on-the-fly, you can set `export=True`. 147 | 148 | ```diff 149 | - from diffusers import StableDiffusionXLPipeline 150 | + from optimum.intel import OVStableDiffusionXLPipeline 151 | 152 | model_id = "stabilityai/stable-diffusion-xl-base-1.0" 153 | - pipeline = StableDiffusionXLPipeline.from_pretrained(model_id) 154 | + pipeline = OVStableDiffusionXLPipeline.from_pretrained(model_id) 155 | prompt = "A majestic lion jumping from a big stone at night" 156 | image = pipeline(prompt).images[0] 157 | ``` 158 | 159 | You can find more examples (such as static reshaping and model compilation) in optimum [documentation](https://huggingface.co/docs/optimum/main/en/intel/inference#stable-diffusion-xl). 160 | 161 | 162 | #### ONNX 163 | 164 | To install Optimum with the dependencies required for ONNX Runtime inference : 165 | 166 | ```bash 167 | pip install optimum[onnxruntime] 168 | ``` 169 | 170 | To load an ONNX model and run inference with ONNX Runtime, you need to replace `StableDiffusionXLPipeline` with Optimum `ORTStableDiffusionXLPipeline`. In case you want to load a PyTorch model and convert it to the ONNX format on-the-fly, you can set `export=True`. 171 | 172 | ```diff 173 | - from diffusers import StableDiffusionXLPipeline 174 | + from optimum.onnxruntime import ORTStableDiffusionXLPipeline 175 | 176 | model_id = "stabilityai/stable-diffusion-xl-base-1.0" 177 | - pipeline = StableDiffusionXLPipeline.from_pretrained(model_id) 178 | + pipeline = ORTStableDiffusionXLPipeline.from_pretrained(model_id) 179 | prompt = "A majestic lion jumping from a big stone at night" 180 | image = pipeline(prompt).images[0] 181 | ``` 182 | 183 | You can find more examples in optimum [documentation](https://huggingface.co/docs/optimum/main/en/onnxruntime/usage_guides/models#stable-diffusion-xl). 184 | 185 | 186 | ## Uses 187 | 188 | ### Direct Use 189 | 190 | The model is intended for research purposes only. Possible research areas and tasks include 191 | 192 | - Generation of artworks and use in design and other artistic processes. 193 | - Applications in educational or creative tools. 194 | - Research on generative models. 195 | - Safe deployment of models which have the potential to generate harmful content. 196 | - Probing and understanding the limitations and biases of generative models. 197 | 198 | Excluded uses are described below. 199 | 200 | ### Out-of-Scope Use 201 | 202 | The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model. 203 | 204 | ## Limitations and Bias 205 | 206 | ### Limitations 207 | 208 | - The model does not achieve perfect photorealism 209 | - The model cannot render legible text 210 | - The model struggles with more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere” 211 | - Faces and people in general may not be generated properly. 212 | - The autoencoding part of the model is lossy. 213 | 214 | ### Bias 215 | While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases. 216 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/model_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "StableDiffusionXLPipeline", 3 | "_diffusers_version": "0.19.0.dev0", 4 | "force_zeros_for_empty_prompt": true, 5 | "add_watermarker": null, 6 | "scheduler": [ 7 | "diffusers", 8 | "EulerDiscreteScheduler" 9 | ], 10 | "text_encoder": [ 11 | "transformers", 12 | "CLIPTextModel" 13 | ], 14 | "text_encoder_2": [ 15 | "transformers", 16 | "CLIPTextModelWithProjection" 17 | ], 18 | "tokenizer": [ 19 | "transformers", 20 | "CLIPTokenizer" 21 | ], 22 | "tokenizer_2": [ 23 | "transformers", 24 | "CLIPTokenizer" 25 | ], 26 | "unet": [ 27 | "diffusers", 28 | "UNet2DConditionModel" 29 | ], 30 | "vae": [ 31 | "diffusers", 32 | "AutoencoderKL" 33 | ] 34 | } 35 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/scheduler/scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.19.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "epsilon", 11 | "sample_max_value": 1.0, 12 | "set_alpha_to_one": false, 13 | "skip_prk_steps": true, 14 | "steps_offset": 1, 15 | "timestep_spacing": "leading", 16 | "trained_betas": null, 17 | "use_karras_sigmas": false 18 | } 19 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/text_encoder/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "CLIPTextModel" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 0, 7 | "dropout": 0.0, 8 | "eos_token_id": 2, 9 | "hidden_act": "quick_gelu", 10 | "hidden_size": 768, 11 | "initializer_factor": 1.0, 12 | "initializer_range": 0.02, 13 | "intermediate_size": 3072, 14 | "layer_norm_eps": 1e-05, 15 | "max_position_embeddings": 77, 16 | "model_type": "clip_text_model", 17 | "num_attention_heads": 12, 18 | "num_hidden_layers": 12, 19 | "pad_token_id": 1, 20 | "projection_dim": 768, 21 | "torch_dtype": "float16", 22 | "transformers_version": "4.32.0.dev0", 23 | "vocab_size": 49408 24 | } 25 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/text_encoder/openvino_model.xml: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ab5cf7327374d8c984f4e963564a329f92c9dad08dac9eee9b8dca86b912f1c9 3 | size 1057789 4 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/text_encoder_2/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "CLIPTextModelWithProjection" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 0, 7 | "dropout": 0.0, 8 | "eos_token_id": 2, 9 | "hidden_act": "gelu", 10 | "hidden_size": 1280, 11 | "initializer_factor": 1.0, 12 | "initializer_range": 0.02, 13 | "intermediate_size": 5120, 14 | "layer_norm_eps": 1e-05, 15 | "max_position_embeddings": 77, 16 | "model_type": "clip_text_model", 17 | "num_attention_heads": 20, 18 | "num_hidden_layers": 32, 19 | "pad_token_id": 1, 20 | "projection_dim": 1280, 21 | "torch_dtype": "float16", 22 | "transformers_version": "4.32.0.dev0", 23 | "vocab_size": 49408 24 | } 25 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/text_encoder_2/openvino_model.xml: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:38f0a4ff68dd918b24908a264140c2ad0e057eca82616f75c17cbf4a099ad6ad 3 | size 2790191 4 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": "<|endoftext|>", 17 | "unk_token": { 18 | "content": "<|endoftext|>", 19 | "lstrip": false, 20 | "normalized": true, 21 | "rstrip": false, 22 | "single_word": false 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": { 4 | "__type": "AddedToken", 5 | "content": "<|startoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false 10 | }, 11 | "clean_up_tokenization_spaces": true, 12 | "do_lower_case": true, 13 | "eos_token": { 14 | "__type": "AddedToken", 15 | "content": "<|endoftext|>", 16 | "lstrip": false, 17 | "normalized": true, 18 | "rstrip": false, 19 | "single_word": false 20 | }, 21 | "errors": "replace", 22 | "model_max_length": 77, 23 | "pad_token": "<|endoftext|>", 24 | "tokenizer_class": "CLIPTokenizer", 25 | "unk_token": { 26 | "__type": "AddedToken", 27 | "content": "<|endoftext|>", 28 | "lstrip": false, 29 | "normalized": true, 30 | "rstrip": false, 31 | "single_word": false 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/tokenizer_2/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": "!", 17 | "unk_token": { 18 | "content": "<|endoftext|>", 19 | "lstrip": false, 20 | "normalized": true, 21 | "rstrip": false, 22 | "single_word": false 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/tokenizer_2/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": { 4 | "__type": "AddedToken", 5 | "content": "<|startoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false 10 | }, 11 | "clean_up_tokenization_spaces": true, 12 | "do_lower_case": true, 13 | "eos_token": { 14 | "__type": "AddedToken", 15 | "content": "<|endoftext|>", 16 | "lstrip": false, 17 | "normalized": true, 18 | "rstrip": false, 19 | "single_word": false 20 | }, 21 | "errors": "replace", 22 | "model_max_length": 77, 23 | "pad_token": "!", 24 | "tokenizer_class": "CLIPTokenizer", 25 | "unk_token": { 26 | "__type": "AddedToken", 27 | "content": "<|endoftext|>", 28 | "lstrip": false, 29 | "normalized": true, 30 | "rstrip": false, 31 | "single_word": false 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/unet/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNet2DConditionModel", 3 | "_diffusers_version": "0.19.0.dev0", 4 | "act_fn": "silu", 5 | "addition_embed_type": "text_time", 6 | "addition_embed_type_num_heads": 64, 7 | "addition_time_embed_dim": 256, 8 | "attention_head_dim": [ 9 | 5, 10 | 10, 11 | 20 12 | ], 13 | "block_out_channels": [ 14 | 320, 15 | 640, 16 | 1280 17 | ], 18 | "center_input_sample": false, 19 | "class_embed_type": null, 20 | "class_embeddings_concat": false, 21 | "conv_in_kernel": 3, 22 | "conv_out_kernel": 3, 23 | "cross_attention_dim": 2048, 24 | "cross_attention_norm": null, 25 | "down_block_types": [ 26 | "DownBlock2D", 27 | "CrossAttnDownBlock2D", 28 | "CrossAttnDownBlock2D" 29 | ], 30 | "downsample_padding": 1, 31 | "dual_cross_attention": false, 32 | "encoder_hid_dim": null, 33 | "encoder_hid_dim_type": null, 34 | "flip_sin_to_cos": true, 35 | "freq_shift": 0, 36 | "in_channels": 4, 37 | "layers_per_block": 2, 38 | "mid_block_only_cross_attention": null, 39 | "mid_block_scale_factor": 1, 40 | "mid_block_type": "UNetMidBlock2DCrossAttn", 41 | "norm_eps": 1e-05, 42 | "norm_num_groups": 32, 43 | "num_attention_heads": null, 44 | "num_class_embeds": null, 45 | "only_cross_attention": false, 46 | "out_channels": 4, 47 | "projection_class_embeddings_input_dim": 2816, 48 | "resnet_out_scale_factor": 1.0, 49 | "resnet_skip_time_act": false, 50 | "resnet_time_scale_shift": "default", 51 | "sample_size": 128, 52 | "time_cond_proj_dim": null, 53 | "time_embedding_act_fn": null, 54 | "time_embedding_dim": null, 55 | "time_embedding_type": "positional", 56 | "timestep_post_act": null, 57 | "transformer_layers_per_block": [ 58 | 1, 59 | 2, 60 | 10 61 | ], 62 | "up_block_types": [ 63 | "CrossAttnUpBlock2D", 64 | "CrossAttnUpBlock2D", 65 | "UpBlock2D" 66 | ], 67 | "upcast_attention": null, 68 | "use_linear_projection": true 69 | } 70 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/unet/openvino_model.xml: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:18955f96dffdba5612c4b554451f4ccc947c93e46df010173791654af4d0d7f6 3 | size 22577438 4 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/vae/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKL", 3 | "_diffusers_version": "0.20.0.dev0", 4 | "_name_or_path": "../sdxl-vae/", 5 | "act_fn": "silu", 6 | "block_out_channels": [ 7 | 128, 8 | 256, 9 | 512, 10 | 512 11 | ], 12 | "down_block_types": [ 13 | "DownEncoderBlock2D", 14 | "DownEncoderBlock2D", 15 | "DownEncoderBlock2D", 16 | "DownEncoderBlock2D" 17 | ], 18 | "force_upcast": true, 19 | "in_channels": 3, 20 | "latent_channels": 4, 21 | "layers_per_block": 2, 22 | "norm_num_groups": 32, 23 | "out_channels": 3, 24 | "sample_size": 1024, 25 | "scaling_factor": 0.13025, 26 | "up_block_types": [ 27 | "UpDecoderBlock2D", 28 | "UpDecoderBlock2D", 29 | "UpDecoderBlock2D", 30 | "UpDecoderBlock2D" 31 | ] 32 | } 33 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/vae_1_0/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKL", 3 | "_diffusers_version": "0.19.0.dev0", 4 | "act_fn": "silu", 5 | "block_out_channels": [ 6 | 128, 7 | 256, 8 | 512, 9 | 512 10 | ], 11 | "down_block_types": [ 12 | "DownEncoderBlock2D", 13 | "DownEncoderBlock2D", 14 | "DownEncoderBlock2D", 15 | "DownEncoderBlock2D" 16 | ], 17 | "force_upcast": true, 18 | "in_channels": 3, 19 | "latent_channels": 4, 20 | "layers_per_block": 2, 21 | "norm_num_groups": 32, 22 | "out_channels": 3, 23 | "sample_size": 1024, 24 | "scaling_factor": 0.13025, 25 | "up_block_types": [ 26 | "UpDecoderBlock2D", 27 | "UpDecoderBlock2D", 28 | "UpDecoderBlock2D", 29 | "UpDecoderBlock2D" 30 | ] 31 | } 32 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/vae_decoder/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKL", 3 | "_diffusers_version": "0.19.0.dev0", 4 | "act_fn": "silu", 5 | "block_out_channels": [ 6 | 128, 7 | 256, 8 | 512, 9 | 512 10 | ], 11 | "down_block_types": [ 12 | "DownEncoderBlock2D", 13 | "DownEncoderBlock2D", 14 | "DownEncoderBlock2D", 15 | "DownEncoderBlock2D" 16 | ], 17 | "force_upcast": true, 18 | "in_channels": 3, 19 | "latent_channels": 4, 20 | "layers_per_block": 2, 21 | "norm_num_groups": 32, 22 | "out_channels": 3, 23 | "sample_size": 1024, 24 | "scaling_factor": 0.13025, 25 | "up_block_types": [ 26 | "UpDecoderBlock2D", 27 | "UpDecoderBlock2D", 28 | "UpDecoderBlock2D", 29 | "UpDecoderBlock2D" 30 | ] 31 | } 32 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/vae_decoder/openvino_model.xml: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:dd61f43e981282b77ecaecf5fc5c842d504932bae78ac99ec581cee50978b423 3 | size 992181 4 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/vae_encoder/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKL", 3 | "_diffusers_version": "0.19.0.dev0", 4 | "act_fn": "silu", 5 | "block_out_channels": [ 6 | 128, 7 | 256, 8 | 512, 9 | 512 10 | ], 11 | "down_block_types": [ 12 | "DownEncoderBlock2D", 13 | "DownEncoderBlock2D", 14 | "DownEncoderBlock2D", 15 | "DownEncoderBlock2D" 16 | ], 17 | "force_upcast": true, 18 | "in_channels": 3, 19 | "latent_channels": 4, 20 | "layers_per_block": 2, 21 | "norm_num_groups": 32, 22 | "out_channels": 3, 23 | "sample_size": 1024, 24 | "scaling_factor": 0.13025, 25 | "up_block_types": [ 26 | "UpDecoderBlock2D", 27 | "UpDecoderBlock2D", 28 | "UpDecoderBlock2D", 29 | "UpDecoderBlock2D" 30 | ] 31 | } 32 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl-base-1.0/vae_encoder/openvino_model.xml: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a3ec36b6f3f74d0cb2b005b7c0a1e5426c5ef1e7163b33e463ea57fa049c5996 3 | size 849965 4 | -------------------------------------------------------------------------------- /configs/models_config/stable-diffusion-xl/sd_xl_base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: sgm.models.diffusion.DiffusionEngine 3 | params: 4 | scale_factor: 0.13025 5 | disable_first_stage_autocast: True 6 | 7 | denoiser_config: 8 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 9 | params: 10 | num_idx: 1000 11 | 12 | scaling_config: 13 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 14 | discretization_config: 15 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 16 | 17 | network_config: 18 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | adm_in_channels: 2816 21 | num_classes: sequential 22 | use_checkpoint: True 23 | in_channels: 4 24 | out_channels: 4 25 | model_channels: 320 26 | attention_resolutions: [4, 2] 27 | num_res_blocks: 2 28 | channel_mult: [1, 2, 4] 29 | num_head_channels: 64 30 | use_linear_in_transformer: True 31 | transformer_depth: [1, 2, 10] 32 | context_dim: 2048 33 | spatial_transformer_attn_type: softmax-xformers 34 | 35 | conditioner_config: 36 | target: sgm.modules.GeneralConditioner 37 | params: 38 | emb_models: 39 | - is_trainable: False 40 | input_key: txt 41 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 42 | params: 43 | layer: hidden 44 | layer_idx: 11 45 | 46 | - is_trainable: False 47 | input_key: txt 48 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 49 | params: 50 | arch: ViT-bigG-14 51 | version: laion2b_s39b_b160k 52 | freeze: True 53 | layer: penultimate 54 | always_return_pooled: True 55 | legacy: False 56 | 57 | - is_trainable: False 58 | input_key: original_size_as_tuple 59 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 60 | params: 61 | outdim: 256 62 | 63 | - is_trainable: False 64 | input_key: crop_coords_top_left 65 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 66 | params: 67 | outdim: 256 68 | 69 | - is_trainable: False 70 | input_key: target_size_as_tuple 71 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 72 | params: 73 | outdim: 256 74 | 75 | first_stage_config: 76 | target: sgm.models.autoencoder.AutoencoderKL 77 | params: 78 | embed_dim: 4 79 | monitor: val/rec_loss 80 | ddconfig: 81 | attn_type: vanilla-xformers 82 | double_z: true 83 | z_channels: 4 84 | resolution: 256 85 | in_channels: 3 86 | out_ch: 3 87 | ch: 128 88 | ch_mult: [1, 2, 4, 4] 89 | num_res_blocks: 2 90 | attn_resolutions: [] 91 | dropout: 0.0 92 | lossconfig: 93 | target: torch.nn.Identity 94 | -------------------------------------------------------------------------------- /configs/sliders/config-xl.yaml: -------------------------------------------------------------------------------- 1 | prompts_file: "" 2 | pretrained_model: 3 | name_or_path: "" # you can also use .ckpt or .safetensors models 4 | v2: false # true if model is v2.x 5 | v_pred: false # true if model uses v-prediction 6 | network: 7 | type: "c3lier" # or "c3lier" or "lierla" 8 | rank: 4 9 | alpha: 1.0 10 | training_method: "noxattn" 11 | train: 12 | precision: "float16" 13 | noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a" 14 | iterations: 1000 15 | lr: 0.0002 16 | optimizer: "AdamW" 17 | lr_scheduler: "constant" 18 | max_denoising_steps: 50 19 | save: 20 | name: "temp" 21 | path: "./models" 22 | per_steps: 500 23 | precision: "float16" 24 | logging: 25 | use_wandb: false 26 | verbose: false 27 | other: 28 | use_xformers: true -------------------------------------------------------------------------------- /configs/sliders/prompts-xl.yaml: -------------------------------------------------------------------------------- 1 | - target: "" # 什么词可以消除积极的概念 2 | positive: "" # 要擦除的概念 3 | unconditional: "" # 与积极概念区别的词 4 | neutral: "" # 调节目标的起点 5 | action: "enhance" # 擦除或增强 6 | guidance_scale: 4 7 | resolution: 768 8 | dynamic_resolution: true 9 | batch_size: 1 -------------------------------------------------------------------------------- /examples/captioner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinusZoneAI/ComfyUI-TrainTools-MZ/cc2faae052d80a51914fefdcaca82f2612af8fcf/examples/captioner.png -------------------------------------------------------------------------------- /examples/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinusZoneAI/ComfyUI-TrainTools-MZ/cc2faae052d80a51914fefdcaca82f2612af8fcf/examples/workflow.png -------------------------------------------------------------------------------- /hook_HYDiT_idk_run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from index_kits.dataset.make_dataset_core import startup, make_multireso 4 | from index_kits.common import show_index_info 5 | from index_kits import __version__ 6 | 7 | 8 | def common_args(parser): 9 | parser.add_argument('-t', '--target', type=str, 10 | required=True, help='Save path') 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser(description=""" 15 | IndexKits is a tool to build and manage index files for large-scale datasets. 16 | It supports both base index and multi-resolution index. 17 | 18 | Introduction 19 | ------------ 20 | This command line tool provides the following functionalities: 21 | 1. Show index v2 information 22 | 2. Build base index v2 23 | 3. Build multi-resolution index v2 24 | 25 | Examples 26 | -------- 27 | 1. Show index v2 information 28 | index_kits show /path/to/index.json 29 | 30 | 2. Build base index v2 31 | Default usage: 32 | index_kits base -c /path/to/config.yaml -t /path/to/index.json 33 | 34 | Use multiple processes: 35 | index_kits base -c /path/to/config.yaml -t /path/to/index.json -w 40 36 | 37 | 3. Build multi-resolution index v2 38 | 39 | Build with a configuration file: 40 | index_kits multireso -c /path/to/config.yaml -t /path/to/index_mb_gt512.json 41 | 42 | Build by specifying arguments without a configuration file: 43 | index_kits multireso --src /path/to/index.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json 44 | 45 | Build by specifying target-ratios: 46 | index_kits multireso --src /path/to/index.json --base-size 512 --target-ratios 1:1 4:3 3:4 16:9 9:16 --min-size 512 -t /path/to/index_mb_gt512.json 47 | 48 | Build with multiple source index files. 49 | index_kits multireso --src /path/to/index1.json /path/to/index2.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json 50 | """, formatter_class=argparse.RawTextHelpFormatter) 51 | sub_parsers = parser.add_subparsers(dest='task', required=True) 52 | 53 | # Show index message 54 | show_parser = sub_parsers.add_parser('show', description=""" 55 | Show base/multireso index v2 information. 56 | 57 | Example 58 | ------- 59 | index_kits show /path/to/index.json 60 | """, formatter_class=argparse.RawTextHelpFormatter) 61 | show_parser.add_argument( 62 | 'src', type=str, help='Path to a base/multireso index file.') 63 | show_parser.add_argument( 64 | '--arrow-files', action='store_true', help='Show arrow files only.') 65 | show_parser.add_argument('--depth', type=int, default=1, 66 | help='Arrow file depth. Default is 1, the level of last folder in the arrow file path. ' 67 | 'Set it to 0 to show the full path including `xxx/last_folder/*.arrow`.') 68 | 69 | # Single resolution bucket 70 | base_parser = sub_parsers.add_parser('base', description=""" 71 | Build base index v2. 72 | 73 | Example 74 | ------- 75 | index_kits base -c /path/to/config.yaml -t /path/to/index.json 76 | """, formatter_class=argparse.RawTextHelpFormatter) 77 | base_parser.add_argument('-c', '--config', type=str, 78 | required=True, help='Configuration file path') 79 | common_args(base_parser) 80 | base_parser.add_argument('-w', '--world-size', type=int, default=1) 81 | base_parser.add_argument('--work-dir', type=str, 82 | default='.', help='Work directory') 83 | base_parser.add_argument('--use-cache', action='store_true', help='Use cache to avoid reprocessing. ' 84 | 'Perform merge pkl results directly.') 85 | 86 | # Multi-resolution bucket 87 | mo_parser = sub_parsers.add_parser('multireso', description=""" 88 | Build multi-resolution index v2 89 | 90 | Example 91 | ------- 92 | Build with a configuration file: 93 | index_kits multireso -c /path/to/config.yaml -t /path/to/index_mb_gt512.json 94 | 95 | Build by specifying arguments without a configuration file: 96 | index_kits multireso --src /path/to/index.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json 97 | 98 | Build by specifying target-ratios: 99 | index_kits multireso --src /path/to/index.json --base-size 512 --target-ratios 1:1 4:3 3:4 16:9 9:16 --min-size 512 -t /path/to/index_mb_gt512.json 100 | 101 | Build with multiple source index files. 102 | index_kits multireso --src /path/to/index1.json /path/to/index2.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json 103 | """, formatter_class=argparse.RawTextHelpFormatter) 104 | mo_parser.add_argument('-c', '--config', type=str, default=None, 105 | help='Configuration file path in a yaml format. Either --config or --src must be provided.') 106 | mo_parser.add_argument('-s', '--src', type=str, nargs='+', default=None, 107 | help='Source index files. Either --config or --src must be provided.') 108 | common_args(mo_parser) 109 | mo_parser.add_argument('--base-size', type=int, default=None, 110 | help="Base size. Typically set as 256/512/1024 according to image size you train model.") 111 | mo_parser.add_argument('--reso-step', type=int, default=None, 112 | help="Resolution step. Either reso_step or target_ratios must be provided.") 113 | mo_parser.add_argument('--target-ratios', type=str, nargs='+', default=None, 114 | help="Target ratios. Either reso_step or target_ratios must be provided.") 115 | mo_parser.add_argument('--md5-file', type=str, default=None, 116 | help='You can provide an md5 to height and width file to accelerate the process. ' 117 | 'It is a pickle file that contains a dict, which maps md5 to (height, width) tuple.') 118 | mo_parser.add_argument('--align', type=int, default=16, 119 | help="Used when --target-ratios is provided. Align size of source image height and width.") 120 | mo_parser.add_argument('--min-size', type=int, default=0, 121 | help="Minimum size. Images smaller than this size will be ignored.") 122 | 123 | # Common 124 | parser.add_argument('-v', '--version', action='version', 125 | version=f'%(prog)s {__version__}') 126 | 127 | args = parser.parse_args() 128 | return args 129 | 130 | 131 | if __name__ == '__main__': 132 | args = get_args() 133 | if args.task == 'show': 134 | show_index_info(args.src, 135 | args.arrow_files, 136 | args.depth, 137 | ) 138 | elif args.task == 'base': 139 | startup(args.config, 140 | args.target, 141 | args.world_size, 142 | args.work_dir, 143 | use_cache=args.use_cache, 144 | ) 145 | elif args.task == 'multireso': 146 | make_multireso(args.target, 147 | args.config, 148 | args.src, 149 | args.base_size, 150 | args.reso_step, 151 | args.target_ratios, 152 | args.align, 153 | args.min_size, 154 | args.md5_file, 155 | ) 156 | -------------------------------------------------------------------------------- /hook_HYDiT_main_train_deepspeed.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import os 4 | import random 5 | import sys 6 | import time 7 | from functools import partial 8 | from glob import glob 9 | from pathlib import Path 10 | import numpy as np 11 | import safetensors.torch 12 | 13 | import deepspeed 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.distributed as dist 18 | from torch.utils.data import DataLoader 19 | from torch.distributed.optim import ZeroRedundancyOptimizer 20 | from torchvision.transforms import functional as TF 21 | from diffusers.models import AutoencoderKL 22 | from transformers import BertModel, BertTokenizer, logging as tf_logging 23 | 24 | from hydit.config import get_args 25 | from hydit.lr_scheduler import WarmupLR 26 | from hydit.data_loader.arrow_load_stream import TextImageArrowStream 27 | from hydit.diffusion import create_diffusion 28 | from hydit.ds_config import deepspeed_config_from_args 29 | from hydit.modules.ema import EMA 30 | from hydit.modules.fp16_layers import Float16Module 31 | from hydit.modules.models import HUNYUAN_DIT_MODELS 32 | from hydit.modules.posemb_layers import init_image_posemb 33 | from hydit.utils.tools import create_logger, set_seeds, create_exp_folder, get_trainable_params 34 | from IndexKits.index_kits import ResolutionGroup 35 | from IndexKits.index_kits.sampler import DistributedSamplerWithStartIndex, BlockDistributedSampler 36 | from peft import LoraConfig, get_peft_model 37 | 38 | 39 | def deepspeed_initialize(args, logger, model, opt, deepspeed_config): 40 | logger.info(f"Initialize deepspeed...") 41 | logger.info(f" Using deepspeed optimizer") 42 | 43 | def get_learning_rate_scheduler(warmup_min_lr, lr, warmup_num_steps, opt): 44 | return WarmupLR(opt, warmup_min_lr, lr, warmup_num_steps) 45 | 46 | logger.info( 47 | f" Building scheduler with warmup_min_lr={args.warmup_min_lr}, warmup_num_steps={args.warmup_num_steps}") 48 | logger.info( 49 | f" deepspeed_config={deepspeed_config}") 50 | 51 | model, opt, _, scheduler = deepspeed.initialize( 52 | model=model, 53 | model_parameters=get_trainable_params(model), 54 | config_params=deepspeed_config, 55 | args=args, 56 | lr_scheduler=partial(get_learning_rate_scheduler, args.warmup_min_lr, 57 | args.lr, args.warmup_num_steps) if args.warmup_num_steps > 0 else None, 58 | ) 59 | return model, opt, scheduler 60 | 61 | 62 | def save_checkpoint(args, rank, logger, model, ema, epoch, train_steps, checkpoint_dir): 63 | def save_lora_weight(checkpoint_dir, client_state, tag=f"{train_steps:07d}.pt"): 64 | cur_ckpt_save_dir = f"{checkpoint_dir}/{tag}" 65 | if rank == 0: 66 | if args.use_fp16: 67 | model.module.module.save_pretrained(cur_ckpt_save_dir) 68 | else: 69 | model.module.save_pretrained(cur_ckpt_save_dir) 70 | 71 | checkpoint_path = "[Not rank 0. Disabled output.]" 72 | 73 | client_state = { 74 | "steps": train_steps, 75 | "epoch": epoch, 76 | "args": args 77 | } 78 | if ema is not None: 79 | client_state['ema'] = ema.state_dict() 80 | 81 | dst_paths = [] 82 | if train_steps % args.ckpt_every == 0: 83 | checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" 84 | try: 85 | if args.training_parts == "lora": 86 | save_lora_weight(checkpoint_dir, client_state, 87 | tag=f"{train_steps:07d}.pt") 88 | else: 89 | model.save_checkpoint( 90 | checkpoint_dir, client_state=client_state, tag=f"{train_steps:07d}.pt") 91 | dst_paths.append(checkpoint_path) 92 | logger.info(f"Saved checkpoint to {checkpoint_path}") 93 | except: 94 | logger.error(f"Saved failed to {checkpoint_path}") 95 | 96 | if train_steps % args.ckpt_latest_every == 0 or train_steps == args.max_training_steps: 97 | save_name = "latest.pt" 98 | checkpoint_path = f"{checkpoint_dir}/{save_name}" 99 | try: 100 | if args.training_parts == "lora": 101 | save_lora_weight(checkpoint_dir, client_state, 102 | tag=f"{save_name}") 103 | else: 104 | model.save_checkpoint( 105 | checkpoint_dir, client_state=client_state, tag=f"{save_name}") 106 | dst_paths.append(checkpoint_path) 107 | logger.info(f"Saved checkpoint to {checkpoint_path}") 108 | except: 109 | logger.error(f"Saved failed to {checkpoint_path}") 110 | 111 | dist.barrier() 112 | if rank == 0 and len(dst_paths) > 0: 113 | # Delete optimizer states to avoid occupying too much disk space. 114 | for dst_path in dst_paths: 115 | for opt_state_path in glob(f"{dst_path}/zero_dp_rank_*_tp_rank_00_pp_rank_00_optim_states.pt"): 116 | os.remove(opt_state_path) 117 | 118 | return checkpoint_path 119 | 120 | 121 | @torch.no_grad() 122 | def prepare_model_inputs(args, batch, device, vae, text_encoder, text_encoder_t5, freqs_cis_img): 123 | try: 124 | from .hook_HYDiT_utils import VAE_EMA_PATH, TEXT_ENCODER, TOKENIZER, T5_ENCODER, easy_sample_images, model_resume, PBar 125 | except: 126 | from hook_HYDiT_utils import VAE_EMA_PATH, TEXT_ENCODER, TOKENIZER, T5_ENCODER, easy_sample_images, model_resume, PBar 127 | 128 | image, text_embedding, text_embedding_mask, text_embedding_t5, text_embedding_mask_t5, kwargs = batch 129 | 130 | # clip & mT5 text embedding 131 | text_embedding = text_embedding.to(device) 132 | text_embedding_mask = text_embedding_mask.to(device) 133 | encoder_hidden_states = text_encoder( 134 | text_embedding.to(device), 135 | attention_mask=text_embedding_mask.to(device), 136 | )[0] 137 | text_embedding_t5 = text_embedding_t5.to(device).squeeze(1) 138 | text_embedding_mask_t5 = text_embedding_mask_t5.to(device).squeeze(1) 139 | with torch.no_grad(): 140 | output_t5 = text_encoder_t5( 141 | input_ids=text_embedding_t5, 142 | attention_mask=text_embedding_mask_t5 if T5_ENCODER['attention_mask'] else None, 143 | output_hidden_states=True 144 | ) 145 | encoder_hidden_states_t5 = output_t5['hidden_states'][T5_ENCODER['layer_index']].detach( 146 | ) 147 | 148 | # additional condition 149 | image_meta_size = kwargs['image_meta_size'].to(device) 150 | style = kwargs['style'].to(device) 151 | 152 | if args.extra_fp16: 153 | image = image.half() 154 | image_meta_size = image_meta_size.half() if image_meta_size is not None else None 155 | 156 | # Map input images to latent space + normalize latents: 157 | image = image.to(device) 158 | vae_scaling_factor = vae.config.scaling_factor 159 | latents = vae.encode(image).latent_dist.sample().mul_(vae_scaling_factor) 160 | 161 | # positional embedding 162 | _, _, height, width = image.shape 163 | reso = f"{height}x{width}" 164 | cos_cis_img, sin_cis_img = freqs_cis_img[reso] 165 | 166 | # Model conditions 167 | model_kwargs = dict( 168 | encoder_hidden_states=encoder_hidden_states, 169 | text_embedding_mask=text_embedding_mask, 170 | encoder_hidden_states_t5=encoder_hidden_states_t5, 171 | text_embedding_mask_t5=text_embedding_mask_t5, 172 | image_meta_size=image_meta_size, 173 | style=style, 174 | cos_cis_img=cos_cis_img, 175 | sin_cis_img=sin_cis_img, 176 | ) 177 | 178 | return latents, model_kwargs 179 | 180 | 181 | def Core(args): 182 | if args.training_parts == "lora": 183 | args.use_ema = False 184 | 185 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." 186 | 187 | dist.init_process_group("nccl") 188 | world_size = dist.get_world_size() 189 | batch_size = args.batch_size 190 | grad_accu_steps = args.grad_accu_steps 191 | global_batch_size = world_size * batch_size * grad_accu_steps 192 | 193 | rank = dist.get_rank() 194 | device = rank % torch.cuda.device_count() 195 | seed = args.global_seed * world_size + rank 196 | random.seed(seed) 197 | np.random.seed(seed) 198 | torch.manual_seed(seed) 199 | torch.cuda.manual_seed_all(seed) 200 | torch.cuda.set_device(device) 201 | print(f"Starting rank={rank}, seed={seed}, world_size={world_size}.") 202 | deepspeed_config = deepspeed_config_from_args(args, global_batch_size) 203 | 204 | # Setup an experiment folder 205 | experiment_dir, checkpoint_dir, logger = create_exp_folder(args, rank) 206 | 207 | # Log all the arguments 208 | logger.info(sys.argv) 209 | logger.info(str(args)) 210 | # Save to a json file 211 | args_dict = vars(args) 212 | args_dict['world_size'] = world_size 213 | with open(f"{experiment_dir}/args.json", 'w') as f: 214 | json.dump(args_dict, f, indent=4) 215 | 216 | # Disable the message "Some weights of the model checkpoint at ... were not used when initializing BertModel." 217 | # If needed, just comment the following line. 218 | tf_logging.set_verbosity_error() 219 | 220 | # =========================================================================== 221 | # Building HYDIT 222 | # =========================================================================== 223 | 224 | logger.info("Building HYDIT Model.") 225 | 226 | # --------------------------------------------------------------------------- 227 | # Training sample base size, such as 256/512/1024. Notice that this size is 228 | # just a base size, not necessary the actual size of training samples. Actual 229 | # size of the training samples are correlated with `resolutions` when enabling 230 | # multi-resolution training. 231 | # --------------------------------------------------------------------------- 232 | image_size = args.image_size 233 | if len(image_size) == 1: 234 | image_size = [image_size[0], image_size[0]] 235 | if len(image_size) != 2: 236 | raise ValueError(f"Invalid image size: {args.image_size}") 237 | assert image_size[0] % 8 == 0 and image_size[1] % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder). " \ 238 | f"got {image_size}" 239 | latent_size = [image_size[0] // 8, image_size[1] // 8] 240 | 241 | # initialize model by deepspeed 242 | assert args.deepspeed, f"Must enable deepspeed in this script: train_deepspeed.py" 243 | with deepspeed.zero.Init(data_parallel_group=torch.distributed.group.WORLD, 244 | remote_device=None if args.remote_device == 'none' else args.remote_device, 245 | config_dict_or_path=deepspeed_config, 246 | mpu=None, 247 | enabled=args.zero_stage == 3): 248 | model = HUNYUAN_DIT_MODELS[args.model](args, 249 | input_size=latent_size, 250 | log_fn=logger.info, 251 | ) 252 | # Multi-resolution / Single-resolution training. 253 | if args.multireso: 254 | resolutions = ResolutionGroup(image_size[0], 255 | align=16, 256 | step=args.reso_step, 257 | target_ratios=args.target_ratios).data 258 | else: 259 | resolutions = ResolutionGroup(image_size[0], 260 | align=16, 261 | target_ratios=['1:1']).data 262 | 263 | freqs_cis_img = init_image_posemb(args.rope_img, 264 | resolutions=resolutions, 265 | patch_size=model.patch_size, 266 | hidden_size=model.hidden_size, 267 | num_heads=model.num_heads, 268 | log_fn=logger.info, 269 | rope_real=args.rope_real, 270 | ) 271 | 272 | # Create EMA model and convert to fp16 if needed. 273 | ema = None 274 | if args.use_ema: 275 | ema = EMA(args, model, device, logger) 276 | 277 | # Setup FP16 main model: 278 | if args.use_fp16: 279 | model = Float16Module(model, args) 280 | logger.info( 281 | f" Using main model with data type {'fp16' if args.use_fp16 else 'fp32'}") 282 | 283 | diffusion = create_diffusion( 284 | noise_schedule=args.noise_schedule, 285 | predict_type=args.predict_type, 286 | learn_sigma=args.learn_sigma, 287 | mse_loss_weight_type=args.mse_loss_weight_type, 288 | beta_start=args.beta_start, 289 | beta_end=args.beta_end, 290 | noise_offset=args.noise_offset, 291 | ) 292 | 293 | try: 294 | from .hook_HYDiT_utils import VAE_EMA_PATH, TEXT_ENCODER, TOKENIZER, T5_ENCODER, easy_sample_images, model_resume, PBar, CustomizeEmbeds 295 | except: 296 | from hook_HYDiT_utils import VAE_EMA_PATH, TEXT_ENCODER, TOKENIZER, T5_ENCODER, easy_sample_images, model_resume, PBar, CustomizeEmbeds 297 | 298 | # Setup VAE 299 | logger.info(f" Loading vae from {VAE_EMA_PATH}") 300 | vae = AutoencoderKL.from_pretrained(VAE_EMA_PATH) 301 | # Setup BERT text encoder 302 | logger.info(f" Loading Bert text encoder from {TEXT_ENCODER}") 303 | text_encoder = BertModel.from_pretrained( 304 | TEXT_ENCODER, False, revision=None) 305 | # Setup BERT tokenizer: 306 | logger.info(f" Loading Bert tokenizer from {TOKENIZER}") 307 | tokenizer = BertTokenizer.from_pretrained(TOKENIZER) 308 | # Setup T5 text encoder 309 | from hydit.modules.text_encoder import MT5Embedder 310 | mt5_path = T5_ENCODER['MT5'] 311 | if mt5_path is None: 312 | embedder_t5 = CustomizeEmbeds() 313 | else: 314 | embedder_t5 = MT5Embedder( 315 | mt5_path, torch_dtype=T5_ENCODER['torch_dtype'], max_length=args.text_len_t5) 316 | tokenizer_t5 = embedder_t5.tokenizer 317 | text_encoder_t5 = embedder_t5.model 318 | 319 | if args.extra_fp16: 320 | logger.info(f" Using fp16 for extra modules: vae, text_encoder") 321 | vae = vae.half().to(device) 322 | text_encoder = text_encoder.half().to(device) 323 | text_encoder_t5 = text_encoder_t5.half().to(device) 324 | else: 325 | vae = vae.to(device) 326 | text_encoder = text_encoder.to(device) 327 | text_encoder_t5 = text_encoder_t5.to(device) 328 | 329 | logger.info( 330 | f" Optimizer parameters: lr={args.lr}, weight_decay={args.weight_decay}") 331 | logger.info(" Using deepspeed optimizer") 332 | opt = None 333 | 334 | # =========================================================================== 335 | # Building Dataset 336 | # =========================================================================== 337 | 338 | logger.info(f"Building Streaming Dataset.") 339 | logger.info(f" Loading index file {args.index_file} (v2)") 340 | 341 | dataset = TextImageArrowStream(args=args, 342 | resolution=image_size[0], 343 | random_flip=args.random_flip, 344 | log_fn=logger.info, 345 | index_file=args.index_file, 346 | multireso=args.multireso, 347 | batch_size=batch_size, 348 | world_size=world_size, 349 | random_shrink_size_cond=args.random_shrink_size_cond, 350 | merge_src_cond=args.merge_src_cond, 351 | uncond_p=args.uncond_p, 352 | text_ctx_len=args.text_len, 353 | tokenizer=tokenizer, 354 | uncond_p_t5=args.uncond_p_t5, 355 | text_ctx_len_t5=args.text_len_t5, 356 | tokenizer_t5=tokenizer_t5, 357 | ) 358 | if args.multireso: 359 | sampler = BlockDistributedSampler(dataset, num_replicas=world_size, rank=rank, seed=args.global_seed, 360 | shuffle=False, drop_last=True, batch_size=batch_size) 361 | else: 362 | sampler = DistributedSamplerWithStartIndex(dataset, num_replicas=world_size, rank=rank, seed=args.global_seed, 363 | shuffle=False, drop_last=True) 364 | loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=sampler, 365 | num_workers=args.num_workers, pin_memory=True, drop_last=True) 366 | logger.info(f" Dataset contains {len(dataset):,} images.") 367 | logger.info(f" Index file: {args.index_file}.") 368 | if args.multireso: 369 | logger.info(f' Using MultiResolutionBucketIndexV2 with step {dataset.index_manager.step} ' 370 | f'and base size {dataset.index_manager.base_size}') 371 | logger.info(f'\n {dataset.index_manager.resolutions}') 372 | 373 | # =========================================================================== 374 | # Loading parameter 375 | # =========================================================================== 376 | 377 | logger.info(f"Loading parameter") 378 | start_epoch = 0 379 | start_epoch_step = 0 380 | train_steps = 0 381 | # Resume checkpoint if needed 382 | # if args.resume is not None or len(args.resume) > 0: 383 | if True: 384 | model, ema, start_epoch, start_epoch_step, train_steps = model_resume( 385 | args, model, ema, logger) 386 | 387 | if args.training_parts == "lora": 388 | lora_ckpt = args.lora_ckpt 389 | if lora_ckpt is not None: 390 | lastest_checkpoint = lora_ckpt 391 | from peft.peft_model import PeftModel 392 | print(f"Loading lora model from {lastest_checkpoint}") 393 | if args.use_fp16: 394 | model.module = PeftModel.from_pretrained( 395 | model.module, lastest_checkpoint, is_trainable=True) 396 | else: 397 | model = PeftModel.from_pretrained( 398 | model, lastest_checkpoint, is_trainable=True) 399 | else: 400 | loraconfig = LoraConfig( 401 | r=args.rank, 402 | lora_alpha=args.rank, 403 | target_modules=args.target_modules 404 | ) 405 | 406 | if args.use_fp16: 407 | model.module = get_peft_model(model.module, loraconfig) 408 | else: 409 | model = get_peft_model(model, loraconfig) 410 | 411 | logger.info(f" Training parts: {args.training_parts}") 412 | 413 | model, opt, scheduler = deepspeed_initialize( 414 | args, logger, model, opt, deepspeed_config) 415 | 416 | # =========================================================================== 417 | # Training 418 | # =========================================================================== 419 | 420 | # print model structure 421 | # for name, param in model.named_parameters(): 422 | # print(name, param.size()) 423 | # raise Exception("stop") 424 | 425 | model.train() 426 | if args.use_ema: 427 | ema.eval() 428 | 429 | print(f" Worker {rank} ready.") 430 | dist.barrier() 431 | 432 | iters_per_epoch = len(loader) 433 | logger.info( 434 | " ****************************** Running training ******************************") 435 | logger.info(f" Number GPUs: {world_size}") 436 | logger.info(f" Number training samples: {len(dataset):,}") 437 | logger.info( 438 | f" Number parameters: {sum(p.numel() for p in model.parameters()):,}") 439 | logger.info( 440 | f" Number trainable params: {sum(p.numel() for p in get_trainable_params(model)):,}") 441 | logger.info( 442 | " ------------------------------------------------------------------------------") 443 | logger.info(f" Iters per epoch: {iters_per_epoch:,}") 444 | logger.info(f" Batch size per device: {batch_size}") 445 | logger.info( 446 | f" Batch size all device: {batch_size * world_size * grad_accu_steps:,} (world_size * batch_size * grad_accu_steps)") 447 | logger.info(f" Gradient Accu steps: {args.grad_accu_steps}") 448 | logger.info( 449 | f" Total optimization steps: {args.epochs * iters_per_epoch // grad_accu_steps:,}") 450 | 451 | logger.info( 452 | f" Training epochs: {start_epoch}/{args.epochs}") 453 | logger.info( 454 | f" Training epoch steps: {start_epoch_step:,}/{iters_per_epoch:,}") 455 | logger.info( 456 | f" Training total steps: {train_steps:,}/{min(args.max_training_steps, args.epochs * iters_per_epoch):,}") 457 | logger.info( 458 | " ------------------------------------------------------------------------------") 459 | logger.info(f" Noise schedule: {args.noise_schedule}") 460 | logger.info( 461 | f" Beta limits: ({args.beta_start}, {args.beta_end})") 462 | logger.info(f" Learn sigma: {args.learn_sigma}") 463 | logger.info(f" Prediction type: {args.predict_type}") 464 | logger.info(f" Noise offset: {args.noise_offset}") 465 | 466 | logger.info( 467 | " ------------------------------------------------------------------------------") 468 | logger.info( 469 | f" Using EMA model: {args.use_ema} ({args.ema_dtype})") 470 | if args.use_ema: 471 | logger.info( 472 | f" Using EMA decay: {ema.max_value if args.use_ema else None}") 473 | logger.info( 474 | f" Using EMA warmup power: {ema.power if args.use_ema else None}") 475 | logger.info(f" Using main model fp16: {args.use_fp16}") 476 | logger.info(f" Using extra modules fp16: {args.extra_fp16}") 477 | logger.info( 478 | " ------------------------------------------------------------------------------") 479 | logger.info(f" Experiment directory: {experiment_dir}") 480 | logger.info( 481 | " *******************************************************************************") 482 | 483 | if args.gc_interval > 0: 484 | gc.disable() 485 | gc.collect() 486 | 487 | # Variables for monitoring/logging purposes: 488 | log_steps = 0 489 | running_loss = 0 490 | start_time = time.time() 491 | 492 | if args.async_ema: 493 | ema_stream = torch.cuda.Stream() 494 | 495 | easy_sample_images(args, vae, text_encoder, tokenizer, model, embedder_t5, 496 | target_height=768, target_width=1280, train_steps=0) 497 | pbar = PBar(args.epochs * len(loader)) 498 | 499 | # Training loop 500 | for epoch in range(start_epoch, args.epochs): 501 | logger.info(f" Start random shuffle with seed={seed}") 502 | # Makesure all processors use the same seed to shuffle dataset. 503 | dataset.shuffle(seed=args.global_seed + epoch, fast=True) 504 | logger.info(f" End of random shuffle") 505 | 506 | # Move sampler to start_index 507 | if not args.multireso: 508 | start_index = start_epoch_step * world_size * batch_size 509 | if start_index != sampler.start_index: 510 | sampler.start_index = start_index 511 | # Reset start_epoch_step to zero, to ensure next epoch will start from the beginning. 512 | start_epoch_step = 0 513 | logger.info(f" Iters left this epoch: {len(loader):,}") 514 | 515 | logger.info(f" Beginning epoch {epoch}...") 516 | step = 0 517 | for batch in loader: 518 | step += 1 519 | 520 | latents, model_kwargs = prepare_model_inputs( 521 | args, batch, device, vae, text_encoder, text_encoder_t5, freqs_cis_img) 522 | 523 | # training model by deepspeed while use fp16 524 | if args.use_fp16: 525 | if args.use_ema and args.async_ema: 526 | with torch.cuda.stream(ema_stream): 527 | ema.update(model.module.module, step=step) 528 | torch.cuda.current_stream().wait_stream(ema_stream) 529 | 530 | loss_dict = diffusion.training_losses( 531 | model=model, x_start=latents, model_kwargs=model_kwargs) 532 | loss = loss_dict["loss"].mean() 533 | model.backward(loss) 534 | last_batch_iteration = ( 535 | train_steps + 1) // (global_batch_size // (batch_size * world_size)) 536 | model.step( 537 | lr_kwargs={'last_batch_iteration': last_batch_iteration}) 538 | 539 | if args.use_ema and not args.async_ema or (args.async_ema and step == len(loader) - 1): 540 | if args.use_fp16: 541 | ema.update(model.module.module, step=step) 542 | else: 543 | ema.update(model.module, step=step) 544 | 545 | # =========================================================================== 546 | # Log loss values: 547 | # =========================================================================== 548 | running_loss += loss.item() 549 | log_steps += 1 550 | train_steps += 1 551 | if train_steps % args.log_every == 0: 552 | # Measure training speed: 553 | torch.cuda.synchronize() 554 | end_time = time.time() 555 | steps_per_sec = log_steps / (end_time - start_time) 556 | # Reduce loss history over all processes: 557 | avg_loss = torch.tensor( 558 | running_loss / log_steps, device=device) 559 | dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) 560 | avg_loss = avg_loss.item() / world_size 561 | # get lr from deepspeed fused optimizer 562 | logger.info(f"(step={train_steps:07d}) " + 563 | (f"(update_step={train_steps // args.grad_accu_steps:07d}) " if args.grad_accu_steps > 1 else "") + 564 | f"Train Loss: {avg_loss:.4f}, " 565 | f"Lr: {opt.param_groups[0]['lr']:.6g}, " 566 | f"Steps/Sec: {steps_per_sec:.2f}, " 567 | f"Samples/Sec: {int(steps_per_sec * batch_size * world_size):d}") 568 | # Reset monitoring variables: 569 | running_loss = 0 570 | log_steps = 0 571 | start_time = time.time() 572 | 573 | # collect gc: 574 | if args.gc_interval > 0 and (step % args.gc_interval == 0): 575 | gc.collect() 576 | 577 | pbar.step( 578 | f"Epoch {epoch}, step {step}, loss {loss.item():.4f}", args.epochs * len(loader), train_steps) 579 | 580 | if (train_steps % args.ckpt_every == 0 or train_steps % args.ckpt_latest_every == 0 # or train_steps == args.max_training_steps 581 | ) and train_steps > 0: 582 | easy_sample_images(args, vae, text_encoder, tokenizer, model, embedder_t5, 583 | target_height=768, target_width=1280, train_steps=train_steps) 584 | save_checkpoint(args, rank, logger, model, ema, 585 | epoch, train_steps, checkpoint_dir) 586 | 587 | if train_steps >= args.max_training_steps: 588 | logger.info(f"Breaking step loop at {train_steps}.") 589 | break 590 | 591 | if train_steps >= args.max_training_steps: 592 | logger.info(f"Breaking epoch loop at {epoch}.") 593 | break 594 | 595 | dist.destroy_process_group() 596 | -------------------------------------------------------------------------------- /hook_HYDiT_run.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import json 4 | import os 5 | import sys 6 | from types import SimpleNamespace 7 | 8 | import torch 9 | 10 | 11 | class SimpleNamespaceCNWarrper(SimpleNamespace): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | self.__dict__.update(kwargs) 15 | self.__iter__ = lambda: iter(kwargs.keys()) 16 | # is not iterable 17 | 18 | def __iter__(self): 19 | return iter(self.__dict__.keys()) 20 | # object has no attribute 'num_attention_heads' 21 | 22 | def __getattr__(self, name): 23 | return self.__dict__.get(name, None) 24 | 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser( 28 | conflict_handler='resolve', 29 | ) 30 | parser.add_argument("--sys_path", type=str, default="") 31 | parser.add_argument("--train_config_file", type=str, default="") 32 | parser.add_argument("--mz_master_port", type=int, default=0) 33 | args = parser.parse_args() 34 | 35 | master_port = args.mz_master_port 36 | 37 | print(f"master_port = {master_port}") 38 | 39 | try: 40 | from . import hook_HYDiT_utils 41 | except Exception as e: 42 | import hook_HYDiT_utils 43 | 44 | hook_HYDiT_utils.set_master_port(master_port) 45 | 46 | sys_path = args.sys_path 47 | if sys_path != "": 48 | sys.path.append(sys_path) 49 | 50 | print("HYDi run hook") 51 | 52 | try: 53 | from . import hook_HYDiT_main_train_deepspeed 54 | except Exception as e: 55 | import hook_HYDiT_main_train_deepspeed 56 | 57 | import hydit.config 58 | 59 | def _handle_conflict_error(self, *args, **kwargs): 60 | pass 61 | 62 | def parse_args(self, args=None, namespace=None): 63 | args, argv = self.parse_known_args(args, namespace) 64 | return args 65 | 66 | train_config_file = args.train_config_file 67 | 68 | if train_config_file == "": 69 | raise ValueError("train_config_file is empty") 70 | 71 | train_config = {} 72 | with open(train_config_file, "r") as f: 73 | train_config = json.load(f) 74 | 75 | argparse.ArgumentParser._handle_conflict_error = _handle_conflict_error 76 | argparse.ArgumentParser._handle_conflict_resolve = _handle_conflict_error 77 | argparse.ArgumentParser.parse_args = parse_args 78 | margs = hydit.config.get_args() 79 | margs.model = train_config.get("model", "DiT-g/2") 80 | 81 | margs.task_flag = train_config.get("task_flag") 82 | 83 | margs.resume_split = train_config.get("resume_split", True) 84 | 85 | margs.ema_to_module = train_config.get("ema_to_module", True) 86 | 87 | margs.deepspeed = False 88 | 89 | margs.predict_type = train_config.get("predict_type", "v_prediction") 90 | 91 | margs.training_parts = train_config.get("training_parts", "lora") 92 | 93 | margs.batch_size = train_config.get("batch_size", 1) 94 | 95 | margs.grad_accu_steps = train_config.get("grad_accu_steps", 1) 96 | 97 | margs.global_seed = train_config.get("global_seed", 0) 98 | 99 | margs.use_flash_attn = train_config.get("use_flash_attn", False) 100 | 101 | margs.use_fp16 = train_config.get("use_fp16", True) 102 | 103 | margs.qk_norm = train_config.get("qk_norm", True) 104 | 105 | margs.ema_dtype = train_config.get("ema_dtype", "fp32") 106 | 107 | margs.async_ema = False 108 | 109 | margs.ckpt_latest_every = 0x7fffffff 110 | 111 | margs.multireso = train_config.get("multireso", True) 112 | 113 | margs.epochs = train_config.get("epochs", 50) 114 | 115 | margs.target_ratios = train_config.get( 116 | "target_ratios", ['1:1', '3:4', '4:3', '16:9', '9:16']) 117 | 118 | margs.rope_img = train_config.get("rope_img", "base1024") 119 | 120 | margs.image_size = train_config.get("image_size", 1024) 121 | 122 | margs.rope_real = train_config.get("rope_real", True) 123 | 124 | margs.index_file = train_config.get("index_file", None) 125 | 126 | margs.lr = train_config.get("lr", 1e-5) 127 | 128 | margs.rank = train_config.get("rank", 8) 129 | 130 | margs.noise_offset = train_config.get("noise_offset", 0.0) 131 | 132 | margs.log_every = train_config.get("log_every", 99999999999999) 133 | 134 | margs.use_zero_stage = train_config.get("use_zero_stage", 2) 135 | 136 | margs.global_batch_size = train_config.get("global_batch_size", 1) 137 | 138 | margs.deepspeed = True 139 | 140 | margs.results_dir = train_config.get("results_dir") 141 | 142 | margs.mse_loss_weight_type = train_config.get( 143 | "mse_loss_weight_type", "constant") 144 | 145 | for k, v in train_config.items(): 146 | if hasattr(margs, k): 147 | setattr(margs, k, v) 148 | 149 | hook_HYDiT_utils.set_unet_path( 150 | train_config.get("unet_path")) 151 | 152 | hook_HYDiT_utils.set_vae_ema_path( 153 | train_config.get("vae_ema_path")) 154 | 155 | hook_HYDiT_utils.set_text_encoder_path( 156 | train_config.get("text_encoder_path")) 157 | 158 | hook_HYDiT_utils.set_tokenizer_path( 159 | train_config.get("tokenizer_path")) 160 | 161 | hook_HYDiT_utils.set_t5_encoder_path( 162 | train_config.get("t5_encoder_path")) 163 | 164 | hook_HYDiT_utils.set_train_config(train_config) 165 | 166 | try: 167 | # deepspeed/runtime/engine 168 | import deepspeed.runtime.engine 169 | deepspeed.runtime.engine.DeepSpeedEngine._do_sanity_check = lambda x: None 170 | except Exception as e: 171 | pass 172 | 173 | if type(margs.image_size) == int: 174 | margs.image_size = [margs.image_size, margs.image_size] 175 | hook_HYDiT_main_train_deepspeed.Core(margs) 176 | -------------------------------------------------------------------------------- /hook_HYDiT_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import json 4 | import os 5 | import time 6 | import torch 7 | 8 | 9 | UNET_PATH = "ckpts/t2i/model/pytorch_model_ema.pt" 10 | 11 | 12 | def set_unet_path(path): 13 | global UNET_PATH 14 | UNET_PATH = path 15 | 16 | 17 | VAE_EMA_PATH = "ckpts/t2i/sdxl-vae-fp16-fix" 18 | 19 | 20 | def set_vae_ema_path(path): 21 | global VAE_EMA_PATH 22 | VAE_EMA_PATH = path 23 | 24 | 25 | TOKENIZER = "ckpts/t2i/tokenizer" 26 | 27 | 28 | def set_tokenizer_path(path): 29 | global TOKENIZER 30 | TOKENIZER = path 31 | 32 | 33 | TEXT_ENCODER = 'ckpts/t2i/clip_text_encoder' 34 | 35 | 36 | def set_text_encoder_path(path): 37 | global TEXT_ENCODER 38 | TEXT_ENCODER = path 39 | 40 | 41 | T5_ENCODER = { 42 | 'MT5': None, 43 | 'attention_mask': True, 44 | 'layer_index': -1, 45 | 'attention_pool': True, 46 | 'torch_dtype': torch.float16, 47 | 'learnable_replace': True 48 | } 49 | 50 | 51 | def set_t5_encoder_path(path): 52 | global T5_ENCODER 53 | T5_ENCODER['MT5'] = path 54 | 55 | 56 | global TRAIN_CONFIG 57 | 58 | 59 | def set_train_config(train_config): 60 | global TRAIN_CONFIG 61 | TRAIN_CONFIG = train_config 62 | 63 | 64 | def easy_sample_images( 65 | args, 66 | vae=None, 67 | text_encoder=None, 68 | tokenizer=None, 69 | model=None, 70 | embedder_t5=None, 71 | target_height=768, 72 | target_width=1280, 73 | prompt="A photo of a girl with a hat on a sunny day", 74 | negative_prompt="", 75 | batch_size=1, 76 | guidance_scale=2.0, 77 | infer_steps=20, 78 | sampler='dpmpp_2m_karras', 79 | train_steps=0, 80 | seed=0, 81 | ): 82 | from hydit.diffusion.pipeline import StableDiffusionPipeline 83 | from diffusers import schedulers 84 | from hydit.constants import SAMPLER_FACTORY 85 | from hydit.modules.posemb_layers import get_fill_resize_and_crop, get_2d_rotary_pos_embed 86 | from hydit.modules.models import HUNYUAN_DIT_CONFIG 87 | 88 | import traceback 89 | with torch.cuda.amp.autocast(): 90 | 91 | workspace_dir = TRAIN_CONFIG.get("workspace_dir") 92 | sample_config_file = TRAIN_CONFIG.get("sample_config_file", None) 93 | if sample_config_file is None: 94 | print("sample_config_file is not set.") 95 | return 96 | try: 97 | sample_config = json.load(open(sample_config_file, "r")) 98 | except Exception as e: 99 | print(f"Failed to load sample_config_file: {sample_config_file}") 100 | return 101 | sample_images_dir = os.path.join(workspace_dir, "sample_images") 102 | os.makedirs(sample_images_dir, exist_ok=True) 103 | 104 | sampler_factory = SAMPLER_FACTORY.copy() 105 | 106 | sampler_factory["uni_pc"] = { 107 | 'scheduler': 'UniPCMultistepScheduler', 108 | 'name': 'UniPCMultistepScheduler', 109 | 'kwargs': { 110 | 'beta_schedule': 'scaled_linear', 111 | 'beta_start': 0.00085, 112 | 'beta_end': 0.03, 113 | 'prediction_type': 'v_prediction', 114 | 'trained_betas': None, 115 | 'solver_order': 2, 116 | } 117 | } 118 | sampler_factory["dpmpp_2m_karras"] = { 119 | 'scheduler': 'DPMSolverMultistepScheduler', 120 | 'name': 'DPMSolverMultistepScheduler', 121 | 'kwargs': { 122 | 'beta_schedule': 'scaled_linear', 123 | 'beta_start': 0.00085, 124 | 'beta_end': 0.03, 125 | 'prediction_type': 'v_prediction', 126 | 'trained_betas': None, 127 | 'solver_order': 2, 128 | 'algorithm_type': 'dpmsolver++', 129 | "use_karras_sigmas": True, 130 | } 131 | } 132 | 133 | # Load sampler from factory 134 | kwargs = sampler_factory[sampler]['kwargs'] 135 | scheduler = sampler_factory[sampler]['scheduler'] 136 | 137 | # Build scheduler according to the sampler. 138 | scheduler_class = getattr(schedulers, scheduler) 139 | scheduler = scheduler_class(**kwargs) 140 | 141 | # Set timesteps for inference steps. 142 | scheduler.set_timesteps(infer_steps, "cuda") 143 | 144 | def calc_rope(height, width): 145 | model_config = HUNYUAN_DIT_CONFIG["DiT-g/2"] 146 | patch_size = model_config['patch_size'] 147 | head_size = model_config['hidden_size'] // model_config['num_heads'] 148 | th = height // 8 // patch_size 149 | tw = width // 8 // patch_size 150 | base_size = 512 // 8 // patch_size 151 | start, stop = get_fill_resize_and_crop((th, tw), base_size) 152 | sub_args = [start, stop, (th, tw)] 153 | rope = get_2d_rotary_pos_embed(head_size, *sub_args) 154 | return rope 155 | 156 | pipeline = StableDiffusionPipeline(vae=vae, 157 | text_encoder=text_encoder, 158 | tokenizer=tokenizer, 159 | unet=model.module, 160 | scheduler=scheduler, 161 | feature_extractor=None, 162 | safety_checker=None, 163 | requires_safety_checker=False, 164 | embedder_t5=embedder_t5, 165 | ) 166 | pipeline = pipeline.to("cuda") 167 | # attr _execution_device is not defined 168 | 169 | style = torch.as_tensor([0, 0] * batch_size, device="cuda") 170 | 171 | src_size_cond = (target_width, target_height) 172 | size_cond = list(src_size_cond) + [target_width, target_height, 0, 0] 173 | image_meta_size = torch.as_tensor( 174 | [size_cond] * 2 * batch_size, device="cuda",) 175 | 176 | if type(sample_config) != list: 177 | sample_config = [sample_config] 178 | 179 | for i, sample in enumerate(sample_config): 180 | prompt = sample.get("prompt", "") 181 | negative_prompt = sample.get("negative_prompt", "") 182 | guidance_scale = sample.get("cfg", guidance_scale) 183 | infer_steps = sample.get("steps", infer_steps) 184 | width = sample.get("width", target_width) 185 | height = sample.get("height", target_height) 186 | 187 | freqs_cis_img = calc_rope(height, width) 188 | 189 | try: 190 | generator = torch.Generator(device="cuda") 191 | generator.manual_seed(seed) 192 | samples = pipeline( 193 | height=height, 194 | width=width, 195 | prompt=prompt, 196 | negative_prompt=negative_prompt, 197 | num_images_per_prompt=batch_size, 198 | guidance_scale=guidance_scale, 199 | num_inference_steps=infer_steps, 200 | style=style, 201 | return_dict=False, 202 | use_fp16=True, 203 | learn_sigma=args.learn_sigma, 204 | freqs_cis_img=freqs_cis_img, 205 | image_meta_size=image_meta_size, 206 | generator=generator, 207 | )[0] 208 | 209 | pass 210 | except Exception as e: 211 | print(f"Failed to sample images: {e} ") 212 | # 打印堆栈信息 213 | traceback.print_exc() 214 | print(f"Failed to sample pipeline: {pipeline} ") 215 | 216 | # print("samples:",type(samples),) 217 | # input("Press Enter to continue...") 218 | # print("samples:",samples,) 219 | 220 | if type(samples) == list: 221 | pil_image = samples[0] 222 | else: 223 | pil_image = samples 224 | 225 | sample_filename = f"{args.task_flag}_train_steps_{train_steps:07d}.png" 226 | sample_filename_path = os.path.join( 227 | sample_images_dir, sample_filename) 228 | pil_image.save(sample_filename_path) 229 | 230 | return None 231 | 232 | 233 | def model_resume(args, model, ema, logger): 234 | """ 235 | Load pretrained weights. 236 | """ 237 | start_epoch = 0 238 | start_epoch_step = 0 239 | train_steps = 0 240 | resume_path = UNET_PATH 241 | 242 | logger.info(f"Resume from checkpoint {resume_path}") 243 | 244 | if args.resume_split: 245 | # Resume main model 246 | 247 | resume_ckpt_module = torch.load( 248 | resume_path, map_location=lambda storage, loc: storage) 249 | model.load_state_dict(resume_ckpt_module, strict=False) 250 | 251 | # Resume ema model 252 | if args.use_ema: 253 | if args.module_to_ema: 254 | if "resume_ckpt_module" in locals(): 255 | logger.info(f" Resume ema model from main states.") 256 | ema.load_state_dict(resume_ckpt_module, strict=args.strict) 257 | else: 258 | logger.info(f" Resume ema model from module states.") 259 | resume_ckpt_module = torch.load( 260 | resume_path, map_location=lambda storage, loc: storage) 261 | ema.load_state_dict(resume_ckpt_module, strict=args.strict) 262 | else: 263 | if "resume_ckpt_ema" in locals(): 264 | logger.info(f" Resume ema model from EMA states.") 265 | ema.load_state_dict(resume_ckpt_ema, strict=args.strict) 266 | else: 267 | logger.info(f" Resume ema model from EMA states.") 268 | resume_ckpt_ema = torch.load(resume_path, 269 | map_location=lambda storage, loc: storage) 270 | ema.load_state_dict(resume_ckpt_ema, strict=args.strict) 271 | else: 272 | raise ValueError( 273 | " “If `resume` is True, then either `resume_split` must be true.”") 274 | 275 | return model, ema, start_epoch, start_epoch_step, train_steps 276 | 277 | 278 | import tqdm 279 | 280 | import requests 281 | 282 | 283 | def set_master_port(port): 284 | global master_port 285 | master_port = port 286 | 287 | 288 | master_port = 0 289 | 290 | 291 | def LOG(log): 292 | if master_port == 0: 293 | raise Exception("master_port is 0") 294 | # 发送http 295 | try: 296 | resp = requests.request("post", f"http://127.0.0.1:{master_port}/log", data=json.dumps(log), headers={ 297 | "Content-Type": "application/json"}) 298 | except Exception as e: 299 | return 300 | if resp.status_code != 200: 301 | raise Exception(f"LOG failed: {resp.text}") 302 | 303 | 304 | # with tqdm(total=total_steps, initial=train_steps) as pbar: 305 | # pbar.update(1) 306 | # pbar.set_description( 307 | # f"Epoch {epoch}, step {step}, loss {loss.item():.4f}, mean_loss {mean_loss / step:.4f}") 308 | class PBar: 309 | def __init__(self, total): 310 | self.pbar = tqdm.tqdm(total=total) 311 | 312 | def step(self, desc, total_steps, train_steps): 313 | self.pbar.update(1) 314 | self.pbar.set_description(desc) 315 | 316 | LOG({ 317 | "type": "sample_images", 318 | "global_step": train_steps, 319 | "total_steps": total_steps, 320 | # "latent": noise_pred_latent_path, 321 | }) 322 | 323 | 324 | from torch import nn 325 | 326 | 327 | class CustomizeEmbedsModel(nn.Module): 328 | dtype = torch.float16 329 | # x = torch.zeros(1, 1, 256, 2048) 330 | x = None 331 | 332 | def __init__(self, *args, **kwargs): 333 | super().__init__() 334 | 335 | def to(self, *args, **kwargs): 336 | self.dtype = torch.float16 337 | return self 338 | 339 | def forward(self, *args, **kwargs): 340 | input_ids = kwargs.get("input_ids", None) 341 | if self.x is None: 342 | if input_ids is None: 343 | batch_size = 1 344 | else: 345 | batch_size = input_ids.shape[0] 346 | self.x = torch.zeros(1, batch_size, 256, 2048, dtype=self.dtype) 347 | 348 | if kwargs.get("output_hidden_states", False): 349 | return { 350 | "hidden_states": self.x.to("cuda"), 351 | "input_ids": torch.zeros(1, 1), 352 | } 353 | return self.x 354 | 355 | 356 | class CustomizeTokenizer(dict): 357 | 358 | added_tokens_encoder = [] 359 | input_ids = torch.zeros(1, 256) 360 | attention_mask = torch.zeros(1, 256) 361 | 362 | def __init__(self, *args, **kwargs): 363 | self['added_tokens_encoder'] = self.added_tokens_encoder 364 | self['input_ids'] = self.input_ids 365 | self['attention_mask'] = self.attention_mask 366 | 367 | def tokenize(self, text): 368 | return text 369 | 370 | def __call__(self, *args, **kwargs): 371 | return self 372 | 373 | 374 | class CustomizeEmbeds(): 375 | def __init__(self): 376 | super().__init__() 377 | self.tokenizer = CustomizeTokenizer() 378 | self.model = CustomizeEmbedsModel().to("cuda") 379 | self.max_length = 256 380 | -------------------------------------------------------------------------------- /hook_kohya_ss_hunyuan_pipe.py: -------------------------------------------------------------------------------- 1 | 2 | import inspect 3 | import re 4 | from typing import Callable, List, Optional, Union 5 | 6 | import numpy as np 7 | import PIL.Image 8 | import torch 9 | from packaging import version 10 | from tqdm import tqdm 11 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 12 | 13 | from diffusers import SchedulerMixin, StableDiffusionPipeline 14 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 15 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker 16 | from diffusers.utils import logging 17 | from PIL import Image 18 | from library import sdxl_model_util, sdxl_train_util, train_util 19 | 20 | from library.sdxl_lpw_stable_diffusion import * 21 | 22 | 23 | from library.hunyuan_models import * 24 | 25 | 26 | def load_scheduler_sigmas(beta_start=0.00085, beta_end=0.018, num_train_timesteps=1000): 27 | betas = torch.linspace(beta_start**0.5, beta_end**0.5, 28 | num_train_timesteps, dtype=torch.float32) ** 2 29 | alphas = 1.0 - betas 30 | alphas_cumprod = torch.cumprod(alphas, dim=0) 31 | 32 | sigmas = np.array(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5) 33 | sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) 34 | sigmas = torch.from_numpy(sigmas) 35 | return alphas_cumprod, sigmas 36 | 37 | 38 | ATTN_MODE = "xformers" 39 | CLIP_TOKENS = 75 * 2 + 2 40 | DEVICE = "cuda" 41 | try: 42 | from k_diffusion.external import DiscreteVDDPMDenoiser 43 | from k_diffusion.sampling import sample_euler_ancestral, get_sigmas_exponential, sample_dpmpp_2m_sde 44 | except ImportError: 45 | import subprocess 46 | import sys 47 | 48 | subprocess.check_call( 49 | [sys.executable, "-m", "pip", "install", "k-diffusion"]) 50 | 51 | from k_diffusion.external import DiscreteVDDPMDenoiser 52 | from k_diffusion.sampling import sample_euler_ancestral, get_sigmas_exponential, sample_dpmpp_2m_sde 53 | 54 | 55 | from library.hunyuan_utils import get_cond, calc_rope 56 | 57 | 58 | class HuanYuanDiffusionLongPromptWeightingPipeline: 59 | 60 | def __init__( 61 | self, 62 | vae: AutoencoderKL, 63 | text_encoder, 64 | tokenizer, 65 | unet, 66 | scheduler, 67 | # clip_skip: int, 68 | safety_checker, 69 | feature_extractor, 70 | requires_safety_checker=False, 71 | clip_skip=0, 72 | ): 73 | # clip skip is ignored currently 74 | # print("tokenizer: ", tokenizer) 75 | # print("text_encoder: ", text_encoder) 76 | self.unet = unet 77 | self.scheduler = scheduler 78 | self.safety_checker = safety_checker 79 | self.feature_extractor = feature_extractor 80 | self.requires_safety_checker = requires_safety_checker 81 | self.vae = vae 82 | self.vae_scale_factor = 2 ** ( 83 | len(self.vae.config.block_out_channels) - 1) 84 | self.progress_bar = lambda x: tqdm(x, leave=False) 85 | 86 | self.clip_skip = clip_skip 87 | self.tokenizers = tokenizer 88 | self.text_encoders = text_encoder 89 | 90 | def to(self, device=None, dtype=None): 91 | if device is not None: 92 | self.device = device 93 | # self.vae.to(device=self.device) 94 | if dtype is not None: 95 | self.dtype = dtype 96 | 97 | @property 98 | def _execution_device(self): 99 | r""" 100 | Returns the device on which the pipeline's models will be executed. After calling 101 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module 102 | hooks. 103 | """ 104 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 105 | return self.device 106 | for module in self.unet.modules(): 107 | if ( 108 | hasattr(module, "_hf_hook") 109 | and hasattr(module._hf_hook, "execution_device") 110 | and module._hf_hook.execution_device is not None 111 | ): 112 | return torch.device(module._hf_hook.execution_device) 113 | return self.device 114 | 115 | def check_inputs(self, prompt, height, width, strength, callback_steps): 116 | pass 117 | 118 | def get_timesteps(self, num_inference_steps, strength, device, is_text2img): 119 | if is_text2img: 120 | return self.scheduler.timesteps.to(device), num_inference_steps 121 | else: 122 | # get the original timestep using init_timestep 123 | offset = self.scheduler.config.get("steps_offset", 0) 124 | init_timestep = int(num_inference_steps * strength) + offset 125 | init_timestep = min(init_timestep, num_inference_steps) 126 | 127 | t_start = max(num_inference_steps - init_timestep + offset, 0) 128 | timesteps = self.scheduler.timesteps[t_start:].to(device) 129 | return timesteps, num_inference_steps - t_start 130 | 131 | def run_safety_checker(self, image, device, dtype): 132 | return image, None 133 | 134 | def decode_latents(self, latents): 135 | with torch.no_grad(): 136 | latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents 137 | 138 | image = self.vae.decode(latents.to(self.vae.dtype)).sample 139 | image = (image / 2 + 0.5).clamp(0, 1) 140 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 141 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 142 | return image 143 | 144 | def prepare_extra_step_kwargs(self, generator, eta): 145 | return {} 146 | 147 | @torch.no_grad() 148 | def __call__( 149 | self, 150 | prompt: Union[str, List[str]], 151 | negative_prompt: Optional[Union[str, List[str]]] = None, 152 | image: Union[torch.FloatTensor, PIL.Image.Image] = None, 153 | mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, 154 | height: int = 512, 155 | width: int = 512, 156 | num_inference_steps: int = 50, 157 | guidance_scale: float = 7.5, 158 | strength: float = 0.8, 159 | num_images_per_prompt: Optional[int] = 1, 160 | eta: float = 0.0, 161 | generator: Optional[torch.Generator] = None, 162 | latents: Optional[torch.FloatTensor] = None, 163 | max_embeddings_multiples: Optional[int] = 3, 164 | output_type: Optional[str] = "pil", 165 | return_dict: bool = True, 166 | controlnet=None, 167 | controlnet_image=None, 168 | callback: Optional[Callable[[ 169 | int, int, torch.FloatTensor], None]] = None, 170 | is_cancelled_callback: Optional[Callable[[], bool]] = None, 171 | callback_steps: int = 1, 172 | ): 173 | 174 | BETA_END = 0.018 175 | CFG_SCALE = guidance_scale 176 | STEPS = num_inference_steps 177 | patch_size = 2 178 | num_heads = 88 179 | 180 | alphas, sigmas = load_scheduler_sigmas(beta_end=BETA_END) 181 | denoiser = self.unet 182 | clip_tokenizer = self.tokenizers[0] 183 | clip_encoder = self.text_encoders[0] 184 | 185 | mt5_embedder = self.text_encoders[1] 186 | 187 | vae = self.vae 188 | denoiser.eval() 189 | denoiser.disable_fp32_silu() 190 | denoiser.disable_fp32_layer_norm() 191 | denoiser.set_attn_mode(ATTN_MODE) 192 | vae.requires_grad_(False) 193 | mt5_embedder.to(torch.float16) 194 | 195 | with torch.autocast("cuda"): 196 | clip_h, clip_m, mt5_h, mt5_m = get_cond( 197 | prompt, 198 | mt5_embedder, 199 | clip_tokenizer, 200 | clip_encoder, 201 | # Should be same as original implementation with max_length_clip=77 202 | # Support 75*n + 2 203 | max_length_clip=CLIP_TOKENS, 204 | ) 205 | neg_clip_h, neg_clip_m, neg_mt5_h, neg_mt5_m = get_cond( 206 | negative_prompt, 207 | mt5_embedder, 208 | clip_tokenizer, 209 | clip_encoder, 210 | max_length_clip=CLIP_TOKENS, 211 | ) 212 | clip_h = torch.concat([clip_h, neg_clip_h], dim=0) 213 | clip_m = torch.concat([clip_m, neg_clip_m], dim=0) 214 | mt5_h = torch.concat([mt5_h, neg_mt5_h], dim=0) 215 | mt5_m = torch.concat([mt5_m, neg_mt5_m], dim=0) 216 | torch.cuda.empty_cache() 217 | 218 | style = torch.as_tensor([0] * 2, device=DEVICE) 219 | # src hw, dst hw, 0, 0 220 | size_cond = [height, width, height, width, 0, 0] 221 | image_meta_size = torch.as_tensor([size_cond] * 2, device=DEVICE) 222 | 223 | freqs_cis_img = calc_rope(height, width, patch_size, num_heads) 224 | 225 | denoiser_wrapper = DiscreteVDDPMDenoiser( 226 | # A quick patch for learn_sigma 227 | lambda *args, **kwargs: denoiser(* \ 228 | args, **kwargs).chunk(2, dim=1)[0], 229 | alphas, 230 | False, 231 | ).to(DEVICE) 232 | 233 | def cfg_denoise_func(x, sigma): 234 | cond, uncond = denoiser_wrapper( 235 | x.repeat(2, 1, 1, 1), 236 | sigma.repeat(2), 237 | encoder_hidden_states=clip_h, 238 | text_embedding_mask=clip_m, 239 | encoder_hidden_states_t5=mt5_h, 240 | text_embedding_mask_t5=mt5_m, 241 | image_meta_size=image_meta_size, 242 | style=style, 243 | cos_cis_img=freqs_cis_img[0], 244 | sin_cis_img=freqs_cis_img[1], 245 | ).chunk(2, dim=0) 246 | return uncond + (cond - uncond) * CFG_SCALE 247 | 248 | sigmas = denoiser_wrapper.get_sigmas(STEPS).to(DEVICE) 249 | sigmas = get_sigmas_exponential( 250 | STEPS, denoiser_wrapper.sigma_min, denoiser_wrapper.sigma_max, DEVICE 251 | ) 252 | x1 = torch.randn(1, 4, height // 8, width // 8, 253 | dtype=torch.float16, device=DEVICE) 254 | 255 | with torch.autocast("cuda"): 256 | sample = sample_dpmpp_2m_sde( 257 | cfg_denoise_func, 258 | x1 * sigmas[0], 259 | sigmas, 260 | ) 261 | torch.cuda.empty_cache() 262 | latents = sample 263 | return latents 264 | 265 | def text2img( 266 | self, 267 | prompt: Union[str, List[str]], 268 | negative_prompt: Optional[Union[str, List[str]]] = None, 269 | height: int = 512, 270 | width: int = 512, 271 | num_inference_steps: int = 50, 272 | guidance_scale: float = 7.5, 273 | num_images_per_prompt: Optional[int] = 1, 274 | eta: float = 0.0, 275 | generator: Optional[torch.Generator] = None, 276 | latents: Optional[torch.FloatTensor] = None, 277 | max_embeddings_multiples: Optional[int] = 3, 278 | output_type: Optional[str] = "pil", 279 | return_dict: bool = True, 280 | callback: Optional[Callable[[ 281 | int, int, torch.FloatTensor], None]] = None, 282 | is_cancelled_callback: Optional[Callable[[], bool]] = None, 283 | callback_steps: int = 1, 284 | ): 285 | 286 | return self.__call__( 287 | prompt=prompt, 288 | negative_prompt=negative_prompt, 289 | height=height, 290 | width=width, 291 | num_inference_steps=num_inference_steps, 292 | guidance_scale=guidance_scale, 293 | num_images_per_prompt=num_images_per_prompt, 294 | eta=eta, 295 | generator=generator, 296 | latents=latents, 297 | max_embeddings_multiples=max_embeddings_multiples, 298 | output_type=output_type, 299 | return_dict=return_dict, 300 | callback=callback, 301 | is_cancelled_callback=is_cancelled_callback, 302 | callback_steps=callback_steps, 303 | ) 304 | 305 | def latents_to_image(self, latents): 306 | # 9. Post-processing 307 | image = self.decode_latents(latents.to(self.vae.dtype)) 308 | image = self.numpy_to_pil(image) 309 | return image 310 | 311 | # copy from pil_utils.py 312 | def numpy_to_pil(self, images: np.ndarray) -> Image.Image: 313 | """ 314 | Convert a numpy image or a batch of images to a PIL image. 315 | """ 316 | if images.ndim == 3: 317 | images = images[None, ...] 318 | images = (images * 255).round().astype("uint8") 319 | if images.shape[-1] == 1: 320 | # special case for grayscale (single channel) images 321 | pil_images = [Image.fromarray( 322 | image.squeeze(), mode="L") for image in images] 323 | else: 324 | pil_images = [Image.fromarray(image) for image in images] 325 | 326 | return pil_images 327 | -------------------------------------------------------------------------------- /hook_kohya_ss_run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.system("title hook_kohya_ss_run") 4 | import random 5 | import time 6 | 7 | import torch 8 | import logging 9 | import sys 10 | import json 11 | import importlib 12 | import argparse 13 | import toml 14 | 15 | 16 | def config2args(train_parser: argparse.ArgumentParser, config): 17 | 18 | config_args_list = [] 19 | for key, value in config.items(): 20 | if type(value) == bool: 21 | if value: 22 | config_args_list.append(f"--{key}") 23 | else: 24 | config_args_list.append(f"--{key}") 25 | config_args_list.append(str(value)) 26 | args = train_parser.parse_args(config_args_list) 27 | return args 28 | 29 | 30 | from PIL import Image 31 | 32 | 33 | import numpy as np 34 | import tempfile 35 | import safetensors.torch 36 | 37 | 38 | import sys 39 | sys.path.append(os.path.dirname(__file__)) 40 | try: 41 | import hook_kohya_ss_utils 42 | except: 43 | from . import hook_kohya_ss_utils 44 | 45 | other_config = {} 46 | original_save_model = None 47 | 48 | 49 | train_config = {} 50 | 51 | sample_images_pipe_class = None 52 | 53 | 54 | def utils_sample_images(*args, **kwargs): 55 | return sample_images(None, *args, **kwargs) 56 | 57 | 58 | def get_datasets(): 59 | import library.config_util 60 | user_config = library.config_util.load_user_config( 61 | train_config.get("dataset_config", None)) 62 | datasets = user_config.get("datasets", []) 63 | if len(datasets) == 0: 64 | return None 65 | return datasets[0] 66 | 67 | 68 | def sample_images(self, *args, **kwargs): 69 | # accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet 70 | accelerator = args[0] 71 | cmd_args = args[1] 72 | epoch = args[2] 73 | global_step = args[3] 74 | device = args[4] 75 | vae = args[5] 76 | tokenizer = args[6] 77 | text_encoder = args[7] 78 | unet = args[8] 79 | 80 | # print(f"sample_images: args = {args}") 81 | # print(f"sample_images: kwargs = {kwargs}") 82 | 83 | controlnet = kwargs.get("controlnet", None) 84 | 85 | if epoch is not None and cmd_args.save_every_n_epochs is not None and epoch % cmd_args.save_every_n_epochs == 0: 86 | 87 | datasets = get_datasets() 88 | resolution = datasets.get("resolution", (512, 512)) 89 | if isinstance(resolution, int): 90 | resolution = (resolution, resolution) 91 | height, width = resolution 92 | print(f"sample_images: height = {height}, width = {width}") 93 | 94 | prompt_dict_list = other_config.get("prompt_dict_list", []) 95 | if len(prompt_dict_list) == 0: 96 | sample_prompt = other_config.get("sample_prompt", None) 97 | if sample_prompt is not None: 98 | seed = other_config.get("seed", 0) 99 | prompt_dict = { 100 | "controlnet_image": other_config.get("controlnet_image", None), 101 | "prompt": other_config.get("sample_prompt", ""), 102 | "seed": seed, 103 | "negative_prompt": "", 104 | "enum": 0, 105 | "sample_sampler": "euler_a", 106 | "sample_steps": 20, 107 | "scale": 5.0, 108 | "height": height, 109 | "width": width, 110 | } 111 | # 112 | prompt_dict_list.append(prompt_dict) 113 | else: 114 | for i, prompt_dict in enumerate(prompt_dict_list): 115 | if prompt_dict.get("controlnet_image", None) is None: 116 | prompt_dict["controlnet_image"] = None 117 | if prompt_dict.get("seed", None) is None: 118 | prompt_dict["seed"] = 0 119 | if prompt_dict.get("negative_prompt", None) is None: 120 | prompt_dict["negative_prompt"] = "" 121 | if prompt_dict.get("enum", None) is None: 122 | prompt_dict["enum"] = i 123 | 124 | if prompt_dict_list is not None and len(prompt_dict_list) > 0: 125 | hook_kohya_ss_utils.generate_image( 126 | pipe_class=sample_images_pipe_class, 127 | cmd_args=cmd_args, 128 | accelerator=accelerator, 129 | epoch=epoch, 130 | text_encoder=text_encoder, 131 | tokenizer=tokenizer, 132 | unet=unet, 133 | vae=vae, 134 | prompt_dict_list=prompt_dict_list, 135 | controlnet=controlnet, 136 | ) 137 | 138 | LOG({ 139 | "type": "sample_images", 140 | "global_step": global_step, 141 | "total_steps": cmd_args.max_train_steps, 142 | # "latent": noise_pred_latent_path, 143 | }) 144 | 145 | 146 | def run_lora_sd1_5(): 147 | hook_kohya_ss_utils.hook_kohya_ss() 148 | 149 | # 覆盖sample_images生成函数,包括进度条和生成图片功能 150 | import train_network 151 | train_network.NetworkTrainer.sample_images = sample_images 152 | 153 | # 配置对应的pipeline 154 | import library.train_util 155 | global sample_images_pipe_class 156 | sample_images_pipe_class = library.train_util.StableDiffusionLongPromptWeightingPipeline 157 | 158 | trainer = train_network.NetworkTrainer() 159 | train_args = config2args(train_network.setup_parser(), train_config) 160 | 161 | LOG({ 162 | "type": "start_train", 163 | }) 164 | trainer.train(train_args) 165 | 166 | 167 | def run_lora_sdxl(): 168 | hook_kohya_ss_utils.hook_kohya_ss() 169 | 170 | # 覆盖sample_images生成函数,包括进度条和生成图片功能 171 | import sdxl_train_network 172 | sdxl_train_network.SdxlNetworkTrainer.sample_images = sample_images 173 | 174 | # 配置对应的pipeline 175 | import library.sdxl_train_util 176 | global sample_images_pipe_class 177 | sample_images_pipe_class = library.sdxl_train_util.SdxlStableDiffusionLongPromptWeightingPipeline 178 | 179 | trainer = sdxl_train_network.SdxlNetworkTrainer() 180 | train_args = config2args(sdxl_train_network.setup_parser(), train_config) 181 | 182 | LOG({ 183 | "type": "start_train", 184 | }) 185 | trainer.train(train_args) 186 | 187 | 188 | from types import SimpleNamespace 189 | 190 | 191 | class SimpleNamespaceCNWarrper(SimpleNamespace): 192 | def __init__(self, *args, **kwargs): 193 | super().__init__(*args, **kwargs) 194 | self.__dict__.update(kwargs) # or self.__dict__ = kwargs 195 | self.__dict__["mid_block_type"] = "UNetMidBlock2DCrossAttn" 196 | self.__dict__["_diffusers_version"] = "0.6.0" 197 | self.__iter__ = lambda: iter(kwargs.keys()) 198 | # is not iterable 199 | 200 | def __iter__(self): 201 | return iter(self.__dict__.keys()) 202 | # object has no attribute 'num_attention_heads' 203 | 204 | def __getattr__(self, name): 205 | return self.__dict__.get(name, None) 206 | 207 | 208 | def run_controlnet_sd1_5(): 209 | import types 210 | types.SimpleNamespace = SimpleNamespaceCNWarrper 211 | hook_kohya_ss_utils.hook_kohya_ss() 212 | # 覆盖sample_images生成函数,包括进度条和生成图片功能 213 | 214 | import train_controlnet 215 | 216 | # 配置对应的pipeline 217 | import library.train_util 218 | library.train_util.sample_images = utils_sample_images 219 | 220 | global sample_images_pipe_class 221 | sample_images_pipe_class = library.train_util.StableDiffusionLongPromptWeightingPipeline 222 | 223 | train_args = config2args(train_controlnet.setup_parser(), train_config) 224 | 225 | LOG({ 226 | "type": "start_train", 227 | }) 228 | 229 | train_controlnet.train(train_args) 230 | 231 | 232 | def run_lora_hunyuan1_2(): 233 | hook_kohya_ss_utils.hook_kohya_ss() 234 | 235 | # 覆盖sample_images生成函数,包括进度条和生成图片功能 236 | import hunyuan_train_network 237 | 238 | # 还未实现 239 | hunyuan_train_network.HunYuanNetworkTrainer.sample_images = sample_images 240 | 241 | # def empty_sample_images(*args, **kwargs): 242 | # pass 243 | # sample_images_pipe_class = empty_sample_images 244 | 245 | import hook_kohya_ss_hunyuan_pipe 246 | global sample_images_pipe_class 247 | sample_images_pipe_class = hook_kohya_ss_hunyuan_pipe.HuanYuanDiffusionLongPromptWeightingPipeline 248 | # 配置对应的pipeline 249 | import library.hunyuan_utils 250 | 251 | print(json.dumps(other_config, indent=4)) 252 | hunyuan_models_config = other_config.get("hunyuan_models_config", None) 253 | 254 | from transformers import ( 255 | AutoTokenizer, 256 | T5Tokenizer, 257 | BertModel, 258 | BertTokenizer, 259 | ) 260 | from diffusers import AutoencoderKL, LMSDiscreteScheduler 261 | 262 | def hunyuan_load_tokenizers(): 263 | tokenizer = AutoTokenizer.from_pretrained( 264 | hunyuan_models_config["tokenizer_path"], 265 | local_files_only=True, 266 | ) 267 | tokenizer.eos_token_id = tokenizer.sep_token_id 268 | t5_encoder_path = hunyuan_models_config.get("t5_encoder_path", None) 269 | if t5_encoder_path == "none": 270 | t5_encoder_path = None 271 | tokenizer2 = None 272 | if t5_encoder_path is not None: 273 | tokenizer2 = T5Tokenizer.from_pretrained( 274 | t5_encoder_path, 275 | local_files_only=True, 276 | ) 277 | 278 | return [tokenizer, tokenizer2] 279 | 280 | library.hunyuan_utils.load_tokenizers = hunyuan_load_tokenizers 281 | 282 | def hunyuan_load_model(model_path: str, dtype=torch.float16, device="cuda", use_extra_cond=False, dit_path=None): 283 | 284 | dit_path = hunyuan_models_config.get("unet_path", None) 285 | 286 | import library.hunyuan_models 287 | # from hunyuan_models import MT5Embedder, HunYuanDiT, BertModel, DiT_g_2 288 | MT5Embedder = library.hunyuan_models.MT5Embedder 289 | HunYuanDiT = library.hunyuan_models.HunYuanDiT 290 | BertModel = library.hunyuan_models.BertModel 291 | DiT_g_2 = library.hunyuan_models.DiT_g_2 292 | 293 | denoiser, patch_size, head_dim = DiT_g_2( 294 | input_size=(128, 128), use_extra_cond=use_extra_cond) 295 | if dit_path is not None: 296 | state_dict = torch.load(dit_path) 297 | if 'state_dict' in state_dict: 298 | state_dict = state_dict['state_dict'] 299 | else: 300 | state_dict = torch.load(os.path.join( 301 | model_path, "denoiser/pytorch_model_module.pt")) 302 | denoiser.load_state_dict(state_dict) 303 | denoiser.to(device).to(dtype) 304 | 305 | clip_tokenizer = AutoTokenizer.from_pretrained( 306 | hunyuan_models_config["tokenizer_path"], 307 | local_files_only=True, 308 | ) 309 | clip_tokenizer.eos_token_id = 2 310 | clip_encoder = ( 311 | BertModel.from_pretrained( 312 | hunyuan_models_config["text_encoder_path"], 313 | local_files_only=True, 314 | ).to(device).to(dtype) 315 | ) 316 | 317 | t5_encoder_path = hunyuan_models_config.get("t5_encoder_path", None) 318 | if t5_encoder_path == "none": 319 | t5_encoder_path = None 320 | mt5_embedder = None 321 | if t5_encoder_path is not None: 322 | mt5_embedder = ( 323 | MT5Embedder( 324 | model_dir=hunyuan_models_config["t5_encoder_path"], 325 | torch_dtype=dtype, 326 | max_length=256) 327 | .to(device) 328 | .to(dtype) 329 | ) 330 | else: 331 | 332 | batch_size = train_args.train_batch_size 333 | import library.config_util 334 | user_config = library.config_util.load_user_config( 335 | train_args.dataset_config) 336 | datasets = user_config.get("datasets", []) 337 | if len(datasets) > 0: 338 | batch_size = datasets[0].get("batch_size", batch_size) 339 | mt5_embedder = ( 340 | hook_kohya_ss_utils.CustomizeMT5Embedder( 341 | batch_size=batch_size, 342 | ) 343 | .to(device) 344 | .to(dtype) 345 | ) 346 | 347 | vae = ( 348 | AutoencoderKL.from_pretrained( 349 | hunyuan_models_config["vae_ema_path"], 350 | local_files_only=True, 351 | ) 352 | .to(device) 353 | .to(dtype) 354 | ) 355 | vae.requires_grad_(False) 356 | return ( 357 | denoiser, 358 | patch_size, 359 | head_dim, 360 | clip_tokenizer, 361 | clip_encoder, 362 | mt5_embedder, 363 | vae, 364 | ) 365 | 366 | library.hunyuan_utils.load_model = hunyuan_load_model 367 | 368 | trainer = hunyuan_train_network.HunYuanNetworkTrainer() 369 | train_args = config2args( 370 | hunyuan_train_network.setup_parser(), train_config) 371 | print(f"train_args = {train_args}") 372 | 373 | LOG({ 374 | "type": "start_train", 375 | }) 376 | trainer.train(train_args) 377 | 378 | 379 | func_map = { 380 | "run_lora_sd1_5": run_lora_sd1_5, 381 | "run_lora_sdxl": run_lora_sdxl, 382 | "run_controlnet_sd1_5": run_controlnet_sd1_5, 383 | "run_lora_hunyuan1_2": run_lora_hunyuan1_2, 384 | } 385 | 386 | 387 | import requests 388 | 389 | 390 | def LOG(log): 391 | try: 392 | # 发送http 393 | resp = requests.request("post", f"http://127.0.0.1:{master_port}/log", data=json.dumps(log), headers={ 394 | "Content-Type": "application/json"}) 395 | if resp.status_code != 200: 396 | # raise Exception(f"LOG failed: {resp.text}") 397 | print(f"LOG failed: {resp.text}") 398 | except Exception as e: 399 | print(f"LOG failed: {e}") 400 | 401 | 402 | if __name__ == "__main__": 403 | try: 404 | parser = argparse.ArgumentParser() 405 | parser.add_argument("--sys_path", type=str, default="") 406 | parser.add_argument("--config", type=str, default="") 407 | parser.add_argument("--train_func", type=str, default="") 408 | parser.add_argument("--master_port", type=int, default=0) 409 | args = parser.parse_args() 410 | 411 | master_port = args.master_port 412 | 413 | print(f"master_port = {master_port}") 414 | 415 | sys_path = args.sys_path 416 | if sys_path != "": 417 | sys.path.append(sys_path) 418 | 419 | config_file = args.config 420 | if config_file == "": 421 | raise Exception("train_config is empty") 422 | 423 | global_config = {} 424 | with open(config_file, "r") as f: 425 | _global_config = f.read() 426 | global_config = json.loads(_global_config) 427 | 428 | train_config = global_config.get("train_config") 429 | print(f"""=======================train_config======================= 430 | {json.dumps(train_config, indent=4, ensure_ascii=False)} 431 | """) 432 | 433 | other_config = global_config.get("other_config", {}) 434 | print(f"""=======================other_config======================= 435 | {json.dumps(other_config, indent=4, ensure_ascii=False)} 436 | """) 437 | 438 | train_func = args.train_func 439 | if train_func == "": 440 | raise Exception("train_func is empty") 441 | 442 | print(f"train_func = {train_func}") 443 | 444 | time.sleep(2) 445 | LOG({ 446 | "type": "Read configuration completed!", 447 | }) 448 | 449 | func_map[train_func]() 450 | except Exception as e: 451 | print(f"Exception: {e}") 452 | if sys.platform == "win32": 453 | input("Press Enter to continue...") -------------------------------------------------------------------------------- /hook_kohya_ss_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | import json 5 | import os 6 | from typing import * 7 | import torch 8 | from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline 9 | from transformers import CLIPTokenizer 10 | import requests 11 | 12 | _requests_get = requests.get 13 | 14 | source_replacement_table = { 15 | "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml": os.path.join( 16 | os.path.dirname(__file__), "configs", "models_config", "stable-diffusion-v1.5", "v1-inference.yaml"), 17 | "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml": os.path.join( 18 | os.path.dirname(__file__), "configs", "models_config", "stable-diffusion-xl", "sd_xl_base.yaml"), 19 | "https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/tokenizer_config.json": os.path.join( 20 | os.path.dirname(__file__), "configs", "models_config", "clip-vit-large-patch14", "tokenizer_config.json"), 21 | "https://huggingface.co/api/models/stabilityai/stable-diffusion-3-medium-diffusers/revision/main": os.path.join( 22 | os.path.dirname(__file__), "configs", "models_config", "stable-diffusion-3-medium-diffusers", "revision.json"), 23 | } 24 | 25 | source_replacement_dir = { 26 | "https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers/resolve/b1148b4028b9ec56ebd36444c193d56aeff7ab56": os.path.join( 27 | os.path.dirname(__file__), "configs", "models_config", "stable-diffusion-3-medium-diffusers"), 28 | } 29 | 30 | 31 | class DictWrapper: 32 | def __init__(self, d): 33 | self.d = d 34 | 35 | def __getattribute__(self, name: str): 36 | if name == "content": 37 | return self.d["content"] 38 | if name == "raise_for_status": 39 | return lambda: None 40 | if name == "json": 41 | return lambda: json.loads(self.d["content"]) 42 | if name == "status_code": 43 | return self.d["status_code"] 44 | if name == "headers": 45 | return { 46 | "Location": self.d["Location"], 47 | "Content-Length": len(self.d["content"]), 48 | } 49 | if name == "request": 50 | return None 51 | return super().__getattribute__(name) 52 | 53 | 54 | def request_wrapper(*args, **kwargs): 55 | url = args[1] 56 | print(f"request_wrapper requesting {url}") 57 | if url in source_replacement_table: 58 | with open(source_replacement_table[url], "rb") as f: 59 | return DictWrapper({ 60 | "Location": url, 61 | "content": f.read(), 62 | "status_code": 200, 63 | }) 64 | 65 | print(f"request_wrapper requesting {url} from original requests") 66 | return _requests_get(*args, **kwargs) 67 | 68 | 69 | from requests import api 70 | from requests import Session 71 | last_request = api.request 72 | original_session_request = Session.request 73 | api.request = request_wrapper 74 | 75 | 76 | def Session_request_wrapper(cls, method, url, **kwargs): 77 | if url.startswith("http://127.0.0.1"): 78 | return original_session_request(cls, method, url, **kwargs) 79 | # print(f"Session_request_wrapper requesting {url}") 80 | # print(f"Session_request_wrapper requesting kwargs: {kwargs}") 81 | if url in source_replacement_table: 82 | with open(source_replacement_table[url], "rb") as f: 83 | return DictWrapper({ 84 | "Location": url, 85 | "content": f.read(), 86 | "status_code": 200, 87 | }) 88 | 89 | for k, v in source_replacement_dir.items(): 90 | if url.startswith(k): 91 | file_path = source_replacement_dir[k] + url[len(k):] 92 | # print( 93 | # f"source_replacement_dir:{k}||||||||||||||||||||| {source_replacement_dir[k]} ||||||||||||||||||| {file_path}") 94 | with open(file_path, "rb") as f: 95 | return DictWrapper({ 96 | "Location": url, 97 | "content": f.read(), 98 | "status_code": 200, 99 | }) 100 | raise NotImplementedError("Session.request is not supported") 101 | 102 | 103 | Session.request = Session_request_wrapper 104 | 105 | import huggingface_hub.file_download 106 | 107 | 108 | def _hf_hub_download_to_cache_dir(repo_id, filename, *args, **kwargs): 109 | print(f"_hf_hub_download_to_cache_dir: {args}") 110 | print(f"_hf_hub_download_to_cache_dir: {kwargs}") 111 | if repo_id == "stabilityai/stable-diffusion-3-medium-diffusers": 112 | return os.path.join( 113 | os.path.dirname(__file__), "configs", "models_config", "stable-diffusion-3-medium-diffusers", filename) 114 | raise NotImplementedError("_hf_hub_download_to_cache_dir is not supported") 115 | 116 | 117 | huggingface_hub.file_download._hf_hub_download_to_cache_dir = _hf_hub_download_to_cache_dir 118 | 119 | import diffusers.loaders.single_file 120 | original_snapshot_download = diffusers.loaders.single_file.snapshot_download 121 | 122 | 123 | def _snapshot_download(repo_id, *args, **kwargs): 124 | print(f"_snapshot_download: {repo_id}") 125 | if repo_id == "stabilityai/stable-diffusion-3-medium-diffusers": 126 | return os.path.join( 127 | os.path.dirname(__file__), "configs", "models_config", "stable-diffusion-3-medium-diffusers",) 128 | if repo_id == "runwayml/stable-diffusion-v1-5": 129 | return os.path.join( 130 | os.path.dirname(__file__), "configs", "models_config", "stable-diffusion-v1-5",) 131 | if repo_id == "stabilityai/stable-diffusion-xl-base-1.0": 132 | return os.path.join( 133 | os.path.dirname(__file__), "configs", "models_config", "stable-diffusion-xl-base-1.0",) 134 | # return original_snapshot_download(repo_id, *args, **kwargs) 135 | raise NotImplementedError("_snapshot_download is not supported") 136 | 137 | 138 | diffusers.loaders.single_file.snapshot_download = _snapshot_download 139 | 140 | original_load_target_model = None 141 | 142 | 143 | def setup_logging(*args, **kwargs): 144 | pass 145 | 146 | 147 | clip_large_tokenizer = None 148 | clip_big_tokenizer = None 149 | 150 | 151 | class TokenizersWrapper: 152 | typed = None 153 | model_max_length = 77 154 | 155 | def __init__(self, t): 156 | self.model_max_length = 77 157 | self.typed = t 158 | 159 | def __getattribute__(self, name: str): 160 | # print(f"TokenizersWrapper.__getattribute__ {name}") 161 | if name == "model_max_length": 162 | return 77 163 | try: 164 | typed = object.__getattribute__(self, "typed") 165 | if typed == "clip_large" and clip_large_tokenizer is not None: 166 | return clip_large_tokenizer.__getattribute__(name) 167 | if typed == "clip_big" and clip_big_tokenizer is not None: 168 | return clip_big_tokenizer.__getattribute__(name) 169 | except: 170 | pass 171 | 172 | return object.__getattribute__(self, name) 173 | 174 | def __call__(self, *args, **kargs): 175 | if self.typed == "clip_large": 176 | return clip_large_tokenizer(*args, **kargs) 177 | if self.typed == "clip_big": 178 | return clip_big_tokenizer(*args, **kargs) 179 | 180 | raise NotImplementedError( 181 | f"TokenizersWrapper: {self.typed} is not supported") 182 | 183 | 184 | from transformers import AutoTokenizer, MT5EncoderModel 185 | from torch import nn 186 | 187 | 188 | class CustomizeEmbedsModel(nn.Module): 189 | dtype = torch.float16 190 | shared = None 191 | # x = torch.zeros(1, 1, 256, 2048) 192 | x = None 193 | 194 | def __init__(self, *args, **kwargs): 195 | super().__init__() 196 | 197 | def to(self, *args, **kwargs): 198 | return self 199 | 200 | def forward(self, *args, **kwargs): 201 | # print("CustomizeEmbedsModel forward: args:", args) 202 | # print("CustomizeEmbedsModel forward: kwargs:", kwargs) 203 | input_ids = kwargs.get("input_ids", None) 204 | # if self.x is None: 205 | if True: 206 | if input_ids is None: 207 | batch_size = 1 208 | else: 209 | batch_size = input_ids.shape[0] 210 | 211 | attention_mask = kwargs.get("attention_mask") 212 | attention_mask_dim = attention_mask.shape[1] 213 | self.x = torch.zeros(1, batch_size, 256, 2048, dtype=self.dtype) 214 | 215 | if kwargs.get("output_hidden_states", False): 216 | return { 217 | "hidden_states": self.x.to("cuda"), 218 | "input_ids": torch.zeros(1, 1), 219 | } 220 | return self.x 221 | 222 | 223 | class CustomizeTokenizer(dict): 224 | added_tokens_encoder = [] 225 | input_ids = None 226 | attention_mask = None 227 | batch_size = 1 228 | 229 | def __init__(self, *args, **kwargs): 230 | self['added_tokens_encoder'] = self.added_tokens_encoder 231 | self['input_ids'] = self.input_ids 232 | self['attention_mask'] = self.attention_mask 233 | self.batch_size = kwargs.get("batch_size", 1) 234 | 235 | def tokenize(self, text): 236 | return text 237 | 238 | def __call__(self, *args, **kwargs): 239 | # print("CustomizeTokenizer args:", args) 240 | # print("CustomizeTokenizer kwargs:", kwargs) 241 | 242 | value = args[0] 243 | if isinstance(value, str): 244 | batch_size = 1 245 | else: 246 | batch_size = value.shape[0] 247 | 248 | # print(f"CustomizeTokenizer batch_size: {batch_size}") 249 | # if self.input_ids is not None: 250 | # return self 251 | 252 | self.input_ids = torch.zeros(batch_size, 256) 253 | self.attention_mask = torch.zeros(batch_size, 256) 254 | self['input_ids'] = self.input_ids 255 | self['attention_mask'] = self.attention_mask 256 | 257 | # print("CustomizeTokenizer input_ids:", self.input_ids.shape) 258 | # print("CustomizeTokenizer attention_mask:", self.attention_mask.shape) 259 | 260 | return self 261 | 262 | 263 | class CustomizeEmbeds(): 264 | def __init__(self): 265 | super().__init__() 266 | self.tokenizer = CustomizeTokenizer() 267 | self.model = CustomizeEmbedsModel().to("cuda") 268 | self.max_length = 256 269 | 270 | 271 | class CustomizeMT5Embedder(nn.Module): 272 | device = torch.device("cuda") 273 | 274 | def __init__( 275 | self, 276 | model_dir="t5-v1_1-xxl", 277 | model_kwargs=None, 278 | torch_dtype=None, 279 | use_tokenizer_only=False, 280 | max_length=128, 281 | batch_size=1, 282 | ): 283 | super().__init__() 284 | self.torch_dtype = torch_dtype or torch.bfloat16 285 | self.max_length = max_length 286 | self.tokenizer = CustomizeTokenizer( 287 | batch_size=batch_size 288 | ) 289 | self.model = CustomizeEmbedsModel().to("cuda") 290 | 291 | def gradient_checkpointing_enable(self): 292 | pass 293 | 294 | def gradient_checkpointing_disable(self): 295 | pass 296 | 297 | def get_tokens_and_mask(self, texts): 298 | text_tokens_and_mask = self.tokenizer( 299 | texts, 300 | max_length=self.max_length, 301 | padding="max_length", 302 | truncation=True, 303 | return_attention_mask=True, 304 | add_special_tokens=True, 305 | return_tensors="pt", 306 | ) 307 | tokens = text_tokens_and_mask["input_ids"][0] 308 | mask = text_tokens_and_mask["attention_mask"][0] 309 | return tokens, mask 310 | 311 | def get_text_embeddings(self, texts, attention_mask=True, layer_index=-1): 312 | text_tokens_and_mask = self.tokenizer( 313 | texts, 314 | max_length=self.max_length, 315 | padding="max_length", 316 | truncation=True, 317 | return_attention_mask=True, 318 | add_special_tokens=True, 319 | return_tensors="pt", 320 | ) 321 | 322 | outputs = self.model( 323 | input_ids=text_tokens_and_mask["input_ids"], 324 | attention_mask=( 325 | text_tokens_and_mask["attention_mask"] 326 | if attention_mask 327 | else None 328 | ), 329 | output_hidden_states=True, 330 | ) 331 | text_encoder_embs = outputs["hidden_states"][layer_index].detach() 332 | 333 | return text_encoder_embs, text_tokens_and_mask["attention_mask"].to(self.device) 334 | 335 | def get_input_ids(self, caption): 336 | return self.tokenizer( 337 | caption, 338 | padding="max_length", 339 | truncation=True, 340 | max_length=self.max_length, 341 | return_tensors="pt", 342 | ).input_ids 343 | 344 | def get_hidden_states(self, input_ids, layer_index=-1): 345 | return self.get_text_embeddings(input_ids, layer_index=layer_index) 346 | 347 | 348 | def load_tokenizers(*args, **kwargs): 349 | return TokenizersWrapper("clip_large") 350 | 351 | 352 | def load_sdxl_tokenizers(*args, **kwargs): 353 | return [TokenizersWrapper("clip_large"), TokenizersWrapper("clip_big")] 354 | 355 | 356 | original_conditional_loss = None 357 | 358 | 359 | running_info = {} 360 | 361 | 362 | def conditional_loss(*args, **kwargs): 363 | running_info["last_noise_pred"] = args[0] 364 | return original_conditional_loss(*args, **kwargs) 365 | 366 | 367 | def decode_latents(vae, latents): 368 | device = "cuda" if torch.cuda.is_available() else "cpu" 369 | latents = latents.to(dtype=vae.dtype).to(device) 370 | vae = vae.to(device) 371 | latents = 1 / 0.18215 * latents 372 | image = vae.decode(latents).sample 373 | image = (image / 2 + 0.5).clamp(0, 1) 374 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 375 | image = image.cpu().permute(0, 2, 3, 1).float().detach().numpy() 376 | return image 377 | 378 | 379 | def hook_kohya_ss(): 380 | import library.utils 381 | import library.train_util 382 | import library.sdxl_train_util 383 | library.utils.setup_logging = setup_logging 384 | 385 | library.train_util.load_tokenizer = load_tokenizers 386 | library.sdxl_train_util.load_tokenizers = load_sdxl_tokenizers 387 | 388 | global original_load_target_model 389 | if original_load_target_model is None: 390 | original_load_target_model = library.train_util._load_target_model 391 | 392 | library.train_util._load_target_model = _load_target_model 393 | 394 | library.sdxl_train_util._load_target_model = _sdxl_load_target_model 395 | 396 | global original_conditional_loss 397 | if original_conditional_loss is None: 398 | original_conditional_loss = library.train_util.conditional_loss 399 | library.train_util.conditional_loss = conditional_loss 400 | 401 | 402 | def _sdxl_load_target_model( 403 | name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, *args, **kwargs 404 | ): 405 | import library.sdxl_model_util as sdxl_model_util 406 | import library.model_util as model_util 407 | import library.sdxl_original_unet as sdxl_original_unet 408 | import library.sdxl_train_util 409 | init_empty_weights = library.sdxl_train_util.init_empty_weights 410 | 411 | # model_dtype only work with full fp16/bf16 412 | name_or_path = os.readlink(name_or_path) if os.path.islink( 413 | name_or_path) else name_or_path 414 | load_stable_diffusion_format = False 415 | 416 | if True: 417 | # Diffusers model is loaded to CPU 418 | variant = "fp16" if weight_dtype == torch.float16 else None 419 | print( 420 | f"load Diffusers pretrained models: {name_or_path}, variant={variant}") 421 | try: 422 | try: 423 | pipe = StableDiffusionXLPipeline.from_single_file( 424 | name_or_path, local_files_only=True, safety_checker=None) 425 | except EnvironmentError as ex: 426 | raise ex 427 | except EnvironmentError as ex: 428 | print( 429 | f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" 430 | ) 431 | raise ex 432 | 433 | text_encoder1 = pipe.text_encoder 434 | text_encoder2 = pipe.text_encoder_2 435 | 436 | # convert to fp32 for cache text_encoders outputs 437 | if text_encoder1.dtype != torch.float32: 438 | text_encoder1 = text_encoder1.to(dtype=torch.float32) 439 | if text_encoder2.dtype != torch.float32: 440 | text_encoder2 = text_encoder2.to(dtype=torch.float32) 441 | 442 | vae = pipe.vae 443 | unet = pipe.unet 444 | global clip_large_tokenizer, clip_big_tokenizer 445 | clip_large_tokenizer = pipe.tokenizer 446 | clip_big_tokenizer = pipe.tokenizer_2 447 | del pipe 448 | 449 | # Diffusers U-Net to original U-Net 450 | state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl( 451 | unet.state_dict()) 452 | with init_empty_weights(): 453 | unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet 454 | sdxl_model_util._load_state_dict_on_device( 455 | unet, state_dict, device=device, dtype=model_dtype) 456 | print("U-Net converted to original U-Net") 457 | 458 | logit_scale = None 459 | ckpt_info = None 460 | 461 | # VAEを読み込む 462 | if vae_path is not None: 463 | vae = model_util.load_vae(vae_path, weight_dtype) 464 | print("additional VAE loaded") 465 | 466 | return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info 467 | 468 | 469 | def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", unet_use_linear_projection_in_v2=False): 470 | import library.model_util as model_util 471 | from library.original_unet import UNet2DConditionModel 472 | 473 | name_or_path = args.pretrained_model_name_or_path 474 | name_or_path = os.path.realpath(name_or_path) if os.path.islink( 475 | name_or_path) else name_or_path 476 | load_stable_diffusion_format = False 477 | if True: 478 | # Diffusers model is loaded to CPU 479 | try: 480 | pipe = StableDiffusionPipeline.from_single_file( 481 | name_or_path, local_files_only=True, safety_checker=None) 482 | except EnvironmentError as ex: 483 | print( 484 | f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" 485 | ) 486 | raise ex 487 | 488 | text_encoder = pipe.text_encoder 489 | vae = pipe.vae 490 | unet = pipe.unet 491 | global clip_large_tokenizer 492 | clip_large_tokenizer = pipe.tokenizer 493 | del pipe 494 | 495 | # Diffusers U-Net to original U-Net 496 | # TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう 497 | # print(f"unet config: {unet.config}") 498 | original_unet = UNet2DConditionModel( 499 | unet.config.sample_size, 500 | unet.config.attention_head_dim, 501 | unet.config.cross_attention_dim, 502 | unet.config.use_linear_projection, 503 | unet.config.upcast_attention, 504 | ) 505 | original_unet.load_state_dict(unet.state_dict()) 506 | unet = original_unet 507 | print("U-Net converted to original U-Net") 508 | 509 | # VAEを読み込む 510 | if args.vae is not None: 511 | vae = model_util.load_vae(args.vae, weight_dtype) 512 | print("additional VAE loaded") 513 | 514 | return text_encoder, vae, unet, load_stable_diffusion_format 515 | 516 | 517 | def generate_image(pipe_class, cmd_args, accelerator, vae, tokenizer, text_encoder, unet, epoch, prompt_dict_list, **kwargs): 518 | if pipe_class is None: 519 | print("pipe_class is None") 520 | return 521 | import library.train_util 522 | 523 | # for multi gpu distributed inference. this is a singleton, so it's safe to use it here 524 | distributed_state = library.train_util.PartialState() 525 | org_vae_device = vae.device # CPUにいるはず 526 | vae.to(distributed_state.device) 527 | unet = accelerator.unwrap_model(unet) 528 | 529 | if isinstance(text_encoder, (list, tuple)): 530 | text_encoder = [accelerator.unwrap_model(te) for te in text_encoder] 531 | else: 532 | text_encoder = accelerator.unwrap_model(text_encoder) 533 | 534 | default_scheduler = library.train_util.get_my_scheduler( 535 | sample_sampler="k_euler", 536 | v_parameterization=cmd_args.v_parameterization, 537 | ) 538 | 539 | pipeline = pipe_class( 540 | text_encoder=text_encoder, 541 | vae=vae, 542 | unet=unet, 543 | tokenizer=tokenizer, 544 | scheduler=default_scheduler, 545 | safety_checker=None, 546 | feature_extractor=None, 547 | requires_safety_checker=False, 548 | clip_skip=cmd_args.clip_skip, 549 | ) 550 | pipeline.to(distributed_state.device) 551 | 552 | workspaces_dir = os.path.dirname(cmd_args.dataset_config) 553 | sample_images_path = os.path.join( 554 | workspaces_dir, "sample_images") 555 | 556 | os.makedirs(sample_images_path, exist_ok=True) 557 | 558 | lora_output_name = cmd_args.output_name 559 | 560 | # 画像生成 561 | save_dir = sample_images_path 562 | prompt_replacement = None 563 | steps = 0 564 | controlnet = None 565 | 566 | # save random state to restore later 567 | rng_state = torch.get_rng_state() 568 | cuda_rng_state = None 569 | try: 570 | cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None 571 | except Exception: 572 | pass 573 | 574 | with torch.no_grad(): 575 | for prompt_dict in prompt_dict_list: 576 | library.train_util.sample_image_inference( 577 | accelerator, cmd_args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet 578 | ) 579 | 580 | del pipeline 581 | library.train_util.clean_memory_on_device(accelerator.device) 582 | 583 | torch.set_rng_state(rng_state) 584 | if cuda_rng_state is not None: 585 | torch.cuda.set_rng_state(cuda_rng_state) 586 | vae.to(org_vae_device) 587 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-traintools-mz" 3 | description = "Nodes for fine-tuning lora in ComfyUI, dependent on training tools such as kohya-ss/sd-scripts" 4 | version = "1.0.0" 5 | license = "LICENSE" 6 | 7 | [project.urls] 8 | Repository = "https://github.com/MinusZoneAI/ComfyUI-TrainTools-MZ" 9 | # Used by Comfy Registry https://comfyregistry.org 10 | 11 | [tool.comfy] 12 | PublisherId = "wailovet" 13 | DisplayName = "ComfyUI-TrainTools-MZ" 14 | Icon = "" 15 | --------------------------------------------------------------------------------