├── .github
└── workflows
│ └── typecheck.yaml
├── .gitignore
├── LICENSE
├── README.md
├── app.py
├── assets
├── CFG-Zero
│ ├── image.webp
│ ├── image_CFG.webp
│ └── image_CFG_zero_star.webp
├── example1.jpeg
├── example2.jpeg
├── example3.jpeg
├── example4.jpeg
├── img.md
├── method.jpg
├── result_canny.png
├── result_ghibli.png
├── result_subject.png
├── result_subject_inpainting.png
└── teaser.jpg
├── infer.ipynb
├── infer.py
├── infer_multi.py
├── requirements.txt
├── src
├── __init__.py
├── layers_cache.py
├── lora_helper.py
├── pipeline.py
└── transformer_flux.py
├── test_imgs
├── canny.png
├── depth.png
├── ghibli.png
├── inpainting.png
├── openpose.png
├── seg.png
├── subject_0.png
└── subject_1.png
└── train
├── default_config.yaml
├── examples
├── openpose_data
│ ├── 1.png
│ └── 2.png
├── pose.jsonl
├── style.jsonl
├── style_data
│ ├── 5.png
│ └── 6.png
├── subject.jsonl
└── subject_data
│ ├── 3.png
│ └── 4.png
├── readme.md
├── src
├── __init__.py
├── jsonl_datasets.py
├── layers.py
├── lora_helper.py
├── pipeline.py
├── prompt_helper.py
└── transformer_flux.py
├── train.py
├── train_spatial.sh
├── train_style.sh
└── train_subject.sh
/.github/workflows/typecheck.yaml:
--------------------------------------------------------------------------------
1 | name: Typecheck
2 |
3 | # These checks will run if at least one file is outside of the `paths-ignore`
4 | # list, but will be skipped if *all* files are in the `paths-ignore` list.
5 | #
6 | # Fore more info, see:
7 | # https://docs.github.com/en/actions/writing-workflows/workflow-syntax-for-github-actions#example-excluding-paths
8 |
9 | on:
10 | push:
11 | branches:
12 | - 'main'
13 | paths-ignore:
14 | - '**.jpeg'
15 | - '**.jpg'
16 | - '**.md'
17 | - '**.png'
18 | - '**.webp'
19 |
20 | pull_request:
21 | branches:
22 | - 'main'
23 | paths-ignore:
24 | - '**.jpeg'
25 | - '**.jpg'
26 | - '**.md'
27 | - '**.png'
28 | - '**.webp'
29 |
30 | jobs:
31 | test:
32 | strategy:
33 | fail-fast: false
34 | matrix:
35 | os: [ 'ubuntu-24.04' ]
36 | python: [ '3.10' ]
37 |
38 | runs-on: ${{ matrix.os }}
39 | name: Python ${{ matrix.python }} on ${{ matrix.os }}
40 |
41 | steps:
42 | - name: Checkout the repo
43 | uses: actions/checkout@v4
44 |
45 | - name: Setup Python
46 | uses: actions/setup-python@v5
47 | with:
48 | python-version: ${{ matrix.python }}
49 | cache: 'pip'
50 |
51 | - name: Update pip
52 | run: python -m pip install --upgrade pip
53 |
54 | - name: Install Python deps
55 | run: python -m pip install -r requirements.txt
56 |
57 | - name: Install Mypy
58 | run: python -m pip install mypy
59 |
60 | - name: Check types with Mypy
61 | run: python -m mypy --python-version=${{ matrix.python }} .
62 | # TODO: fix the type checking errors and remove this line to make errors
63 | # obvious by failing the test.
64 | continue-on-error: true
65 |
66 | - name: Install PyType
67 | run: python -m pip install pytype
68 |
69 | - name: Check types with PyType
70 | run: python -m pytype --python-version=${{ matrix.python }} -k .
71 | # TODO: fix the type checking errors and remove this line to make errors
72 | # obvious by failing the test.
73 | continue-on-error: true
74 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.safetensors
3 |
4 | .DS_Store
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Implementation of EasyControl
2 |
3 | EasyControl: Adding Efficient and Flexible Control for Diffusion Transformer
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | > *[Yuxuan Zhang](https://xiaojiu-z.github.io/YuxuanZhang.github.io/), [Yirui Yuan](https://github.com/Reynoldyy), [Yiren Song](https://scholar.google.com.hk/citations?user=L2YS0jgAAAAJ), [Haofan Wang](https://haofanwang.github.io/), [Jiaming Liu](https://scholar.google.com/citations?user=SmL7oMQAAAAJ&hl=en)*
12 | >
13 | > Tiamat AI, ShanghaiTech University, National University of Singapore, Liblib AI
14 |
15 |
16 |
17 | ## Features
18 | * **Motivation:** The architecture of diffusion models is transitioning from Unet-based to DiT (Diffusion Transformer). However, the DiT ecosystem lacks mature plugin support and faces challenges such as efficiency bottlenecks, conflicts in multi-condition coordination, and insufficient model adaptability.
19 | * **Contribution:** We propose EasyControl, an efficient and flexible unified conditional DiT framework. By incorporating a lightweight Condition Injection LoRA module, a Position-Aware Training Paradigm, and a combination of Causal Attention mechanisms with KV Cache technology, we significantly enhance **model compatibility** (enabling plug-and-play functionality and style lossless control), **generation flexibility** (supporting multiple resolutions, aspect ratios, and multi-condition combinations), and **inference efficiency**.
20 |
21 |
22 | ## News
23 | - **2025-04-11**: 🔥🔥🔥 Training code have been released. Recommanded Hardware: at least 1x NVIDIA H100/H800/A100, GPUs Memory: ~80GB GPU memory.
24 | - **2025-04-09**: ⭐️ The codes for the simple API have been released. If you wish to run the models on your personal machines, head over to the simple_api branch to access the relevant resources.
25 |
26 | - **2025-04-07**: 🔥 Thanks to the great work by [CFG-Zero*](https://github.com/WeichenFan/CFG-Zero-star) team, EasyControl is now integrated with CFG-Zero*!! With just a few lines of code, you can boost image fidelity and controllability!! You can download the modified code from [this link](https://github.com/WeichenFan/CFG-Zero-star/blob/main/models/easycontrol/infer.py) and try it.
27 |
28 |
29 |
30 |  |
31 |  |
32 |  |
33 |
34 |
35 | Source Image |
36 | CFG |
37 | CFG-Zero* |
38 |
39 |
40 |
41 | - **2025-04-03**: Thanks to [jax-explorer](https://github.com/jax-explorer), [Ghibli Img2Img Control ComfyUI Node](https://github.com/jax-explorer/ComfyUI-easycontrol) is supported!
42 |
43 | - **2025-04-01**: 🔥 New [Stylized Img2Img Control Model](https://huggingface.co/spaces/jamesliu1217/EasyControl_Ghibli) is now released!! Transform portraits into Studio Ghibli-style artwork using this LoRA model. Trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, it preserves facial features while applying the iconic anime aesthetic.
44 |
45 |
46 |
47 |
48 |  |
49 |  |
50 |
51 |
52 | Example 3 |
53 | Example 4 |
54 |
55 |
56 |
57 |
58 | - **2025-03-19**: 🔥 We have released [huggingface demo](https://huggingface.co/spaces/jamesliu1217/EasyControl)! You can now try out EasyControl with the huggingface space, enjoy it!
59 |
60 |
61 |
62 |  |
63 |  |
64 |
65 |
66 | Example 1 |
67 | Example 2 |
68 |
69 |
70 |
71 |
72 | - **2025-03-18**: 🔥 We have released our [pre-trained checkpoints](https://huggingface.co/Xiaojiu-Z/EasyControl/) on Hugging Face! You can now try out EasyControl with the official weights.
73 | - **2025-03-12**: ⭐️ Inference code are released. Once we have ensured that everything is functioning correctly, the new model will be merged into this repository. Stay tuned for updates! 😊
74 |
75 | ## Installation
76 |
77 | We recommend using Python 3.10 and PyTorch with CUDA support. To set up the environment:
78 |
79 | ```bash
80 | # Create a new conda environment
81 | conda create -n easycontrol python=3.10
82 | conda activate easycontrol
83 |
84 | # Install other dependencies
85 | pip install -r requirements.txt
86 | ```
87 |
88 | ## Download
89 |
90 | You can download the model directly from [Hugging Face](https://huggingface.co/EasyControl/EasyControl).
91 | Or download using Python script:
92 |
93 | ```python
94 | from huggingface_hub import hf_hub_download
95 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/canny.safetensors", local_dir="./")
96 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/depth.safetensors", local_dir="./")
97 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/hedsketch.safetensors", local_dir="./")
98 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/inpainting.safetensors", local_dir="./")
99 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/pose.safetensors", local_dir="./")
100 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/seg.safetensors", local_dir="./")
101 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/subject.safetensors", local_dir="./")
102 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/Ghibli.safetensors", local_dir="./")
103 | ```
104 |
105 | If you cannot access Hugging Face, you can use [hf-mirror](https://hf-mirror.com/) to download the models:
106 | ```python
107 | export HF_ENDPOINT=https://hf-mirror.com
108 | huggingface-cli download --resume-download Xiaojiu-Z/EasyControl --local-dir checkpoints --local-dir-use-symlinks False
109 | ```
110 |
111 | ## Usage
112 | Here's a basic example of using EasyControl:
113 |
114 | ### Model Initialization
115 |
116 | ```python
117 | import torch
118 | from PIL import Image
119 | from src.pipeline import FluxPipeline
120 | from src.transformer_flux import FluxTransformer2DModel
121 | from src.lora_helper import set_single_lora, set_multi_lora
122 |
123 | def clear_cache(transformer):
124 | for name, attn_processor in transformer.attn_processors.items():
125 | attn_processor.bank_kv.clear()
126 |
127 | # Initialize model
128 | device = "cuda"
129 | base_path = "FLUX.1-dev" # Path to your base model
130 | pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16, device=device)
131 | transformer = FluxTransformer2DModel.from_pretrained(
132 | base_path,
133 | subfolder="transformer",
134 | torch_dtype=torch.bfloat16,
135 | device=device
136 | )
137 | pipe.transformer = transformer
138 | pipe.to(device)
139 |
140 | # Load control models
141 | lora_path = "./checkpoints/models"
142 | control_models = {
143 | "canny": f"{lora_path}/canny.safetensors",
144 | "depth": f"{lora_path}/depth.safetensors",
145 | "hedsketch": f"{lora_path}/hedsketch.safetensors",
146 | "pose": f"{lora_path}/pose.safetensors",
147 | "seg": f"{lora_path}/seg.safetensors",
148 | "inpainting": f"{lora_path}/inpainting.safetensors",
149 | "subject": f"{lora_path}/subject.safetensors",
150 | }
151 | ```
152 |
153 | ### Single Condition Control
154 |
155 | ```python
156 | # Single spatial condition control example
157 | path = control_models["canny"]
158 | set_single_lora(pipe.transformer, path, lora_weights=[1], cond_size=512)
159 |
160 | # Generate image
161 | prompt = "A nice car on the beach"
162 | spatial_image = Image.open("./test_imgs/canny.png").convert("RGB")
163 |
164 | image = pipe(
165 | prompt,
166 | height=720,
167 | width=992,
168 | guidance_scale=3.5,
169 | num_inference_steps=25,
170 | max_sequence_length=512,
171 | generator=torch.Generator("cpu").manual_seed(5),
172 | spatial_images=[spatial_image],
173 | cond_size=512,
174 | ).images[0]
175 |
176 | # Clear cache after generation
177 | clear_cache(pipe.transformer)
178 | ```
179 |
180 |
181 |
182 |
183 |  |
184 |  |
185 |
186 |
187 | Canny Condition |
188 | Generated Result |
189 |
190 |
191 |
192 |
193 | ```python
194 | # Single subject condition control example
195 | path = control_models["subject"]
196 | set_single_lora(pipe.transformer, path, lora_weights=[1], cond_size=512)
197 |
198 | # Generate image
199 | prompt = "A SKS in the library"
200 | subject_image = Image.open("./test_imgs/subject_0.png").convert("RGB")
201 |
202 | image = pipe(
203 | prompt,
204 | height=1024,
205 | width=1024,
206 | guidance_scale=3.5,
207 | num_inference_steps=25,
208 | max_sequence_length=512,
209 | generator=torch.Generator("cpu").manual_seed(5),
210 | subject_images=[subject_image],
211 | cond_size=512,
212 | ).images[0]
213 |
214 | # Clear cache after generation
215 | clear_cache(pipe.transformer)
216 | ```
217 |
218 |
219 |
220 |
221 |  |
222 |  |
223 |
224 |
225 | Subject Condition |
226 | Generated Result |
227 |
228 |
229 |
230 |
231 | ### Multi-Condition Control
232 |
233 | ```python
234 | # Multi-condition control example
235 | paths = [control_models["subject"], control_models["inpainting"]]
236 | set_multi_lora(pipe.transformer, paths, lora_weights=[[1], [1]], cond_size=512)
237 |
238 | prompt = "A SKS on the car"
239 | subject_images = [Image.open("./test_imgs/subject_1.png").convert("RGB")]
240 | spatial_images = [Image.open("./test_imgs/inpainting.png").convert("RGB")]
241 |
242 | image = pipe(
243 | prompt,
244 | height=1024,
245 | width=1024,
246 | guidance_scale=3.5,
247 | num_inference_steps=25,
248 | max_sequence_length=512,
249 | generator=torch.Generator("cpu").manual_seed(42),
250 | subject_images=subject_images,
251 | spatial_images=spatial_images,
252 | cond_size=512,
253 | ).images[0]
254 |
255 | # Clear cache after generation
256 | clear_cache(pipe.transformer)
257 | ```
258 |
259 |
260 |
261 |
262 |  |
263 |  |
264 |  |
265 |
266 |
267 | Subject Condition |
268 | Inpainting Condition |
269 | Generated Result |
270 |
271 |
272 |
273 |
274 | ### Ghibli-Style Portrait Generation
275 |
276 | ```python
277 | import spaces
278 | import os
279 | import json
280 | import time
281 | import torch
282 | from PIL import Image
283 | from tqdm import tqdm
284 | import gradio as gr
285 |
286 | from safetensors.torch import save_file
287 | from src.pipeline import FluxPipeline
288 | from src.transformer_flux import FluxTransformer2DModel
289 | from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
290 |
291 | # Initialize the image processor
292 | base_path = "black-forest-labs/FLUX.1-dev"
293 | lora_base_path = "./checkpoints/models"
294 |
295 |
296 | pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
297 | transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
298 | pipe.transformer = transformer
299 | pipe.to("cuda")
300 |
301 | def clear_cache(transformer):
302 | for name, attn_processor in transformer.attn_processors.items():
303 | attn_processor.bank_kv.clear()
304 |
305 | # Define the Gradio interface
306 | @spaces.GPU()
307 | def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type):
308 | # Set the control type
309 | if control_type == "Ghibli":
310 | lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
311 | set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
312 |
313 | # Process the image
314 | spatial_imgs = [spatial_img] if spatial_img else []
315 | image = pipe(
316 | prompt,
317 | height=int(height),
318 | width=int(width),
319 | guidance_scale=3.5,
320 | num_inference_steps=25,
321 | max_sequence_length=512,
322 | generator=torch.Generator("cpu").manual_seed(seed),
323 | subject_images=[],
324 | spatial_images=spatial_imgs,
325 | cond_size=512,
326 | ).images[0]
327 | clear_cache(pipe.transformer)
328 | return image
329 |
330 | # Define the Gradio interface components
331 | control_types = ["Ghibli"]
332 |
333 |
334 | # Create the Gradio Blocks interface
335 | with gr.Blocks() as demo:
336 | gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl")
337 | gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.")
338 | gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Due to hardware constraints, only low-resolution images can be generated. For high-resolution (1024+), please set up your own environment.)")
339 |
340 | gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: `Ghibli Studio style, Charming hand-drawn anime-style illustration`")
341 | gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))")
342 |
343 | with gr.Tab("Ghibli Condition Generation"):
344 | with gr.Row():
345 | with gr.Column():
346 | prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration")
347 | spatial_img = gr.Image(label="Ghibli Image", type="pil") # 上传图像文件
348 | height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768)
349 | width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768)
350 | seed = gr.Number(label="Seed", value=42)
351 | control_type = gr.Dropdown(choices=control_types, label="Control Type")
352 | single_generate_btn = gr.Button("Generate Image")
353 | with gr.Column():
354 | single_output_image = gr.Image(label="Generated Image")
355 |
356 |
357 | # Link the buttons to the functions
358 | single_generate_btn.click(
359 | single_condition_generate_image,
360 | inputs=[prompt, spatial_img, height, width, seed, control_type],
361 | outputs=single_output_image
362 | )
363 |
364 | # Launch the Gradio app
365 | demo.queue().launch()
366 | ```
367 |
368 |
369 |
370 |
371 |  |
372 |  |
373 |
374 |
375 | Input Image |
376 | Generated Result |
377 |
378 |
379 |
380 |
381 | ## Usage Tips
382 |
383 | - Clear cache after each generation using `clear_cache(pipe.transformer)`
384 | - For optimal performance:
385 | - Start with `guidance_scale=3.5` and adjust based on results
386 | - Use `num_inference_steps=25` for a good balance of quality and speed
387 | - When using set_multi_lora api, make sure the subject lora path(subject) is before the spatial lora path(canny, depth, hedsketch, etc.).
388 |
389 | ## Todo List
390 | 1. - [x] Inference code
391 | 2. - [x] Spatial Pre-trained weights
392 | 3. - [x] Subject Pre-trained weights
393 | 4. - [x] Training code
394 |
395 |
396 | ## Star History
397 |
398 | [](https://star-history.com/#Xiaojiu-z/EasyControl&Date)
399 |
400 | ## Disclaimer
401 | The code of EasyControl is released under [Apache License](https://github.com/Xiaojiu-Z/EasyControl?tab=Apache-2.0-1-ov-file#readme) for both academic and commercial usage. Our released checkpoints are for research purposes only. Users are granted the freedom to create images using this tool, but they are obligated to comply with local laws and utilize it responsibly. The developers will not assume any responsibility for potential misuse by users.
402 |
403 | ## Hiring/Cooperation
404 | - **2025-04-03**: If you are interested in EasyControl and beyond, or if you are interested in building 4o-like capacity (in a low-budget way, of course), we can collaborate offline in Shanghai/Beijing/Hong Kong/Singapore or online.
405 | Reach out: **jmliu1217@gmail.com (wechat: jiaming068870)**
406 |
407 | ## Citation
408 | ```bibtex
409 | @article{zhang2025easycontrol,
410 | title={EasyControl: Adding Efficient and Flexible Control for Diffusion Transformer},
411 | author={Zhang, Yuxuan and Yuan, Yirui and Song, Yiren and Wang, Haofan and Liu, Jiaming},
412 | journal={arXiv preprint arXiv:2503.07027},
413 | year={2025}
414 | }
415 | ```
416 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import spaces
2 | import os
3 | import json
4 | import time
5 | import torch
6 | from PIL import Image
7 | from tqdm import tqdm
8 | import gradio as gr
9 |
10 | from safetensors.torch import save_file
11 | from src.pipeline import FluxPipeline
12 | from src.transformer_flux import FluxTransformer2DModel
13 | from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
14 |
15 | class ImageProcessor:
16 | def __init__(self, path):
17 | device = "cuda"
18 | self.pipe = FluxPipeline.from_pretrained(path, torch_dtype=torch.bfloat16, device=device)
19 | transformer = FluxTransformer2DModel.from_pretrained(path, subfolder="transformer", torch_dtype=torch.bfloat16, device=device)
20 | self.pipe.transformer = transformer
21 | self.pipe.to(device)
22 |
23 | def clear_cache(self, transformer):
24 | for name, attn_processor in transformer.attn_processors.items():
25 | attn_processor.bank_kv.clear()
26 |
27 | @spaces.GPU()
28 | def process_image(self, prompt='', subject_imgs=[], spatial_imgs=[], height=768, width=768, output_path=None, seed=42):
29 | image = self.pipe(
30 | prompt,
31 | height=int(height),
32 | width=int(width),
33 | guidance_scale=3.5,
34 | num_inference_steps=25,
35 | max_sequence_length=512,
36 | generator=torch.Generator("cpu").manual_seed(seed),
37 | subject_images=subject_imgs,
38 | spatial_images=spatial_imgs,
39 | cond_size=512,
40 | ).images[0]
41 | self.clear_cache(self.pipe.transformer)
42 | if output_path:
43 | image.save(output_path)
44 | return image
45 |
46 | # Initialize the image processor
47 | base_path = "black-forest-labs/FLUX.1-dev"
48 | lora_base_path = "EasyControl/models"
49 | style_lora_base_path = "Shakker-Labs"
50 | processor = ImageProcessor(base_path)
51 |
52 | # Define the Gradio interface
53 | def single_condition_generate_image(prompt, subject_img, spatial_img, height, width, seed, control_type, style_lora=None):
54 | # Set the control type
55 | if control_type == "subject":
56 | lora_path = os.path.join(lora_base_path, "subject.safetensors")
57 | elif control_type == "depth":
58 | lora_path = os.path.join(lora_base_path, "depth.safetensors")
59 | elif control_type == "seg":
60 | lora_path = os.path.join(lora_base_path, "seg.safetensors")
61 | elif control_type == "pose":
62 | lora_path = os.path.join(lora_base_path, "pose.safetensors")
63 | elif control_type == "inpainting":
64 | lora_path = os.path.join(lora_base_path, "inpainting.safetensors")
65 | elif control_type == "hedsketch":
66 | lora_path = os.path.join(lora_base_path, "hedsketch.safetensors")
67 | elif control_type == "canny":
68 | lora_path = os.path.join(lora_base_path, "canny.safetensors")
69 | set_single_lora(processor.pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
70 |
71 | # Set the style LoRA
72 | if style_lora=="None":
73 | pass
74 | else:
75 | if style_lora == "Simple_Sketch":
76 | processor.pipe.unload_lora_weights()
77 | style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Children-Simple-Sketch")
78 | processor.pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
79 | if style_lora == "Text_Poster":
80 | processor.pipe.unload_lora_weights()
81 | style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Text-Poster")
82 | processor.pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Text-Poster.safetensors")
83 | if style_lora == "Vector_Style":
84 | processor.pipe.unload_lora_weights()
85 | style_lora_path = os.path.join(style_lora_base_path, "FLUX.1-dev-LoRA-Vector-Journey")
86 | processor.pipe.load_lora_weights(style_lora_path, weight_name="FLUX-dev-lora-Vector-Journey.safetensors")
87 |
88 | # Process the image
89 | subject_imgs = [subject_img] if subject_img else []
90 | spatial_imgs = [spatial_img] if spatial_img else []
91 | image = processor.process_image(prompt=prompt, subject_imgs=subject_imgs, spatial_imgs=spatial_imgs, height=height, width=width, seed=seed)
92 | return image
93 |
94 | # Define the Gradio interface
95 | def multi_condition_generate_image(prompt, subject_img, spatial_img, height, width, seed):
96 | subject_path = os.path.join(lora_base_path, "subject.safetensors")
97 | inpainting_path = os.path.join(lora_base_path, "inpainting.safetensors")
98 | set_multi_lora(processor.pipe.transformer, [subject_path, inpainting_path], lora_weights=[[1],[1]],cond_size=512)
99 |
100 | # Process the image
101 | subject_imgs = [subject_img] if subject_img else []
102 | spatial_imgs = [spatial_img] if spatial_img else []
103 | image = processor.process_image(prompt=prompt, subject_imgs=subject_imgs, spatial_imgs=spatial_imgs, height=height, width=width, seed=seed)
104 | return image
105 |
106 | # Define the Gradio interface components
107 | control_types = ["subject", "depth", "pose", "inpainting", "hedsketch", "seg", "canny"]
108 | style_loras = ["Simple_Sketch", "Text_Poster", "Vector_Style", "None"]
109 |
110 | # Example data
111 | single_examples = [
112 | ["A SKS in the library", Image.open("./test_imgs/subject1.png"), None, 1024, 1024, 5, "subject", None],
113 | ["In a picturesque village, a narrow cobblestone street with rustic stone buildings, colorful blinds, and lush green spaces, a cartoon man drawn with simple lines and solid colors stands in the foreground, wearing a red shirt, beige work pants, and brown shoes, carrying a strap on his shoulder. The scene features warm and enticing colors, a pleasant fusion of nature and architecture, and the camera's perspective on the street clearly shows the charming and quaint environment., Integrating elements of reality and cartoon.", None, Image.open("./test_imgs/spatial1.png"), 1024, 1024, 1, "pose", "Vector_Style"],
114 | ]
115 | multi_examples = [
116 | ["A SKS on the car", Image.open("./test_imgs/subject2.png"), Image.open("./test_imgs/spatial2.png"), 1024, 1024, 7],
117 | ]
118 |
119 |
120 | # Create the Gradio Blocks interface
121 | with gr.Blocks() as demo:
122 | gr.Markdown("# Image Generation with EasyControl")
123 | gr.Markdown("Generate images using EasyControl with different control types and style LoRAs.")
124 |
125 | with gr.Tab("Single Condition Generation"):
126 | with gr.Row():
127 | with gr.Column():
128 | prompt = gr.Textbox(label="Prompt")
129 | subject_img = gr.Image(label="Subject Image", type="pil") # 上传图像文件
130 | spatial_img = gr.Image(label="Spatial Image", type="pil") # 上传图像文件
131 | height = gr.Slider(minimum=256, maximum=1536, step=64, label="Height", value=768)
132 | width = gr.Slider(minimum=256, maximum=1536, step=64, label="Width", value=768)
133 | seed = gr.Number(label="Seed", value=42)
134 | control_type = gr.Dropdown(choices=control_types, label="Control Type")
135 | style_lora = gr.Dropdown(choices=style_loras, label="Style LoRA")
136 | single_generate_btn = gr.Button("Generate Image")
137 | with gr.Column():
138 | single_output_image = gr.Image(label="Generated Image")
139 |
140 | # Add examples for Single Condition Generation
141 | gr.Examples(
142 | examples=single_examples,
143 | inputs=[prompt, subject_img, spatial_img, height, width, seed, control_type, style_lora],
144 | outputs=single_output_image,
145 | fn=single_condition_generate_image,
146 | cache_examples=False, # 缓存示例结果以加快加载速度
147 | label="Single Condition Examples"
148 | )
149 |
150 |
151 | with gr.Tab("Multi-Condition Generation"):
152 | with gr.Row():
153 | with gr.Column():
154 | multi_prompt = gr.Textbox(label="Prompt")
155 | multi_subject_img = gr.Image(label="Subject Image", type="pil") # 上传图像文件
156 | multi_spatial_img = gr.Image(label="Spatial Image", type="pil") # 上传图像文件
157 | multi_height = gr.Slider(minimum=256, maximum=1536, step=64, label="Height", value=768)
158 | multi_width = gr.Slider(minimum=256, maximum=1536, step=64, label="Width", value=768)
159 | multi_seed = gr.Number(label="Seed", value=42)
160 | multi_generate_btn = gr.Button("Generate Image")
161 | with gr.Column():
162 | multi_output_image = gr.Image(label="Generated Image")
163 |
164 | # Add examples for Multi-Condition Generation
165 | gr.Examples(
166 | examples=multi_examples,
167 | inputs=[multi_prompt, multi_subject_img, multi_spatial_img, multi_height, multi_width, multi_seed],
168 | outputs=multi_output_image,
169 | fn=multi_condition_generate_image,
170 | cache_examples=False, # 缓存示例结果以加快加载速度
171 | label="Multi-Condition Examples"
172 | )
173 |
174 |
175 | # Link the buttons to the functions
176 | single_generate_btn.click(
177 | single_condition_generate_image,
178 | inputs=[prompt, subject_img, spatial_img, height, width, seed, control_type, style_lora],
179 | outputs=single_output_image
180 | )
181 | multi_generate_btn.click(
182 | multi_condition_generate_image,
183 | inputs=[multi_prompt, multi_subject_img, multi_spatial_img, multi_height, multi_width, multi_seed],
184 | outputs=multi_output_image
185 | )
186 |
187 | # Launch the Gradio app
188 | demo.queue().launch()
189 |
--------------------------------------------------------------------------------
/assets/CFG-Zero/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/CFG-Zero/image.webp
--------------------------------------------------------------------------------
/assets/CFG-Zero/image_CFG.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/CFG-Zero/image_CFG.webp
--------------------------------------------------------------------------------
/assets/CFG-Zero/image_CFG_zero_star.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/CFG-Zero/image_CFG_zero_star.webp
--------------------------------------------------------------------------------
/assets/example1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/example1.jpeg
--------------------------------------------------------------------------------
/assets/example2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/example2.jpeg
--------------------------------------------------------------------------------
/assets/example3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/example3.jpeg
--------------------------------------------------------------------------------
/assets/example4.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/example4.jpeg
--------------------------------------------------------------------------------
/assets/img.md:
--------------------------------------------------------------------------------
1 | put imgs here!
2 |
--------------------------------------------------------------------------------
/assets/method.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/method.jpg
--------------------------------------------------------------------------------
/assets/result_canny.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/result_canny.png
--------------------------------------------------------------------------------
/assets/result_ghibli.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/result_ghibli.png
--------------------------------------------------------------------------------
/assets/result_subject.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/result_subject.png
--------------------------------------------------------------------------------
/assets/result_subject_inpainting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/result_subject_inpainting.png
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/assets/teaser.jpg
--------------------------------------------------------------------------------
/infer.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "!nvidia-smi"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import os\n",
19 | "import json\n",
20 | "import time\n",
21 | "import torch\n",
22 | "from PIL import Image\n",
23 | "from tqdm import tqdm\n",
24 | "\n",
25 | "from safetensors.torch import save_file\n",
26 | "from src.pipeline import FluxPipeline\n",
27 | "from src.transformer_flux import FluxTransformer2DModel\n",
28 | "from src.lora_helper import set_single_lora, set_multi_lora, unset_lora\n",
29 | "\n",
30 | "torch.cuda.set_device(1)\n",
31 | "\n",
32 | "class ImageProcessor:\n",
33 | " def __init__(self, path):\n",
34 | " device = \"cuda\"\n",
35 | " self.pipe = FluxPipeline.from_pretrained(path, torch_dtype=torch.bfloat16, device=device)\n",
36 | " transformer = FluxTransformer2DModel.from_pretrained(path, subfolder=\"transformer\",torch_dtype=torch.bfloat16, device=device)\n",
37 | " self.pipe.transformer = transformer\n",
38 | " self.pipe.to(device)\n",
39 | " \n",
40 | " def clear_cache(self, transformer):\n",
41 | " for name, attn_processor in transformer.attn_processors.items():\n",
42 | " attn_processor.bank_kv.clear()\n",
43 | " \n",
44 | " def process_image(self, prompt='', subject_imgs=[], spatial_imgs=[], height = 768, width = 768, output_path=None, seed=42):\n",
45 | " if len(spatial_imgs)>0:\n",
46 | " spatial_ls = [Image.open(image_path).convert(\"RGB\") for image_path in spatial_imgs]\n",
47 | " else:\n",
48 | " spatial_ls = []\n",
49 | " if len(subject_imgs)>0:\n",
50 | " subject_ls = [Image.open(image_path).convert(\"RGB\") for image_path in subject_imgs]\n",
51 | " else:\n",
52 | " subject_ls = []\n",
53 | "\n",
54 | " prompt = prompt\n",
55 | " image = self.pipe(\n",
56 | " prompt,\n",
57 | " height=int(height),\n",
58 | " width=int(width),\n",
59 | " guidance_scale=3.5,\n",
60 | " num_inference_steps=25,\n",
61 | " max_sequence_length=512,\n",
62 | " generator=torch.Generator(\"cpu\").manual_seed(seed), \n",
63 | " subject_images=subject_ls,\n",
64 | " spatial_images=spatial_ls,\n",
65 | " cond_size=512,\n",
66 | " ).images[0]\n",
67 | " self.clear_cache(self.pipe.transformer)\n",
68 | " image.show()\n",
69 | " if output_path:\n",
70 | " image.save(output_path)"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "### models path ###\n",
80 | "# spatial model\n",
81 | "base_path = \"FLUX.1-dev\" # your flux model path\n",
82 | "lora_path = \"./models\" # your lora folder path\n",
83 | "canny_path = lora_path + \"/canny.safetensors\"\n",
84 | "depth_path = lora_path + \"/depth.safetensors\"\n",
85 | "openpose_path = lora_path + \"/pose.safetensors\"\n",
86 | "inpainting_path = lora_path + \"/inpainting.safetensors\"\n",
87 | "hedsketch_path = lora_path + \"/hedsketch.safetensors\"\n",
88 | "seg_path = lora_path + \"/seg.safetensors\"\n",
89 | "# subject model\n",
90 | "subject_path = lora_path + \"/subject.safetensors\"\n",
91 | "\n",
92 | "# init image processor\n",
93 | "processor = ImageProcessor(base_path)"
94 | ]
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "metadata": {},
99 | "source": [
100 | "for single condition"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "# set lora\n",
110 | "path = depth_path # single control model path\n",
111 | "lora_weights=[1] # lora weights for each control model\n",
112 | "set_single_lora(processor.pipe.transformer, path, lora_weights=lora_weights,cond_size=512)\n",
113 | "\n",
114 | "# infer\n",
115 | "prompt='a cafe bar'\n",
116 | "spatial_imgs=[\"./test_imgs/depth.png\"]\n",
117 | "height = 1024\n",
118 | "width = 1024\n",
119 | "processor.process_image(prompt=prompt, subject_imgs=[], spatial_imgs=spatial_imgs, height=height, width=width, seed=11)"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "metadata": {},
125 | "source": [
126 | "for multi condition"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "metadata": {},
133 | "outputs": [],
134 | "source": [
135 | "# set lora\n",
136 | "paths = [subject_path, inpainting_path] # multi control model paths\n",
137 | "lora_weights=[[1],[1]] # lora weights for each control model\n",
138 | "set_multi_lora(processor.pipe.transformer, paths, lora_weights=lora_weights, cond_size=512)\n",
139 | "\n",
140 | "# infer\n",
141 | "prompt='A SKS on the car'\n",
142 | "spatial_imgs=[\"./test_imgs/subject_1.png\"]\n",
143 | "subject_imgs=[\"./test_imgs/inpainting.png\"]\n",
144 | "height = 1024\n",
145 | "width = 1024\n",
146 | "processor.process_image(prompt=prompt, subject_imgs=subject_imgs, spatial_imgs=spatial_imgs, height=height, width=width, seed=42)"
147 | ]
148 | }
149 | ],
150 | "metadata": {
151 | "kernelspec": {
152 | "display_name": "zyxdit",
153 | "language": "python",
154 | "name": "python3"
155 | },
156 | "language_info": {
157 | "codemirror_mode": {
158 | "name": "ipython",
159 | "version": 3
160 | },
161 | "file_extension": ".py",
162 | "mimetype": "text/x-python",
163 | "name": "python",
164 | "nbconvert_exporter": "python",
165 | "pygments_lexer": "ipython3",
166 | "version": "3.10.16"
167 | }
168 | },
169 | "nbformat": 4,
170 | "nbformat_minor": 2
171 | }
172 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | from src.pipeline import FluxPipeline
4 | from src.transformer_flux import FluxTransformer2DModel
5 | from src.lora_helper import set_single_lora, set_multi_lora
6 |
7 | from huggingface_hub import hf_hub_download
8 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/canny.safetensors", local_dir="./")
9 |
10 | def clear_cache(transformer):
11 | for name, attn_processor in transformer.attn_processors.items():
12 | attn_processor.bank_kv.clear()
13 |
14 | # Initialize model
15 | device = "cuda"
16 | base_path = "black-forest-labs/FLUX.1-dev" # Path to your base model
17 | pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16, device=device)
18 | transformer = FluxTransformer2DModel.from_pretrained(
19 | base_path,
20 | subfolder="transformer",
21 | torch_dtype=torch.bfloat16,
22 | device=device
23 | )
24 | pipe.transformer = transformer
25 | pipe.to(device)
26 |
27 | # Load control models
28 | lora_path = "./models"
29 | control_models = {
30 | "canny": f"{lora_path}/canny.safetensors",
31 | "depth": f"{lora_path}/depth.safetensors",
32 | "hedsketch": f"{lora_path}/hedsketch.safetensors",
33 | "pose": f"{lora_path}/pose.safetensors",
34 | "seg": f"{lora_path}/seg.safetensors",
35 | "inpainting": f"{lora_path}/inpainting.safetensors",
36 | "subject": f"{lora_path}/subject.safetensors",
37 | }
38 |
39 | # Single spatial condition control example
40 | path = control_models["canny"]
41 | set_single_lora(pipe.transformer, path, lora_weights=[1], cond_size=512)
42 |
43 | # Generate image
44 | prompt = "A nice car on the beach"
45 |
46 | spatial_image = Image.open("./test_imgs/canny.png")
47 |
48 | image = pipe(
49 | prompt,
50 | height=768,
51 | width=1024,
52 | guidance_scale=3.5,
53 | num_inference_steps=25,
54 | max_sequence_length=512,
55 | generator=torch.Generator("cpu").manual_seed(5),
56 | spatial_images=[spatial_image],
57 | subject_images=[],
58 | cond_size=512,
59 | ).images[0]
60 |
61 | # Clear cache after generation
62 | clear_cache(pipe.transformer)
63 |
64 | image.save("output.png")
65 |
--------------------------------------------------------------------------------
/infer_multi.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | from src.pipeline import FluxPipeline
4 | from src.transformer_flux import FluxTransformer2DModel
5 | from src.lora_helper import set_single_lora, set_multi_lora
6 |
7 | from huggingface_hub import hf_hub_download
8 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/canny.safetensors", local_dir="./")
9 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/inpainting.safetensors", local_dir="./")
10 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/subject.safetensors", local_dir="./")
11 |
12 | def clear_cache(transformer):
13 | for name, attn_processor in transformer.attn_processors.items():
14 | attn_processor.bank_kv.clear()
15 |
16 | # Initialize model
17 | device = "cuda"
18 | base_path = "black-forest-labs/FLUX.1-dev" # Path to your base model
19 | pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16, device=device)
20 | transformer = FluxTransformer2DModel.from_pretrained(
21 | base_path,
22 | subfolder="transformer",
23 | torch_dtype=torch.bfloat16,
24 | device=device
25 | )
26 | pipe.transformer = transformer
27 | pipe.to(device)
28 |
29 | # Load control models
30 | lora_path = "./models"
31 | control_models = {
32 | "canny": f"{lora_path}/canny.safetensors",
33 | "depth": f"{lora_path}/depth.safetensors",
34 | "hedsketch": f"{lora_path}/hedsketch.safetensors",
35 | "pose": f"{lora_path}/pose.safetensors",
36 | "seg": f"{lora_path}/seg.safetensors",
37 | "inpainting": f"{lora_path}/inpainting.safetensors",
38 | "subject": f"{lora_path}/subject.safetensors",
39 | }
40 |
41 | # Single spatial condition control example
42 | path = control_models["canny"]
43 | set_single_lora(pipe.transformer, path, lora_weights=[1], cond_size=512)
44 | # Multi-condition control example
45 | paths = [control_models["subject"], control_models["inpainting"]]
46 | set_multi_lora(pipe.transformer, paths, lora_weights=[[1], [1]], cond_size=512)
47 |
48 | prompt = "A SKS on the car"
49 | subject_images = [Image.open("./test_imgs/subject_1.png").convert("RGB")]
50 | spatial_images = [Image.open("./test_imgs/inpainting.png").convert("RGB")]
51 |
52 | image = pipe(
53 | prompt,
54 | height=1024,
55 | width=1024,
56 | guidance_scale=3.5,
57 | num_inference_steps=25,
58 | max_sequence_length=512,
59 | generator=torch.Generator("cpu").manual_seed(42),
60 | subject_images=subject_images,
61 | spatial_images=spatial_images,
62 | cond_size=512,
63 | ).images[0]
64 |
65 | image.save("output_multi.png")
66 |
67 | # Clear cache after generation
68 | clear_cache(pipe.transformer)
69 |
70 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu124
2 | torch
3 | torchvision
4 | torchaudio==2.3.1
5 | diffusers==0.32.2
6 | easydict==1.13
7 | einops==0.8.1
8 | peft==0.14.0
9 | pillow==11.0.0
10 | protobuf==5.29.3
11 | requests==2.32.3
12 | safetensors==0.5.2
13 | sentencepiece==0.2.0
14 | spaces==0.34.1
15 | transformers==4.49.0
16 | datasets
17 | wandb
18 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/src/__init__.py
--------------------------------------------------------------------------------
/src/layers_cache.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import math
3 | from typing import Callable, List, Optional, Tuple, Union
4 | from einops import rearrange
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 | from torch import Tensor
9 | from diffusers.models.attention_processor import Attention
10 |
11 | class LoRALinearLayer(nn.Module):
12 | def __init__(
13 | self,
14 | in_features: int,
15 | out_features: int,
16 | rank: int = 4,
17 | network_alpha: Optional[float] = None,
18 | device: Optional[Union[torch.device, str]] = None,
19 | dtype: Optional[torch.dtype] = None,
20 | cond_width=512,
21 | cond_height=512,
22 | number=0,
23 | n_loras=1
24 | ):
25 | super().__init__()
26 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
27 | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
28 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
29 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
30 | self.network_alpha = network_alpha
31 | self.rank = rank
32 | self.out_features = out_features
33 | self.in_features = in_features
34 |
35 | nn.init.normal_(self.down.weight, std=1 / rank)
36 | nn.init.zeros_(self.up.weight)
37 |
38 | self.cond_height = cond_height
39 | self.cond_width = cond_width
40 | self.number = number
41 | self.n_loras = n_loras
42 |
43 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
44 | orig_dtype = hidden_states.dtype
45 | dtype = self.down.weight.dtype
46 |
47 | ####
48 | batch_size = hidden_states.shape[0]
49 | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
50 | block_size = hidden_states.shape[1] - cond_size * self.n_loras
51 | shape = (batch_size, hidden_states.shape[1], 3072)
52 | mask = torch.ones(shape, device=hidden_states.device, dtype=dtype)
53 | mask[:, :block_size+self.number*cond_size, :] = 0
54 | mask[:, block_size+(self.number+1)*cond_size:, :] = 0
55 | hidden_states = mask * hidden_states
56 | ####
57 |
58 | down_hidden_states = self.down(hidden_states.to(dtype))
59 | up_hidden_states = self.up(down_hidden_states)
60 |
61 | if self.network_alpha is not None:
62 | up_hidden_states *= self.network_alpha / self.rank
63 |
64 | return up_hidden_states.to(orig_dtype)
65 |
66 |
67 | class MultiSingleStreamBlockLoraProcessor(nn.Module):
68 | def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
69 | super().__init__()
70 | # Initialize a list to store the LoRA layers
71 | self.n_loras = n_loras
72 | self.cond_width = cond_width
73 | self.cond_height = cond_height
74 |
75 | self.q_loras = nn.ModuleList([
76 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
77 | for i in range(n_loras)
78 | ])
79 | self.k_loras = nn.ModuleList([
80 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
81 | for i in range(n_loras)
82 | ])
83 | self.v_loras = nn.ModuleList([
84 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
85 | for i in range(n_loras)
86 | ])
87 | self.lora_weights = lora_weights
88 | self.bank_attn = None
89 | self.bank_kv = []
90 |
91 |
92 | def __call__(self,
93 | attn: Attention,
94 | hidden_states: torch.FloatTensor,
95 | encoder_hidden_states: torch.FloatTensor = None,
96 | attention_mask: Optional[torch.FloatTensor] = None,
97 | image_rotary_emb: Optional[torch.Tensor] = None,
98 | use_cond = False
99 | ) -> torch.FloatTensor:
100 |
101 | batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
102 | scaled_seq_len = hidden_states.shape[1]
103 | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
104 | block_size = scaled_seq_len - cond_size * self.n_loras
105 | scaled_cond_size = cond_size
106 | scaled_block_size = block_size
107 |
108 | if len(self.bank_kv)== 0:
109 | cache = True
110 | else:
111 | cache = False
112 |
113 | if cache:
114 | query = attn.to_q(hidden_states)
115 | key = attn.to_k(hidden_states)
116 | value = attn.to_v(hidden_states)
117 | for i in range(self.n_loras):
118 | query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
119 | key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
120 | value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
121 |
122 | inner_dim = key.shape[-1]
123 | head_dim = inner_dim // attn.heads
124 |
125 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
126 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
127 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
128 |
129 | self.bank_kv.append(key[:, :, scaled_block_size:, :])
130 | self.bank_kv.append(value[:, :, scaled_block_size:, :])
131 |
132 | if attn.norm_q is not None:
133 | query = attn.norm_q(query)
134 | if attn.norm_k is not None:
135 | key = attn.norm_k(key)
136 |
137 | if image_rotary_emb is not None:
138 | from diffusers.models.embeddings import apply_rotary_emb
139 | query = apply_rotary_emb(query, image_rotary_emb)
140 | key = apply_rotary_emb(key, image_rotary_emb)
141 |
142 | num_cond_blocks = self.n_loras
143 | mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
144 | mask[ :scaled_block_size, :] = 0 # First block_size row
145 | for i in range(num_cond_blocks):
146 | start = i * scaled_cond_size + scaled_block_size
147 | end = (i + 1) * scaled_cond_size + scaled_block_size
148 | mask[start:end, start:end] = 0 # Diagonal blocks
149 | mask = mask * -1e20
150 | mask = mask.to(query.dtype)
151 |
152 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
153 | self.bank_attn = hidden_states[:, :, scaled_block_size:, :]
154 |
155 | else:
156 | query = attn.to_q(hidden_states)
157 | key = attn.to_k(hidden_states)
158 | value = attn.to_v(hidden_states)
159 |
160 | inner_dim = query.shape[-1]
161 | head_dim = inner_dim // attn.heads
162 |
163 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
164 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
165 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
166 |
167 | key = torch.concat([key[:, :, :scaled_block_size, :], self.bank_kv[0]], dim=-2)
168 | value = torch.concat([value[:, :, :scaled_block_size, :], self.bank_kv[1]], dim=-2)
169 |
170 | if attn.norm_q is not None:
171 | query = attn.norm_q(query)
172 | if attn.norm_k is not None:
173 | key = attn.norm_k(key)
174 |
175 | if image_rotary_emb is not None:
176 | from diffusers.models.embeddings import apply_rotary_emb
177 | query = apply_rotary_emb(query, image_rotary_emb)
178 | key = apply_rotary_emb(key, image_rotary_emb)
179 |
180 | query = query[:, :, :scaled_block_size, :]
181 |
182 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
183 | hidden_states = torch.concat([hidden_states, self.bank_attn], dim=-2)
184 |
185 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
186 | hidden_states = hidden_states.to(query.dtype)
187 |
188 | cond_hidden_states = hidden_states[:, block_size:,:]
189 | hidden_states = hidden_states[:, : block_size,:]
190 |
191 | return hidden_states if not use_cond else (hidden_states, cond_hidden_states)
192 |
193 |
194 | class MultiDoubleStreamBlockLoraProcessor(nn.Module):
195 | def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
196 | super().__init__()
197 |
198 | # Initialize a list to store the LoRA layers
199 | self.n_loras = n_loras
200 | self.cond_width = cond_width
201 | self.cond_height = cond_height
202 | self.q_loras = nn.ModuleList([
203 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
204 | for i in range(n_loras)
205 | ])
206 | self.k_loras = nn.ModuleList([
207 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
208 | for i in range(n_loras)
209 | ])
210 | self.v_loras = nn.ModuleList([
211 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
212 | for i in range(n_loras)
213 | ])
214 | self.proj_loras = nn.ModuleList([
215 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
216 | for i in range(n_loras)
217 | ])
218 | self.lora_weights = lora_weights
219 | self.bank_attn = None
220 | self.bank_kv = []
221 |
222 |
223 | def __call__(self,
224 | attn: Attention,
225 | hidden_states: torch.FloatTensor,
226 | encoder_hidden_states: torch.FloatTensor = None,
227 | attention_mask: Optional[torch.FloatTensor] = None,
228 | image_rotary_emb: Optional[torch.Tensor] = None,
229 | use_cond=False,
230 | ) -> torch.FloatTensor:
231 |
232 | batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
233 | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
234 | block_size = hidden_states.shape[1] - cond_size * self.n_loras
235 | scaled_seq_len = encoder_hidden_states.shape[1] + hidden_states.shape[1]
236 | scaled_cond_size = cond_size
237 | scaled_block_size = scaled_seq_len - scaled_cond_size * self.n_loras
238 |
239 | # `context` projections.
240 | inner_dim = 3072
241 | head_dim = inner_dim // attn.heads
242 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
243 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
244 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
245 |
246 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
247 | batch_size, -1, attn.heads, head_dim
248 | ).transpose(1, 2)
249 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
250 | batch_size, -1, attn.heads, head_dim
251 | ).transpose(1, 2)
252 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
253 | batch_size, -1, attn.heads, head_dim
254 | ).transpose(1, 2)
255 |
256 | if attn.norm_added_q is not None:
257 | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
258 | if attn.norm_added_k is not None:
259 | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
260 |
261 | if len(self.bank_kv)== 0:
262 | cache = True
263 | else:
264 | cache = False
265 |
266 | if cache:
267 |
268 | query = attn.to_q(hidden_states)
269 | key = attn.to_k(hidden_states)
270 | value = attn.to_v(hidden_states)
271 | for i in range(self.n_loras):
272 | query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
273 | key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
274 | value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
275 |
276 | inner_dim = key.shape[-1]
277 | head_dim = inner_dim // attn.heads
278 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
279 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
280 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
281 |
282 |
283 | self.bank_kv.append(key[:, :, block_size:, :])
284 | self.bank_kv.append(value[:, :, block_size:, :])
285 |
286 | if attn.norm_q is not None:
287 | query = attn.norm_q(query)
288 | if attn.norm_k is not None:
289 | key = attn.norm_k(key)
290 |
291 | # attention
292 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
293 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
294 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
295 |
296 | if image_rotary_emb is not None:
297 | from diffusers.models.embeddings import apply_rotary_emb
298 | query = apply_rotary_emb(query, image_rotary_emb)
299 | key = apply_rotary_emb(key, image_rotary_emb)
300 |
301 | num_cond_blocks = self.n_loras
302 | mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
303 | mask[ :scaled_block_size, :] = 0 # First block_size row
304 | for i in range(num_cond_blocks):
305 | start = i * scaled_cond_size + scaled_block_size
306 | end = (i + 1) * scaled_cond_size + scaled_block_size
307 | mask[start:end, start:end] = 0 # Diagonal blocks
308 | mask = mask * -1e20
309 | mask = mask.to(query.dtype)
310 |
311 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
312 | self.bank_attn = hidden_states[:, :, scaled_block_size:, :]
313 |
314 | else:
315 | query = attn.to_q(hidden_states)
316 | key = attn.to_k(hidden_states)
317 | value = attn.to_v(hidden_states)
318 |
319 | inner_dim = query.shape[-1]
320 | head_dim = inner_dim // attn.heads
321 |
322 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
323 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
324 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
325 |
326 | key = torch.concat([key[:, :, :block_size, :], self.bank_kv[0]], dim=-2)
327 | value = torch.concat([value[:, :, :block_size, :], self.bank_kv[1]], dim=-2)
328 |
329 | if attn.norm_q is not None:
330 | query = attn.norm_q(query)
331 | if attn.norm_k is not None:
332 | key = attn.norm_k(key)
333 |
334 | # attention
335 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
336 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
337 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
338 |
339 | if image_rotary_emb is not None:
340 | from diffusers.models.embeddings import apply_rotary_emb
341 | query = apply_rotary_emb(query, image_rotary_emb)
342 | key = apply_rotary_emb(key, image_rotary_emb)
343 |
344 | query = query[:, :, :scaled_block_size, :]
345 |
346 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
347 | hidden_states = torch.concat([hidden_states, self.bank_attn], dim=-2)
348 |
349 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
350 | hidden_states = hidden_states.to(query.dtype)
351 |
352 | encoder_hidden_states, hidden_states = (
353 | hidden_states[:, : encoder_hidden_states.shape[1]],
354 | hidden_states[:, encoder_hidden_states.shape[1] :],
355 | )
356 |
357 | # Linear projection (with LoRA weight applied to each proj layer)
358 | hidden_states = attn.to_out[0](hidden_states)
359 | for i in range(self.n_loras):
360 | hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
361 | # dropout
362 | hidden_states = attn.to_out[1](hidden_states)
363 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
364 |
365 | cond_hidden_states = hidden_states[:, block_size:,:]
366 | hidden_states = hidden_states[:, :block_size,:]
367 |
368 | return (hidden_states, encoder_hidden_states, cond_hidden_states) if use_cond else (encoder_hidden_states, hidden_states)
--------------------------------------------------------------------------------
/src/lora_helper.py:
--------------------------------------------------------------------------------
1 | from diffusers.models.attention_processor import FluxAttnProcessor2_0
2 | from safetensors import safe_open
3 | import re
4 | import torch
5 | from .layers_cache import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
6 |
7 | device = "cuda"
8 |
9 | def load_safetensors(path):
10 | tensors = {}
11 | with safe_open(path, framework="pt", device="cpu") as f:
12 | for key in f.keys():
13 | tensors[key] = f.get_tensor(key)
14 | return tensors
15 |
16 | def get_lora_rank(checkpoint):
17 | for k in checkpoint.keys():
18 | if k.endswith(".down.weight"):
19 | return checkpoint[k].shape[0]
20 |
21 | def load_checkpoint(local_path):
22 | if local_path is not None:
23 | if '.safetensors' in local_path:
24 | print(f"Loading .safetensors checkpoint from {local_path}")
25 | checkpoint = load_safetensors(local_path)
26 | else:
27 | print(f"Loading checkpoint from {local_path}")
28 | checkpoint = torch.load(local_path, map_location='cpu')
29 | return checkpoint
30 |
31 | def update_model_with_lora(checkpoint, lora_weights, transformer, cond_size):
32 | number = len(lora_weights)
33 | ranks = [get_lora_rank(checkpoint) for _ in range(number)]
34 | lora_attn_procs = {}
35 | double_blocks_idx = list(range(19))
36 | single_blocks_idx = list(range(38))
37 | for name, attn_processor in transformer.attn_processors.items():
38 | match = re.search(r'\.(\d+)\.', name)
39 | if match:
40 | layer_index = int(match.group(1))
41 |
42 | if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
43 |
44 | lora_state_dicts = {}
45 | for key, value in checkpoint.items():
46 | # Match based on the layer index in the key (assuming the key contains layer index)
47 | if re.search(r'\.(\d+)\.', key):
48 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
49 | if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
50 | lora_state_dicts[key] = value
51 |
52 | lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
53 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
54 | )
55 |
56 | # Load the weights from the checkpoint dictionary into the corresponding layers
57 | for n in range(number):
58 | lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
59 | lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
60 | lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
61 | lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
62 | lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
63 | lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
64 | lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
65 | lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
66 | lora_attn_procs[name].to(device)
67 |
68 | elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
69 |
70 | lora_state_dicts = {}
71 | for key, value in checkpoint.items():
72 | # Match based on the layer index in the key (assuming the key contains layer index)
73 | if re.search(r'\.(\d+)\.', key):
74 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
75 | if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
76 | lora_state_dicts[key] = value
77 |
78 | lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
79 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
80 | )
81 | # Load the weights from the checkpoint dictionary into the corresponding layers
82 | for n in range(number):
83 | lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
84 | lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
85 | lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
86 | lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
87 | lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
88 | lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
89 | lora_attn_procs[name].to(device)
90 | else:
91 | lora_attn_procs[name] = FluxAttnProcessor2_0()
92 |
93 | transformer.set_attn_processor(lora_attn_procs)
94 |
95 |
96 | def update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size):
97 | ck_number = len(checkpoints)
98 | cond_lora_number = [len(ls) for ls in lora_weights]
99 | cond_number = sum(cond_lora_number)
100 | ranks = [get_lora_rank(checkpoint) for checkpoint in checkpoints]
101 | multi_lora_weight = []
102 | for ls in lora_weights:
103 | for n in ls:
104 | multi_lora_weight.append(n)
105 |
106 | lora_attn_procs = {}
107 | double_blocks_idx = list(range(19))
108 | single_blocks_idx = list(range(38))
109 | for name, attn_processor in transformer.attn_processors.items():
110 | match = re.search(r'\.(\d+)\.', name)
111 | if match:
112 | layer_index = int(match.group(1))
113 |
114 | if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
115 | lora_state_dicts = [{} for _ in range(ck_number)]
116 | for idx, checkpoint in enumerate(checkpoints):
117 | for key, value in checkpoint.items():
118 | # Match based on the layer index in the key (assuming the key contains layer index)
119 | if re.search(r'\.(\d+)\.', key):
120 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
121 | if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
122 | lora_state_dicts[idx][key] = value
123 |
124 | lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
125 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
126 | )
127 |
128 | # Load the weights from the checkpoint dictionary into the corresponding layers
129 | num = 0
130 | for idx in range(ck_number):
131 | for n in range(cond_lora_number[idx]):
132 | lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
133 | lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
134 | lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
135 | lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
136 | lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
137 | lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
138 | lora_attn_procs[name].proj_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.down.weight', None)
139 | lora_attn_procs[name].proj_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.up.weight', None)
140 | lora_attn_procs[name].to(device)
141 | num += 1
142 |
143 | elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
144 |
145 | lora_state_dicts = [{} for _ in range(ck_number)]
146 | for idx, checkpoint in enumerate(checkpoints):
147 | for key, value in checkpoint.items():
148 | # Match based on the layer index in the key (assuming the key contains layer index)
149 | if re.search(r'\.(\d+)\.', key):
150 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
151 | if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
152 | lora_state_dicts[idx][key] = value
153 |
154 | lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
155 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
156 | )
157 | # Load the weights from the checkpoint dictionary into the corresponding layers
158 | num = 0
159 | for idx in range(ck_number):
160 | for n in range(cond_lora_number[idx]):
161 | lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
162 | lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
163 | lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
164 | lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
165 | lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
166 | lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
167 | lora_attn_procs[name].to(device)
168 | num += 1
169 |
170 | else:
171 | lora_attn_procs[name] = FluxAttnProcessor2_0()
172 |
173 | transformer.set_attn_processor(lora_attn_procs)
174 |
175 |
176 | def set_single_lora(transformer, local_path, lora_weights=[], cond_size=512):
177 | checkpoint = load_checkpoint(local_path)
178 | update_model_with_lora(checkpoint, lora_weights, transformer, cond_size)
179 |
180 | def set_multi_lora(transformer, local_paths, lora_weights=[[]], cond_size=512):
181 | checkpoints = [load_checkpoint(local_path) for local_path in local_paths]
182 | update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size)
183 |
184 | def unset_lora(transformer):
185 | lora_attn_procs = {}
186 | for name, attn_processor in transformer.attn_processors.items():
187 | lora_attn_procs[name] = FluxAttnProcessor2_0()
188 | transformer.set_attn_processor(lora_attn_procs)
189 |
190 |
191 | '''
192 | unset_lora(pipe.transformer)
193 | lora_path = "./lora.safetensors"
194 | lora_weights = [1, 1]
195 | set_lora(pipe.transformer, local_path=lora_path, lora_weights=lora_weights, cond_size=512)
196 | '''
--------------------------------------------------------------------------------
/src/transformer_flux.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional, Tuple, Union
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from diffusers.configuration_utils import ConfigMixin, register_to_config
9 | from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
10 | from diffusers.models.attention import FeedForward
11 | from diffusers.models.attention_processor import (
12 | Attention,
13 | AttentionProcessor,
14 | FluxAttnProcessor2_0,
15 | FluxAttnProcessor2_0_NPU,
16 | FusedFluxAttnProcessor2_0,
17 | )
18 | from diffusers.models.modeling_utils import ModelMixin
19 | from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
20 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
21 | from diffusers.utils.import_utils import is_torch_npu_available
22 | from diffusers.utils.torch_utils import maybe_allow_in_graph
23 | from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
24 | from diffusers.models.modeling_outputs import Transformer2DModelOutput
25 |
26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27 |
28 | @maybe_allow_in_graph
29 | class FluxSingleTransformerBlock(nn.Module):
30 |
31 | def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
32 | super().__init__()
33 | self.mlp_hidden_dim = int(dim * mlp_ratio)
34 |
35 | self.norm = AdaLayerNormZeroSingle(dim)
36 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
37 | self.act_mlp = nn.GELU(approximate="tanh")
38 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
39 |
40 | if is_torch_npu_available():
41 | processor = FluxAttnProcessor2_0_NPU()
42 | else:
43 | processor = FluxAttnProcessor2_0()
44 | self.attn = Attention(
45 | query_dim=dim,
46 | cross_attention_dim=None,
47 | dim_head=attention_head_dim,
48 | heads=num_attention_heads,
49 | out_dim=dim,
50 | bias=True,
51 | processor=processor,
52 | qk_norm="rms_norm",
53 | eps=1e-6,
54 | pre_only=True,
55 | )
56 |
57 | def forward(
58 | self,
59 | hidden_states: torch.Tensor,
60 | cond_hidden_states: torch.Tensor,
61 | temb: torch.Tensor,
62 | cond_temb: torch.Tensor,
63 | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
64 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
65 | ) -> torch.Tensor:
66 | use_cond = cond_hidden_states is not None
67 |
68 | residual = hidden_states
69 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
70 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
71 |
72 | if use_cond:
73 | residual_cond = cond_hidden_states
74 | norm_cond_hidden_states, cond_gate = self.norm(cond_hidden_states, emb=cond_temb)
75 | mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_cond_hidden_states))
76 |
77 | norm_hidden_states_concat = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
78 |
79 | joint_attention_kwargs = joint_attention_kwargs or {}
80 | attn_output = self.attn(
81 | hidden_states=norm_hidden_states_concat,
82 | image_rotary_emb=image_rotary_emb,
83 | use_cond=use_cond,
84 | **joint_attention_kwargs,
85 | )
86 | if use_cond:
87 | attn_output, cond_attn_output = attn_output
88 |
89 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
90 | gate = gate.unsqueeze(1)
91 | hidden_states = gate * self.proj_out(hidden_states)
92 | hidden_states = residual + hidden_states
93 |
94 | if use_cond:
95 | condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
96 | cond_gate = cond_gate.unsqueeze(1)
97 | condition_latents = cond_gate * self.proj_out(condition_latents)
98 | condition_latents = residual_cond + condition_latents
99 |
100 | if hidden_states.dtype == torch.float16:
101 | hidden_states = hidden_states.clip(-65504, 65504)
102 |
103 | return hidden_states, condition_latents if use_cond else None
104 |
105 |
106 | @maybe_allow_in_graph
107 | class FluxTransformerBlock(nn.Module):
108 | def __init__(
109 | self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
110 | ):
111 | super().__init__()
112 |
113 | self.norm1 = AdaLayerNormZero(dim)
114 |
115 | self.norm1_context = AdaLayerNormZero(dim)
116 |
117 | if hasattr(F, "scaled_dot_product_attention"):
118 | processor = FluxAttnProcessor2_0()
119 | else:
120 | raise ValueError(
121 | "The current PyTorch version does not support the `scaled_dot_product_attention` function."
122 | )
123 | self.attn = Attention(
124 | query_dim=dim,
125 | cross_attention_dim=None,
126 | added_kv_proj_dim=dim,
127 | dim_head=attention_head_dim,
128 | heads=num_attention_heads,
129 | out_dim=dim,
130 | context_pre_only=False,
131 | bias=True,
132 | processor=processor,
133 | qk_norm=qk_norm,
134 | eps=eps,
135 | )
136 |
137 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
138 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
139 |
140 | self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
141 | self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
142 |
143 | # let chunk size default to None
144 | self._chunk_size = None
145 | self._chunk_dim = 0
146 |
147 | def forward(
148 | self,
149 | hidden_states: torch.Tensor,
150 | cond_hidden_states: torch.Tensor,
151 | encoder_hidden_states: torch.Tensor,
152 | temb: torch.Tensor,
153 | cond_temb: torch.Tensor,
154 | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
155 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
156 | ) -> Tuple[torch.Tensor, torch.Tensor]:
157 | use_cond = cond_hidden_states is not None
158 |
159 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
160 | if use_cond:
161 | (
162 | norm_cond_hidden_states,
163 | cond_gate_msa,
164 | cond_shift_mlp,
165 | cond_scale_mlp,
166 | cond_gate_mlp,
167 | ) = self.norm1(cond_hidden_states, emb=cond_temb)
168 |
169 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
170 | encoder_hidden_states, emb=temb
171 | )
172 |
173 | norm_hidden_states = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
174 |
175 | joint_attention_kwargs = joint_attention_kwargs or {}
176 | # Attention.
177 | attention_outputs = self.attn(
178 | hidden_states=norm_hidden_states,
179 | encoder_hidden_states=norm_encoder_hidden_states,
180 | image_rotary_emb=image_rotary_emb,
181 | use_cond=use_cond,
182 | **joint_attention_kwargs,
183 | )
184 |
185 | attn_output, context_attn_output = attention_outputs[:2]
186 | cond_attn_output = attention_outputs[2] if use_cond else None
187 |
188 | # Process attention outputs for the `hidden_states`.
189 | attn_output = gate_msa.unsqueeze(1) * attn_output
190 | hidden_states = hidden_states + attn_output
191 |
192 | if use_cond:
193 | cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
194 | cond_hidden_states = cond_hidden_states + cond_attn_output
195 |
196 | norm_hidden_states = self.norm2(hidden_states)
197 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
198 |
199 | if use_cond:
200 | norm_cond_hidden_states = self.norm2(cond_hidden_states)
201 | norm_cond_hidden_states = (
202 | norm_cond_hidden_states * (1 + cond_scale_mlp[:, None])
203 | + cond_shift_mlp[:, None]
204 | )
205 |
206 | ff_output = self.ff(norm_hidden_states)
207 | ff_output = gate_mlp.unsqueeze(1) * ff_output
208 | hidden_states = hidden_states + ff_output
209 |
210 | if use_cond:
211 | cond_ff_output = self.ff(norm_cond_hidden_states)
212 | cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
213 | cond_hidden_states = cond_hidden_states + cond_ff_output
214 |
215 | # Process attention outputs for the `encoder_hidden_states`.
216 |
217 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
218 | encoder_hidden_states = encoder_hidden_states + context_attn_output
219 |
220 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
221 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
222 |
223 | context_ff_output = self.ff_context(norm_encoder_hidden_states)
224 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
225 | if encoder_hidden_states.dtype == torch.float16:
226 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
227 |
228 | return encoder_hidden_states, hidden_states, cond_hidden_states if use_cond else None
229 |
230 |
231 | class FluxTransformer2DModel(
232 | ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
233 | ):
234 | _supports_gradient_checkpointing = True
235 | _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
236 |
237 | @register_to_config
238 | def __init__(
239 | self,
240 | patch_size: int = 1,
241 | in_channels: int = 64,
242 | out_channels: Optional[int] = None,
243 | num_layers: int = 19,
244 | num_single_layers: int = 38,
245 | attention_head_dim: int = 128,
246 | num_attention_heads: int = 24,
247 | joint_attention_dim: int = 4096,
248 | pooled_projection_dim: int = 768,
249 | guidance_embeds: bool = False,
250 | axes_dims_rope: Tuple[int] = (16, 56, 56),
251 | ):
252 | super().__init__()
253 | self.out_channels = out_channels or in_channels
254 | self.inner_dim = num_attention_heads * attention_head_dim
255 |
256 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
257 |
258 | text_time_guidance_cls = (
259 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
260 | )
261 | self.time_text_embed = text_time_guidance_cls(
262 | embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
263 | )
264 |
265 | self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
266 | self.x_embedder = nn.Linear(in_channels, self.inner_dim)
267 |
268 | self.transformer_blocks = nn.ModuleList(
269 | [
270 | FluxTransformerBlock(
271 | dim=self.inner_dim,
272 | num_attention_heads=num_attention_heads,
273 | attention_head_dim=attention_head_dim,
274 | )
275 | for _ in range(num_layers)
276 | ]
277 | )
278 |
279 | self.single_transformer_blocks = nn.ModuleList(
280 | [
281 | FluxSingleTransformerBlock(
282 | dim=self.inner_dim,
283 | num_attention_heads=num_attention_heads,
284 | attention_head_dim=attention_head_dim,
285 | )
286 | for _ in range(num_single_layers)
287 | ]
288 | )
289 |
290 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
291 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
292 |
293 | self.gradient_checkpointing = False
294 |
295 | @property
296 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
297 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
298 | r"""
299 | Returns:
300 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
301 | indexed by its weight name.
302 | """
303 | # set recursively
304 | processors = {}
305 |
306 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
307 | if hasattr(module, "get_processor"):
308 | processors[f"{name}.processor"] = module.get_processor()
309 |
310 | for sub_name, child in module.named_children():
311 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
312 |
313 | return processors
314 |
315 | for name, module in self.named_children():
316 | fn_recursive_add_processors(name, module, processors)
317 |
318 | return processors
319 |
320 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
321 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
322 | r"""
323 | Sets the attention processor to use to compute attention.
324 |
325 | Parameters:
326 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
327 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
328 | for **all** `Attention` layers.
329 |
330 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
331 | processor. This is strongly recommended when setting trainable attention processors.
332 |
333 | """
334 | count = len(self.attn_processors.keys())
335 |
336 | if isinstance(processor, dict) and len(processor) != count:
337 | raise ValueError(
338 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
339 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
340 | )
341 |
342 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
343 | if hasattr(module, "set_processor"):
344 | if not isinstance(processor, dict):
345 | module.set_processor(processor)
346 | else:
347 | module.set_processor(processor.pop(f"{name}.processor"))
348 |
349 | for sub_name, child in module.named_children():
350 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
351 |
352 | for name, module in self.named_children():
353 | fn_recursive_attn_processor(name, module, processor)
354 |
355 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
356 | def fuse_qkv_projections(self):
357 | """
358 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
359 | are fused. For cross-attention modules, key and value projection matrices are fused.
360 |
361 |
362 |
363 | This API is 🧪 experimental.
364 |
365 |
366 | """
367 | self.original_attn_processors = None
368 |
369 | for _, attn_processor in self.attn_processors.items():
370 | if "Added" in str(attn_processor.__class__.__name__):
371 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
372 |
373 | self.original_attn_processors = self.attn_processors
374 |
375 | for module in self.modules():
376 | if isinstance(module, Attention):
377 | module.fuse_projections(fuse=True)
378 |
379 | self.set_attn_processor(FusedFluxAttnProcessor2_0())
380 |
381 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
382 | def unfuse_qkv_projections(self):
383 | """Disables the fused QKV projection if enabled.
384 |
385 |
386 |
387 | This API is 🧪 experimental.
388 |
389 |
390 |
391 | """
392 | if self.original_attn_processors is not None:
393 | self.set_attn_processor(self.original_attn_processors)
394 |
395 | def _set_gradient_checkpointing(self, module, value=False):
396 | if hasattr(module, "gradient_checkpointing"):
397 | module.gradient_checkpointing = value
398 |
399 | def forward(
400 | self,
401 | hidden_states: torch.Tensor,
402 | cond_hidden_states: torch.Tensor = None,
403 | encoder_hidden_states: torch.Tensor = None,
404 | pooled_projections: torch.Tensor = None,
405 | timestep: torch.LongTensor = None,
406 | img_ids: torch.Tensor = None,
407 | txt_ids: torch.Tensor = None,
408 | guidance: torch.Tensor = None,
409 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
410 | controlnet_block_samples=None,
411 | controlnet_single_block_samples=None,
412 | return_dict: bool = True,
413 | controlnet_blocks_repeat: bool = False,
414 | ) -> Union[torch.Tensor, Transformer2DModelOutput]:
415 | if cond_hidden_states is not None:
416 | use_condition = True
417 | else:
418 | use_condition = False
419 |
420 | if joint_attention_kwargs is not None:
421 | joint_attention_kwargs = joint_attention_kwargs.copy()
422 | lora_scale = joint_attention_kwargs.pop("scale", 1.0)
423 | else:
424 | lora_scale = 1.0
425 |
426 | if USE_PEFT_BACKEND:
427 | # weight the lora layers by setting `lora_scale` for each PEFT layer
428 | scale_lora_layers(self, lora_scale)
429 | else:
430 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
431 | logger.warning(
432 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
433 | )
434 |
435 | hidden_states = self.x_embedder(hidden_states)
436 | cond_hidden_states = self.x_embedder(cond_hidden_states)
437 |
438 | timestep = timestep.to(hidden_states.dtype) * 1000
439 | if guidance is not None:
440 | guidance = guidance.to(hidden_states.dtype) * 1000
441 | else:
442 | guidance = None
443 |
444 | temb = (
445 | self.time_text_embed(timestep, pooled_projections)
446 | if guidance is None
447 | else self.time_text_embed(timestep, guidance, pooled_projections)
448 | )
449 |
450 | cond_temb = (
451 | self.time_text_embed(torch.ones_like(timestep) * 0, pooled_projections)
452 | if guidance is None
453 | else self.time_text_embed(
454 | torch.ones_like(timestep) * 0, guidance, pooled_projections
455 | )
456 | )
457 |
458 | encoder_hidden_states = self.context_embedder(encoder_hidden_states)
459 |
460 | if txt_ids.ndim == 3:
461 | logger.warning(
462 | "Passing `txt_ids` 3d torch.Tensor is deprecated."
463 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
464 | )
465 | txt_ids = txt_ids[0]
466 | if img_ids.ndim == 3:
467 | logger.warning(
468 | "Passing `img_ids` 3d torch.Tensor is deprecated."
469 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
470 | )
471 | img_ids = img_ids[0]
472 |
473 | ids = torch.cat((txt_ids, img_ids), dim=0)
474 | image_rotary_emb = self.pos_embed(ids)
475 |
476 | if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
477 | ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
478 | ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
479 | joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
480 |
481 | for index_block, block in enumerate(self.transformer_blocks):
482 | if torch.is_grad_enabled() and self.gradient_checkpointing:
483 |
484 | def create_custom_forward(module, return_dict=None):
485 | def custom_forward(*inputs):
486 | if return_dict is not None:
487 | return module(*inputs, return_dict=return_dict)
488 | else:
489 | return module(*inputs)
490 |
491 | return custom_forward
492 |
493 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
494 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
495 | create_custom_forward(block),
496 | hidden_states,
497 | encoder_hidden_states,
498 | temb,
499 | image_rotary_emb,
500 | cond_temb=cond_temb if use_condition else None,
501 | cond_hidden_states=cond_hidden_states if use_condition else None,
502 | **ckpt_kwargs,
503 | )
504 |
505 | else:
506 | encoder_hidden_states, hidden_states, cond_hidden_states = block(
507 | hidden_states=hidden_states,
508 | encoder_hidden_states=encoder_hidden_states,
509 | cond_hidden_states=cond_hidden_states if use_condition else None,
510 | temb=temb,
511 | cond_temb=cond_temb if use_condition else None,
512 | image_rotary_emb=image_rotary_emb,
513 | joint_attention_kwargs=joint_attention_kwargs,
514 | )
515 |
516 | # controlnet residual
517 | if controlnet_block_samples is not None:
518 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
519 | interval_control = int(np.ceil(interval_control))
520 | # For Xlabs ControlNet.
521 | if controlnet_blocks_repeat:
522 | hidden_states = (
523 | hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
524 | )
525 | else:
526 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
527 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
528 |
529 | for index_block, block in enumerate(self.single_transformer_blocks):
530 | if torch.is_grad_enabled() and self.gradient_checkpointing:
531 |
532 | def create_custom_forward(module, return_dict=None):
533 | def custom_forward(*inputs):
534 | if return_dict is not None:
535 | return module(*inputs, return_dict=return_dict)
536 | else:
537 | return module(*inputs)
538 |
539 | return custom_forward
540 |
541 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
542 | hidden_states, cond_hidden_states = torch.utils.checkpoint.checkpoint(
543 | create_custom_forward(block),
544 | hidden_states,
545 | temb,
546 | image_rotary_emb,
547 | cond_temb=cond_temb if use_condition else None,
548 | cond_hidden_states=cond_hidden_states if use_condition else None,
549 | **ckpt_kwargs,
550 | )
551 |
552 | else:
553 | hidden_states, cond_hidden_states = block(
554 | hidden_states=hidden_states,
555 | cond_hidden_states=cond_hidden_states if use_condition else None,
556 | temb=temb,
557 | cond_temb=cond_temb if use_condition else None,
558 | image_rotary_emb=image_rotary_emb,
559 | joint_attention_kwargs=joint_attention_kwargs,
560 | )
561 |
562 | # controlnet residual
563 | if controlnet_single_block_samples is not None:
564 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
565 | interval_control = int(np.ceil(interval_control))
566 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
567 | hidden_states[:, encoder_hidden_states.shape[1] :, ...]
568 | + controlnet_single_block_samples[index_block // interval_control]
569 | )
570 |
571 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
572 |
573 | hidden_states = self.norm_out(hidden_states, temb)
574 | output = self.proj_out(hidden_states)
575 |
576 | if USE_PEFT_BACKEND:
577 | # remove `lora_scale` from each PEFT layer
578 | unscale_lora_layers(self, lora_scale)
579 |
580 | if not return_dict:
581 | return (output,)
582 |
583 | return Transformer2DModelOutput(sample=output)
--------------------------------------------------------------------------------
/test_imgs/canny.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/canny.png
--------------------------------------------------------------------------------
/test_imgs/depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/depth.png
--------------------------------------------------------------------------------
/test_imgs/ghibli.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/ghibli.png
--------------------------------------------------------------------------------
/test_imgs/inpainting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/inpainting.png
--------------------------------------------------------------------------------
/test_imgs/openpose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/openpose.png
--------------------------------------------------------------------------------
/test_imgs/seg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/seg.png
--------------------------------------------------------------------------------
/test_imgs/subject_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/subject_0.png
--------------------------------------------------------------------------------
/test_imgs/subject_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/test_imgs/subject_1.png
--------------------------------------------------------------------------------
/train/default_config.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: MULTI_GPU
4 | main_process_port: 14121
5 | downcast_bf16: 'no'
6 | gpu_ids: all
7 | machine_rank: 0
8 | main_training_function: main
9 | mixed_precision: fp16
10 | num_machines: 1
11 | num_processes: 4
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/train/examples/openpose_data/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/train/examples/openpose_data/1.png
--------------------------------------------------------------------------------
/train/examples/openpose_data/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/train/examples/openpose_data/2.png
--------------------------------------------------------------------------------
/train/examples/pose.jsonl:
--------------------------------------------------------------------------------
1 | {"source": "./examples/openpose_data/2.png", "caption": "A girl wearing a green coat.", "target": "./examples/openpose_data/1.png"}
2 | {"source": "./examples/openpose_data/2.png", "caption": "A girl wearing a green coat.", "target": "./examples/openpose_data/1.png"}
3 | {"source": "./examples/openpose_data/2.png", "caption": "A girl wearing a green coat.", "target": "./examples/openpose_data/1.png"}
--------------------------------------------------------------------------------
/train/examples/style.jsonl:
--------------------------------------------------------------------------------
1 | {"source": "./examples/style_data/5.png", "caption": "Ghibli Studio style, A digital illustration of an elderly couple standing on a grassy field, holding oranges.", "target": "./examples/style_data/6.png"}
2 | {"source": "./examples/style_data/5.png", "caption": "Ghibli Studio style, A digital illustration of an elderly couple standing on a grassy field, holding oranges.", "target": "./examples/style_data/6.png"}
3 | {"source": "./examples/style_data/5.png", "caption": "Ghibli Studio style, A digital illustration of an elderly couple standing on a grassy field, holding oranges.", "target": "./examples/style_data/6.png"}
--------------------------------------------------------------------------------
/train/examples/style_data/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/train/examples/style_data/5.png
--------------------------------------------------------------------------------
/train/examples/style_data/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/train/examples/style_data/6.png
--------------------------------------------------------------------------------
/train/examples/subject.jsonl:
--------------------------------------------------------------------------------
1 | {"source": "./examples/subject_data/3.png", "caption": "A SKS float on the water.", "target": "./examples/subject_data/4.png"}
2 | {"source": "./examples/subject_data/3.png", "caption": "A SKS float on the water.", "target": "./examples/subject_data/4.png"}
3 | {"source": "./examples/subject_data/3.png", "caption": "A SKS float on the water.", "target": "./examples/subject_data/4.png"}
--------------------------------------------------------------------------------
/train/examples/subject_data/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/train/examples/subject_data/3.png
--------------------------------------------------------------------------------
/train/examples/subject_data/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/train/examples/subject_data/4.png
--------------------------------------------------------------------------------
/train/readme.md:
--------------------------------------------------------------------------------
1 | # Model Training Guide
2 |
3 | This document provides a step-by-step guide for training the model in this project.
4 |
5 | ## Environment Setup
6 |
7 | 1. Ensure the following dependencies are installed:
8 | - Python 3.10.16
9 | - PyTorch 2.5.1+cu121
10 | - Required libraries (install via `requirements.txt`)
11 |
12 | ```bash
13 | pip install -r requirements.txt
14 | ```
15 |
16 | ## Data Preparation
17 |
18 | - Ensure the data format matches the requirements of the training dataset (e.g., `examples/pose.jsonl`, `examples/subject.jsonl`, `examples/style.jsonl`).
19 |
20 | ## Start Training
21 |
22 | 1. Use the following commands to start training:
23 |
24 | - For spatial control:
25 | ```bash
26 | bash ./train_spatial.sh
27 | ```
28 | - For subject control:
29 | ```bash
30 | bash ./train_subject.sh
31 | ```
32 | - For style control:
33 | ```bash
34 | bash ./train_style.sh
35 | ```
36 |
37 | 2. Example training configuration:
38 |
39 | ```bash
40 | --pretrained_model_name_or_path $MODEL_DIR \ # Path to the FLUX model
41 | --cond_size=512 \ # Source image size (recommended: 384-512 or higher for better detail control)
42 | --noise_size=1024 \ # Target image's longest side size (recommended: 1024 for better resolution)
43 | --subject_column="None" \ # JSONL key for subject; set to "None" if using spatial condition
44 | --spatial_column="source" \ # JSONL key for spatial; set to "None" if using subject condition
45 | --target_column="target" \ # JSONL key for the target image
46 | --caption_column="caption" \ # JSONL key for the caption
47 | --ranks 128 \ # LoRA rank (recommended: 128)
48 | --network_alphas 128 \ # LoRA network alphas (recommended: 128)
49 | --output_dir=$OUTPUT_DIR \ # Directory for model and validation outputs
50 | --logging_dir=$LOG_PATH \ # Directory for logs
51 | --mixed_precision="bf16" \ # Recommended: bf16
52 | --train_data_dir=$TRAIN_DATA \ # Path to the training data JSONL file
53 | --learning_rate=1e-4 \ # Recommended: 1e-4
54 | --train_batch_size=1 \ # Only supports 1 due to multi-resolution target images
55 | --validation_prompt "Ghibli Studio style, Charming hand-drawn anime-style illustration" \ # Validation prompt
56 | --num_train_epochs=1000 \ # Total number of epochs
57 | --validation_steps=20 \ # Validate every n steps
58 | --checkpointing_steps=20 \ # Save model every n steps
59 | --spatial_test_images "./examples/style_data/5.png" \ # Validation images for spatial condition
60 | --subject_test_images None \ # Validation images for subject condition
61 | --test_h 1024 \ # Height of validation images
62 | --test_w 1024 \ # Width of validation images
63 | --num_validation_images=2 # Number of validation images
64 | ```
65 |
66 | ## Model Inference
67 |
68 | 1. After training, use the following script for inference:
69 |
70 | ```bash
71 | # Navigate to the original repository to use KV cache
72 | cd ..
73 | ```
74 |
75 | ```python
76 | import torch
77 | from PIL import Image
78 | from src.pipeline import FluxPipeline
79 | from src.transformer_flux import FluxTransformer2DModel
80 | from src.lora_helper import set_single_lora, set_multi_lora
81 |
82 | def clear_cache(transformer):
83 | for name, attn_processor in transformer.attn_processors.items():
84 | attn_processor.bank_kv.clear()
85 |
86 | # Initialize the model
87 | device = "cuda"
88 | base_path = "black-forest-labs/FLUX.1-dev" # Path to the base model
89 | pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16, device=device)
90 | transformer = FluxTransformer2DModel.from_pretrained(
91 | base_path,
92 | subfolder="transformer",
93 | torch_dtype=torch.bfloat16,
94 | device=device
95 | )
96 | pipe.transformer = transformer
97 | pipe.to(device)
98 |
99 | # Path to your trained EasyControl model
100 | lora_path = " "
101 |
102 | # Single condition control example
103 | set_single_lora(pipe.transformer, path, lora_weights=[1], cond_size=512)
104 |
105 | # Set your control image path
106 | spatial_image_path = ""
107 | subject_image_path = ""
108 | style_image_path = ""
109 |
110 | control_image = Image.open("fill in spatial_image_path or subject_image_path !!")
111 | prompt = "fill in your prompt!!"
112 |
113 | # For spatial or style control
114 | image = pipe(
115 | prompt,
116 | height=768, # Generated image height
117 | width=1024, # Generated image width
118 | guidance_scale=3.5,
119 | num_inference_steps=25, # Number of steps
120 | max_sequence_length=512,
121 | generator=torch.Generator("cpu").manual_seed(5),
122 | spatial_images=[control_image],
123 | subject_images=[],
124 | cond_size=512, # Training setting for cond_size
125 | ).images[0]
126 | # Clear cache after generation
127 | clear_cache(pipe.transformer)
128 | image.save("output.png")
129 | ```
130 |
131 | 2. For subject control:
132 |
133 | ```python
134 | image = pipe(
135 | prompt,
136 | height=768,
137 | width=1024,
138 | guidance_scale=3.5,
139 | num_inference_steps=25,
140 | max_sequence_length=512,
141 | generator=torch.Generator("cpu").manual_seed(5),
142 | spatial_images=[],
143 | subject_images=[control_image],
144 | cond_size=512,
145 | ).images[0]
146 | # Clear cache after generation
147 | clear_cache(pipe.transformer)
148 | image.save("output.png")
149 | ```
150 |
151 | 3. For multi-condition control:
152 |
153 | ```python
154 | import torch
155 | from PIL import Image
156 | from src.pipeline import FluxPipeline
157 | from src.transformer_flux import FluxTransformer2DModel
158 | from src.lora_helper import set_single_lora, set_multi_lora
159 |
160 | def clear_cache(transformer):
161 | for name, attn_processor in transformer.attn_processors.items():
162 | attn_processor.bank_kv.clear()
163 |
164 | # Initialize the model
165 | device = "cuda"
166 | base_path = "black-forest-labs/FLUX.1-dev" # Path to the base model
167 | pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16, device=device)
168 | transformer = FluxTransformer2DModel.from_pretrained(
169 | base_path,
170 | subfolder="transformer",
171 | torch_dtype=torch.bfloat16,
172 | device=device
173 | )
174 | pipe.transformer = transformer
175 | pipe.to(device)
176 |
177 | # Change to your EasyControl Model path!!!
178 | lora_path = "./models"
179 | control_models = {
180 | "canny": f"{lora_path}/canny.safetensors",
181 | "depth": f"{lora_path}/depth.safetensors",
182 | "hedsketch": f"{lora_path}/hedsketch.safetensors",
183 | "pose": f"{lora_path}/pose.safetensors",
184 | "seg": f"{lora_path}/seg.safetensors",
185 | "inpainting": f"{lora_path}/inpainting.safetensors",
186 | "subject": f"{lora_path}/subject.safetensors",
187 | }
188 | paths = [control_models["subject"], control_models["inpainting"]]
189 | set_multi_lora(pipe.transformer, paths, lora_weights=[[1], [1]], cond_size=512)
190 |
191 | # Subject + spatial control
192 | prompt = "A SKS on the car"
193 | subject_images = [Image.open("./test_imgs/subject_1.png").convert("RGB")]
194 | spatial_images = [Image.open("./test_imgs/inpainting.png").convert("RGB")]
195 | image = pipe(
196 | prompt,
197 | height=1024,
198 | width=1024,
199 | guidance_scale=3.5,
200 | num_inference_steps=25,
201 | max_sequence_length=512,
202 | generator=torch.Generator("cpu").manual_seed(42),
203 | subject_images=subject_images,
204 | spatial_images=spatial_images,
205 | cond_size=512,
206 | ).images[0]
207 | # Clear cache after generation
208 | clear_cache(pipe.transformer)
209 | image.save("output_multi.png")
210 | ```
211 |
212 | 4. For spatial + spatial control:
213 |
214 | ```python
215 | prompt = "A car"
216 | subject_images = []
217 | spatial_images = [Image.open("image1_path").convert("RGB"), Image.open("image2_path").convert("RGB")]
218 | image = pipe(
219 | prompt,
220 | height=1024,
221 | width=1024,
222 | guidance_scale=3.5,
223 | num_inference_steps=25,
224 | max_sequence_length=512,
225 | generator=torch.Generator("cpu").manual_seed(42),
226 | subject_images=subject_images,
227 | spatial_images=spatial_images,
228 | cond_size=512,
229 | ).images[0]
230 | # Clear cache after generation
231 | clear_cache(pipe.transformer)
232 | image.save("output_multi.png")
233 | ```
234 |
235 | ## Notes
236 |
237 | - Adjust `noise_size` and `cond_size` based on your VRAM.
238 | - Batch size is limited to 1.
239 |
--------------------------------------------------------------------------------
/train/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaojiu-z/EasyControl/1f1d7df8d52c461c052e0183005908fbb13298ef/train/src/__init__.py
--------------------------------------------------------------------------------
/train/src/jsonl_datasets.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from datasets import load_dataset
3 | from torchvision import transforms
4 | import random
5 | import torch
6 |
7 | Image.MAX_IMAGE_PIXELS = None
8 |
9 | def multiple_16(num: float):
10 | return int(round(num / 16) * 16)
11 |
12 | def get_random_resolution(min_size=512, max_size=1280, multiple=16):
13 | resolution = random.randint(min_size // multiple, max_size // multiple) * multiple
14 | return resolution
15 |
16 | def load_image_safely(image_path, size):
17 | try:
18 | image = Image.open(image_path).convert("RGB")
19 | return image
20 | except Exception as e:
21 | print("file error: "+image_path)
22 | with open("failed_images.txt", "a") as f:
23 | f.write(f"{image_path}\n")
24 | return Image.new("RGB", (size, size), (255, 255, 255))
25 |
26 | def make_train_dataset(args, tokenizer, accelerator=None):
27 | if args.train_data_dir is not None:
28 | print("load_data")
29 | dataset = load_dataset('json', data_files=args.train_data_dir)
30 |
31 | column_names = dataset["train"].column_names
32 |
33 | # 6. Get the column names for input/target.
34 | caption_column = args.caption_column
35 | target_column = args.target_column
36 | if args.subject_column is not None:
37 | subject_columns = args.subject_column.split(",")
38 | if args.spatial_column is not None:
39 | spatial_columns= args.spatial_column.split(",")
40 |
41 | size = args.cond_size
42 | noise_size = get_random_resolution(max_size=args.noise_size) # maybe 768 or higher
43 | subject_cond_train_transforms = transforms.Compose(
44 | [
45 | transforms.Lambda(lambda img: img.resize((
46 | multiple_16(size * img.size[0] / max(img.size)),
47 | multiple_16(size * img.size[1] / max(img.size))
48 | ), resample=Image.BILINEAR)),
49 | transforms.RandomHorizontalFlip(p=0.7),
50 | transforms.RandomRotation(degrees=20),
51 | transforms.Lambda(lambda img: transforms.Pad(
52 | padding=(
53 | int((size - img.size[0]) / 2),
54 | int((size - img.size[1]) / 2),
55 | int((size - img.size[0]) / 2),
56 | int((size - img.size[1]) / 2)
57 | ),
58 | fill=0
59 | )(img)),
60 | transforms.ToTensor(),
61 | transforms.Normalize([0.5], [0.5]),
62 | ]
63 | )
64 | cond_train_transforms = transforms.Compose(
65 | [
66 | transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
67 | transforms.CenterCrop((size, size)),
68 | transforms.ToTensor(),
69 | transforms.Normalize([0.5], [0.5]),
70 | ]
71 | )
72 |
73 | def train_transforms(image, noise_size):
74 | train_transforms_ = transforms.Compose(
75 | [
76 | transforms.Lambda(lambda img: img.resize((
77 | multiple_16(noise_size * img.size[0] / max(img.size)),
78 | multiple_16(noise_size * img.size[1] / max(img.size))
79 | ), resample=Image.BILINEAR)),
80 | transforms.ToTensor(),
81 | transforms.Normalize([0.5], [0.5]),
82 | ]
83 | )
84 | transformed_image = train_transforms_(image)
85 | return transformed_image
86 |
87 | def load_and_transform_cond_images(images):
88 | transformed_images = [cond_train_transforms(image) for image in images]
89 | concatenated_image = torch.cat(transformed_images, dim=1)
90 | return concatenated_image
91 |
92 | def load_and_transform_subject_images(images):
93 | transformed_images = [subject_cond_train_transforms(image) for image in images]
94 | concatenated_image = torch.cat(transformed_images, dim=1)
95 | return concatenated_image
96 |
97 | tokenizer_clip = tokenizer[0]
98 | tokenizer_t5 = tokenizer[1]
99 |
100 | def tokenize_prompt_clip_t5(examples):
101 | captions = []
102 | for caption in examples[caption_column]:
103 | if isinstance(caption, str):
104 | if random.random() < 0.1:
105 | captions.append(" ") # 将文本设为空
106 | else:
107 | captions.append(caption)
108 | elif isinstance(caption, list):
109 | # take a random caption if there are multiple
110 | if random.random() < 0.1:
111 | captions.append(" ")
112 | else:
113 | captions.append(random.choice(caption))
114 | else:
115 | raise ValueError(
116 | f"Caption column `{caption_column}` should contain either strings or lists of strings."
117 | )
118 | text_inputs = tokenizer_clip(
119 | captions,
120 | padding="max_length",
121 | max_length=77,
122 | truncation=True,
123 | return_length=False,
124 | return_overflowing_tokens=False,
125 | return_tensors="pt",
126 | )
127 | text_input_ids_1 = text_inputs.input_ids
128 |
129 | text_inputs = tokenizer_t5(
130 | captions,
131 | padding="max_length",
132 | max_length=512,
133 | truncation=True,
134 | return_length=False,
135 | return_overflowing_tokens=False,
136 | return_tensors="pt",
137 | )
138 | text_input_ids_2 = text_inputs.input_ids
139 | return text_input_ids_1, text_input_ids_2
140 |
141 | def preprocess_train(examples):
142 | _examples = {}
143 | if args.subject_column is not None:
144 | subject_images = [[load_image_safely(examples[column][i], args.cond_size) for column in subject_columns] for i in range(len(examples[target_column]))]
145 | _examples["subject_pixel_values"] = [load_and_transform_subject_images(subject) for subject in subject_images]
146 | if args.spatial_column is not None:
147 | spatial_images = [[load_image_safely(examples[column][i], args.cond_size) for column in spatial_columns] for i in range(len(examples[target_column]))]
148 | _examples["cond_pixel_values"] = [load_and_transform_cond_images(spatial) for spatial in spatial_images]
149 | target_images = [load_image_safely(image_path, args.cond_size) for image_path in examples[target_column]]
150 | _examples["pixel_values"] = [train_transforms(image, noise_size) for image in target_images]
151 | _examples["token_ids_clip"], _examples["token_ids_t5"] = tokenize_prompt_clip_t5(examples)
152 | return _examples
153 |
154 | if accelerator is not None:
155 | with accelerator.main_process_first():
156 | train_dataset = dataset["train"].with_transform(preprocess_train)
157 | else:
158 | train_dataset = dataset["train"].with_transform(preprocess_train)
159 |
160 | return train_dataset
161 |
162 |
163 | def collate_fn(examples):
164 | if examples[0].get("cond_pixel_values") is not None:
165 | cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
166 | cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
167 | else:
168 | cond_pixel_values = None
169 | if examples[0].get("subject_pixel_values") is not None:
170 | subject_pixel_values = torch.stack([example["subject_pixel_values"] for example in examples])
171 | subject_pixel_values = subject_pixel_values.to(memory_format=torch.contiguous_format).float()
172 | else:
173 | subject_pixel_values = None
174 |
175 | target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
176 | target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
177 | token_ids_clip = torch.stack([torch.tensor(example["token_ids_clip"]) for example in examples])
178 | token_ids_t5 = torch.stack([torch.tensor(example["token_ids_t5"]) for example in examples])
179 |
180 | return {
181 | "cond_pixel_values": cond_pixel_values,
182 | "subject_pixel_values": subject_pixel_values,
183 | "pixel_values": target_pixel_values,
184 | "text_ids_1": token_ids_clip,
185 | "text_ids_2": token_ids_t5,
186 | }
--------------------------------------------------------------------------------
/train/src/layers.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import math
3 | from typing import Callable, List, Optional, Tuple, Union
4 | from einops import rearrange
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 | from torch import Tensor
9 | from diffusers.models.attention_processor import Attention
10 |
11 | class LoRALinearLayer(nn.Module):
12 | def __init__(
13 | self,
14 | in_features: int,
15 | out_features: int,
16 | rank: int = 4,
17 | network_alpha: Optional[float] = None,
18 | device: Optional[Union[torch.device, str]] = None,
19 | dtype: Optional[torch.dtype] = None,
20 | cond_width=512,
21 | cond_height=512,
22 | number=0,
23 | n_loras=1
24 | ):
25 | super().__init__()
26 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
27 | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
28 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
29 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
30 | self.network_alpha = network_alpha
31 | self.rank = rank
32 | self.out_features = out_features
33 | self.in_features = in_features
34 |
35 | nn.init.normal_(self.down.weight, std=1 / rank)
36 | nn.init.zeros_(self.up.weight)
37 |
38 | self.cond_height = cond_height
39 | self.cond_width = cond_width
40 | self.number = number
41 | self.n_loras = n_loras
42 |
43 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
44 | orig_dtype = hidden_states.dtype
45 | dtype = self.down.weight.dtype
46 |
47 | #### img condition
48 | batch_size = hidden_states.shape[0]
49 | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
50 | block_size = hidden_states.shape[1] - cond_size * self.n_loras
51 | shape = (batch_size, hidden_states.shape[1], 3072)
52 | mask = torch.ones(shape, device=hidden_states.device, dtype=dtype)
53 | mask[:, :block_size+self.number*cond_size, :] = 0
54 | mask[:, block_size+(self.number+1)*cond_size:, :] = 0
55 | hidden_states = mask * hidden_states
56 | ####
57 |
58 | down_hidden_states = self.down(hidden_states.to(dtype))
59 | up_hidden_states = self.up(down_hidden_states)
60 |
61 | if self.network_alpha is not None:
62 | up_hidden_states *= self.network_alpha / self.rank
63 |
64 | return up_hidden_states.to(orig_dtype)
65 |
66 |
67 | class MultiSingleStreamBlockLoraProcessor(nn.Module):
68 | def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
69 | super().__init__()
70 | # Initialize a list to store the LoRA layers
71 | self.n_loras = n_loras
72 | self.cond_width = cond_width
73 | self.cond_height = cond_height
74 |
75 | self.q_loras = nn.ModuleList([
76 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
77 | for i in range(n_loras)
78 | ])
79 | self.k_loras = nn.ModuleList([
80 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
81 | for i in range(n_loras)
82 | ])
83 | self.v_loras = nn.ModuleList([
84 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
85 | for i in range(n_loras)
86 | ])
87 | self.lora_weights = lora_weights
88 |
89 |
90 | def __call__(self,
91 | attn: Attention,
92 | hidden_states: torch.FloatTensor,
93 | encoder_hidden_states: torch.FloatTensor = None,
94 | attention_mask: Optional[torch.FloatTensor] = None,
95 | image_rotary_emb: Optional[torch.Tensor] = None,
96 | use_cond = False,
97 | ) -> torch.FloatTensor:
98 |
99 | batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
100 | query = attn.to_q(hidden_states)
101 | key = attn.to_k(hidden_states)
102 | value = attn.to_v(hidden_states)
103 |
104 | for i in range(self.n_loras):
105 | query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
106 | key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
107 | value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
108 |
109 | inner_dim = key.shape[-1]
110 | head_dim = inner_dim // attn.heads
111 |
112 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
113 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
114 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
115 |
116 | if attn.norm_q is not None:
117 | query = attn.norm_q(query)
118 | if attn.norm_k is not None:
119 | key = attn.norm_k(key)
120 |
121 | if image_rotary_emb is not None:
122 | from diffusers.models.embeddings import apply_rotary_emb
123 | query = apply_rotary_emb(query, image_rotary_emb)
124 | key = apply_rotary_emb(key, image_rotary_emb)
125 |
126 | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
127 | block_size = hidden_states.shape[1] - cond_size * self.n_loras
128 | scaled_cond_size = cond_size
129 | scaled_block_size = block_size
130 | scaled_seq_len = query.shape[2]
131 |
132 | num_cond_blocks = self.n_loras
133 | mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
134 | mask[ :scaled_block_size, :] = 0 # First block_size row
135 | for i in range(num_cond_blocks):
136 | start = i * scaled_cond_size + scaled_block_size
137 | end = (i + 1) * scaled_cond_size + scaled_block_size
138 | mask[start:end, start:end] = 0 # Diagonal blocks
139 | mask = mask * -1e20
140 | mask = mask.to(query.dtype)
141 |
142 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
143 |
144 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
145 | hidden_states = hidden_states.to(query.dtype)
146 |
147 | cond_hidden_states = hidden_states[:, block_size:,:]
148 | hidden_states = hidden_states[:, : block_size,:]
149 |
150 | return hidden_states if not use_cond else (hidden_states, cond_hidden_states)
151 |
152 |
153 | class MultiDoubleStreamBlockLoraProcessor(nn.Module):
154 | def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
155 | super().__init__()
156 |
157 | # Initialize a list to store the LoRA layers
158 | self.n_loras = n_loras
159 | self.cond_width = cond_width
160 | self.cond_height = cond_height
161 | self.q_loras = nn.ModuleList([
162 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
163 | for i in range(n_loras)
164 | ])
165 | self.k_loras = nn.ModuleList([
166 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
167 | for i in range(n_loras)
168 | ])
169 | self.v_loras = nn.ModuleList([
170 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
171 | for i in range(n_loras)
172 | ])
173 | self.proj_loras = nn.ModuleList([
174 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
175 | for i in range(n_loras)
176 | ])
177 | self.lora_weights = lora_weights
178 |
179 |
180 | def __call__(self,
181 | attn: Attention,
182 | hidden_states: torch.FloatTensor,
183 | encoder_hidden_states: torch.FloatTensor = None,
184 | attention_mask: Optional[torch.FloatTensor] = None,
185 | image_rotary_emb: Optional[torch.Tensor] = None,
186 | use_cond=False,
187 | ) -> torch.FloatTensor:
188 |
189 | batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
190 |
191 | # `context` projections.
192 | inner_dim = 3072
193 | head_dim = inner_dim // attn.heads
194 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
195 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
196 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
197 |
198 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
199 | batch_size, -1, attn.heads, head_dim
200 | ).transpose(1, 2)
201 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
202 | batch_size, -1, attn.heads, head_dim
203 | ).transpose(1, 2)
204 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
205 | batch_size, -1, attn.heads, head_dim
206 | ).transpose(1, 2)
207 |
208 | if attn.norm_added_q is not None:
209 | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
210 | if attn.norm_added_k is not None:
211 | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
212 |
213 | query = attn.to_q(hidden_states)
214 | key = attn.to_k(hidden_states)
215 | value = attn.to_v(hidden_states)
216 | for i in range(self.n_loras):
217 | query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
218 | key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
219 | value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
220 |
221 | inner_dim = key.shape[-1]
222 | head_dim = inner_dim // attn.heads
223 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
224 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
225 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
226 |
227 | if attn.norm_q is not None:
228 | query = attn.norm_q(query)
229 | if attn.norm_k is not None:
230 | key = attn.norm_k(key)
231 |
232 | # attention
233 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
234 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
235 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
236 |
237 | if image_rotary_emb is not None:
238 | from diffusers.models.embeddings import apply_rotary_emb
239 | query = apply_rotary_emb(query, image_rotary_emb)
240 | key = apply_rotary_emb(key, image_rotary_emb)
241 |
242 | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
243 | block_size = hidden_states.shape[1] - cond_size * self.n_loras
244 | scaled_cond_size = cond_size
245 | scaled_seq_len = query.shape[2]
246 | scaled_block_size = scaled_seq_len - cond_size * self.n_loras
247 |
248 | num_cond_blocks = self.n_loras
249 | mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
250 | mask[ :scaled_block_size, :] = 0 # First block_size row
251 | for i in range(num_cond_blocks):
252 | start = i * scaled_cond_size + scaled_block_size
253 | end = (i + 1) * scaled_cond_size + scaled_block_size
254 | mask[start:end, start:end] = 0 # Diagonal blocks
255 | mask = mask * -1e20
256 | mask = mask.to(query.dtype)
257 |
258 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
259 |
260 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
261 | hidden_states = hidden_states.to(query.dtype)
262 |
263 | encoder_hidden_states, hidden_states = (
264 | hidden_states[:, : encoder_hidden_states.shape[1]],
265 | hidden_states[:, encoder_hidden_states.shape[1] :],
266 | )
267 |
268 | # Linear projection (with LoRA weight applied to each proj layer)
269 | hidden_states = attn.to_out[0](hidden_states)
270 | for i in range(self.n_loras):
271 | hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
272 | # dropout
273 | hidden_states = attn.to_out[1](hidden_states)
274 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
275 |
276 | cond_hidden_states = hidden_states[:, block_size:,:]
277 | hidden_states = hidden_states[:, :block_size,:]
278 |
279 | return (hidden_states, encoder_hidden_states, cond_hidden_states) if use_cond else (encoder_hidden_states, hidden_states)
--------------------------------------------------------------------------------
/train/src/lora_helper.py:
--------------------------------------------------------------------------------
1 | from diffusers.models.attention_processor import FluxAttnProcessor2_0
2 | from safetensors import safe_open
3 | import re
4 | import torch
5 | from .layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
6 |
7 | device = "cuda"
8 |
9 | def load_safetensors(path):
10 | tensors = {}
11 | with safe_open(path, framework="pt", device="cpu") as f:
12 | for key in f.keys():
13 | tensors[key] = f.get_tensor(key)
14 | return tensors
15 |
16 | def get_lora_rank(checkpoint):
17 | for k in checkpoint.keys():
18 | if k.endswith(".down.weight"):
19 | return checkpoint[k].shape[0]
20 |
21 | def load_checkpoint(local_path):
22 | if local_path is not None:
23 | if '.safetensors' in local_path:
24 | print(f"Loading .safetensors checkpoint from {local_path}")
25 | checkpoint = load_safetensors(local_path)
26 | else:
27 | print(f"Loading checkpoint from {local_path}")
28 | checkpoint = torch.load(local_path, map_location='cpu')
29 | return checkpoint
30 |
31 | def update_model_with_lora(checkpoint, lora_weights, transformer, cond_size):
32 | number = len(lora_weights)
33 | ranks = [get_lora_rank(checkpoint) for _ in range(number)]
34 | lora_attn_procs = {}
35 | double_blocks_idx = list(range(19))
36 | single_blocks_idx = list(range(38))
37 | for name, attn_processor in transformer.attn_processors.items():
38 | match = re.search(r'\.(\d+)\.', name)
39 | if match:
40 | layer_index = int(match.group(1))
41 |
42 | if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
43 |
44 | lora_state_dicts = {}
45 | for key, value in checkpoint.items():
46 | # Match based on the layer index in the key (assuming the key contains layer index)
47 | if re.search(r'\.(\d+)\.', key):
48 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
49 | if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
50 | lora_state_dicts[key] = value
51 |
52 | lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
53 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
54 | )
55 |
56 | # Load the weights from the checkpoint dictionary into the corresponding layers
57 | for n in range(number):
58 | lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
59 | lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
60 | lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
61 | lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
62 | lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
63 | lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
64 | lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
65 | lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
66 | lora_attn_procs[name].to(device)
67 |
68 | elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
69 |
70 | lora_state_dicts = {}
71 | for key, value in checkpoint.items():
72 | # Match based on the layer index in the key (assuming the key contains layer index)
73 | if re.search(r'\.(\d+)\.', key):
74 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
75 | if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
76 | lora_state_dicts[key] = value
77 |
78 | lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
79 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
80 | )
81 | # Load the weights from the checkpoint dictionary into the corresponding layers
82 | for n in range(number):
83 | lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
84 | lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
85 | lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
86 | lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
87 | lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
88 | lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
89 | lora_attn_procs[name].to(device)
90 | else:
91 | lora_attn_procs[name] = FluxAttnProcessor2_0()
92 |
93 | transformer.set_attn_processor(lora_attn_procs)
94 |
95 |
96 | def update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size):
97 | ck_number = len(checkpoints)
98 | cond_lora_number = [len(ls) for ls in lora_weights]
99 | cond_number = sum(cond_lora_number)
100 | ranks = [get_lora_rank(checkpoint) for checkpoint in checkpoints]
101 | multi_lora_weight = []
102 | for ls in lora_weights:
103 | for n in ls:
104 | multi_lora_weight.append(n)
105 |
106 | lora_attn_procs = {}
107 | double_blocks_idx = list(range(19))
108 | single_blocks_idx = list(range(38))
109 | for name, attn_processor in transformer.attn_processors.items():
110 | match = re.search(r'\.(\d+)\.', name)
111 | if match:
112 | layer_index = int(match.group(1))
113 |
114 | if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
115 | lora_state_dicts = [{} for _ in range(ck_number)]
116 | for idx, checkpoint in enumerate(checkpoints):
117 | for key, value in checkpoint.items():
118 | # Match based on the layer index in the key (assuming the key contains layer index)
119 | if re.search(r'\.(\d+)\.', key):
120 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
121 | if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
122 | lora_state_dicts[idx][key] = value
123 |
124 | lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
125 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
126 | )
127 |
128 | # Load the weights from the checkpoint dictionary into the corresponding layers
129 | num = 0
130 | for idx in range(ck_number):
131 | for n in range(cond_lora_number[idx]):
132 | lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
133 | lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
134 | lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
135 | lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
136 | lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
137 | lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
138 | lora_attn_procs[name].proj_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.down.weight', None)
139 | lora_attn_procs[name].proj_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.up.weight', None)
140 | lora_attn_procs[name].to(device)
141 | num += 1
142 |
143 | elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
144 |
145 | lora_state_dicts = [{} for _ in range(ck_number)]
146 | for idx, checkpoint in enumerate(checkpoints):
147 | for key, value in checkpoint.items():
148 | # Match based on the layer index in the key (assuming the key contains layer index)
149 | if re.search(r'\.(\d+)\.', key):
150 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
151 | if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
152 | lora_state_dicts[idx][key] = value
153 |
154 | lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
155 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
156 | )
157 | # Load the weights from the checkpoint dictionary into the corresponding layers
158 | num = 0
159 | for idx in range(ck_number):
160 | for n in range(cond_lora_number[idx]):
161 | lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
162 | lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
163 | lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
164 | lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
165 | lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
166 | lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
167 | lora_attn_procs[name].to(device)
168 | num += 1
169 |
170 | else:
171 | lora_attn_procs[name] = FluxAttnProcessor2_0()
172 |
173 | transformer.set_attn_processor(lora_attn_procs)
174 |
175 |
176 | def set_single_lora(transformer, local_path, lora_weights=[], cond_size=512):
177 | checkpoint = load_checkpoint(local_path)
178 | update_model_with_lora(checkpoint, lora_weights, transformer, cond_size)
179 |
180 | def set_multi_lora(transformer, local_paths, lora_weights=[[]], cond_size=512):
181 | checkpoints = [load_checkpoint(local_path) for local_path in local_paths]
182 | update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size)
183 |
184 | def unset_lora(transformer):
185 | lora_attn_procs = {}
186 | for name, attn_processor in transformer.attn_processors.items():
187 | lora_attn_procs[name] = FluxAttnProcessor2_0()
188 | transformer.set_attn_processor(lora_attn_procs)
189 |
190 |
191 | '''
192 | unset_lora(pipe.transformer)
193 | lora_path = "./lora.safetensors"
194 | lora_weights = [1, 1]
195 | set_lora(pipe.transformer, local_path=lora_path, lora_weights=lora_weights, cond_size=512)
196 | '''
--------------------------------------------------------------------------------
/train/src/prompt_helper.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def load_text_encoders(args, class_one, class_two):
5 | text_encoder_one = class_one.from_pretrained(
6 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
7 | )
8 | text_encoder_two = class_two.from_pretrained(
9 | args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
10 | )
11 | return text_encoder_one, text_encoder_two
12 |
13 |
14 | def tokenize_prompt(tokenizer, prompt, max_sequence_length):
15 | text_inputs = tokenizer(
16 | prompt,
17 | padding="max_length",
18 | max_length=max_sequence_length,
19 | truncation=True,
20 | return_length=False,
21 | return_overflowing_tokens=False,
22 | return_tensors="pt",
23 | )
24 | text_input_ids = text_inputs.input_ids
25 | return text_input_ids
26 |
27 |
28 | def tokenize_prompt_clip(tokenizer, prompt):
29 | text_inputs = tokenizer(
30 | prompt,
31 | padding="max_length",
32 | max_length=77,
33 | truncation=True,
34 | return_length=False,
35 | return_overflowing_tokens=False,
36 | return_tensors="pt",
37 | )
38 | text_input_ids = text_inputs.input_ids
39 | return text_input_ids
40 |
41 |
42 | def tokenize_prompt_t5(tokenizer, prompt):
43 | text_inputs = tokenizer(
44 | prompt,
45 | padding="max_length",
46 | max_length=512,
47 | truncation=True,
48 | return_length=False,
49 | return_overflowing_tokens=False,
50 | return_tensors="pt",
51 | )
52 | text_input_ids = text_inputs.input_ids
53 | return text_input_ids
54 |
55 |
56 | def _encode_prompt_with_t5(
57 | text_encoder,
58 | tokenizer,
59 | max_sequence_length=512,
60 | prompt=None,
61 | num_images_per_prompt=1,
62 | device=None,
63 | text_input_ids=None,
64 | ):
65 | prompt = [prompt] if isinstance(prompt, str) else prompt
66 | batch_size = len(prompt)
67 |
68 | if tokenizer is not None:
69 | text_inputs = tokenizer(
70 | prompt,
71 | padding="max_length",
72 | max_length=max_sequence_length,
73 | truncation=True,
74 | return_length=False,
75 | return_overflowing_tokens=False,
76 | return_tensors="pt",
77 | )
78 | text_input_ids = text_inputs.input_ids
79 | else:
80 | if text_input_ids is None:
81 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
82 |
83 | prompt_embeds = text_encoder(text_input_ids.to(device))[0]
84 |
85 | dtype = text_encoder.dtype
86 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
87 |
88 | _, seq_len, _ = prompt_embeds.shape
89 |
90 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
91 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
92 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
93 |
94 | return prompt_embeds
95 |
96 |
97 | def _encode_prompt_with_clip(
98 | text_encoder,
99 | tokenizer,
100 | prompt: str,
101 | device=None,
102 | text_input_ids=None,
103 | num_images_per_prompt: int = 1,
104 | ):
105 | prompt = [prompt] if isinstance(prompt, str) else prompt
106 | batch_size = len(prompt)
107 |
108 | if tokenizer is not None:
109 | text_inputs = tokenizer(
110 | prompt,
111 | padding="max_length",
112 | max_length=77,
113 | truncation=True,
114 | return_overflowing_tokens=False,
115 | return_length=False,
116 | return_tensors="pt",
117 | )
118 |
119 | text_input_ids = text_inputs.input_ids
120 | else:
121 | if text_input_ids is None:
122 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
123 |
124 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
125 |
126 | # Use pooled output of CLIPTextModel
127 | prompt_embeds = prompt_embeds.pooler_output
128 | prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
129 |
130 | # duplicate text embeddings for each generation per prompt, using mps friendly method
131 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
132 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
133 |
134 | return prompt_embeds
135 |
136 |
137 | def encode_prompt(
138 | text_encoders,
139 | tokenizers,
140 | prompt: str,
141 | max_sequence_length,
142 | device=None,
143 | num_images_per_prompt: int = 1,
144 | text_input_ids_list=None,
145 | ):
146 | prompt = [prompt] if isinstance(prompt, str) else prompt
147 | dtype = text_encoders[0].dtype
148 |
149 | pooled_prompt_embeds = _encode_prompt_with_clip(
150 | text_encoder=text_encoders[0],
151 | tokenizer=tokenizers[0],
152 | prompt=prompt,
153 | device=device if device is not None else text_encoders[0].device,
154 | num_images_per_prompt=num_images_per_prompt,
155 | text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
156 | )
157 |
158 | prompt_embeds = _encode_prompt_with_t5(
159 | text_encoder=text_encoders[1],
160 | tokenizer=tokenizers[1],
161 | max_sequence_length=max_sequence_length,
162 | prompt=prompt,
163 | num_images_per_prompt=num_images_per_prompt,
164 | device=device if device is not None else text_encoders[1].device,
165 | text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
166 | )
167 |
168 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
169 |
170 | return prompt_embeds, pooled_prompt_embeds, text_ids
171 |
172 |
173 | def encode_token_ids(text_encoders, tokens, accelerator, num_images_per_prompt=1, device=None):
174 | text_encoder_clip = text_encoders[0]
175 | text_encoder_t5 = text_encoders[1]
176 | tokens_clip, tokens_t5 = tokens[0], tokens[1]
177 | batch_size = tokens_clip.shape[0]
178 |
179 | if device == "cpu":
180 | device = "cpu"
181 | else:
182 | device = accelerator.device
183 |
184 | # clip
185 | prompt_embeds = text_encoder_clip(tokens_clip.to(device), output_hidden_states=False)
186 | # Use pooled output of CLIPTextModel
187 | prompt_embeds = prompt_embeds.pooler_output
188 | prompt_embeds = prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
189 | # duplicate text embeddings for each generation per prompt, using mps friendly method
190 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
191 | pooled_prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
192 | pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
193 |
194 | # t5
195 | prompt_embeds = text_encoder_t5(tokens_t5.to(device))[0]
196 | dtype = text_encoder_t5.dtype
197 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=accelerator.device)
198 | _, seq_len, _ = prompt_embeds.shape
199 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
200 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
201 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
202 |
203 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=dtype)
204 |
205 | return prompt_embeds, pooled_prompt_embeds, text_ids
--------------------------------------------------------------------------------
/train/src/transformer_flux.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional, Tuple, Union
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from diffusers.configuration_utils import ConfigMixin, register_to_config
9 | from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
10 | from diffusers.models.attention import FeedForward
11 | from diffusers.models.attention_processor import (
12 | Attention,
13 | AttentionProcessor,
14 | FluxAttnProcessor2_0,
15 | FluxAttnProcessor2_0_NPU,
16 | FusedFluxAttnProcessor2_0,
17 | )
18 | from diffusers.models.modeling_utils import ModelMixin
19 | from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
20 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
21 | from diffusers.utils.import_utils import is_torch_npu_available
22 | from diffusers.utils.torch_utils import maybe_allow_in_graph
23 | from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
24 | from diffusers.models.modeling_outputs import Transformer2DModelOutput
25 |
26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27 |
28 | @maybe_allow_in_graph
29 | class FluxSingleTransformerBlock(nn.Module):
30 |
31 | def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
32 | super().__init__()
33 | self.mlp_hidden_dim = int(dim * mlp_ratio)
34 |
35 | self.norm = AdaLayerNormZeroSingle(dim)
36 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
37 | self.act_mlp = nn.GELU(approximate="tanh")
38 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
39 |
40 | if is_torch_npu_available():
41 | processor = FluxAttnProcessor2_0_NPU()
42 | else:
43 | processor = FluxAttnProcessor2_0()
44 | self.attn = Attention(
45 | query_dim=dim,
46 | cross_attention_dim=None,
47 | dim_head=attention_head_dim,
48 | heads=num_attention_heads,
49 | out_dim=dim,
50 | bias=True,
51 | processor=processor,
52 | qk_norm="rms_norm",
53 | eps=1e-6,
54 | pre_only=True,
55 | )
56 |
57 | def forward(
58 | self,
59 | hidden_states: torch.Tensor,
60 | cond_hidden_states: torch.Tensor,
61 | temb: torch.Tensor,
62 | cond_temb: torch.Tensor,
63 | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
64 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
65 | ) -> torch.Tensor:
66 | use_cond = cond_hidden_states is not None
67 |
68 | residual = hidden_states
69 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
70 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
71 |
72 | if use_cond:
73 | residual_cond = cond_hidden_states
74 | norm_cond_hidden_states, cond_gate = self.norm(cond_hidden_states, emb=cond_temb)
75 | mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_cond_hidden_states))
76 |
77 | norm_hidden_states_concat = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
78 |
79 | joint_attention_kwargs = joint_attention_kwargs or {}
80 | attn_output = self.attn(
81 | hidden_states=norm_hidden_states_concat,
82 | image_rotary_emb=image_rotary_emb,
83 | use_cond=use_cond,
84 | **joint_attention_kwargs,
85 | )
86 | if use_cond:
87 | attn_output, cond_attn_output = attn_output
88 |
89 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
90 | gate = gate.unsqueeze(1)
91 | hidden_states = gate * self.proj_out(hidden_states)
92 | hidden_states = residual + hidden_states
93 |
94 | if use_cond:
95 | condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
96 | cond_gate = cond_gate.unsqueeze(1)
97 | condition_latents = cond_gate * self.proj_out(condition_latents)
98 | condition_latents = residual_cond + condition_latents
99 |
100 | if hidden_states.dtype == torch.float16:
101 | hidden_states = hidden_states.clip(-65504, 65504)
102 |
103 | return hidden_states, condition_latents if use_cond else None
104 |
105 |
106 | @maybe_allow_in_graph
107 | class FluxTransformerBlock(nn.Module):
108 | def __init__(
109 | self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
110 | ):
111 | super().__init__()
112 |
113 | self.norm1 = AdaLayerNormZero(dim)
114 |
115 | self.norm1_context = AdaLayerNormZero(dim)
116 |
117 | if hasattr(F, "scaled_dot_product_attention"):
118 | processor = FluxAttnProcessor2_0()
119 | else:
120 | raise ValueError(
121 | "The current PyTorch version does not support the `scaled_dot_product_attention` function."
122 | )
123 | self.attn = Attention(
124 | query_dim=dim,
125 | cross_attention_dim=None,
126 | added_kv_proj_dim=dim,
127 | dim_head=attention_head_dim,
128 | heads=num_attention_heads,
129 | out_dim=dim,
130 | context_pre_only=False,
131 | bias=True,
132 | processor=processor,
133 | qk_norm=qk_norm,
134 | eps=eps,
135 | )
136 |
137 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
138 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
139 |
140 | self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
141 | self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
142 |
143 | # let chunk size default to None
144 | self._chunk_size = None
145 | self._chunk_dim = 0
146 |
147 | def forward(
148 | self,
149 | hidden_states: torch.Tensor,
150 | cond_hidden_states: torch.Tensor,
151 | encoder_hidden_states: torch.Tensor,
152 | temb: torch.Tensor,
153 | cond_temb: torch.Tensor,
154 | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
155 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
156 | ) -> Tuple[torch.Tensor, torch.Tensor]:
157 | use_cond = cond_hidden_states is not None
158 |
159 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
160 | if use_cond:
161 | (
162 | norm_cond_hidden_states,
163 | cond_gate_msa,
164 | cond_shift_mlp,
165 | cond_scale_mlp,
166 | cond_gate_mlp,
167 | ) = self.norm1(cond_hidden_states, emb=cond_temb)
168 |
169 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
170 | encoder_hidden_states, emb=temb
171 | )
172 |
173 | norm_hidden_states = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
174 |
175 | joint_attention_kwargs = joint_attention_kwargs or {}
176 | # Attention.
177 | attention_outputs = self.attn(
178 | hidden_states=norm_hidden_states,
179 | encoder_hidden_states=norm_encoder_hidden_states,
180 | image_rotary_emb=image_rotary_emb,
181 | use_cond=use_cond,
182 | **joint_attention_kwargs,
183 | )
184 |
185 | attn_output, context_attn_output = attention_outputs[:2]
186 | cond_attn_output = attention_outputs[2] if use_cond else None
187 |
188 | # Process attention outputs for the `hidden_states`.
189 | attn_output = gate_msa.unsqueeze(1) * attn_output
190 | hidden_states = hidden_states + attn_output
191 |
192 | if use_cond:
193 | cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
194 | cond_hidden_states = cond_hidden_states + cond_attn_output
195 |
196 | norm_hidden_states = self.norm2(hidden_states)
197 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
198 |
199 | if use_cond:
200 | norm_cond_hidden_states = self.norm2(cond_hidden_states)
201 | norm_cond_hidden_states = (
202 | norm_cond_hidden_states * (1 + cond_scale_mlp[:, None])
203 | + cond_shift_mlp[:, None]
204 | )
205 |
206 | ff_output = self.ff(norm_hidden_states)
207 | ff_output = gate_mlp.unsqueeze(1) * ff_output
208 | hidden_states = hidden_states + ff_output
209 |
210 | if use_cond:
211 | cond_ff_output = self.ff(norm_cond_hidden_states)
212 | cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
213 | cond_hidden_states = cond_hidden_states + cond_ff_output
214 |
215 | # Process attention outputs for the `encoder_hidden_states`.
216 |
217 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
218 | encoder_hidden_states = encoder_hidden_states + context_attn_output
219 |
220 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
221 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
222 |
223 | context_ff_output = self.ff_context(norm_encoder_hidden_states)
224 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
225 | if encoder_hidden_states.dtype == torch.float16:
226 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
227 |
228 | return encoder_hidden_states, hidden_states, cond_hidden_states if use_cond else None
229 |
230 |
231 | class FluxTransformer2DModel(
232 | ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
233 | ):
234 | _supports_gradient_checkpointing = True
235 | _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
236 |
237 | @register_to_config
238 | def __init__(
239 | self,
240 | patch_size: int = 1,
241 | in_channels: int = 64,
242 | out_channels: Optional[int] = None,
243 | num_layers: int = 19,
244 | num_single_layers: int = 38,
245 | attention_head_dim: int = 128,
246 | num_attention_heads: int = 24,
247 | joint_attention_dim: int = 4096,
248 | pooled_projection_dim: int = 768,
249 | guidance_embeds: bool = False,
250 | axes_dims_rope: Tuple[int] = (16, 56, 56),
251 | ):
252 | super().__init__()
253 | self.out_channels = out_channels or in_channels
254 | self.inner_dim = num_attention_heads * attention_head_dim
255 |
256 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
257 |
258 | text_time_guidance_cls = (
259 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
260 | )
261 | self.time_text_embed = text_time_guidance_cls(
262 | embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
263 | )
264 |
265 | self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
266 | self.x_embedder = nn.Linear(in_channels, self.inner_dim)
267 |
268 | self.transformer_blocks = nn.ModuleList(
269 | [
270 | FluxTransformerBlock(
271 | dim=self.inner_dim,
272 | num_attention_heads=num_attention_heads,
273 | attention_head_dim=attention_head_dim,
274 | )
275 | for _ in range(num_layers)
276 | ]
277 | )
278 |
279 | self.single_transformer_blocks = nn.ModuleList(
280 | [
281 | FluxSingleTransformerBlock(
282 | dim=self.inner_dim,
283 | num_attention_heads=num_attention_heads,
284 | attention_head_dim=attention_head_dim,
285 | )
286 | for _ in range(num_single_layers)
287 | ]
288 | )
289 |
290 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
291 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
292 |
293 | self.gradient_checkpointing = False
294 |
295 | @property
296 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
297 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
298 | r"""
299 | Returns:
300 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
301 | indexed by its weight name.
302 | """
303 | # set recursively
304 | processors = {}
305 |
306 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
307 | if hasattr(module, "get_processor"):
308 | processors[f"{name}.processor"] = module.get_processor()
309 |
310 | for sub_name, child in module.named_children():
311 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
312 |
313 | return processors
314 |
315 | for name, module in self.named_children():
316 | fn_recursive_add_processors(name, module, processors)
317 |
318 | return processors
319 |
320 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
321 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
322 | r"""
323 | Sets the attention processor to use to compute attention.
324 |
325 | Parameters:
326 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
327 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
328 | for **all** `Attention` layers.
329 |
330 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
331 | processor. This is strongly recommended when setting trainable attention processors.
332 |
333 | """
334 | count = len(self.attn_processors.keys())
335 |
336 | if isinstance(processor, dict) and len(processor) != count:
337 | raise ValueError(
338 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
339 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
340 | )
341 |
342 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
343 | if hasattr(module, "set_processor"):
344 | if not isinstance(processor, dict):
345 | module.set_processor(processor)
346 | else:
347 | module.set_processor(processor.pop(f"{name}.processor"))
348 |
349 | for sub_name, child in module.named_children():
350 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
351 |
352 | for name, module in self.named_children():
353 | fn_recursive_attn_processor(name, module, processor)
354 |
355 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
356 | def fuse_qkv_projections(self):
357 | """
358 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
359 | are fused. For cross-attention modules, key and value projection matrices are fused.
360 |
361 |
362 |
363 | This API is 🧪 experimental.
364 |
365 |
366 | """
367 | self.original_attn_processors = None
368 |
369 | for _, attn_processor in self.attn_processors.items():
370 | if "Added" in str(attn_processor.__class__.__name__):
371 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
372 |
373 | self.original_attn_processors = self.attn_processors
374 |
375 | for module in self.modules():
376 | if isinstance(module, Attention):
377 | module.fuse_projections(fuse=True)
378 |
379 | self.set_attn_processor(FusedFluxAttnProcessor2_0())
380 |
381 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
382 | def unfuse_qkv_projections(self):
383 | """Disables the fused QKV projection if enabled.
384 |
385 |
386 |
387 | This API is 🧪 experimental.
388 |
389 |
390 |
391 | """
392 | if self.original_attn_processors is not None:
393 | self.set_attn_processor(self.original_attn_processors)
394 |
395 | def _set_gradient_checkpointing(self, module, value=False):
396 | if hasattr(module, "gradient_checkpointing"):
397 | module.gradient_checkpointing = value
398 |
399 | def forward(
400 | self,
401 | hidden_states: torch.Tensor,
402 | cond_hidden_states: torch.Tensor = None,
403 | encoder_hidden_states: torch.Tensor = None,
404 | pooled_projections: torch.Tensor = None,
405 | timestep: torch.LongTensor = None,
406 | img_ids: torch.Tensor = None,
407 | txt_ids: torch.Tensor = None,
408 | guidance: torch.Tensor = None,
409 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
410 | controlnet_block_samples=None,
411 | controlnet_single_block_samples=None,
412 | return_dict: bool = True,
413 | controlnet_blocks_repeat: bool = False,
414 | ) -> Union[torch.Tensor, Transformer2DModelOutput]:
415 | if cond_hidden_states is not None:
416 | use_condition = True
417 | else:
418 | use_condition = False
419 |
420 | if joint_attention_kwargs is not None:
421 | joint_attention_kwargs = joint_attention_kwargs.copy()
422 | lora_scale = joint_attention_kwargs.pop("scale", 1.0)
423 | else:
424 | lora_scale = 1.0
425 |
426 | if USE_PEFT_BACKEND:
427 | # weight the lora layers by setting `lora_scale` for each PEFT layer
428 | scale_lora_layers(self, lora_scale)
429 | else:
430 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
431 | logger.warning(
432 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
433 | )
434 |
435 | hidden_states = self.x_embedder(hidden_states)
436 | cond_hidden_states = self.x_embedder(cond_hidden_states)
437 |
438 | timestep = timestep.to(hidden_states.dtype) * 1000
439 | if guidance is not None:
440 | guidance = guidance.to(hidden_states.dtype) * 1000
441 | else:
442 | guidance = None
443 |
444 | temb = (
445 | self.time_text_embed(timestep, pooled_projections)
446 | if guidance is None
447 | else self.time_text_embed(timestep, guidance, pooled_projections)
448 | )
449 |
450 | cond_temb = (
451 | self.time_text_embed(torch.ones_like(timestep) * 0, pooled_projections)
452 | if guidance is None
453 | else self.time_text_embed(
454 | torch.ones_like(timestep) * 0, guidance, pooled_projections
455 | )
456 | )
457 |
458 | encoder_hidden_states = self.context_embedder(encoder_hidden_states)
459 |
460 | if txt_ids.ndim == 3:
461 | logger.warning(
462 | "Passing `txt_ids` 3d torch.Tensor is deprecated."
463 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
464 | )
465 | txt_ids = txt_ids[0]
466 | if img_ids.ndim == 3:
467 | logger.warning(
468 | "Passing `img_ids` 3d torch.Tensor is deprecated."
469 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
470 | )
471 | img_ids = img_ids[0]
472 |
473 | ids = torch.cat((txt_ids, img_ids), dim=0)
474 | image_rotary_emb = self.pos_embed(ids)
475 |
476 | if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
477 | ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
478 | ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
479 | joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
480 |
481 | for index_block, block in enumerate(self.transformer_blocks):
482 | if torch.is_grad_enabled() and self.gradient_checkpointing:
483 |
484 | def create_custom_forward(module, return_dict=None):
485 | def custom_forward(*inputs):
486 | if return_dict is not None:
487 | return module(*inputs, return_dict=return_dict)
488 | else:
489 | return module(*inputs)
490 |
491 | return custom_forward
492 |
493 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
494 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
495 | create_custom_forward(block),
496 | hidden_states,
497 | encoder_hidden_states,
498 | temb,
499 | image_rotary_emb,
500 | cond_temb=cond_temb if use_condition else None,
501 | cond_hidden_states=cond_hidden_states if use_condition else None,
502 | **ckpt_kwargs,
503 | )
504 |
505 | else:
506 | encoder_hidden_states, hidden_states, cond_hidden_states = block(
507 | hidden_states=hidden_states,
508 | encoder_hidden_states=encoder_hidden_states,
509 | cond_hidden_states=cond_hidden_states if use_condition else None,
510 | temb=temb,
511 | cond_temb=cond_temb if use_condition else None,
512 | image_rotary_emb=image_rotary_emb,
513 | joint_attention_kwargs=joint_attention_kwargs,
514 | )
515 |
516 | # controlnet residual
517 | if controlnet_block_samples is not None:
518 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
519 | interval_control = int(np.ceil(interval_control))
520 | # For Xlabs ControlNet.
521 | if controlnet_blocks_repeat:
522 | hidden_states = (
523 | hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
524 | )
525 | else:
526 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
527 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
528 |
529 | for index_block, block in enumerate(self.single_transformer_blocks):
530 | if torch.is_grad_enabled() and self.gradient_checkpointing:
531 |
532 | def create_custom_forward(module, return_dict=None):
533 | def custom_forward(*inputs):
534 | if return_dict is not None:
535 | return module(*inputs, return_dict=return_dict)
536 | else:
537 | return module(*inputs)
538 |
539 | return custom_forward
540 |
541 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
542 | hidden_states, cond_hidden_states = torch.utils.checkpoint.checkpoint(
543 | create_custom_forward(block),
544 | hidden_states,
545 | temb,
546 | image_rotary_emb,
547 | cond_temb=cond_temb if use_condition else None,
548 | cond_hidden_states=cond_hidden_states if use_condition else None,
549 | **ckpt_kwargs,
550 | )
551 |
552 | else:
553 | hidden_states, cond_hidden_states = block(
554 | hidden_states=hidden_states,
555 | cond_hidden_states=cond_hidden_states if use_condition else None,
556 | temb=temb,
557 | cond_temb=cond_temb if use_condition else None,
558 | image_rotary_emb=image_rotary_emb,
559 | joint_attention_kwargs=joint_attention_kwargs,
560 | )
561 |
562 | # controlnet residual
563 | if controlnet_single_block_samples is not None:
564 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
565 | interval_control = int(np.ceil(interval_control))
566 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
567 | hidden_states[:, encoder_hidden_states.shape[1] :, ...]
568 | + controlnet_single_block_samples[index_block // interval_control]
569 | )
570 |
571 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
572 |
573 | hidden_states = self.norm_out(hidden_states, temb)
574 | output = self.proj_out(hidden_states)
575 |
576 | if USE_PEFT_BACKEND:
577 | # remove `lora_scale` from each PEFT layer
578 | unscale_lora_layers(self, lora_scale)
579 |
580 | if not return_dict:
581 | return (output,)
582 |
583 | return Transformer2DModelOutput(sample=output)
--------------------------------------------------------------------------------
/train/train_spatial.sh:
--------------------------------------------------------------------------------
1 | export MODEL_DIR="black-forest-labs/FLUX.1-dev" # your flux path
2 | export OUTPUT_DIR="./models/pose_model" # your save path
3 | export CONFIG="./default_config.yaml"
4 | export TRAIN_DATA="./examples/pose.jsonl" # your data jsonl file
5 | export LOG_PATH="$OUTPUT_DIR/log"
6 |
7 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file $CONFIG train.py \
8 | --pretrained_model_name_or_path $MODEL_DIR \
9 | --cond_size=512 \
10 | --noise_size=1024 \
11 | --subject_column="None" \
12 | --spatial_column="source" \
13 | --target_column="target" \
14 | --caption_column="caption" \
15 | --ranks 128 \
16 | --network_alphas 128 \
17 | --output_dir=$OUTPUT_DIR \
18 | --logging_dir=$LOG_PATH \
19 | --mixed_precision="bf16" \
20 | --train_data_dir=$TRAIN_DATA \
21 | --learning_rate=1e-4 \
22 | --train_batch_size=1 \
23 | --validation_prompt "A girl in the city." \
24 | --num_train_epochs=1000 \
25 | --validation_steps=20 \
26 | --checkpointing_steps=20 \
27 | --spatial_test_images "./examples/openpose_data/1.png" \
28 | --subject_test_images None \
29 | --test_h 1024 \
30 | --test_w 1024 \
31 | --num_validation_images=2
32 |
--------------------------------------------------------------------------------
/train/train_style.sh:
--------------------------------------------------------------------------------
1 | export MODEL_DIR="black-forest-labs/FLUX.1-dev" # your flux path
2 | export OUTPUT_DIR="./models/style_model" # your save path
3 | export CONFIG="./default_config.yaml"
4 | export TRAIN_DATA="./examples/style.jsonl" # your data jsonl file
5 | export LOG_PATH="$OUTPUT_DIR/log"
6 |
7 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file $CONFIG train.py \
8 | --pretrained_model_name_or_path $MODEL_DIR \
9 | --cond_size=512 \
10 | --noise_size=1024 \
11 | --subject_column="None" \
12 | --spatial_column="source" \
13 | --target_column="target" \
14 | --caption_column="caption" \
15 | --ranks 128 \
16 | --network_alphas 128 \
17 | --output_dir=$OUTPUT_DIR \
18 | --logging_dir=$LOG_PATH \
19 | --mixed_precision="bf16" \
20 | --train_data_dir=$TRAIN_DATA \
21 | --learning_rate=1e-4 \
22 | --train_batch_size=1 \
23 | --validation_prompt "Ghibli Studio style, Charming hand-drawn anime-style illustration" \
24 | --num_train_epochs=1000 \
25 | --validation_steps=20 \
26 | --checkpointing_steps=20 \
27 | --spatial_test_images "./examples/style_data/5.png" \
28 | --subject_test_images None \
29 | --test_h 1024 \
30 | --test_w 1024 \
31 | --num_validation_images=2
32 | #
33 |
--------------------------------------------------------------------------------
/train/train_subject.sh:
--------------------------------------------------------------------------------
1 | export MODEL_DIR="black-forest-labs/FLUX.1-dev" # your flux path
2 | export OUTPUT_DIR="./models/subject_model" # your save path
3 | export CONFIG="./default_config.yaml"
4 | export TRAIN_DATA="./examples/subject.jsonl" # your data jsonl file
5 | export LOG_PATH="$OUTPUT_DIR/log"
6 |
7 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file $CONFIG train.py \
8 | --pretrained_model_name_or_path $MODEL_DIR \
9 | --cond_size=512 \
10 | --noise_size=1024 \
11 | --subject_column="source" \
12 | --spatial_column="None" \
13 | --target_column="target" \
14 | --caption_column="caption" \
15 | --ranks 128 \
16 | --network_alphas 128 \
17 | --output_dir=$OUTPUT_DIR \
18 | --logging_dir=$LOG_PATH \
19 | --mixed_precision="bf16" \
20 | --train_data_dir=$TRAIN_DATA \
21 | --learning_rate=1e-4 \
22 | --train_batch_size=1 \
23 | --validation_prompt "An SKS in the city." \
24 | --num_train_epochs=1000 \
25 | --validation_steps=20 \
26 | --checkpointing_steps=20 \
27 | --spatial_test_images None \
28 | --subject_test_images "./examples/subject_data/3.png" \
29 | --test_h 1024 \
30 | --test_w 1024 \
31 | --num_validation_images=2
32 |
--------------------------------------------------------------------------------