├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── README_EN.md ├── __init__.py ├── doc └── image.png ├── example └── workflow_base.png ├── nodes ├── HairNode.py ├── __init__.py └── libs │ ├── __init__.py │ ├── configs │ └── sd15 │ │ ├── feature_extractor │ │ └── preprocessor_config.json │ │ ├── model_index.json │ │ ├── safety_checker │ │ └── config.json │ │ ├── scheduler │ │ └── scheduler_config.json │ │ ├── text_encoder │ │ └── config.json │ │ ├── tokenizer │ │ ├── merges.txt │ │ ├── special_tokens_map.json │ │ ├── tokenizer_config.json │ │ └── vocab.json │ │ ├── unet │ │ └── config.json │ │ ├── v1-inference.yaml │ │ └── vae │ │ └── config.json │ ├── ref_encoder │ ├── __init__.py │ ├── adapter.py │ ├── attention_processor.py │ ├── latent_controlnet.py │ ├── reference_control.py │ └── reference_unet.py │ └── utils │ ├── __init__.py │ ├── pipeline.py │ └── pipeline_cn.py └── pyproject.toml /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'lldacing' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@v1 23 | with: 24 | ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [ENGLISH](README_EN.md) 2 | 3 | 头发迁移 4 | 5 | ## 预览 6 | ![save api extended](doc/image.png) 7 | 8 | ## [示例工作流](example/workflow_base.png) 9 | 工作流中是最简单的替换,没有贴回原图,完整的流程是:裁剪包含头发的头像区域->生成光头->迁移参考图像头发->把结果图贴回原图 10 | 11 | ### 说明: 12 | - 生成光头和迁移头发的图像宽高需要是8的倍数,两个裁剪后的图像尺寸要一致,需要是正面照 13 | - 大模型要选sd1.5模型 14 | 15 | ## 安装 16 | 17 | - 手动安装 18 | ```shell 19 | cd custom_nodes 20 | git clone https://github.com/lldacing/ComfyUI_StableHair_ll.git 21 | cd ComfyUI_StableHair_ll 22 | # 重启comfyUI 23 | ``` 24 | 25 | 26 | ## 模型 27 | 从[HuggingFace](https://huggingface.co/lldacing/StableHair/tree/main)下载所有文件放到目录`ComfyUI/models/diffusers/StableHair` 28 | 29 | 建议使用huggingface-cli下载 30 | ``` 31 | # 设置代理,按需设置,也可开全局代理 32 | set https_proxy=http://127.0.0.1:7890 33 | # 在ComfyUI/models/diffusers/目录下启动命令行执行下面的命令,如果找不到huggingface-cli,huggingface-cli在${python_home}/Scripts目录下,使用全路径 34 | huggingface-cli download lldacing/StableHair --local-dir StableHair 35 | ``` 36 | 目录结构如下: 37 | ``` 38 | ComfyUI 39 | └─models 40 | └─diffusers 41 | └─StableHair 42 | └─hair_encoder_model.bin 43 | └─hair_adapter_model.bin 44 | └─hair_controlnet_model.bin 45 | └─hair_bald_model.bin 46 | ``` 47 | 48 | ## 感谢 49 | 50 | 原项目 [Xiaojiu-z/Stable-Hair](https://github.com/Xiaojiu-z/Stable-Hair) 51 | 52 | -------------------------------------------------------------------------------- /README_EN.md: -------------------------------------------------------------------------------- 1 | [中文文档](README.md) 2 | 3 | Stable-Hair: Real-World Hair Transfer via Diffusion Model 4 | 5 | ## Preview 6 | ![save api extended](doc/image.png) 7 | 8 | ## [Workflow example](example/workflow_base.png) 9 | In the workflow, it is the simplest replacement without pasting back the original image. The complete process is: crop the avatar area containing hair ->generate a bald head ->migrate the reference image hair ->paste the result image back to the original image 10 | 11 | ### tips: 12 | - The width and height of the images for generating bald heads and transferring hair need to be multiples of 8, the two cropped image should be same size, and they should be front facing photos 13 | - Choose the SD1.5 model 14 | 15 | ## Install 16 | 17 | - Manual 18 | ```shell 19 | cd custom_nodes 20 | git clone https://github.com/lldacing/ComfyUI_StableHair_ll.git 21 | cd ComfyUI_StableHair_ll 22 | # restart ComfyUI 23 | ``` 24 | 25 | 26 | ## Model 27 | From [HuggingFace](https://huggingface.co/lldacing/StableHair/tree/main) download all files to `ComfyUI/models/diffusers/StableHair` directory. 28 | 29 | Suggest using huggingface-cli to download 30 | ``` 31 | # Start the command line in the ComfyUI/models directory and execute the following command 32 | huggingface-cli download lldacing/StableHair --local-dir StableHair 33 | ``` 34 | The directory structure is as follows: 35 | ``` 36 | ComfyUI 37 | └─models 38 | └─diffusers 39 | └─StableHair 40 | └─hair_encoder_model.bin 41 | └─hair_adapter_model.bin 42 | └─hair_controlnet_model.bin 43 | └─hair_bald_model.bin 44 | ``` 45 | 46 | 47 | ## Thanks 48 | 49 | Original Project [Xiaojiu-z/Stable-Hair](https://github.com/Xiaojiu-z/Stable-Hair) 50 | 51 | 52 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # 获取当前目录的父目录 5 | parent_dir = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | # 添加父目录到系统路径 8 | sys.path.insert(0, parent_dir) 9 | 10 | from .nodes import HairNode 11 | 12 | NODE_CLASS_MAPPINGS = {**HairNode.NODE_CLASS_MAPPINGS} 13 | NODE_DISPLAY_NAME_MAPPINGS = {**HairNode.NODE_DISPLAY_NAME_MAPPINGS} 14 | -------------------------------------------------------------------------------- /doc/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_StableHair_ll/d00491937bc9f0c4fd96f010bf0c63fb5eabcd30/doc/image.png -------------------------------------------------------------------------------- /example/workflow_base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_StableHair_ll/d00491937bc9f0c4fd96f010bf0c63fb5eabcd30/example/workflow_base.png -------------------------------------------------------------------------------- /nodes/HairNode.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from sympy.stats.sampling.sample_numpy import numpy 4 | 5 | from comfy import model_management 6 | import folder_paths 7 | from diffusers import UniPCMultistepScheduler 8 | 9 | from comfy.utils import ProgressBar 10 | from folder_paths import supported_pt_extensions 11 | from .libs.ref_encoder.adapter import * 12 | from .libs.ref_encoder.latent_controlnet import ControlNetModel 13 | from .libs.ref_encoder.reference_unet import RefHairUnet 14 | from .libs.utils.pipeline import StableHairPipeline 15 | from .libs.utils.pipeline_cn import StableDiffusionControlNetPipeline 16 | 17 | deviceType = model_management.get_torch_device().type 18 | current_dir = os.path.dirname(os.path.abspath(__file__)) 19 | sd_config_dir = os.path.join(current_dir, "libs/configs/sd15") 20 | hair_model_path_format = 'StableHair/{}' 21 | 22 | 23 | class LoadStableHairRemoverModel: 24 | 25 | @classmethod 26 | def INPUT_TYPES(cls): 27 | model_paths = [] 28 | for search_path in folder_paths.get_folder_paths("diffusers"): 29 | stable_hair_path = os.path.join(search_path, "StableHair") 30 | if os.path.exists(stable_hair_path): 31 | for root, subdir, files in os.walk(stable_hair_path, followlinks=True): 32 | for file in files: 33 | file_name_ext = file.split(".") 34 | if len(file_name_ext) > 1 and '.{}'.format(file_name_ext[-1]) in supported_pt_extensions: 35 | model_paths.append(file) 36 | return { 37 | "required": { 38 | "ckpt_name": (folder_paths.get_filename_list("checkpoints"), 39 | {"tooltip": "The name of the checkpoint (model) to load."}), 40 | "bald_model": (model_paths, {}), 41 | "device": (["AUTO", "CPU"],) 42 | } 43 | } 44 | 45 | RETURN_TYPES = ("BALD_MODEL",) 46 | RETURN_NAMES = ("bald_model",) 47 | FUNCTION = "load_model" 48 | CATEGORY = "hair/transfer" 49 | 50 | def load_model(self, ckpt_name, bald_model, device): 51 | model_management.soft_empty_cache() 52 | sd15_model_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) 53 | bald_model_path = folder_paths.get_full_path_or_raise("diffusers", hair_model_path_format.format(bald_model)) 54 | if device == "AUTO": 55 | device_type = deviceType 56 | else: 57 | device_type = "cpu" 58 | 59 | weight_dtype = torch.float16 if device_type == "cuda" else torch.float32 60 | 61 | remove_hair_pipeline = StableDiffusionControlNetPipeline.from_single_file(sd15_model_path, 62 | config=sd_config_dir, 63 | torch_dtype=weight_dtype, 64 | local_files_only=True, 65 | safety_checker=None, 66 | requires_safety_checker=False, 67 | ) 68 | bald_converter = ControlNetModel.from_unet(remove_hair_pipeline.unet) 69 | _state_dict = torch.load(bald_model_path) 70 | 71 | bald_converter.load_state_dict(_state_dict, strict=False) 72 | bald_converter.to(device_type, dtype=weight_dtype) 73 | remove_hair_pipeline.register_modules(controlnet=bald_converter) 74 | 75 | remove_hair_pipeline.scheduler = UniPCMultistepScheduler.from_config(remove_hair_pipeline.scheduler.config) 76 | remove_hair_pipeline.to(device_type) 77 | 78 | return remove_hair_pipeline, 79 | 80 | 81 | class LoadStableHairTransferModel: 82 | 83 | @classmethod 84 | def INPUT_TYPES(cls): 85 | model_paths = [] 86 | for search_path in folder_paths.get_folder_paths("diffusers"): 87 | stable_hair_path = os.path.join(search_path, "StableHair") 88 | if os.path.exists(stable_hair_path): 89 | for root, subdir, files in os.walk(stable_hair_path, followlinks=True): 90 | for file in files: 91 | file_name_ext = file.split(".") 92 | if len(file_name_ext) >1 and '.{}'.format(file_name_ext[-1]) in supported_pt_extensions: 93 | model_paths.append(file) 94 | return { 95 | "required": { 96 | "ckpt_name": (folder_paths.get_filename_list("checkpoints"), 97 | {"tooltip": "The name of the checkpoint (model) to load."}), 98 | "encoder_model": (model_paths, {}), 99 | "adapter_model": (model_paths, {}), 100 | "control_model": (model_paths, {}), 101 | "device": (["AUTO", "CPU"],) 102 | } 103 | } 104 | 105 | RETURN_TYPES = ("HAIR_MODEL",) 106 | RETURN_NAMES = ("model",) 107 | FUNCTION = "load_model" 108 | CATEGORY = "hair/transfer" 109 | 110 | def load_model(self, ckpt_name, encoder_model, adapter_model, control_model, device): 111 | model_management.soft_empty_cache() 112 | sd15_model_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) 113 | encoder_model_path = folder_paths.get_full_path_or_raise("diffusers", 114 | hair_model_path_format.format(encoder_model)) 115 | adapter_model_path = folder_paths.get_full_path_or_raise("diffusers", 116 | hair_model_path_format.format(adapter_model)) 117 | control_model_path = folder_paths.get_full_path_or_raise("diffusers", 118 | hair_model_path_format.format(control_model)) 119 | if device == "AUTO": 120 | device_type = deviceType 121 | else: 122 | device_type = "cpu" 123 | 124 | weight_dtype = torch.float16 if device_type == "cuda" else torch.float32 125 | 126 | pipeline = StableHairPipeline.from_single_file(sd15_model_path, 127 | config=sd_config_dir, 128 | torch_dtype=weight_dtype, 129 | local_files_only=True, 130 | safety_checker=None, 131 | requires_safety_checker=False, 132 | ).to(device_type) 133 | 134 | controlnet = ControlNetModel.from_unet(pipeline.unet).to(device_type) 135 | _state_dict = torch.load(control_model_path) 136 | controlnet.load_state_dict(_state_dict, strict=False) 137 | controlnet.to(device_type, dtype=weight_dtype) 138 | pipeline.register_modules(controlnet=controlnet) 139 | 140 | pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) 141 | 142 | hair_encoder = RefHairUnet.from_config(pipeline.unet.config) 143 | _state_dict = torch.load(encoder_model_path) 144 | hair_encoder.load_state_dict(_state_dict, strict=False) 145 | hair_encoder.to(device_type, dtype=weight_dtype) 146 | pipeline.register_modules(reference_encoder=hair_encoder) 147 | 148 | hair_adapter = adapter_injection(pipeline.unet, device=device_type, dtype=weight_dtype, use_resampler=False) 149 | _state_dict = torch.load(adapter_model_path) 150 | 151 | hair_adapter.load_state_dict(_state_dict, strict=False) 152 | 153 | return pipeline, 154 | 155 | 156 | class ApplyHairRemover: 157 | 158 | @classmethod 159 | def INPUT_TYPES(cls): 160 | return { 161 | "required": { 162 | "bald_model": ("BALD_MODEL",), 163 | "images": ("IMAGE",), 164 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 165 | "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), 166 | "strength": ("FLOAT", {"default": 1.5, "min": 0.0, "max": 5.0, "step": 0.01}), 167 | }, 168 | "optional": { 169 | "cfg": ("FLOAT", {"default": 1.5, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), 170 | } 171 | } 172 | 173 | RETURN_TYPES = ("IMAGE",) 174 | RETURN_NAMES = ("image",) 175 | FUNCTION = "apply" 176 | CATEGORY = "hair/transfer" 177 | 178 | def apply(self, bald_model, images, seed, steps, strength, cfg=1.5): 179 | _images = [] 180 | _masks = [] 181 | 182 | for image in images: 183 | # h, w, c -> c, h, w 184 | H, W, c = image.shape 185 | im_tensor = image.permute(2, 0, 1) 186 | # 随记种子 187 | generator = torch.Generator(device=bald_model.device) 188 | generator.manual_seed(seed) 189 | comfy_pbar = ProgressBar(steps) 190 | 191 | def callback_bar(step, timestep, latents): 192 | comfy_pbar.update(1) 193 | 194 | with torch.no_grad(): 195 | # 采样,变光头 196 | result_image = bald_model( 197 | prompt="", 198 | negative_prompt="", 199 | num_inference_steps=steps, 200 | guidance_scale=cfg, 201 | width=W, 202 | height=H, 203 | image=im_tensor.unsqueeze(0), 204 | controlnet_conditioning_scale=strength, 205 | generator=None, 206 | return_dict=False, 207 | output_type="pt", 208 | callback=callback_bar, 209 | )[0] 210 | 211 | # b, c, h, w -> b, h, w, c 212 | result_image = result_image.permute(0, 2, 3, 1) 213 | 214 | _images.append(result_image) 215 | 216 | out_images = torch.cat(_images, dim=0) 217 | 218 | return out_images, 219 | 220 | 221 | class ApplyHairTransfer: 222 | 223 | @classmethod 224 | def INPUT_TYPES(cls): 225 | return { 226 | "required": { 227 | "model": ("HAIR_MODEL",), 228 | "images": ("IMAGE",), 229 | "bald_image": ("IMAGE",), 230 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 231 | "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), 232 | "cfg": ("FLOAT", {"default": 1.5, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), 233 | "control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}), 234 | "adapter_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}), 235 | } 236 | } 237 | 238 | RETURN_TYPES = ("IMAGE",) 239 | RETURN_NAMES = ("image",) 240 | FUNCTION = "apply" 241 | CATEGORY = "hair/transfer" 242 | 243 | def apply(self, model, images, bald_image, seed, steps, cfg, control_strength, adapter_strength): 244 | _images = [] 245 | _masks = [] 246 | 247 | for image in images: 248 | h, w, c = image.shape 249 | 250 | prompt = "" 251 | n_prompt = "" 252 | 253 | # 设置adapter强度 254 | set_scale(model.unet, adapter_strength) 255 | # 随记种子 256 | generator = torch.Generator(device=model.device) 257 | generator.manual_seed(seed) 258 | comfy_pbar = ProgressBar(steps) 259 | 260 | def callback_bar(step, timestep, latents): 261 | comfy_pbar.update(1) 262 | 263 | ref_image_np = (image.cpu().numpy() * 255).astype(numpy.uint8) 264 | bald_image_np = (bald_image.squeeze(0).cpu().numpy() * 255).astype(numpy.uint8) 265 | with torch.no_grad(): 266 | # 采样,转移发型 267 | result_image = model( 268 | prompt, 269 | negative_prompt=n_prompt, 270 | num_inference_steps=steps, 271 | guidance_scale=cfg, 272 | width=w, 273 | height=h, 274 | controlnet_condition=bald_image_np, 275 | controlnet_conditioning_scale=control_strength, 276 | generator=generator, 277 | ref_image=ref_image_np, 278 | output_type="tensor", 279 | callback=callback_bar, 280 | return_dict=False 281 | ) 282 | 283 | # b, h, w, c 284 | result_image = result_image.unsqueeze(0) 285 | 286 | _images.append(result_image) 287 | 288 | out_images = torch.cat(_images, dim=0) 289 | 290 | return out_images, 291 | 292 | 293 | NODE_CLASS_MAPPINGS = { 294 | "LoadStableHairRemoverModel": LoadStableHairRemoverModel, 295 | "LoadStableHairTransferModel": LoadStableHairTransferModel, 296 | "ApplyHairRemover": ApplyHairRemover, 297 | "ApplyHairTransfer": ApplyHairTransfer, 298 | } 299 | 300 | NODE_DISPLAY_NAME_MAPPINGS = { 301 | "LoadStableHairRemoverModel": "LoadStableHairRemoverModel", 302 | "LoadStableHairTransferModel": "LoadStableHairTransferModel", 303 | "ApplyHairRemover": "ApplyHairRemover", 304 | "ApplyHairTransfer": "ApplyHairTransfer", 305 | } 306 | -------------------------------------------------------------------------------- /nodes/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nodes/libs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_StableHair_ll/d00491937bc9f0c4fd96f010bf0c63fb5eabcd30/nodes/libs/__init__.py -------------------------------------------------------------------------------- /nodes/libs/configs/sd15/feature_extractor/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "crop_size": 224, 3 | "do_center_crop": true, 4 | "do_convert_rgb": true, 5 | "do_normalize": true, 6 | "do_resize": true, 7 | "feature_extractor_type": "CLIPFeatureExtractor", 8 | "image_mean": [ 9 | 0.48145466, 10 | 0.4578275, 11 | 0.40821073 12 | ], 13 | "image_std": [ 14 | 0.26862954, 15 | 0.26130258, 16 | 0.27577711 17 | ], 18 | "resample": 3, 19 | "size": 224 20 | } 21 | -------------------------------------------------------------------------------- /nodes/libs/configs/sd15/model_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "StableDiffusionPipeline", 3 | "_diffusers_version": "0.6.0", 4 | "feature_extractor": [ 5 | "transformers", 6 | "CLIPImageProcessor" 7 | ], 8 | "safety_checker": [ 9 | "stable_diffusion", 10 | "StableDiffusionSafetyChecker" 11 | ], 12 | "scheduler": [ 13 | "diffusers", 14 | "PNDMScheduler" 15 | ], 16 | "text_encoder": [ 17 | "transformers", 18 | "CLIPTextModel" 19 | ], 20 | "tokenizer": [ 21 | "transformers", 22 | "CLIPTokenizer" 23 | ], 24 | "unet": [ 25 | "diffusers", 26 | "UNet2DConditionModel" 27 | ], 28 | "vae": [ 29 | "diffusers", 30 | "AutoencoderKL" 31 | ] 32 | } 33 | -------------------------------------------------------------------------------- /nodes/libs/configs/sd15/safety_checker/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_commit_hash": "4bb648a606ef040e7685bde262611766a5fdd67b", 3 | "_name_or_path": "CompVis/stable-diffusion-safety-checker", 4 | "architectures": [ 5 | "StableDiffusionSafetyChecker" 6 | ], 7 | "initializer_factor": 1.0, 8 | "logit_scale_init_value": 2.6592, 9 | "model_type": "clip", 10 | "projection_dim": 768, 11 | "text_config": { 12 | "_name_or_path": "", 13 | "add_cross_attention": false, 14 | "architectures": null, 15 | "attention_dropout": 0.0, 16 | "bad_words_ids": null, 17 | "bos_token_id": 0, 18 | "chunk_size_feed_forward": 0, 19 | "cross_attention_hidden_size": null, 20 | "decoder_start_token_id": null, 21 | "diversity_penalty": 0.0, 22 | "do_sample": false, 23 | "dropout": 0.0, 24 | "early_stopping": false, 25 | "encoder_no_repeat_ngram_size": 0, 26 | "eos_token_id": 2, 27 | "exponential_decay_length_penalty": null, 28 | "finetuning_task": null, 29 | "forced_bos_token_id": null, 30 | "forced_eos_token_id": null, 31 | "hidden_act": "quick_gelu", 32 | "hidden_size": 768, 33 | "id2label": { 34 | "0": "LABEL_0", 35 | "1": "LABEL_1" 36 | }, 37 | "initializer_factor": 1.0, 38 | "initializer_range": 0.02, 39 | "intermediate_size": 3072, 40 | "is_decoder": false, 41 | "is_encoder_decoder": false, 42 | "label2id": { 43 | "LABEL_0": 0, 44 | "LABEL_1": 1 45 | }, 46 | "layer_norm_eps": 1e-05, 47 | "length_penalty": 1.0, 48 | "max_length": 20, 49 | "max_position_embeddings": 77, 50 | "min_length": 0, 51 | "model_type": "clip_text_model", 52 | "no_repeat_ngram_size": 0, 53 | "num_attention_heads": 12, 54 | "num_beam_groups": 1, 55 | "num_beams": 1, 56 | "num_hidden_layers": 12, 57 | "num_return_sequences": 1, 58 | "output_attentions": false, 59 | "output_hidden_states": false, 60 | "output_scores": false, 61 | "pad_token_id": 1, 62 | "prefix": null, 63 | "problem_type": null, 64 | "pruned_heads": {}, 65 | "remove_invalid_values": false, 66 | "repetition_penalty": 1.0, 67 | "return_dict": true, 68 | "return_dict_in_generate": false, 69 | "sep_token_id": null, 70 | "task_specific_params": null, 71 | "temperature": 1.0, 72 | "tf_legacy_loss": false, 73 | "tie_encoder_decoder": false, 74 | "tie_word_embeddings": true, 75 | "tokenizer_class": null, 76 | "top_k": 50, 77 | "top_p": 1.0, 78 | "torch_dtype": null, 79 | "torchscript": false, 80 | "transformers_version": "4.22.0.dev0", 81 | "typical_p": 1.0, 82 | "use_bfloat16": false, 83 | "vocab_size": 49408 84 | }, 85 | "text_config_dict": { 86 | "hidden_size": 768, 87 | "intermediate_size": 3072, 88 | "num_attention_heads": 12, 89 | "num_hidden_layers": 12 90 | }, 91 | "torch_dtype": "float32", 92 | "transformers_version": null, 93 | "vision_config": { 94 | "_name_or_path": "", 95 | "add_cross_attention": false, 96 | "architectures": null, 97 | "attention_dropout": 0.0, 98 | "bad_words_ids": null, 99 | "bos_token_id": null, 100 | "chunk_size_feed_forward": 0, 101 | "cross_attention_hidden_size": null, 102 | "decoder_start_token_id": null, 103 | "diversity_penalty": 0.0, 104 | "do_sample": false, 105 | "dropout": 0.0, 106 | "early_stopping": false, 107 | "encoder_no_repeat_ngram_size": 0, 108 | "eos_token_id": null, 109 | "exponential_decay_length_penalty": null, 110 | "finetuning_task": null, 111 | "forced_bos_token_id": null, 112 | "forced_eos_token_id": null, 113 | "hidden_act": "quick_gelu", 114 | "hidden_size": 1024, 115 | "id2label": { 116 | "0": "LABEL_0", 117 | "1": "LABEL_1" 118 | }, 119 | "image_size": 224, 120 | "initializer_factor": 1.0, 121 | "initializer_range": 0.02, 122 | "intermediate_size": 4096, 123 | "is_decoder": false, 124 | "is_encoder_decoder": false, 125 | "label2id": { 126 | "LABEL_0": 0, 127 | "LABEL_1": 1 128 | }, 129 | "layer_norm_eps": 1e-05, 130 | "length_penalty": 1.0, 131 | "max_length": 20, 132 | "min_length": 0, 133 | "model_type": "clip_vision_model", 134 | "no_repeat_ngram_size": 0, 135 | "num_attention_heads": 16, 136 | "num_beam_groups": 1, 137 | "num_beams": 1, 138 | "num_channels": 3, 139 | "num_hidden_layers": 24, 140 | "num_return_sequences": 1, 141 | "output_attentions": false, 142 | "output_hidden_states": false, 143 | "output_scores": false, 144 | "pad_token_id": null, 145 | "patch_size": 14, 146 | "prefix": null, 147 | "problem_type": null, 148 | "pruned_heads": {}, 149 | "remove_invalid_values": false, 150 | "repetition_penalty": 1.0, 151 | "return_dict": true, 152 | "return_dict_in_generate": false, 153 | "sep_token_id": null, 154 | "task_specific_params": null, 155 | "temperature": 1.0, 156 | "tf_legacy_loss": false, 157 | "tie_encoder_decoder": false, 158 | "tie_word_embeddings": true, 159 | "tokenizer_class": null, 160 | "top_k": 50, 161 | "top_p": 1.0, 162 | "torch_dtype": null, 163 | "torchscript": false, 164 | "transformers_version": "4.22.0.dev0", 165 | "typical_p": 1.0, 166 | "use_bfloat16": false 167 | }, 168 | "vision_config_dict": { 169 | "hidden_size": 1024, 170 | "intermediate_size": 4096, 171 | "num_attention_heads": 16, 172 | "num_hidden_layers": 24, 173 | "patch_size": 14 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /nodes/libs/configs/sd15/scheduler/scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "PNDMScheduler", 3 | "_diffusers_version": "0.6.0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "num_train_timesteps": 1000, 8 | "set_alpha_to_one": false, 9 | "skip_prk_steps": true, 10 | "steps_offset": 1, 11 | "trained_betas": null, 12 | "clip_sample": false 13 | } 14 | -------------------------------------------------------------------------------- /nodes/libs/configs/sd15/text_encoder/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "CLIPTextModel" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 0, 7 | "dropout": 0.0, 8 | "eos_token_id": 2, 9 | "hidden_act": "quick_gelu", 10 | "hidden_size": 768, 11 | "initializer_factor": 1.0, 12 | "initializer_range": 0.02, 13 | "intermediate_size": 3072, 14 | "layer_norm_eps": 1e-05, 15 | "max_position_embeddings": 77, 16 | "model_type": "clip_text_model", 17 | "num_attention_heads": 12, 18 | "num_hidden_layers": 12, 19 | "pad_token_id": 1, 20 | "projection_dim": 768, 21 | "torch_dtype": "float32", 22 | "transformers_version": "4.22.0.dev0", 23 | "vocab_size": 49408 24 | } 25 | -------------------------------------------------------------------------------- /nodes/libs/configs/sd15/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": "<|endoftext|>", 17 | "unk_token": { 18 | "content": "<|endoftext|>", 19 | "lstrip": false, 20 | "normalized": true, 21 | "rstrip": false, 22 | "single_word": false 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /nodes/libs/configs/sd15/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": { 4 | "__type": "AddedToken", 5 | "content": "<|startoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false 10 | }, 11 | "do_lower_case": true, 12 | "eos_token": { 13 | "__type": "AddedToken", 14 | "content": "<|endoftext|>", 15 | "lstrip": false, 16 | "normalized": true, 17 | "rstrip": false, 18 | "single_word": false 19 | }, 20 | "errors": "replace", 21 | "model_max_length": 77, 22 | "name_or_path": "openai/clip-vit-large-patch14", 23 | "pad_token": "<|endoftext|>", 24 | "special_tokens_map_file": "./special_tokens_map.json", 25 | "tokenizer_class": "CLIPTokenizer", 26 | "unk_token": { 27 | "__type": "AddedToken", 28 | "content": "<|endoftext|>", 29 | "lstrip": false, 30 | "normalized": true, 31 | "rstrip": false, 32 | "single_word": false 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /nodes/libs/configs/sd15/unet/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNet2DConditionModel", 3 | "_diffusers_version": "0.6.0", 4 | "act_fn": "silu", 5 | "attention_head_dim": 8, 6 | "block_out_channels": [ 7 | 320, 8 | 640, 9 | 1280, 10 | 1280 11 | ], 12 | "center_input_sample": false, 13 | "cross_attention_dim": 768, 14 | "down_block_types": [ 15 | "CrossAttnDownBlock2D", 16 | "CrossAttnDownBlock2D", 17 | "CrossAttnDownBlock2D", 18 | "DownBlock2D" 19 | ], 20 | "downsample_padding": 1, 21 | "flip_sin_to_cos": true, 22 | "freq_shift": 0, 23 | "in_channels": 4, 24 | "layers_per_block": 2, 25 | "mid_block_scale_factor": 1, 26 | "norm_eps": 1e-05, 27 | "norm_num_groups": 32, 28 | "out_channels": 4, 29 | "sample_size": 64, 30 | "up_block_types": [ 31 | "UpBlock2D", 32 | "CrossAttnUpBlock2D", 33 | "CrossAttnUpBlock2D", 34 | "CrossAttnUpBlock2D" 35 | ] 36 | } 37 | -------------------------------------------------------------------------------- /nodes/libs/configs/sd15/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /nodes/libs/configs/sd15/vae/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKL", 3 | "_diffusers_version": "0.6.0", 4 | "act_fn": "silu", 5 | "block_out_channels": [ 6 | 128, 7 | 256, 8 | 512, 9 | 512 10 | ], 11 | "down_block_types": [ 12 | "DownEncoderBlock2D", 13 | "DownEncoderBlock2D", 14 | "DownEncoderBlock2D", 15 | "DownEncoderBlock2D" 16 | ], 17 | "in_channels": 3, 18 | "latent_channels": 4, 19 | "layers_per_block": 2, 20 | "norm_num_groups": 32, 21 | "out_channels": 3, 22 | "sample_size": 512, 23 | "up_block_types": [ 24 | "UpDecoderBlock2D", 25 | "UpDecoderBlock2D", 26 | "UpDecoderBlock2D", 27 | "UpDecoderBlock2D" 28 | ] 29 | } 30 | -------------------------------------------------------------------------------- /nodes/libs/ref_encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_StableHair_ll/d00491937bc9f0c4fd96f010bf0c63fb5eabcd30/nodes/libs/ref_encoder/__init__.py -------------------------------------------------------------------------------- /nodes/libs/ref_encoder/adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | def is_torch2_available(): 4 | return hasattr(F, "scaled_dot_product_attention") 5 | if is_torch2_available(): 6 | from .attention_processor import HairAttnProcessor2_0 as HairAttnProcessor, AttnProcessor2_0 as AttnProcessor 7 | else: 8 | from .attention_processor import HairAttnProcessor, AttnProcessor 9 | 10 | def adapter_injection(unet, device="cuda", dtype=torch.float32, use_resampler=False): 11 | device = device 12 | dtype = dtype 13 | # load Hair attention layers 14 | attn_procs = {} 15 | for name in unet.attn_processors.keys(): 16 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 17 | if name.startswith("mid_block"): 18 | hidden_size = unet.config.block_out_channels[-1] 19 | elif name.startswith("up_blocks"): 20 | block_id = int(name[len("up_blocks.")]) 21 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 22 | elif name.startswith("down_blocks"): 23 | block_id = int(name[len("down_blocks.")]) 24 | hidden_size = unet.config.block_out_channels[block_id] 25 | if cross_attention_dim is None: 26 | attn_procs[name] = HairAttnProcessor(hidden_size=hidden_size, cross_attention_dim=hidden_size, scale=1, use_resampler=use_resampler).to(device, dtype=dtype) 27 | else: 28 | attn_procs[name] = AttnProcessor() 29 | unet.set_attn_processor(attn_procs) 30 | adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) 31 | adapter_layers = adapter_modules 32 | adapter_layers.to(device, dtype=dtype) 33 | return adapter_layers 34 | 35 | def set_scale(unet, scale): 36 | for attn_processor in unet.attn_processors.values(): 37 | if isinstance(attn_processor, HairAttnProcessor): 38 | attn_processor.scale = scale -------------------------------------------------------------------------------- /nodes/libs/ref_encoder/attention_processor.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from diffusers.utils.import_utils import is_xformers_available 6 | if is_xformers_available(): 7 | import xformers 8 | import xformers.ops 9 | else: 10 | xformers = None 11 | 12 | class HairAttnProcessor(nn.Module): 13 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, use_resampler=False): 14 | super().__init__() 15 | 16 | self.hidden_size = hidden_size 17 | self.cross_attention_dim = cross_attention_dim 18 | self.scale = scale 19 | self.use_resampler = use_resampler 20 | if self.use_resampler: 21 | self.resampler = Resampler(query_dim=hidden_size) 22 | self.to_k_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 23 | self.to_v_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 24 | 25 | def __call__( 26 | self, 27 | attn, 28 | hidden_states, 29 | encoder_hidden_states=None, 30 | attention_mask=None, 31 | temb=None, 32 | ): 33 | residual = hidden_states 34 | 35 | if attn.spatial_norm is not None: 36 | hidden_states = attn.spatial_norm(hidden_states, temb) 37 | 38 | input_ndim = hidden_states.ndim 39 | 40 | if input_ndim == 4: 41 | batch_size, channel, height, width = hidden_states.shape 42 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 43 | 44 | batch_size, sequence_length, _ = ( 45 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 46 | ) 47 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 48 | 49 | if attn.group_norm is not None: 50 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 51 | 52 | query = attn.to_q(hidden_states) 53 | 54 | if encoder_hidden_states is None: 55 | encoder_hidden_states = hidden_states 56 | elif attn.norm_cross: 57 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 58 | 59 | # split hidden states 60 | split_num = encoder_hidden_states.shape[1] // 2 61 | encoder_hidden_states, _hidden_states = encoder_hidden_states[:, :split_num, 62 | :], encoder_hidden_states[:, split_num:, :] 63 | 64 | if self.use_resampler: 65 | _hidden_states = self.resampler(_hidden_states) 66 | 67 | key = attn.to_k(encoder_hidden_states) 68 | value = attn.to_v(encoder_hidden_states) 69 | 70 | query = attn.head_to_batch_dim(query) 71 | key = attn.head_to_batch_dim(key) 72 | value = attn.head_to_batch_dim(value) 73 | 74 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 75 | hidden_states = torch.bmm(attention_probs, value) 76 | hidden_states = attn.batch_to_head_dim(hidden_states) 77 | 78 | _key = self.to_k_SSR(_hidden_states) 79 | _value = self.to_v_SSR(_hidden_states) 80 | 81 | _key = attn.head_to_batch_dim(_key) 82 | _value = attn.head_to_batch_dim(_value) 83 | 84 | _attention_probs = attn.get_attention_scores(query, _key, None) 85 | _hidden_states = torch.bmm(_attention_probs, _value) 86 | _hidden_states = attn.batch_to_head_dim(_hidden_states) 87 | 88 | # # assume _hidden_states is a tensor of shape (batch_size, num_patches, hidden_size) 89 | # batch_size, num_patches, hidden_size = _hidden_states.shape 90 | # # create a mask tensor of shape (batch_size, num_patches) 91 | # mask = torch.zeros((batch_size, num_patches), device="cuda", dtype=torch.float16) 92 | # mask[:, 0:num_patches // 2] = 1 93 | # # reshape the mask tensor to match the shape of _hidden_states 94 | # mask = mask.unsqueeze(-1).expand(-1, -1, hidden_size) 95 | # # apply the mask to _hidden_states 96 | # _hidden_states = _hidden_states * mask 97 | 98 | hidden_states = hidden_states + self.scale * _hidden_states 99 | 100 | # linear proj 101 | hidden_states = attn.to_out[0](hidden_states) 102 | # dropout 103 | hidden_states = attn.to_out[1](hidden_states) 104 | 105 | if input_ndim == 4: 106 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 107 | 108 | if attn.residual_connection: 109 | hidden_states = hidden_states + residual 110 | 111 | hidden_states = hidden_states / attn.rescale_output_factor 112 | 113 | return hidden_states 114 | 115 | 116 | class HairAttnProcessor2_0(torch.nn.Module): 117 | 118 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, use_resampler=False): 119 | super().__init__() 120 | 121 | if not hasattr(F, "scaled_dot_product_attention"): 122 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 123 | 124 | self.hidden_size = hidden_size 125 | self.cross_attention_dim = cross_attention_dim 126 | self.scale = scale 127 | self.use_resampler = use_resampler 128 | if self.use_resampler: 129 | self.resampler = Resampler(query_dim=hidden_size) 130 | self.to_k_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 131 | self.to_v_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 132 | 133 | def __call__( 134 | self, 135 | attn, 136 | hidden_states, 137 | encoder_hidden_states=None, 138 | attention_mask=None, 139 | temb=None, 140 | ): 141 | residual = hidden_states 142 | 143 | if attn.spatial_norm is not None: 144 | hidden_states = attn.spatial_norm(hidden_states, temb) 145 | 146 | input_ndim = hidden_states.ndim 147 | 148 | if input_ndim == 4: 149 | batch_size, channel, height, width = hidden_states.shape 150 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 151 | 152 | batch_size, sequence_length, _ = ( 153 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 154 | ) 155 | 156 | if attention_mask is not None: 157 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 158 | # scaled_dot_product_attention expects attention_mask shape to be 159 | # (batch, heads, source_length, target_length) 160 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 161 | 162 | if attn.group_norm is not None: 163 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 164 | 165 | query = attn.to_q(hidden_states) 166 | 167 | if encoder_hidden_states is None: 168 | encoder_hidden_states = hidden_states 169 | elif attn.norm_cross: 170 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 171 | 172 | # split hidden states 173 | split_num = encoder_hidden_states.shape[1] // 2 174 | encoder_hidden_states, _hidden_states = encoder_hidden_states[:, :split_num, 175 | :], encoder_hidden_states[:, split_num:, :] 176 | 177 | if self.use_resampler: 178 | _hidden_states = self.resampler(_hidden_states) 179 | 180 | key = attn.to_k(encoder_hidden_states) 181 | value = attn.to_v(encoder_hidden_states) 182 | 183 | inner_dim = key.shape[-1] 184 | head_dim = inner_dim // attn.heads 185 | 186 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 187 | 188 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 189 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 190 | 191 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 192 | # TODO: add support for attn.scale when we move to Torch 2.1 193 | hidden_states = F.scaled_dot_product_attention( 194 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 195 | ) 196 | 197 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 198 | hidden_states = hidden_states.to(query.dtype) 199 | 200 | _hidden_states = _hidden_states.to(self.to_k_SSR.weight.dtype) 201 | _key = self.to_k_SSR(_hidden_states) 202 | _value = self.to_v_SSR(_hidden_states) 203 | 204 | _key = _key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 205 | _value = _value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 206 | 207 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 208 | # TODO: add support for attn.scale when we move to Torch 2.1 209 | _hidden_states = F.scaled_dot_product_attention( 210 | query.to(self.to_k_SSR.weight.dtype), _key, _value, attn_mask=None, dropout_p=0.0, is_causal=False 211 | ) 212 | 213 | _hidden_states = _hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 214 | _hidden_states = _hidden_states.to(query.dtype) 215 | 216 | hidden_states = hidden_states + self.scale * _hidden_states 217 | 218 | # linear proj 219 | hidden_states = attn.to_out[0](hidden_states) 220 | # dropout 221 | hidden_states = attn.to_out[1](hidden_states) 222 | 223 | if input_ndim == 4: 224 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 225 | 226 | if attn.residual_connection: 227 | hidden_states = hidden_states + residual 228 | 229 | hidden_states = hidden_states / attn.rescale_output_factor 230 | 231 | return hidden_states 232 | 233 | 234 | class AttnProcessor(nn.Module): 235 | r""" 236 | Default processor for performing attention-related computations. 237 | """ 238 | def __init__( 239 | self, 240 | hidden_size=None, 241 | cross_attention_dim=None, 242 | ): 243 | super().__init__() 244 | 245 | def __call__( 246 | self, 247 | attn, 248 | hidden_states, 249 | encoder_hidden_states=None, 250 | attention_mask=None, 251 | temb=None, 252 | ): 253 | residual = hidden_states 254 | 255 | if attn.spatial_norm is not None: 256 | hidden_states = attn.spatial_norm(hidden_states, temb) 257 | 258 | input_ndim = hidden_states.ndim 259 | 260 | if input_ndim == 4: 261 | batch_size, channel, height, width = hidden_states.shape 262 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 263 | 264 | batch_size, sequence_length, _ = ( 265 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 266 | ) 267 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 268 | 269 | if attn.group_norm is not None: 270 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 271 | 272 | query = attn.to_q(hidden_states) 273 | 274 | if encoder_hidden_states is None: 275 | encoder_hidden_states = hidden_states 276 | elif attn.norm_cross: 277 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 278 | 279 | key = attn.to_k(encoder_hidden_states) 280 | value = attn.to_v(encoder_hidden_states) 281 | 282 | query = attn.head_to_batch_dim(query) 283 | key = attn.head_to_batch_dim(key) 284 | value = attn.head_to_batch_dim(value) 285 | 286 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 287 | hidden_states = torch.bmm(attention_probs, value) 288 | hidden_states = attn.batch_to_head_dim(hidden_states) 289 | 290 | # linear proj 291 | hidden_states = attn.to_out[0](hidden_states) 292 | # dropout 293 | hidden_states = attn.to_out[1](hidden_states) 294 | 295 | if input_ndim == 4: 296 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 297 | 298 | if attn.residual_connection: 299 | hidden_states = hidden_states + residual 300 | 301 | hidden_states = hidden_states / attn.rescale_output_factor 302 | 303 | return hidden_states 304 | 305 | class AttnProcessor2_0(torch.nn.Module): 306 | r""" 307 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 308 | """ 309 | 310 | def __init__( 311 | self, 312 | hidden_size=None, 313 | cross_attention_dim=None, 314 | ): 315 | super().__init__() 316 | if not hasattr(F, "scaled_dot_product_attention"): 317 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 318 | 319 | def __call__( 320 | self, 321 | attn, 322 | hidden_states, 323 | encoder_hidden_states=None, 324 | attention_mask=None, 325 | temb=None, 326 | ): 327 | residual = hidden_states 328 | 329 | if attn.spatial_norm is not None: 330 | hidden_states = attn.spatial_norm(hidden_states, temb) 331 | 332 | input_ndim = hidden_states.ndim 333 | 334 | if input_ndim == 4: 335 | batch_size, channel, height, width = hidden_states.shape 336 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 337 | 338 | batch_size, sequence_length, _ = ( 339 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 340 | ) 341 | 342 | if attention_mask is not None: 343 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 344 | # scaled_dot_product_attention expects attention_mask shape to be 345 | # (batch, heads, source_length, target_length) 346 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 347 | 348 | if attn.group_norm is not None: 349 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 350 | 351 | query = attn.to_q(hidden_states) 352 | 353 | if encoder_hidden_states is None: 354 | encoder_hidden_states = hidden_states 355 | elif attn.norm_cross: 356 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 357 | 358 | key = attn.to_k(encoder_hidden_states) 359 | value = attn.to_v(encoder_hidden_states) 360 | 361 | inner_dim = key.shape[-1] 362 | head_dim = inner_dim // attn.heads 363 | 364 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 365 | 366 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 367 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 368 | 369 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 370 | # TODO: add support for attn.scale when we move to Torch 2.1 371 | hidden_states = F.scaled_dot_product_attention( 372 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 373 | ) 374 | 375 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 376 | hidden_states = hidden_states.to(query.dtype) 377 | 378 | # linear proj 379 | hidden_states = attn.to_out[0](hidden_states) 380 | # dropout 381 | hidden_states = attn.to_out[1](hidden_states) 382 | 383 | if input_ndim == 4: 384 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 385 | 386 | if attn.residual_connection: 387 | hidden_states = hidden_states + residual 388 | 389 | hidden_states = hidden_states / attn.rescale_output_factor 390 | 391 | return hidden_states -------------------------------------------------------------------------------- /nodes/libs/ref_encoder/latent_controlnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Any, Dict, List, Optional, Tuple, Union 16 | 17 | import torch 18 | from torch import nn 19 | from torch.nn import functional as F 20 | 21 | from packaging import version 22 | import diffusers 23 | from diffusers.configuration_utils import ConfigMixin, register_to_config 24 | if version.parse(diffusers.__version__) < version.parse("0.28.0"): 25 | from diffusers.loaders import FromOriginalControlnetMixin 26 | else: 27 | from diffusers.loaders import FromOriginalModelMixin as FromOriginalControlnetMixin 28 | 29 | from diffusers.utils import BaseOutput, logging 30 | from diffusers.models.attention_processor import ( 31 | ADDED_KV_ATTENTION_PROCESSORS, 32 | CROSS_ATTENTION_PROCESSORS, 33 | AttentionProcessor, 34 | AttnAddedKVProcessor, 35 | AttnProcessor, 36 | ) 37 | from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps 38 | from diffusers.models.modeling_utils import ModelMixin 39 | if version.parse(diffusers.__version__) < version.parse("0.26.0"): 40 | from diffusers.models.unet_2d_blocks import ( 41 | CrossAttnDownBlock2D, 42 | DownBlock2D, 43 | UNetMidBlock2DCrossAttn, 44 | get_down_block, 45 | ) 46 | from diffusers.models.unet_2d_condition import UNet2DConditionModel 47 | else: 48 | from diffusers.models.unets.unet_2d_blocks import ( 49 | CrossAttnDownBlock2D, 50 | DownBlock2D, 51 | UNetMidBlock2DCrossAttn, 52 | get_down_block, 53 | ) 54 | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel 55 | 56 | 57 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 58 | 59 | 60 | @dataclass 61 | class ControlNetOutput(BaseOutput): 62 | """ 63 | The output of [`ControlNetModel`]. 64 | 65 | Args: 66 | down_block_res_samples (`tuple[torch.Tensor]`): 67 | A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should 68 | be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be 69 | used to condition the original UNet's downsampling activations. 70 | mid_down_block_re_sample (`torch.Tensor`): 71 | The activation of the midde block (the lowest sample resolution). Each tensor should be of shape 72 | `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. 73 | Output can be used to condition the original UNet's middle block activation. 74 | """ 75 | 76 | down_block_res_samples: Tuple[torch.Tensor] 77 | mid_block_res_sample: torch.Tensor 78 | 79 | 80 | class ControlNetConditioningEmbedding(nn.Module): 81 | """ 82 | Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN 83 | [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized 84 | training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the 85 | convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides 86 | (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full 87 | model) to encode image-space conditions ... into feature maps ..." 88 | """ 89 | 90 | def __init__( 91 | self, 92 | conditioning_embedding_channels: int, 93 | conditioning_channels: int = 4, 94 | block_out_channels: Tuple[int] = (16, 32, 96, 256), 95 | ): 96 | super().__init__() 97 | 98 | self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) 99 | 100 | self.blocks = nn.ModuleList([]) 101 | 102 | for i in range(len(block_out_channels) - 1): 103 | channel_in = block_out_channels[i] 104 | channel_out = block_out_channels[i + 1] 105 | self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) 106 | self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=1)) 107 | 108 | self.conv_out = zero_module( 109 | nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) 110 | ) 111 | 112 | def forward(self, conditioning): 113 | embedding = self.conv_in(conditioning) 114 | embedding = F.silu(embedding) 115 | 116 | for block in self.blocks: 117 | embedding = block(embedding) 118 | embedding = F.silu(embedding) 119 | 120 | embedding = self.conv_out(embedding) 121 | 122 | return embedding 123 | 124 | 125 | 126 | class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): 127 | """ 128 | A ControlNet model. 129 | 130 | Args: 131 | in_channels (`int`, defaults to 4): 132 | The number of channels in the input sample. 133 | flip_sin_to_cos (`bool`, defaults to `True`): 134 | Whether to flip the sin to cos in the time embedding. 135 | freq_shift (`int`, defaults to 0): 136 | The frequency shift to apply to the time embedding. 137 | down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): 138 | The tuple of downsample blocks to use. 139 | only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): 140 | block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): 141 | The tuple of output channels for each block. 142 | layers_per_block (`int`, defaults to 2): 143 | The number of layers per block. 144 | downsample_padding (`int`, defaults to 1): 145 | The padding to use for the downsampling convolution. 146 | mid_block_scale_factor (`float`, defaults to 1): 147 | The scale factor to use for the mid block. 148 | act_fn (`str`, defaults to "silu"): 149 | The activation function to use. 150 | norm_num_groups (`int`, *optional*, defaults to 32): 151 | The number of groups to use for the normalization. If None, normalization and activation layers is skipped 152 | in post-processing. 153 | norm_eps (`float`, defaults to 1e-5): 154 | The epsilon to use for the normalization. 155 | cross_attention_dim (`int`, defaults to 1280): 156 | The dimension of the cross attention features. 157 | transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): 158 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for 159 | [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], 160 | [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. 161 | encoder_hid_dim (`int`, *optional*, defaults to None): 162 | If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` 163 | dimension to `cross_attention_dim`. 164 | encoder_hid_dim_type (`str`, *optional*, defaults to `None`): 165 | If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text 166 | embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. 167 | attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): 168 | The dimension of the attention heads. 169 | use_linear_projection (`bool`, defaults to `False`): 170 | class_embed_type (`str`, *optional*, defaults to `None`): 171 | The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, 172 | `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. 173 | addition_embed_type (`str`, *optional*, defaults to `None`): 174 | Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or 175 | "text". "text" will use the `TextTimeEmbedding` layer. 176 | num_class_embeds (`int`, *optional*, defaults to 0): 177 | Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing 178 | class conditioning with `class_embed_type` equal to `None`. 179 | upcast_attention (`bool`, defaults to `False`): 180 | resnet_time_scale_shift (`str`, defaults to `"default"`): 181 | Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. 182 | projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): 183 | The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when 184 | `class_embed_type="projection"`. 185 | controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): 186 | The channel order of conditional image. Will convert to `rgb` if it's `bgr`. 187 | conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): 188 | The tuple of output channel for each block in the `conditioning_embedding` layer. 189 | global_pool_conditions (`bool`, defaults to `False`): 190 | """ 191 | 192 | _supports_gradient_checkpointing = True 193 | 194 | @register_to_config 195 | def __init__( 196 | self, 197 | in_channels: int = 4, 198 | conditioning_channels: int = 4, 199 | flip_sin_to_cos: bool = True, 200 | freq_shift: int = 0, 201 | down_block_types: Tuple[str] = ( 202 | "CrossAttnDownBlock2D", 203 | "CrossAttnDownBlock2D", 204 | "CrossAttnDownBlock2D", 205 | "DownBlock2D", 206 | ), 207 | only_cross_attention: Union[bool, Tuple[bool]] = False, 208 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 209 | layers_per_block: int = 2, 210 | downsample_padding: int = 1, 211 | mid_block_scale_factor: float = 1, 212 | act_fn: str = "silu", 213 | norm_num_groups: Optional[int] = 32, 214 | norm_eps: float = 1e-5, 215 | cross_attention_dim: int = 1280, 216 | transformer_layers_per_block: Union[int, Tuple[int]] = 1, 217 | encoder_hid_dim: Optional[int] = None, 218 | encoder_hid_dim_type: Optional[str] = None, 219 | attention_head_dim: Union[int, Tuple[int]] = 8, 220 | num_attention_heads: Optional[Union[int, Tuple[int]]] = None, 221 | use_linear_projection: bool = False, 222 | class_embed_type: Optional[str] = None, 223 | addition_embed_type: Optional[str] = None, 224 | addition_time_embed_dim: Optional[int] = None, 225 | num_class_embeds: Optional[int] = None, 226 | upcast_attention: bool = False, 227 | resnet_time_scale_shift: str = "default", 228 | projection_class_embeddings_input_dim: Optional[int] = None, 229 | controlnet_conditioning_channel_order: str = "rgb", 230 | conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), 231 | global_pool_conditions: bool = False, 232 | addition_embed_type_num_heads=64, 233 | ): 234 | super().__init__() 235 | 236 | # If `num_attention_heads` is not defined (which is the case for most models) 237 | # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. 238 | # The reason for this behavior is to correct for incorrectly named variables that were introduced 239 | # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 240 | # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking 241 | # which is why we correct for the naming here. 242 | num_attention_heads = num_attention_heads or attention_head_dim 243 | 244 | # Check inputs 245 | if len(block_out_channels) != len(down_block_types): 246 | raise ValueError( 247 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 248 | ) 249 | 250 | if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): 251 | raise ValueError( 252 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." 253 | ) 254 | 255 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): 256 | raise ValueError( 257 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 258 | ) 259 | 260 | if isinstance(transformer_layers_per_block, int): 261 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 262 | 263 | # input 264 | conv_in_kernel = 3 265 | conv_in_padding = (conv_in_kernel - 1) // 2 266 | self.conv_in = nn.Conv2d( 267 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding 268 | ) 269 | 270 | self.conv_in_2 = nn.Conv2d( 271 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding 272 | ) 273 | 274 | # input 275 | # conv_in_kernel = 3 276 | # conv_in_padding = (conv_in_kernel - 1) // 2 277 | # self.conv_in = nn.Conv2d( 278 | # in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, 279 | # padding=conv_in_padding 280 | # ) 281 | 282 | # time 283 | time_embed_dim = block_out_channels[0] * 4 284 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 285 | timestep_input_dim = block_out_channels[0] 286 | self.time_embedding = TimestepEmbedding( 287 | timestep_input_dim, 288 | time_embed_dim, 289 | act_fn=act_fn, 290 | ) 291 | 292 | if encoder_hid_dim_type is None and encoder_hid_dim is not None: 293 | encoder_hid_dim_type = "text_proj" 294 | self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) 295 | logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") 296 | 297 | if encoder_hid_dim is None and encoder_hid_dim_type is not None: 298 | raise ValueError( 299 | f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." 300 | ) 301 | 302 | if encoder_hid_dim_type == "text_proj": 303 | self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) 304 | elif encoder_hid_dim_type == "text_image_proj": 305 | # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much 306 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use 307 | # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` 308 | self.encoder_hid_proj = TextImageProjection( 309 | text_embed_dim=encoder_hid_dim, 310 | image_embed_dim=cross_attention_dim, 311 | cross_attention_dim=cross_attention_dim, 312 | ) 313 | 314 | elif encoder_hid_dim_type is not None: 315 | raise ValueError( 316 | f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." 317 | ) 318 | else: 319 | self.encoder_hid_proj = None 320 | 321 | # class embedding 322 | if class_embed_type is None and num_class_embeds is not None: 323 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 324 | elif class_embed_type == "timestep": 325 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 326 | elif class_embed_type == "identity": 327 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 328 | elif class_embed_type == "projection": 329 | if projection_class_embeddings_input_dim is None: 330 | raise ValueError( 331 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" 332 | ) 333 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except 334 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings 335 | # 2. it projects from an arbitrary input dimension. 336 | # 337 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. 338 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. 339 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors. 340 | self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 341 | else: 342 | self.class_embedding = None 343 | 344 | if addition_embed_type == "text": 345 | if encoder_hid_dim is not None: 346 | text_time_embedding_from_dim = encoder_hid_dim 347 | else: 348 | text_time_embedding_from_dim = cross_attention_dim 349 | 350 | self.add_embedding = TextTimeEmbedding( 351 | text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads 352 | ) 353 | elif addition_embed_type == "text_image": 354 | # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much 355 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use 356 | # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` 357 | self.add_embedding = TextImageTimeEmbedding( 358 | text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim 359 | ) 360 | elif addition_embed_type == "text_time": 361 | self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) 362 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 363 | 364 | elif addition_embed_type is not None: 365 | raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") 366 | 367 | # control net conditioning embedding 368 | # self.controlnet_cond_embedding = ControlNetConditioningEmbedding( 369 | # conditioning_embedding_channels=block_out_channels[0], 370 | # block_out_channels=conditioning_embedding_out_channels, 371 | # conditioning_channels=conditioning_channels, 372 | # ) 373 | 374 | self.down_blocks = nn.ModuleList([]) 375 | self.controlnet_down_blocks = nn.ModuleList([]) 376 | 377 | if isinstance(only_cross_attention, bool): 378 | only_cross_attention = [only_cross_attention] * len(down_block_types) 379 | 380 | if isinstance(attention_head_dim, int): 381 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 382 | 383 | if isinstance(num_attention_heads, int): 384 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 385 | 386 | # down 387 | output_channel = block_out_channels[0] 388 | 389 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 390 | controlnet_block = zero_module(controlnet_block) 391 | self.controlnet_down_blocks.append(controlnet_block) 392 | 393 | for i, down_block_type in enumerate(down_block_types): 394 | input_channel = output_channel 395 | output_channel = block_out_channels[i] 396 | is_final_block = i == len(block_out_channels) - 1 397 | 398 | down_block = get_down_block( 399 | down_block_type, 400 | num_layers=layers_per_block, 401 | transformer_layers_per_block=transformer_layers_per_block[i], 402 | in_channels=input_channel, 403 | out_channels=output_channel, 404 | temb_channels=time_embed_dim, 405 | add_downsample=not is_final_block, 406 | resnet_eps=norm_eps, 407 | resnet_act_fn=act_fn, 408 | resnet_groups=norm_num_groups, 409 | cross_attention_dim=cross_attention_dim, 410 | num_attention_heads=num_attention_heads[i], 411 | attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, 412 | downsample_padding=downsample_padding, 413 | use_linear_projection=use_linear_projection, 414 | only_cross_attention=only_cross_attention[i], 415 | upcast_attention=upcast_attention, 416 | resnet_time_scale_shift=resnet_time_scale_shift, 417 | ) 418 | self.down_blocks.append(down_block) 419 | 420 | for _ in range(layers_per_block): 421 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 422 | controlnet_block = zero_module(controlnet_block) 423 | self.controlnet_down_blocks.append(controlnet_block) 424 | 425 | if not is_final_block: 426 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 427 | controlnet_block = zero_module(controlnet_block) 428 | self.controlnet_down_blocks.append(controlnet_block) 429 | 430 | # mid 431 | mid_block_channel = block_out_channels[-1] 432 | 433 | controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) 434 | controlnet_block = zero_module(controlnet_block) 435 | self.controlnet_mid_block = controlnet_block 436 | 437 | self.mid_block = UNetMidBlock2DCrossAttn( 438 | transformer_layers_per_block=transformer_layers_per_block[-1], 439 | in_channels=mid_block_channel, 440 | temb_channels=time_embed_dim, 441 | resnet_eps=norm_eps, 442 | resnet_act_fn=act_fn, 443 | output_scale_factor=mid_block_scale_factor, 444 | resnet_time_scale_shift=resnet_time_scale_shift, 445 | cross_attention_dim=cross_attention_dim, 446 | num_attention_heads=num_attention_heads[-1], 447 | resnet_groups=norm_num_groups, 448 | use_linear_projection=use_linear_projection, 449 | upcast_attention=upcast_attention, 450 | ) 451 | 452 | @classmethod 453 | def from_unet( 454 | cls, 455 | unet: UNet2DConditionModel, 456 | controlnet_conditioning_channel_order: str = "rgb", 457 | conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), 458 | load_weights_from_unet: bool = True, 459 | ): 460 | r""" 461 | Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. 462 | 463 | Parameters: 464 | unet (`UNet2DConditionModel`): 465 | The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied 466 | where applicable. 467 | """ 468 | transformer_layers_per_block = ( 469 | unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 470 | ) 471 | encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None 472 | encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None 473 | addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None 474 | addition_time_embed_dim = ( 475 | unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None 476 | ) 477 | 478 | controlnet = cls( 479 | encoder_hid_dim=encoder_hid_dim, 480 | encoder_hid_dim_type=encoder_hid_dim_type, 481 | addition_embed_type=addition_embed_type, 482 | addition_time_embed_dim=addition_time_embed_dim, 483 | transformer_layers_per_block=transformer_layers_per_block, 484 | in_channels=unet.config.in_channels, 485 | flip_sin_to_cos=unet.config.flip_sin_to_cos, 486 | freq_shift=unet.config.freq_shift, 487 | down_block_types=unet.config.down_block_types, 488 | only_cross_attention=unet.config.only_cross_attention, 489 | block_out_channels=unet.config.block_out_channels, 490 | layers_per_block=unet.config.layers_per_block, 491 | downsample_padding=unet.config.downsample_padding, 492 | mid_block_scale_factor=unet.config.mid_block_scale_factor, 493 | act_fn=unet.config.act_fn, 494 | norm_num_groups=unet.config.norm_num_groups, 495 | norm_eps=unet.config.norm_eps, 496 | cross_attention_dim=unet.config.cross_attention_dim, 497 | attention_head_dim=unet.config.attention_head_dim, 498 | num_attention_heads=unet.config.num_attention_heads, 499 | use_linear_projection=unet.config.use_linear_projection, 500 | class_embed_type=unet.config.class_embed_type, 501 | num_class_embeds=unet.config.num_class_embeds, 502 | upcast_attention=unet.config.upcast_attention, 503 | resnet_time_scale_shift=unet.config.resnet_time_scale_shift, 504 | projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, 505 | controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, 506 | conditioning_embedding_out_channels=conditioning_embedding_out_channels, 507 | ) 508 | 509 | if load_weights_from_unet: 510 | # conv_in_condition_weight = torch.zeros_like(controlnet.conv_in.weight) 511 | # conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight 512 | # conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight 513 | # controlnet.conv_in.weight = torch.nn.Parameter(conv_in_condition_weight) 514 | # controlnet.conv_in.bias = unet.conv_in.bias 515 | 516 | controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) 517 | controlnet.conv_in_2.load_state_dict(unet.conv_in.state_dict()) 518 | 519 | controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) 520 | controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) 521 | 522 | controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) 523 | controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) 524 | 525 | 526 | if controlnet.class_embedding: 527 | controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) 528 | 529 | controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) 530 | controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) 531 | 532 | return controlnet 533 | 534 | @property 535 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors 536 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 537 | r""" 538 | Returns: 539 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 540 | indexed by its weight name. 541 | """ 542 | # set recursively 543 | processors = {} 544 | 545 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 546 | if hasattr(module, "get_processor"): 547 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 548 | 549 | for sub_name, child in module.named_children(): 550 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 551 | 552 | return processors 553 | 554 | for name, module in self.named_children(): 555 | fn_recursive_add_processors(name, module, processors) 556 | 557 | return processors 558 | 559 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor 560 | def set_attn_processor( 561 | self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False 562 | ): 563 | r""" 564 | Sets the attention processor to use to compute attention. 565 | 566 | Parameters: 567 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 568 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 569 | for **all** `Attention` layers. 570 | 571 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 572 | processor. This is strongly recommended when setting trainable attention processors. 573 | 574 | """ 575 | count = len(self.attn_processors.keys()) 576 | 577 | if isinstance(processor, dict) and len(processor) != count: 578 | raise ValueError( 579 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 580 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 581 | ) 582 | 583 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 584 | if hasattr(module, "set_processor"): 585 | if not isinstance(processor, dict): 586 | module.set_processor(processor, _remove_lora=_remove_lora) 587 | else: 588 | module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) 589 | 590 | for sub_name, child in module.named_children(): 591 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 592 | 593 | for name, module in self.named_children(): 594 | fn_recursive_attn_processor(name, module, processor) 595 | 596 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 597 | def set_default_attn_processor(self): 598 | """ 599 | Disables custom attention processors and sets the default attention implementation. 600 | """ 601 | if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 602 | processor = AttnAddedKVProcessor() 603 | elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 604 | processor = AttnProcessor() 605 | else: 606 | raise ValueError( 607 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 608 | ) 609 | 610 | self.set_attn_processor(processor, _remove_lora=True) 611 | 612 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice 613 | def set_attention_slice(self, slice_size): 614 | r""" 615 | Enable sliced attention computation. 616 | 617 | When this option is enabled, the attention module splits the input tensor in slices to compute attention in 618 | several steps. This is useful for saving some memory in exchange for a small decrease in speed. 619 | 620 | Args: 621 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 622 | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If 623 | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is 624 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 625 | must be a multiple of `slice_size`. 626 | """ 627 | sliceable_head_dims = [] 628 | 629 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): 630 | if hasattr(module, "set_attention_slice"): 631 | sliceable_head_dims.append(module.sliceable_head_dim) 632 | 633 | for child in module.children(): 634 | fn_recursive_retrieve_sliceable_dims(child) 635 | 636 | # retrieve number of attention layers 637 | for module in self.children(): 638 | fn_recursive_retrieve_sliceable_dims(module) 639 | 640 | num_sliceable_layers = len(sliceable_head_dims) 641 | 642 | if slice_size == "auto": 643 | # half the attention head size is usually a good trade-off between 644 | # speed and memory 645 | slice_size = [dim // 2 for dim in sliceable_head_dims] 646 | elif slice_size == "max": 647 | # make smallest slice possible 648 | slice_size = num_sliceable_layers * [1] 649 | 650 | slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 651 | 652 | if len(slice_size) != len(sliceable_head_dims): 653 | raise ValueError( 654 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 655 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 656 | ) 657 | 658 | for i in range(len(slice_size)): 659 | size = slice_size[i] 660 | dim = sliceable_head_dims[i] 661 | if size is not None and size > dim: 662 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 663 | 664 | # Recursively walk through all the children. 665 | # Any children which exposes the set_attention_slice method 666 | # gets the message 667 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 668 | if hasattr(module, "set_attention_slice"): 669 | module.set_attention_slice(slice_size.pop()) 670 | 671 | for child in module.children(): 672 | fn_recursive_set_attention_slice(child, slice_size) 673 | 674 | reversed_slice_size = list(reversed(slice_size)) 675 | for module in self.children(): 676 | fn_recursive_set_attention_slice(module, reversed_slice_size) 677 | 678 | def _set_gradient_checkpointing(self, module, value=False): 679 | if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): 680 | module.gradient_checkpointing = value 681 | 682 | def forward( 683 | self, 684 | sample: torch.FloatTensor, 685 | timestep: Union[torch.Tensor, float, int], 686 | encoder_hidden_states: torch.Tensor, 687 | controlnet_cond: torch.FloatTensor, 688 | conditioning_scale: float = 1.0, 689 | class_labels: Optional[torch.Tensor] = None, 690 | timestep_cond: Optional[torch.Tensor] = None, 691 | attention_mask: Optional[torch.Tensor] = None, 692 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 693 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 694 | guess_mode: bool = False, 695 | return_dict: bool = True, 696 | ) -> Union[ControlNetOutput, Tuple]: 697 | """ 698 | The [`ControlNetModel`] forward method. 699 | 700 | Args: 701 | sample (`torch.FloatTensor`): 702 | The noisy input tensor. 703 | timestep (`Union[torch.Tensor, float, int]`): 704 | The number of timesteps to denoise an input. 705 | encoder_hidden_states (`torch.Tensor`): 706 | The encoder hidden states. 707 | controlnet_cond (`torch.FloatTensor`): 708 | The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. 709 | conditioning_scale (`float`, defaults to `1.0`): 710 | The scale factor for ControlNet outputs. 711 | class_labels (`torch.Tensor`, *optional*, defaults to `None`): 712 | Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. 713 | timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): 714 | Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the 715 | timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep 716 | embeddings. 717 | attention_mask (`torch.Tensor`, *optional*, defaults to `None`): 718 | An attention face_hair_mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the face_hair_mask 719 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 720 | negative values to the attention scores corresponding to "discard" tokens. 721 | added_cond_kwargs (`dict`): 722 | Additional conditions for the Stable Diffusion XL UNet. 723 | cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): 724 | A kwargs dictionary that if specified is passed along to the `AttnProcessor`. 725 | guess_mode (`bool`, defaults to `False`): 726 | In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if 727 | you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. 728 | return_dict (`bool`, defaults to `True`): 729 | Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. 730 | 731 | Returns: 732 | [`~models.controlnet.ControlNetOutput`] **or** `tuple`: 733 | If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is 734 | returned where the first element is the sample tensor. 735 | """ 736 | # check channel order 737 | channel_order = self.config.controlnet_conditioning_channel_order 738 | 739 | if channel_order == "rgb": 740 | # in rgb order by default 741 | ... 742 | elif channel_order == "bgr": 743 | controlnet_cond = torch.flip(controlnet_cond, dims=[1]) 744 | else: 745 | raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") 746 | 747 | # prepare attention_mask 748 | if attention_mask is not None: 749 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 750 | attention_mask = attention_mask.unsqueeze(1) 751 | 752 | # 1. time 753 | timesteps = timestep 754 | if not torch.is_tensor(timesteps): 755 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 756 | # This would be a good case for the `match` statement (Python 3.10+) 757 | is_mps = sample.device.type == "mps" 758 | if isinstance(timestep, float): 759 | dtype = torch.float32 if is_mps else torch.float64 760 | else: 761 | dtype = torch.int32 if is_mps else torch.int64 762 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 763 | elif len(timesteps.shape) == 0: 764 | timesteps = timesteps[None].to(sample.device) 765 | 766 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 767 | timesteps = timesteps.expand(sample.shape[0]) 768 | 769 | t_emb = self.time_proj(timesteps) 770 | 771 | # timesteps does not contain any weights and will always return f32 tensors 772 | # but time_embedding might actually be running in fp16. so we need to cast here. 773 | # there might be better ways to encapsulate this. 774 | t_emb = t_emb.to(dtype=sample.dtype) 775 | 776 | emb = self.time_embedding(t_emb, timestep_cond) 777 | aug_emb = None 778 | 779 | if self.class_embedding is not None: 780 | if class_labels is None: 781 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 782 | 783 | if self.config.class_embed_type == "timestep": 784 | class_labels = self.time_proj(class_labels) 785 | 786 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 787 | emb = emb + class_emb 788 | 789 | if self.config.addition_embed_type is not None: 790 | if self.config.addition_embed_type == "text": 791 | aug_emb = self.add_embedding(encoder_hidden_states) 792 | 793 | elif self.config.addition_embed_type == "text_time": 794 | if "text_embeds" not in added_cond_kwargs: 795 | raise ValueError( 796 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" 797 | ) 798 | text_embeds = added_cond_kwargs.get("text_embeds") 799 | if "time_ids" not in added_cond_kwargs: 800 | raise ValueError( 801 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" 802 | ) 803 | time_ids = added_cond_kwargs.get("time_ids") 804 | time_embeds = self.add_time_proj(time_ids.flatten()) 805 | time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) 806 | 807 | add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) 808 | add_embeds = add_embeds.to(emb.dtype) 809 | aug_emb = self.add_embedding(add_embeds) 810 | 811 | emb = emb + aug_emb if aug_emb is not None else emb 812 | 813 | # 2. pre-process 814 | 815 | ## v1 816 | # controlnet_cond = torch.concat([sample, controlnet_cond], 1) 817 | # sample = self.conv_in(controlnet_cond) 818 | 819 | ## v2 820 | sample = self.conv_in(sample) 821 | controlnet_cond = self.conv_in_2(controlnet_cond) 822 | sample = sample + controlnet_cond 823 | 824 | # 3. down 825 | down_block_res_samples = (sample,) 826 | for downsample_block in self.down_blocks: 827 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 828 | sample, res_samples = downsample_block( 829 | hidden_states=sample, 830 | temb=emb, 831 | encoder_hidden_states=encoder_hidden_states, 832 | attention_mask=attention_mask, 833 | cross_attention_kwargs=cross_attention_kwargs, 834 | ) 835 | else: 836 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 837 | 838 | down_block_res_samples += res_samples 839 | 840 | # 4. mid 841 | if self.mid_block is not None: 842 | sample = self.mid_block( 843 | sample, 844 | emb, 845 | encoder_hidden_states=encoder_hidden_states, 846 | attention_mask=attention_mask, 847 | cross_attention_kwargs=cross_attention_kwargs, 848 | ) 849 | 850 | # 5. Control net blocks 851 | 852 | controlnet_down_block_res_samples = () 853 | 854 | for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): 855 | down_block_res_sample = controlnet_block(down_block_res_sample) 856 | controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) 857 | 858 | down_block_res_samples = controlnet_down_block_res_samples 859 | 860 | mid_block_res_sample = self.controlnet_mid_block(sample) 861 | 862 | # 6. scaling 863 | if guess_mode and not self.config.global_pool_conditions: 864 | scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 865 | scales = scales * conditioning_scale 866 | down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] 867 | mid_block_res_sample = mid_block_res_sample * scales[-1] # last one 868 | else: 869 | down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] 870 | mid_block_res_sample = mid_block_res_sample * conditioning_scale 871 | 872 | if self.config.global_pool_conditions: 873 | down_block_res_samples = [ 874 | torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples 875 | ] 876 | mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) 877 | 878 | if not return_dict: 879 | return (down_block_res_samples, mid_block_res_sample) 880 | 881 | return ControlNetOutput( 882 | down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample 883 | ) 884 | 885 | 886 | def zero_module(module): 887 | for p in module.parameters(): 888 | nn.init.zeros_(p) 889 | return module 890 | -------------------------------------------------------------------------------- /nodes/libs/ref_encoder/reference_control.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 3 | from diffusers.models.attention import BasicTransformerBlock 4 | from packaging import version 5 | import diffusers 6 | 7 | if version.parse(diffusers.__version__) < version.parse("0.26.0"): 8 | from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D 9 | else: 10 | from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D 11 | 12 | def torch_dfs(model: torch.nn.Module): 13 | result = [model] 14 | for child in model.children(): 15 | result += torch_dfs(child) 16 | return result 17 | 18 | class ReferenceAttentionControl(): 19 | 20 | def __init__(self, 21 | unet, 22 | mode="write", 23 | do_classifier_free_guidance=False, 24 | attention_auto_machine_weight=float('inf'), 25 | gn_auto_machine_weight=1.0, 26 | style_fidelity=1.0, 27 | reference_attn=True, 28 | reference_adain=False, 29 | fusion_blocks="full", 30 | batch_size=1, 31 | ) -> None: 32 | # 10. Modify self attention and group norm 33 | self.unet = unet 34 | assert mode in ["read", "write"] 35 | assert fusion_blocks in ["midup", "full"] 36 | self.reference_attn = reference_attn 37 | self.reference_adain = reference_adain 38 | self.fusion_blocks = fusion_blocks 39 | self.register_reference_hooks( 40 | mode, 41 | do_classifier_free_guidance, 42 | attention_auto_machine_weight, 43 | gn_auto_machine_weight, 44 | style_fidelity, 45 | reference_attn, 46 | reference_adain, 47 | fusion_blocks, 48 | batch_size=batch_size, 49 | ) 50 | 51 | def register_reference_hooks( 52 | self, 53 | mode, 54 | do_classifier_free_guidance, 55 | attention_auto_machine_weight, 56 | gn_auto_machine_weight, 57 | style_fidelity, 58 | reference_attn, 59 | reference_adain, 60 | dtype=torch.float16, 61 | batch_size=1, 62 | num_images_per_prompt=1, 63 | device=torch.device("cpu"), 64 | fusion_blocks='midup', 65 | ): 66 | MODE = mode 67 | do_classifier_free_guidance = do_classifier_free_guidance 68 | attention_auto_machine_weight = attention_auto_machine_weight 69 | gn_auto_machine_weight = gn_auto_machine_weight 70 | style_fidelity = style_fidelity 71 | reference_attn = reference_attn 72 | reference_adain = reference_adain 73 | fusion_blocks = fusion_blocks 74 | num_images_per_prompt = num_images_per_prompt 75 | dtype = dtype 76 | if do_classifier_free_guidance: 77 | uc_mask = ( 78 | torch.Tensor( 79 | [1] * batch_size * num_images_per_prompt * 16 + [0] * batch_size * num_images_per_prompt * 16) 80 | .to(device) 81 | .bool() 82 | ) 83 | else: 84 | uc_mask = ( 85 | torch.Tensor([0] * batch_size * num_images_per_prompt * 2) 86 | .to(device) 87 | .bool() 88 | ) 89 | 90 | def hacked_basic_transformer_inner_forward( 91 | self, 92 | hidden_states: torch.FloatTensor, 93 | attention_mask: Optional[torch.FloatTensor] = None, 94 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 95 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 96 | timestep: Optional[torch.LongTensor] = None, 97 | cross_attention_kwargs: Dict[str, Any] = None, 98 | class_labels: Optional[torch.LongTensor] = None, 99 | ): 100 | if self.use_ada_layer_norm: 101 | norm_hidden_states = self.norm1(hidden_states, timestep) 102 | elif self.use_ada_layer_norm_zero: 103 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 104 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 105 | ) 106 | else: 107 | norm_hidden_states = self.norm1(hidden_states) 108 | 109 | # 1. Self-Attention 110 | cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} 111 | if self.only_cross_attention: 112 | attn_output = self.attn1( 113 | norm_hidden_states, 114 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 115 | attention_mask=attention_mask, 116 | **cross_attention_kwargs, 117 | ) 118 | else: 119 | if MODE == "write": 120 | self.bank.append(norm_hidden_states.clone()) 121 | attn_output = self.attn1( 122 | norm_hidden_states, 123 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 124 | attention_mask=attention_mask, 125 | **cross_attention_kwargs, 126 | ) 127 | if MODE == "read": 128 | hidden_states_uc = self.attn1(norm_hidden_states, 129 | encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, 130 | dim=1), 131 | attention_mask=attention_mask) + hidden_states 132 | hidden_states_c = hidden_states_uc.clone() 133 | _uc_mask = uc_mask.clone() 134 | if do_classifier_free_guidance: 135 | if hidden_states.shape[0] != _uc_mask.shape[0]: 136 | _uc_mask = ( 137 | torch.Tensor([1] * (hidden_states.shape[0] // 2) + [0] * (hidden_states.shape[0] // 2)) 138 | .to(device) 139 | .bool() 140 | ) 141 | 142 | hidden_states_c[_uc_mask] = self.attn1( 143 | norm_hidden_states[_uc_mask], 144 | encoder_hidden_states=norm_hidden_states[_uc_mask], 145 | attention_mask=attention_mask, 146 | ) + hidden_states[_uc_mask] 147 | hidden_states = hidden_states_c.clone() 148 | 149 | self.bank.clear() 150 | if self.attn2 is not None: 151 | # Cross-Attention 152 | norm_hidden_states = ( 153 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2( 154 | hidden_states) 155 | ) 156 | hidden_states = ( 157 | self.attn2( 158 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, 159 | attention_mask=attention_mask 160 | ) 161 | + hidden_states 162 | ) 163 | 164 | # Feed-forward 165 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 166 | 167 | return hidden_states 168 | 169 | if self.use_ada_layer_norm_zero: 170 | attn_output = gate_msa.unsqueeze(1) * attn_output 171 | hidden_states = attn_output + hidden_states 172 | 173 | if self.attn2 is not None: 174 | norm_hidden_states = ( 175 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 176 | ) 177 | 178 | # 2. Cross-Attention 179 | attn_output = self.attn2( 180 | norm_hidden_states, 181 | encoder_hidden_states=encoder_hidden_states, 182 | attention_mask=encoder_attention_mask, 183 | **cross_attention_kwargs, 184 | ) 185 | hidden_states = attn_output + hidden_states 186 | 187 | # 3. Feed-forward 188 | norm_hidden_states = self.norm3(hidden_states) 189 | 190 | if self.use_ada_layer_norm_zero: 191 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 192 | 193 | ff_output = self.ff(norm_hidden_states) 194 | 195 | if self.use_ada_layer_norm_zero: 196 | ff_output = gate_mlp.unsqueeze(1) * ff_output 197 | 198 | hidden_states = ff_output + hidden_states 199 | 200 | return hidden_states 201 | 202 | def hacked_mid_forward(self, *args, **kwargs): 203 | eps = 1e-6 204 | x = self.original_forward(*args, **kwargs) 205 | if MODE == "write": 206 | if gn_auto_machine_weight >= self.gn_weight: 207 | var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) 208 | self.mean_bank.append(mean) 209 | self.var_bank.append(var) 210 | if MODE == "read": 211 | if len(self.mean_bank) > 0 and len(self.var_bank) > 0: 212 | var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) 213 | std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 214 | mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) 215 | var_acc = sum(self.var_bank) / float(len(self.var_bank)) 216 | std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 217 | x_uc = (((x - mean) / std) * std_acc) + mean_acc 218 | x_c = x_uc.clone() 219 | if do_classifier_free_guidance and style_fidelity > 0: 220 | x_c[uc_mask] = x[uc_mask] 221 | x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc 222 | self.mean_bank = [] 223 | self.var_bank = [] 224 | return x 225 | 226 | def hack_CrossAttnDownBlock2D_forward( 227 | self, 228 | hidden_states: torch.FloatTensor, 229 | temb: Optional[torch.FloatTensor] = None, 230 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 231 | attention_mask: Optional[torch.FloatTensor] = None, 232 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 233 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 234 | ): 235 | eps = 1e-6 236 | 237 | # TODO(Patrick, William) - attention face_hair_mask is not used 238 | output_states = () 239 | 240 | for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): 241 | hidden_states = resnet(hidden_states, temb) 242 | hidden_states = attn( 243 | hidden_states, 244 | encoder_hidden_states=encoder_hidden_states, 245 | cross_attention_kwargs=cross_attention_kwargs, 246 | attention_mask=attention_mask, 247 | encoder_attention_mask=encoder_attention_mask, 248 | return_dict=False, 249 | )[0] 250 | if MODE == "write": 251 | if gn_auto_machine_weight >= self.gn_weight: 252 | var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) 253 | self.mean_bank.append([mean]) 254 | self.var_bank.append([var]) 255 | if MODE == "read": 256 | if len(self.mean_bank) > 0 and len(self.var_bank) > 0: 257 | var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) 258 | std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 259 | mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) 260 | var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) 261 | std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 262 | hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc 263 | hidden_states_c = hidden_states_uc.clone() 264 | if do_classifier_free_guidance and style_fidelity > 0: 265 | hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) 266 | hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc 267 | 268 | output_states = output_states + (hidden_states,) 269 | 270 | if MODE == "read": 271 | self.mean_bank = [] 272 | self.var_bank = [] 273 | 274 | if self.downsamplers is not None: 275 | for downsampler in self.downsamplers: 276 | hidden_states = downsampler(hidden_states) 277 | 278 | output_states = output_states + (hidden_states,) 279 | 280 | return hidden_states, output_states 281 | 282 | def hacked_DownBlock2D_forward(self, hidden_states, temb=None): 283 | eps = 1e-6 284 | 285 | output_states = () 286 | 287 | for i, resnet in enumerate(self.resnets): 288 | hidden_states = resnet(hidden_states, temb) 289 | 290 | if MODE == "write": 291 | if gn_auto_machine_weight >= self.gn_weight: 292 | var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) 293 | self.mean_bank.append([mean]) 294 | self.var_bank.append([var]) 295 | if MODE == "read": 296 | if len(self.mean_bank) > 0 and len(self.var_bank) > 0: 297 | var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) 298 | std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 299 | mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) 300 | var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) 301 | std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 302 | hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc 303 | hidden_states_c = hidden_states_uc.clone() 304 | if do_classifier_free_guidance and style_fidelity > 0: 305 | hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) 306 | hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc 307 | 308 | output_states = output_states + (hidden_states,) 309 | 310 | if MODE == "read": 311 | self.mean_bank = [] 312 | self.var_bank = [] 313 | 314 | if self.downsamplers is not None: 315 | for downsampler in self.downsamplers: 316 | hidden_states = downsampler(hidden_states) 317 | 318 | output_states = output_states + (hidden_states,) 319 | 320 | return hidden_states, output_states 321 | 322 | def hacked_CrossAttnUpBlock2D_forward( 323 | self, 324 | hidden_states: torch.FloatTensor, 325 | res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], 326 | temb: Optional[torch.FloatTensor] = None, 327 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 328 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 329 | upsample_size: Optional[int] = None, 330 | attention_mask: Optional[torch.FloatTensor] = None, 331 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 332 | ): 333 | eps = 1e-6 334 | # TODO(Patrick, William) - attention face_hair_mask is not used 335 | for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): 336 | # pop res hidden states 337 | res_hidden_states = res_hidden_states_tuple[-1] 338 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 339 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 340 | hidden_states = resnet(hidden_states, temb) 341 | hidden_states = attn( 342 | hidden_states, 343 | encoder_hidden_states=encoder_hidden_states, 344 | cross_attention_kwargs=cross_attention_kwargs, 345 | attention_mask=attention_mask, 346 | encoder_attention_mask=encoder_attention_mask, 347 | return_dict=False, 348 | )[0] 349 | 350 | if MODE == "write": 351 | if gn_auto_machine_weight >= self.gn_weight: 352 | var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) 353 | self.mean_bank.append([mean]) 354 | self.var_bank.append([var]) 355 | if MODE == "read": 356 | if len(self.mean_bank) > 0 and len(self.var_bank) > 0: 357 | var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) 358 | std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 359 | mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) 360 | var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) 361 | std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 362 | hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc 363 | hidden_states_c = hidden_states_uc.clone() 364 | if do_classifier_free_guidance and style_fidelity > 0: 365 | hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) 366 | hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc 367 | 368 | if MODE == "read": 369 | self.mean_bank = [] 370 | self.var_bank = [] 371 | 372 | if self.upsamplers is not None: 373 | for upsampler in self.upsamplers: 374 | hidden_states = upsampler(hidden_states, upsample_size) 375 | 376 | return hidden_states 377 | 378 | def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): 379 | eps = 1e-6 380 | for i, resnet in enumerate(self.resnets): 381 | # pop res hidden states 382 | res_hidden_states = res_hidden_states_tuple[-1] 383 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 384 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 385 | hidden_states = resnet(hidden_states, temb) 386 | 387 | if MODE == "write": 388 | if gn_auto_machine_weight >= self.gn_weight: 389 | var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) 390 | self.mean_bank.append([mean]) 391 | self.var_bank.append([var]) 392 | if MODE == "read": 393 | if len(self.mean_bank) > 0 and len(self.var_bank) > 0: 394 | var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) 395 | std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 396 | mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) 397 | var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) 398 | std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 399 | hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc 400 | hidden_states_c = hidden_states_uc.clone() 401 | if do_classifier_free_guidance and style_fidelity > 0: 402 | hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype) 403 | hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc 404 | 405 | if MODE == "read": 406 | self.mean_bank = [] 407 | self.var_bank = [] 408 | 409 | if self.upsamplers is not None: 410 | for upsampler in self.upsamplers: 411 | hidden_states = upsampler(hidden_states, upsample_size) 412 | 413 | return hidden_states 414 | 415 | if self.reference_attn: 416 | if self.fusion_blocks == "midup": 417 | attn_modules = [module for module in (torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)) 418 | if isinstance(module, BasicTransformerBlock)] 419 | elif self.fusion_blocks == "full": 420 | attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)] 421 | attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) 422 | 423 | for i, module in enumerate(attn_modules): 424 | module._original_inner_forward = module.forward 425 | module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) 426 | module.bank = [] 427 | module.attn_weight = float(i) / float(len(attn_modules)) 428 | 429 | if self.reference_adain: 430 | gn_modules = [self.unet.mid_block] 431 | self.unet.mid_block.gn_weight = 0 432 | 433 | down_blocks = self.unet.down_blocks 434 | for w, module in enumerate(down_blocks): 435 | module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) 436 | gn_modules.append(module) 437 | 438 | up_blocks = self.unet.up_blocks 439 | for w, module in enumerate(up_blocks): 440 | module.gn_weight = float(w) / float(len(up_blocks)) 441 | gn_modules.append(module) 442 | 443 | for i, module in enumerate(gn_modules): 444 | if getattr(module, "original_forward", None) is None: 445 | module.original_forward = module.forward 446 | if i == 0: 447 | # mid_block 448 | module.forward = hacked_mid_forward.__get__(module, torch.nn.Module) 449 | elif isinstance(module, CrossAttnDownBlock2D): 450 | module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D) 451 | elif isinstance(module, DownBlock2D): 452 | module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D) 453 | elif isinstance(module, CrossAttnUpBlock2D): 454 | module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D) 455 | elif isinstance(module, UpBlock2D): 456 | module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D) 457 | module.mean_bank = [] 458 | module.var_bank = [] 459 | module.gn_weight *= 2 460 | 461 | def update(self, writer, dtype=torch.float16): 462 | if self.reference_attn: 463 | if self.fusion_blocks == "midup": 464 | reader_attn_modules = [module for module in 465 | (torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)) if 466 | isinstance(module, BasicTransformerBlock)] 467 | writer_attn_modules = [module for module in 468 | (torch_dfs(writer.unet.mid_block) + torch_dfs(writer.unet.up_blocks)) if 469 | isinstance(module, BasicTransformerBlock)] 470 | elif self.fusion_blocks == "full": 471 | reader_attn_modules = [module for module in torch_dfs(self.unet) if 472 | isinstance(module, BasicTransformerBlock)] 473 | writer_attn_modules = [module for module in torch_dfs(writer.unet) if 474 | isinstance(module, BasicTransformerBlock)] 475 | reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) 476 | writer_attn_modules = sorted(writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) 477 | for r, w in zip(reader_attn_modules, writer_attn_modules): 478 | r.bank = [v.clone().to(dtype) for v in w.bank] 479 | 480 | if self.reference_adain: 481 | reader_gn_modules = [self.unet.mid_block] 482 | 483 | down_blocks = self.unet.down_blocks 484 | for w, module in enumerate(down_blocks): 485 | reader_gn_modules.append(module) 486 | 487 | up_blocks = self.unet.up_blocks 488 | for w, module in enumerate(up_blocks): 489 | reader_gn_modules.append(module) 490 | 491 | writer_gn_modules = [writer.unet.mid_block] 492 | 493 | down_blocks = writer.unet.down_blocks 494 | for w, module in enumerate(down_blocks): 495 | writer_gn_modules.append(module) 496 | 497 | up_blocks = writer.unet.up_blocks 498 | for w, module in enumerate(up_blocks): 499 | writer_gn_modules.append(module) 500 | 501 | for r, w in zip(reader_gn_modules, writer_gn_modules): 502 | if len(w.mean_bank) > 0 and isinstance(w.mean_bank[0], list): 503 | r.mean_bank = [[v.clone().to(dtype) for v in vl] for vl in w.mean_bank] 504 | r.var_bank = [[v.clone().to(dtype) for v in vl] for vl in w.var_bank] 505 | else: 506 | r.mean_bank = [v.clone().to(dtype) for v in w.mean_bank] 507 | r.var_bank = [v.clone().to(dtype) for v in w.var_bank] 508 | 509 | def clear(self): 510 | if self.reference_attn: 511 | if self.fusion_blocks == "midup": 512 | reader_attn_modules = [module for module in 513 | (torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)) if 514 | isinstance(module, BasicTransformerBlock)] 515 | elif self.fusion_blocks == "full": 516 | reader_attn_modules = [module for module in torch_dfs(self.unet) if 517 | isinstance(module, BasicTransformerBlock)] 518 | reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) 519 | for r in reader_attn_modules: 520 | r.bank.clear() 521 | if self.reference_adain: 522 | reader_gn_modules = [self.unet.mid_block] 523 | 524 | down_blocks = self.unet.down_blocks 525 | for w, module in enumerate(down_blocks): 526 | reader_gn_modules.append(module) 527 | 528 | up_blocks = self.unet.up_blocks 529 | for w, module in enumerate(up_blocks): 530 | reader_gn_modules.append(module) 531 | 532 | for r in reader_gn_modules: 533 | r.mean_bank.clear() 534 | r.var_bank.clear() 535 | -------------------------------------------------------------------------------- /nodes/libs/ref_encoder/reference_unet.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.checkpoint 7 | from packaging import version 8 | import diffusers 9 | from diffusers.configuration_utils import ConfigMixin, register_to_config 10 | from diffusers.loaders import UNet2DConditionLoadersMixin 11 | from diffusers.utils import BaseOutput, logging 12 | from diffusers.models.activations import get_activation 13 | from diffusers.models.attention_processor import ( 14 | ADDED_KV_ATTENTION_PROCESSORS, 15 | CROSS_ATTENTION_PROCESSORS, 16 | AttentionProcessor, 17 | AttnAddedKVProcessor, 18 | AttnProcessor, 19 | ) 20 | from diffusers.models.lora import LoRALinearLayer 21 | from diffusers.models.embeddings import ( 22 | GaussianFourierProjection, 23 | ImageHintTimeEmbedding, 24 | ImageProjection, 25 | ImageTimeEmbedding, 26 | # PositionNet, 27 | TextImageProjection, 28 | TextImageTimeEmbedding, 29 | TextTimeEmbedding, 30 | TimestepEmbedding, 31 | Timesteps, 32 | ) 33 | from diffusers.models.modeling_utils import ModelMixin 34 | if version.parse(diffusers.__version__) < version.parse("0.26.0"): 35 | from diffusers.models.embeddings import PositionNet 36 | from diffusers.models.unet_2d_blocks import ( 37 | UNetMidBlock2DCrossAttn, 38 | UNetMidBlock2DSimpleCrossAttn, 39 | get_down_block, 40 | get_up_block, 41 | ) 42 | else: 43 | from diffusers.models.unets.unet_2d_blocks import ( 44 | UNetMidBlock2DCrossAttn, 45 | UNetMidBlock2DSimpleCrossAttn, 46 | get_down_block, 47 | get_up_block, 48 | ) 49 | from diffusers.models.embeddings import GLIGENTextBoundingboxProjection as PositionNet 50 | 51 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 52 | 53 | 54 | class Identity(torch.nn.Module): 55 | def __init__(self, scale=None, *args, **kwargs) -> None: 56 | super(Identity, self).__init__() 57 | def forward(self, input, *args, **kwargs): 58 | return input 59 | 60 | 61 | class _LoRACompatibleLinear(nn.Module): 62 | """ 63 | A Linear layer that can be used with LoRA. 64 | """ 65 | 66 | def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): 67 | super().__init__(*args, **kwargs) 68 | self.lora_layer = lora_layer 69 | 70 | def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): 71 | self.lora_layer = lora_layer 72 | 73 | def _fuse_lora(self): 74 | pass 75 | 76 | def _unfuse_lora(self): 77 | pass 78 | 79 | def forward(self, hidden_states, scale=None, lora_scale: int = 1): 80 | return hidden_states 81 | 82 | 83 | @dataclass 84 | class UNet2DConditionOutput(BaseOutput): 85 | """ 86 | The output of [`UNet2DConditionModel`]. 87 | 88 | Args: 89 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 90 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. 91 | """ 92 | 93 | sample: torch.FloatTensor = None 94 | 95 | 96 | class RefHairUnet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): 97 | r""" 98 | A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample 99 | shaped output. 100 | 101 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 102 | for all models (such as downloading or saving). 103 | 104 | Parameters: 105 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): 106 | Height and width of input/output sample. 107 | in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. 108 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. 109 | center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. 110 | flip_sin_to_cos (`bool`, *optional*, defaults to `False`): 111 | Whether to flip the sin to cos in the time embedding. 112 | freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. 113 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): 114 | The tuple of downsample blocks to use. 115 | mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): 116 | Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or 117 | `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. 118 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): 119 | The tuple of upsample blocks to use. 120 | only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): 121 | Whether to include self-attention in the basic transformer blocks, see 122 | [`~models.attention.BasicTransformerBlock`]. 123 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 124 | The tuple of output channels for each block. 125 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 126 | downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. 127 | mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. 128 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. 129 | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. 130 | If `None`, normalization and activation layers is skipped in post-processing. 131 | norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. 132 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): 133 | The dimension of the cross attention features. 134 | transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): 135 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for 136 | [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], 137 | [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. 138 | encoder_hid_dim (`int`, *optional*, defaults to None): 139 | If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` 140 | dimension to `cross_attention_dim`. 141 | encoder_hid_dim_type (`str`, *optional*, defaults to `None`): 142 | If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text 143 | embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. 144 | attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. 145 | num_attention_heads (`int`, *optional*): 146 | The number of attention heads. If not defined, defaults to `attention_head_dim` 147 | resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config 148 | for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. 149 | class_embed_type (`str`, *optional*, defaults to `None`): 150 | The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, 151 | `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. 152 | addition_embed_type (`str`, *optional*, defaults to `None`): 153 | Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or 154 | "text". "text" will use the `TextTimeEmbedding` layer. 155 | addition_time_embed_dim: (`int`, *optional*, defaults to `None`): 156 | Dimension for the timestep embeddings. 157 | num_class_embeds (`int`, *optional*, defaults to `None`): 158 | Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing 159 | class conditioning with `class_embed_type` equal to `None`. 160 | time_embedding_type (`str`, *optional*, defaults to `positional`): 161 | The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. 162 | time_embedding_dim (`int`, *optional*, defaults to `None`): 163 | An optional override for the dimension of the projected time embedding. 164 | time_embedding_act_fn (`str`, *optional*, defaults to `None`): 165 | Optional activation function to use only once on the time embeddings before they are passed to the rest of 166 | the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. 167 | timestep_post_act (`str`, *optional*, defaults to `None`): 168 | The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. 169 | time_cond_proj_dim (`int`, *optional*, defaults to `None`): 170 | The dimension of `cond_proj` layer in the timestep embedding. 171 | conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. 172 | conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. 173 | projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when 174 | `class_embed_type="projection"`. Required when `class_embed_type="projection"`. 175 | class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time 176 | embeddings with the class embeddings. 177 | mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): 178 | Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If 179 | `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the 180 | `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` 181 | otherwise. 182 | """ 183 | 184 | _supports_gradient_checkpointing = True 185 | 186 | @register_to_config 187 | def __init__( 188 | self, 189 | sample_size: Optional[int] = None, 190 | in_channels: int = 4, 191 | out_channels: int = 4, 192 | center_input_sample: bool = False, 193 | flip_sin_to_cos: bool = True, 194 | freq_shift: int = 0, 195 | down_block_types: Tuple[str] = ( 196 | "CrossAttnDownBlock2D", 197 | "CrossAttnDownBlock2D", 198 | "CrossAttnDownBlock2D", 199 | "DownBlock2D", 200 | ), 201 | mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", 202 | up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), 203 | only_cross_attention: Union[bool, Tuple[bool]] = False, 204 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 205 | layers_per_block: Union[int, Tuple[int]] = 2, 206 | downsample_padding: int = 1, 207 | mid_block_scale_factor: float = 1, 208 | act_fn: str = "silu", 209 | norm_num_groups: Optional[int] = 32, 210 | norm_eps: float = 1e-5, 211 | cross_attention_dim: Union[int, Tuple[int]] = 1280, 212 | transformer_layers_per_block: Union[int, Tuple[int]] = 1, 213 | encoder_hid_dim: Optional[int] = None, 214 | encoder_hid_dim_type: Optional[str] = None, 215 | attention_head_dim: Union[int, Tuple[int]] = 8, 216 | num_attention_heads: Optional[Union[int, Tuple[int]]] = None, 217 | dual_cross_attention: bool = False, 218 | use_linear_projection: bool = False, 219 | class_embed_type: Optional[str] = None, 220 | addition_embed_type: Optional[str] = None, 221 | addition_time_embed_dim: Optional[int] = None, 222 | num_class_embeds: Optional[int] = None, 223 | upcast_attention: bool = False, 224 | resnet_time_scale_shift: str = "default", 225 | resnet_skip_time_act: bool = False, 226 | resnet_out_scale_factor: int = 1.0, 227 | time_embedding_type: str = "positional", 228 | time_embedding_dim: Optional[int] = None, 229 | time_embedding_act_fn: Optional[str] = None, 230 | timestep_post_act: Optional[str] = None, 231 | time_cond_proj_dim: Optional[int] = None, 232 | conv_in_kernel: int = 3, 233 | conv_out_kernel: int = 3, 234 | projection_class_embeddings_input_dim: Optional[int] = None, 235 | attention_type: str = "default", 236 | class_embeddings_concat: bool = False, 237 | mid_block_only_cross_attention: Optional[bool] = None, 238 | cross_attention_norm: Optional[str] = None, 239 | addition_embed_type_num_heads=64, 240 | ): 241 | super().__init__() 242 | 243 | self.sample_size = sample_size 244 | 245 | if num_attention_heads is not None: 246 | raise ValueError( 247 | "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." 248 | ) 249 | 250 | # If `num_attention_heads` is not defined (which is the case for most models) 251 | # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. 252 | # The reason for this behavior is to correct for incorrectly named variables that were introduced 253 | # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 254 | # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking 255 | # which is why we correct for the naming here. 256 | num_attention_heads = num_attention_heads or attention_head_dim 257 | 258 | # Check inputs 259 | if len(down_block_types) != len(up_block_types): 260 | raise ValueError( 261 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 262 | ) 263 | 264 | if len(block_out_channels) != len(down_block_types): 265 | raise ValueError( 266 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 267 | ) 268 | 269 | if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): 270 | raise ValueError( 271 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." 272 | ) 273 | 274 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): 275 | raise ValueError( 276 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 277 | ) 278 | 279 | if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): 280 | raise ValueError( 281 | f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." 282 | ) 283 | 284 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): 285 | raise ValueError( 286 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." 287 | ) 288 | 289 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): 290 | raise ValueError( 291 | f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." 292 | ) 293 | 294 | # input 295 | conv_in_padding = (conv_in_kernel - 1) // 2 296 | self.conv_in = nn.Conv2d( 297 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding 298 | ) 299 | 300 | # time 301 | if time_embedding_type == "fourier": 302 | time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 303 | if time_embed_dim % 2 != 0: 304 | raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") 305 | self.time_proj = GaussianFourierProjection( 306 | time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos 307 | ) 308 | timestep_input_dim = time_embed_dim 309 | elif time_embedding_type == "positional": 310 | time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 311 | 312 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 313 | timestep_input_dim = block_out_channels[0] 314 | else: 315 | raise ValueError( 316 | f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." 317 | ) 318 | 319 | self.time_embedding = TimestepEmbedding( 320 | timestep_input_dim, 321 | time_embed_dim, 322 | act_fn=act_fn, 323 | post_act_fn=timestep_post_act, 324 | cond_proj_dim=time_cond_proj_dim, 325 | ) 326 | 327 | if encoder_hid_dim_type is None and encoder_hid_dim is not None: 328 | encoder_hid_dim_type = "text_proj" 329 | self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) 330 | logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") 331 | 332 | if encoder_hid_dim is None and encoder_hid_dim_type is not None: 333 | raise ValueError( 334 | f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." 335 | ) 336 | 337 | if encoder_hid_dim_type == "text_proj": 338 | self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) 339 | elif encoder_hid_dim_type == "text_image_proj": 340 | # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much 341 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use 342 | # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` 343 | self.encoder_hid_proj = TextImageProjection( 344 | text_embed_dim=encoder_hid_dim, 345 | image_embed_dim=cross_attention_dim, 346 | cross_attention_dim=cross_attention_dim, 347 | ) 348 | elif encoder_hid_dim_type == "image_proj": 349 | # Kandinsky 2.2 350 | self.encoder_hid_proj = ImageProjection( 351 | image_embed_dim=encoder_hid_dim, 352 | cross_attention_dim=cross_attention_dim, 353 | ) 354 | elif encoder_hid_dim_type is not None: 355 | raise ValueError( 356 | f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." 357 | ) 358 | else: 359 | self.encoder_hid_proj = None 360 | 361 | # class embedding 362 | if class_embed_type is None and num_class_embeds is not None: 363 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 364 | elif class_embed_type == "timestep": 365 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) 366 | elif class_embed_type == "identity": 367 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 368 | elif class_embed_type == "projection": 369 | if projection_class_embeddings_input_dim is None: 370 | raise ValueError( 371 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" 372 | ) 373 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except 374 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings 375 | # 2. it projects from an arbitrary input dimension. 376 | # 377 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. 378 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. 379 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors. 380 | self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 381 | elif class_embed_type == "simple_projection": 382 | if projection_class_embeddings_input_dim is None: 383 | raise ValueError( 384 | "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" 385 | ) 386 | self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) 387 | else: 388 | self.class_embedding = None 389 | 390 | if addition_embed_type == "text": 391 | if encoder_hid_dim is not None: 392 | text_time_embedding_from_dim = encoder_hid_dim 393 | else: 394 | text_time_embedding_from_dim = cross_attention_dim 395 | 396 | self.add_embedding = TextTimeEmbedding( 397 | text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads 398 | ) 399 | elif addition_embed_type == "text_image": 400 | # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much 401 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use 402 | # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` 403 | self.add_embedding = TextImageTimeEmbedding( 404 | text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim 405 | ) 406 | elif addition_embed_type == "text_time": 407 | self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) 408 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 409 | elif addition_embed_type == "image": 410 | # Kandinsky 2.2 411 | self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) 412 | elif addition_embed_type == "image_hint": 413 | # Kandinsky 2.2 ControlNet 414 | self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) 415 | elif addition_embed_type is not None: 416 | raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") 417 | 418 | if time_embedding_act_fn is None: 419 | self.time_embed_act = None 420 | else: 421 | self.time_embed_act = get_activation(time_embedding_act_fn) 422 | 423 | self.down_blocks = nn.ModuleList([]) 424 | self.up_blocks = nn.ModuleList([]) 425 | 426 | if isinstance(only_cross_attention, bool): 427 | if mid_block_only_cross_attention is None: 428 | mid_block_only_cross_attention = only_cross_attention 429 | 430 | only_cross_attention = [only_cross_attention] * len(down_block_types) 431 | 432 | if mid_block_only_cross_attention is None: 433 | mid_block_only_cross_attention = False 434 | 435 | if isinstance(num_attention_heads, int): 436 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 437 | 438 | if isinstance(attention_head_dim, int): 439 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 440 | 441 | if isinstance(cross_attention_dim, int): 442 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types) 443 | 444 | if isinstance(layers_per_block, int): 445 | layers_per_block = [layers_per_block] * len(down_block_types) 446 | 447 | if isinstance(transformer_layers_per_block, int): 448 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 449 | 450 | if class_embeddings_concat: 451 | # The time embeddings are concatenated with the class embeddings. The dimension of the 452 | # time embeddings passed to the down, middle, and up blocks is twice the dimension of the 453 | # regular time embeddings 454 | blocks_time_embed_dim = time_embed_dim * 2 455 | else: 456 | blocks_time_embed_dim = time_embed_dim 457 | 458 | # down 459 | output_channel = block_out_channels[0] 460 | for i, down_block_type in enumerate(down_block_types): 461 | input_channel = output_channel 462 | output_channel = block_out_channels[i] 463 | is_final_block = i == len(block_out_channels) - 1 464 | 465 | down_block = get_down_block( 466 | down_block_type, 467 | num_layers=layers_per_block[i], 468 | transformer_layers_per_block=transformer_layers_per_block[i], 469 | in_channels=input_channel, 470 | out_channels=output_channel, 471 | temb_channels=blocks_time_embed_dim, 472 | add_downsample=not is_final_block, 473 | resnet_eps=norm_eps, 474 | resnet_act_fn=act_fn, 475 | resnet_groups=norm_num_groups, 476 | cross_attention_dim=cross_attention_dim[i], 477 | num_attention_heads=num_attention_heads[i], 478 | downsample_padding=downsample_padding, 479 | dual_cross_attention=dual_cross_attention, 480 | use_linear_projection=use_linear_projection, 481 | only_cross_attention=only_cross_attention[i], 482 | upcast_attention=upcast_attention, 483 | resnet_time_scale_shift=resnet_time_scale_shift, 484 | attention_type=attention_type, 485 | resnet_skip_time_act=resnet_skip_time_act, 486 | resnet_out_scale_factor=resnet_out_scale_factor, 487 | cross_attention_norm=cross_attention_norm, 488 | attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, 489 | ) 490 | self.down_blocks.append(down_block) 491 | 492 | # mid 493 | if mid_block_type == "UNetMidBlock2DCrossAttn": 494 | self.mid_block = UNetMidBlock2DCrossAttn( 495 | transformer_layers_per_block=transformer_layers_per_block[-1], 496 | in_channels=block_out_channels[-1], 497 | temb_channels=blocks_time_embed_dim, 498 | resnet_eps=norm_eps, 499 | resnet_act_fn=act_fn, 500 | output_scale_factor=mid_block_scale_factor, 501 | resnet_time_scale_shift=resnet_time_scale_shift, 502 | cross_attention_dim=cross_attention_dim[-1], 503 | num_attention_heads=num_attention_heads[-1], 504 | resnet_groups=norm_num_groups, 505 | dual_cross_attention=dual_cross_attention, 506 | use_linear_projection=use_linear_projection, 507 | upcast_attention=upcast_attention, 508 | attention_type=attention_type, 509 | ) 510 | elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": 511 | self.mid_block = UNetMidBlock2DSimpleCrossAttn( 512 | in_channels=block_out_channels[-1], 513 | temb_channels=blocks_time_embed_dim, 514 | resnet_eps=norm_eps, 515 | resnet_act_fn=act_fn, 516 | output_scale_factor=mid_block_scale_factor, 517 | cross_attention_dim=cross_attention_dim[-1], 518 | attention_head_dim=attention_head_dim[-1], 519 | resnet_groups=norm_num_groups, 520 | resnet_time_scale_shift=resnet_time_scale_shift, 521 | skip_time_act=resnet_skip_time_act, 522 | only_cross_attention=mid_block_only_cross_attention, 523 | cross_attention_norm=cross_attention_norm, 524 | ) 525 | elif mid_block_type is None: 526 | self.mid_block = None 527 | else: 528 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 529 | 530 | # count how many layers upsample the images 531 | self.num_upsamplers = 0 532 | 533 | # up 534 | reversed_block_out_channels = list(reversed(block_out_channels)) 535 | reversed_num_attention_heads = list(reversed(num_attention_heads)) 536 | reversed_layers_per_block = list(reversed(layers_per_block)) 537 | reversed_cross_attention_dim = list(reversed(cross_attention_dim)) 538 | reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) 539 | only_cross_attention = list(reversed(only_cross_attention)) 540 | 541 | output_channel = reversed_block_out_channels[0] 542 | for i, up_block_type in enumerate(up_block_types): 543 | is_final_block = i == len(block_out_channels) - 1 544 | 545 | prev_output_channel = output_channel 546 | output_channel = reversed_block_out_channels[i] 547 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 548 | 549 | # add upsample block for all BUT final layer 550 | if not is_final_block: 551 | add_upsample = True 552 | self.num_upsamplers += 1 553 | else: 554 | add_upsample = False 555 | 556 | up_block = get_up_block( 557 | up_block_type, 558 | num_layers=reversed_layers_per_block[i] + 1, 559 | transformer_layers_per_block=reversed_transformer_layers_per_block[i], 560 | in_channels=input_channel, 561 | out_channels=output_channel, 562 | prev_output_channel=prev_output_channel, 563 | temb_channels=blocks_time_embed_dim, 564 | add_upsample=add_upsample, 565 | resnet_eps=norm_eps, 566 | resnet_act_fn=act_fn, 567 | resnet_groups=norm_num_groups, 568 | cross_attention_dim=reversed_cross_attention_dim[i], 569 | num_attention_heads=reversed_num_attention_heads[i], 570 | dual_cross_attention=dual_cross_attention, 571 | use_linear_projection=use_linear_projection, 572 | only_cross_attention=only_cross_attention[i], 573 | upcast_attention=upcast_attention, 574 | resnet_time_scale_shift=resnet_time_scale_shift, 575 | attention_type=attention_type, 576 | resnet_skip_time_act=resnet_skip_time_act, 577 | resnet_out_scale_factor=resnet_out_scale_factor, 578 | cross_attention_norm=cross_attention_norm, 579 | attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, 580 | ) 581 | self.up_blocks.append(up_block) 582 | prev_output_channel = output_channel 583 | self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear() 584 | self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear() 585 | self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear() 586 | self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()]) 587 | self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity() 588 | self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None 589 | self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity() 590 | self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity() 591 | self.up_blocks[3].attentions[2].proj_out = Identity() 592 | 593 | if attention_type in ["gated", "gated-text-image"]: 594 | positive_len = 768 595 | if isinstance(cross_attention_dim, int): 596 | positive_len = cross_attention_dim 597 | elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): 598 | positive_len = cross_attention_dim[0] 599 | 600 | feature_type = "text-only" if attention_type == "gated" else "text-image" 601 | self.position_net = PositionNet( 602 | positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type 603 | ) 604 | 605 | @property 606 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 607 | r""" 608 | Returns: 609 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 610 | indexed by its weight name. 611 | """ 612 | # set recursively 613 | processors = {} 614 | 615 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 616 | if hasattr(module, "get_processor"): 617 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 618 | 619 | for sub_name, child in module.named_children(): 620 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 621 | 622 | return processors 623 | 624 | for name, module in self.named_children(): 625 | fn_recursive_add_processors(name, module, processors) 626 | 627 | return processors 628 | 629 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 630 | r""" 631 | Sets the attention processor to use to compute attention. 632 | 633 | Parameters: 634 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 635 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 636 | for **all** `Attention` layers. 637 | 638 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 639 | processor. This is strongly recommended when setting trainable attention processors. 640 | 641 | """ 642 | count = len(self.attn_processors.keys()) 643 | 644 | if isinstance(processor, dict) and len(processor) != count: 645 | raise ValueError( 646 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 647 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 648 | ) 649 | 650 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 651 | if hasattr(module, "set_processor"): 652 | if not isinstance(processor, dict): 653 | module.set_processor(processor) 654 | else: 655 | module.set_processor(processor.pop(f"{name}.processor")) 656 | 657 | for sub_name, child in module.named_children(): 658 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 659 | 660 | for name, module in self.named_children(): 661 | fn_recursive_attn_processor(name, module, processor) 662 | 663 | def set_default_attn_processor(self): 664 | """ 665 | Disables custom attention processors and sets the default attention implementation. 666 | """ 667 | if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 668 | processor = AttnAddedKVProcessor() 669 | elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 670 | processor = AttnProcessor() 671 | else: 672 | raise ValueError( 673 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 674 | ) 675 | 676 | self.set_attn_processor(processor) 677 | 678 | def set_attention_slice(self, slice_size): 679 | r""" 680 | Enable sliced attention computation. 681 | 682 | When this option is enabled, the attention module splits the input tensor in slices to compute attention in 683 | several steps. This is useful for saving some memory in exchange for a small decrease in speed. 684 | 685 | Args: 686 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 687 | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If 688 | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is 689 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 690 | must be a multiple of `slice_size`. 691 | """ 692 | sliceable_head_dims = [] 693 | 694 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): 695 | if hasattr(module, "set_attention_slice"): 696 | sliceable_head_dims.append(module.sliceable_head_dim) 697 | 698 | for child in module.children(): 699 | fn_recursive_retrieve_sliceable_dims(child) 700 | 701 | # retrieve number of attention layers 702 | for module in self.children(): 703 | fn_recursive_retrieve_sliceable_dims(module) 704 | 705 | num_sliceable_layers = len(sliceable_head_dims) 706 | 707 | if slice_size == "auto": 708 | # half the attention head size is usually a good trade-off between 709 | # speed and memory 710 | slice_size = [dim // 2 for dim in sliceable_head_dims] 711 | elif slice_size == "max": 712 | # make smallest slice possible 713 | slice_size = num_sliceable_layers * [1] 714 | 715 | slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 716 | 717 | if len(slice_size) != len(sliceable_head_dims): 718 | raise ValueError( 719 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 720 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 721 | ) 722 | 723 | for i in range(len(slice_size)): 724 | size = slice_size[i] 725 | dim = sliceable_head_dims[i] 726 | if size is not None and size > dim: 727 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 728 | 729 | # Recursively walk through all the children. 730 | # Any children which exposes the set_attention_slice method 731 | # gets the message 732 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 733 | if hasattr(module, "set_attention_slice"): 734 | module.set_attention_slice(slice_size.pop()) 735 | 736 | for child in module.children(): 737 | fn_recursive_set_attention_slice(child, slice_size) 738 | 739 | reversed_slice_size = list(reversed(slice_size)) 740 | for module in self.children(): 741 | fn_recursive_set_attention_slice(module, reversed_slice_size) 742 | 743 | def _set_gradient_checkpointing(self, module, value=False): 744 | if hasattr(module, "gradient_checkpointing"): 745 | module.gradient_checkpointing = value 746 | 747 | def forward( 748 | self, 749 | sample: torch.FloatTensor, 750 | timestep: Union[torch.Tensor, float, int], 751 | encoder_hidden_states: torch.Tensor, 752 | class_labels: Optional[torch.Tensor] = None, 753 | timestep_cond: Optional[torch.Tensor] = None, 754 | attention_mask: Optional[torch.Tensor] = None, 755 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 756 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 757 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 758 | mid_block_additional_residual: Optional[torch.Tensor] = None, 759 | encoder_attention_mask: Optional[torch.Tensor] = None, 760 | return_dict: bool = True, 761 | ) -> Union[UNet2DConditionOutput, Tuple]: 762 | r""" 763 | The [`UNet2DConditionModel`] forward method. 764 | 765 | Args: 766 | sample (`torch.FloatTensor`): 767 | The noisy input tensor with the following shape `(batch, channel, height, width)`. 768 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 769 | encoder_hidden_states (`torch.FloatTensor`): 770 | The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. 771 | encoder_attention_mask (`torch.Tensor`): 772 | A cross-attention face_hair_mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If 773 | `True` the face_hair_mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, 774 | which adds large negative values to the attention scores corresponding to "discard" tokens. 775 | return_dict (`bool`, *optional*, defaults to `True`): 776 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 777 | tuple. 778 | cross_attention_kwargs (`dict`, *optional*): 779 | A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. 780 | added_cond_kwargs: (`dict`, *optional*): 781 | A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that 782 | are passed along to the UNet blocks. 783 | 784 | Returns: 785 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 786 | If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise 787 | a `tuple` is returned where the first element is the sample tensor. 788 | """ 789 | # By default samples have to be AT least a multiple of the overall upsampling factor. 790 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). 791 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 792 | # on the fly if necessary. 793 | default_overall_up_factor = 2**self.num_upsamplers 794 | 795 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 796 | forward_upsample_size = False 797 | upsample_size = None 798 | 799 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 800 | logger.info("Forward upsample size to force interpolation output size.") 801 | forward_upsample_size = True 802 | 803 | if attention_mask is not None: 804 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 805 | attention_mask = attention_mask.unsqueeze(1) 806 | 807 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 808 | if encoder_attention_mask is not None: 809 | encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 810 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 811 | 812 | # 0. center input if necessary 813 | if self.config.center_input_sample: 814 | sample = 2 * sample - 1.0 815 | 816 | # 1. time 817 | timesteps = timestep 818 | if not torch.is_tensor(timesteps): 819 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 820 | # This would be a good case for the `match` statement (Python 3.10+) 821 | is_mps = sample.device.type == "mps" 822 | if isinstance(timestep, float): 823 | dtype = torch.float32 if is_mps else torch.float64 824 | else: 825 | dtype = torch.int32 if is_mps else torch.int64 826 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 827 | elif len(timesteps.shape) == 0: 828 | timesteps = timesteps[None].to(sample.device) 829 | 830 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 831 | timesteps = timesteps.expand(sample.shape[0]) 832 | 833 | t_emb = self.time_proj(timesteps) 834 | 835 | # `Timesteps` does not contain any weights and will always return f32 tensors 836 | # but time_embedding might actually be running in fp16. so we need to cast here. 837 | # there might be better ways to encapsulate this. 838 | t_emb = t_emb.to(sample.device, dtype=sample.dtype) 839 | 840 | emb = self.time_embedding(t_emb, timestep_cond) 841 | aug_emb = None 842 | 843 | if self.class_embedding is not None: 844 | if class_labels is None: 845 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 846 | 847 | if self.config.class_embed_type == "timestep": 848 | class_labels = self.time_proj(class_labels) 849 | 850 | # `Timesteps` does not contain any weights and will always return f32 tensors 851 | # there might be better ways to encapsulate this. 852 | class_labels = class_labels.to(dtype=sample.dtype) 853 | 854 | class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) 855 | 856 | if self.config.class_embeddings_concat: 857 | emb = torch.cat([emb, class_emb], dim=-1) 858 | else: 859 | emb = emb + class_emb 860 | 861 | if self.config.addition_embed_type == "text": 862 | aug_emb = self.add_embedding(encoder_hidden_states) 863 | elif self.config.addition_embed_type == "text_image": 864 | # Kandinsky 2.1 - style 865 | if "image_embeds" not in added_cond_kwargs: 866 | raise ValueError( 867 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" 868 | ) 869 | 870 | image_embs = added_cond_kwargs.get("image_embeds") 871 | text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) 872 | aug_emb = self.add_embedding(text_embs, image_embs) 873 | elif self.config.addition_embed_type == "text_time": 874 | # SDXL - style 875 | if "text_embeds" not in added_cond_kwargs: 876 | raise ValueError( 877 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" 878 | ) 879 | text_embeds = added_cond_kwargs.get("text_embeds") 880 | if "time_ids" not in added_cond_kwargs: 881 | raise ValueError( 882 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" 883 | ) 884 | time_ids = added_cond_kwargs.get("time_ids") 885 | time_embeds = self.add_time_proj(time_ids.flatten()) 886 | time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) 887 | 888 | add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) 889 | add_embeds = add_embeds.to(emb.dtype) 890 | aug_emb = self.add_embedding(add_embeds) 891 | elif self.config.addition_embed_type == "image": 892 | # Kandinsky 2.2 - style 893 | if "image_embeds" not in added_cond_kwargs: 894 | raise ValueError( 895 | f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" 896 | ) 897 | image_embs = added_cond_kwargs.get("image_embeds") 898 | aug_emb = self.add_embedding(image_embs) 899 | elif self.config.addition_embed_type == "image_hint": 900 | # Kandinsky 2.2 - style 901 | if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: 902 | raise ValueError( 903 | f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" 904 | ) 905 | image_embs = added_cond_kwargs.get("image_embeds") 906 | hint = added_cond_kwargs.get("hint") 907 | aug_emb, hint = self.add_embedding(image_embs, hint) 908 | sample = torch.cat([sample, hint], dim=1) 909 | 910 | emb = emb + aug_emb if aug_emb is not None else emb 911 | 912 | if self.time_embed_act is not None: 913 | emb = self.time_embed_act(emb) 914 | 915 | if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": 916 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) 917 | elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": 918 | # Kadinsky 2.1 - style 919 | if "image_embeds" not in added_cond_kwargs: 920 | raise ValueError( 921 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" 922 | ) 923 | 924 | image_embeds = added_cond_kwargs.get("image_embeds") 925 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) 926 | elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": 927 | # Kandinsky 2.2 - style 928 | if "image_embeds" not in added_cond_kwargs: 929 | raise ValueError( 930 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" 931 | ) 932 | image_embeds = added_cond_kwargs.get("image_embeds") 933 | encoder_hidden_states = self.encoder_hid_proj(image_embeds) 934 | # 2. pre-process 935 | sample = self.conv_in(sample) 936 | 937 | # 2.5 GLIGEN position net 938 | if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: 939 | cross_attention_kwargs = cross_attention_kwargs.copy() 940 | gligen_args = cross_attention_kwargs.pop("gligen") 941 | cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} 942 | 943 | # 3. down 944 | 945 | is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None 946 | is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None 947 | 948 | down_block_res_samples = (sample,) 949 | for downsample_block in self.down_blocks: 950 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 951 | # For t2i-adapter CrossAttnDownBlock2D 952 | additional_residuals = {} 953 | if is_adapter and len(down_block_additional_residuals) > 0: 954 | additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) 955 | 956 | sample, res_samples = downsample_block( 957 | hidden_states=sample, 958 | temb=emb, 959 | encoder_hidden_states=encoder_hidden_states, 960 | attention_mask=attention_mask, 961 | cross_attention_kwargs=cross_attention_kwargs, 962 | encoder_attention_mask=encoder_attention_mask, 963 | **additional_residuals, 964 | ) 965 | else: 966 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 967 | 968 | if is_adapter and len(down_block_additional_residuals) > 0: 969 | sample += down_block_additional_residuals.pop(0) 970 | 971 | down_block_res_samples += res_samples 972 | 973 | if is_controlnet: 974 | new_down_block_res_samples = () 975 | 976 | for down_block_res_sample, down_block_additional_residual in zip( 977 | down_block_res_samples, down_block_additional_residuals 978 | ): 979 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 980 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) 981 | 982 | down_block_res_samples = new_down_block_res_samples 983 | 984 | # 4. mid 985 | if self.mid_block is not None: 986 | sample = self.mid_block( 987 | sample, 988 | emb, 989 | encoder_hidden_states=encoder_hidden_states, 990 | attention_mask=attention_mask, 991 | cross_attention_kwargs=cross_attention_kwargs, 992 | encoder_attention_mask=encoder_attention_mask, 993 | ) 994 | # To support T2I-Adapter-XL 995 | if ( 996 | is_adapter 997 | and len(down_block_additional_residuals) > 0 998 | and sample.shape == down_block_additional_residuals[0].shape 999 | ): 1000 | sample += down_block_additional_residuals.pop(0) 1001 | 1002 | if is_controlnet: 1003 | sample = sample + mid_block_additional_residual 1004 | 1005 | # 5. up 1006 | for i, upsample_block in enumerate(self.up_blocks): 1007 | is_final_block = i == len(self.up_blocks) - 1 1008 | 1009 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 1010 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 1011 | 1012 | # if we have not reached the final block and need to forward the 1013 | # upsample size, we do it here 1014 | if not is_final_block and forward_upsample_size: 1015 | upsample_size = down_block_res_samples[-1].shape[2:] 1016 | 1017 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 1018 | sample = upsample_block( 1019 | hidden_states=sample, 1020 | temb=emb, 1021 | res_hidden_states_tuple=res_samples, 1022 | encoder_hidden_states=encoder_hidden_states, 1023 | cross_attention_kwargs=cross_attention_kwargs, 1024 | upsample_size=upsample_size, 1025 | attention_mask=attention_mask, 1026 | encoder_attention_mask=encoder_attention_mask, 1027 | ) 1028 | else: 1029 | sample = upsample_block( 1030 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 1031 | ) 1032 | 1033 | if not return_dict: 1034 | return (sample,) 1035 | 1036 | return UNet2DConditionOutput(sample=sample) -------------------------------------------------------------------------------- /nodes/libs/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_StableHair_ll/d00491937bc9f0c4fd96f010bf0c63fb5eabcd30/nodes/libs/utils/__init__.py -------------------------------------------------------------------------------- /nodes/libs/utils/pipeline.py: -------------------------------------------------------------------------------- 1 | import inspect, math 2 | from typing import Callable, List, Optional, Union 3 | from dataclasses import dataclass 4 | from PIL import Image 5 | import numpy as np 6 | import torch 7 | from diffusers.loaders import FromSingleFileMixin 8 | from einops import rearrange 9 | import torch.distributed as dist 10 | from tqdm import tqdm 11 | from diffusers.utils import is_accelerate_available 12 | from transformers import CLIPTextModel, CLIPTokenizer 13 | 14 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 15 | from diffusers import DiffusionPipeline 16 | from diffusers.schedulers import ( 17 | DDIMScheduler, 18 | DPMSolverMultistepScheduler, 19 | EulerAncestralDiscreteScheduler, 20 | EulerDiscreteScheduler, 21 | LMSDiscreteScheduler, 22 | PNDMScheduler, 23 | ) 24 | from diffusers.utils import logging, BaseOutput 25 | 26 | from ..ref_encoder.latent_controlnet import ControlNetModel 27 | from ..ref_encoder.reference_control import ReferenceAttentionControl 28 | import torch.nn.functional as F 29 | from ..ref_encoder.reference_unet import RefHairUnet 30 | 31 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 32 | 33 | 34 | @dataclass 35 | class PipelineOutput(BaseOutput): 36 | samples: Union[torch.Tensor, np.ndarray] 37 | 38 | 39 | class StableHairPipeline(DiffusionPipeline, FromSingleFileMixin): 40 | _optional_components = [] 41 | 42 | def __init__( 43 | self, 44 | vae: AutoencoderKL, 45 | text_encoder: CLIPTextModel, 46 | tokenizer: CLIPTokenizer, 47 | unet: UNet2DConditionModel, 48 | scheduler: Union[ 49 | DDIMScheduler, 50 | PNDMScheduler, 51 | LMSDiscreteScheduler, 52 | EulerDiscreteScheduler, 53 | EulerAncestralDiscreteScheduler, 54 | DPMSolverMultistepScheduler, 55 | ], 56 | controlnet: ControlNetModel = None, 57 | reference_encoder: RefHairUnet = None 58 | ): 59 | super().__init__() 60 | self.register_modules( 61 | vae=vae, 62 | text_encoder=text_encoder, 63 | tokenizer=tokenizer, 64 | unet=unet, 65 | controlnet=controlnet, 66 | scheduler=scheduler, 67 | reference_encoder=reference_encoder 68 | ) 69 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 70 | 71 | def enable_vae_slicing(self): 72 | self.vae.enable_slicing() 73 | 74 | def disable_vae_slicing(self): 75 | self.vae.disable_slicing() 76 | 77 | def enable_sequential_cpu_offload(self, gpu_id=0): 78 | if is_accelerate_available(): 79 | from accelerate import cpu_offload 80 | else: 81 | raise ImportError("Please install accelerate via `pip install accelerate`") 82 | 83 | device = torch.device(f"cuda:{gpu_id}") 84 | 85 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 86 | if cpu_offloaded_model is not None: 87 | cpu_offload(cpu_offloaded_model, device) 88 | 89 | @property 90 | def _execution_device(self): 91 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 92 | return self.device 93 | for module in self.unet.modules(): 94 | if ( 95 | hasattr(module, "_hf_hook") 96 | and hasattr(module._hf_hook, "execution_device") 97 | and module._hf_hook.execution_device is not None 98 | ): 99 | return torch.device(module._hf_hook.execution_device) 100 | return self.device 101 | 102 | def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt): 103 | if isinstance(prompt, torch.Tensor): 104 | batch_size = prompt.shape[0] 105 | text_input_ids = prompt 106 | else: 107 | batch_size = 1 if isinstance(prompt, str) else len(prompt) 108 | text_inputs = self.tokenizer( 109 | prompt, 110 | padding="max_length", 111 | max_length=self.tokenizer.model_max_length, 112 | truncation=True, 113 | return_tensors="pt", 114 | ) 115 | text_input_ids = text_inputs.input_ids 116 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 117 | 118 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, 119 | untruncated_ids): 120 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]) 121 | logger.warning( 122 | "The following part of your input was truncated because CLIP can only handle sequences up to" 123 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 124 | ) 125 | 126 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 127 | attention_mask = text_inputs.attention_mask.to(device) 128 | else: 129 | attention_mask = None 130 | 131 | text_embeddings = self.text_encoder( 132 | text_input_ids.to(device), 133 | attention_mask=attention_mask, 134 | ) 135 | text_embeddings = text_embeddings[0] 136 | 137 | # duplicate text embeddings for each generation per prompt, using mps friendly method 138 | bs_embed, seq_len, _ = text_embeddings.shape 139 | 140 | # get unconditional embeddings for classifier free guidance 141 | if do_classifier_free_guidance: 142 | uncond_tokens: List[str] 143 | if negative_prompt is None: 144 | uncond_tokens = [""] * batch_size 145 | elif type(prompt) is not type(negative_prompt): 146 | raise TypeError( 147 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 148 | f" {type(prompt)}." 149 | ) 150 | elif isinstance(negative_prompt, str): 151 | uncond_tokens = [negative_prompt] 152 | elif batch_size != len(negative_prompt): 153 | raise ValueError( 154 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 155 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 156 | " the batch size of `prompt`." 157 | ) 158 | else: 159 | uncond_tokens = negative_prompt 160 | 161 | max_length = text_input_ids.shape[-1] 162 | uncond_input = self.tokenizer( 163 | uncond_tokens, 164 | padding="max_length", 165 | max_length=max_length, 166 | truncation=True, 167 | return_tensors="pt", 168 | ) 169 | 170 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 171 | attention_mask = uncond_input.attention_mask.to(device) 172 | else: 173 | attention_mask = None 174 | 175 | uncond_embeddings = self.text_encoder( 176 | uncond_input.input_ids.to(device), 177 | attention_mask=attention_mask, 178 | ) 179 | uncond_embeddings = uncond_embeddings[0] 180 | 181 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 182 | seq_len = uncond_embeddings.shape[1] 183 | 184 | # For classifier free guidance, we need to do two forward passes. 185 | # Here we concatenate the unconditional and text embeddings into a single batch 186 | # to avoid doing two forward passes 187 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 188 | 189 | return text_embeddings 190 | 191 | def decode_latents(self, latents): 192 | latents = 1 / 0.18215 * latents 193 | image = self.vae.decode(latents).sample 194 | image = (image / 2 + 0.5).clamp(0, 1).permute(0, 2, 3, 1) 195 | image = image.cpu().squeeze(0).float().numpy() 196 | return image 197 | 198 | def prepare_extra_step_kwargs(self, generator, eta): 199 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 200 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 201 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 202 | # and should be between [0, 1] 203 | 204 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 205 | extra_step_kwargs = {} 206 | if accepts_eta: 207 | extra_step_kwargs["eta"] = eta 208 | 209 | # check if the scheduler accepts generator 210 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 211 | if accepts_generator: 212 | extra_step_kwargs["generator"] = generator 213 | return extra_step_kwargs 214 | 215 | def check_inputs(self, prompt, height, width, callback_steps): 216 | if not isinstance(prompt, str) and not isinstance(prompt, list) and not isinstance(prompt, torch.Tensor): 217 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 218 | 219 | if height % 8 != 0 or width % 8 != 0: 220 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 221 | 222 | if (callback_steps is None) or ( 223 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 224 | ): 225 | raise ValueError( 226 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 227 | f" {type(callback_steps)}." 228 | ) 229 | 230 | def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): 231 | if isinstance(generator, list): 232 | image_latents = [ 233 | self.vae.encode(image[i: i + 1]).latent_dist.sample(generator=generator[i]) 234 | for i in range(image.shape[0]) 235 | ] 236 | image_latents = torch.cat(image_latents, dim=0) 237 | else: 238 | image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) 239 | 240 | image_latents = self.vae.config.scaling_factor * image_latents 241 | 242 | return image_latents 243 | 244 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, 245 | clip_length=16): 246 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) 247 | if isinstance(generator, list) and len(generator) != batch_size: 248 | raise ValueError( 249 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 250 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 251 | ) 252 | if latents is None: 253 | rand_device = "cpu" if device.type == "mps" else device 254 | 255 | if isinstance(generator, list): 256 | latents = [ 257 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) 258 | for i in range(batch_size) 259 | ] 260 | latents = torch.cat(latents, dim=0).to(device) 261 | else: 262 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) 263 | 264 | else: 265 | if latents.shape != shape: 266 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 267 | latents = latents.to(device) 268 | 269 | # scale the initial noise by the standard deviation required by the scheduler 270 | noise = latents.clone() 271 | latents = latents * self.scheduler.init_noise_sigma 272 | return latents, noise 273 | 274 | def prepare_condition(self, condition, device, dtype, do_classifier_free_guidance): 275 | if isinstance(condition, torch.Tensor): 276 | # suppose input is [-1, 1] 277 | # condition = condition 278 | condition = self.images2latents(condition, dtype).to(device) 279 | elif isinstance(condition, np.ndarray): 280 | # suppose input is [0, 255] 281 | condition = self.images2latents(condition, dtype).to(device) 282 | if do_classifier_free_guidance: 283 | condition_pad = torch.ones_like(condition) * -1 284 | condition = torch.cat([condition_pad, condition]).to(device) 285 | return condition 286 | 287 | @torch.no_grad() 288 | def images2latents(self, images, dtype): 289 | """ 290 | Convert RGB image to VAE latents 291 | """ 292 | device = self._execution_device 293 | if isinstance(images, torch.Tensor): 294 | # suppose input is [-1, 1] 295 | images = images.to(dtype) 296 | if images.ndim == 3: 297 | images = images.unsqueeze(0) 298 | elif isinstance(images, np.ndarray): 299 | # suppose input is [0, 255] 300 | images = torch.from_numpy(images).float().to(dtype) / 127.5 - 1 301 | images = rearrange(images, "h w c -> c h w").to(device)[None, :] 302 | latents = self.vae.encode(images)['latent_dist'].mean * self.vae.config.scaling_factor 303 | return latents 304 | 305 | @torch.no_grad() 306 | def encode_single_image_latents(self, images, mask, dtype): 307 | device = self._execution_device 308 | images = torch.from_numpy(images).float().to(dtype) / 127.5 - 1 309 | images = rearrange(images, "h w c -> c h w").to(device) 310 | latents = self.vae.encode(images[None, :])['latent_dist'].mean * 0.18215 311 | 312 | images = images.unsqueeze(0) 313 | 314 | mask = torch.from_numpy(mask).float().to(dtype).to(device) / 255.0 315 | if mask.ndim == 2: 316 | mask = mask[None, None, :] 317 | elif mask.ndim == 3: 318 | mask = mask[:, None, :, :] 319 | 320 | mask = F.interpolate(mask, size=latents.shape[-2:], mode='nearest') 321 | return latents, images, mask 322 | 323 | @torch.no_grad() 324 | def __call__( 325 | self, 326 | prompt: Union[str, List[str]], 327 | height: Optional[int] = None, 328 | width: Optional[int] = None, 329 | num_inference_steps: int = 50, 330 | guidance_scale: float = 7.5, 331 | negative_prompt: Optional[Union[str, List[str]]] = None, 332 | eta: float = 0.0, 333 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 334 | latents: Optional[torch.FloatTensor] = None, 335 | output_type: Optional[str] = "np", 336 | return_dict: bool = True, 337 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 338 | callback_steps: Optional[int] = 1, 339 | controlnet_condition: list = None, 340 | controlnet_conditioning_scale: Optional[float] = 1.0, 341 | init_latents: Optional[torch.FloatTensor] = None, 342 | num_actual_inference_steps: Optional[int] = None, 343 | # reference_encoder=None, 344 | ref_image=None, 345 | t2i=False, 346 | style_fidelity=1.0, 347 | **kwargs, 348 | ): 349 | controlnet = self.controlnet 350 | 351 | # Default height and width to unet 352 | height = height or self.unet.config.sample_size * self.vae_scale_factor 353 | width = width or self.unet.config.sample_size * self.vae_scale_factor 354 | 355 | # Check inputs. Raise error if not correct 356 | self.check_inputs(prompt, height, width, callback_steps) 357 | 358 | # Define call parameters 359 | # batch_size = 1 if isinstance(prompt, str) else len(prompt) 360 | batch_size = 1 361 | if latents is not None: 362 | batch_size = latents.shape[0] 363 | if isinstance(prompt, list): 364 | batch_size = len(prompt) 365 | 366 | device = self._execution_device 367 | do_classifier_free_guidance = guidance_scale > 1.0 368 | 369 | # Encode input prompt 370 | if not isinstance(prompt, torch.Tensor): 371 | prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size 372 | if negative_prompt is not None: 373 | negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size 374 | text_embeddings = self._encode_prompt( 375 | prompt, device, do_classifier_free_guidance, negative_prompt 376 | ) 377 | text_embeddings = torch.cat([text_embeddings]) 378 | 379 | reference_control_writer = ReferenceAttentionControl(self.reference_encoder, do_classifier_free_guidance=True, 380 | style_fidelity=style_fidelity, 381 | mode='write', fusion_blocks='full') 382 | reference_control_reader = ReferenceAttentionControl(self.unet, do_classifier_free_guidance=True, mode='read', 383 | style_fidelity=style_fidelity, 384 | fusion_blocks='full') 385 | 386 | is_dist_initialized = kwargs.get("dist", False) 387 | rank = kwargs.get("rank", 0) 388 | 389 | # Prepare control_img 390 | control = self.prepare_condition( 391 | condition=controlnet_condition, 392 | device=device, 393 | dtype=controlnet.dtype, 394 | do_classifier_free_guidance=do_classifier_free_guidance, 395 | ) 396 | # for b in range(control.size(0)): 397 | # max_value = torch.max(control[b]) 398 | # min_value = torch.min(control[b]) 399 | # control[b] = (control[b] - min_value) / (max_value - min_value) 400 | 401 | # Prepare timesteps 402 | self.scheduler.set_timesteps(num_inference_steps, device=device) 403 | timesteps = self.scheduler.timesteps 404 | 405 | num_channels_latents = self.unet.in_channels 406 | latents = self.prepare_latents( 407 | batch_size, 408 | num_channels_latents, 409 | height, 410 | width, 411 | text_embeddings.dtype, 412 | device, 413 | generator, 414 | latents, 415 | ) 416 | if isinstance(latents, tuple): 417 | latents, noise = latents 418 | 419 | latents_dtype = latents.dtype 420 | 421 | # Prepare extra step kwargs. 422 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 423 | 424 | # For img2img setting 425 | if num_actual_inference_steps is None: 426 | num_actual_inference_steps = num_inference_steps 427 | 428 | if isinstance(ref_image, str): 429 | ref_image_latents = self.images2latents(np.array(Image.open(ref_image).resize((width, height))), 430 | latents_dtype) 431 | elif isinstance(ref_image, np.ndarray): 432 | ref_image_latents = self.images2latents(ref_image, latents_dtype) 433 | elif isinstance(ref_image, torch.Tensor): 434 | ref_image_latents = self.images2latents(ref_image, latents_dtype) 435 | 436 | ref_padding_latents = torch.ones_like(ref_image_latents) * -1 437 | ref_image_latents = torch.cat([ref_padding_latents, ref_image_latents]) if do_classifier_free_guidance else ref_image_latents 438 | ref_image_latents.to(device) 439 | 440 | # Denoising loop 441 | for i, t in tqdm(enumerate(timesteps), total=len(timesteps), disable=(rank != 0)): 442 | if num_actual_inference_steps is not None and i < num_inference_steps - num_actual_inference_steps: 443 | continue 444 | 445 | # writer 446 | ref_latents_input = ref_image_latents 447 | self.reference_encoder( 448 | ref_latents_input, 449 | t, 450 | encoder_hidden_states=text_embeddings, 451 | return_dict=False, 452 | ) 453 | reference_control_reader.update(reference_control_writer) 454 | 455 | # prepare latents 456 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 457 | 458 | if t2i: 459 | pass 460 | 461 | else: 462 | # controlnet 463 | down_block_res_samples, mid_block_res_sample = self.controlnet( 464 | latent_model_input, 465 | t, 466 | encoder_hidden_states=text_embeddings, 467 | controlnet_cond=control, 468 | return_dict=False, 469 | ) 470 | down_block_res_samples = [sample * controlnet_conditioning_scale for sample in down_block_res_samples] 471 | mid_block_res_sample = mid_block_res_sample * controlnet_conditioning_scale 472 | 473 | if t2i: 474 | # predict the noise residual 475 | noise_pred = self.unet( 476 | latent_model_input, 477 | t, 478 | encoder_hidden_states=text_embeddings, 479 | return_dict=False, 480 | )[0] 481 | 482 | else: 483 | # predict the noise residual 484 | noise_pred = self.unet( 485 | latent_model_input, 486 | t, 487 | encoder_hidden_states=text_embeddings, 488 | down_block_additional_residuals=down_block_res_samples, 489 | mid_block_additional_residual=mid_block_res_sample, 490 | return_dict=False, 491 | )[0] 492 | 493 | # clean the reader 494 | reference_control_reader.clear() 495 | 496 | # perform guidance 497 | if do_classifier_free_guidance: 498 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 499 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 500 | 501 | # compute the previous noisy sample x_t -> x_t-1 502 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 503 | 504 | if callback is not None: 505 | callback(i, t, latents) 506 | 507 | if is_dist_initialized: 508 | dist.broadcast(latents, 0) 509 | dist.barrier() 510 | 511 | reference_control_writer.clear() 512 | 513 | samples = self.decode_latents(latents) 514 | if is_dist_initialized: 515 | dist.barrier() 516 | 517 | # Convert to tensor 518 | if output_type == "tensor": 519 | samples = torch.from_numpy(samples) 520 | 521 | if not return_dict: 522 | return samples 523 | 524 | return PipelineOutput(samples=samples) 525 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_stablehair_ll" 3 | description = "Hair transfer" 4 | version = "1.0.1" 5 | license = {file = "LICENSE"} 6 | dependencies = ["numpy"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/lldacing/ComfyUI_StableHair_ll" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "lldacing" 14 | DisplayName = "ComfyUI_StableHair_ll" 15 | Icon = "" 16 | --------------------------------------------------------------------------------