├── .github └── workflows │ └── typecheck.yaml ├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── assets ├── CFG-Zero │ ├── image.webp │ ├── image_CFG.webp │ └── image_CFG_zero_star.webp ├── example1.jpeg ├── example2.jpeg ├── example3.jpeg ├── example4.jpeg ├── img.md ├── method.jpg ├── result_canny.png ├── result_ghibli.png ├── result_subject.png ├── result_subject_inpainting.png └── teaser.jpg ├── infer.ipynb ├── infer.py ├── infer_multi.py ├── requirements.txt ├── src ├── __init__.py ├── layers_cache.py ├── lora_helper.py ├── pipeline.py └── transformer_flux.py ├── test_imgs ├── canny.png ├── depth.png ├── ghibli.png ├── inpainting.png ├── openpose.png ├── seg.png ├── subject_0.png └── subject_1.png └── train ├── default_config.yaml ├── examples ├── openpose_data │ ├── 1.png │ └── 2.png ├── pose.jsonl ├── style.jsonl ├── style_data │ ├── 5.png │ └── 6.png ├── subject.jsonl └── subject_data │ ├── 3.png │ └── 4.png ├── readme.md ├── src ├── __init__.py ├── jsonl_datasets.py ├── layers.py ├── lora_helper.py ├── pipeline.py ├── prompt_helper.py └── transformer_flux.py ├── train.py ├── train_spatial.sh ├── train_style.sh └── train_subject.sh /.github/workflows/typecheck.yaml: -------------------------------------------------------------------------------- 1 | name: Typecheck 2 | 3 | # These checks will run if at least one file is outside of the `paths-ignore` 4 | # list, but will be skipped if *all* files are in the `paths-ignore` list. 5 | # 6 | # Fore more info, see: 7 | # https://docs.github.com/en/actions/writing-workflows/workflow-syntax-for-github-actions#example-excluding-paths 8 | 9 | on: 10 | push: 11 | branches: 12 | - 'main' 13 | paths-ignore: 14 | - '**.jpeg' 15 | - '**.jpg' 16 | - '**.md' 17 | - '**.png' 18 | - '**.webp' 19 | 20 | pull_request: 21 | branches: 22 | - 'main' 23 | paths-ignore: 24 | - '**.jpeg' 25 | - '**.jpg' 26 | - '**.md' 27 | - '**.png' 28 | - '**.webp' 29 | 30 | jobs: 31 | test: 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | os: [ 'ubuntu-24.04' ] 36 | python: [ '3.10' ] 37 | 38 | runs-on: ${{ matrix.os }} 39 | name: Python ${{ matrix.python }} on ${{ matrix.os }} 40 | 41 | steps: 42 | - name: Checkout the repo 43 | uses: actions/checkout@v4 44 | 45 | - name: Setup Python 46 | uses: actions/setup-python@v5 47 | with: 48 | python-version: ${{ matrix.python }} 49 | cache: 'pip' 50 | 51 | - name: Update pip 52 | run: python -m pip install --upgrade pip 53 | 54 | - name: Install Python deps 55 | run: python -m pip install -r requirements.txt 56 | 57 | - name: Install Mypy 58 | run: python -m pip install mypy 59 | 60 | - name: Check types with Mypy 61 | run: python -m mypy --python-version=${{ matrix.python }} . 62 | # TODO: fix the type checking errors and remove this line to make errors 63 | # obvious by failing the test. 64 | continue-on-error: true 65 | 66 | - name: Install PyType 67 | run: python -m pip install pytype 68 | 69 | - name: Check types with PyType 70 | run: python -m pytype --python-version=${{ matrix.python }} -k . 71 | # TODO: fix the type checking errors and remove this line to make errors 72 | # obvious by failing the test. 73 | continue-on-error: true 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.safetensors 3 | 4 | .DS_Store 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementation of EasyControl 2 | 3 | EasyControl: Adding Efficient and Flexible Control for Diffusion Transformer 4 | 5 | 6 | 7 | HuggingFace 8 | 9 | 10 | 11 | > *[Yuxuan Zhang](https://xiaojiu-z.github.io/YuxuanZhang.github.io/), [Yirui Yuan](https://github.com/Reynoldyy), [Yiren Song](https://scholar.google.com.hk/citations?user=L2YS0jgAAAAJ), [Haofan Wang](https://haofanwang.github.io/), [Jiaming Liu](https://scholar.google.com/citations?user=SmL7oMQAAAAJ&hl=en)* 12 | >
13 | > Tiamat AI, ShanghaiTech University, National University of Singapore, Liblib AI 14 | 15 | 16 | 17 | ## Features 18 | * **Motivation:** The architecture of diffusion models is transitioning from Unet-based to DiT (Diffusion Transformer). However, the DiT ecosystem lacks mature plugin support and faces challenges such as efficiency bottlenecks, conflicts in multi-condition coordination, and insufficient model adaptability. 19 | * **Contribution:** We propose EasyControl, an efficient and flexible unified conditional DiT framework. By incorporating a lightweight Condition Injection LoRA module, a Position-Aware Training Paradigm, and a combination of Causal Attention mechanisms with KV Cache technology, we significantly enhance **model compatibility** (enabling plug-and-play functionality and style lossless control), **generation flexibility** (supporting multiple resolutions, aspect ratios, and multi-condition combinations), and **inference efficiency**. 20 | 21 | 22 | ## News 23 | - **2025-04-11**: 🔥🔥🔥 Training code have been released. Recommanded Hardware: at least 1x NVIDIA H100/H800/A100, GPUs Memory: ~80GB GPU memory. 24 | - **2025-04-09**: ⭐️ The codes for the simple API have been released. If you wish to run the models on your personal machines, head over to the simple_api branch to access the relevant resources. 25 | 26 | - **2025-04-07**: 🔥 Thanks to the great work by [CFG-Zero*](https://github.com/WeichenFan/CFG-Zero-star) team, EasyControl is now integrated with CFG-Zero*!! With just a few lines of code, you can boost image fidelity and controllability!! You can download the modified code from [this link](https://github.com/WeichenFan/CFG-Zero-star/blob/main/models/easycontrol/infer.py) and try it. 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 |
Source ImageCFGCFG-Zero*
40 | 41 | - **2025-04-03**: Thanks to [jax-explorer](https://github.com/jax-explorer), [Ghibli Img2Img Control ComfyUI Node](https://github.com/jax-explorer/ComfyUI-easycontrol) is supported! 42 | 43 | - **2025-04-01**: 🔥 New [Stylized Img2Img Control Model](https://huggingface.co/spaces/jamesliu1217/EasyControl_Ghibli) is now released!! Transform portraits into Studio Ghibli-style artwork using this LoRA model. Trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, it preserves facial features while applying the iconic anime aesthetic. 44 | 45 |
46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 |
Example 3Example 4
Example 3Example 4
56 |
57 | 58 | - **2025-03-19**: 🔥 We have released [huggingface demo](https://huggingface.co/spaces/jamesliu1217/EasyControl)! You can now try out EasyControl with the huggingface space, enjoy it! 59 |
60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 |
Example 1Example 2
Example 1Example 2
70 |
71 | 72 | - **2025-03-18**: 🔥 We have released our [pre-trained checkpoints](https://huggingface.co/Xiaojiu-Z/EasyControl/) on Hugging Face! You can now try out EasyControl with the official weights. 73 | - **2025-03-12**: ⭐️ Inference code are released. Once we have ensured that everything is functioning correctly, the new model will be merged into this repository. Stay tuned for updates! 😊 74 | 75 | ## Installation 76 | 77 | We recommend using Python 3.10 and PyTorch with CUDA support. To set up the environment: 78 | 79 | ```bash 80 | # Create a new conda environment 81 | conda create -n easycontrol python=3.10 82 | conda activate easycontrol 83 | 84 | # Install other dependencies 85 | pip install -r requirements.txt 86 | ``` 87 | 88 | ## Download 89 | 90 | You can download the model directly from [Hugging Face](https://huggingface.co/EasyControl/EasyControl). 91 | Or download using Python script: 92 | 93 | ```python 94 | from huggingface_hub import hf_hub_download 95 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/canny.safetensors", local_dir="./") 96 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/depth.safetensors", local_dir="./") 97 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/hedsketch.safetensors", local_dir="./") 98 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/inpainting.safetensors", local_dir="./") 99 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/pose.safetensors", local_dir="./") 100 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/seg.safetensors", local_dir="./") 101 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/subject.safetensors", local_dir="./") 102 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/Ghibli.safetensors", local_dir="./") 103 | ``` 104 | 105 | If you cannot access Hugging Face, you can use [hf-mirror](https://hf-mirror.com/) to download the models: 106 | ```python 107 | export HF_ENDPOINT=https://hf-mirror.com 108 | huggingface-cli download --resume-download Xiaojiu-Z/EasyControl --local-dir checkpoints --local-dir-use-symlinks False 109 | ``` 110 | 111 | ## Usage 112 | Here's a basic example of using EasyControl: 113 | 114 | ### Model Initialization 115 | 116 | ```python 117 | import torch 118 | from PIL import Image 119 | from src.pipeline import FluxPipeline 120 | from src.transformer_flux import FluxTransformer2DModel 121 | from src.lora_helper import set_single_lora, set_multi_lora 122 | 123 | def clear_cache(transformer): 124 | for name, attn_processor in transformer.attn_processors.items(): 125 | attn_processor.bank_kv.clear() 126 | 127 | # Initialize model 128 | device = "cuda" 129 | base_path = "FLUX.1-dev" # Path to your base model 130 | pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16, device=device) 131 | transformer = FluxTransformer2DModel.from_pretrained( 132 | base_path, 133 | subfolder="transformer", 134 | torch_dtype=torch.bfloat16, 135 | device=device 136 | ) 137 | pipe.transformer = transformer 138 | pipe.to(device) 139 | 140 | # Load control models 141 | lora_path = "./checkpoints/models" 142 | control_models = { 143 | "canny": f"{lora_path}/canny.safetensors", 144 | "depth": f"{lora_path}/depth.safetensors", 145 | "hedsketch": f"{lora_path}/hedsketch.safetensors", 146 | "pose": f"{lora_path}/pose.safetensors", 147 | "seg": f"{lora_path}/seg.safetensors", 148 | "inpainting": f"{lora_path}/inpainting.safetensors", 149 | "subject": f"{lora_path}/subject.safetensors", 150 | } 151 | ``` 152 | 153 | ### Single Condition Control 154 | 155 | ```python 156 | # Single spatial condition control example 157 | path = control_models["canny"] 158 | set_single_lora(pipe.transformer, path, lora_weights=[1], cond_size=512) 159 | 160 | # Generate image 161 | prompt = "A nice car on the beach" 162 | spatial_image = Image.open("./test_imgs/canny.png").convert("RGB") 163 | 164 | image = pipe( 165 | prompt, 166 | height=720, 167 | width=992, 168 | guidance_scale=3.5, 169 | num_inference_steps=25, 170 | max_sequence_length=512, 171 | generator=torch.Generator("cpu").manual_seed(5), 172 | spatial_images=[spatial_image], 173 | cond_size=512, 174 | ).images[0] 175 | 176 | # Clear cache after generation 177 | clear_cache(pipe.transformer) 178 | ``` 179 | 180 |
181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 |
Canny ConditionGenerated Result
Canny ConditionGenerated Result
191 |
192 | 193 | ```python 194 | # Single subject condition control example 195 | path = control_models["subject"] 196 | set_single_lora(pipe.transformer, path, lora_weights=[1], cond_size=512) 197 | 198 | # Generate image 199 | prompt = "A SKS in the library" 200 | subject_image = Image.open("./test_imgs/subject_0.png").convert("RGB") 201 | 202 | image = pipe( 203 | prompt, 204 | height=1024, 205 | width=1024, 206 | guidance_scale=3.5, 207 | num_inference_steps=25, 208 | max_sequence_length=512, 209 | generator=torch.Generator("cpu").manual_seed(5), 210 | subject_images=[subject_image], 211 | cond_size=512, 212 | ).images[0] 213 | 214 | # Clear cache after generation 215 | clear_cache(pipe.transformer) 216 | ``` 217 | 218 |
219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 |
Subject ConditionGenerated Result
Subject ConditionGenerated Result
229 |
230 | 231 | ### Multi-Condition Control 232 | 233 | ```python 234 | # Multi-condition control example 235 | paths = [control_models["subject"], control_models["inpainting"]] 236 | set_multi_lora(pipe.transformer, paths, lora_weights=[[1], [1]], cond_size=512) 237 | 238 | prompt = "A SKS on the car" 239 | subject_images = [Image.open("./test_imgs/subject_1.png").convert("RGB")] 240 | spatial_images = [Image.open("./test_imgs/inpainting.png").convert("RGB")] 241 | 242 | image = pipe( 243 | prompt, 244 | height=1024, 245 | width=1024, 246 | guidance_scale=3.5, 247 | num_inference_steps=25, 248 | max_sequence_length=512, 249 | generator=torch.Generator("cpu").manual_seed(42), 250 | subject_images=subject_images, 251 | spatial_images=spatial_images, 252 | cond_size=512, 253 | ).images[0] 254 | 255 | # Clear cache after generation 256 | clear_cache(pipe.transformer) 257 | ``` 258 | 259 |
260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 |
Subject ConditionInpainting ConditionGenerated Result
Subject ConditionInpainting ConditionGenerated Result
272 |
273 | 274 | ### Ghibli-Style Portrait Generation 275 | 276 | ```python 277 | import spaces 278 | import os 279 | import json 280 | import time 281 | import torch 282 | from PIL import Image 283 | from tqdm import tqdm 284 | import gradio as gr 285 | 286 | from safetensors.torch import save_file 287 | from src.pipeline import FluxPipeline 288 | from src.transformer_flux import FluxTransformer2DModel 289 | from src.lora_helper import set_single_lora, set_multi_lora, unset_lora 290 | 291 | # Initialize the image processor 292 | base_path = "black-forest-labs/FLUX.1-dev" 293 | lora_base_path = "./checkpoints/models" 294 | 295 | 296 | pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16) 297 | transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16) 298 | pipe.transformer = transformer 299 | pipe.to("cuda") 300 | 301 | def clear_cache(transformer): 302 | for name, attn_processor in transformer.attn_processors.items(): 303 | attn_processor.bank_kv.clear() 304 | 305 | # Define the Gradio interface 306 | @spaces.GPU() 307 | def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type): 308 | # Set the control type 309 | if control_type == "Ghibli": 310 | lora_path = os.path.join(lora_base_path, "Ghibli.safetensors") 311 | set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512) 312 | 313 | # Process the image 314 | spatial_imgs = [spatial_img] if spatial_img else [] 315 | image = pipe( 316 | prompt, 317 | height=int(height), 318 | width=int(width), 319 | guidance_scale=3.5, 320 | num_inference_steps=25, 321 | max_sequence_length=512, 322 | generator=torch.Generator("cpu").manual_seed(seed), 323 | subject_images=[], 324 | spatial_images=spatial_imgs, 325 | cond_size=512, 326 | ).images[0] 327 | clear_cache(pipe.transformer) 328 | return image 329 | 330 | # Define the Gradio interface components 331 | control_types = ["Ghibli"] 332 | 333 | 334 | # Create the Gradio Blocks interface 335 | with gr.Blocks() as demo: 336 | gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl") 337 | gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.") 338 | gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Due to hardware constraints, only low-resolution images can be generated. For high-resolution (1024+), please set up your own environment.)") 339 | 340 | gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: `Ghibli Studio style, Charming hand-drawn anime-style illustration`") 341 | gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))") 342 | 343 | with gr.Tab("Ghibli Condition Generation"): 344 | with gr.Row(): 345 | with gr.Column(): 346 | prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration") 347 | spatial_img = gr.Image(label="Ghibli Image", type="pil") # 上传图像文件 348 | height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768) 349 | width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768) 350 | seed = gr.Number(label="Seed", value=42) 351 | control_type = gr.Dropdown(choices=control_types, label="Control Type") 352 | single_generate_btn = gr.Button("Generate Image") 353 | with gr.Column(): 354 | single_output_image = gr.Image(label="Generated Image") 355 | 356 | 357 | # Link the buttons to the functions 358 | single_generate_btn.click( 359 | single_condition_generate_image, 360 | inputs=[prompt, spatial_img, height, width, seed, control_type], 361 | outputs=single_output_image 362 | ) 363 | 364 | # Launch the Gradio app 365 | demo.queue().launch() 366 | ``` 367 | 368 |
369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 |
Input ImageGenerated Result
Input ImageGenerated Result
379 |
380 | 381 | ## Usage Tips 382 | 383 | - Clear cache after each generation using `clear_cache(pipe.transformer)` 384 | - For optimal performance: 385 | - Start with `guidance_scale=3.5` and adjust based on results 386 | - Use `num_inference_steps=25` for a good balance of quality and speed 387 | - When using set_multi_lora api, make sure the subject lora path(subject) is before the spatial lora path(canny, depth, hedsketch, etc.). 388 | 389 | ## Todo List 390 | 1. - [x] Inference code 391 | 2. - [x] Spatial Pre-trained weights 392 | 3. - [x] Subject Pre-trained weights 393 | 4. - [x] Training code 394 | 395 | 396 | ## Star History 397 | 398 | [![Star History Chart](https://api.star-history.com/svg?repos=Xiaojiu-z/EasyControl&type=Date)](https://star-history.com/#Xiaojiu-z/EasyControl&Date) 399 | 400 | ## Disclaimer 401 | The code of EasyControl is released under [Apache License](https://github.com/Xiaojiu-Z/EasyControl?tab=Apache-2.0-1-ov-file#readme) for both academic and commercial usage. Our released checkpoints are for research purposes only. Users are granted the freedom to create images using this tool, but they are obligated to comply with local laws and utilize it responsibly. The developers will not assume any responsibility for potential misuse by users. 402 | 403 | ## Hiring/Cooperation 404 | - **2025-04-03**: If you are interested in EasyControl and beyond, or if you are interested in building 4o-like capacity (in a low-budget way, of course), we can collaborate offline in Shanghai/Beijing/Hong Kong/Singapore or online. 405 | Reach out: **jmliu1217@gmail.com (wechat: jiaming068870)** 406 | 407 | ## Citation 408 | ```bibtex 409 | @article{zhang2025easycontrol, 410 | title={EasyControl: Adding Efficient and Flexible Control for Diffusion Transformer}, 411 | author={Zhang, Yuxuan and Yuan, Yirui and Song, Yiren and Wang, Haofan and Liu, Jiaming}, 412 | journal={arXiv preprint arXiv:2503.07027}, 413 | year={2025} 414 | } 415 | ``` 416 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import spaces 2 | import os 3 | import json 4 | import time 5 | import torch 6 | from PIL import Image 7 | from tqdm import tqdm 8 | import gradio as gr 9 | 10 | from safetensors.torch import save_file 11 | from src.pipeline import FluxPipeline 12 | from src.transformer_flux import FluxTransformer2DModel 13 | from src.lora_helper import set_single_lora, set_multi_lora, unset_lora 14 | 15 | class ImageProcessor: 16 | def __init__(self, path): 17 | device = "cuda" 18 | self.pipe = FluxPipeline.from_pretrained(path, torch_dtype=torch.bfloat16, device=device) 19 | transformer = FluxTransformer2DModel.from_pretrained(path, subfolder="transformer", torch_dtype=torch.bfloat16, device=device) 20 | self.pipe.transformer = transformer 21 | self.pipe.to(device) 22 | 23 | def clear_cache(self, transformer): 24 | for name, attn_processor in transformer.attn_processors.items(): 25 | attn_processor.bank_kv.clear() 26 | 27 | @spaces.GPU() 28 | def process_image(self, prompt='', subject_imgs=[], spatial_imgs=[], height=768, width=768, output_path=None, seed=42): 29 | image = self.pipe( 30 | prompt, 31 | height=int(height), 32 | width=int(width), 33 | guidance_scale=3.5, 34 | num_inference_steps=25, 35 | max_sequence_length=512, 36 | generator=torch.Generator("cpu").manual_seed(seed), 37 | subject_images=subject_imgs, 38 | spatial_images=spatial_imgs, 39 | cond_size=512, 40 | ).images[0] 41 | self.clear_cache(self.pipe.transformer) 42 | if output_path: 43 | image.save(output_path) 44 | return image 45 | 46 | # Initialize the image processor 47 | base_path = "black-forest-labs/FLUX.1-dev" 48 | lora_base_path = "EasyControl/models" 49 | style_lora_base_path = "Shakker-Labs" 50 | processor = ImageProcessor(base_path) 51 | 52 | # Define the Gradio interface 53 | def single_condition_generate_image(prompt, subject_img, spatial_img, height, width, seed, control_type, style_lora=None): 54 | # Set the control type 55 | if control_type == "subject": 56 | lora_path = os.path.join(lora_base_path, "subject.safetensors") 57 | elif control_type == "depth": 58 | lora_path = os.path.join(lora_base_path, "depth.safetensors") 59 | elif control_type == "seg": 60 | lora_path = os.path.join(lora_base_path, "seg.safetensors") 61 | elif control_type == "pose": 62 | lora_path = os.path.join(lora_base_path, "pose.safetensors") 63 | elif control_type == "inpainting": 64 | lora_path = os.path.join(lora_base_path, "inpainting.safetensors") 65 | elif control_type == "hedsketch": 66 | lora_path = os.path.join(lora_base_path, "hedsketch.safetensors") 67 | elif control_type == "canny": 68 | lora_path = os.path.join(lora_base_path, "canny.safetensors") 69 | set_single_lora(processor.pipe.transformer, lora_path, lora_weights=[1], cond_size=512) 70 | 71 | # Set the style LoRA 72 | if style_lora=="None": 73 | pass 74 | else: 75 | if style_lora == "Simple_Sketch": 76 | processor.pipe.unload_lora_weights() 77 | style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Children-Simple-Sketch") 78 | processor.pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-children-simple-sketch.safetensors") 79 | if style_lora == "Text_Poster": 80 | processor.pipe.unload_lora_weights() 81 | style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Text-Poster") 82 | processor.pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Text-Poster.safetensors") 83 | if style_lora == "Vector_Style": 84 | processor.pipe.unload_lora_weights() 85 | style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Vector-Journey") 86 | processor.pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Vector-Journey.safetensors") 87 | 88 | # Process the image 89 | subject_imgs = [subject_img] if subject_img else [] 90 | spatial_imgs = [spatial_img] if spatial_img else [] 91 | image = processor.process_image(prompt=prompt, subject_imgs=subject_imgs, spatial_imgs=spatial_imgs, height=height, width=width, seed=seed) 92 | return image 93 | 94 | # Define the Gradio interface 95 | def multi_condition_generate_image(prompt, subject_img, spatial_img, height, width, seed): 96 | subject_path = os.path.join(lora_base_path, "subject.safetensors") 97 | inpainting_path = os.path.join(lora_base_path, "inpainting.safetensors") 98 | set_multi_lora(processor.pipe.transformer, [subject_path, inpainting_path], lora_weights=[[1],[1]],cond_size=512) 99 | 100 | # Process the image 101 | subject_imgs = [subject_img] if subject_img else [] 102 | spatial_imgs = [spatial_img] if spatial_img else [] 103 | image = processor.process_image(prompt=prompt, subject_imgs=subject_imgs, spatial_imgs=spatial_imgs, height=height, width=width, seed=seed) 104 | return image 105 | 106 | # Define the Gradio interface components 107 | control_types = ["subject", "depth", "pose", "inpainting", "hedsketch", "seg", "canny"] 108 | style_loras = ["Simple_Sketch", "Text_Poster", "Vector_Style", "None"] 109 | 110 | # Example data 111 | single_examples = [ 112 | ["A SKS in the library", Image.open("./test_imgs/subject1.png"), None, 1024, 1024, 5, "subject", None], 113 | ["In a picturesque village, a narrow cobblestone street with rustic stone buildings, colorful blinds, and lush green spaces, a cartoon man drawn with simple lines and solid colors stands in the foreground, wearing a red shirt, beige work pants, and brown shoes, carrying a strap on his shoulder. The scene features warm and enticing colors, a pleasant fusion of nature and architecture, and the camera's perspective on the street clearly shows the charming and quaint environment., Integrating elements of reality and cartoon.", None, Image.open("./test_imgs/spatial1.png"), 1024, 1024, 1, "pose", "Vector_Style"], 114 | ] 115 | multi_examples = [ 116 | ["A SKS on the car", Image.open("./test_imgs/subject2.png"), Image.open("./test_imgs/spatial2.png"), 1024, 1024, 7], 117 | ] 118 | 119 | 120 | # Create the Gradio Blocks interface 121 | with gr.Blocks() as demo: 122 | gr.Markdown("# Image Generation with EasyControl") 123 | gr.Markdown("Generate images using EasyControl with different control types and style LoRAs.") 124 | 125 | with gr.Tab("Single Condition Generation"): 126 | with gr.Row(): 127 | with gr.Column(): 128 | prompt = gr.Textbox(label="Prompt") 129 | subject_img = gr.Image(label="Subject Image", type="pil") # 上传图像文件 130 | spatial_img = gr.Image(label="Spatial Image", type="pil") # 上传图像文件 131 | height = gr.Slider(minimum=256, maximum=1536, step=64, label="Height", value=768) 132 | width = gr.Slider(minimum=256, maximum=1536, step=64, label="Width", value=768) 133 | seed = gr.Number(label="Seed", value=42) 134 | control_type = gr.Dropdown(choices=control_types, label="Control Type") 135 | style_lora = gr.Dropdown(choices=style_loras, label="Style LoRA") 136 | single_generate_btn = gr.Button("Generate Image") 137 | with gr.Column(): 138 | single_output_image = gr.Image(label="Generated Image") 139 | 140 | # Add examples for Single Condition Generation 141 | gr.Examples( 142 | examples=single_examples, 143 | inputs=[prompt, subject_img, spatial_img, height, width, seed, control_type, style_lora], 144 | outputs=single_output_image, 145 | fn=single_condition_generate_image, 146 | cache_examples=False, # 缓存示例结果以加快加载速度 147 | label="Single Condition Examples" 148 | ) 149 | 150 | 151 | with gr.Tab("Multi-Condition Generation"): 152 | with gr.Row(): 153 | with gr.Column(): 154 | multi_prompt = gr.Textbox(label="Prompt") 155 | multi_subject_img = gr.Image(label="Subject Image", type="pil") # 上传图像文件 156 | multi_spatial_img = gr.Image(label="Spatial Image", type="pil") # 上传图像文件 157 | multi_height = gr.Slider(minimum=256, maximum=1536, step=64, label="Height", value=768) 158 | multi_width = gr.Slider(minimum=256, maximum=1536, step=64, label="Width", value=768) 159 | multi_seed = gr.Number(label="Seed", value=42) 160 | multi_generate_btn = gr.Button("Generate Image") 161 | with gr.Column(): 162 | multi_output_image = gr.Image(label="Generated Image") 163 | 164 | # Add examples for Multi-Condition Generation 165 | gr.Examples( 166 | examples=multi_examples, 167 | inputs=[multi_prompt, multi_subject_img, multi_spatial_img, multi_height, multi_width, multi_seed], 168 | outputs=multi_output_image, 169 | fn=multi_condition_generate_image, 170 | cache_examples=False, # 缓存示例结果以加快加载速度 171 | label="Multi-Condition Examples" 172 | ) 173 | 174 | 175 | # Link the buttons to the functions 176 | single_generate_btn.click( 177 | single_condition_generate_image, 178 | inputs=[prompt, subject_img, spatial_img, height, width, seed, control_type, style_lora], 179 | outputs=single_output_image 180 | ) 181 | multi_generate_btn.click( 182 | multi_condition_generate_image, 183 | inputs=[multi_prompt, multi_subject_img, multi_spatial_img, multi_height, multi_width, multi_seed], 184 | outputs=multi_output_image 185 | ) 186 | 187 | # Launch the Gradio app 188 | demo.queue().launch() 189 | -------------------------------------------------------------------------------- /assets/CFG-Zero/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/CFG-Zero/image.webp -------------------------------------------------------------------------------- /assets/CFG-Zero/image_CFG.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/CFG-Zero/image_CFG.webp -------------------------------------------------------------------------------- /assets/CFG-Zero/image_CFG_zero_star.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/CFG-Zero/image_CFG_zero_star.webp -------------------------------------------------------------------------------- /assets/example1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/example1.jpeg -------------------------------------------------------------------------------- /assets/example2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/example2.jpeg -------------------------------------------------------------------------------- /assets/example3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/example3.jpeg -------------------------------------------------------------------------------- /assets/example4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/example4.jpeg -------------------------------------------------------------------------------- /assets/img.md: -------------------------------------------------------------------------------- 1 | put imgs here! 2 | -------------------------------------------------------------------------------- /assets/method.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/method.jpg -------------------------------------------------------------------------------- /assets/result_canny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/result_canny.png -------------------------------------------------------------------------------- /assets/result_ghibli.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/result_ghibli.png -------------------------------------------------------------------------------- /assets/result_subject.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/result_subject.png -------------------------------------------------------------------------------- /assets/result_subject_inpainting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/result_subject_inpainting.png -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/teaser.jpg -------------------------------------------------------------------------------- /infer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "!nvidia-smi" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import os\n", 19 | "import json\n", 20 | "import time\n", 21 | "import torch\n", 22 | "from PIL import Image\n", 23 | "from tqdm import tqdm\n", 24 | "\n", 25 | "from safetensors.torch import save_file\n", 26 | "from src.pipeline import FluxPipeline\n", 27 | "from src.transformer_flux import FluxTransformer2DModel\n", 28 | "from src.lora_helper import set_single_lora, set_multi_lora, unset_lora\n", 29 | "\n", 30 | "torch.cuda.set_device(1)\n", 31 | "\n", 32 | "class ImageProcessor:\n", 33 | " def __init__(self, path):\n", 34 | " device = \"cuda\"\n", 35 | " self.pipe = FluxPipeline.from_pretrained(path, torch_dtype=torch.bfloat16, device=device)\n", 36 | " transformer = FluxTransformer2DModel.from_pretrained(path, subfolder=\"transformer\",torch_dtype=torch.bfloat16, device=device)\n", 37 | " self.pipe.transformer = transformer\n", 38 | " self.pipe.to(device)\n", 39 | " \n", 40 | " def clear_cache(self, transformer):\n", 41 | " for name, attn_processor in transformer.attn_processors.items():\n", 42 | " attn_processor.bank_kv.clear()\n", 43 | " \n", 44 | " def process_image(self, prompt='', subject_imgs=[], spatial_imgs=[], height = 768, width = 768, output_path=None, seed=42):\n", 45 | " if len(spatial_imgs)>0:\n", 46 | " spatial_ls = [Image.open(image_path).convert(\"RGB\") for image_path in spatial_imgs]\n", 47 | " else:\n", 48 | " spatial_ls = []\n", 49 | " if len(subject_imgs)>0:\n", 50 | " subject_ls = [Image.open(image_path).convert(\"RGB\") for image_path in subject_imgs]\n", 51 | " else:\n", 52 | " subject_ls = []\n", 53 | "\n", 54 | " prompt = prompt\n", 55 | " image = self.pipe(\n", 56 | " prompt,\n", 57 | " height=int(height),\n", 58 | " width=int(width),\n", 59 | " guidance_scale=3.5,\n", 60 | " num_inference_steps=25,\n", 61 | " max_sequence_length=512,\n", 62 | " generator=torch.Generator(\"cpu\").manual_seed(seed), \n", 63 | " subject_images=subject_ls,\n", 64 | " spatial_images=spatial_ls,\n", 65 | " cond_size=512,\n", 66 | " ).images[0]\n", 67 | " self.clear_cache(self.pipe.transformer)\n", 68 | " image.show()\n", 69 | " if output_path:\n", 70 | " image.save(output_path)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "### models path ###\n", 80 | "# spatial model\n", 81 | "base_path = \"FLUX.1-dev\" # your flux model path\n", 82 | "lora_path = \"./models\" # your lora folder path\n", 83 | "canny_path = lora_path + \"/canny.safetensors\"\n", 84 | "depth_path = lora_path + \"/depth.safetensors\"\n", 85 | "openpose_path = lora_path + \"/pose.safetensors\"\n", 86 | "inpainting_path = lora_path + \"/inpainting.safetensors\"\n", 87 | "hedsketch_path = lora_path + \"/hedsketch.safetensors\"\n", 88 | "seg_path = lora_path + \"/seg.safetensors\"\n", 89 | "# subject model\n", 90 | "subject_path = lora_path + \"/subject.safetensors\"\n", 91 | "\n", 92 | "# init image processor\n", 93 | "processor = ImageProcessor(base_path)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "for single condition" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "# set lora\n", 110 | "path = depth_path # single control model path\n", 111 | "lora_weights=[1] # lora weights for each control model\n", 112 | "set_single_lora(processor.pipe.transformer, path, lora_weights=lora_weights,cond_size=512)\n", 113 | "\n", 114 | "# infer\n", 115 | "prompt='a cafe bar'\n", 116 | "spatial_imgs=[\"./test_imgs/depth.png\"]\n", 117 | "height = 1024\n", 118 | "width = 1024\n", 119 | "processor.process_image(prompt=prompt, subject_imgs=[], spatial_imgs=spatial_imgs, height=height, width=width, seed=11)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "for multi condition" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "# set lora\n", 136 | "paths = [subject_path, inpainting_path] # multi control model paths\n", 137 | "lora_weights=[[1],[1]] # lora weights for each control model\n", 138 | "set_multi_lora(processor.pipe.transformer, paths, lora_weights=lora_weights, cond_size=512)\n", 139 | "\n", 140 | "# infer\n", 141 | "prompt='A SKS on the car'\n", 142 | "spatial_imgs=[\"./test_imgs/subject_1.png\"]\n", 143 | "subject_imgs=[\"./test_imgs/inpainting.png\"]\n", 144 | "height = 1024\n", 145 | "width = 1024\n", 146 | "processor.process_image(prompt=prompt, subject_imgs=subject_imgs, spatial_imgs=spatial_imgs, height=height, width=width, seed=42)" 147 | ] 148 | } 149 | ], 150 | "metadata": { 151 | "kernelspec": { 152 | "display_name": "zyxdit", 153 | "language": "python", 154 | "name": "python3" 155 | }, 156 | "language_info": { 157 | "codemirror_mode": { 158 | "name": "ipython", 159 | "version": 3 160 | }, 161 | "file_extension": ".py", 162 | "mimetype": "text/x-python", 163 | "name": "python", 164 | "nbconvert_exporter": "python", 165 | "pygments_lexer": "ipython3", 166 | "version": "3.10.16" 167 | } 168 | }, 169 | "nbformat": 4, 170 | "nbformat_minor": 2 171 | } 172 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from src.pipeline import FluxPipeline 4 | from src.transformer_flux import FluxTransformer2DModel 5 | from src.lora_helper import set_single_lora, set_multi_lora 6 | 7 | from huggingface_hub import hf_hub_download 8 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/canny.safetensors", local_dir="./") 9 | 10 | def clear_cache(transformer): 11 | for name, attn_processor in transformer.attn_processors.items(): 12 | attn_processor.bank_kv.clear() 13 | 14 | # Initialize model 15 | device = "cuda" 16 | base_path = "black-forest-labs/FLUX.1-dev" # Path to your base model 17 | pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16, device=device) 18 | transformer = FluxTransformer2DModel.from_pretrained( 19 | base_path, 20 | subfolder="transformer", 21 | torch_dtype=torch.bfloat16, 22 | device=device 23 | ) 24 | pipe.transformer = transformer 25 | pipe.to(device) 26 | 27 | # Load control models 28 | lora_path = "./models" 29 | control_models = { 30 | "canny": f"{lora_path}/canny.safetensors", 31 | "depth": f"{lora_path}/depth.safetensors", 32 | "hedsketch": f"{lora_path}/hedsketch.safetensors", 33 | "pose": f"{lora_path}/pose.safetensors", 34 | "seg": f"{lora_path}/seg.safetensors", 35 | "inpainting": f"{lora_path}/inpainting.safetensors", 36 | "subject": f"{lora_path}/subject.safetensors", 37 | } 38 | 39 | # Single spatial condition control example 40 | path = control_models["canny"] 41 | set_single_lora(pipe.transformer, path, lora_weights=[1], cond_size=512) 42 | 43 | # Generate image 44 | prompt = "A nice car on the beach" 45 | 46 | spatial_image = Image.open("./test_imgs/canny.png") 47 | 48 | image = pipe( 49 | prompt, 50 | height=768, 51 | width=1024, 52 | guidance_scale=3.5, 53 | num_inference_steps=25, 54 | max_sequence_length=512, 55 | generator=torch.Generator("cpu").manual_seed(5), 56 | spatial_images=[spatial_image], 57 | subject_images=[], 58 | cond_size=512, 59 | ).images[0] 60 | 61 | # Clear cache after generation 62 | clear_cache(pipe.transformer) 63 | 64 | image.save("output.png") 65 | -------------------------------------------------------------------------------- /infer_multi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from src.pipeline import FluxPipeline 4 | from src.transformer_flux import FluxTransformer2DModel 5 | from src.lora_helper import set_single_lora, set_multi_lora 6 | 7 | from huggingface_hub import hf_hub_download 8 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/canny.safetensors", local_dir="./") 9 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/inpainting.safetensors", local_dir="./") 10 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/subject.safetensors", local_dir="./") 11 | 12 | def clear_cache(transformer): 13 | for name, attn_processor in transformer.attn_processors.items(): 14 | attn_processor.bank_kv.clear() 15 | 16 | # Initialize model 17 | device = "cuda" 18 | base_path = "black-forest-labs/FLUX.1-dev" # Path to your base model 19 | pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16, device=device) 20 | transformer = FluxTransformer2DModel.from_pretrained( 21 | base_path, 22 | subfolder="transformer", 23 | torch_dtype=torch.bfloat16, 24 | device=device 25 | ) 26 | pipe.transformer = transformer 27 | pipe.to(device) 28 | 29 | # Load control models 30 | lora_path = "./models" 31 | control_models = { 32 | "canny": f"{lora_path}/canny.safetensors", 33 | "depth": f"{lora_path}/depth.safetensors", 34 | "hedsketch": f"{lora_path}/hedsketch.safetensors", 35 | "pose": f"{lora_path}/pose.safetensors", 36 | "seg": f"{lora_path}/seg.safetensors", 37 | "inpainting": f"{lora_path}/inpainting.safetensors", 38 | "subject": f"{lora_path}/subject.safetensors", 39 | } 40 | 41 | # Single spatial condition control example 42 | path = control_models["canny"] 43 | set_single_lora(pipe.transformer, path, lora_weights=[1], cond_size=512) 44 | # Multi-condition control example 45 | paths = [control_models["subject"], control_models["inpainting"]] 46 | set_multi_lora(pipe.transformer, paths, lora_weights=[[1], [1]], cond_size=512) 47 | 48 | prompt = "A SKS on the car" 49 | subject_images = [Image.open("./test_imgs/subject_1.png").convert("RGB")] 50 | spatial_images = [Image.open("./test_imgs/inpainting.png").convert("RGB")] 51 | 52 | image = pipe( 53 | prompt, 54 | height=1024, 55 | width=1024, 56 | guidance_scale=3.5, 57 | num_inference_steps=25, 58 | max_sequence_length=512, 59 | generator=torch.Generator("cpu").manual_seed(42), 60 | subject_images=subject_images, 61 | spatial_images=spatial_images, 62 | cond_size=512, 63 | ).images[0] 64 | 65 | image.save("output_multi.png") 66 | 67 | # Clear cache after generation 68 | clear_cache(pipe.transformer) 69 | 70 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu124 2 | torch 3 | torchvision 4 | torchaudio==2.3.1 5 | diffusers==0.32.2 6 | easydict==1.13 7 | einops==0.8.1 8 | peft==0.14.0 9 | pillow==11.0.0 10 | protobuf==5.29.3 11 | requests==2.32.3 12 | safetensors==0.5.2 13 | sentencepiece==0.2.0 14 | spaces==0.34.1 15 | transformers==4.49.0 16 | datasets 17 | wandb 18 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/src/__init__.py -------------------------------------------------------------------------------- /src/layers_cache.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import math 3 | from typing import Callable, List, Optional, Tuple, Union 4 | from einops import rearrange 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from diffusers.models.attention_processor import Attention 10 | 11 | class LoRALinearLayer(nn.Module): 12 | def __init__( 13 | self, 14 | in_features: int, 15 | out_features: int, 16 | rank: int = 4, 17 | network_alpha: Optional[float] = None, 18 | device: Optional[Union[torch.device, str]] = None, 19 | dtype: Optional[torch.dtype] = None, 20 | cond_width=512, 21 | cond_height=512, 22 | number=0, 23 | n_loras=1 24 | ): 25 | super().__init__() 26 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) 27 | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 28 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 29 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 30 | self.network_alpha = network_alpha 31 | self.rank = rank 32 | self.out_features = out_features 33 | self.in_features = in_features 34 | 35 | nn.init.normal_(self.down.weight, std=1 / rank) 36 | nn.init.zeros_(self.up.weight) 37 | 38 | self.cond_height = cond_height 39 | self.cond_width = cond_width 40 | self.number = number 41 | self.n_loras = n_loras 42 | 43 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 44 | orig_dtype = hidden_states.dtype 45 | dtype = self.down.weight.dtype 46 | 47 | #### 48 | batch_size = hidden_states.shape[0] 49 | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64 50 | block_size = hidden_states.shape[1] - cond_size * self.n_loras 51 | shape = (batch_size, hidden_states.shape[1], 3072) 52 | mask = torch.ones(shape, device=hidden_states.device, dtype=dtype) 53 | mask[:, :block_size+self.number*cond_size, :] = 0 54 | mask[:, block_size+(self.number+1)*cond_size:, :] = 0 55 | hidden_states = mask * hidden_states 56 | #### 57 | 58 | down_hidden_states = self.down(hidden_states.to(dtype)) 59 | up_hidden_states = self.up(down_hidden_states) 60 | 61 | if self.network_alpha is not None: 62 | up_hidden_states *= self.network_alpha / self.rank 63 | 64 | return up_hidden_states.to(orig_dtype) 65 | 66 | 67 | class MultiSingleStreamBlockLoraProcessor(nn.Module): 68 | def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1): 69 | super().__init__() 70 | # Initialize a list to store the LoRA layers 71 | self.n_loras = n_loras 72 | self.cond_width = cond_width 73 | self.cond_height = cond_height 74 | 75 | self.q_loras = nn.ModuleList([ 76 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 77 | for i in range(n_loras) 78 | ]) 79 | self.k_loras = nn.ModuleList([ 80 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 81 | for i in range(n_loras) 82 | ]) 83 | self.v_loras = nn.ModuleList([ 84 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 85 | for i in range(n_loras) 86 | ]) 87 | self.lora_weights = lora_weights 88 | self.bank_attn = None 89 | self.bank_kv = [] 90 | 91 | 92 | def __call__(self, 93 | attn: Attention, 94 | hidden_states: torch.FloatTensor, 95 | encoder_hidden_states: torch.FloatTensor = None, 96 | attention_mask: Optional[torch.FloatTensor] = None, 97 | image_rotary_emb: Optional[torch.Tensor] = None, 98 | use_cond = False 99 | ) -> torch.FloatTensor: 100 | 101 | batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 102 | scaled_seq_len = hidden_states.shape[1] 103 | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64 104 | block_size = scaled_seq_len - cond_size * self.n_loras 105 | scaled_cond_size = cond_size 106 | scaled_block_size = block_size 107 | 108 | if len(self.bank_kv)== 0: 109 | cache = True 110 | else: 111 | cache = False 112 | 113 | if cache: 114 | query = attn.to_q(hidden_states) 115 | key = attn.to_k(hidden_states) 116 | value = attn.to_v(hidden_states) 117 | for i in range(self.n_loras): 118 | query = query + self.lora_weights[i] * self.q_loras[i](hidden_states) 119 | key = key + self.lora_weights[i] * self.k_loras[i](hidden_states) 120 | value = value + self.lora_weights[i] * self.v_loras[i](hidden_states) 121 | 122 | inner_dim = key.shape[-1] 123 | head_dim = inner_dim // attn.heads 124 | 125 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 126 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 127 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 128 | 129 | self.bank_kv.append(key[:, :, scaled_block_size:, :]) 130 | self.bank_kv.append(value[:, :, scaled_block_size:, :]) 131 | 132 | if attn.norm_q is not None: 133 | query = attn.norm_q(query) 134 | if attn.norm_k is not None: 135 | key = attn.norm_k(key) 136 | 137 | if image_rotary_emb is not None: 138 | from diffusers.models.embeddings import apply_rotary_emb 139 | query = apply_rotary_emb(query, image_rotary_emb) 140 | key = apply_rotary_emb(key, image_rotary_emb) 141 | 142 | num_cond_blocks = self.n_loras 143 | mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device) 144 | mask[ :scaled_block_size, :] = 0 # First block_size row 145 | for i in range(num_cond_blocks): 146 | start = i * scaled_cond_size + scaled_block_size 147 | end = (i + 1) * scaled_cond_size + scaled_block_size 148 | mask[start:end, start:end] = 0 # Diagonal blocks 149 | mask = mask * -1e20 150 | mask = mask.to(query.dtype) 151 | 152 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask) 153 | self.bank_attn = hidden_states[:, :, scaled_block_size:, :] 154 | 155 | else: 156 | query = attn.to_q(hidden_states) 157 | key = attn.to_k(hidden_states) 158 | value = attn.to_v(hidden_states) 159 | 160 | inner_dim = query.shape[-1] 161 | head_dim = inner_dim // attn.heads 162 | 163 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 164 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 165 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 166 | 167 | key = torch.concat([key[:, :, :scaled_block_size, :], self.bank_kv[0]], dim=-2) 168 | value = torch.concat([value[:, :, :scaled_block_size, :], self.bank_kv[1]], dim=-2) 169 | 170 | if attn.norm_q is not None: 171 | query = attn.norm_q(query) 172 | if attn.norm_k is not None: 173 | key = attn.norm_k(key) 174 | 175 | if image_rotary_emb is not None: 176 | from diffusers.models.embeddings import apply_rotary_emb 177 | query = apply_rotary_emb(query, image_rotary_emb) 178 | key = apply_rotary_emb(key, image_rotary_emb) 179 | 180 | query = query[:, :, :scaled_block_size, :] 181 | 182 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None) 183 | hidden_states = torch.concat([hidden_states, self.bank_attn], dim=-2) 184 | 185 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 186 | hidden_states = hidden_states.to(query.dtype) 187 | 188 | cond_hidden_states = hidden_states[:, block_size:,:] 189 | hidden_states = hidden_states[:, : block_size,:] 190 | 191 | return hidden_states if not use_cond else (hidden_states, cond_hidden_states) 192 | 193 | 194 | class MultiDoubleStreamBlockLoraProcessor(nn.Module): 195 | def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1): 196 | super().__init__() 197 | 198 | # Initialize a list to store the LoRA layers 199 | self.n_loras = n_loras 200 | self.cond_width = cond_width 201 | self.cond_height = cond_height 202 | self.q_loras = nn.ModuleList([ 203 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 204 | for i in range(n_loras) 205 | ]) 206 | self.k_loras = nn.ModuleList([ 207 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 208 | for i in range(n_loras) 209 | ]) 210 | self.v_loras = nn.ModuleList([ 211 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 212 | for i in range(n_loras) 213 | ]) 214 | self.proj_loras = nn.ModuleList([ 215 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 216 | for i in range(n_loras) 217 | ]) 218 | self.lora_weights = lora_weights 219 | self.bank_attn = None 220 | self.bank_kv = [] 221 | 222 | 223 | def __call__(self, 224 | attn: Attention, 225 | hidden_states: torch.FloatTensor, 226 | encoder_hidden_states: torch.FloatTensor = None, 227 | attention_mask: Optional[torch.FloatTensor] = None, 228 | image_rotary_emb: Optional[torch.Tensor] = None, 229 | use_cond=False, 230 | ) -> torch.FloatTensor: 231 | 232 | batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 233 | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64 234 | block_size = hidden_states.shape[1] - cond_size * self.n_loras 235 | scaled_seq_len = encoder_hidden_states.shape[1] + hidden_states.shape[1] 236 | scaled_cond_size = cond_size 237 | scaled_block_size = scaled_seq_len - scaled_cond_size * self.n_loras 238 | 239 | # `context` projections. 240 | inner_dim = 3072 241 | head_dim = inner_dim // attn.heads 242 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 243 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 244 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 245 | 246 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( 247 | batch_size, -1, attn.heads, head_dim 248 | ).transpose(1, 2) 249 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( 250 | batch_size, -1, attn.heads, head_dim 251 | ).transpose(1, 2) 252 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( 253 | batch_size, -1, attn.heads, head_dim 254 | ).transpose(1, 2) 255 | 256 | if attn.norm_added_q is not None: 257 | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) 258 | if attn.norm_added_k is not None: 259 | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) 260 | 261 | if len(self.bank_kv)== 0: 262 | cache = True 263 | else: 264 | cache = False 265 | 266 | if cache: 267 | 268 | query = attn.to_q(hidden_states) 269 | key = attn.to_k(hidden_states) 270 | value = attn.to_v(hidden_states) 271 | for i in range(self.n_loras): 272 | query = query + self.lora_weights[i] * self.q_loras[i](hidden_states) 273 | key = key + self.lora_weights[i] * self.k_loras[i](hidden_states) 274 | value = value + self.lora_weights[i] * self.v_loras[i](hidden_states) 275 | 276 | inner_dim = key.shape[-1] 277 | head_dim = inner_dim // attn.heads 278 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 279 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 280 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 281 | 282 | 283 | self.bank_kv.append(key[:, :, block_size:, :]) 284 | self.bank_kv.append(value[:, :, block_size:, :]) 285 | 286 | if attn.norm_q is not None: 287 | query = attn.norm_q(query) 288 | if attn.norm_k is not None: 289 | key = attn.norm_k(key) 290 | 291 | # attention 292 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) 293 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) 294 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) 295 | 296 | if image_rotary_emb is not None: 297 | from diffusers.models.embeddings import apply_rotary_emb 298 | query = apply_rotary_emb(query, image_rotary_emb) 299 | key = apply_rotary_emb(key, image_rotary_emb) 300 | 301 | num_cond_blocks = self.n_loras 302 | mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device) 303 | mask[ :scaled_block_size, :] = 0 # First block_size row 304 | for i in range(num_cond_blocks): 305 | start = i * scaled_cond_size + scaled_block_size 306 | end = (i + 1) * scaled_cond_size + scaled_block_size 307 | mask[start:end, start:end] = 0 # Diagonal blocks 308 | mask = mask * -1e20 309 | mask = mask.to(query.dtype) 310 | 311 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask) 312 | self.bank_attn = hidden_states[:, :, scaled_block_size:, :] 313 | 314 | else: 315 | query = attn.to_q(hidden_states) 316 | key = attn.to_k(hidden_states) 317 | value = attn.to_v(hidden_states) 318 | 319 | inner_dim = query.shape[-1] 320 | head_dim = inner_dim // attn.heads 321 | 322 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 323 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 324 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 325 | 326 | key = torch.concat([key[:, :, :block_size, :], self.bank_kv[0]], dim=-2) 327 | value = torch.concat([value[:, :, :block_size, :], self.bank_kv[1]], dim=-2) 328 | 329 | if attn.norm_q is not None: 330 | query = attn.norm_q(query) 331 | if attn.norm_k is not None: 332 | key = attn.norm_k(key) 333 | 334 | # attention 335 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) 336 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) 337 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) 338 | 339 | if image_rotary_emb is not None: 340 | from diffusers.models.embeddings import apply_rotary_emb 341 | query = apply_rotary_emb(query, image_rotary_emb) 342 | key = apply_rotary_emb(key, image_rotary_emb) 343 | 344 | query = query[:, :, :scaled_block_size, :] 345 | 346 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None) 347 | hidden_states = torch.concat([hidden_states, self.bank_attn], dim=-2) 348 | 349 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 350 | hidden_states = hidden_states.to(query.dtype) 351 | 352 | encoder_hidden_states, hidden_states = ( 353 | hidden_states[:, : encoder_hidden_states.shape[1]], 354 | hidden_states[:, encoder_hidden_states.shape[1] :], 355 | ) 356 | 357 | # Linear projection (with LoRA weight applied to each proj layer) 358 | hidden_states = attn.to_out[0](hidden_states) 359 | for i in range(self.n_loras): 360 | hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states) 361 | # dropout 362 | hidden_states = attn.to_out[1](hidden_states) 363 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 364 | 365 | cond_hidden_states = hidden_states[:, block_size:,:] 366 | hidden_states = hidden_states[:, :block_size,:] 367 | 368 | return (hidden_states, encoder_hidden_states, cond_hidden_states) if use_cond else (encoder_hidden_states, hidden_states) -------------------------------------------------------------------------------- /src/lora_helper.py: -------------------------------------------------------------------------------- 1 | from diffusers.models.attention_processor import FluxAttnProcessor2_0 2 | from safetensors import safe_open 3 | import re 4 | import torch 5 | from .layers_cache import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor 6 | 7 | device = "cuda" 8 | 9 | def load_safetensors(path): 10 | tensors = {} 11 | with safe_open(path, framework="pt", device="cpu") as f: 12 | for key in f.keys(): 13 | tensors[key] = f.get_tensor(key) 14 | return tensors 15 | 16 | def get_lora_rank(checkpoint): 17 | for k in checkpoint.keys(): 18 | if k.endswith(".down.weight"): 19 | return checkpoint[k].shape[0] 20 | 21 | def load_checkpoint(local_path): 22 | if local_path is not None: 23 | if '.safetensors' in local_path: 24 | print(f"Loading .safetensors checkpoint from {local_path}") 25 | checkpoint = load_safetensors(local_path) 26 | else: 27 | print(f"Loading checkpoint from {local_path}") 28 | checkpoint = torch.load(local_path, map_location='cpu') 29 | return checkpoint 30 | 31 | def update_model_with_lora(checkpoint, lora_weights, transformer, cond_size): 32 | number = len(lora_weights) 33 | ranks = [get_lora_rank(checkpoint) for _ in range(number)] 34 | lora_attn_procs = {} 35 | double_blocks_idx = list(range(19)) 36 | single_blocks_idx = list(range(38)) 37 | for name, attn_processor in transformer.attn_processors.items(): 38 | match = re.search(r'\.(\d+)\.', name) 39 | if match: 40 | layer_index = int(match.group(1)) 41 | 42 | if name.startswith("transformer_blocks") and layer_index in double_blocks_idx: 43 | 44 | lora_state_dicts = {} 45 | for key, value in checkpoint.items(): 46 | # Match based on the layer index in the key (assuming the key contains layer index) 47 | if re.search(r'\.(\d+)\.', key): 48 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) 49 | if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"): 50 | lora_state_dicts[key] = value 51 | 52 | lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor( 53 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number 54 | ) 55 | 56 | # Load the weights from the checkpoint dictionary into the corresponding layers 57 | for n in range(number): 58 | lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None) 59 | lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None) 60 | lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None) 61 | lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None) 62 | lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None) 63 | lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None) 64 | lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None) 65 | lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None) 66 | lora_attn_procs[name].to(device) 67 | 68 | elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx: 69 | 70 | lora_state_dicts = {} 71 | for key, value in checkpoint.items(): 72 | # Match based on the layer index in the key (assuming the key contains layer index) 73 | if re.search(r'\.(\d+)\.', key): 74 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) 75 | if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"): 76 | lora_state_dicts[key] = value 77 | 78 | lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor( 79 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number 80 | ) 81 | # Load the weights from the checkpoint dictionary into the corresponding layers 82 | for n in range(number): 83 | lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None) 84 | lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None) 85 | lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None) 86 | lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None) 87 | lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None) 88 | lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None) 89 | lora_attn_procs[name].to(device) 90 | else: 91 | lora_attn_procs[name] = FluxAttnProcessor2_0() 92 | 93 | transformer.set_attn_processor(lora_attn_procs) 94 | 95 | 96 | def update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size): 97 | ck_number = len(checkpoints) 98 | cond_lora_number = [len(ls) for ls in lora_weights] 99 | cond_number = sum(cond_lora_number) 100 | ranks = [get_lora_rank(checkpoint) for checkpoint in checkpoints] 101 | multi_lora_weight = [] 102 | for ls in lora_weights: 103 | for n in ls: 104 | multi_lora_weight.append(n) 105 | 106 | lora_attn_procs = {} 107 | double_blocks_idx = list(range(19)) 108 | single_blocks_idx = list(range(38)) 109 | for name, attn_processor in transformer.attn_processors.items(): 110 | match = re.search(r'\.(\d+)\.', name) 111 | if match: 112 | layer_index = int(match.group(1)) 113 | 114 | if name.startswith("transformer_blocks") and layer_index in double_blocks_idx: 115 | lora_state_dicts = [{} for _ in range(ck_number)] 116 | for idx, checkpoint in enumerate(checkpoints): 117 | for key, value in checkpoint.items(): 118 | # Match based on the layer index in the key (assuming the key contains layer index) 119 | if re.search(r'\.(\d+)\.', key): 120 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) 121 | if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"): 122 | lora_state_dicts[idx][key] = value 123 | 124 | lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor( 125 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number 126 | ) 127 | 128 | # Load the weights from the checkpoint dictionary into the corresponding layers 129 | num = 0 130 | for idx in range(ck_number): 131 | for n in range(cond_lora_number[idx]): 132 | lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None) 133 | lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None) 134 | lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None) 135 | lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None) 136 | lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None) 137 | lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None) 138 | lora_attn_procs[name].proj_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.down.weight', None) 139 | lora_attn_procs[name].proj_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.up.weight', None) 140 | lora_attn_procs[name].to(device) 141 | num += 1 142 | 143 | elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx: 144 | 145 | lora_state_dicts = [{} for _ in range(ck_number)] 146 | for idx, checkpoint in enumerate(checkpoints): 147 | for key, value in checkpoint.items(): 148 | # Match based on the layer index in the key (assuming the key contains layer index) 149 | if re.search(r'\.(\d+)\.', key): 150 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) 151 | if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"): 152 | lora_state_dicts[idx][key] = value 153 | 154 | lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor( 155 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number 156 | ) 157 | # Load the weights from the checkpoint dictionary into the corresponding layers 158 | num = 0 159 | for idx in range(ck_number): 160 | for n in range(cond_lora_number[idx]): 161 | lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None) 162 | lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None) 163 | lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None) 164 | lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None) 165 | lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None) 166 | lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None) 167 | lora_attn_procs[name].to(device) 168 | num += 1 169 | 170 | else: 171 | lora_attn_procs[name] = FluxAttnProcessor2_0() 172 | 173 | transformer.set_attn_processor(lora_attn_procs) 174 | 175 | 176 | def set_single_lora(transformer, local_path, lora_weights=[], cond_size=512): 177 | checkpoint = load_checkpoint(local_path) 178 | update_model_with_lora(checkpoint, lora_weights, transformer, cond_size) 179 | 180 | def set_multi_lora(transformer, local_paths, lora_weights=[[]], cond_size=512): 181 | checkpoints = [load_checkpoint(local_path) for local_path in local_paths] 182 | update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size) 183 | 184 | def unset_lora(transformer): 185 | lora_attn_procs = {} 186 | for name, attn_processor in transformer.attn_processors.items(): 187 | lora_attn_procs[name] = FluxAttnProcessor2_0() 188 | transformer.set_attn_processor(lora_attn_procs) 189 | 190 | 191 | ''' 192 | unset_lora(pipe.transformer) 193 | lora_path = "./lora.safetensors" 194 | lora_weights = [1, 1] 195 | set_lora(pipe.transformer, local_path=lora_path, lora_weights=lora_weights, cond_size=512) 196 | ''' -------------------------------------------------------------------------------- /src/transformer_flux.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config 9 | from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin 10 | from diffusers.models.attention import FeedForward 11 | from diffusers.models.attention_processor import ( 12 | Attention, 13 | AttentionProcessor, 14 | FluxAttnProcessor2_0, 15 | FluxAttnProcessor2_0_NPU, 16 | FusedFluxAttnProcessor2_0, 17 | ) 18 | from diffusers.models.modeling_utils import ModelMixin 19 | from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle 20 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 21 | from diffusers.utils.import_utils import is_torch_npu_available 22 | from diffusers.utils.torch_utils import maybe_allow_in_graph 23 | from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed 24 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 25 | 26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 27 | 28 | @maybe_allow_in_graph 29 | class FluxSingleTransformerBlock(nn.Module): 30 | 31 | def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): 32 | super().__init__() 33 | self.mlp_hidden_dim = int(dim * mlp_ratio) 34 | 35 | self.norm = AdaLayerNormZeroSingle(dim) 36 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) 37 | self.act_mlp = nn.GELU(approximate="tanh") 38 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) 39 | 40 | if is_torch_npu_available(): 41 | processor = FluxAttnProcessor2_0_NPU() 42 | else: 43 | processor = FluxAttnProcessor2_0() 44 | self.attn = Attention( 45 | query_dim=dim, 46 | cross_attention_dim=None, 47 | dim_head=attention_head_dim, 48 | heads=num_attention_heads, 49 | out_dim=dim, 50 | bias=True, 51 | processor=processor, 52 | qk_norm="rms_norm", 53 | eps=1e-6, 54 | pre_only=True, 55 | ) 56 | 57 | def forward( 58 | self, 59 | hidden_states: torch.Tensor, 60 | cond_hidden_states: torch.Tensor, 61 | temb: torch.Tensor, 62 | cond_temb: torch.Tensor, 63 | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 64 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 65 | ) -> torch.Tensor: 66 | use_cond = cond_hidden_states is not None 67 | 68 | residual = hidden_states 69 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb) 70 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) 71 | 72 | if use_cond: 73 | residual_cond = cond_hidden_states 74 | norm_cond_hidden_states, cond_gate = self.norm(cond_hidden_states, emb=cond_temb) 75 | mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_cond_hidden_states)) 76 | 77 | norm_hidden_states_concat = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2) 78 | 79 | joint_attention_kwargs = joint_attention_kwargs or {} 80 | attn_output = self.attn( 81 | hidden_states=norm_hidden_states_concat, 82 | image_rotary_emb=image_rotary_emb, 83 | use_cond=use_cond, 84 | **joint_attention_kwargs, 85 | ) 86 | if use_cond: 87 | attn_output, cond_attn_output = attn_output 88 | 89 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) 90 | gate = gate.unsqueeze(1) 91 | hidden_states = gate * self.proj_out(hidden_states) 92 | hidden_states = residual + hidden_states 93 | 94 | if use_cond: 95 | condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2) 96 | cond_gate = cond_gate.unsqueeze(1) 97 | condition_latents = cond_gate * self.proj_out(condition_latents) 98 | condition_latents = residual_cond + condition_latents 99 | 100 | if hidden_states.dtype == torch.float16: 101 | hidden_states = hidden_states.clip(-65504, 65504) 102 | 103 | return hidden_states, condition_latents if use_cond else None 104 | 105 | 106 | @maybe_allow_in_graph 107 | class FluxTransformerBlock(nn.Module): 108 | def __init__( 109 | self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 110 | ): 111 | super().__init__() 112 | 113 | self.norm1 = AdaLayerNormZero(dim) 114 | 115 | self.norm1_context = AdaLayerNormZero(dim) 116 | 117 | if hasattr(F, "scaled_dot_product_attention"): 118 | processor = FluxAttnProcessor2_0() 119 | else: 120 | raise ValueError( 121 | "The current PyTorch version does not support the `scaled_dot_product_attention` function." 122 | ) 123 | self.attn = Attention( 124 | query_dim=dim, 125 | cross_attention_dim=None, 126 | added_kv_proj_dim=dim, 127 | dim_head=attention_head_dim, 128 | heads=num_attention_heads, 129 | out_dim=dim, 130 | context_pre_only=False, 131 | bias=True, 132 | processor=processor, 133 | qk_norm=qk_norm, 134 | eps=eps, 135 | ) 136 | 137 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 138 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 139 | 140 | self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 141 | self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 142 | 143 | # let chunk size default to None 144 | self._chunk_size = None 145 | self._chunk_dim = 0 146 | 147 | def forward( 148 | self, 149 | hidden_states: torch.Tensor, 150 | cond_hidden_states: torch.Tensor, 151 | encoder_hidden_states: torch.Tensor, 152 | temb: torch.Tensor, 153 | cond_temb: torch.Tensor, 154 | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 155 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 156 | ) -> Tuple[torch.Tensor, torch.Tensor]: 157 | use_cond = cond_hidden_states is not None 158 | 159 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) 160 | if use_cond: 161 | ( 162 | norm_cond_hidden_states, 163 | cond_gate_msa, 164 | cond_shift_mlp, 165 | cond_scale_mlp, 166 | cond_gate_mlp, 167 | ) = self.norm1(cond_hidden_states, emb=cond_temb) 168 | 169 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( 170 | encoder_hidden_states, emb=temb 171 | ) 172 | 173 | norm_hidden_states = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2) 174 | 175 | joint_attention_kwargs = joint_attention_kwargs or {} 176 | # Attention. 177 | attention_outputs = self.attn( 178 | hidden_states=norm_hidden_states, 179 | encoder_hidden_states=norm_encoder_hidden_states, 180 | image_rotary_emb=image_rotary_emb, 181 | use_cond=use_cond, 182 | **joint_attention_kwargs, 183 | ) 184 | 185 | attn_output, context_attn_output = attention_outputs[:2] 186 | cond_attn_output = attention_outputs[2] if use_cond else None 187 | 188 | # Process attention outputs for the `hidden_states`. 189 | attn_output = gate_msa.unsqueeze(1) * attn_output 190 | hidden_states = hidden_states + attn_output 191 | 192 | if use_cond: 193 | cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output 194 | cond_hidden_states = cond_hidden_states + cond_attn_output 195 | 196 | norm_hidden_states = self.norm2(hidden_states) 197 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 198 | 199 | if use_cond: 200 | norm_cond_hidden_states = self.norm2(cond_hidden_states) 201 | norm_cond_hidden_states = ( 202 | norm_cond_hidden_states * (1 + cond_scale_mlp[:, None]) 203 | + cond_shift_mlp[:, None] 204 | ) 205 | 206 | ff_output = self.ff(norm_hidden_states) 207 | ff_output = gate_mlp.unsqueeze(1) * ff_output 208 | hidden_states = hidden_states + ff_output 209 | 210 | if use_cond: 211 | cond_ff_output = self.ff(norm_cond_hidden_states) 212 | cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output 213 | cond_hidden_states = cond_hidden_states + cond_ff_output 214 | 215 | # Process attention outputs for the `encoder_hidden_states`. 216 | 217 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output 218 | encoder_hidden_states = encoder_hidden_states + context_attn_output 219 | 220 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) 221 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] 222 | 223 | context_ff_output = self.ff_context(norm_encoder_hidden_states) 224 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output 225 | if encoder_hidden_states.dtype == torch.float16: 226 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) 227 | 228 | return encoder_hidden_states, hidden_states, cond_hidden_states if use_cond else None 229 | 230 | 231 | class FluxTransformer2DModel( 232 | ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin 233 | ): 234 | _supports_gradient_checkpointing = True 235 | _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] 236 | 237 | @register_to_config 238 | def __init__( 239 | self, 240 | patch_size: int = 1, 241 | in_channels: int = 64, 242 | out_channels: Optional[int] = None, 243 | num_layers: int = 19, 244 | num_single_layers: int = 38, 245 | attention_head_dim: int = 128, 246 | num_attention_heads: int = 24, 247 | joint_attention_dim: int = 4096, 248 | pooled_projection_dim: int = 768, 249 | guidance_embeds: bool = False, 250 | axes_dims_rope: Tuple[int] = (16, 56, 56), 251 | ): 252 | super().__init__() 253 | self.out_channels = out_channels or in_channels 254 | self.inner_dim = num_attention_heads * attention_head_dim 255 | 256 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) 257 | 258 | text_time_guidance_cls = ( 259 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings 260 | ) 261 | self.time_text_embed = text_time_guidance_cls( 262 | embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim 263 | ) 264 | 265 | self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) 266 | self.x_embedder = nn.Linear(in_channels, self.inner_dim) 267 | 268 | self.transformer_blocks = nn.ModuleList( 269 | [ 270 | FluxTransformerBlock( 271 | dim=self.inner_dim, 272 | num_attention_heads=num_attention_heads, 273 | attention_head_dim=attention_head_dim, 274 | ) 275 | for _ in range(num_layers) 276 | ] 277 | ) 278 | 279 | self.single_transformer_blocks = nn.ModuleList( 280 | [ 281 | FluxSingleTransformerBlock( 282 | dim=self.inner_dim, 283 | num_attention_heads=num_attention_heads, 284 | attention_head_dim=attention_head_dim, 285 | ) 286 | for _ in range(num_single_layers) 287 | ] 288 | ) 289 | 290 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) 291 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) 292 | 293 | self.gradient_checkpointing = False 294 | 295 | @property 296 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 297 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 298 | r""" 299 | Returns: 300 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 301 | indexed by its weight name. 302 | """ 303 | # set recursively 304 | processors = {} 305 | 306 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 307 | if hasattr(module, "get_processor"): 308 | processors[f"{name}.processor"] = module.get_processor() 309 | 310 | for sub_name, child in module.named_children(): 311 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 312 | 313 | return processors 314 | 315 | for name, module in self.named_children(): 316 | fn_recursive_add_processors(name, module, processors) 317 | 318 | return processors 319 | 320 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 321 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 322 | r""" 323 | Sets the attention processor to use to compute attention. 324 | 325 | Parameters: 326 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 327 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 328 | for **all** `Attention` layers. 329 | 330 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 331 | processor. This is strongly recommended when setting trainable attention processors. 332 | 333 | """ 334 | count = len(self.attn_processors.keys()) 335 | 336 | if isinstance(processor, dict) and len(processor) != count: 337 | raise ValueError( 338 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 339 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 340 | ) 341 | 342 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 343 | if hasattr(module, "set_processor"): 344 | if not isinstance(processor, dict): 345 | module.set_processor(processor) 346 | else: 347 | module.set_processor(processor.pop(f"{name}.processor")) 348 | 349 | for sub_name, child in module.named_children(): 350 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 351 | 352 | for name, module in self.named_children(): 353 | fn_recursive_attn_processor(name, module, processor) 354 | 355 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 356 | def fuse_qkv_projections(self): 357 | """ 358 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) 359 | are fused. For cross-attention modules, key and value projection matrices are fused. 360 | 361 | 362 | 363 | This API is 🧪 experimental. 364 | 365 | 366 | """ 367 | self.original_attn_processors = None 368 | 369 | for _, attn_processor in self.attn_processors.items(): 370 | if "Added" in str(attn_processor.__class__.__name__): 371 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") 372 | 373 | self.original_attn_processors = self.attn_processors 374 | 375 | for module in self.modules(): 376 | if isinstance(module, Attention): 377 | module.fuse_projections(fuse=True) 378 | 379 | self.set_attn_processor(FusedFluxAttnProcessor2_0()) 380 | 381 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 382 | def unfuse_qkv_projections(self): 383 | """Disables the fused QKV projection if enabled. 384 | 385 | 386 | 387 | This API is 🧪 experimental. 388 | 389 | 390 | 391 | """ 392 | if self.original_attn_processors is not None: 393 | self.set_attn_processor(self.original_attn_processors) 394 | 395 | def _set_gradient_checkpointing(self, module, value=False): 396 | if hasattr(module, "gradient_checkpointing"): 397 | module.gradient_checkpointing = value 398 | 399 | def forward( 400 | self, 401 | hidden_states: torch.Tensor, 402 | cond_hidden_states: torch.Tensor = None, 403 | encoder_hidden_states: torch.Tensor = None, 404 | pooled_projections: torch.Tensor = None, 405 | timestep: torch.LongTensor = None, 406 | img_ids: torch.Tensor = None, 407 | txt_ids: torch.Tensor = None, 408 | guidance: torch.Tensor = None, 409 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 410 | controlnet_block_samples=None, 411 | controlnet_single_block_samples=None, 412 | return_dict: bool = True, 413 | controlnet_blocks_repeat: bool = False, 414 | ) -> Union[torch.Tensor, Transformer2DModelOutput]: 415 | if cond_hidden_states is not None: 416 | use_condition = True 417 | else: 418 | use_condition = False 419 | 420 | if joint_attention_kwargs is not None: 421 | joint_attention_kwargs = joint_attention_kwargs.copy() 422 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 423 | else: 424 | lora_scale = 1.0 425 | 426 | if USE_PEFT_BACKEND: 427 | # weight the lora layers by setting `lora_scale` for each PEFT layer 428 | scale_lora_layers(self, lora_scale) 429 | else: 430 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 431 | logger.warning( 432 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 433 | ) 434 | 435 | hidden_states = self.x_embedder(hidden_states) 436 | cond_hidden_states = self.x_embedder(cond_hidden_states) 437 | 438 | timestep = timestep.to(hidden_states.dtype) * 1000 439 | if guidance is not None: 440 | guidance = guidance.to(hidden_states.dtype) * 1000 441 | else: 442 | guidance = None 443 | 444 | temb = ( 445 | self.time_text_embed(timestep, pooled_projections) 446 | if guidance is None 447 | else self.time_text_embed(timestep, guidance, pooled_projections) 448 | ) 449 | 450 | cond_temb = ( 451 | self.time_text_embed(torch.ones_like(timestep) * 0, pooled_projections) 452 | if guidance is None 453 | else self.time_text_embed( 454 | torch.ones_like(timestep) * 0, guidance, pooled_projections 455 | ) 456 | ) 457 | 458 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 459 | 460 | if txt_ids.ndim == 3: 461 | logger.warning( 462 | "Passing `txt_ids` 3d torch.Tensor is deprecated." 463 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 464 | ) 465 | txt_ids = txt_ids[0] 466 | if img_ids.ndim == 3: 467 | logger.warning( 468 | "Passing `img_ids` 3d torch.Tensor is deprecated." 469 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 470 | ) 471 | img_ids = img_ids[0] 472 | 473 | ids = torch.cat((txt_ids, img_ids), dim=0) 474 | image_rotary_emb = self.pos_embed(ids) 475 | 476 | if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: 477 | ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") 478 | ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) 479 | joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) 480 | 481 | for index_block, block in enumerate(self.transformer_blocks): 482 | if torch.is_grad_enabled() and self.gradient_checkpointing: 483 | 484 | def create_custom_forward(module, return_dict=None): 485 | def custom_forward(*inputs): 486 | if return_dict is not None: 487 | return module(*inputs, return_dict=return_dict) 488 | else: 489 | return module(*inputs) 490 | 491 | return custom_forward 492 | 493 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 494 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 495 | create_custom_forward(block), 496 | hidden_states, 497 | encoder_hidden_states, 498 | temb, 499 | image_rotary_emb, 500 | cond_temb=cond_temb if use_condition else None, 501 | cond_hidden_states=cond_hidden_states if use_condition else None, 502 | **ckpt_kwargs, 503 | ) 504 | 505 | else: 506 | encoder_hidden_states, hidden_states, cond_hidden_states = block( 507 | hidden_states=hidden_states, 508 | encoder_hidden_states=encoder_hidden_states, 509 | cond_hidden_states=cond_hidden_states if use_condition else None, 510 | temb=temb, 511 | cond_temb=cond_temb if use_condition else None, 512 | image_rotary_emb=image_rotary_emb, 513 | joint_attention_kwargs=joint_attention_kwargs, 514 | ) 515 | 516 | # controlnet residual 517 | if controlnet_block_samples is not None: 518 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) 519 | interval_control = int(np.ceil(interval_control)) 520 | # For Xlabs ControlNet. 521 | if controlnet_blocks_repeat: 522 | hidden_states = ( 523 | hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] 524 | ) 525 | else: 526 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] 527 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 528 | 529 | for index_block, block in enumerate(self.single_transformer_blocks): 530 | if torch.is_grad_enabled() and self.gradient_checkpointing: 531 | 532 | def create_custom_forward(module, return_dict=None): 533 | def custom_forward(*inputs): 534 | if return_dict is not None: 535 | return module(*inputs, return_dict=return_dict) 536 | else: 537 | return module(*inputs) 538 | 539 | return custom_forward 540 | 541 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 542 | hidden_states, cond_hidden_states = torch.utils.checkpoint.checkpoint( 543 | create_custom_forward(block), 544 | hidden_states, 545 | temb, 546 | image_rotary_emb, 547 | cond_temb=cond_temb if use_condition else None, 548 | cond_hidden_states=cond_hidden_states if use_condition else None, 549 | **ckpt_kwargs, 550 | ) 551 | 552 | else: 553 | hidden_states, cond_hidden_states = block( 554 | hidden_states=hidden_states, 555 | cond_hidden_states=cond_hidden_states if use_condition else None, 556 | temb=temb, 557 | cond_temb=cond_temb if use_condition else None, 558 | image_rotary_emb=image_rotary_emb, 559 | joint_attention_kwargs=joint_attention_kwargs, 560 | ) 561 | 562 | # controlnet residual 563 | if controlnet_single_block_samples is not None: 564 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) 565 | interval_control = int(np.ceil(interval_control)) 566 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 567 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] 568 | + controlnet_single_block_samples[index_block // interval_control] 569 | ) 570 | 571 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 572 | 573 | hidden_states = self.norm_out(hidden_states, temb) 574 | output = self.proj_out(hidden_states) 575 | 576 | if USE_PEFT_BACKEND: 577 | # remove `lora_scale` from each PEFT layer 578 | unscale_lora_layers(self, lora_scale) 579 | 580 | if not return_dict: 581 | return (output,) 582 | 583 | return Transformer2DModelOutput(sample=output) -------------------------------------------------------------------------------- /test_imgs/canny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/canny.png -------------------------------------------------------------------------------- /test_imgs/depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/depth.png -------------------------------------------------------------------------------- /test_imgs/ghibli.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/ghibli.png -------------------------------------------------------------------------------- /test_imgs/inpainting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/inpainting.png -------------------------------------------------------------------------------- /test_imgs/openpose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/openpose.png -------------------------------------------------------------------------------- /test_imgs/seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/seg.png -------------------------------------------------------------------------------- /test_imgs/subject_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/subject_0.png -------------------------------------------------------------------------------- /test_imgs/subject_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/subject_1.png -------------------------------------------------------------------------------- /train/default_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | main_process_port: 14121 5 | downcast_bf16: 'no' 6 | gpu_ids: all 7 | machine_rank: 0 8 | main_training_function: main 9 | mixed_precision: fp16 10 | num_machines: 1 11 | num_processes: 4 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /train/examples/openpose_data/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/train/examples/openpose_data/1.png -------------------------------------------------------------------------------- /train/examples/openpose_data/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/train/examples/openpose_data/2.png -------------------------------------------------------------------------------- /train/examples/pose.jsonl: -------------------------------------------------------------------------------- 1 | {"source": "./examples/openpose_data/2.png", "caption": "A girl wearing a green coat.", "target": "./examples/openpose_data/1.png"} 2 | {"source": "./examples/openpose_data/2.png", "caption": "A girl wearing a green coat.", "target": "./examples/openpose_data/1.png"} 3 | {"source": "./examples/openpose_data/2.png", "caption": "A girl wearing a green coat.", "target": "./examples/openpose_data/1.png"} -------------------------------------------------------------------------------- /train/examples/style.jsonl: -------------------------------------------------------------------------------- 1 | {"source": "./examples/style_data/5.png", "caption": "Ghibli Studio style, A digital illustration of an elderly couple standing on a grassy field, holding oranges.", "target": "./examples/style_data/6.png"} 2 | {"source": "./examples/style_data/5.png", "caption": "Ghibli Studio style, A digital illustration of an elderly couple standing on a grassy field, holding oranges.", "target": "./examples/style_data/6.png"} 3 | {"source": "./examples/style_data/5.png", "caption": "Ghibli Studio style, A digital illustration of an elderly couple standing on a grassy field, holding oranges.", "target": "./examples/style_data/6.png"} -------------------------------------------------------------------------------- /train/examples/style_data/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/train/examples/style_data/5.png -------------------------------------------------------------------------------- /train/examples/style_data/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/train/examples/style_data/6.png -------------------------------------------------------------------------------- /train/examples/subject.jsonl: -------------------------------------------------------------------------------- 1 | {"source": "./examples/subject_data/3.png", "caption": "A SKS float on the water.", "target": "./examples/subject_data/4.png"} 2 | {"source": "./examples/subject_data/3.png", "caption": "A SKS float on the water.", "target": "./examples/subject_data/4.png"} 3 | {"source": "./examples/subject_data/3.png", "caption": "A SKS float on the water.", "target": "./examples/subject_data/4.png"} -------------------------------------------------------------------------------- /train/examples/subject_data/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/train/examples/subject_data/3.png -------------------------------------------------------------------------------- /train/examples/subject_data/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/train/examples/subject_data/4.png -------------------------------------------------------------------------------- /train/readme.md: -------------------------------------------------------------------------------- 1 | # Model Training Guide 2 | 3 | This document provides a step-by-step guide for training the model in this project. 4 | 5 | ## Environment Setup 6 | 7 | 1. Ensure the following dependencies are installed: 8 | - Python 3.10.16 9 | - PyTorch 2.5.1+cu121 10 | - Required libraries (install via `requirements.txt`) 11 | 12 | ```bash 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ## Data Preparation 17 | 18 | - Ensure the data format matches the requirements of the training dataset (e.g., `examples/pose.jsonl`, `examples/subject.jsonl`, `examples/style.jsonl`). 19 | 20 | ## Start Training 21 | 22 | 1. Use the following commands to start training: 23 | 24 | - For spatial control: 25 | ```bash 26 | bash ./train_spatial.sh 27 | ``` 28 | - For subject control: 29 | ```bash 30 | bash ./train_subject.sh 31 | ``` 32 | - For style control: 33 | ```bash 34 | bash ./train_style.sh 35 | ``` 36 | 37 | 2. Example training configuration: 38 | 39 | ```bash 40 | --pretrained_model_name_or_path $MODEL_DIR \ # Path to the FLUX model 41 | --cond_size=512 \ # Source image size (recommended: 384-512 or higher for better detail control) 42 | --noise_size=1024 \ # Target image's longest side size (recommended: 1024 for better resolution) 43 | --subject_column="None" \ # JSONL key for subject; set to "None" if using spatial condition 44 | --spatial_column="source" \ # JSONL key for spatial; set to "None" if using subject condition 45 | --target_column="target" \ # JSONL key for the target image 46 | --caption_column="caption" \ # JSONL key for the caption 47 | --ranks 128 \ # LoRA rank (recommended: 128) 48 | --network_alphas 128 \ # LoRA network alphas (recommended: 128) 49 | --output_dir=$OUTPUT_DIR \ # Directory for model and validation outputs 50 | --logging_dir=$LOG_PATH \ # Directory for logs 51 | --mixed_precision="bf16" \ # Recommended: bf16 52 | --train_data_dir=$TRAIN_DATA \ # Path to the training data JSONL file 53 | --learning_rate=1e-4 \ # Recommended: 1e-4 54 | --train_batch_size=1 \ # Only supports 1 due to multi-resolution target images 55 | --validation_prompt "Ghibli Studio style, Charming hand-drawn anime-style illustration" \ # Validation prompt 56 | --num_train_epochs=1000 \ # Total number of epochs 57 | --validation_steps=20 \ # Validate every n steps 58 | --checkpointing_steps=20 \ # Save model every n steps 59 | --spatial_test_images "./examples/style_data/5.png" \ # Validation images for spatial condition 60 | --subject_test_images None \ # Validation images for subject condition 61 | --test_h 1024 \ # Height of validation images 62 | --test_w 1024 \ # Width of validation images 63 | --num_validation_images=2 # Number of validation images 64 | ``` 65 | 66 | ## Model Inference 67 | 68 | 1. After training, use the following script for inference: 69 | 70 | ```bash 71 | # Navigate to the original repository to use KV cache 72 | cd .. 73 | ``` 74 | 75 | ```python 76 | import torch 77 | from PIL import Image 78 | from src.pipeline import FluxPipeline 79 | from src.transformer_flux import FluxTransformer2DModel 80 | from src.lora_helper import set_single_lora, set_multi_lora 81 | 82 | def clear_cache(transformer): 83 | for name, attn_processor in transformer.attn_processors.items(): 84 | attn_processor.bank_kv.clear() 85 | 86 | # Initialize the model 87 | device = "cuda" 88 | base_path = "black-forest-labs/FLUX.1-dev" # Path to the base model 89 | pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16, device=device) 90 | transformer = FluxTransformer2DModel.from_pretrained( 91 | base_path, 92 | subfolder="transformer", 93 | torch_dtype=torch.bfloat16, 94 | device=device 95 | ) 96 | pipe.transformer = transformer 97 | pipe.to(device) 98 | 99 | # Path to your trained EasyControl model 100 | lora_path = " " 101 | 102 | # Single condition control example 103 | set_single_lora(pipe.transformer, path, lora_weights=[1], cond_size=512) 104 | 105 | # Set your control image path 106 | spatial_image_path = "" 107 | subject_image_path = "" 108 | style_image_path = "" 109 | 110 | control_image = Image.open("fill in spatial_image_path or subject_image_path !!") 111 | prompt = "fill in your prompt!!" 112 | 113 | # For spatial or style control 114 | image = pipe( 115 | prompt, 116 | height=768, # Generated image height 117 | width=1024, # Generated image width 118 | guidance_scale=3.5, 119 | num_inference_steps=25, # Number of steps 120 | max_sequence_length=512, 121 | generator=torch.Generator("cpu").manual_seed(5), 122 | spatial_images=[control_image], 123 | subject_images=[], 124 | cond_size=512, # Training setting for cond_size 125 | ).images[0] 126 | # Clear cache after generation 127 | clear_cache(pipe.transformer) 128 | image.save("output.png") 129 | ``` 130 | 131 | 2. For subject control: 132 | 133 | ```python 134 | image = pipe( 135 | prompt, 136 | height=768, 137 | width=1024, 138 | guidance_scale=3.5, 139 | num_inference_steps=25, 140 | max_sequence_length=512, 141 | generator=torch.Generator("cpu").manual_seed(5), 142 | spatial_images=[], 143 | subject_images=[control_image], 144 | cond_size=512, 145 | ).images[0] 146 | # Clear cache after generation 147 | clear_cache(pipe.transformer) 148 | image.save("output.png") 149 | ``` 150 | 151 | 3. For multi-condition control: 152 | 153 | ```python 154 | import torch 155 | from PIL import Image 156 | from src.pipeline import FluxPipeline 157 | from src.transformer_flux import FluxTransformer2DModel 158 | from src.lora_helper import set_single_lora, set_multi_lora 159 | 160 | def clear_cache(transformer): 161 | for name, attn_processor in transformer.attn_processors.items(): 162 | attn_processor.bank_kv.clear() 163 | 164 | # Initialize the model 165 | device = "cuda" 166 | base_path = "black-forest-labs/FLUX.1-dev" # Path to the base model 167 | pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16, device=device) 168 | transformer = FluxTransformer2DModel.from_pretrained( 169 | base_path, 170 | subfolder="transformer", 171 | torch_dtype=torch.bfloat16, 172 | device=device 173 | ) 174 | pipe.transformer = transformer 175 | pipe.to(device) 176 | 177 | # Change to your EasyControl Model path!!! 178 | lora_path = "./models" 179 | control_models = { 180 | "canny": f"{lora_path}/canny.safetensors", 181 | "depth": f"{lora_path}/depth.safetensors", 182 | "hedsketch": f"{lora_path}/hedsketch.safetensors", 183 | "pose": f"{lora_path}/pose.safetensors", 184 | "seg": f"{lora_path}/seg.safetensors", 185 | "inpainting": f"{lora_path}/inpainting.safetensors", 186 | "subject": f"{lora_path}/subject.safetensors", 187 | } 188 | paths = [control_models["subject"], control_models["inpainting"]] 189 | set_multi_lora(pipe.transformer, paths, lora_weights=[[1], [1]], cond_size=512) 190 | 191 | # Subject + spatial control 192 | prompt = "A SKS on the car" 193 | subject_images = [Image.open("./test_imgs/subject_1.png").convert("RGB")] 194 | spatial_images = [Image.open("./test_imgs/inpainting.png").convert("RGB")] 195 | image = pipe( 196 | prompt, 197 | height=1024, 198 | width=1024, 199 | guidance_scale=3.5, 200 | num_inference_steps=25, 201 | max_sequence_length=512, 202 | generator=torch.Generator("cpu").manual_seed(42), 203 | subject_images=subject_images, 204 | spatial_images=spatial_images, 205 | cond_size=512, 206 | ).images[0] 207 | # Clear cache after generation 208 | clear_cache(pipe.transformer) 209 | image.save("output_multi.png") 210 | ``` 211 | 212 | 4. For spatial + spatial control: 213 | 214 | ```python 215 | prompt = "A car" 216 | subject_images = [] 217 | spatial_images = [Image.open("image1_path").convert("RGB"), Image.open("image2_path").convert("RGB")] 218 | image = pipe( 219 | prompt, 220 | height=1024, 221 | width=1024, 222 | guidance_scale=3.5, 223 | num_inference_steps=25, 224 | max_sequence_length=512, 225 | generator=torch.Generator("cpu").manual_seed(42), 226 | subject_images=subject_images, 227 | spatial_images=spatial_images, 228 | cond_size=512, 229 | ).images[0] 230 | # Clear cache after generation 231 | clear_cache(pipe.transformer) 232 | image.save("output_multi.png") 233 | ``` 234 | 235 | ## Notes 236 | 237 | - Adjust `noise_size` and `cond_size` based on your VRAM. 238 | - Batch size is limited to 1. 239 | -------------------------------------------------------------------------------- /train/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/train/src/__init__.py -------------------------------------------------------------------------------- /train/src/jsonl_datasets.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from datasets import load_dataset 3 | from torchvision import transforms 4 | import random 5 | import torch 6 | 7 | Image.MAX_IMAGE_PIXELS = None 8 | 9 | def multiple_16(num: float): 10 | return int(round(num / 16) * 16) 11 | 12 | def get_random_resolution(min_size=512, max_size=1280, multiple=16): 13 | resolution = random.randint(min_size // multiple, max_size // multiple) * multiple 14 | return resolution 15 | 16 | def load_image_safely(image_path, size): 17 | try: 18 | image = Image.open(image_path).convert("RGB") 19 | return image 20 | except Exception as e: 21 | print("file error: "+image_path) 22 | with open("failed_images.txt", "a") as f: 23 | f.write(f"{image_path}\n") 24 | return Image.new("RGB", (size, size), (255, 255, 255)) 25 | 26 | def make_train_dataset(args, tokenizer, accelerator=None): 27 | if args.train_data_dir is not None: 28 | print("load_data") 29 | dataset = load_dataset('json', data_files=args.train_data_dir) 30 | 31 | column_names = dataset["train"].column_names 32 | 33 | # 6. Get the column names for input/target. 34 | caption_column = args.caption_column 35 | target_column = args.target_column 36 | if args.subject_column is not None: 37 | subject_columns = args.subject_column.split(",") 38 | if args.spatial_column is not None: 39 | spatial_columns= args.spatial_column.split(",") 40 | 41 | size = args.cond_size 42 | noise_size = get_random_resolution(max_size=args.noise_size) # maybe 768 or higher 43 | subject_cond_train_transforms = transforms.Compose( 44 | [ 45 | transforms.Lambda(lambda img: img.resize(( 46 | multiple_16(size * img.size[0] / max(img.size)), 47 | multiple_16(size * img.size[1] / max(img.size)) 48 | ), resample=Image.BILINEAR)), 49 | transforms.RandomHorizontalFlip(p=0.7), 50 | transforms.RandomRotation(degrees=20), 51 | transforms.Lambda(lambda img: transforms.Pad( 52 | padding=( 53 | int((size - img.size[0]) / 2), 54 | int((size - img.size[1]) / 2), 55 | int((size - img.size[0]) / 2), 56 | int((size - img.size[1]) / 2) 57 | ), 58 | fill=0 59 | )(img)), 60 | transforms.ToTensor(), 61 | transforms.Normalize([0.5], [0.5]), 62 | ] 63 | ) 64 | cond_train_transforms = transforms.Compose( 65 | [ 66 | transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR), 67 | transforms.CenterCrop((size, size)), 68 | transforms.ToTensor(), 69 | transforms.Normalize([0.5], [0.5]), 70 | ] 71 | ) 72 | 73 | def train_transforms(image, noise_size): 74 | train_transforms_ = transforms.Compose( 75 | [ 76 | transforms.Lambda(lambda img: img.resize(( 77 | multiple_16(noise_size * img.size[0] / max(img.size)), 78 | multiple_16(noise_size * img.size[1] / max(img.size)) 79 | ), resample=Image.BILINEAR)), 80 | transforms.ToTensor(), 81 | transforms.Normalize([0.5], [0.5]), 82 | ] 83 | ) 84 | transformed_image = train_transforms_(image) 85 | return transformed_image 86 | 87 | def load_and_transform_cond_images(images): 88 | transformed_images = [cond_train_transforms(image) for image in images] 89 | concatenated_image = torch.cat(transformed_images, dim=1) 90 | return concatenated_image 91 | 92 | def load_and_transform_subject_images(images): 93 | transformed_images = [subject_cond_train_transforms(image) for image in images] 94 | concatenated_image = torch.cat(transformed_images, dim=1) 95 | return concatenated_image 96 | 97 | tokenizer_clip = tokenizer[0] 98 | tokenizer_t5 = tokenizer[1] 99 | 100 | def tokenize_prompt_clip_t5(examples): 101 | captions = [] 102 | for caption in examples[caption_column]: 103 | if isinstance(caption, str): 104 | if random.random() < 0.1: 105 | captions.append(" ") # 将文本设为空 106 | else: 107 | captions.append(caption) 108 | elif isinstance(caption, list): 109 | # take a random caption if there are multiple 110 | if random.random() < 0.1: 111 | captions.append(" ") 112 | else: 113 | captions.append(random.choice(caption)) 114 | else: 115 | raise ValueError( 116 | f"Caption column `{caption_column}` should contain either strings or lists of strings." 117 | ) 118 | text_inputs = tokenizer_clip( 119 | captions, 120 | padding="max_length", 121 | max_length=77, 122 | truncation=True, 123 | return_length=False, 124 | return_overflowing_tokens=False, 125 | return_tensors="pt", 126 | ) 127 | text_input_ids_1 = text_inputs.input_ids 128 | 129 | text_inputs = tokenizer_t5( 130 | captions, 131 | padding="max_length", 132 | max_length=512, 133 | truncation=True, 134 | return_length=False, 135 | return_overflowing_tokens=False, 136 | return_tensors="pt", 137 | ) 138 | text_input_ids_2 = text_inputs.input_ids 139 | return text_input_ids_1, text_input_ids_2 140 | 141 | def preprocess_train(examples): 142 | _examples = {} 143 | if args.subject_column is not None: 144 | subject_images = [[load_image_safely(examples[column][i], args.cond_size) for column in subject_columns] for i in range(len(examples[target_column]))] 145 | _examples["subject_pixel_values"] = [load_and_transform_subject_images(subject) for subject in subject_images] 146 | if args.spatial_column is not None: 147 | spatial_images = [[load_image_safely(examples[column][i], args.cond_size) for column in spatial_columns] for i in range(len(examples[target_column]))] 148 | _examples["cond_pixel_values"] = [load_and_transform_cond_images(spatial) for spatial in spatial_images] 149 | target_images = [load_image_safely(image_path, args.cond_size) for image_path in examples[target_column]] 150 | _examples["pixel_values"] = [train_transforms(image, noise_size) for image in target_images] 151 | _examples["token_ids_clip"], _examples["token_ids_t5"] = tokenize_prompt_clip_t5(examples) 152 | return _examples 153 | 154 | if accelerator is not None: 155 | with accelerator.main_process_first(): 156 | train_dataset = dataset["train"].with_transform(preprocess_train) 157 | else: 158 | train_dataset = dataset["train"].with_transform(preprocess_train) 159 | 160 | return train_dataset 161 | 162 | 163 | def collate_fn(examples): 164 | if examples[0].get("cond_pixel_values") is not None: 165 | cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples]) 166 | cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float() 167 | else: 168 | cond_pixel_values = None 169 | if examples[0].get("subject_pixel_values") is not None: 170 | subject_pixel_values = torch.stack([example["subject_pixel_values"] for example in examples]) 171 | subject_pixel_values = subject_pixel_values.to(memory_format=torch.contiguous_format).float() 172 | else: 173 | subject_pixel_values = None 174 | 175 | target_pixel_values = torch.stack([example["pixel_values"] for example in examples]) 176 | target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float() 177 | token_ids_clip = torch.stack([torch.tensor(example["token_ids_clip"]) for example in examples]) 178 | token_ids_t5 = torch.stack([torch.tensor(example["token_ids_t5"]) for example in examples]) 179 | 180 | return { 181 | "cond_pixel_values": cond_pixel_values, 182 | "subject_pixel_values": subject_pixel_values, 183 | "pixel_values": target_pixel_values, 184 | "text_ids_1": token_ids_clip, 185 | "text_ids_2": token_ids_t5, 186 | } -------------------------------------------------------------------------------- /train/src/layers.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import math 3 | from typing import Callable, List, Optional, Tuple, Union 4 | from einops import rearrange 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | from torch import Tensor 9 | from diffusers.models.attention_processor import Attention 10 | 11 | class LoRALinearLayer(nn.Module): 12 | def __init__( 13 | self, 14 | in_features: int, 15 | out_features: int, 16 | rank: int = 4, 17 | network_alpha: Optional[float] = None, 18 | device: Optional[Union[torch.device, str]] = None, 19 | dtype: Optional[torch.dtype] = None, 20 | cond_width=512, 21 | cond_height=512, 22 | number=0, 23 | n_loras=1 24 | ): 25 | super().__init__() 26 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) 27 | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 28 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 29 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 30 | self.network_alpha = network_alpha 31 | self.rank = rank 32 | self.out_features = out_features 33 | self.in_features = in_features 34 | 35 | nn.init.normal_(self.down.weight, std=1 / rank) 36 | nn.init.zeros_(self.up.weight) 37 | 38 | self.cond_height = cond_height 39 | self.cond_width = cond_width 40 | self.number = number 41 | self.n_loras = n_loras 42 | 43 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 44 | orig_dtype = hidden_states.dtype 45 | dtype = self.down.weight.dtype 46 | 47 | #### img condition 48 | batch_size = hidden_states.shape[0] 49 | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64 50 | block_size = hidden_states.shape[1] - cond_size * self.n_loras 51 | shape = (batch_size, hidden_states.shape[1], 3072) 52 | mask = torch.ones(shape, device=hidden_states.device, dtype=dtype) 53 | mask[:, :block_size+self.number*cond_size, :] = 0 54 | mask[:, block_size+(self.number+1)*cond_size:, :] = 0 55 | hidden_states = mask * hidden_states 56 | #### 57 | 58 | down_hidden_states = self.down(hidden_states.to(dtype)) 59 | up_hidden_states = self.up(down_hidden_states) 60 | 61 | if self.network_alpha is not None: 62 | up_hidden_states *= self.network_alpha / self.rank 63 | 64 | return up_hidden_states.to(orig_dtype) 65 | 66 | 67 | class MultiSingleStreamBlockLoraProcessor(nn.Module): 68 | def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1): 69 | super().__init__() 70 | # Initialize a list to store the LoRA layers 71 | self.n_loras = n_loras 72 | self.cond_width = cond_width 73 | self.cond_height = cond_height 74 | 75 | self.q_loras = nn.ModuleList([ 76 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 77 | for i in range(n_loras) 78 | ]) 79 | self.k_loras = nn.ModuleList([ 80 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 81 | for i in range(n_loras) 82 | ]) 83 | self.v_loras = nn.ModuleList([ 84 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 85 | for i in range(n_loras) 86 | ]) 87 | self.lora_weights = lora_weights 88 | 89 | 90 | def __call__(self, 91 | attn: Attention, 92 | hidden_states: torch.FloatTensor, 93 | encoder_hidden_states: torch.FloatTensor = None, 94 | attention_mask: Optional[torch.FloatTensor] = None, 95 | image_rotary_emb: Optional[torch.Tensor] = None, 96 | use_cond = False, 97 | ) -> torch.FloatTensor: 98 | 99 | batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 100 | query = attn.to_q(hidden_states) 101 | key = attn.to_k(hidden_states) 102 | value = attn.to_v(hidden_states) 103 | 104 | for i in range(self.n_loras): 105 | query = query + self.lora_weights[i] * self.q_loras[i](hidden_states) 106 | key = key + self.lora_weights[i] * self.k_loras[i](hidden_states) 107 | value = value + self.lora_weights[i] * self.v_loras[i](hidden_states) 108 | 109 | inner_dim = key.shape[-1] 110 | head_dim = inner_dim // attn.heads 111 | 112 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 113 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 114 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 115 | 116 | if attn.norm_q is not None: 117 | query = attn.norm_q(query) 118 | if attn.norm_k is not None: 119 | key = attn.norm_k(key) 120 | 121 | if image_rotary_emb is not None: 122 | from diffusers.models.embeddings import apply_rotary_emb 123 | query = apply_rotary_emb(query, image_rotary_emb) 124 | key = apply_rotary_emb(key, image_rotary_emb) 125 | 126 | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64 127 | block_size = hidden_states.shape[1] - cond_size * self.n_loras 128 | scaled_cond_size = cond_size 129 | scaled_block_size = block_size 130 | scaled_seq_len = query.shape[2] 131 | 132 | num_cond_blocks = self.n_loras 133 | mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device) 134 | mask[ :scaled_block_size, :] = 0 # First block_size row 135 | for i in range(num_cond_blocks): 136 | start = i * scaled_cond_size + scaled_block_size 137 | end = (i + 1) * scaled_cond_size + scaled_block_size 138 | mask[start:end, start:end] = 0 # Diagonal blocks 139 | mask = mask * -1e20 140 | mask = mask.to(query.dtype) 141 | 142 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask) 143 | 144 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 145 | hidden_states = hidden_states.to(query.dtype) 146 | 147 | cond_hidden_states = hidden_states[:, block_size:,:] 148 | hidden_states = hidden_states[:, : block_size,:] 149 | 150 | return hidden_states if not use_cond else (hidden_states, cond_hidden_states) 151 | 152 | 153 | class MultiDoubleStreamBlockLoraProcessor(nn.Module): 154 | def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1): 155 | super().__init__() 156 | 157 | # Initialize a list to store the LoRA layers 158 | self.n_loras = n_loras 159 | self.cond_width = cond_width 160 | self.cond_height = cond_height 161 | self.q_loras = nn.ModuleList([ 162 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 163 | for i in range(n_loras) 164 | ]) 165 | self.k_loras = nn.ModuleList([ 166 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 167 | for i in range(n_loras) 168 | ]) 169 | self.v_loras = nn.ModuleList([ 170 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 171 | for i in range(n_loras) 172 | ]) 173 | self.proj_loras = nn.ModuleList([ 174 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 175 | for i in range(n_loras) 176 | ]) 177 | self.lora_weights = lora_weights 178 | 179 | 180 | def __call__(self, 181 | attn: Attention, 182 | hidden_states: torch.FloatTensor, 183 | encoder_hidden_states: torch.FloatTensor = None, 184 | attention_mask: Optional[torch.FloatTensor] = None, 185 | image_rotary_emb: Optional[torch.Tensor] = None, 186 | use_cond=False, 187 | ) -> torch.FloatTensor: 188 | 189 | batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 190 | 191 | # `context` projections. 192 | inner_dim = 3072 193 | head_dim = inner_dim // attn.heads 194 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 195 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 196 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 197 | 198 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( 199 | batch_size, -1, attn.heads, head_dim 200 | ).transpose(1, 2) 201 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( 202 | batch_size, -1, attn.heads, head_dim 203 | ).transpose(1, 2) 204 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( 205 | batch_size, -1, attn.heads, head_dim 206 | ).transpose(1, 2) 207 | 208 | if attn.norm_added_q is not None: 209 | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) 210 | if attn.norm_added_k is not None: 211 | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) 212 | 213 | query = attn.to_q(hidden_states) 214 | key = attn.to_k(hidden_states) 215 | value = attn.to_v(hidden_states) 216 | for i in range(self.n_loras): 217 | query = query + self.lora_weights[i] * self.q_loras[i](hidden_states) 218 | key = key + self.lora_weights[i] * self.k_loras[i](hidden_states) 219 | value = value + self.lora_weights[i] * self.v_loras[i](hidden_states) 220 | 221 | inner_dim = key.shape[-1] 222 | head_dim = inner_dim // attn.heads 223 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 224 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 225 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 226 | 227 | if attn.norm_q is not None: 228 | query = attn.norm_q(query) 229 | if attn.norm_k is not None: 230 | key = attn.norm_k(key) 231 | 232 | # attention 233 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) 234 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) 235 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) 236 | 237 | if image_rotary_emb is not None: 238 | from diffusers.models.embeddings import apply_rotary_emb 239 | query = apply_rotary_emb(query, image_rotary_emb) 240 | key = apply_rotary_emb(key, image_rotary_emb) 241 | 242 | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64 243 | block_size = hidden_states.shape[1] - cond_size * self.n_loras 244 | scaled_cond_size = cond_size 245 | scaled_seq_len = query.shape[2] 246 | scaled_block_size = scaled_seq_len - cond_size * self.n_loras 247 | 248 | num_cond_blocks = self.n_loras 249 | mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device) 250 | mask[ :scaled_block_size, :] = 0 # First block_size row 251 | for i in range(num_cond_blocks): 252 | start = i * scaled_cond_size + scaled_block_size 253 | end = (i + 1) * scaled_cond_size + scaled_block_size 254 | mask[start:end, start:end] = 0 # Diagonal blocks 255 | mask = mask * -1e20 256 | mask = mask.to(query.dtype) 257 | 258 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask) 259 | 260 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 261 | hidden_states = hidden_states.to(query.dtype) 262 | 263 | encoder_hidden_states, hidden_states = ( 264 | hidden_states[:, : encoder_hidden_states.shape[1]], 265 | hidden_states[:, encoder_hidden_states.shape[1] :], 266 | ) 267 | 268 | # Linear projection (with LoRA weight applied to each proj layer) 269 | hidden_states = attn.to_out[0](hidden_states) 270 | for i in range(self.n_loras): 271 | hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states) 272 | # dropout 273 | hidden_states = attn.to_out[1](hidden_states) 274 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 275 | 276 | cond_hidden_states = hidden_states[:, block_size:,:] 277 | hidden_states = hidden_states[:, :block_size,:] 278 | 279 | return (hidden_states, encoder_hidden_states, cond_hidden_states) if use_cond else (encoder_hidden_states, hidden_states) -------------------------------------------------------------------------------- /train/src/lora_helper.py: -------------------------------------------------------------------------------- 1 | from diffusers.models.attention_processor import FluxAttnProcessor2_0 2 | from safetensors import safe_open 3 | import re 4 | import torch 5 | from .layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor 6 | 7 | device = "cuda" 8 | 9 | def load_safetensors(path): 10 | tensors = {} 11 | with safe_open(path, framework="pt", device="cpu") as f: 12 | for key in f.keys(): 13 | tensors[key] = f.get_tensor(key) 14 | return tensors 15 | 16 | def get_lora_rank(checkpoint): 17 | for k in checkpoint.keys(): 18 | if k.endswith(".down.weight"): 19 | return checkpoint[k].shape[0] 20 | 21 | def load_checkpoint(local_path): 22 | if local_path is not None: 23 | if '.safetensors' in local_path: 24 | print(f"Loading .safetensors checkpoint from {local_path}") 25 | checkpoint = load_safetensors(local_path) 26 | else: 27 | print(f"Loading checkpoint from {local_path}") 28 | checkpoint = torch.load(local_path, map_location='cpu') 29 | return checkpoint 30 | 31 | def update_model_with_lora(checkpoint, lora_weights, transformer, cond_size): 32 | number = len(lora_weights) 33 | ranks = [get_lora_rank(checkpoint) for _ in range(number)] 34 | lora_attn_procs = {} 35 | double_blocks_idx = list(range(19)) 36 | single_blocks_idx = list(range(38)) 37 | for name, attn_processor in transformer.attn_processors.items(): 38 | match = re.search(r'\.(\d+)\.', name) 39 | if match: 40 | layer_index = int(match.group(1)) 41 | 42 | if name.startswith("transformer_blocks") and layer_index in double_blocks_idx: 43 | 44 | lora_state_dicts = {} 45 | for key, value in checkpoint.items(): 46 | # Match based on the layer index in the key (assuming the key contains layer index) 47 | if re.search(r'\.(\d+)\.', key): 48 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) 49 | if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"): 50 | lora_state_dicts[key] = value 51 | 52 | lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor( 53 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number 54 | ) 55 | 56 | # Load the weights from the checkpoint dictionary into the corresponding layers 57 | for n in range(number): 58 | lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None) 59 | lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None) 60 | lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None) 61 | lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None) 62 | lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None) 63 | lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None) 64 | lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None) 65 | lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None) 66 | lora_attn_procs[name].to(device) 67 | 68 | elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx: 69 | 70 | lora_state_dicts = {} 71 | for key, value in checkpoint.items(): 72 | # Match based on the layer index in the key (assuming the key contains layer index) 73 | if re.search(r'\.(\d+)\.', key): 74 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) 75 | if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"): 76 | lora_state_dicts[key] = value 77 | 78 | lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor( 79 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number 80 | ) 81 | # Load the weights from the checkpoint dictionary into the corresponding layers 82 | for n in range(number): 83 | lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None) 84 | lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None) 85 | lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None) 86 | lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None) 87 | lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None) 88 | lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None) 89 | lora_attn_procs[name].to(device) 90 | else: 91 | lora_attn_procs[name] = FluxAttnProcessor2_0() 92 | 93 | transformer.set_attn_processor(lora_attn_procs) 94 | 95 | 96 | def update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size): 97 | ck_number = len(checkpoints) 98 | cond_lora_number = [len(ls) for ls in lora_weights] 99 | cond_number = sum(cond_lora_number) 100 | ranks = [get_lora_rank(checkpoint) for checkpoint in checkpoints] 101 | multi_lora_weight = [] 102 | for ls in lora_weights: 103 | for n in ls: 104 | multi_lora_weight.append(n) 105 | 106 | lora_attn_procs = {} 107 | double_blocks_idx = list(range(19)) 108 | single_blocks_idx = list(range(38)) 109 | for name, attn_processor in transformer.attn_processors.items(): 110 | match = re.search(r'\.(\d+)\.', name) 111 | if match: 112 | layer_index = int(match.group(1)) 113 | 114 | if name.startswith("transformer_blocks") and layer_index in double_blocks_idx: 115 | lora_state_dicts = [{} for _ in range(ck_number)] 116 | for idx, checkpoint in enumerate(checkpoints): 117 | for key, value in checkpoint.items(): 118 | # Match based on the layer index in the key (assuming the key contains layer index) 119 | if re.search(r'\.(\d+)\.', key): 120 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) 121 | if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"): 122 | lora_state_dicts[idx][key] = value 123 | 124 | lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor( 125 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number 126 | ) 127 | 128 | # Load the weights from the checkpoint dictionary into the corresponding layers 129 | num = 0 130 | for idx in range(ck_number): 131 | for n in range(cond_lora_number[idx]): 132 | lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None) 133 | lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None) 134 | lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None) 135 | lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None) 136 | lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None) 137 | lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None) 138 | lora_attn_procs[name].proj_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.down.weight', None) 139 | lora_attn_procs[name].proj_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.up.weight', None) 140 | lora_attn_procs[name].to(device) 141 | num += 1 142 | 143 | elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx: 144 | 145 | lora_state_dicts = [{} for _ in range(ck_number)] 146 | for idx, checkpoint in enumerate(checkpoints): 147 | for key, value in checkpoint.items(): 148 | # Match based on the layer index in the key (assuming the key contains layer index) 149 | if re.search(r'\.(\d+)\.', key): 150 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) 151 | if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"): 152 | lora_state_dicts[idx][key] = value 153 | 154 | lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor( 155 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number 156 | ) 157 | # Load the weights from the checkpoint dictionary into the corresponding layers 158 | num = 0 159 | for idx in range(ck_number): 160 | for n in range(cond_lora_number[idx]): 161 | lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None) 162 | lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None) 163 | lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None) 164 | lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None) 165 | lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None) 166 | lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None) 167 | lora_attn_procs[name].to(device) 168 | num += 1 169 | 170 | else: 171 | lora_attn_procs[name] = FluxAttnProcessor2_0() 172 | 173 | transformer.set_attn_processor(lora_attn_procs) 174 | 175 | 176 | def set_single_lora(transformer, local_path, lora_weights=[], cond_size=512): 177 | checkpoint = load_checkpoint(local_path) 178 | update_model_with_lora(checkpoint, lora_weights, transformer, cond_size) 179 | 180 | def set_multi_lora(transformer, local_paths, lora_weights=[[]], cond_size=512): 181 | checkpoints = [load_checkpoint(local_path) for local_path in local_paths] 182 | update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size) 183 | 184 | def unset_lora(transformer): 185 | lora_attn_procs = {} 186 | for name, attn_processor in transformer.attn_processors.items(): 187 | lora_attn_procs[name] = FluxAttnProcessor2_0() 188 | transformer.set_attn_processor(lora_attn_procs) 189 | 190 | 191 | ''' 192 | unset_lora(pipe.transformer) 193 | lora_path = "./lora.safetensors" 194 | lora_weights = [1, 1] 195 | set_lora(pipe.transformer, local_path=lora_path, lora_weights=lora_weights, cond_size=512) 196 | ''' -------------------------------------------------------------------------------- /train/src/prompt_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def load_text_encoders(args, class_one, class_two): 5 | text_encoder_one = class_one.from_pretrained( 6 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant 7 | ) 8 | text_encoder_two = class_two.from_pretrained( 9 | args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant 10 | ) 11 | return text_encoder_one, text_encoder_two 12 | 13 | 14 | def tokenize_prompt(tokenizer, prompt, max_sequence_length): 15 | text_inputs = tokenizer( 16 | prompt, 17 | padding="max_length", 18 | max_length=max_sequence_length, 19 | truncation=True, 20 | return_length=False, 21 | return_overflowing_tokens=False, 22 | return_tensors="pt", 23 | ) 24 | text_input_ids = text_inputs.input_ids 25 | return text_input_ids 26 | 27 | 28 | def tokenize_prompt_clip(tokenizer, prompt): 29 | text_inputs = tokenizer( 30 | prompt, 31 | padding="max_length", 32 | max_length=77, 33 | truncation=True, 34 | return_length=False, 35 | return_overflowing_tokens=False, 36 | return_tensors="pt", 37 | ) 38 | text_input_ids = text_inputs.input_ids 39 | return text_input_ids 40 | 41 | 42 | def tokenize_prompt_t5(tokenizer, prompt): 43 | text_inputs = tokenizer( 44 | prompt, 45 | padding="max_length", 46 | max_length=512, 47 | truncation=True, 48 | return_length=False, 49 | return_overflowing_tokens=False, 50 | return_tensors="pt", 51 | ) 52 | text_input_ids = text_inputs.input_ids 53 | return text_input_ids 54 | 55 | 56 | def _encode_prompt_with_t5( 57 | text_encoder, 58 | tokenizer, 59 | max_sequence_length=512, 60 | prompt=None, 61 | num_images_per_prompt=1, 62 | device=None, 63 | text_input_ids=None, 64 | ): 65 | prompt = [prompt] if isinstance(prompt, str) else prompt 66 | batch_size = len(prompt) 67 | 68 | if tokenizer is not None: 69 | text_inputs = tokenizer( 70 | prompt, 71 | padding="max_length", 72 | max_length=max_sequence_length, 73 | truncation=True, 74 | return_length=False, 75 | return_overflowing_tokens=False, 76 | return_tensors="pt", 77 | ) 78 | text_input_ids = text_inputs.input_ids 79 | else: 80 | if text_input_ids is None: 81 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified") 82 | 83 | prompt_embeds = text_encoder(text_input_ids.to(device))[0] 84 | 85 | dtype = text_encoder.dtype 86 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 87 | 88 | _, seq_len, _ = prompt_embeds.shape 89 | 90 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 91 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 92 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 93 | 94 | return prompt_embeds 95 | 96 | 97 | def _encode_prompt_with_clip( 98 | text_encoder, 99 | tokenizer, 100 | prompt: str, 101 | device=None, 102 | text_input_ids=None, 103 | num_images_per_prompt: int = 1, 104 | ): 105 | prompt = [prompt] if isinstance(prompt, str) else prompt 106 | batch_size = len(prompt) 107 | 108 | if tokenizer is not None: 109 | text_inputs = tokenizer( 110 | prompt, 111 | padding="max_length", 112 | max_length=77, 113 | truncation=True, 114 | return_overflowing_tokens=False, 115 | return_length=False, 116 | return_tensors="pt", 117 | ) 118 | 119 | text_input_ids = text_inputs.input_ids 120 | else: 121 | if text_input_ids is None: 122 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified") 123 | 124 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) 125 | 126 | # Use pooled output of CLIPTextModel 127 | prompt_embeds = prompt_embeds.pooler_output 128 | prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) 129 | 130 | # duplicate text embeddings for each generation per prompt, using mps friendly method 131 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 132 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 133 | 134 | return prompt_embeds 135 | 136 | 137 | def encode_prompt( 138 | text_encoders, 139 | tokenizers, 140 | prompt: str, 141 | max_sequence_length, 142 | device=None, 143 | num_images_per_prompt: int = 1, 144 | text_input_ids_list=None, 145 | ): 146 | prompt = [prompt] if isinstance(prompt, str) else prompt 147 | dtype = text_encoders[0].dtype 148 | 149 | pooled_prompt_embeds = _encode_prompt_with_clip( 150 | text_encoder=text_encoders[0], 151 | tokenizer=tokenizers[0], 152 | prompt=prompt, 153 | device=device if device is not None else text_encoders[0].device, 154 | num_images_per_prompt=num_images_per_prompt, 155 | text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, 156 | ) 157 | 158 | prompt_embeds = _encode_prompt_with_t5( 159 | text_encoder=text_encoders[1], 160 | tokenizer=tokenizers[1], 161 | max_sequence_length=max_sequence_length, 162 | prompt=prompt, 163 | num_images_per_prompt=num_images_per_prompt, 164 | device=device if device is not None else text_encoders[1].device, 165 | text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, 166 | ) 167 | 168 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) 169 | 170 | return prompt_embeds, pooled_prompt_embeds, text_ids 171 | 172 | 173 | def encode_token_ids(text_encoders, tokens, accelerator, num_images_per_prompt=1, device=None): 174 | text_encoder_clip = text_encoders[0] 175 | text_encoder_t5 = text_encoders[1] 176 | tokens_clip, tokens_t5 = tokens[0], tokens[1] 177 | batch_size = tokens_clip.shape[0] 178 | 179 | if device == "cpu": 180 | device = "cpu" 181 | else: 182 | device = accelerator.device 183 | 184 | # clip 185 | prompt_embeds = text_encoder_clip(tokens_clip.to(device), output_hidden_states=False) 186 | # Use pooled output of CLIPTextModel 187 | prompt_embeds = prompt_embeds.pooler_output 188 | prompt_embeds = prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device) 189 | # duplicate text embeddings for each generation per prompt, using mps friendly method 190 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 191 | pooled_prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 192 | pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device) 193 | 194 | # t5 195 | prompt_embeds = text_encoder_t5(tokens_t5.to(device))[0] 196 | dtype = text_encoder_t5.dtype 197 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=accelerator.device) 198 | _, seq_len, _ = prompt_embeds.shape 199 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 200 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 201 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 202 | 203 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=dtype) 204 | 205 | return prompt_embeds, pooled_prompt_embeds, text_ids -------------------------------------------------------------------------------- /train/src/transformer_flux.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config 9 | from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin 10 | from diffusers.models.attention import FeedForward 11 | from diffusers.models.attention_processor import ( 12 | Attention, 13 | AttentionProcessor, 14 | FluxAttnProcessor2_0, 15 | FluxAttnProcessor2_0_NPU, 16 | FusedFluxAttnProcessor2_0, 17 | ) 18 | from diffusers.models.modeling_utils import ModelMixin 19 | from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle 20 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 21 | from diffusers.utils.import_utils import is_torch_npu_available 22 | from diffusers.utils.torch_utils import maybe_allow_in_graph 23 | from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed 24 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 25 | 26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 27 | 28 | @maybe_allow_in_graph 29 | class FluxSingleTransformerBlock(nn.Module): 30 | 31 | def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): 32 | super().__init__() 33 | self.mlp_hidden_dim = int(dim * mlp_ratio) 34 | 35 | self.norm = AdaLayerNormZeroSingle(dim) 36 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) 37 | self.act_mlp = nn.GELU(approximate="tanh") 38 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) 39 | 40 | if is_torch_npu_available(): 41 | processor = FluxAttnProcessor2_0_NPU() 42 | else: 43 | processor = FluxAttnProcessor2_0() 44 | self.attn = Attention( 45 | query_dim=dim, 46 | cross_attention_dim=None, 47 | dim_head=attention_head_dim, 48 | heads=num_attention_heads, 49 | out_dim=dim, 50 | bias=True, 51 | processor=processor, 52 | qk_norm="rms_norm", 53 | eps=1e-6, 54 | pre_only=True, 55 | ) 56 | 57 | def forward( 58 | self, 59 | hidden_states: torch.Tensor, 60 | cond_hidden_states: torch.Tensor, 61 | temb: torch.Tensor, 62 | cond_temb: torch.Tensor, 63 | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 64 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 65 | ) -> torch.Tensor: 66 | use_cond = cond_hidden_states is not None 67 | 68 | residual = hidden_states 69 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb) 70 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) 71 | 72 | if use_cond: 73 | residual_cond = cond_hidden_states 74 | norm_cond_hidden_states, cond_gate = self.norm(cond_hidden_states, emb=cond_temb) 75 | mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_cond_hidden_states)) 76 | 77 | norm_hidden_states_concat = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2) 78 | 79 | joint_attention_kwargs = joint_attention_kwargs or {} 80 | attn_output = self.attn( 81 | hidden_states=norm_hidden_states_concat, 82 | image_rotary_emb=image_rotary_emb, 83 | use_cond=use_cond, 84 | **joint_attention_kwargs, 85 | ) 86 | if use_cond: 87 | attn_output, cond_attn_output = attn_output 88 | 89 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) 90 | gate = gate.unsqueeze(1) 91 | hidden_states = gate * self.proj_out(hidden_states) 92 | hidden_states = residual + hidden_states 93 | 94 | if use_cond: 95 | condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2) 96 | cond_gate = cond_gate.unsqueeze(1) 97 | condition_latents = cond_gate * self.proj_out(condition_latents) 98 | condition_latents = residual_cond + condition_latents 99 | 100 | if hidden_states.dtype == torch.float16: 101 | hidden_states = hidden_states.clip(-65504, 65504) 102 | 103 | return hidden_states, condition_latents if use_cond else None 104 | 105 | 106 | @maybe_allow_in_graph 107 | class FluxTransformerBlock(nn.Module): 108 | def __init__( 109 | self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 110 | ): 111 | super().__init__() 112 | 113 | self.norm1 = AdaLayerNormZero(dim) 114 | 115 | self.norm1_context = AdaLayerNormZero(dim) 116 | 117 | if hasattr(F, "scaled_dot_product_attention"): 118 | processor = FluxAttnProcessor2_0() 119 | else: 120 | raise ValueError( 121 | "The current PyTorch version does not support the `scaled_dot_product_attention` function." 122 | ) 123 | self.attn = Attention( 124 | query_dim=dim, 125 | cross_attention_dim=None, 126 | added_kv_proj_dim=dim, 127 | dim_head=attention_head_dim, 128 | heads=num_attention_heads, 129 | out_dim=dim, 130 | context_pre_only=False, 131 | bias=True, 132 | processor=processor, 133 | qk_norm=qk_norm, 134 | eps=eps, 135 | ) 136 | 137 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 138 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 139 | 140 | self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 141 | self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 142 | 143 | # let chunk size default to None 144 | self._chunk_size = None 145 | self._chunk_dim = 0 146 | 147 | def forward( 148 | self, 149 | hidden_states: torch.Tensor, 150 | cond_hidden_states: torch.Tensor, 151 | encoder_hidden_states: torch.Tensor, 152 | temb: torch.Tensor, 153 | cond_temb: torch.Tensor, 154 | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 155 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 156 | ) -> Tuple[torch.Tensor, torch.Tensor]: 157 | use_cond = cond_hidden_states is not None 158 | 159 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) 160 | if use_cond: 161 | ( 162 | norm_cond_hidden_states, 163 | cond_gate_msa, 164 | cond_shift_mlp, 165 | cond_scale_mlp, 166 | cond_gate_mlp, 167 | ) = self.norm1(cond_hidden_states, emb=cond_temb) 168 | 169 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( 170 | encoder_hidden_states, emb=temb 171 | ) 172 | 173 | norm_hidden_states = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2) 174 | 175 | joint_attention_kwargs = joint_attention_kwargs or {} 176 | # Attention. 177 | attention_outputs = self.attn( 178 | hidden_states=norm_hidden_states, 179 | encoder_hidden_states=norm_encoder_hidden_states, 180 | image_rotary_emb=image_rotary_emb, 181 | use_cond=use_cond, 182 | **joint_attention_kwargs, 183 | ) 184 | 185 | attn_output, context_attn_output = attention_outputs[:2] 186 | cond_attn_output = attention_outputs[2] if use_cond else None 187 | 188 | # Process attention outputs for the `hidden_states`. 189 | attn_output = gate_msa.unsqueeze(1) * attn_output 190 | hidden_states = hidden_states + attn_output 191 | 192 | if use_cond: 193 | cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output 194 | cond_hidden_states = cond_hidden_states + cond_attn_output 195 | 196 | norm_hidden_states = self.norm2(hidden_states) 197 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 198 | 199 | if use_cond: 200 | norm_cond_hidden_states = self.norm2(cond_hidden_states) 201 | norm_cond_hidden_states = ( 202 | norm_cond_hidden_states * (1 + cond_scale_mlp[:, None]) 203 | + cond_shift_mlp[:, None] 204 | ) 205 | 206 | ff_output = self.ff(norm_hidden_states) 207 | ff_output = gate_mlp.unsqueeze(1) * ff_output 208 | hidden_states = hidden_states + ff_output 209 | 210 | if use_cond: 211 | cond_ff_output = self.ff(norm_cond_hidden_states) 212 | cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output 213 | cond_hidden_states = cond_hidden_states + cond_ff_output 214 | 215 | # Process attention outputs for the `encoder_hidden_states`. 216 | 217 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output 218 | encoder_hidden_states = encoder_hidden_states + context_attn_output 219 | 220 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) 221 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] 222 | 223 | context_ff_output = self.ff_context(norm_encoder_hidden_states) 224 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output 225 | if encoder_hidden_states.dtype == torch.float16: 226 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) 227 | 228 | return encoder_hidden_states, hidden_states, cond_hidden_states if use_cond else None 229 | 230 | 231 | class FluxTransformer2DModel( 232 | ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin 233 | ): 234 | _supports_gradient_checkpointing = True 235 | _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] 236 | 237 | @register_to_config 238 | def __init__( 239 | self, 240 | patch_size: int = 1, 241 | in_channels: int = 64, 242 | out_channels: Optional[int] = None, 243 | num_layers: int = 19, 244 | num_single_layers: int = 38, 245 | attention_head_dim: int = 128, 246 | num_attention_heads: int = 24, 247 | joint_attention_dim: int = 4096, 248 | pooled_projection_dim: int = 768, 249 | guidance_embeds: bool = False, 250 | axes_dims_rope: Tuple[int] = (16, 56, 56), 251 | ): 252 | super().__init__() 253 | self.out_channels = out_channels or in_channels 254 | self.inner_dim = num_attention_heads * attention_head_dim 255 | 256 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) 257 | 258 | text_time_guidance_cls = ( 259 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings 260 | ) 261 | self.time_text_embed = text_time_guidance_cls( 262 | embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim 263 | ) 264 | 265 | self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) 266 | self.x_embedder = nn.Linear(in_channels, self.inner_dim) 267 | 268 | self.transformer_blocks = nn.ModuleList( 269 | [ 270 | FluxTransformerBlock( 271 | dim=self.inner_dim, 272 | num_attention_heads=num_attention_heads, 273 | attention_head_dim=attention_head_dim, 274 | ) 275 | for _ in range(num_layers) 276 | ] 277 | ) 278 | 279 | self.single_transformer_blocks = nn.ModuleList( 280 | [ 281 | FluxSingleTransformerBlock( 282 | dim=self.inner_dim, 283 | num_attention_heads=num_attention_heads, 284 | attention_head_dim=attention_head_dim, 285 | ) 286 | for _ in range(num_single_layers) 287 | ] 288 | ) 289 | 290 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) 291 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) 292 | 293 | self.gradient_checkpointing = False 294 | 295 | @property 296 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 297 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 298 | r""" 299 | Returns: 300 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 301 | indexed by its weight name. 302 | """ 303 | # set recursively 304 | processors = {} 305 | 306 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 307 | if hasattr(module, "get_processor"): 308 | processors[f"{name}.processor"] = module.get_processor() 309 | 310 | for sub_name, child in module.named_children(): 311 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 312 | 313 | return processors 314 | 315 | for name, module in self.named_children(): 316 | fn_recursive_add_processors(name, module, processors) 317 | 318 | return processors 319 | 320 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 321 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 322 | r""" 323 | Sets the attention processor to use to compute attention. 324 | 325 | Parameters: 326 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 327 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 328 | for **all** `Attention` layers. 329 | 330 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 331 | processor. This is strongly recommended when setting trainable attention processors. 332 | 333 | """ 334 | count = len(self.attn_processors.keys()) 335 | 336 | if isinstance(processor, dict) and len(processor) != count: 337 | raise ValueError( 338 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 339 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 340 | ) 341 | 342 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 343 | if hasattr(module, "set_processor"): 344 | if not isinstance(processor, dict): 345 | module.set_processor(processor) 346 | else: 347 | module.set_processor(processor.pop(f"{name}.processor")) 348 | 349 | for sub_name, child in module.named_children(): 350 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 351 | 352 | for name, module in self.named_children(): 353 | fn_recursive_attn_processor(name, module, processor) 354 | 355 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 356 | def fuse_qkv_projections(self): 357 | """ 358 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) 359 | are fused. For cross-attention modules, key and value projection matrices are fused. 360 | 361 | 362 | 363 | This API is 🧪 experimental. 364 | 365 | 366 | """ 367 | self.original_attn_processors = None 368 | 369 | for _, attn_processor in self.attn_processors.items(): 370 | if "Added" in str(attn_processor.__class__.__name__): 371 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") 372 | 373 | self.original_attn_processors = self.attn_processors 374 | 375 | for module in self.modules(): 376 | if isinstance(module, Attention): 377 | module.fuse_projections(fuse=True) 378 | 379 | self.set_attn_processor(FusedFluxAttnProcessor2_0()) 380 | 381 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 382 | def unfuse_qkv_projections(self): 383 | """Disables the fused QKV projection if enabled. 384 | 385 | 386 | 387 | This API is 🧪 experimental. 388 | 389 | 390 | 391 | """ 392 | if self.original_attn_processors is not None: 393 | self.set_attn_processor(self.original_attn_processors) 394 | 395 | def _set_gradient_checkpointing(self, module, value=False): 396 | if hasattr(module, "gradient_checkpointing"): 397 | module.gradient_checkpointing = value 398 | 399 | def forward( 400 | self, 401 | hidden_states: torch.Tensor, 402 | cond_hidden_states: torch.Tensor = None, 403 | encoder_hidden_states: torch.Tensor = None, 404 | pooled_projections: torch.Tensor = None, 405 | timestep: torch.LongTensor = None, 406 | img_ids: torch.Tensor = None, 407 | txt_ids: torch.Tensor = None, 408 | guidance: torch.Tensor = None, 409 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 410 | controlnet_block_samples=None, 411 | controlnet_single_block_samples=None, 412 | return_dict: bool = True, 413 | controlnet_blocks_repeat: bool = False, 414 | ) -> Union[torch.Tensor, Transformer2DModelOutput]: 415 | if cond_hidden_states is not None: 416 | use_condition = True 417 | else: 418 | use_condition = False 419 | 420 | if joint_attention_kwargs is not None: 421 | joint_attention_kwargs = joint_attention_kwargs.copy() 422 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 423 | else: 424 | lora_scale = 1.0 425 | 426 | if USE_PEFT_BACKEND: 427 | # weight the lora layers by setting `lora_scale` for each PEFT layer 428 | scale_lora_layers(self, lora_scale) 429 | else: 430 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 431 | logger.warning( 432 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 433 | ) 434 | 435 | hidden_states = self.x_embedder(hidden_states) 436 | cond_hidden_states = self.x_embedder(cond_hidden_states) 437 | 438 | timestep = timestep.to(hidden_states.dtype) * 1000 439 | if guidance is not None: 440 | guidance = guidance.to(hidden_states.dtype) * 1000 441 | else: 442 | guidance = None 443 | 444 | temb = ( 445 | self.time_text_embed(timestep, pooled_projections) 446 | if guidance is None 447 | else self.time_text_embed(timestep, guidance, pooled_projections) 448 | ) 449 | 450 | cond_temb = ( 451 | self.time_text_embed(torch.ones_like(timestep) * 0, pooled_projections) 452 | if guidance is None 453 | else self.time_text_embed( 454 | torch.ones_like(timestep) * 0, guidance, pooled_projections 455 | ) 456 | ) 457 | 458 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 459 | 460 | if txt_ids.ndim == 3: 461 | logger.warning( 462 | "Passing `txt_ids` 3d torch.Tensor is deprecated." 463 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 464 | ) 465 | txt_ids = txt_ids[0] 466 | if img_ids.ndim == 3: 467 | logger.warning( 468 | "Passing `img_ids` 3d torch.Tensor is deprecated." 469 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 470 | ) 471 | img_ids = img_ids[0] 472 | 473 | ids = torch.cat((txt_ids, img_ids), dim=0) 474 | image_rotary_emb = self.pos_embed(ids) 475 | 476 | if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: 477 | ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") 478 | ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) 479 | joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) 480 | 481 | for index_block, block in enumerate(self.transformer_blocks): 482 | if torch.is_grad_enabled() and self.gradient_checkpointing: 483 | 484 | def create_custom_forward(module, return_dict=None): 485 | def custom_forward(*inputs): 486 | if return_dict is not None: 487 | return module(*inputs, return_dict=return_dict) 488 | else: 489 | return module(*inputs) 490 | 491 | return custom_forward 492 | 493 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 494 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 495 | create_custom_forward(block), 496 | hidden_states, 497 | encoder_hidden_states, 498 | temb, 499 | image_rotary_emb, 500 | cond_temb=cond_temb if use_condition else None, 501 | cond_hidden_states=cond_hidden_states if use_condition else None, 502 | **ckpt_kwargs, 503 | ) 504 | 505 | else: 506 | encoder_hidden_states, hidden_states, cond_hidden_states = block( 507 | hidden_states=hidden_states, 508 | encoder_hidden_states=encoder_hidden_states, 509 | cond_hidden_states=cond_hidden_states if use_condition else None, 510 | temb=temb, 511 | cond_temb=cond_temb if use_condition else None, 512 | image_rotary_emb=image_rotary_emb, 513 | joint_attention_kwargs=joint_attention_kwargs, 514 | ) 515 | 516 | # controlnet residual 517 | if controlnet_block_samples is not None: 518 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) 519 | interval_control = int(np.ceil(interval_control)) 520 | # For Xlabs ControlNet. 521 | if controlnet_blocks_repeat: 522 | hidden_states = ( 523 | hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] 524 | ) 525 | else: 526 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] 527 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 528 | 529 | for index_block, block in enumerate(self.single_transformer_blocks): 530 | if torch.is_grad_enabled() and self.gradient_checkpointing: 531 | 532 | def create_custom_forward(module, return_dict=None): 533 | def custom_forward(*inputs): 534 | if return_dict is not None: 535 | return module(*inputs, return_dict=return_dict) 536 | else: 537 | return module(*inputs) 538 | 539 | return custom_forward 540 | 541 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 542 | hidden_states, cond_hidden_states = torch.utils.checkpoint.checkpoint( 543 | create_custom_forward(block), 544 | hidden_states, 545 | temb, 546 | image_rotary_emb, 547 | cond_temb=cond_temb if use_condition else None, 548 | cond_hidden_states=cond_hidden_states if use_condition else None, 549 | **ckpt_kwargs, 550 | ) 551 | 552 | else: 553 | hidden_states, cond_hidden_states = block( 554 | hidden_states=hidden_states, 555 | cond_hidden_states=cond_hidden_states if use_condition else None, 556 | temb=temb, 557 | cond_temb=cond_temb if use_condition else None, 558 | image_rotary_emb=image_rotary_emb, 559 | joint_attention_kwargs=joint_attention_kwargs, 560 | ) 561 | 562 | # controlnet residual 563 | if controlnet_single_block_samples is not None: 564 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) 565 | interval_control = int(np.ceil(interval_control)) 566 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 567 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] 568 | + controlnet_single_block_samples[index_block // interval_control] 569 | ) 570 | 571 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 572 | 573 | hidden_states = self.norm_out(hidden_states, temb) 574 | output = self.proj_out(hidden_states) 575 | 576 | if USE_PEFT_BACKEND: 577 | # remove `lora_scale` from each PEFT layer 578 | unscale_lora_layers(self, lora_scale) 579 | 580 | if not return_dict: 581 | return (output,) 582 | 583 | return Transformer2DModelOutput(sample=output) -------------------------------------------------------------------------------- /train/train_spatial.sh: -------------------------------------------------------------------------------- 1 | export MODEL_DIR="black-forest-labs/FLUX.1-dev" # your flux path 2 | export OUTPUT_DIR="./models/pose_model" # your save path 3 | export CONFIG="./default_config.yaml" 4 | export TRAIN_DATA="./examples/pose.jsonl" # your data jsonl file 5 | export LOG_PATH="$OUTPUT_DIR/log" 6 | 7 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file $CONFIG train.py \ 8 | --pretrained_model_name_or_path $MODEL_DIR \ 9 | --cond_size=512 \ 10 | --noise_size=1024 \ 11 | --subject_column="None" \ 12 | --spatial_column="source" \ 13 | --target_column="target" \ 14 | --caption_column="caption" \ 15 | --ranks 128 \ 16 | --network_alphas 128 \ 17 | --output_dir=$OUTPUT_DIR \ 18 | --logging_dir=$LOG_PATH \ 19 | --mixed_precision="bf16" \ 20 | --train_data_dir=$TRAIN_DATA \ 21 | --learning_rate=1e-4 \ 22 | --train_batch_size=1 \ 23 | --validation_prompt "A girl in the city." \ 24 | --num_train_epochs=1000 \ 25 | --validation_steps=20 \ 26 | --checkpointing_steps=20 \ 27 | --spatial_test_images "./examples/openpose_data/1.png" \ 28 | --subject_test_images None \ 29 | --test_h 1024 \ 30 | --test_w 1024 \ 31 | --num_validation_images=2 32 | -------------------------------------------------------------------------------- /train/train_style.sh: -------------------------------------------------------------------------------- 1 | export MODEL_DIR="black-forest-labs/FLUX.1-dev" # your flux path 2 | export OUTPUT_DIR="./models/style_model" # your save path 3 | export CONFIG="./default_config.yaml" 4 | export TRAIN_DATA="./examples/style.jsonl" # your data jsonl file 5 | export LOG_PATH="$OUTPUT_DIR/log" 6 | 7 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file $CONFIG train.py \ 8 | --pretrained_model_name_or_path $MODEL_DIR \ 9 | --cond_size=512 \ 10 | --noise_size=1024 \ 11 | --subject_column="None" \ 12 | --spatial_column="source" \ 13 | --target_column="target" \ 14 | --caption_column="caption" \ 15 | --ranks 128 \ 16 | --network_alphas 128 \ 17 | --output_dir=$OUTPUT_DIR \ 18 | --logging_dir=$LOG_PATH \ 19 | --mixed_precision="bf16" \ 20 | --train_data_dir=$TRAIN_DATA \ 21 | --learning_rate=1e-4 \ 22 | --train_batch_size=1 \ 23 | --validation_prompt "Ghibli Studio style, Charming hand-drawn anime-style illustration" \ 24 | --num_train_epochs=1000 \ 25 | --validation_steps=20 \ 26 | --checkpointing_steps=20 \ 27 | --spatial_test_images "./examples/style_data/5.png" \ 28 | --subject_test_images None \ 29 | --test_h 1024 \ 30 | --test_w 1024 \ 31 | --num_validation_images=2 32 | # 33 | -------------------------------------------------------------------------------- /train/train_subject.sh: -------------------------------------------------------------------------------- 1 | export MODEL_DIR="black-forest-labs/FLUX.1-dev" # your flux path 2 | export OUTPUT_DIR="./models/subject_model" # your save path 3 | export CONFIG="./default_config.yaml" 4 | export TRAIN_DATA="./examples/subject.jsonl" # your data jsonl file 5 | export LOG_PATH="$OUTPUT_DIR/log" 6 | 7 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file $CONFIG train.py \ 8 | --pretrained_model_name_or_path $MODEL_DIR \ 9 | --cond_size=512 \ 10 | --noise_size=1024 \ 11 | --subject_column="source" \ 12 | --spatial_column="None" \ 13 | --target_column="target" \ 14 | --caption_column="caption" \ 15 | --ranks 128 \ 16 | --network_alphas 128 \ 17 | --output_dir=$OUTPUT_DIR \ 18 | --logging_dir=$LOG_PATH \ 19 | --mixed_precision="bf16" \ 20 | --train_data_dir=$TRAIN_DATA \ 21 | --learning_rate=1e-4 \ 22 | --train_batch_size=1 \ 23 | --validation_prompt "An SKS in the city." \ 24 | --num_train_epochs=1000 \ 25 | --validation_steps=20 \ 26 | --checkpointing_steps=20 \ 27 | --spatial_test_images None \ 28 | --subject_test_images "./examples/subject_data/3.png" \ 29 | --test_h 1024 \ 30 | --test_w 1024 \ 31 | --num_validation_images=2 32 | --------------------------------------------------------------------------------