├── LICENSE ├── README.md ├── app ├── __init__.py └── gradio_app.py ├── flux ├── __init__.py ├── block.py ├── condition.py ├── generate.py ├── lora_controller.py ├── pipeline_tools.py └── transformer.py ├── requirements.txt └── samples ├── 1.png ├── 12.png ├── 13.png ├── 14.png ├── 18.png ├── 7.png ├── 8.png └── image_6.png /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2025] [Fotographer AI Inc] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | ZenCtrl Banner 4 | 5 |

ZenCtrl

6 |
7 | 8 | **An all-in-one, control framework for unified visual content creation using GenAI.** 9 | Generate multi-view, diverse-scene, and task-specific high-resolution images from a single subject image—without fine-tuning. 10 | 11 |
12 | HuggingFace Model 13 | HuggingFace Space 14 | Discord 15 | LP 16 | X 17 |
18 | 19 | --- 20 | 21 | ## 🧠 Overview 22 | 23 | **ZenCtrl** is a comprehensive toolkit built to tackle core challenges in image generation: 24 | 25 | - No fine-tuning needed – works from **a single subject image** 26 | - Maintains **control over shape, pose, camera angle, context** 27 | - Supports **high-resolution**, multi-scene generation 28 | - Modular toolkit for preprocessing, control, editing, and post-processing tasks 29 | 30 | ZenCtrl is based on OminiControl but enhanced with more fine-grained control, consistent subject preservation, and more improved and ready-to-use models. Our goal is to build an **agentic visual generation system** that can orchestrate image/video creation from **LLM-driven recipes.** 31 | 32 |
33 | 34 | ZenCtrl Collage 35 | 36 |
37 | 38 | --- 39 | 40 | ## 🛠 Toolkit Components (coming soon) 41 | 42 | ### 🧹 Preprocessing 43 | 44 | - Background removal 45 | - Matting 46 | - Reshaping 47 | - Segmentation 48 | 49 | ### 🎮 Control Models 50 | 51 | - Shape (HED, Scribble, Depth) 52 | - Pose (OpenPose, DensePose) 53 | - Mask control 54 | - Camera/View control 55 | 56 | ### 🎨 Post-processing 57 | 58 | - Deblurring 59 | - Color fixing 60 | - Natural blending 61 | 62 | ### ✏️ Editing Models 63 | 64 | - Inpainting (removal, masked editing, replacement) 65 | - Outpainting 66 | - Transformation / Motion 67 | - Relighting 68 | 69 | --- 70 | 71 | ## 🎯 Supported Tasks 72 | 73 | - Background generation 74 | - Controlled background generation 75 | - Subject-consistent context-aware generation 76 | - Object and subject placement (coming soon) 77 | - In-context image/video generation (coming soon) 78 | - Multi-object/subject merging & blending (coming soon) 79 | - Video generation (coming soon) 80 | 81 | --- 82 | 83 | ## 📦 Target Use Cases 84 | 85 | - Product photography 86 | - Fashion & accessory try-on 87 | - Virtual try-on (shoes, hats, glasses, etc.) 88 | - People & portrait control 89 | - Illustration, animation, and ad creatives 90 | 91 | All of these tasks can be **mixed and layered** — ZenCtrl is designed to support real-world visual workflows with **agentic task composition**. 92 | 93 | --- 94 | 95 | ## 📢 News 96 | 97 | - **2025-03-24**: 🧠 First release — model weights available on Hugging Face! 98 | - **2025-05-06**: 📢 Update — ource code release, latest model weights available on Hugging Face! 99 | - **Coming Soon**: Quick Start guide, Upscaling source code, Example notebooks 100 | - **Next**: Controlled fine-grain version on our platform and API (Pro version) 101 | - **Future**: Video generation toolkit release 102 | 103 | --- 104 | 105 | ## 🚀 Quick Start 106 | 107 | Before running the Gradio code, please install the requirements and download the weights from our HuggingFace repository: 108 | 👉 [https://huggingface.co/fotographerai/zenctrl_tools](https://huggingface.co/fotographerai/zenctrl_tools) 109 | 110 | We matched our original code with the Omnicontrol structure. Our model takes two inputs instead, but we are going to release the original code soon with the LLaMA task driver — so stay tuned. We will also update the tasks for specific verticals (e.g., virtual try-on, ad creatives, etc.). 111 | 112 | --- 113 | 114 | ### Quick Setup (CMD) 115 | 116 | You can follow the step-by-step setup instructions below: 117 | 118 | ```cmd 119 | *** Cloning and setting up ZenCtrl 120 | git clone https://github.com/FotographerAI/ZenCtrl.git 121 | cd ZenCtrl 122 | 123 | *** Creating virtual environment 124 | python -m venv venv 125 | call venv\Scripts\activate.bat 126 | 127 | *** Installing PyTorch and requirements 128 | pip install torch==2.7.0+cu128 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 129 | pip install --upgrade pip wheel setuptools 130 | pip install -r requirements.txt 131 | 132 | *** Downloading model weights 133 | curl --create-dirs -L https://huggingface.co/fotographerai/zenctrl_tools/resolve/main/weights/zen2con_1440_17000/pytorch_lora_weights.safetensors -o weights\zen2con_1440_17000\pytorch_lora_weights.safetensors 134 | 135 | *** All set! Launching Gradio app 136 | python app/gradio_app.py 137 | ``` 138 | 139 | --- 140 | 141 | ## 🎨 Demo 142 | 143 | #### Examples 144 | 145 |
146 |
147 | 148 | 149 | 150 | bottle on top of a rock 151 | 152 | 153 | 154 | 155 | bottle on top of a rock 156 | 157 |
158 |
159 | 160 | 161 | 162 | bottle on top of a rock 163 | 164 | 165 | 166 | 167 | bottle on top of a rock 168 | 169 |
170 |
171 | 172 | ### 🧪 Try it now on [Hugging Face Space](https://huggingface.co/spaces/fotographerai/ZenCtrl) 173 | 174 | 180 | 181 | --- 182 | 183 | ## 🔧 Models (Updated Weights Released) 184 | 185 | | Type | Name | Base | Resolution | Description | links | 186 | | --------------------- | --------------------- | ------------ | ---------- | --------------------------------- | ---------------------------------------------------------- | 187 | | Subject Generation | `zen2con_1440_17000` | FLUX.1 | 1024x1024 | Core model for subject-driven gen | [link](https://huggingface.co/fotographerai/zenctrl_tools/tree/main/weights/zen2con_1440_17000) | 188 | | Bg generation + Canny | `bg_canny_58000_1024` | FLUX.1 | 1024x1024 | Enhanced background control | [link](https://huggingface.co/fotographerai/zenctrl_tools) | 189 | | Deblurring Model | `deblurr_1024_10000` | OminiControl | 1024x1024 | Quality recovery post-generation | [link](https://huggingface.co/fotographerai/zenctrl_tools) | 190 | 191 | --- 192 | 193 | ## 🚧 Limitations 194 | 195 | 1. Models currently perform best with **objects**, and to some extent **humans**. 196 | 2. Resolution support is currently capped at **1024x1024** (higher quality coming soon). 197 | 3. Performance with **illustrations** is currently limited. 198 | 4. The models were **not trained on large-scale or highly diverse datasets** yet — we plan to improve quality and variation by training on larger and more diverse datasets, especially for **illustration and stylized content**. 199 | 5. Video support and the full **agentic task pipeline** are still under development. 200 | 201 | --- 202 | 203 | ## 📋 To-do 204 | 205 | - [x] Release early pretrained model weights for defined tasks 206 | - [x] Release additional task-specific models and modes 207 | - [x] Release open source code 208 | - [x] Launch API access via Baseten for easier deployment 209 | - [ ] Release Quick Start guide and example notebooks 210 | - [ ] Launch API access via our app for easier deployment 211 | - [ ] Release high-resolution models (1500×1500+) 212 | - [ ] Enable full toolkit integration with agent API 213 | - [ ] Add video generation module 214 | 215 | --- 216 | 217 | ## 🤝 Join the Community 218 | 219 | - 💬 [Discord](https://discord.com/invite/b9RuYQ3F8k) – share ideas and feedback 220 | - 🌐 [Landing Page](https://fotographer.ai/zen-control) 221 | - 🧪 [Try it now on Hugging Face Space](https://huggingface.co/fotographerai/zenctrl_tools/tree/main/weights) 222 | 223 | 224 | --- 225 | 226 | ## 🤝 Community Collaboration 227 | 228 | We hope to collaborate closely with the open-source community to make **ZenCtrl** a powerful and extensible toolkit for visual content creation. 229 | Once the source code is released, we welcome contributions in training, expanding supported use cases, and developing new task-specific modules. 230 | Our vision is to make ZenCtrl the **standard framework** for agentic, high-quality image and video generation — built together, for everyone. 231 | -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/app/__init__.py -------------------------------------------------------------------------------- /app/gradio_app.py: -------------------------------------------------------------------------------- 1 | # Recycled from Ominicontrol 2 | 3 | import gradio as gr 4 | import torch 5 | from PIL import Image 6 | from diffusers.pipelines import FluxPipeline 7 | from diffusers import FluxTransformer2DModel 8 | 9 | from flux.condition import Condition 10 | from flux.generate import generate 11 | from flux.lora_controller import set_lora_scale 12 | 13 | pipe = None 14 | use_int8 = False 15 | model_config = { "union_cond_attn": True, "add_cond_attn": False, "latent_lora": False, "independent_condition": False} 16 | 17 | def get_gpu_memory(): 18 | return torch.cuda.get_device_properties(0).total_memory / 1024**3 19 | 20 | 21 | def init_pipeline(): 22 | global pipe 23 | if use_int8 or get_gpu_memory() < 33: 24 | transformer_model = FluxTransformer2DModel.from_pretrained( 25 | "sayakpaul/flux.1-schell-int8wo-improved", 26 | torch_dtype=torch.bfloat16, 27 | use_safetensors=False, 28 | ) 29 | pipe = FluxPipeline.from_pretrained( 30 | "black-forest-labs/FLUX.1-schnell", 31 | transformer=transformer_model, 32 | torch_dtype=torch.bfloat16, 33 | ) 34 | else: 35 | pipe = FluxPipeline.from_pretrained( 36 | "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 37 | ) 38 | pipe = pipe.to("cuda") 39 | 40 | # Optional: Load additional LoRA weights, put the loaded weigths here! 41 | pipe.load_lora_weights("weights/zen2con_1440_17000/pytorch_lora_weights.safetensors", 42 | adapter_name="subject") 43 | pipe.set_adapters(["subject"]) 44 | 45 | def paste_on_white_background(image: Image.Image) -> Image.Image: 46 | """ 47 | Pastes a transparent image onto a white background of the same size. 48 | """ 49 | if image.mode != "RGBA": 50 | image = image.convert("RGBA") 51 | 52 | # Create white background 53 | white_bg = Image.new("RGBA", image.size, (255, 255, 255, 255)) 54 | white_bg.paste(image, (0, 0), mask=image) 55 | return white_bg.convert("RGB") # Convert back to RGB if you don't need alpha 56 | 57 | 58 | def process_image_and_text(image, text, steps=8, strength_sub=1.0, strength_spat=1.0, size=1024): 59 | # center crop image 60 | w, h, min_size = image.size[0], image.size[1], min(image.size) 61 | image = image.crop( 62 | ( 63 | (w - min_size) // 2, 64 | (h - min_size) // 2, 65 | (w + min_size) // 2, 66 | (h + min_size) // 2, 67 | ) 68 | ) 69 | image = image.resize((size, size)) 70 | image = paste_on_white_background(image) #Optional, you can remove this line if you want just make sure the size it matched. 71 | condition0 = Condition("subject", image, position_delta=(0, size // 16)) 72 | condition1 = Condition("subject", image, position_delta=(0, -size // 16)) 73 | 74 | if pipe is None: 75 | init_pipeline() 76 | 77 | with set_lora_scale(["subject"], scale=3.0): 78 | result_img = generate( 79 | pipe, 80 | prompt=text.strip(), 81 | conditions=[condition0, condition1], 82 | num_inference_steps=steps, 83 | height=1024, 84 | width=1024, 85 | condition_scale = [strength_sub,strength_spat], 86 | model_config=model_config, 87 | default_lora=True, 88 | ).images[0] 89 | 90 | return [condition0.condition, condition1.condition, result_img] 91 | 92 | 93 | def get_samples(): 94 | sample_list = [ 95 | { 96 | "image": "samples/1.png", #place your image path here 97 | "text": "A man sitting in a yellow chair drinking a cup of coffee", 98 | } 99 | ] 100 | return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list] 101 | 102 | 103 | demo = gr.Interface( 104 | fn=process_image_and_text, 105 | inputs=[ 106 | gr.Image(type="pil"), 107 | gr.Textbox(lines=2), 108 | gr.Slider(minimum=2, maximum=28, value=2, label="steps"), 109 | gr.Slider(minimum=0, maximum=2.0, value=1.0, label="strength_sub"), 110 | gr.Slider(minimum=0, maximum=2.0, value=1.0, label="strength_spat"), 111 | gr.Slider(minimum=512, maximum=2048, value=1024, label="size"), 112 | ], 113 | outputs=gr.Gallery( 114 | label="Outputs", show_label=False, elem_id="gallery", 115 | columns=[3], rows=[1], object_fit="contain", height="auto" 116 | ), 117 | title="ZenCtrl / Subject driven generation", 118 | examples=get_samples(), 119 | ) 120 | 121 | if __name__ == "__main__": 122 | init_pipeline() 123 | demo.launch( 124 | debug=True, 125 | # share=True 126 | ) 127 | -------------------------------------------------------------------------------- /flux/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/flux/__init__.py -------------------------------------------------------------------------------- /flux/block.py: -------------------------------------------------------------------------------- 1 | # Recycled from Ominicontrol and modified to accept an extra condition. 2 | # While Zenctrl pursued a similar idea, it diverged structurally. 3 | # We appreciate the clarity of Omini's implementation and decided to align with it. 4 | 5 | import torch 6 | from typing import Optional, Dict, Any 7 | from diffusers.models.attention_processor import Attention, F 8 | from .lora_controller import enable_lora 9 | from diffusers.models.embeddings import apply_rotary_emb 10 | 11 | def attn_forward( 12 | attn: Attention, 13 | hidden_states: torch.FloatTensor, 14 | encoder_hidden_states: torch.FloatTensor = None, 15 | condition_latents: torch.FloatTensor = None, 16 | extra_condition_latents: torch.FloatTensor = None, 17 | attention_mask: Optional[torch.FloatTensor] = None, 18 | image_rotary_emb: Optional[torch.Tensor] = None, 19 | cond_rotary_emb: Optional[torch.Tensor] = None, 20 | extra_cond_rotary_emb: Optional[torch.Tensor] = None, 21 | model_config: Optional[Dict[str, Any]] = {}, 22 | ) -> torch.FloatTensor: 23 | batch_size, _, _ = ( 24 | hidden_states.shape 25 | if encoder_hidden_states is None 26 | else encoder_hidden_states.shape 27 | ) 28 | 29 | with enable_lora( 30 | (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False) 31 | ): 32 | # `sample` projections. 33 | query = attn.to_q(hidden_states) 34 | key = attn.to_k(hidden_states) 35 | value = attn.to_v(hidden_states) 36 | 37 | inner_dim = key.shape[-1] 38 | head_dim = inner_dim // attn.heads 39 | 40 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 41 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 42 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 43 | 44 | if attn.norm_q is not None: 45 | query = attn.norm_q(query) 46 | if attn.norm_k is not None: 47 | key = attn.norm_k(key) 48 | 49 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` 50 | if encoder_hidden_states is not None: 51 | # `context` projections. 52 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 53 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 54 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 55 | 56 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( 57 | batch_size, -1, attn.heads, head_dim 58 | ).transpose(1, 2) 59 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( 60 | batch_size, -1, attn.heads, head_dim 61 | ).transpose(1, 2) 62 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( 63 | batch_size, -1, attn.heads, head_dim 64 | ).transpose(1, 2) 65 | 66 | if attn.norm_added_q is not None: 67 | encoder_hidden_states_query_proj = attn.norm_added_q( 68 | encoder_hidden_states_query_proj 69 | ) 70 | if attn.norm_added_k is not None: 71 | encoder_hidden_states_key_proj = attn.norm_added_k( 72 | encoder_hidden_states_key_proj 73 | ) 74 | 75 | # attention 76 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) 77 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) 78 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) 79 | 80 | if image_rotary_emb is not None: 81 | 82 | 83 | query = apply_rotary_emb(query, image_rotary_emb) 84 | key = apply_rotary_emb(key, image_rotary_emb) 85 | 86 | if condition_latents is not None: 87 | cond_query = attn.to_q(condition_latents) 88 | cond_key = attn.to_k(condition_latents) 89 | cond_value = attn.to_v(condition_latents) 90 | 91 | cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose( 92 | 1, 2 93 | ) 94 | cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 95 | cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose( 96 | 1, 2 97 | ) 98 | if attn.norm_q is not None: 99 | cond_query = attn.norm_q(cond_query) 100 | if attn.norm_k is not None: 101 | cond_key = attn.norm_k(cond_key) 102 | 103 | #extra condition 104 | if extra_condition_latents is not None: 105 | extra_cond_query = attn.to_q(extra_condition_latents) 106 | extra_cond_key = attn.to_k(extra_condition_latents) 107 | extra_cond_value = attn.to_v(extra_condition_latents) 108 | 109 | extra_cond_query = extra_cond_query.view(batch_size, -1, attn.heads, head_dim).transpose( 110 | 1, 2 111 | ) 112 | extra_cond_key = extra_cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 113 | extra_cond_value = extra_cond_value.view(batch_size, -1, attn.heads, head_dim).transpose( 114 | 1, 2 115 | ) 116 | if attn.norm_q is not None: 117 | extra_cond_query = attn.norm_q(extra_cond_query) 118 | if attn.norm_k is not None: 119 | extra_cond_key = attn.norm_k(extra_cond_key) 120 | 121 | 122 | if extra_cond_rotary_emb is not None: 123 | extra_cond_query = apply_rotary_emb(extra_cond_query, extra_cond_rotary_emb) 124 | extra_cond_key = apply_rotary_emb(extra_cond_key, extra_cond_rotary_emb) 125 | 126 | if cond_rotary_emb is not None: 127 | cond_query = apply_rotary_emb(cond_query, cond_rotary_emb) 128 | cond_key = apply_rotary_emb(cond_key, cond_rotary_emb) 129 | 130 | if condition_latents is not None: 131 | if extra_condition_latents is not None: 132 | 133 | query = torch.cat([query, cond_query, extra_cond_query], dim=2) 134 | key = torch.cat([key, cond_key, extra_cond_key], dim=2) 135 | value = torch.cat([value, cond_value, extra_cond_value], dim=2) 136 | else: 137 | query = torch.cat([query, cond_query], dim=2) 138 | key = torch.cat([key, cond_key], dim=2) 139 | value = torch.cat([value, cond_value], dim=2) 140 | print("concat Omini latents: ", query.shape, key.shape, value.shape) 141 | 142 | 143 | if not model_config.get("union_cond_attn", True): 144 | 145 | attention_mask = torch.ones( 146 | query.shape[2], key.shape[2], device=query.device, dtype=torch.bool 147 | ) 148 | condition_n = cond_query.shape[2] 149 | attention_mask[-condition_n:, :-condition_n] = False 150 | attention_mask[:-condition_n, -condition_n:] = False 151 | elif model_config.get("independent_condition", False): 152 | attention_mask = torch.ones( 153 | query.shape[2], key.shape[2], device=query.device, dtype=torch.bool 154 | ) 155 | condition_n = cond_query.shape[2] 156 | attention_mask[-condition_n:, :-condition_n] = False 157 | 158 | if hasattr(attn, "c_factor"): 159 | attention_mask = torch.zeros( 160 | query.shape[2], key.shape[2], device=query.device, dtype=query.dtype 161 | ) 162 | condition_n = cond_query.shape[2] 163 | condition_e = extra_cond_query.shape[2] 164 | bias = torch.log(attn.c_factor[0]) 165 | attention_mask[-condition_n-condition_e:-condition_e, :-condition_n-condition_e] = bias 166 | attention_mask[:-condition_n-condition_e, -condition_n-condition_e:-condition_e] = bias 167 | 168 | bias = torch.log(attn.c_factor[1]) 169 | attention_mask[-condition_e:, :-condition_n-condition_e] = bias 170 | attention_mask[:-condition_n-condition_e, -condition_e:] = bias 171 | 172 | hidden_states = F.scaled_dot_product_attention( 173 | query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask 174 | ) 175 | hidden_states = hidden_states.transpose(1, 2).reshape( 176 | batch_size, -1, attn.heads * head_dim 177 | ) 178 | hidden_states = hidden_states.to(query.dtype) 179 | 180 | if encoder_hidden_states is not None: 181 | if condition_latents is not None: 182 | if extra_condition_latents is not None: 183 | encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = ( 184 | hidden_states[:, : encoder_hidden_states.shape[1]], 185 | hidden_states[ 186 | :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]*2 187 | ], 188 | hidden_states[:, -condition_latents.shape[1]*2 :-condition_latents.shape[1]], 189 | hidden_states[:, -condition_latents.shape[1] :], #extra condition latents 190 | ) 191 | else: 192 | encoder_hidden_states, hidden_states, condition_latents = ( 193 | hidden_states[:, : encoder_hidden_states.shape[1]], 194 | hidden_states[ 195 | :, encoder_hidden_states.shape[1] : -condition_latents.shape[1] 196 | ], 197 | hidden_states[:, -condition_latents.shape[1] :] 198 | ) 199 | else: 200 | encoder_hidden_states, hidden_states = ( 201 | hidden_states[:, : encoder_hidden_states.shape[1]], 202 | hidden_states[:, encoder_hidden_states.shape[1] :], 203 | ) 204 | 205 | with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)): 206 | # linear proj 207 | hidden_states = attn.to_out[0](hidden_states) 208 | # dropout 209 | hidden_states = attn.to_out[1](hidden_states) 210 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 211 | 212 | if condition_latents is not None: 213 | condition_latents = attn.to_out[0](condition_latents) 214 | condition_latents = attn.to_out[1](condition_latents) 215 | 216 | if extra_condition_latents is not None: 217 | extra_condition_latents = attn.to_out[0](extra_condition_latents) 218 | extra_condition_latents = attn.to_out[1](extra_condition_latents) 219 | 220 | 221 | return ( 222 | # (hidden_states, encoder_hidden_states, condition_latents, extra_condition_latents) 223 | (hidden_states, encoder_hidden_states, condition_latents, extra_condition_latents) 224 | if condition_latents is not None 225 | else (hidden_states, encoder_hidden_states) 226 | ) 227 | elif condition_latents is not None: 228 | # if there are condition_latents, we need to separate the hidden_states and the condition_latents 229 | if extra_condition_latents is not None: 230 | hidden_states, condition_latents, extra_condition_latents = ( 231 | hidden_states[:, : -condition_latents.shape[1]*2], 232 | hidden_states[:, -condition_latents.shape[1]*2 :-condition_latents.shape[1]], 233 | hidden_states[:, -condition_latents.shape[1] :], 234 | ) 235 | else: 236 | hidden_states, condition_latents = ( 237 | hidden_states[:, : -condition_latents.shape[1]], 238 | hidden_states[:, -condition_latents.shape[1] :], 239 | ) 240 | return hidden_states, condition_latents, extra_condition_latents 241 | else: 242 | return hidden_states 243 | 244 | 245 | def block_forward( 246 | self, 247 | hidden_states: torch.FloatTensor, 248 | encoder_hidden_states: torch.FloatTensor, 249 | condition_latents: torch.FloatTensor, 250 | extra_condition_latents: torch.FloatTensor, 251 | temb: torch.FloatTensor, 252 | cond_temb: torch.FloatTensor, 253 | extra_cond_temb: torch.FloatTensor, 254 | cond_rotary_emb=None, 255 | extra_cond_rotary_emb=None, 256 | image_rotary_emb=None, 257 | model_config: Optional[Dict[str, Any]] = {}, 258 | ): 259 | use_cond = condition_latents is not None 260 | 261 | use_extra_cond = extra_condition_latents is not None 262 | with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)): 263 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 264 | hidden_states, emb=temb 265 | ) 266 | 267 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( 268 | self.norm1_context(encoder_hidden_states, emb=temb) 269 | ) 270 | 271 | if use_cond: 272 | ( 273 | norm_condition_latents, 274 | cond_gate_msa, 275 | cond_shift_mlp, 276 | cond_scale_mlp, 277 | cond_gate_mlp, 278 | ) = self.norm1(condition_latents, emb=cond_temb) 279 | ( 280 | norm_extra_condition_latents, 281 | extra_cond_gate_msa, 282 | extra_cond_shift_mlp, 283 | extra_cond_scale_mlp, 284 | extra_cond_gate_mlp, 285 | ) = self.norm1(extra_condition_latents, emb=extra_cond_temb) 286 | 287 | # Attention. 288 | result = attn_forward( 289 | self.attn, 290 | model_config=model_config, 291 | hidden_states=norm_hidden_states, 292 | encoder_hidden_states=norm_encoder_hidden_states, 293 | condition_latents=norm_condition_latents if use_cond else None, 294 | extra_condition_latents=norm_extra_condition_latents if use_cond else None, 295 | image_rotary_emb=image_rotary_emb, 296 | cond_rotary_emb=cond_rotary_emb if use_cond else None, 297 | extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_cond else None, 298 | ) 299 | 300 | attn_output, context_attn_output = result[:2] 301 | cond_attn_output = result[2] if use_cond else None 302 | extra_condition_output = result[3] 303 | 304 | # Process attention outputs for the `hidden_states`. 305 | # 1. hidden_states 306 | attn_output = gate_msa.unsqueeze(1) * attn_output 307 | hidden_states = hidden_states + attn_output 308 | # 2. encoder_hidden_states 309 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output 310 | 311 | encoder_hidden_states = encoder_hidden_states + context_attn_output 312 | # 3. condition_latents 313 | if use_cond: 314 | cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output 315 | condition_latents = condition_latents + cond_attn_output 316 | #need to make new condition_extra and add extra_condition_output 317 | if use_extra_cond: 318 | extra_condition_output = extra_cond_gate_msa.unsqueeze(1) * extra_condition_output 319 | extra_condition_latents = extra_condition_latents + extra_condition_output 320 | 321 | if model_config.get("add_cond_attn", False): 322 | hidden_states += cond_attn_output 323 | hidden_states += extra_condition_output 324 | 325 | 326 | # LayerNorm + MLP. 327 | # 1. hidden_states 328 | norm_hidden_states = self.norm2(hidden_states) 329 | norm_hidden_states = ( 330 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 331 | ) 332 | # 2. encoder_hidden_states 333 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) 334 | norm_encoder_hidden_states = ( 335 | norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] 336 | ) 337 | # 3. condition_latents 338 | if use_cond: 339 | norm_condition_latents = self.norm2(condition_latents) 340 | norm_condition_latents = ( 341 | norm_condition_latents * (1 + cond_scale_mlp[:, None]) 342 | + cond_shift_mlp[:, None] 343 | ) 344 | 345 | if use_extra_cond: 346 | #added conditions 347 | extra_norm_condition_latents = self.norm2(extra_condition_latents) 348 | extra_norm_condition_latents = ( 349 | extra_norm_condition_latents * (1 + extra_cond_scale_mlp[:, None]) 350 | + extra_cond_shift_mlp[:, None] 351 | ) 352 | 353 | # Feed-forward. 354 | with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)): 355 | # 1. hidden_states 356 | ff_output = self.ff(norm_hidden_states) 357 | ff_output = gate_mlp.unsqueeze(1) * ff_output 358 | # 2. encoder_hidden_states 359 | context_ff_output = self.ff_context(norm_encoder_hidden_states) 360 | context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output 361 | # 3. condition_latents 362 | if use_cond: 363 | cond_ff_output = self.ff(norm_condition_latents) 364 | cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output 365 | 366 | if use_extra_cond: 367 | extra_cond_ff_output = self.ff(extra_norm_condition_latents) 368 | extra_cond_ff_output = extra_cond_gate_mlp.unsqueeze(1) * extra_cond_ff_output 369 | 370 | # Process feed-forward outputs. 371 | hidden_states = hidden_states + ff_output 372 | encoder_hidden_states = encoder_hidden_states + context_ff_output 373 | if use_cond: 374 | condition_latents = condition_latents + cond_ff_output 375 | if use_extra_cond: 376 | extra_condition_latents = extra_condition_latents + extra_cond_ff_output 377 | 378 | # Clip to avoid overflow. 379 | if encoder_hidden_states.dtype == torch.float16: 380 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) 381 | 382 | return encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents if use_cond else None 383 | 384 | 385 | def single_block_forward( 386 | self, 387 | hidden_states: torch.FloatTensor, 388 | temb: torch.FloatTensor, 389 | image_rotary_emb=None, 390 | condition_latents: torch.FloatTensor = None, 391 | extra_condition_latents: torch.FloatTensor = None, 392 | cond_temb: torch.FloatTensor = None, 393 | extra_cond_temb: torch.FloatTensor = None, 394 | cond_rotary_emb=None, 395 | extra_cond_rotary_emb=None, 396 | model_config: Optional[Dict[str, Any]] = {}, 397 | ): 398 | 399 | using_cond = condition_latents is not None 400 | using_extra_cond = extra_condition_latents is not None 401 | residual = hidden_states 402 | with enable_lora( 403 | ( 404 | self.norm.linear, 405 | self.proj_mlp, 406 | ), 407 | model_config.get("latent_lora", False), 408 | ): 409 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb) 410 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) 411 | if using_cond: 412 | residual_cond = condition_latents 413 | norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb) 414 | mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents)) 415 | 416 | if using_extra_cond: 417 | extra_residual_cond = extra_condition_latents 418 | extra_norm_condition_latents, extra_cond_gate = self.norm(extra_condition_latents, emb=extra_cond_temb) 419 | extra_mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(extra_norm_condition_latents)) 420 | 421 | attn_output = attn_forward( 422 | self.attn, 423 | model_config=model_config, 424 | hidden_states=norm_hidden_states, 425 | image_rotary_emb=image_rotary_emb, 426 | **( 427 | { 428 | "condition_latents": norm_condition_latents, 429 | "cond_rotary_emb": cond_rotary_emb if using_cond else None, 430 | "extra_condition_latents": extra_norm_condition_latents if using_cond else None, 431 | "extra_cond_rotary_emb": extra_cond_rotary_emb if using_cond else None, 432 | } 433 | if using_cond 434 | else {} 435 | ), 436 | ) 437 | 438 | if using_cond: 439 | attn_output, cond_attn_output, extra_cond_attn_output = attn_output 440 | 441 | 442 | with enable_lora((self.proj_out,), model_config.get("latent_lora", False)): 443 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) 444 | gate = gate.unsqueeze(1) 445 | hidden_states = gate * self.proj_out(hidden_states) 446 | hidden_states = residual + hidden_states 447 | if using_cond: 448 | condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2) 449 | cond_gate = cond_gate.unsqueeze(1) 450 | condition_latents = cond_gate * self.proj_out(condition_latents) 451 | condition_latents = residual_cond + condition_latents 452 | 453 | extra_condition_latents = torch.cat([extra_cond_attn_output, extra_mlp_cond_hidden_states], dim=2) 454 | extra_cond_gate = extra_cond_gate.unsqueeze(1) 455 | extra_condition_latents = extra_cond_gate * self.proj_out(extra_condition_latents) 456 | extra_condition_latents = extra_residual_cond + extra_condition_latents 457 | 458 | if hidden_states.dtype == torch.float16: 459 | hidden_states = hidden_states.clip(-65504, 65504) 460 | 461 | return hidden_states if not using_cond else (hidden_states, condition_latents, extra_condition_latents) 462 | -------------------------------------------------------------------------------- /flux/condition.py: -------------------------------------------------------------------------------- 1 | # Recycled from Ominicontrol and modified to accept an extra condition. 2 | # While Zenctrl pursued a similar idea, it diverged structurally. 3 | # We appreciate the clarity of Omini's implementation and decided to align with it. 4 | 5 | import torch 6 | from typing import Union, Tuple 7 | from diffusers.pipelines import FluxPipeline 8 | from PIL import Image 9 | 10 | 11 | # from pipeline_tools import encode_images 12 | from .pipeline_tools import encode_images 13 | 14 | condition_dict = { 15 | "subject": 4, 16 | "sr": 10, 17 | "cot": 12, 18 | } 19 | 20 | 21 | class Condition(object): 22 | def __init__( 23 | self, 24 | condition_type: str, 25 | raw_img: Union[Image.Image, torch.Tensor] = None, 26 | condition: Union[Image.Image, torch.Tensor] = None, 27 | position_delta=None, 28 | ) -> None: 29 | self.condition_type = condition_type 30 | assert raw_img is not None or condition is not None 31 | if raw_img is not None: 32 | self.condition = self.get_condition(condition_type, raw_img) 33 | else: 34 | self.condition = condition 35 | self.position_delta = position_delta 36 | 37 | 38 | def get_condition( 39 | self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor] 40 | ) -> Union[Image.Image, torch.Tensor]: 41 | """ 42 | Returns the condition image. 43 | """ 44 | if condition_type == "subject": 45 | return raw_img 46 | elif condition_type == "sr": 47 | return raw_img 48 | elif condition_type == "cot": 49 | return raw_img 50 | return self.condition 51 | 52 | 53 | @property 54 | def type_id(self) -> int: 55 | """ 56 | Returns the type id of the condition. 57 | """ 58 | return condition_dict[self.condition_type] 59 | 60 | def encode( 61 | self, pipe: FluxPipeline, empty: bool = False 62 | ) -> Tuple[torch.Tensor, torch.Tensor, int]: 63 | """ 64 | Encodes the condition into tokens, ids and type_id. 65 | """ 66 | if self.condition_type in [ 67 | "subject", 68 | "sr", 69 | "cot" 70 | ]: 71 | if empty: 72 | # make the condition black 73 | e_condition = Image.new("RGB", self.condition.size, (0, 0, 0)) 74 | e_condition = e_condition.convert("RGB") 75 | tokens, ids = encode_images(pipe, e_condition) 76 | else: 77 | tokens, ids = encode_images(pipe, self.condition) 78 | else: 79 | raise NotImplementedError( 80 | f"Condition type {self.condition_type} not implemented" 81 | ) 82 | if self.position_delta is None and self.condition_type == "subject": 83 | self.position_delta = [0, -self.condition.size[0] // 16] 84 | if self.position_delta is not None: 85 | ids[:, 1] += self.position_delta[0] 86 | ids[:, 2] += self.position_delta[1] 87 | type_id = torch.ones_like(ids[:, :1]) * self.type_id 88 | return tokens, ids, type_id 89 | -------------------------------------------------------------------------------- /flux/generate.py: -------------------------------------------------------------------------------- 1 | # Recycled from Ominicontrol and modified to accept an extra condition. 2 | # While Zenctrl pursued a similar idea, it diverged structurally. 3 | # We appreciate the clarity of Omini's implementation and decided to align with it. 4 | 5 | import torch 6 | import yaml, os 7 | from diffusers.pipelines import FluxPipeline 8 | from typing import List, Union, Optional, Dict, Any, Callable 9 | from .transformer import tranformer_forward 10 | from .condition import Condition 11 | 12 | 13 | from diffusers.pipelines.flux.pipeline_flux import ( 14 | FluxPipelineOutput, 15 | calculate_shift, 16 | retrieve_timesteps, 17 | np, 18 | ) 19 | 20 | 21 | def get_config(config_path: str = None): 22 | config_path = config_path or os.environ.get("XFL_CONFIG") 23 | if not config_path: 24 | return {} 25 | with open(config_path, "r") as f: 26 | config = yaml.safe_load(f) 27 | return config 28 | 29 | 30 | def prepare_params( 31 | prompt: Union[str, List[str]] = None, 32 | prompt_2: Optional[Union[str, List[str]]] = None, 33 | height: Optional[int] = 512, 34 | width: Optional[int] = 512, 35 | num_inference_steps: int = 28, 36 | timesteps: List[int] = None, 37 | guidance_scale: float = 3.5, 38 | num_images_per_prompt: Optional[int] = 1, 39 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 40 | latents: Optional[torch.FloatTensor] = None, 41 | prompt_embeds: Optional[torch.FloatTensor] = None, 42 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 43 | output_type: Optional[str] = "pil", 44 | return_dict: bool = True, 45 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 46 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 47 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 48 | max_sequence_length: int = 512, 49 | **kwargs: dict, 50 | ): 51 | return ( 52 | prompt, 53 | prompt_2, 54 | height, 55 | width, 56 | num_inference_steps, 57 | timesteps, 58 | guidance_scale, 59 | num_images_per_prompt, 60 | generator, 61 | latents, 62 | prompt_embeds, 63 | pooled_prompt_embeds, 64 | output_type, 65 | return_dict, 66 | joint_attention_kwargs, 67 | callback_on_step_end, 68 | callback_on_step_end_tensor_inputs, 69 | max_sequence_length, 70 | ) 71 | 72 | 73 | def seed_everything(seed: int = 42): 74 | torch.backends.cudnn.deterministic = True 75 | torch.manual_seed(seed) 76 | np.random.seed(seed) 77 | 78 | 79 | @torch.no_grad() 80 | def generate( 81 | pipeline: FluxPipeline, 82 | conditions: List[Condition] = None, 83 | config_path: str = None, 84 | model_config: Optional[Dict[str, Any]] = {}, 85 | condition_scale: float = [1, 1], 86 | default_lora: bool = False, 87 | image_guidance_scale: float = 1.0, 88 | **params: dict, 89 | ): 90 | model_config = model_config or get_config(config_path).get("model", {}) 91 | if condition_scale != [1,1]: 92 | for name, module in pipeline.transformer.named_modules(): 93 | if not name.endswith(".attn"): 94 | continue 95 | module.c_factor = torch.tensor(condition_scale) 96 | 97 | self = pipeline 98 | ( 99 | prompt, 100 | prompt_2, 101 | height, 102 | width, 103 | num_inference_steps, 104 | timesteps, 105 | guidance_scale, 106 | num_images_per_prompt, 107 | generator, 108 | latents, 109 | prompt_embeds, 110 | pooled_prompt_embeds, 111 | output_type, 112 | return_dict, 113 | joint_attention_kwargs, 114 | callback_on_step_end, 115 | callback_on_step_end_tensor_inputs, 116 | max_sequence_length, 117 | ) = prepare_params(**params) 118 | 119 | height = height or self.default_sample_size * self.vae_scale_factor 120 | width = width or self.default_sample_size * self.vae_scale_factor 121 | 122 | # 1. Check inputs. Raise error if not correct 123 | self.check_inputs( 124 | prompt, 125 | prompt_2, 126 | height, 127 | width, 128 | prompt_embeds=prompt_embeds, 129 | pooled_prompt_embeds=pooled_prompt_embeds, 130 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 131 | max_sequence_length=max_sequence_length, 132 | ) 133 | 134 | self._guidance_scale = guidance_scale 135 | self._joint_attention_kwargs = joint_attention_kwargs 136 | self._interrupt = False 137 | 138 | # 2. Define call parameters 139 | if prompt is not None and isinstance(prompt, str): 140 | batch_size = 1 141 | elif prompt is not None and isinstance(prompt, list): 142 | batch_size = len(prompt) 143 | else: 144 | batch_size = prompt_embeds.shape[0] 145 | 146 | device = self._execution_device 147 | 148 | lora_scale = ( 149 | self.joint_attention_kwargs.get("scale", None) 150 | if self.joint_attention_kwargs is not None 151 | else None 152 | ) 153 | ( 154 | prompt_embeds, 155 | pooled_prompt_embeds, 156 | text_ids, 157 | ) = self.encode_prompt( 158 | prompt=prompt, 159 | prompt_2=prompt_2, 160 | prompt_embeds=prompt_embeds, 161 | pooled_prompt_embeds=pooled_prompt_embeds, 162 | device=device, 163 | num_images_per_prompt=num_images_per_prompt, 164 | max_sequence_length=max_sequence_length, 165 | lora_scale=lora_scale, 166 | ) 167 | 168 | # 4. Prepare latent variables 169 | num_channels_latents = self.transformer.config.in_channels // 4 170 | latents, latent_image_ids = self.prepare_latents( 171 | batch_size * num_images_per_prompt, 172 | num_channels_latents, 173 | height, 174 | width, 175 | prompt_embeds.dtype, 176 | device, 177 | generator, 178 | latents, 179 | ) 180 | 181 | # 4.1. Prepare conditions 182 | condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3)) 183 | extra_condition_latents, extra_condition_ids, extra_condition_type_ids = ([] for _ in range(3)) 184 | use_condition = conditions is not None or [] 185 | if use_condition: 186 | if not default_lora: 187 | pipeline.set_adapters(conditions[1].condition_type) 188 | # for condition in conditions: 189 | tokens, ids, type_id = conditions[0].encode(self) 190 | condition_latents.append(tokens) # [batch_size, token_n, token_dim] 191 | condition_ids.append(ids) # [token_n, id_dim(3)] 192 | condition_type_ids.append(type_id) # [token_n, 1] 193 | condition_latents = torch.cat(condition_latents, dim=1) 194 | condition_ids = torch.cat(condition_ids, dim=0) 195 | condition_type_ids = torch.cat(condition_type_ids, dim=0) 196 | 197 | tokens, ids, type_id = conditions[1].encode(self) 198 | extra_condition_latents.append(tokens) # [batch_size, token_n, token_dim] 199 | extra_condition_ids.append(ids) # [token_n, id_dim(3)] 200 | extra_condition_type_ids.append(type_id) # [token_n, 1] 201 | extra_condition_latents = torch.cat(extra_condition_latents, dim=1) 202 | extra_condition_ids = torch.cat(extra_condition_ids, dim=0) 203 | extra_condition_type_ids = torch.cat(extra_condition_type_ids, dim=0) 204 | 205 | # 5. Prepare timesteps 206 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 207 | image_seq_len = latents.shape[1] 208 | mu = calculate_shift( 209 | image_seq_len, 210 | self.scheduler.config.base_image_seq_len, 211 | self.scheduler.config.max_image_seq_len, 212 | self.scheduler.config.base_shift, 213 | self.scheduler.config.max_shift, 214 | ) 215 | timesteps, num_inference_steps = retrieve_timesteps( 216 | self.scheduler, 217 | num_inference_steps, 218 | device, 219 | timesteps, 220 | sigmas, 221 | mu=mu, 222 | ) 223 | num_warmup_steps = max( 224 | len(timesteps) - num_inference_steps * self.scheduler.order, 0 225 | ) 226 | self._num_timesteps = len(timesteps) 227 | 228 | # 6. Denoising loop 229 | with self.progress_bar(total=num_inference_steps) as progress_bar: 230 | for i, t in enumerate(timesteps): 231 | if self.interrupt: 232 | continue 233 | 234 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 235 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 236 | 237 | # handle guidance 238 | if self.transformer.config.guidance_embeds: 239 | guidance = torch.tensor([guidance_scale], device=device) 240 | guidance = guidance.expand(latents.shape[0]) 241 | else: 242 | guidance = None 243 | noise_pred = tranformer_forward( 244 | self.transformer, 245 | model_config=model_config, 246 | # Inputs of the condition (new feature) 247 | condition_latents=condition_latents if use_condition else None, 248 | condition_ids=condition_ids if use_condition else None, 249 | condition_type_ids=condition_type_ids if use_condition else None, 250 | extra_condition_latents=extra_condition_latents if use_condition else None, 251 | extra_condition_ids=extra_condition_ids if use_condition else None, 252 | extra_condition_type_ids=extra_condition_type_ids if use_condition else None, 253 | # Inputs to the original transformer 254 | hidden_states=latents, 255 | # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) 256 | timestep=timestep / 1000, 257 | guidance=guidance, 258 | pooled_projections=pooled_prompt_embeds, 259 | encoder_hidden_states=prompt_embeds, 260 | txt_ids=text_ids, 261 | img_ids=latent_image_ids, 262 | joint_attention_kwargs=self.joint_attention_kwargs, 263 | return_dict=False, 264 | )[0] 265 | 266 | if image_guidance_scale != 1.0: 267 | uncondition_latents = conditions.encode(self, empty=True)[0] 268 | unc_pred = tranformer_forward( 269 | self.transformer, 270 | model_config=model_config, 271 | # Inputs of the condition (new feature) 272 | condition_latents=uncondition_latents if use_condition else None, 273 | condition_ids=condition_ids if use_condition else None, 274 | condition_type_ids=condition_type_ids if use_condition else None, 275 | # Inputs to the original transformer 276 | hidden_states=latents, 277 | # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) 278 | timestep=timestep / 1000, 279 | guidance=torch.ones_like(guidance), 280 | pooled_projections=pooled_prompt_embeds, 281 | encoder_hidden_states=prompt_embeds, 282 | txt_ids=text_ids, 283 | img_ids=latent_image_ids, 284 | joint_attention_kwargs=self.joint_attention_kwargs, 285 | return_dict=False, 286 | )[0] 287 | 288 | noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred) 289 | 290 | # compute the previous noisy sample x_t -> x_t-1 291 | latents_dtype = latents.dtype 292 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 293 | 294 | if latents.dtype != latents_dtype: 295 | if torch.backends.mps.is_available(): 296 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 297 | latents = latents.to(latents_dtype) 298 | 299 | if callback_on_step_end is not None: 300 | callback_kwargs = {} 301 | for k in callback_on_step_end_tensor_inputs: 302 | callback_kwargs[k] = locals()[k] 303 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 304 | 305 | latents = callback_outputs.pop("latents", latents) 306 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 307 | 308 | # call the callback, if provided 309 | if i == len(timesteps) - 1 or ( 310 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 311 | ): 312 | progress_bar.update() 313 | 314 | if output_type == "latent": 315 | image = latents 316 | 317 | else: 318 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 319 | latents = ( 320 | latents / self.vae.config.scaling_factor 321 | ) + self.vae.config.shift_factor 322 | image = self.vae.decode(latents, return_dict=False)[0] 323 | image = self.image_processor.postprocess(image, output_type=output_type) 324 | 325 | # Offload all models 326 | self.maybe_free_model_hooks() 327 | 328 | if condition_scale != [1,1]: 329 | for name, module in pipeline.transformer.named_modules(): 330 | if not name.endswith(".attn"): 331 | continue 332 | del module.c_factor 333 | 334 | if not return_dict: 335 | return (image,) 336 | 337 | return FluxPipelineOutput(images=image) 338 | -------------------------------------------------------------------------------- /flux/lora_controller.py: -------------------------------------------------------------------------------- 1 | #As is from OminiControl 2 | from peft.tuners.tuners_utils import BaseTunerLayer 3 | from typing import List, Any, Optional, Type 4 | from .condition import condition_dict 5 | 6 | 7 | class enable_lora: 8 | def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None: 9 | self.activated: bool = activated 10 | if activated: 11 | return 12 | self.lora_modules: List[BaseTunerLayer] = [ 13 | each for each in lora_modules if isinstance(each, BaseTunerLayer) 14 | ] 15 | self.scales = [ 16 | { 17 | active_adapter: lora_module.scaling[active_adapter] 18 | for active_adapter in lora_module.active_adapters 19 | } 20 | for lora_module in self.lora_modules 21 | ] 22 | 23 | def __enter__(self) -> None: 24 | if self.activated: 25 | return 26 | 27 | for lora_module in self.lora_modules: 28 | if not isinstance(lora_module, BaseTunerLayer): 29 | continue 30 | for active_adapter in lora_module.active_adapters: 31 | if ( 32 | active_adapter in condition_dict.keys() 33 | or active_adapter == "default" 34 | ): 35 | lora_module.scaling[active_adapter] = 0.0 36 | 37 | def __exit__( 38 | self, 39 | exc_type: Optional[Type[BaseException]], 40 | exc_val: Optional[BaseException], 41 | exc_tb: Optional[Any], 42 | ) -> None: 43 | if self.activated: 44 | return 45 | for i, lora_module in enumerate(self.lora_modules): 46 | if not isinstance(lora_module, BaseTunerLayer): 47 | continue 48 | for active_adapter in lora_module.active_adapters: 49 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter] 50 | 51 | 52 | class set_lora_scale: 53 | def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None: 54 | self.lora_modules: List[BaseTunerLayer] = [ 55 | each for each in lora_modules if isinstance(each, BaseTunerLayer) 56 | ] 57 | self.scales = [ 58 | { 59 | active_adapter: lora_module.scaling[active_adapter] 60 | for active_adapter in lora_module.active_adapters 61 | } 62 | for lora_module in self.lora_modules 63 | ] 64 | self.scale = scale 65 | 66 | def __enter__(self) -> None: 67 | for lora_module in self.lora_modules: 68 | if not isinstance(lora_module, BaseTunerLayer): 69 | continue 70 | lora_module.scale_layer(self.scale) 71 | 72 | def __exit__( 73 | self, 74 | exc_type: Optional[Type[BaseException]], 75 | exc_val: Optional[BaseException], 76 | exc_tb: Optional[Any], 77 | ) -> None: 78 | for i, lora_module in enumerate(self.lora_modules): 79 | if not isinstance(lora_module, BaseTunerLayer): 80 | continue 81 | for active_adapter in lora_module.active_adapters: 82 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter] -------------------------------------------------------------------------------- /flux/pipeline_tools.py: -------------------------------------------------------------------------------- 1 | #As is from OminiControl 2 | from diffusers.pipelines import FluxPipeline 3 | from diffusers.utils import logging 4 | from diffusers.pipelines.flux.pipeline_flux import logger 5 | from torch import Tensor 6 | 7 | 8 | def encode_images(pipeline: FluxPipeline, images: Tensor): 9 | images = pipeline.image_processor.preprocess(images) 10 | images = images.to(pipeline.device).to(pipeline.dtype) 11 | images = pipeline.vae.encode(images).latent_dist.sample() 12 | images = ( 13 | images - pipeline.vae.config.shift_factor 14 | ) * pipeline.vae.config.scaling_factor 15 | images_tokens = pipeline._pack_latents(images, *images.shape) 16 | images_ids = pipeline._prepare_latent_image_ids( 17 | images.shape[0], 18 | images.shape[2], 19 | images.shape[3], 20 | pipeline.device, 21 | pipeline.dtype, 22 | ) 23 | if images_tokens.shape[1] != images_ids.shape[0]: 24 | images_ids = pipeline._prepare_latent_image_ids( 25 | images.shape[0], 26 | images.shape[2] // 2, 27 | images.shape[3] // 2, 28 | pipeline.device, 29 | pipeline.dtype, 30 | ) 31 | return images_tokens, images_ids 32 | 33 | 34 | def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512): 35 | # Turn off warnings (CLIP overflow) 36 | logger.setLevel(logging.ERROR) 37 | ( 38 | prompt_embeds, 39 | pooled_prompt_embeds, 40 | text_ids, 41 | ) = pipeline.encode_prompt( 42 | prompt=prompts, 43 | prompt_2=None, 44 | prompt_embeds=None, 45 | pooled_prompt_embeds=None, 46 | device=pipeline.device, 47 | num_images_per_prompt=1, 48 | max_sequence_length=max_sequence_length, 49 | lora_scale=None, 50 | ) 51 | # Turn on warnings 52 | logger.setLevel(logging.WARNING) 53 | return prompt_embeds, pooled_prompt_embeds, text_ids 54 | -------------------------------------------------------------------------------- /flux/transformer.py: -------------------------------------------------------------------------------- 1 | # Recycled from Ominicontrol and modified to accept an extra condition. 2 | # While Zenctrl pursued a similar idea, it diverged structurally. 3 | # We appreciate the clarity of Omini's implementation and decided to align with it. 4 | 5 | import torch 6 | from typing import Optional, Dict, Any 7 | from .block import block_forward, single_block_forward 8 | from .lora_controller import enable_lora 9 | from accelerate.utils import is_torch_version 10 | from diffusers.models.transformers.transformer_flux import ( 11 | FluxTransformer2DModel, 12 | Transformer2DModelOutput, 13 | USE_PEFT_BACKEND, 14 | scale_lora_layers, 15 | unscale_lora_layers, 16 | logger, 17 | ) 18 | import numpy as np 19 | 20 | 21 | def prepare_params( 22 | hidden_states: torch.Tensor, 23 | encoder_hidden_states: torch.Tensor = None, 24 | pooled_projections: torch.Tensor = None, 25 | timestep: torch.LongTensor = None, 26 | img_ids: torch.Tensor = None, 27 | txt_ids: torch.Tensor = None, 28 | guidance: torch.Tensor = None, 29 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 30 | controlnet_block_samples=None, 31 | controlnet_single_block_samples=None, 32 | return_dict: bool = True, 33 | **kwargs: dict, 34 | ): 35 | return ( 36 | hidden_states, 37 | encoder_hidden_states, 38 | pooled_projections, 39 | timestep, 40 | img_ids, 41 | txt_ids, 42 | guidance, 43 | joint_attention_kwargs, 44 | controlnet_block_samples, 45 | controlnet_single_block_samples, 46 | return_dict, 47 | ) 48 | 49 | 50 | def tranformer_forward( 51 | transformer: FluxTransformer2DModel, 52 | condition_latents: torch.Tensor, 53 | extra_condition_latents: torch.Tensor, 54 | condition_ids: torch.Tensor, 55 | condition_type_ids: torch.Tensor, 56 | extra_condition_ids: torch.Tensor, 57 | extra_condition_type_ids: torch.Tensor, 58 | model_config: Optional[Dict[str, Any]] = {}, 59 | c_t=0, 60 | **params: dict, 61 | ): 62 | self = transformer 63 | use_condition = condition_latents is not None 64 | use_extra_condition = extra_condition_latents is not None 65 | 66 | ( 67 | hidden_states, 68 | encoder_hidden_states, 69 | pooled_projections, 70 | timestep, 71 | img_ids, 72 | txt_ids, 73 | guidance, 74 | joint_attention_kwargs, 75 | controlnet_block_samples, 76 | controlnet_single_block_samples, 77 | return_dict, 78 | ) = prepare_params(**params) 79 | 80 | if joint_attention_kwargs is not None: 81 | joint_attention_kwargs = joint_attention_kwargs.copy() 82 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 83 | else: 84 | lora_scale = 1.0 85 | 86 | if USE_PEFT_BACKEND: 87 | # weight the lora layers by setting `lora_scale` for each PEFT layer 88 | scale_lora_layers(self, lora_scale) 89 | else: 90 | if ( 91 | joint_attention_kwargs is not None 92 | and joint_attention_kwargs.get("scale", None) is not None 93 | ): 94 | logger.warning( 95 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 96 | ) 97 | 98 | with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)): 99 | hidden_states = self.x_embedder(hidden_states) 100 | condition_latents = self.x_embedder(condition_latents) if use_condition else None 101 | extra_condition_latents = self.x_embedder(extra_condition_latents) if use_extra_condition else None 102 | 103 | timestep = timestep.to(hidden_states.dtype) * 1000 104 | 105 | if guidance is not None: 106 | guidance = guidance.to(hidden_states.dtype) * 1000 107 | else: 108 | guidance = None 109 | 110 | temb = ( 111 | self.time_text_embed(timestep, pooled_projections) 112 | if guidance is None 113 | else self.time_text_embed(timestep, guidance, pooled_projections) 114 | ) 115 | 116 | cond_temb = ( 117 | self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections) 118 | if guidance is None 119 | else self.time_text_embed( 120 | torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections 121 | ) 122 | ) 123 | extra_cond_temb = ( 124 | self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections) 125 | if guidance is None 126 | else self.time_text_embed( 127 | torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections 128 | ) 129 | ) 130 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 131 | 132 | if txt_ids.ndim == 3: 133 | logger.warning( 134 | "Passing `txt_ids` 3d torch.Tensor is deprecated." 135 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 136 | ) 137 | txt_ids = txt_ids[0] 138 | if img_ids.ndim == 3: 139 | logger.warning( 140 | "Passing `img_ids` 3d torch.Tensor is deprecated." 141 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 142 | ) 143 | img_ids = img_ids[0] 144 | 145 | ids = torch.cat((txt_ids, img_ids), dim=0) 146 | image_rotary_emb = self.pos_embed(ids) 147 | if use_condition: 148 | # condition_ids[:, :1] = condition_type_ids 149 | cond_rotary_emb = self.pos_embed(condition_ids) 150 | 151 | if use_extra_condition: 152 | extra_cond_rotary_emb = self.pos_embed(extra_condition_ids) 153 | 154 | 155 | # hidden_states = torch.cat([hidden_states, condition_latents], dim=1) 156 | 157 | #print("here!") 158 | for index_block, block in enumerate(self.transformer_blocks): 159 | if self.training and self.gradient_checkpointing: 160 | ckpt_kwargs: Dict[str, Any] = ( 161 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 162 | ) 163 | encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = ( 164 | torch.utils.checkpoint.checkpoint( 165 | block_forward, 166 | self=block, 167 | model_config=model_config, 168 | hidden_states=hidden_states, 169 | encoder_hidden_states=encoder_hidden_states, 170 | condition_latents=condition_latents if use_condition else None, 171 | extra_condition_latents=extra_condition_latents if use_extra_condition else None, 172 | temb=temb, 173 | cond_temb=cond_temb if use_condition else None, 174 | cond_rotary_emb=cond_rotary_emb if use_condition else None, 175 | extra_cond_temb=extra_cond_temb if use_extra_condition else None, 176 | extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_condition else None, 177 | image_rotary_emb=image_rotary_emb, 178 | **ckpt_kwargs, 179 | ) 180 | ) 181 | 182 | else: 183 | encoder_hidden_states, hidden_states, condition_latents, extra_condition_latents = block_forward( 184 | block, 185 | model_config=model_config, 186 | hidden_states=hidden_states, 187 | encoder_hidden_states=encoder_hidden_states, 188 | condition_latents=condition_latents if use_condition else None, 189 | extra_condition_latents=extra_condition_latents if use_extra_condition else None, 190 | temb=temb, 191 | cond_temb=cond_temb if use_condition else None, 192 | cond_rotary_emb=cond_rotary_emb if use_condition else None, 193 | extra_cond_temb=cond_temb if use_extra_condition else None, 194 | extra_cond_rotary_emb=extra_cond_rotary_emb if use_extra_condition else None, 195 | image_rotary_emb=image_rotary_emb, 196 | ) 197 | 198 | # controlnet residual 199 | if controlnet_block_samples is not None: 200 | interval_control = len(self.transformer_blocks) / len( 201 | controlnet_block_samples 202 | ) 203 | interval_control = int(np.ceil(interval_control)) 204 | hidden_states = ( 205 | hidden_states 206 | + controlnet_block_samples[index_block // interval_control] 207 | ) 208 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 209 | 210 | 211 | for index_block, block in enumerate(self.single_transformer_blocks): 212 | if self.training and self.gradient_checkpointing: 213 | ckpt_kwargs: Dict[str, Any] = ( 214 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 215 | ) 216 | result = torch.utils.checkpoint.checkpoint( 217 | single_block_forward, 218 | self=block, 219 | model_config=model_config, 220 | hidden_states=hidden_states, 221 | temb=temb, 222 | image_rotary_emb=image_rotary_emb, 223 | **( 224 | { 225 | "condition_latents": condition_latents, 226 | "extra_condition_latents": extra_condition_latents, 227 | "cond_temb": cond_temb, 228 | "cond_rotary_emb": cond_rotary_emb, 229 | "extra_cond_temb": extra_cond_temb, 230 | "extra_cond_rotary_emb": extra_cond_rotary_emb, 231 | } 232 | if use_condition 233 | else {} 234 | ), 235 | **ckpt_kwargs, 236 | ) 237 | 238 | else: 239 | result = single_block_forward( 240 | block, 241 | model_config=model_config, 242 | hidden_states=hidden_states, 243 | temb=temb, 244 | image_rotary_emb=image_rotary_emb, 245 | **( 246 | { 247 | "condition_latents": condition_latents, 248 | "extra_condition_latents": extra_condition_latents, 249 | "cond_temb": cond_temb, 250 | "cond_rotary_emb": cond_rotary_emb, 251 | "extra_cond_temb": extra_cond_temb, 252 | "extra_cond_rotary_emb": extra_cond_rotary_emb, 253 | } 254 | if use_condition 255 | else {} 256 | ), 257 | ) 258 | if use_condition: 259 | hidden_states, condition_latents, extra_condition_latents = result 260 | else: 261 | hidden_states = result 262 | 263 | # controlnet residual 264 | if controlnet_single_block_samples is not None: 265 | interval_control = len(self.single_transformer_blocks) / len( 266 | controlnet_single_block_samples 267 | ) 268 | interval_control = int(np.ceil(interval_control)) 269 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 270 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] 271 | + controlnet_single_block_samples[index_block // interval_control] 272 | ) 273 | 274 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 275 | 276 | hidden_states = self.norm_out(hidden_states, temb) 277 | output = self.proj_out(hidden_states) 278 | 279 | if USE_PEFT_BACKEND: 280 | # remove `lora_scale` from each PEFT layer 281 | unscale_lora_layers(self, lora_scale) 282 | 283 | if not return_dict: 284 | return (output,) 285 | return Transformer2DModelOutput(sample=output) 286 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | diffusers==0.31.0 3 | peft==0.15.2 4 | opencv-python 5 | protobuf 6 | sentencepiece 7 | gradio 8 | jupyter 9 | torchao -------------------------------------------------------------------------------- /samples/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/1.png -------------------------------------------------------------------------------- /samples/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/12.png -------------------------------------------------------------------------------- /samples/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/13.png -------------------------------------------------------------------------------- /samples/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/14.png -------------------------------------------------------------------------------- /samples/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/18.png -------------------------------------------------------------------------------- /samples/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/7.png -------------------------------------------------------------------------------- /samples/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/8.png -------------------------------------------------------------------------------- /samples/image_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FotographerAI/ZenCtrl/8e5120017994a00e9f0b8b119362732a40a03770/samples/image_6.png --------------------------------------------------------------------------------