├── LICENSE
├── README.md
├── app
├── __init__.py
└── gradio_app.py
├── flux
├── __init__.py
├── block.py
├── condition.py
├── generate.py
├── lora_controller.py
├── pipeline_tools.py
└── transformer.py
├── requirements.txt
└── samples
├── 1.png
├── 12.png
├── 13.png
├── 14.png
├── 18.png
├── 7.png
├── 8.png
└── image_6.png
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [2025] [Fotographer AI Inc]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
7 |
8 | **An all-in-one, control framework for unified visual content creation using GenAI.**
9 | Generate multi-view, diverse-scene, and task-specific high-resolution images from a single subject image—without fine-tuning.
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | ---
20 |
21 | ## 🧠 Overview
22 |
23 | **ZenCtrl** is a comprehensive toolkit built to tackle core challenges in image generation:
24 |
25 | - No fine-tuning needed – works from **a single subject image**
26 | - Maintains **control over shape, pose, camera angle, context**
27 | - Supports **high-resolution**, multi-scene generation
28 | - Modular toolkit for preprocessing, control, editing, and post-processing tasks
29 |
30 | ZenCtrl is based on OminiControl but enhanced with more fine-grained control, consistent subject preservation, and more improved and ready-to-use models. Our goal is to build an **agentic visual generation system** that can orchestrate image/video creation from **LLM-driven recipes.**
31 |
32 |
37 |
38 | ---
39 |
40 | ## 🛠 Toolkit Components (coming soon)
41 |
42 | ### 🧹 Preprocessing
43 |
44 | - Background removal
45 | - Matting
46 | - Reshaping
47 | - Segmentation
48 |
49 | ### 🎮 Control Models
50 |
51 | - Shape (HED, Scribble, Depth)
52 | - Pose (OpenPose, DensePose)
53 | - Mask control
54 | - Camera/View control
55 |
56 | ### 🎨 Post-processing
57 |
58 | - Deblurring
59 | - Color fixing
60 | - Natural blending
61 |
62 | ### ✏️ Editing Models
63 |
64 | - Inpainting (removal, masked editing, replacement)
65 | - Outpainting
66 | - Transformation / Motion
67 | - Relighting
68 |
69 | ---
70 |
71 | ## 🎯 Supported Tasks
72 |
73 | - Background generation
74 | - Controlled background generation
75 | - Subject-consistent context-aware generation
76 | - Object and subject placement (coming soon)
77 | - In-context image/video generation (coming soon)
78 | - Multi-object/subject merging & blending (coming soon)
79 | - Video generation (coming soon)
80 |
81 | ---
82 |
83 | ## 📦 Target Use Cases
84 |
85 | - Product photography
86 | - Fashion & accessory try-on
87 | - Virtual try-on (shoes, hats, glasses, etc.)
88 | - People & portrait control
89 | - Illustration, animation, and ad creatives
90 |
91 | All of these tasks can be **mixed and layered** — ZenCtrl is designed to support real-world visual workflows with **agentic task composition**.
92 |
93 | ---
94 |
95 | ## 📢 News
96 |
97 | - **2025-03-24**: 🧠 First release — model weights available on Hugging Face!
98 | - **2025-05-06**: 📢 Update — ource code release, latest model weights available on Hugging Face!
99 | - **Coming Soon**: Quick Start guide, Upscaling source code, Example notebooks
100 | - **Next**: Controlled fine-grain version on our platform and API (Pro version)
101 | - **Future**: Video generation toolkit release
102 |
103 | ---
104 |
105 | ## 🚀 Quick Start
106 |
107 | Before running the Gradio code, please install the requirements and download the weights from our HuggingFace repository:
108 | 👉 [https://huggingface.co/fotographerai/zenctrl_tools](https://huggingface.co/fotographerai/zenctrl_tools)
109 |
110 | We matched our original code with the Omnicontrol structure. Our model takes two inputs instead, but we are going to release the original code soon with the LLaMA task driver — so stay tuned. We will also update the tasks for specific verticals (e.g., virtual try-on, ad creatives, etc.).
111 |
112 | ---
113 |
114 | ### Quick Setup (CMD)
115 |
116 | You can follow the step-by-step setup instructions below:
117 |
118 | ```cmd
119 | *** Cloning and setting up ZenCtrl
120 | git clone https://github.com/FotographerAI/ZenCtrl.git
121 | cd ZenCtrl
122 |
123 | *** Creating virtual environment
124 | python -m venv venv
125 | call venv\Scripts\activate.bat
126 |
127 | *** Installing PyTorch and requirements
128 | pip install torch==2.7.0+cu128 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
129 | pip install --upgrade pip wheel setuptools
130 | pip install -r requirements.txt
131 |
132 | *** Downloading model weights
133 | curl --create-dirs -L https://huggingface.co/fotographerai/zenctrl_tools/resolve/main/weights/zen2con_1440_17000/pytorch_lora_weights.safetensors -o weights\zen2con_1440_17000\pytorch_lora_weights.safetensors
134 |
135 | *** All set! Launching Gradio app
136 | python app/gradio_app.py
137 | ```
138 |
139 | ---
140 |
141 | ## 🎨 Demo
142 |
143 | #### Examples
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 | ### 🧪 Try it now on [Hugging Face Space](https://huggingface.co/spaces/fotographerai/ZenCtrl)
173 |
174 |
180 |
181 | ---
182 |
183 | ## 🔧 Models (Updated Weights Released)
184 |
185 | | Type | Name | Base | Resolution | Description | links |
186 | | --------------------- | --------------------- | ------------ | ---------- | --------------------------------- | ---------------------------------------------------------- |
187 | | Subject Generation | `zen2con_1440_17000` | FLUX.1 | 1024x1024 | Core model for subject-driven gen | [link](https://huggingface.co/fotographerai/zenctrl_tools/tree/main/weights/zen2con_1440_17000) |
188 | | Bg generation + Canny | `bg_canny_58000_1024` | FLUX.1 | 1024x1024 | Enhanced background control | [link](https://huggingface.co/fotographerai/zenctrl_tools) |
189 | | Deblurring Model | `deblurr_1024_10000` | OminiControl | 1024x1024 | Quality recovery post-generation | [link](https://huggingface.co/fotographerai/zenctrl_tools) |
190 |
191 | ---
192 |
193 | ## 🚧 Limitations
194 |
195 | 1. Models currently perform best with **objects**, and to some extent **humans**.
196 | 2. Resolution support is currently capped at **1024x1024** (higher quality coming soon).
197 | 3. Performance with **illustrations** is currently limited.
198 | 4. The models were **not trained on large-scale or highly diverse datasets** yet — we plan to improve quality and variation by training on larger and more diverse datasets, especially for **illustration and stylized content**.
199 | 5. Video support and the full **agentic task pipeline** are still under development.
200 |
201 | ---
202 |
203 | ## 📋 To-do
204 |
205 | - [x] Release early pretrained model weights for defined tasks
206 | - [x] Release additional task-specific models and modes
207 | - [x] Release open source code
208 | - [x] Launch API access via Baseten for easier deployment
209 | - [ ] Release Quick Start guide and example notebooks
210 | - [ ] Launch API access via our app for easier deployment
211 | - [ ] Release high-resolution models (1500×1500+)
212 | - [ ] Enable full toolkit integration with agent API
213 | - [ ] Add video generation module
214 |
215 | ---
216 |
217 | ## 🤝 Join the Community
218 |
219 | - 💬 [Discord](https://discord.com/invite/b9RuYQ3F8k) – share ideas and feedback
220 | - 🌐 [Landing Page](https://fotographer.ai/zen-control)
221 | - 🧪 [Try it now on Hugging Face Space](https://huggingface.co/fotographerai/zenctrl_tools/tree/main/weights)
222 |
223 |
224 | ---
225 |
226 | ## 🤝 Community Collaboration
227 |
228 | We hope to collaborate closely with the open-source community to make **ZenCtrl** a powerful and extensible toolkit for visual content creation.
229 | Once the source code is released, we welcome contributions in training, expanding supported use cases, and developing new task-specific modules.
230 | Our vision is to make ZenCtrl the **standard framework** for agentic, high-quality image and video generation — built together, for everyone.
231 |
--------------------------------------------------------------------------------
/app/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/app/__init__.py
--------------------------------------------------------------------------------
/app/gradio_app.py:
--------------------------------------------------------------------------------
1 | # Recycled from Ominicontrol
2 |
3 | import gradio as gr
4 | import torch
5 | from PIL import Image
6 | from diffusers.pipelines import FluxPipeline
7 | from diffusers import FluxTransformer2DModel
8 |
9 | from flux.condition import Condition
10 | from flux.generate import generate
11 | from flux.lora_controller import set_lora_scale
12 |
13 | pipe = None
14 | use_int8 = False
15 | model_config = { "union_cond_attn": True, "add_cond_attn": False, "latent_lora": False, "independent_condition": False}
16 |
17 | def get_gpu_memory():
18 | return torch.cuda.get_device_properties(0).total_memory / 1024**3
19 |
20 |
21 | def init_pipeline():
22 | global pipe
23 | if use_int8 or get_gpu_memory() < 33:
24 | transformer_model = FluxTransformer2DModel.from_pretrained(
25 | "sayakpaul/flux.1-schell-int8wo-improved",
26 | torch_dtype=torch.bfloat16,
27 | use_safetensors=False,
28 | )
29 | pipe = FluxPipeline.from_pretrained(
30 | "black-forest-labs/FLUX.1-schnell",
31 | transformer=transformer_model,
32 | torch_dtype=torch.bfloat16,
33 | )
34 | else:
35 | pipe = FluxPipeline.from_pretrained(
36 | "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
37 | )
38 | pipe = pipe.to("cuda")
39 |
40 | # Optional: Load additional LoRA weights, put the loaded weigths here!
41 | pipe.load_lora_weights("weights/zen2con_1440_17000/pytorch_lora_weights.safetensors",
42 | adapter_name="subject")
43 | pipe.set_adapters(["subject"])
44 |
45 | def paste_on_white_background(image: Image.Image) -> Image.Image:
46 | """
47 | Pastes a transparent image onto a white background of the same size.
48 | """
49 | if image.mode != "RGBA":
50 | image = image.convert("RGBA")
51 |
52 | # Create white background
53 | white_bg = Image.new("RGBA", image.size, (255, 255, 255, 255))
54 | white_bg.paste(image, (0, 0), mask=image)
55 | return white_bg.convert("RGB") # Convert back to RGB if you don't need alpha
56 |
57 |
58 | def process_image_and_text(image, text, steps=8, strength_sub=1.0, strength_spat=1.0, size=1024):
59 | # center crop image
60 | w, h, min_size = image.size[0], image.size[1], min(image.size)
61 | image = image.crop(
62 | (
63 | (w - min_size) // 2,
64 | (h - min_size) // 2,
65 | (w + min_size) // 2,
66 | (h + min_size) // 2,
67 | )
68 | )
69 | image = image.resize((size, size))
70 | image = paste_on_white_background(image) #Optional, you can remove this line if you want just make sure the size it matched.
71 | condition0 = Condition("subject", image, position_delta=(0, size // 16))
72 | condition1 = Condition("subject", image, position_delta=(0, -size // 16))
73 |
74 | if pipe is None:
75 | init_pipeline()
76 |
77 | with set_lora_scale(["subject"], scale=3.0):
78 | result_img = generate(
79 | pipe,
80 | prompt=text.strip(),
81 | conditions=[condition0, condition1],
82 | num_inference_steps=steps,
83 | height=1024,
84 | width=1024,
85 | condition_scale = [strength_sub,strength_spat],
86 | model_config=model_config,
87 | default_lora=True,
88 | ).images[0]
89 |
90 | return [condition0.condition, condition1.condition, result_img]
91 |
92 |
93 | def get_samples():
94 | sample_list = [
95 | {
96 | "image": "samples/1.png", #place your image path here
97 | "text": "A man sitting in a yellow chair drinking a cup of coffee",
98 | }
99 | ]
100 | return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list]
101 |
102 |
103 | demo = gr.Interface(
104 | fn=process_image_and_text,
105 | inputs=[
106 | gr.Image(type="pil"),
107 | gr.Textbox(lines=2),
108 | gr.Slider(minimum=2, maximum=28, value=2, label="steps"),
109 | gr.Slider(minimum=0, maximum=2.0, value=1.0, label="strength_sub"),
110 | gr.Slider(minimum=0, maximum=2.0, value=1.0, label="strength_spat"),
111 | gr.Slider(minimum=512, maximum=2048, value=1024, label="size"),
112 | ],
113 | outputs=gr.Gallery(
114 | label="Outputs", show_label=False, elem_id="gallery",
115 | columns=[3], rows=[1], object_fit="contain", height="auto"
116 | ),
117 | title="ZenCtrl / Subject driven generation",
118 | examples=get_samples(),
119 | )
120 |
121 | if __name__ == "__main__":
122 | init_pipeline()
123 | demo.launch(
124 | debug=True,
125 | # share=True
126 | )
127 |
--------------------------------------------------------------------------------
/flux/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/flux/__init__.py
--------------------------------------------------------------------------------
/flux/block.py:
--------------------------------------------------------------------------------
1 | # Recycled from Ominicontrol and modified to accept an extra condition.
2 | # While Zenctrl pursued a similar idea, it diverged structurally.
3 | # We appreciate the clarity of Omini's implementation and decided to align with it.
4 |
5 | import torch
6 | from typing import Optional, Dict, Any
7 | from diffusers.models.attention_processor import Attention, F
8 | from .lora_controller import enable_lora
9 | from diffusers.models.embeddings import apply_rotary_emb
10 |
11 | def attn_forward(
12 | attn: Attention,
13 | hidden_states: torch.FloatTensor,
14 | encoder_hidden_states: torch.FloatTensor = None,
15 | condition_latents: torch.FloatTensor = None,
16 | extra_condition_latents: torch.FloatTensor = None,
17 | attention_mask: Optional[torch.FloatTensor] = None,
18 | image_rotary_emb: Optional[torch.Tensor] = None,
19 | cond_rotary_emb: Optional[torch.Tensor] = None,
20 | extra_cond_rotary_emb: Optional[torch.Tensor] = None,
21 | model_config: Optional[Dict[str, Any]] = {},
22 | ) -> torch.FloatTensor:
23 | batch_size, _, _ = (
24 | hidden_states.shape
25 | if encoder_hidden_states is None
26 | else encoder_hidden_states.shape
27 | )
28 |
29 | with enable_lora(
30 | (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
31 | ):
32 | # `sample` projections.
33 | query = attn.to_q(hidden_states)
34 | key = attn.to_k(hidden_states)
35 | value = attn.to_v(hidden_states)
36 |
37 | inner_dim = key.shape[-1]
38 | head_dim = inner_dim // attn.heads
39 |
40 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
41 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
42 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
43 |
44 | if attn.norm_q is not None:
45 | query = attn.norm_q(query)
46 | if attn.norm_k is not None:
47 | key = attn.norm_k(key)
48 |
49 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
50 | if encoder_hidden_states is not None:
51 | # `context` projections.
52 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
53 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
54 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
55 |
56 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
57 | batch_size, -1, attn.heads, head_dim
58 | ).transpose(1, 2)
59 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
60 | batch_size, -1, attn.heads, head_dim
61 | ).transpose(1, 2)
62 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
63 | batch_size, -1, attn.heads, head_dim
64 | ).transpose(1, 2)
65 |
66 | if attn.norm_added_q is not None:
67 | encoder_hidden_states_query_proj = attn.norm_added_q(
68 | encoder_hidden_states_query_proj
69 | )
70 | if attn.norm_added_k is not None:
71 | encoder_hidden_states_key_proj = attn.norm_added_k(
72 | encoder_hidden_states_key_proj
73 | )
74 |
75 | # attention
76 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
77 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
78 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
79 |
80 | if image_rotary_emb is not None:
81 |
82 |
83 | query = apply_rotary_emb(query, image_rotary_emb)
84 | key = apply_rotary_emb(key, image_rotary_emb)
85 |
86 | if condition_latents is not None:
87 | cond_query = attn.to_q(condition_latents)
88 | cond_key = attn.to_k(condition_latents)
89 | cond_value = attn.to_v(condition_latents)
90 |
91 | cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
92 | 1, 2
93 | )
94 | cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
95 | cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
96 | 1, 2
97 | )
98 | if attn.norm_q is not None:
99 | cond_query = attn.norm_q(cond_query)
100 | if attn.norm_k is not None:
101 | cond_key = attn.norm_k(cond_key)
102 |
103 | #extra condition
104 | if extra_condition_latents is not None:
105 | extra_cond_query = attn.to_q(extra_condition_latents)
106 | extra_cond_key = attn.to_k(extra_condition_latents)
107 | extra_cond_value = attn.to_v(extra_condition_latents)
108 |
109 | extra_cond_query = extra_cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
110 | 1, 2
111 | )
112 | extra_cond_key = extra_cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
113 | extra_cond_value = extra_cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
114 | 1, 2
115 | )
116 | if attn.norm_q is not None:
117 | extra_cond_query = attn.norm_q(extra_cond_query)
118 | if attn.norm_k is not None:
119 | extra_cond_key = attn.norm_k(extra_cond_key)
120 |
121 |
122 | if extra_cond_rotary_emb is not None:
123 | extra_cond_query = apply_rotary_emb(extra_cond_query, extra_cond_rotary_emb)
124 | extra_cond_key = apply_rotary_emb(extra_cond_key, extra_cond_rotary_emb)
125 |
126 | if cond_rotary_emb is not None:
127 | cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
128 | cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
129 |
130 | if condition_latents is not None:
131 | if extra_condition_latents is not None:
132 |
133 | query = torch.cat([query, cond_query, extra_cond_query], dim=2)
134 | key = torch.cat([key, cond_key, extra_cond_key], dim=2)
135 | value = torch.cat([value, cond_value, extra_cond_value], dim=2)
136 | else:
137 | query = torch.cat([query, cond_query], dim=2)
138 | key = torch.cat([key, cond_key], dim=2)
139 | value = torch.cat([value, cond_value], dim=2)
140 | print("concat Omini latents: ", query.shape, key.shape, value.shape)
141 |
142 |
143 | if not model_config.get("union_cond_attn", True):
144 |
145 | attention_mask = torch.ones(
146 | query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
147 | )
148 | condition_n = cond_query.shape[2]
149 | attention_mask[-condition_n:, :-condition_n] = False
150 | attention_mask[:-condition_n, -condition_n:] = False
151 | elif model_config.get("independent_condition", False):
152 | attention_mask = torch.ones(
153 | query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
154 | )
155 | condition_n = cond_query.shape[2]
156 | attention_mask[-condition_n:, :-condition_n] = False
157 |
158 | if hasattr(attn, "c_factor"):
159 | attention_mask = torch.zeros(
160 | query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
161 | )
162 | condition_n = cond_query.shape[2]
163 | condition_e = extra_cond_query.shape[2]
164 | bias = torch.log(attn.c_factor[0])
165 | attention_mask[-condition_n-condition_e:-condition_e, :-condition_n-condition_e] = bias
166 | attention_mask[:-condition_n-condition_e, -condition_n-condition_e:-condition_e] = bias
167 |
168 | bias = torch.log(attn.c_factor[1])
169 | attention_mask[-condition_e:, :-condition_n-condition_e] = bias
170 | attention_mask[:-condition_n-condition_e, -condition_e:] = bias
171 |
172 | hidden_states = F.scaled_dot_product_attention(
173 | query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
174 | )
175 | hidden_states = hidden_states.transpose(1, 2).reshape(
176 | batch_size, -1, attn.heads * head_dim
177 | )
178 | hidden_states = hidden_states.to(query.dtype)
179 |
180 | if encoder_hidden_states is not None:
181 | if condition_latents is not None:
182 | if extra_condition_latents is not None:
183 | encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = (
184 | hidden_states[:, : encoder_hidden_states.shape[1]],
185 | hidden_states[
186 | :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]*2
187 | ],
188 | hidden_states[:, -condition_latents.shape[1]*2 :-condition_latents.shape[1]],
189 | hidden_states[:, -condition_latents.shape[1] :], #extra condition latents
190 | )
191 | else:
192 | encoder_hidden_states, hidden_states, condition_latents = (
193 | hidden_states[:, : encoder_hidden_states.shape[1]],
194 | hidden_states[
195 | :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
196 | ],
197 | hidden_states[:, -condition_latents.shape[1] :]
198 | )
199 | else:
200 | encoder_hidden_states, hidden_states = (
201 | hidden_states[:, : encoder_hidden_states.shape[1]],
202 | hidden_states[:, encoder_hidden_states.shape[1] :],
203 | )
204 |
205 | with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
206 | # linear proj
207 | hidden_states = attn.to_out[0](hidden_states)
208 | # dropout
209 | hidden_states = attn.to_out[1](hidden_states)
210 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
211 |
212 | if condition_latents is not None:
213 | condition_latents = attn.to_out[0](condition_latents)
214 | condition_latents = attn.to_out[1](condition_latents)
215 |
216 | if extra_condition_latents is not None:
217 | extra_condition_latents = attn.to_out[0](extra_condition_latents)
218 | extra_condition_latents = attn.to_out[1](extra_condition_latents)
219 |
220 |
221 | return (
222 | # (hidden_states, encoder_hidden_states, condition_latents, extra_condition_latents)
223 | (hidden_states, encoder_hidden_states, condition_latents, extra_condition_latents)
224 | if condition_latents is not None
225 | else (hidden_states, encoder_hidden_states)
226 | )
227 | elif condition_latents is not None:
228 | # if there are condition_latents, we need to separate the hidden_states and the condition_latents
229 | if extra_condition_latents is not None:
230 | hidden_states, condition_latents, extra_condition_latents = (
231 | hidden_states[:, : -condition_latents.shape[1]*2],
232 | hidden_states[:, -condition_latents.shape[1]*2 :-condition_latents.shape[1]],
233 | hidden_states[:, -condition_latents.shape[1] :],
234 | )
235 | else:
236 | hidden_states, condition_latents = (
237 | hidden_states[:, : -condition_latents.shape[1]],
238 | hidden_states[:, -condition_latents.shape[1] :],
239 | )
240 | return hidden_states, condition_latents, extra_condition_latents
241 | else:
242 | return hidden_states
243 |
244 |
245 | def block_forward(
246 | self,
247 | hidden_states: torch.FloatTensor,
248 | encoder_hidden_states: torch.FloatTensor,
249 | condition_latents: torch.FloatTensor,
250 | extra_condition_latents: torch.FloatTensor,
251 | temb: torch.FloatTensor,
252 | cond_temb: torch.FloatTensor,
253 | extra_cond_temb: torch.FloatTensor,
254 | cond_rotary_emb=None,
255 | extra_cond_rotary_emb=None,
256 | image_rotary_emb=None,
257 | model_config: Optional[Dict[str, Any]] = {},
258 | ):
259 | use_cond = condition_latents is not None
260 |
261 | use_extra_cond = extra_condition_latents is not None
262 | with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
263 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
264 | hidden_states, emb=temb
265 | )
266 |
267 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
268 | self.norm1_context(encoder_hidden_states, emb=temb)
269 | )
270 |
271 | if use_cond:
272 | (
273 | norm_condition_latents,
274 | cond_gate_msa,
275 | cond_shift_mlp,
276 | cond_scale_mlp,
277 | cond_gate_mlp,
278 | ) = self.norm1(condition_latents, emb=cond_temb)
279 | (
280 | norm_extra_condition_latents,
281 | extra_cond_gate_msa,
282 | extra_cond_shift_mlp,
283 | extra_cond_scale_mlp,
284 | extra_cond_gate_mlp,
285 | ) = self.norm1(extra_condition_latents, emb=extra_cond_temb)
286 |
287 | # Attention.
288 | result = attn_forward(
289 | self.attn,
290 | model_config=model_config,
291 | hidden_states=norm_hidden_states,
292 | encoder_hidden_states=norm_encoder_hidden_states,
293 | condition_latents=norm_condition_latents if use_cond else None,
294 | extra_condition_latents=norm_extra_condition_latents if use_cond else None,
295 | image_rotary_emb=image_rotary_emb,
296 | cond_rotary_emb=cond_rotary_emb if use_cond else None,
297 | extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_cond else None,
298 | )
299 |
300 | attn_output, context_attn_output = result[:2]
301 | cond_attn_output = result[2] if use_cond else None
302 | extra_condition_output = result[3]
303 |
304 | # Process attention outputs for the `hidden_states`.
305 | # 1. hidden_states
306 | attn_output = gate_msa.unsqueeze(1) * attn_output
307 | hidden_states = hidden_states + attn_output
308 | # 2. encoder_hidden_states
309 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
310 |
311 | encoder_hidden_states = encoder_hidden_states + context_attn_output
312 | # 3. condition_latents
313 | if use_cond:
314 | cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
315 | condition_latents = condition_latents + cond_attn_output
316 | #need to make new condition_extra and add extra_condition_output
317 | if use_extra_cond:
318 | extra_condition_output = extra_cond_gate_msa.unsqueeze(1) * extra_condition_output
319 | extra_condition_latents = extra_condition_latents + extra_condition_output
320 |
321 | if model_config.get("add_cond_attn", False):
322 | hidden_states += cond_attn_output
323 | hidden_states += extra_condition_output
324 |
325 |
326 | # LayerNorm + MLP.
327 | # 1. hidden_states
328 | norm_hidden_states = self.norm2(hidden_states)
329 | norm_hidden_states = (
330 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
331 | )
332 | # 2. encoder_hidden_states
333 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
334 | norm_encoder_hidden_states = (
335 | norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
336 | )
337 | # 3. condition_latents
338 | if use_cond:
339 | norm_condition_latents = self.norm2(condition_latents)
340 | norm_condition_latents = (
341 | norm_condition_latents * (1 + cond_scale_mlp[:, None])
342 | + cond_shift_mlp[:, None]
343 | )
344 |
345 | if use_extra_cond:
346 | #added conditions
347 | extra_norm_condition_latents = self.norm2(extra_condition_latents)
348 | extra_norm_condition_latents = (
349 | extra_norm_condition_latents * (1 + extra_cond_scale_mlp[:, None])
350 | + extra_cond_shift_mlp[:, None]
351 | )
352 |
353 | # Feed-forward.
354 | with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
355 | # 1. hidden_states
356 | ff_output = self.ff(norm_hidden_states)
357 | ff_output = gate_mlp.unsqueeze(1) * ff_output
358 | # 2. encoder_hidden_states
359 | context_ff_output = self.ff_context(norm_encoder_hidden_states)
360 | context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
361 | # 3. condition_latents
362 | if use_cond:
363 | cond_ff_output = self.ff(norm_condition_latents)
364 | cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
365 |
366 | if use_extra_cond:
367 | extra_cond_ff_output = self.ff(extra_norm_condition_latents)
368 | extra_cond_ff_output = extra_cond_gate_mlp.unsqueeze(1) * extra_cond_ff_output
369 |
370 | # Process feed-forward outputs.
371 | hidden_states = hidden_states + ff_output
372 | encoder_hidden_states = encoder_hidden_states + context_ff_output
373 | if use_cond:
374 | condition_latents = condition_latents + cond_ff_output
375 | if use_extra_cond:
376 | extra_condition_latents = extra_condition_latents + extra_cond_ff_output
377 |
378 | # Clip to avoid overflow.
379 | if encoder_hidden_states.dtype == torch.float16:
380 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
381 |
382 | return encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents if use_cond else None
383 |
384 |
385 | def single_block_forward(
386 | self,
387 | hidden_states: torch.FloatTensor,
388 | temb: torch.FloatTensor,
389 | image_rotary_emb=None,
390 | condition_latents: torch.FloatTensor = None,
391 | extra_condition_latents: torch.FloatTensor = None,
392 | cond_temb: torch.FloatTensor = None,
393 | extra_cond_temb: torch.FloatTensor = None,
394 | cond_rotary_emb=None,
395 | extra_cond_rotary_emb=None,
396 | model_config: Optional[Dict[str, Any]] = {},
397 | ):
398 |
399 | using_cond = condition_latents is not None
400 | using_extra_cond = extra_condition_latents is not None
401 | residual = hidden_states
402 | with enable_lora(
403 | (
404 | self.norm.linear,
405 | self.proj_mlp,
406 | ),
407 | model_config.get("latent_lora", False),
408 | ):
409 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
410 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
411 | if using_cond:
412 | residual_cond = condition_latents
413 | norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
414 | mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
415 |
416 | if using_extra_cond:
417 | extra_residual_cond = extra_condition_latents
418 | extra_norm_condition_latents, extra_cond_gate = self.norm(extra_condition_latents, emb=extra_cond_temb)
419 | extra_mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(extra_norm_condition_latents))
420 |
421 | attn_output = attn_forward(
422 | self.attn,
423 | model_config=model_config,
424 | hidden_states=norm_hidden_states,
425 | image_rotary_emb=image_rotary_emb,
426 | **(
427 | {
428 | "condition_latents": norm_condition_latents,
429 | "cond_rotary_emb": cond_rotary_emb if using_cond else None,
430 | "extra_condition_latents": extra_norm_condition_latents if using_cond else None,
431 | "extra_cond_rotary_emb": extra_cond_rotary_emb if using_cond else None,
432 | }
433 | if using_cond
434 | else {}
435 | ),
436 | )
437 |
438 | if using_cond:
439 | attn_output, cond_attn_output, extra_cond_attn_output = attn_output
440 |
441 |
442 | with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
443 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
444 | gate = gate.unsqueeze(1)
445 | hidden_states = gate * self.proj_out(hidden_states)
446 | hidden_states = residual + hidden_states
447 | if using_cond:
448 | condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
449 | cond_gate = cond_gate.unsqueeze(1)
450 | condition_latents = cond_gate * self.proj_out(condition_latents)
451 | condition_latents = residual_cond + condition_latents
452 |
453 | extra_condition_latents = torch.cat([extra_cond_attn_output, extra_mlp_cond_hidden_states], dim=2)
454 | extra_cond_gate = extra_cond_gate.unsqueeze(1)
455 | extra_condition_latents = extra_cond_gate * self.proj_out(extra_condition_latents)
456 | extra_condition_latents = extra_residual_cond + extra_condition_latents
457 |
458 | if hidden_states.dtype == torch.float16:
459 | hidden_states = hidden_states.clip(-65504, 65504)
460 |
461 | return hidden_states if not using_cond else (hidden_states, condition_latents, extra_condition_latents)
462 |
--------------------------------------------------------------------------------
/flux/condition.py:
--------------------------------------------------------------------------------
1 | # Recycled from Ominicontrol and modified to accept an extra condition.
2 | # While Zenctrl pursued a similar idea, it diverged structurally.
3 | # We appreciate the clarity of Omini's implementation and decided to align with it.
4 |
5 | import torch
6 | from typing import Union, Tuple
7 | from diffusers.pipelines import FluxPipeline
8 | from PIL import Image
9 |
10 |
11 | # from pipeline_tools import encode_images
12 | from .pipeline_tools import encode_images
13 |
14 | condition_dict = {
15 | "subject": 4,
16 | "sr": 10,
17 | "cot": 12,
18 | }
19 |
20 |
21 | class Condition(object):
22 | def __init__(
23 | self,
24 | condition_type: str,
25 | raw_img: Union[Image.Image, torch.Tensor] = None,
26 | condition: Union[Image.Image, torch.Tensor] = None,
27 | position_delta=None,
28 | ) -> None:
29 | self.condition_type = condition_type
30 | assert raw_img is not None or condition is not None
31 | if raw_img is not None:
32 | self.condition = self.get_condition(condition_type, raw_img)
33 | else:
34 | self.condition = condition
35 | self.position_delta = position_delta
36 |
37 |
38 | def get_condition(
39 | self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
40 | ) -> Union[Image.Image, torch.Tensor]:
41 | """
42 | Returns the condition image.
43 | """
44 | if condition_type == "subject":
45 | return raw_img
46 | elif condition_type == "sr":
47 | return raw_img
48 | elif condition_type == "cot":
49 | return raw_img
50 | return self.condition
51 |
52 |
53 | @property
54 | def type_id(self) -> int:
55 | """
56 | Returns the type id of the condition.
57 | """
58 | return condition_dict[self.condition_type]
59 |
60 | def encode(
61 | self, pipe: FluxPipeline, empty: bool = False
62 | ) -> Tuple[torch.Tensor, torch.Tensor, int]:
63 | """
64 | Encodes the condition into tokens, ids and type_id.
65 | """
66 | if self.condition_type in [
67 | "subject",
68 | "sr",
69 | "cot"
70 | ]:
71 | if empty:
72 | # make the condition black
73 | e_condition = Image.new("RGB", self.condition.size, (0, 0, 0))
74 | e_condition = e_condition.convert("RGB")
75 | tokens, ids = encode_images(pipe, e_condition)
76 | else:
77 | tokens, ids = encode_images(pipe, self.condition)
78 | else:
79 | raise NotImplementedError(
80 | f"Condition type {self.condition_type} not implemented"
81 | )
82 | if self.position_delta is None and self.condition_type == "subject":
83 | self.position_delta = [0, -self.condition.size[0] // 16]
84 | if self.position_delta is not None:
85 | ids[:, 1] += self.position_delta[0]
86 | ids[:, 2] += self.position_delta[1]
87 | type_id = torch.ones_like(ids[:, :1]) * self.type_id
88 | return tokens, ids, type_id
89 |
--------------------------------------------------------------------------------
/flux/generate.py:
--------------------------------------------------------------------------------
1 | # Recycled from Ominicontrol and modified to accept an extra condition.
2 | # While Zenctrl pursued a similar idea, it diverged structurally.
3 | # We appreciate the clarity of Omini's implementation and decided to align with it.
4 |
5 | import torch
6 | import yaml, os
7 | from diffusers.pipelines import FluxPipeline
8 | from typing import List, Union, Optional, Dict, Any, Callable
9 | from .transformer import tranformer_forward
10 | from .condition import Condition
11 |
12 |
13 | from diffusers.pipelines.flux.pipeline_flux import (
14 | FluxPipelineOutput,
15 | calculate_shift,
16 | retrieve_timesteps,
17 | np,
18 | )
19 |
20 |
21 | def get_config(config_path: str = None):
22 | config_path = config_path or os.environ.get("XFL_CONFIG")
23 | if not config_path:
24 | return {}
25 | with open(config_path, "r") as f:
26 | config = yaml.safe_load(f)
27 | return config
28 |
29 |
30 | def prepare_params(
31 | prompt: Union[str, List[str]] = None,
32 | prompt_2: Optional[Union[str, List[str]]] = None,
33 | height: Optional[int] = 512,
34 | width: Optional[int] = 512,
35 | num_inference_steps: int = 28,
36 | timesteps: List[int] = None,
37 | guidance_scale: float = 3.5,
38 | num_images_per_prompt: Optional[int] = 1,
39 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
40 | latents: Optional[torch.FloatTensor] = None,
41 | prompt_embeds: Optional[torch.FloatTensor] = None,
42 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
43 | output_type: Optional[str] = "pil",
44 | return_dict: bool = True,
45 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
46 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
47 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
48 | max_sequence_length: int = 512,
49 | **kwargs: dict,
50 | ):
51 | return (
52 | prompt,
53 | prompt_2,
54 | height,
55 | width,
56 | num_inference_steps,
57 | timesteps,
58 | guidance_scale,
59 | num_images_per_prompt,
60 | generator,
61 | latents,
62 | prompt_embeds,
63 | pooled_prompt_embeds,
64 | output_type,
65 | return_dict,
66 | joint_attention_kwargs,
67 | callback_on_step_end,
68 | callback_on_step_end_tensor_inputs,
69 | max_sequence_length,
70 | )
71 |
72 |
73 | def seed_everything(seed: int = 42):
74 | torch.backends.cudnn.deterministic = True
75 | torch.manual_seed(seed)
76 | np.random.seed(seed)
77 |
78 |
79 | @torch.no_grad()
80 | def generate(
81 | pipeline: FluxPipeline,
82 | conditions: List[Condition] = None,
83 | config_path: str = None,
84 | model_config: Optional[Dict[str, Any]] = {},
85 | condition_scale: float = [1, 1],
86 | default_lora: bool = False,
87 | image_guidance_scale: float = 1.0,
88 | **params: dict,
89 | ):
90 | model_config = model_config or get_config(config_path).get("model", {})
91 | if condition_scale != [1,1]:
92 | for name, module in pipeline.transformer.named_modules():
93 | if not name.endswith(".attn"):
94 | continue
95 | module.c_factor = torch.tensor(condition_scale)
96 |
97 | self = pipeline
98 | (
99 | prompt,
100 | prompt_2,
101 | height,
102 | width,
103 | num_inference_steps,
104 | timesteps,
105 | guidance_scale,
106 | num_images_per_prompt,
107 | generator,
108 | latents,
109 | prompt_embeds,
110 | pooled_prompt_embeds,
111 | output_type,
112 | return_dict,
113 | joint_attention_kwargs,
114 | callback_on_step_end,
115 | callback_on_step_end_tensor_inputs,
116 | max_sequence_length,
117 | ) = prepare_params(**params)
118 |
119 | height = height or self.default_sample_size * self.vae_scale_factor
120 | width = width or self.default_sample_size * self.vae_scale_factor
121 |
122 | # 1. Check inputs. Raise error if not correct
123 | self.check_inputs(
124 | prompt,
125 | prompt_2,
126 | height,
127 | width,
128 | prompt_embeds=prompt_embeds,
129 | pooled_prompt_embeds=pooled_prompt_embeds,
130 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
131 | max_sequence_length=max_sequence_length,
132 | )
133 |
134 | self._guidance_scale = guidance_scale
135 | self._joint_attention_kwargs = joint_attention_kwargs
136 | self._interrupt = False
137 |
138 | # 2. Define call parameters
139 | if prompt is not None and isinstance(prompt, str):
140 | batch_size = 1
141 | elif prompt is not None and isinstance(prompt, list):
142 | batch_size = len(prompt)
143 | else:
144 | batch_size = prompt_embeds.shape[0]
145 |
146 | device = self._execution_device
147 |
148 | lora_scale = (
149 | self.joint_attention_kwargs.get("scale", None)
150 | if self.joint_attention_kwargs is not None
151 | else None
152 | )
153 | (
154 | prompt_embeds,
155 | pooled_prompt_embeds,
156 | text_ids,
157 | ) = self.encode_prompt(
158 | prompt=prompt,
159 | prompt_2=prompt_2,
160 | prompt_embeds=prompt_embeds,
161 | pooled_prompt_embeds=pooled_prompt_embeds,
162 | device=device,
163 | num_images_per_prompt=num_images_per_prompt,
164 | max_sequence_length=max_sequence_length,
165 | lora_scale=lora_scale,
166 | )
167 |
168 | # 4. Prepare latent variables
169 | num_channels_latents = self.transformer.config.in_channels // 4
170 | latents, latent_image_ids = self.prepare_latents(
171 | batch_size * num_images_per_prompt,
172 | num_channels_latents,
173 | height,
174 | width,
175 | prompt_embeds.dtype,
176 | device,
177 | generator,
178 | latents,
179 | )
180 |
181 | # 4.1. Prepare conditions
182 | condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
183 | extra_condition_latents, extra_condition_ids, extra_condition_type_ids = ([] for _ in range(3))
184 | use_condition = conditions is not None or []
185 | if use_condition:
186 | if not default_lora:
187 | pipeline.set_adapters(conditions[1].condition_type)
188 | # for condition in conditions:
189 | tokens, ids, type_id = conditions[0].encode(self)
190 | condition_latents.append(tokens) # [batch_size, token_n, token_dim]
191 | condition_ids.append(ids) # [token_n, id_dim(3)]
192 | condition_type_ids.append(type_id) # [token_n, 1]
193 | condition_latents = torch.cat(condition_latents, dim=1)
194 | condition_ids = torch.cat(condition_ids, dim=0)
195 | condition_type_ids = torch.cat(condition_type_ids, dim=0)
196 |
197 | tokens, ids, type_id = conditions[1].encode(self)
198 | extra_condition_latents.append(tokens) # [batch_size, token_n, token_dim]
199 | extra_condition_ids.append(ids) # [token_n, id_dim(3)]
200 | extra_condition_type_ids.append(type_id) # [token_n, 1]
201 | extra_condition_latents = torch.cat(extra_condition_latents, dim=1)
202 | extra_condition_ids = torch.cat(extra_condition_ids, dim=0)
203 | extra_condition_type_ids = torch.cat(extra_condition_type_ids, dim=0)
204 |
205 | # 5. Prepare timesteps
206 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
207 | image_seq_len = latents.shape[1]
208 | mu = calculate_shift(
209 | image_seq_len,
210 | self.scheduler.config.base_image_seq_len,
211 | self.scheduler.config.max_image_seq_len,
212 | self.scheduler.config.base_shift,
213 | self.scheduler.config.max_shift,
214 | )
215 | timesteps, num_inference_steps = retrieve_timesteps(
216 | self.scheduler,
217 | num_inference_steps,
218 | device,
219 | timesteps,
220 | sigmas,
221 | mu=mu,
222 | )
223 | num_warmup_steps = max(
224 | len(timesteps) - num_inference_steps * self.scheduler.order, 0
225 | )
226 | self._num_timesteps = len(timesteps)
227 |
228 | # 6. Denoising loop
229 | with self.progress_bar(total=num_inference_steps) as progress_bar:
230 | for i, t in enumerate(timesteps):
231 | if self.interrupt:
232 | continue
233 |
234 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
235 | timestep = t.expand(latents.shape[0]).to(latents.dtype)
236 |
237 | # handle guidance
238 | if self.transformer.config.guidance_embeds:
239 | guidance = torch.tensor([guidance_scale], device=device)
240 | guidance = guidance.expand(latents.shape[0])
241 | else:
242 | guidance = None
243 | noise_pred = tranformer_forward(
244 | self.transformer,
245 | model_config=model_config,
246 | # Inputs of the condition (new feature)
247 | condition_latents=condition_latents if use_condition else None,
248 | condition_ids=condition_ids if use_condition else None,
249 | condition_type_ids=condition_type_ids if use_condition else None,
250 | extra_condition_latents=extra_condition_latents if use_condition else None,
251 | extra_condition_ids=extra_condition_ids if use_condition else None,
252 | extra_condition_type_ids=extra_condition_type_ids if use_condition else None,
253 | # Inputs to the original transformer
254 | hidden_states=latents,
255 | # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
256 | timestep=timestep / 1000,
257 | guidance=guidance,
258 | pooled_projections=pooled_prompt_embeds,
259 | encoder_hidden_states=prompt_embeds,
260 | txt_ids=text_ids,
261 | img_ids=latent_image_ids,
262 | joint_attention_kwargs=self.joint_attention_kwargs,
263 | return_dict=False,
264 | )[0]
265 |
266 | if image_guidance_scale != 1.0:
267 | uncondition_latents = conditions.encode(self, empty=True)[0]
268 | unc_pred = tranformer_forward(
269 | self.transformer,
270 | model_config=model_config,
271 | # Inputs of the condition (new feature)
272 | condition_latents=uncondition_latents if use_condition else None,
273 | condition_ids=condition_ids if use_condition else None,
274 | condition_type_ids=condition_type_ids if use_condition else None,
275 | # Inputs to the original transformer
276 | hidden_states=latents,
277 | # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
278 | timestep=timestep / 1000,
279 | guidance=torch.ones_like(guidance),
280 | pooled_projections=pooled_prompt_embeds,
281 | encoder_hidden_states=prompt_embeds,
282 | txt_ids=text_ids,
283 | img_ids=latent_image_ids,
284 | joint_attention_kwargs=self.joint_attention_kwargs,
285 | return_dict=False,
286 | )[0]
287 |
288 | noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
289 |
290 | # compute the previous noisy sample x_t -> x_t-1
291 | latents_dtype = latents.dtype
292 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
293 |
294 | if latents.dtype != latents_dtype:
295 | if torch.backends.mps.is_available():
296 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
297 | latents = latents.to(latents_dtype)
298 |
299 | if callback_on_step_end is not None:
300 | callback_kwargs = {}
301 | for k in callback_on_step_end_tensor_inputs:
302 | callback_kwargs[k] = locals()[k]
303 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
304 |
305 | latents = callback_outputs.pop("latents", latents)
306 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
307 |
308 | # call the callback, if provided
309 | if i == len(timesteps) - 1 or (
310 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
311 | ):
312 | progress_bar.update()
313 |
314 | if output_type == "latent":
315 | image = latents
316 |
317 | else:
318 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
319 | latents = (
320 | latents / self.vae.config.scaling_factor
321 | ) + self.vae.config.shift_factor
322 | image = self.vae.decode(latents, return_dict=False)[0]
323 | image = self.image_processor.postprocess(image, output_type=output_type)
324 |
325 | # Offload all models
326 | self.maybe_free_model_hooks()
327 |
328 | if condition_scale != [1,1]:
329 | for name, module in pipeline.transformer.named_modules():
330 | if not name.endswith(".attn"):
331 | continue
332 | del module.c_factor
333 |
334 | if not return_dict:
335 | return (image,)
336 |
337 | return FluxPipelineOutput(images=image)
338 |
--------------------------------------------------------------------------------
/flux/lora_controller.py:
--------------------------------------------------------------------------------
1 | #As is from OminiControl
2 | from peft.tuners.tuners_utils import BaseTunerLayer
3 | from typing import List, Any, Optional, Type
4 | from .condition import condition_dict
5 |
6 |
7 | class enable_lora:
8 | def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
9 | self.activated: bool = activated
10 | if activated:
11 | return
12 | self.lora_modules: List[BaseTunerLayer] = [
13 | each for each in lora_modules if isinstance(each, BaseTunerLayer)
14 | ]
15 | self.scales = [
16 | {
17 | active_adapter: lora_module.scaling[active_adapter]
18 | for active_adapter in lora_module.active_adapters
19 | }
20 | for lora_module in self.lora_modules
21 | ]
22 |
23 | def __enter__(self) -> None:
24 | if self.activated:
25 | return
26 |
27 | for lora_module in self.lora_modules:
28 | if not isinstance(lora_module, BaseTunerLayer):
29 | continue
30 | for active_adapter in lora_module.active_adapters:
31 | if (
32 | active_adapter in condition_dict.keys()
33 | or active_adapter == "default"
34 | ):
35 | lora_module.scaling[active_adapter] = 0.0
36 |
37 | def __exit__(
38 | self,
39 | exc_type: Optional[Type[BaseException]],
40 | exc_val: Optional[BaseException],
41 | exc_tb: Optional[Any],
42 | ) -> None:
43 | if self.activated:
44 | return
45 | for i, lora_module in enumerate(self.lora_modules):
46 | if not isinstance(lora_module, BaseTunerLayer):
47 | continue
48 | for active_adapter in lora_module.active_adapters:
49 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
50 |
51 |
52 | class set_lora_scale:
53 | def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
54 | self.lora_modules: List[BaseTunerLayer] = [
55 | each for each in lora_modules if isinstance(each, BaseTunerLayer)
56 | ]
57 | self.scales = [
58 | {
59 | active_adapter: lora_module.scaling[active_adapter]
60 | for active_adapter in lora_module.active_adapters
61 | }
62 | for lora_module in self.lora_modules
63 | ]
64 | self.scale = scale
65 |
66 | def __enter__(self) -> None:
67 | for lora_module in self.lora_modules:
68 | if not isinstance(lora_module, BaseTunerLayer):
69 | continue
70 | lora_module.scale_layer(self.scale)
71 |
72 | def __exit__(
73 | self,
74 | exc_type: Optional[Type[BaseException]],
75 | exc_val: Optional[BaseException],
76 | exc_tb: Optional[Any],
77 | ) -> None:
78 | for i, lora_module in enumerate(self.lora_modules):
79 | if not isinstance(lora_module, BaseTunerLayer):
80 | continue
81 | for active_adapter in lora_module.active_adapters:
82 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
--------------------------------------------------------------------------------
/flux/pipeline_tools.py:
--------------------------------------------------------------------------------
1 | #As is from OminiControl
2 | from diffusers.pipelines import FluxPipeline
3 | from diffusers.utils import logging
4 | from diffusers.pipelines.flux.pipeline_flux import logger
5 | from torch import Tensor
6 |
7 |
8 | def encode_images(pipeline: FluxPipeline, images: Tensor):
9 | images = pipeline.image_processor.preprocess(images)
10 | images = images.to(pipeline.device).to(pipeline.dtype)
11 | images = pipeline.vae.encode(images).latent_dist.sample()
12 | images = (
13 | images - pipeline.vae.config.shift_factor
14 | ) * pipeline.vae.config.scaling_factor
15 | images_tokens = pipeline._pack_latents(images, *images.shape)
16 | images_ids = pipeline._prepare_latent_image_ids(
17 | images.shape[0],
18 | images.shape[2],
19 | images.shape[3],
20 | pipeline.device,
21 | pipeline.dtype,
22 | )
23 | if images_tokens.shape[1] != images_ids.shape[0]:
24 | images_ids = pipeline._prepare_latent_image_ids(
25 | images.shape[0],
26 | images.shape[2] // 2,
27 | images.shape[3] // 2,
28 | pipeline.device,
29 | pipeline.dtype,
30 | )
31 | return images_tokens, images_ids
32 |
33 |
34 | def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512):
35 | # Turn off warnings (CLIP overflow)
36 | logger.setLevel(logging.ERROR)
37 | (
38 | prompt_embeds,
39 | pooled_prompt_embeds,
40 | text_ids,
41 | ) = pipeline.encode_prompt(
42 | prompt=prompts,
43 | prompt_2=None,
44 | prompt_embeds=None,
45 | pooled_prompt_embeds=None,
46 | device=pipeline.device,
47 | num_images_per_prompt=1,
48 | max_sequence_length=max_sequence_length,
49 | lora_scale=None,
50 | )
51 | # Turn on warnings
52 | logger.setLevel(logging.WARNING)
53 | return prompt_embeds, pooled_prompt_embeds, text_ids
54 |
--------------------------------------------------------------------------------
/flux/transformer.py:
--------------------------------------------------------------------------------
1 | # Recycled from Ominicontrol and modified to accept an extra condition.
2 | # While Zenctrl pursued a similar idea, it diverged structurally.
3 | # We appreciate the clarity of Omini's implementation and decided to align with it.
4 |
5 | import torch
6 | from typing import Optional, Dict, Any
7 | from .block import block_forward, single_block_forward
8 | from .lora_controller import enable_lora
9 | from accelerate.utils import is_torch_version
10 | from diffusers.models.transformers.transformer_flux import (
11 | FluxTransformer2DModel,
12 | Transformer2DModelOutput,
13 | USE_PEFT_BACKEND,
14 | scale_lora_layers,
15 | unscale_lora_layers,
16 | logger,
17 | )
18 | import numpy as np
19 |
20 |
21 | def prepare_params(
22 | hidden_states: torch.Tensor,
23 | encoder_hidden_states: torch.Tensor = None,
24 | pooled_projections: torch.Tensor = None,
25 | timestep: torch.LongTensor = None,
26 | img_ids: torch.Tensor = None,
27 | txt_ids: torch.Tensor = None,
28 | guidance: torch.Tensor = None,
29 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
30 | controlnet_block_samples=None,
31 | controlnet_single_block_samples=None,
32 | return_dict: bool = True,
33 | **kwargs: dict,
34 | ):
35 | return (
36 | hidden_states,
37 | encoder_hidden_states,
38 | pooled_projections,
39 | timestep,
40 | img_ids,
41 | txt_ids,
42 | guidance,
43 | joint_attention_kwargs,
44 | controlnet_block_samples,
45 | controlnet_single_block_samples,
46 | return_dict,
47 | )
48 |
49 |
50 | def tranformer_forward(
51 | transformer: FluxTransformer2DModel,
52 | condition_latents: torch.Tensor,
53 | extra_condition_latents: torch.Tensor,
54 | condition_ids: torch.Tensor,
55 | condition_type_ids: torch.Tensor,
56 | extra_condition_ids: torch.Tensor,
57 | extra_condition_type_ids: torch.Tensor,
58 | model_config: Optional[Dict[str, Any]] = {},
59 | c_t=0,
60 | **params: dict,
61 | ):
62 | self = transformer
63 | use_condition = condition_latents is not None
64 | use_extra_condition = extra_condition_latents is not None
65 |
66 | (
67 | hidden_states,
68 | encoder_hidden_states,
69 | pooled_projections,
70 | timestep,
71 | img_ids,
72 | txt_ids,
73 | guidance,
74 | joint_attention_kwargs,
75 | controlnet_block_samples,
76 | controlnet_single_block_samples,
77 | return_dict,
78 | ) = prepare_params(**params)
79 |
80 | if joint_attention_kwargs is not None:
81 | joint_attention_kwargs = joint_attention_kwargs.copy()
82 | lora_scale = joint_attention_kwargs.pop("scale", 1.0)
83 | else:
84 | lora_scale = 1.0
85 |
86 | if USE_PEFT_BACKEND:
87 | # weight the lora layers by setting `lora_scale` for each PEFT layer
88 | scale_lora_layers(self, lora_scale)
89 | else:
90 | if (
91 | joint_attention_kwargs is not None
92 | and joint_attention_kwargs.get("scale", None) is not None
93 | ):
94 | logger.warning(
95 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
96 | )
97 |
98 | with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
99 | hidden_states = self.x_embedder(hidden_states)
100 | condition_latents = self.x_embedder(condition_latents) if use_condition else None
101 | extra_condition_latents = self.x_embedder(extra_condition_latents) if use_extra_condition else None
102 |
103 | timestep = timestep.to(hidden_states.dtype) * 1000
104 |
105 | if guidance is not None:
106 | guidance = guidance.to(hidden_states.dtype) * 1000
107 | else:
108 | guidance = None
109 |
110 | temb = (
111 | self.time_text_embed(timestep, pooled_projections)
112 | if guidance is None
113 | else self.time_text_embed(timestep, guidance, pooled_projections)
114 | )
115 |
116 | cond_temb = (
117 | self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
118 | if guidance is None
119 | else self.time_text_embed(
120 | torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections
121 | )
122 | )
123 | extra_cond_temb = (
124 | self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
125 | if guidance is None
126 | else self.time_text_embed(
127 | torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections
128 | )
129 | )
130 | encoder_hidden_states = self.context_embedder(encoder_hidden_states)
131 |
132 | if txt_ids.ndim == 3:
133 | logger.warning(
134 | "Passing `txt_ids` 3d torch.Tensor is deprecated."
135 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
136 | )
137 | txt_ids = txt_ids[0]
138 | if img_ids.ndim == 3:
139 | logger.warning(
140 | "Passing `img_ids` 3d torch.Tensor is deprecated."
141 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
142 | )
143 | img_ids = img_ids[0]
144 |
145 | ids = torch.cat((txt_ids, img_ids), dim=0)
146 | image_rotary_emb = self.pos_embed(ids)
147 | if use_condition:
148 | # condition_ids[:, :1] = condition_type_ids
149 | cond_rotary_emb = self.pos_embed(condition_ids)
150 |
151 | if use_extra_condition:
152 | extra_cond_rotary_emb = self.pos_embed(extra_condition_ids)
153 |
154 |
155 | # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
156 |
157 | #print("here!")
158 | for index_block, block in enumerate(self.transformer_blocks):
159 | if self.training and self.gradient_checkpointing:
160 | ckpt_kwargs: Dict[str, Any] = (
161 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
162 | )
163 | encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = (
164 | torch.utils.checkpoint.checkpoint(
165 | block_forward,
166 | self=block,
167 | model_config=model_config,
168 | hidden_states=hidden_states,
169 | encoder_hidden_states=encoder_hidden_states,
170 | condition_latents=condition_latents if use_condition else None,
171 | extra_condition_latents=extra_condition_latents if use_extra_condition else None,
172 | temb=temb,
173 | cond_temb=cond_temb if use_condition else None,
174 | cond_rotary_emb=cond_rotary_emb if use_condition else None,
175 | extra_cond_temb=extra_cond_temb if use_extra_condition else None,
176 | extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_condition else None,
177 | image_rotary_emb=image_rotary_emb,
178 | **ckpt_kwargs,
179 | )
180 | )
181 |
182 | else:
183 | encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = block_forward(
184 | block,
185 | model_config=model_config,
186 | hidden_states=hidden_states,
187 | encoder_hidden_states=encoder_hidden_states,
188 | condition_latents=condition_latents if use_condition else None,
189 | extra_condition_latents=extra_condition_latents if use_extra_condition else None,
190 | temb=temb,
191 | cond_temb=cond_temb if use_condition else None,
192 | cond_rotary_emb=cond_rotary_emb if use_condition else None,
193 | extra_cond_temb=cond_temb if use_extra_condition else None,
194 | extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_condition else None,
195 | image_rotary_emb=image_rotary_emb,
196 | )
197 |
198 | # controlnet residual
199 | if controlnet_block_samples is not None:
200 | interval_control = len(self.transformer_blocks) / len(
201 | controlnet_block_samples
202 | )
203 | interval_control = int(np.ceil(interval_control))
204 | hidden_states = (
205 | hidden_states
206 | + controlnet_block_samples[index_block // interval_control]
207 | )
208 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
209 |
210 |
211 | for index_block, block in enumerate(self.single_transformer_blocks):
212 | if self.training and self.gradient_checkpointing:
213 | ckpt_kwargs: Dict[str, Any] = (
214 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
215 | )
216 | result = torch.utils.checkpoint.checkpoint(
217 | single_block_forward,
218 | self=block,
219 | model_config=model_config,
220 | hidden_states=hidden_states,
221 | temb=temb,
222 | image_rotary_emb=image_rotary_emb,
223 | **(
224 | {
225 | "condition_latents": condition_latents,
226 | "extra_condition_latents": extra_condition_latents,
227 | "cond_temb": cond_temb,
228 | "cond_rotary_emb": cond_rotary_emb,
229 | "extra_cond_temb": extra_cond_temb,
230 | "extra_cond_rotary_emb": extra_cond_rotary_emb,
231 | }
232 | if use_condition
233 | else {}
234 | ),
235 | **ckpt_kwargs,
236 | )
237 |
238 | else:
239 | result = single_block_forward(
240 | block,
241 | model_config=model_config,
242 | hidden_states=hidden_states,
243 | temb=temb,
244 | image_rotary_emb=image_rotary_emb,
245 | **(
246 | {
247 | "condition_latents": condition_latents,
248 | "extra_condition_latents": extra_condition_latents,
249 | "cond_temb": cond_temb,
250 | "cond_rotary_emb": cond_rotary_emb,
251 | "extra_cond_temb": extra_cond_temb,
252 | "extra_cond_rotary_emb": extra_cond_rotary_emb,
253 | }
254 | if use_condition
255 | else {}
256 | ),
257 | )
258 | if use_condition:
259 | hidden_states, condition_latents, extra_condition_latents = result
260 | else:
261 | hidden_states = result
262 |
263 | # controlnet residual
264 | if controlnet_single_block_samples is not None:
265 | interval_control = len(self.single_transformer_blocks) / len(
266 | controlnet_single_block_samples
267 | )
268 | interval_control = int(np.ceil(interval_control))
269 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
270 | hidden_states[:, encoder_hidden_states.shape[1] :, ...]
271 | + controlnet_single_block_samples[index_block // interval_control]
272 | )
273 |
274 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
275 |
276 | hidden_states = self.norm_out(hidden_states, temb)
277 | output = self.proj_out(hidden_states)
278 |
279 | if USE_PEFT_BACKEND:
280 | # remove `lora_scale` from each PEFT layer
281 | unscale_lora_layers(self, lora_scale)
282 |
283 | if not return_dict:
284 | return (output,)
285 | return Transformer2DModelOutput(sample=output)
286 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers
2 | diffusers==0.31.0
3 | peft==0.15.2
4 | opencv-python
5 | protobuf
6 | sentencepiece
7 | gradio
8 | jupyter
9 | torchao
--------------------------------------------------------------------------------
/samples/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/1.png
--------------------------------------------------------------------------------
/samples/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/12.png
--------------------------------------------------------------------------------
/samples/13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/13.png
--------------------------------------------------------------------------------
/samples/14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/14.png
--------------------------------------------------------------------------------
/samples/18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/18.png
--------------------------------------------------------------------------------
/samples/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/7.png
--------------------------------------------------------------------------------
/samples/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/8.png
--------------------------------------------------------------------------------
/samples/image_6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/image_6.png
--------------------------------------------------------------------------------