├── LICENSE ├── README.md ├── assets ├── book.jpg ├── cartoon_boy.png ├── clock.jpg ├── coffee.png ├── demo │ ├── book_omini.jpg │ ├── clock_omini.jpg │ ├── demo_this_is_omini_control.jpg │ ├── dreambooth_res.jpg │ ├── man_omini.jpg │ ├── monalisa_omini.jpg │ ├── oranges_omini.jpg │ ├── panda_omini.jpg │ ├── penguin_omini.jpg │ ├── rc_car_omini.jpg │ ├── room_corner_canny.jpg │ ├── room_corner_coloring.jpg │ ├── room_corner_deblurring.jpg │ ├── room_corner_depth.jpg │ ├── scene_variation.jpg │ ├── shirt_omini.jpg │ └── try_on.jpg ├── monalisa.jpg ├── oranges.jpg ├── penguin.jpg ├── rc_car.jpg ├── room_corner.jpg ├── test_in.jpg ├── test_out.jpg ├── tshirt.jpg ├── vase.jpg └── vase_hq.jpg ├── examples ├── inpainting.ipynb ├── spatial.ipynb ├── subject.ipynb └── subject_1024.ipynb ├── gradio_app.py ├── requirements.txt ├── src ├── flux │ ├── block.py │ ├── condition.py │ ├── generate.py │ ├── lora_controller.py │ ├── pipeline_tools.py │ └── transformer.py ├── gradio │ └── gradio_app.py └── train │ ├── callbacks.py │ ├── data.py │ ├── model.py │ └── train.py └── train ├── README.md ├── config ├── canny_512.yaml ├── cartoon_512.yaml ├── fill_1024.yaml ├── sr_512.yaml └── subject_512.yaml ├── requirements.txt └── script ├── data_download ├── data_download1.sh └── data_download2.sh ├── train_canny.sh ├── train_cartoon.sh └── train_subject.sh /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2024] [Zhenxiong Tan] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OminiControl 2 | 3 | 4 | 5 |
6 | 7 | arXiv 8 | HuggingFace 9 | HuggingFace 10 | GitHub 11 | HuggingFace 12 | 13 | > **OminiControl: Minimal and Universal Control for Diffusion Transformer** 14 | >
15 | > Zhenxiong Tan, 16 | > [Songhua Liu](http://121.37.94.87/), 17 | > [Xingyi Yang](https://adamdad.github.io/), 18 | > Qiaochu Xue, 19 | > and 20 | > [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/) 21 | >
22 | > [Learning and Vision Lab](http://lv-nus.org/), National University of Singapore 23 | >
24 | 25 | 26 | ## Features 27 | 28 | 29 | OminiControl is a minimal yet powerful universal control framework for Diffusion Transformer models like [FLUX](https://github.com/black-forest-labs/flux). 30 | 31 | * **Universal Control 🌐**: A unified control framework that supports both subject-driven control and spatial control (such as edge-guided and in-painting generation). 32 | 33 | * **Minimal Design 🚀**: Injects control signals while preserving original model structure. Only introduces 0.1% additional parameters to the base model. 34 | 35 | ## OminiControlGP (OminiControl for the GPU Poor) by DeepBeepMeep 36 | 37 | With just one line adding the 'mmgp' module (https://github.com/deepbeepmeep/mmgp\), OminiControl can generate images from a derived Flux model in less than 6s with 16 GB of VRAM (profile 1), in 9s with 8 GB VRAM (profile 4) or in 16s with less than 6 GB of VRAM (profile 5) 38 | 39 | To run the Gradio app with a profile 3 (default profile, the fastest but requires the most VRAM): 40 | ```bash 41 | python gradio_app --profile 3 42 | ``` 43 | To run the Gradio app with a profile 5 (a bit slower but requires only 6 GB of VRAM): 44 | ```bash 45 | python gradio_app --profile 5 46 | ``` 47 | 48 | You may check the mmgp git homepage if you want to design your own profiles (for instance to disable quantization). 49 | 50 | If you enjoy this applcitation, you will certainly appreciate these ones:\ 51 | - Hunyuan3D-2GP: https://github.com/deepbeepmeep/Hunyuan3D-2GP\ 52 | A great image to 3D or text to 3D tool by the Tencent team. Thanks to mmgp it can run with less than 6 GB of VRAM 53 | 54 | - HuanyuanVideoGP: https://github.com/deepbeepmeep/HunyuanVideoGP\ 55 | One of the best open source Text to Video generator 56 | 57 | - FluxFillGP: https://github.com/deepbeepmeep/FluxFillGP\ 58 | One of the best inpainting / outpainting tools based on Flux that can run with less than 12 GB of VRAM. 59 | 60 | - Cosmos1GP: https://github.com/deepbeepmeep/Cosmos1GP\ 61 | This application include two models: a text to world generator and a image / video to world (probably the best open source image to video generator). 62 | 63 | ## News 64 | - **2025-01-25**: ⭐️ DeepBeepMeep fork: added support for mmgp 65 | - **2024-12-26**: ⭐️ Training code are released. Now you can create your own OminiControl model by customizing any control tasks (3D, multi-view, pose-guided, try-on, etc.) with the FLUX model. Check the [training folder](./train) for more details. 66 | 67 | ## Quick Start 68 | ### Setup (Optional) 69 | 1. **Environment setup** 70 | ```bash 71 | conda create -n omini python=3.10 72 | conda activate omini 73 | ``` 74 | 2. **Requirements installation** 75 | ```bash 76 | pip install -r requirements.txt 77 | ``` 78 | ### Usage example 79 | 1. Subject-driven generation: `examples/subject.ipynb` 80 | 2. In-painting: `examples/inpainting.ipynb` 81 | 3. Canny edge to image, depth to image, colorization, deblurring: `examples/spatial.ipynb` 82 | 83 | 84 | 85 | ### Guidelines for subject-driven generation 86 | 1. Input images are automatically center-cropped and resized to 512x512 resolution. 87 | 2. When writing prompts, refer to the subject using phrases like `this item`, `the object`, or `it`. e.g. 88 | 1. *A close up view of this item. It is placed on a wooden table.* 89 | 2. *A young lady is wearing this shirt.* 90 | 3. The model primarily works with objects rather than human subjects currently, due to the absence of human data in training. 91 | 92 | ## Generated samples 93 | ### Subject-driven generation 94 | HuggingFace 95 | 96 | **Demos** (Left: condition image; Right: generated image) 97 | 98 |
99 | 100 | 101 | 102 | 103 |
104 | 105 |
106 | Text Prompts 107 | 108 | - Prompt1: *A close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!.'* 109 | - Prompt2: *A film style shot. On the moon, this item drives across the moon surface. A flag on it reads 'Omini'. The background is that Earth looms large in the foreground.* 110 | - Prompt3: *In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.* 111 | - Prompt4: *"On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple."* 112 |
113 |
114 | More results 115 | 116 | * Try on: 117 | 118 | * Scene variations: 119 | 120 | * Dreambooth dataset: 121 | 122 | * Oye-cartoon finetune: 123 |
124 | 125 | 126 |
127 |
128 | 129 | ### Spatially aligned control 130 | 1. **Image Inpainting** (Left: original image; Center: masked image; Right: filled image) 131 | - Prompt: *The Mona Lisa is wearing a white VR headset with 'Omini' written on it.* 132 |
133 | 134 | - Prompt: *A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.* 135 |
136 | 137 | 2. **Other spatially aligned tasks** (Canny edge to image, depth to image, colorization, deblurring) 138 |
139 |
140 | Click to show 141 |
142 | 143 | 144 | 145 | 146 |
147 | 148 | Prompt: *A light gray sofa stands against a white wall, featuring a black and white geometric patterned pillow. A white side table sits next to the sofa, topped with a white adjustable desk lamp and some books. Dark hardwood flooring contrasts with the pale walls and furniture.* 149 |
150 | 151 | 152 | 153 | 154 | ## Models 155 | 156 | **Subject-driven control:** 157 | | Model | Base model | Description | Resolution | 158 | | ------------------------------------------------------------------------------------------------ | -------------- | -------------------------------------------------------------------------------------------------------- | ------------ | 159 | | [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `subject` | FLUX.1-schnell | The model used in the paper. | (512, 512) | 160 | | [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_512` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset. | (512, 512) | 161 | | [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_1024` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset and accommodates higher resolution. (To be released) | (1024, 1024) | 162 | | [`oye-cartoon`](https://huggingface.co/saquiboye/oye-cartoon) | FLUX.1-dev | The model has been fine-tuned on [oye-cartoon](https://huggingface.co/datasets/saquiboye/oye-cartoon) dataset by [@saquib764](https://github.com/Saquib764) | (512, 512) | 163 | 164 | **Spatial aligned control:** 165 | | Model | Base model | Description | Resolution | 166 | | --------------------------------------------------------------------------------------------------------- | ---------- | -------------------------------------------------------------------------- | ------------ | 167 | | [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `` | FLUX.1 | Canny edge to image, depth to image, colorization, deblurring, in-painting | (512, 512) | 168 | | [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `_1024` | FLUX.1 | Supports higher resolution.(To be released) | (1024, 1024) | 169 | 170 | ## Community Extensions 171 | - [ComfyUI-Diffusers-OminiControl](https://github.com/Macoron/ComfyUI-Diffusers-OminiControl) - ComfyUI integration by [@Macoron](https://github.com/Macoron) 172 | - [ComfyUI_RH_OminiControl](https://github.com/HM-RunningHub/ComfyUI_RH_OminiControl) - ComfyUI integration by [@HM-RunningHub](https://github.com/HM-RunningHub) 173 | 174 | ## Limitations 175 | 1. The model's subject-driven generation primarily works with objects rather than human subjects due to the absence of human data in training. 176 | 2. The subject-driven generation model may not work well with `FLUX.1-dev`. 177 | 3. The released model currently only supports the resolution of 512x512. 178 | 179 | ## Training 180 | Training instructions can be found in this [folder](./train). 181 | 182 | 183 | ## To-do 184 | - [x] Release the training code. 185 | - [ ] Release the model for higher resolution (1024x1024). 186 | 187 | ## Citation 188 | ``` 189 | @article{ 190 | tan2024omini, 191 | title={OminiControl: Minimal and Universal Control for Diffusion Transformer}, 192 | author={Zhenxiong Tan, Songhua Liu, Xingyi Yang, Qiaochu Xue, and Xinchao Wang}, 193 | journal={arXiv preprint arXiv:2411.15098}, 194 | year={2024} 195 | } 196 | ``` 197 | -------------------------------------------------------------------------------- /assets/book.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/book.jpg -------------------------------------------------------------------------------- /assets/cartoon_boy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/cartoon_boy.png -------------------------------------------------------------------------------- /assets/clock.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/clock.jpg -------------------------------------------------------------------------------- /assets/coffee.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/coffee.png -------------------------------------------------------------------------------- /assets/demo/book_omini.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/book_omini.jpg -------------------------------------------------------------------------------- /assets/demo/clock_omini.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/clock_omini.jpg -------------------------------------------------------------------------------- /assets/demo/demo_this_is_omini_control.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/demo_this_is_omini_control.jpg -------------------------------------------------------------------------------- /assets/demo/dreambooth_res.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/dreambooth_res.jpg -------------------------------------------------------------------------------- /assets/demo/man_omini.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/man_omini.jpg -------------------------------------------------------------------------------- /assets/demo/monalisa_omini.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/monalisa_omini.jpg -------------------------------------------------------------------------------- /assets/demo/oranges_omini.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/oranges_omini.jpg -------------------------------------------------------------------------------- /assets/demo/panda_omini.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/panda_omini.jpg -------------------------------------------------------------------------------- /assets/demo/penguin_omini.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/penguin_omini.jpg -------------------------------------------------------------------------------- /assets/demo/rc_car_omini.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/rc_car_omini.jpg -------------------------------------------------------------------------------- /assets/demo/room_corner_canny.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/room_corner_canny.jpg -------------------------------------------------------------------------------- /assets/demo/room_corner_coloring.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/room_corner_coloring.jpg -------------------------------------------------------------------------------- /assets/demo/room_corner_deblurring.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/room_corner_deblurring.jpg -------------------------------------------------------------------------------- /assets/demo/room_corner_depth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/room_corner_depth.jpg -------------------------------------------------------------------------------- /assets/demo/scene_variation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/scene_variation.jpg -------------------------------------------------------------------------------- /assets/demo/shirt_omini.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/shirt_omini.jpg -------------------------------------------------------------------------------- /assets/demo/try_on.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/try_on.jpg -------------------------------------------------------------------------------- /assets/monalisa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/monalisa.jpg -------------------------------------------------------------------------------- /assets/oranges.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/oranges.jpg -------------------------------------------------------------------------------- /assets/penguin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/penguin.jpg -------------------------------------------------------------------------------- /assets/rc_car.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/rc_car.jpg -------------------------------------------------------------------------------- /assets/room_corner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/room_corner.jpg -------------------------------------------------------------------------------- /assets/test_in.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/test_in.jpg -------------------------------------------------------------------------------- /assets/test_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/test_out.jpg -------------------------------------------------------------------------------- /assets/tshirt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/tshirt.jpg -------------------------------------------------------------------------------- /assets/vase.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/vase.jpg -------------------------------------------------------------------------------- /assets/vase_hq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/vase_hq.jpg -------------------------------------------------------------------------------- /examples/inpainting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "\n", 11 | "os.chdir(\"..\")" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import torch\n", 21 | "from diffusers.pipelines import FluxPipeline\n", 22 | "from src.flux.condition import Condition\n", 23 | "from PIL import Image\n", 24 | "\n", 25 | "from src.flux.generate import generate, seed_everything" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "pipe = FluxPipeline.from_pretrained(\n", 35 | " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n", 36 | ")\n", 37 | "pipe = pipe.to(\"cuda\")" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "pipe.load_lora_weights(\n", 47 | " \"Yuanshi/OminiControl\",\n", 48 | " weight_name=f\"experimental/fill.safetensors\",\n", 49 | " adapter_name=\"fill\",\n", 50 | ")" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "image = Image.open(\"assets/monalisa.jpg\").convert(\"RGB\").resize((512, 512))\n", 60 | "\n", 61 | "masked_image = image.copy()\n", 62 | "masked_image.paste((0, 0, 0), (128, 100, 384, 220))\n", 63 | "\n", 64 | "condition = Condition(\"fill\", masked_image)\n", 65 | "\n", 66 | "seed_everything()\n", 67 | "result_img = generate(\n", 68 | " pipe,\n", 69 | " prompt=\"The Mona Lisa is wearing a white VR headset with 'Omini' written on it.\",\n", 70 | " conditions=[condition],\n", 71 | ").images[0]\n", 72 | "\n", 73 | "concat_image = Image.new(\"RGB\", (1536, 512))\n", 74 | "concat_image.paste(image, (0, 0))\n", 75 | "concat_image.paste(condition.condition, (512, 0))\n", 76 | "concat_image.paste(result_img, (1024, 0))\n", 77 | "concat_image" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "image = Image.open(\"assets/book.jpg\").convert(\"RGB\").resize((512, 512))\n", 87 | "\n", 88 | "w, h, min_dim = image.size + (min(image.size),)\n", 89 | "image = image.crop(\n", 90 | " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n", 91 | ").resize((512, 512))\n", 92 | "\n", 93 | "\n", 94 | "masked_image = image.copy()\n", 95 | "masked_image.paste((0, 0, 0), (150, 150, 350, 250))\n", 96 | "masked_image.paste((0, 0, 0), (200, 380, 320, 420))\n", 97 | "\n", 98 | "condition = Condition(\"fill\", masked_image)\n", 99 | "\n", 100 | "seed_everything()\n", 101 | "result_img = generate(\n", 102 | " pipe,\n", 103 | " prompt=\"A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.\",\n", 104 | " conditions=[condition],\n", 105 | ").images[0]\n", 106 | "\n", 107 | "concat_image = Image.new(\"RGB\", (1536, 512))\n", 108 | "concat_image.paste(image, (0, 0))\n", 109 | "concat_image.paste(condition.condition, (512, 0))\n", 110 | "concat_image.paste(result_img, (1024, 0))\n", 111 | "concat_image" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [] 120 | } 121 | ], 122 | "metadata": { 123 | "kernelspec": { 124 | "display_name": "base", 125 | "language": "python", 126 | "name": "python3" 127 | }, 128 | "language_info": { 129 | "codemirror_mode": { 130 | "name": "ipython", 131 | "version": 3 132 | }, 133 | "file_extension": ".py", 134 | "mimetype": "text/x-python", 135 | "name": "python", 136 | "nbconvert_exporter": "python", 137 | "pygments_lexer": "ipython3", 138 | "version": "3.12.7" 139 | } 140 | }, 141 | "nbformat": 4, 142 | "nbformat_minor": 2 143 | } 144 | -------------------------------------------------------------------------------- /examples/spatial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "\n", 11 | "os.chdir(\"..\")" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import torch\n", 21 | "from diffusers.pipelines import FluxPipeline\n", 22 | "from src.flux.condition import Condition\n", 23 | "from PIL import Image\n", 24 | "\n", 25 | "from src.flux.generate import generate, seed_everything" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "pipe = FluxPipeline.from_pretrained(\n", 35 | " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n", 36 | ")\n", 37 | "pipe = pipe.to(\"cuda\")" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "for condition_type in [\"canny\", \"depth\", \"coloring\", \"deblurring\"]:\n", 47 | " pipe.load_lora_weights(\n", 48 | " \"Yuanshi/OminiControl\",\n", 49 | " weight_name=f\"experimental/{condition_type}.safetensors\",\n", 50 | " adapter_name=condition_type,\n", 51 | " )" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "image = Image.open(\"assets/coffee.png\").convert(\"RGB\")\n", 61 | "\n", 62 | "w, h, min_dim = image.size + (min(image.size),)\n", 63 | "image = image.crop(\n", 64 | " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n", 65 | ").resize((512, 512))\n", 66 | "\n", 67 | "prompt = \"In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table.\"" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "condition = Condition(\"canny\", image)\n", 77 | "\n", 78 | "seed_everything()\n", 79 | "\n", 80 | "result_img = generate(\n", 81 | " pipe,\n", 82 | " prompt=prompt,\n", 83 | " conditions=[condition],\n", 84 | ").images[0]\n", 85 | "\n", 86 | "concat_image = Image.new(\"RGB\", (1536, 512))\n", 87 | "concat_image.paste(image, (0, 0))\n", 88 | "concat_image.paste(condition.condition, (512, 0))\n", 89 | "concat_image.paste(result_img, (1024, 0))\n", 90 | "concat_image" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "condition = Condition(\"depth\", image)\n", 100 | "\n", 101 | "seed_everything()\n", 102 | "\n", 103 | "result_img = generate(\n", 104 | " pipe,\n", 105 | " prompt=prompt,\n", 106 | " conditions=[condition],\n", 107 | ").images[0]\n", 108 | "\n", 109 | "concat_image = Image.new(\"RGB\", (1536, 512))\n", 110 | "concat_image.paste(image, (0, 0))\n", 111 | "concat_image.paste(condition.condition, (512, 0))\n", 112 | "concat_image.paste(result_img, (1024, 0))\n", 113 | "concat_image" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "condition = Condition(\"deblurring\", image)\n", 123 | "\n", 124 | "seed_everything()\n", 125 | "\n", 126 | "result_img = generate(\n", 127 | " pipe,\n", 128 | " prompt=prompt,\n", 129 | " conditions=[condition],\n", 130 | ").images[0]\n", 131 | "\n", 132 | "concat_image = Image.new(\"RGB\", (1536, 512))\n", 133 | "concat_image.paste(image, (0, 0))\n", 134 | "concat_image.paste(condition.condition, (512, 0))\n", 135 | "concat_image.paste(result_img, (1024, 0))\n", 136 | "concat_image" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "condition = Condition(\"coloring\", image)\n", 146 | "\n", 147 | "seed_everything()\n", 148 | "\n", 149 | "result_img = generate(\n", 150 | " pipe,\n", 151 | " prompt=prompt,\n", 152 | " conditions=[condition],\n", 153 | ").images[0]\n", 154 | "\n", 155 | "concat_image = Image.new(\"RGB\", (1536, 512))\n", 156 | "concat_image.paste(image, (0, 0))\n", 157 | "concat_image.paste(condition.condition, (512, 0))\n", 158 | "concat_image.paste(result_img, (1024, 0))\n", 159 | "concat_image" 160 | ] 161 | } 162 | ], 163 | "metadata": { 164 | "kernelspec": { 165 | "display_name": "base", 166 | "language": "python", 167 | "name": "python3" 168 | }, 169 | "language_info": { 170 | "codemirror_mode": { 171 | "name": "ipython", 172 | "version": 3 173 | }, 174 | "file_extension": ".py", 175 | "mimetype": "text/x-python", 176 | "name": "python", 177 | "nbconvert_exporter": "python", 178 | "pygments_lexer": "ipython3", 179 | "version": "3.12.7" 180 | } 181 | }, 182 | "nbformat": 4, 183 | "nbformat_minor": 2 184 | } 185 | -------------------------------------------------------------------------------- /examples/subject.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "\n", 11 | "os.chdir(\"..\")" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import torch\n", 21 | "from diffusers.pipelines import FluxPipeline\n", 22 | "from src.flux.condition import Condition\n", 23 | "from PIL import Image\n", 24 | "\n", 25 | "from src.flux.generate import generate, seed_everything" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "pipe = FluxPipeline.from_pretrained(\n", 35 | " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n", 36 | ")\n", 37 | "pipe = pipe.to(\"cuda\")\n", 38 | "pipe.load_lora_weights(\n", 39 | " \"Yuanshi/OminiControl\",\n", 40 | " weight_name=f\"omini/subject_512.safetensors\",\n", 41 | " adapter_name=\"subject\",\n", 42 | ")" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n", 52 | "\n", 53 | "condition = Condition(\"subject\", image, position_delta=(0, 32))\n", 54 | "\n", 55 | "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n", 56 | "\n", 57 | "\n", 58 | "seed_everything(0)\n", 59 | "\n", 60 | "result_img = generate(\n", 61 | " pipe,\n", 62 | " prompt=prompt,\n", 63 | " conditions=[condition],\n", 64 | " num_inference_steps=8,\n", 65 | " height=512,\n", 66 | " width=512,\n", 67 | ").images[0]\n", 68 | "\n", 69 | "concat_image = Image.new(\"RGB\", (1024, 512))\n", 70 | "concat_image.paste(image, (0, 0))\n", 71 | "concat_image.paste(result_img, (512, 0))\n", 72 | "concat_image" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n", 82 | "\n", 83 | "condition = Condition(\"subject\", image, position_delta=(0, 32))\n", 84 | "\n", 85 | "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n", 86 | "\n", 87 | "\n", 88 | "seed_everything()\n", 89 | "\n", 90 | "result_img = generate(\n", 91 | " pipe,\n", 92 | " prompt=prompt,\n", 93 | " conditions=[condition],\n", 94 | " num_inference_steps=8,\n", 95 | " height=512,\n", 96 | " width=512,\n", 97 | ").images[0]\n", 98 | "\n", 99 | "concat_image = Image.new(\"RGB\", (1024, 512))\n", 100 | "concat_image.paste(condition.condition, (0, 0))\n", 101 | "concat_image.paste(result_img, (512, 0))\n", 102 | "concat_image" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n", 112 | "\n", 113 | "condition = Condition(\"subject\", image, position_delta=(0, 32))\n", 114 | "\n", 115 | "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n", 116 | "\n", 117 | "seed_everything()\n", 118 | "\n", 119 | "result_img = generate(\n", 120 | " pipe,\n", 121 | " prompt=prompt,\n", 122 | " conditions=[condition],\n", 123 | " num_inference_steps=8,\n", 124 | " height=512,\n", 125 | " width=512,\n", 126 | ").images[0]\n", 127 | "\n", 128 | "concat_image = Image.new(\"RGB\", (1024, 512))\n", 129 | "concat_image.paste(condition.condition, (0, 0))\n", 130 | "concat_image.paste(result_img, (512, 0))\n", 131 | "concat_image" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n", 141 | "\n", 142 | "condition = Condition(\"subject\", image, position_delta=(0, 32))\n", 143 | "\n", 144 | "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n", 145 | "\n", 146 | "seed_everything()\n", 147 | "\n", 148 | "result_img = generate(\n", 149 | " pipe,\n", 150 | " prompt=prompt,\n", 151 | " conditions=[condition],\n", 152 | " num_inference_steps=8,\n", 153 | " height=512,\n", 154 | " width=512,\n", 155 | ").images[0]\n", 156 | "\n", 157 | "concat_image = Image.new(\"RGB\", (1024, 512))\n", 158 | "concat_image.paste(condition.condition, (0, 0))\n", 159 | "concat_image.paste(result_img, (512, 0))\n", 160 | "concat_image" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n", 170 | "\n", 171 | "condition = Condition(\"subject\", image, position_delta=(0, 32))\n", 172 | "\n", 173 | "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n", 174 | "\n", 175 | "seed_everything()\n", 176 | "\n", 177 | "result_img = generate(\n", 178 | " pipe,\n", 179 | " prompt=prompt,\n", 180 | " conditions=[condition],\n", 181 | " num_inference_steps=8,\n", 182 | " height=512,\n", 183 | " width=512,\n", 184 | ").images[0]\n", 185 | "\n", 186 | "concat_image = Image.new(\"RGB\", (1024, 512))\n", 187 | "concat_image.paste(condition.condition, (0, 0))\n", 188 | "concat_image.paste(result_img, (512, 0))\n", 189 | "concat_image" 190 | ] 191 | } 192 | ], 193 | "metadata": { 194 | "kernelspec": { 195 | "display_name": "base", 196 | "language": "python", 197 | "name": "python3" 198 | }, 199 | "language_info": { 200 | "codemirror_mode": { 201 | "name": "ipython", 202 | "version": 3 203 | }, 204 | "file_extension": ".py", 205 | "mimetype": "text/x-python", 206 | "name": "python", 207 | "nbconvert_exporter": "python", 208 | "pygments_lexer": "ipython3", 209 | "version": "3.12.7" 210 | } 211 | }, 212 | "nbformat": 4, 213 | "nbformat_minor": 2 214 | } 215 | -------------------------------------------------------------------------------- /examples/subject_1024.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "\n", 11 | "os.chdir(\"..\")" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import torch\n", 21 | "from diffusers.pipelines import FluxPipeline\n", 22 | "from src.flux.condition import Condition\n", 23 | "from PIL import Image\n", 24 | "\n", 25 | "from src.flux.generate import generate, seed_everything" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "pipe = FluxPipeline.from_pretrained(\n", 35 | " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n", 36 | ")\n", 37 | "pipe = pipe.to(\"cuda\")\n", 38 | "pipe.load_lora_weights(\n", 39 | " \"Yuanshi/OminiControl\",\n", 40 | " weight_name=f\"omini/subject_1024_beta.safetensors\",\n", 41 | " adapter_name=\"subject\",\n", 42 | ")" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n", 52 | "\n", 53 | "condition = Condition(\"subject\", image)\n", 54 | "\n", 55 | "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n", 56 | "\n", 57 | "\n", 58 | "seed_everything(0)\n", 59 | "\n", 60 | "result_img = generate(\n", 61 | " pipe,\n", 62 | " prompt=prompt,\n", 63 | " conditions=[condition],\n", 64 | " num_inference_steps=8,\n", 65 | " height=1024,\n", 66 | " width=1024,\n", 67 | ").images[0]\n", 68 | "\n", 69 | "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n", 70 | "concat_image.paste(image, (0, 0))\n", 71 | "concat_image.paste(result_img, (512, 0))\n", 72 | "concat_image" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n", 82 | "\n", 83 | "condition = Condition(\"subject\", image)\n", 84 | "\n", 85 | "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n", 86 | "\n", 87 | "\n", 88 | "seed_everything(0)\n", 89 | "\n", 90 | "result_img = generate(\n", 91 | " pipe,\n", 92 | " prompt=prompt,\n", 93 | " conditions=[condition],\n", 94 | " num_inference_steps=8,\n", 95 | " height=1024,\n", 96 | " width=1024,\n", 97 | ").images[0]\n", 98 | "\n", 99 | "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n", 100 | "concat_image.paste(image, (0, 0))\n", 101 | "concat_image.paste(result_img, (512, 0))\n", 102 | "concat_image" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n", 112 | "\n", 113 | "condition = Condition(\"subject\", image)\n", 114 | "\n", 115 | "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n", 116 | "\n", 117 | "seed_everything()\n", 118 | "\n", 119 | "result_img = generate(\n", 120 | " pipe,\n", 121 | " prompt=prompt,\n", 122 | " conditions=[condition],\n", 123 | " num_inference_steps=8,\n", 124 | " height=1024,\n", 125 | " width=1024,\n", 126 | ").images[0]\n", 127 | "\n", 128 | "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n", 129 | "concat_image.paste(image, (0, 0))\n", 130 | "concat_image.paste(result_img, (512, 0))\n", 131 | "concat_image" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n", 141 | "\n", 142 | "condition = Condition(\"subject\", image)\n", 143 | "\n", 144 | "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n", 145 | "\n", 146 | "seed_everything(0)\n", 147 | "\n", 148 | "result_img = generate(\n", 149 | " pipe,\n", 150 | " prompt=prompt,\n", 151 | " conditions=[condition],\n", 152 | " num_inference_steps=8,\n", 153 | " height=1024,\n", 154 | " width=1024,\n", 155 | ").images[0]\n", 156 | "\n", 157 | "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n", 158 | "concat_image.paste(image, (0, 0))\n", 159 | "concat_image.paste(result_img, (512, 0))\n", 160 | "concat_image" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n", 170 | "\n", 171 | "condition = Condition(\"subject\", image)\n", 172 | "\n", 173 | "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n", 174 | "\n", 175 | "seed_everything()\n", 176 | "\n", 177 | "result_img = generate(\n", 178 | " pipe,\n", 179 | " prompt=prompt,\n", 180 | " conditions=[condition],\n", 181 | " num_inference_steps=8,\n", 182 | " height=1024,\n", 183 | " width=1024,\n", 184 | ").images[0]\n", 185 | "\n", 186 | "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n", 187 | "concat_image.paste(image, (0, 0))\n", 188 | "concat_image.paste(result_img, (512, 0))\n", 189 | "concat_image" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [] 198 | } 199 | ], 200 | "metadata": { 201 | "kernelspec": { 202 | "display_name": "Python 3 (ipykernel)", 203 | "language": "python", 204 | "name": "python3" 205 | }, 206 | "language_info": { 207 | "codemirror_mode": { 208 | "name": "ipython", 209 | "version": 3 210 | }, 211 | "file_extension": ".py", 212 | "mimetype": "text/x-python", 213 | "name": "python", 214 | "nbconvert_exporter": "python", 215 | "pygments_lexer": "ipython3", 216 | "version": "3.9.21" 217 | } 218 | }, 219 | "nbformat": 4, 220 | "nbformat_minor": 2 221 | } 222 | -------------------------------------------------------------------------------- /gradio_app.py: -------------------------------------------------------------------------------- 1 | from mmgp import offload 2 | import gradio as gr 3 | import torch 4 | from PIL import Image, ImageDraw, ImageFont 5 | from diffusers.pipelines import FluxPipeline 6 | from diffusers import FluxTransformer2DModel 7 | import numpy as np 8 | 9 | from src.flux.condition import Condition 10 | from src.flux.generate import seed_everything, generate 11 | 12 | 13 | pipe = None 14 | use_int8 = False 15 | 16 | 17 | def get_gpu_memory(): 18 | return torch.cuda.get_device_properties(0).total_memory / 1024**3 19 | 20 | 21 | def init_pipeline(): 22 | global pipe 23 | if False and (use_int8 or get_gpu_memory() < 33): 24 | transformer_model = FluxTransformer2DModel.from_pretrained( 25 | "sayakpaul/flux.1-schell-int8wo-improved", 26 | torch_dtype=torch.bfloat16, 27 | use_safetensors=False, 28 | ) 29 | pipe = FluxPipeline.from_pretrained( 30 | "black-forest-labs/FLUX.1-schnell", 31 | transformer=transformer_model, 32 | torch_dtype=torch.bfloat16, 33 | ) 34 | else: 35 | pipe = FluxPipeline.from_pretrained( 36 | "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 37 | ) 38 | pipe = pipe.to("cpu") 39 | pipe.load_lora_weights( 40 | "Yuanshi/OminiControl", 41 | weight_name="omini/subject_512.safetensors", 42 | adapter_name="subject", 43 | ) 44 | offload.profile(pipe, profile_no=int(args.profile), verboseLevel=int(args.verbose), quantizeTransformer= False 45 | ) 46 | 47 | def process_image_and_text(image, text): 48 | # center crop image 49 | w, h, min_size = image.size[0], image.size[1], min(image.size) 50 | image = image.crop( 51 | ( 52 | (w - min_size) // 2, 53 | (h - min_size) // 2, 54 | (w + min_size) // 2, 55 | (h + min_size) // 2, 56 | ) 57 | ) 58 | image = image.resize((512, 512)) 59 | 60 | condition = Condition("subject", image, position_delta=(0, 32)) 61 | 62 | if pipe is None: 63 | init_pipeline() 64 | 65 | result_img = generate( 66 | pipe, 67 | prompt=text.strip(), 68 | conditions=[condition], 69 | num_inference_steps=8, 70 | height=512, 71 | width=512, 72 | ).images[0] 73 | 74 | return result_img 75 | 76 | 77 | def get_samples(): 78 | sample_list = [ 79 | { 80 | "image": "assets/oranges.jpg", 81 | "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'", 82 | }, 83 | { 84 | "image": "assets/penguin.jpg", 85 | "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'", 86 | }, 87 | { 88 | "image": "assets/rc_car.jpg", 89 | "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.", 90 | }, 91 | { 92 | "image": "assets/clock.jpg", 93 | "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.", 94 | }, 95 | { 96 | "image": "assets/tshirt.jpg", 97 | "text": "On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her.", 98 | }, 99 | ] 100 | return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list] 101 | 102 | 103 | demo = gr.Interface( 104 | fn=process_image_and_text, 105 | inputs=[ 106 | gr.Image(type="pil"), 107 | gr.Textbox(lines=2), 108 | ], 109 | outputs=gr.Image(type="pil"), 110 | title="OminiControlGP / Subject driven generation for the GPU Poor", 111 | 112 | examples=get_samples(), 113 | ) 114 | 115 | if __name__ == "__main__": 116 | import argparse 117 | 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument('--profile', type=str, default="3") 120 | parser.add_argument('--verbose', type=str, default="1") 121 | 122 | args = parser.parse_args() 123 | 124 | init_pipeline() 125 | demo.launch( 126 | debug=True, 127 | ) 128 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | diffusers 3 | peft 4 | opencv-python 5 | protobuf 6 | sentencepiece 7 | gradio 8 | jupyter 9 | torchao 10 | mmgp==3.1.4.post1 -------------------------------------------------------------------------------- /src/flux/block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Union, Optional, Dict, Any, Callable 3 | from diffusers.models.attention_processor import Attention, F 4 | from .lora_controller import enable_lora 5 | 6 | 7 | def attn_forward( 8 | attn: Attention, 9 | hidden_states: torch.FloatTensor, 10 | encoder_hidden_states: torch.FloatTensor = None, 11 | condition_latents: torch.FloatTensor = None, 12 | attention_mask: Optional[torch.FloatTensor] = None, 13 | image_rotary_emb: Optional[torch.Tensor] = None, 14 | cond_rotary_emb: Optional[torch.Tensor] = None, 15 | model_config: Optional[Dict[str, Any]] = {}, 16 | ) -> torch.FloatTensor: 17 | batch_size, _, _ = ( 18 | hidden_states.shape 19 | if encoder_hidden_states is None 20 | else encoder_hidden_states.shape 21 | ) 22 | 23 | with enable_lora( 24 | (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False) 25 | ): 26 | # `sample` projections. 27 | query = attn.to_q(hidden_states) 28 | key = attn.to_k(hidden_states) 29 | value = attn.to_v(hidden_states) 30 | 31 | inner_dim = key.shape[-1] 32 | head_dim = inner_dim // attn.heads 33 | 34 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 35 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 36 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 37 | 38 | if attn.norm_q is not None: 39 | query = attn.norm_q(query) 40 | if attn.norm_k is not None: 41 | key = attn.norm_k(key) 42 | 43 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` 44 | if encoder_hidden_states is not None: 45 | # `context` projections. 46 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 47 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 48 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 49 | 50 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( 51 | batch_size, -1, attn.heads, head_dim 52 | ).transpose(1, 2) 53 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( 54 | batch_size, -1, attn.heads, head_dim 55 | ).transpose(1, 2) 56 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( 57 | batch_size, -1, attn.heads, head_dim 58 | ).transpose(1, 2) 59 | 60 | if attn.norm_added_q is not None: 61 | encoder_hidden_states_query_proj = attn.norm_added_q( 62 | encoder_hidden_states_query_proj 63 | ) 64 | if attn.norm_added_k is not None: 65 | encoder_hidden_states_key_proj = attn.norm_added_k( 66 | encoder_hidden_states_key_proj 67 | ) 68 | 69 | # attention 70 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) 71 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) 72 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) 73 | 74 | if image_rotary_emb is not None: 75 | from diffusers.models.embeddings import apply_rotary_emb 76 | 77 | query = apply_rotary_emb(query, image_rotary_emb) 78 | key = apply_rotary_emb(key, image_rotary_emb) 79 | 80 | if condition_latents is not None: 81 | cond_query = attn.to_q(condition_latents) 82 | cond_key = attn.to_k(condition_latents) 83 | cond_value = attn.to_v(condition_latents) 84 | 85 | cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose( 86 | 1, 2 87 | ) 88 | cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 89 | cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose( 90 | 1, 2 91 | ) 92 | if attn.norm_q is not None: 93 | cond_query = attn.norm_q(cond_query) 94 | if attn.norm_k is not None: 95 | cond_key = attn.norm_k(cond_key) 96 | 97 | if cond_rotary_emb is not None: 98 | cond_query = apply_rotary_emb(cond_query, cond_rotary_emb) 99 | cond_key = apply_rotary_emb(cond_key, cond_rotary_emb) 100 | 101 | if condition_latents is not None: 102 | query = torch.cat([query, cond_query], dim=2) 103 | key = torch.cat([key, cond_key], dim=2) 104 | value = torch.cat([value, cond_value], dim=2) 105 | 106 | if not model_config.get("union_cond_attn", True): 107 | # If we don't want to use the union condition attention, we need to mask the attention 108 | # between the hidden states and the condition latents 109 | attention_mask = torch.ones( 110 | query.shape[2], key.shape[2], device=query.device, dtype=torch.bool 111 | ) 112 | condition_n = cond_query.shape[2] 113 | attention_mask[-condition_n:, :-condition_n] = False 114 | attention_mask[:-condition_n, -condition_n:] = False 115 | if hasattr(attn, "c_factor"): 116 | attention_mask = torch.zeros( 117 | query.shape[2], key.shape[2], device=query.device, dtype=query.dtype 118 | ) 119 | condition_n = cond_query.shape[2] 120 | bias = torch.log(attn.c_factor[0]) 121 | attention_mask[-condition_n:, :-condition_n] = bias 122 | attention_mask[:-condition_n, -condition_n:] = bias 123 | hidden_states = F.scaled_dot_product_attention( 124 | query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask 125 | ) 126 | hidden_states = hidden_states.transpose(1, 2).reshape( 127 | batch_size, -1, attn.heads * head_dim 128 | ) 129 | hidden_states = hidden_states.to(query.dtype) 130 | 131 | if encoder_hidden_states is not None: 132 | if condition_latents is not None: 133 | encoder_hidden_states, hidden_states, condition_latents = ( 134 | hidden_states[:, : encoder_hidden_states.shape[1]], 135 | hidden_states[ 136 | :, encoder_hidden_states.shape[1] : -condition_latents.shape[1] 137 | ], 138 | hidden_states[:, -condition_latents.shape[1] :], 139 | ) 140 | else: 141 | encoder_hidden_states, hidden_states = ( 142 | hidden_states[:, : encoder_hidden_states.shape[1]], 143 | hidden_states[:, encoder_hidden_states.shape[1] :], 144 | ) 145 | 146 | with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)): 147 | # linear proj 148 | hidden_states = attn.to_out[0](hidden_states) 149 | # dropout 150 | hidden_states = attn.to_out[1](hidden_states) 151 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 152 | 153 | if condition_latents is not None: 154 | condition_latents = attn.to_out[0](condition_latents) 155 | condition_latents = attn.to_out[1](condition_latents) 156 | 157 | return ( 158 | (hidden_states, encoder_hidden_states, condition_latents) 159 | if condition_latents is not None 160 | else (hidden_states, encoder_hidden_states) 161 | ) 162 | elif condition_latents is not None: 163 | # if there are condition_latents, we need to separate the hidden_states and the condition_latents 164 | hidden_states, condition_latents = ( 165 | hidden_states[:, : -condition_latents.shape[1]], 166 | hidden_states[:, -condition_latents.shape[1] :], 167 | ) 168 | return hidden_states, condition_latents 169 | else: 170 | return hidden_states 171 | 172 | 173 | def block_forward( 174 | self, 175 | hidden_states: torch.FloatTensor, 176 | encoder_hidden_states: torch.FloatTensor, 177 | condition_latents: torch.FloatTensor, 178 | temb: torch.FloatTensor, 179 | cond_temb: torch.FloatTensor, 180 | cond_rotary_emb=None, 181 | image_rotary_emb=None, 182 | model_config: Optional[Dict[str, Any]] = {}, 183 | ): 184 | use_cond = condition_latents is not None 185 | with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)): 186 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 187 | hidden_states, emb=temb 188 | ) 189 | 190 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( 191 | self.norm1_context(encoder_hidden_states, emb=temb) 192 | ) 193 | 194 | if use_cond: 195 | ( 196 | norm_condition_latents, 197 | cond_gate_msa, 198 | cond_shift_mlp, 199 | cond_scale_mlp, 200 | cond_gate_mlp, 201 | ) = self.norm1(condition_latents, emb=cond_temb) 202 | 203 | # Attention. 204 | result = attn_forward( 205 | self.attn, 206 | model_config=model_config, 207 | hidden_states=norm_hidden_states, 208 | encoder_hidden_states=norm_encoder_hidden_states, 209 | condition_latents=norm_condition_latents if use_cond else None, 210 | image_rotary_emb=image_rotary_emb, 211 | cond_rotary_emb=cond_rotary_emb if use_cond else None, 212 | ) 213 | attn_output, context_attn_output = result[:2] 214 | cond_attn_output = result[2] if use_cond else None 215 | 216 | # Process attention outputs for the `hidden_states`. 217 | # 1. hidden_states 218 | attn_output = gate_msa.unsqueeze(1) * attn_output 219 | hidden_states = hidden_states + attn_output 220 | # 2. encoder_hidden_states 221 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output 222 | encoder_hidden_states = encoder_hidden_states + context_attn_output 223 | # 3. condition_latents 224 | if use_cond: 225 | cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output 226 | condition_latents = condition_latents + cond_attn_output 227 | if model_config.get("add_cond_attn", False): 228 | hidden_states += cond_attn_output 229 | 230 | # LayerNorm + MLP. 231 | # 1. hidden_states 232 | norm_hidden_states = self.norm2(hidden_states) 233 | norm_hidden_states = ( 234 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 235 | ) 236 | # 2. encoder_hidden_states 237 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) 238 | norm_encoder_hidden_states = ( 239 | norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] 240 | ) 241 | # 3. condition_latents 242 | if use_cond: 243 | norm_condition_latents = self.norm2(condition_latents) 244 | norm_condition_latents = ( 245 | norm_condition_latents * (1 + cond_scale_mlp[:, None]) 246 | + cond_shift_mlp[:, None] 247 | ) 248 | 249 | # Feed-forward. 250 | with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)): 251 | # 1. hidden_states 252 | ff_output = self.ff(norm_hidden_states) 253 | ff_output = gate_mlp.unsqueeze(1) * ff_output 254 | # 2. encoder_hidden_states 255 | context_ff_output = self.ff_context(norm_encoder_hidden_states) 256 | context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output 257 | # 3. condition_latents 258 | if use_cond: 259 | cond_ff_output = self.ff(norm_condition_latents) 260 | cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output 261 | 262 | # Process feed-forward outputs. 263 | hidden_states = hidden_states + ff_output 264 | encoder_hidden_states = encoder_hidden_states + context_ff_output 265 | if use_cond: 266 | condition_latents = condition_latents + cond_ff_output 267 | 268 | # Clip to avoid overflow. 269 | if encoder_hidden_states.dtype == torch.float16: 270 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) 271 | 272 | return encoder_hidden_states, hidden_states, condition_latents if use_cond else None 273 | 274 | 275 | def single_block_forward( 276 | self, 277 | hidden_states: torch.FloatTensor, 278 | temb: torch.FloatTensor, 279 | image_rotary_emb=None, 280 | condition_latents: torch.FloatTensor = None, 281 | cond_temb: torch.FloatTensor = None, 282 | cond_rotary_emb=None, 283 | model_config: Optional[Dict[str, Any]] = {}, 284 | ): 285 | 286 | using_cond = condition_latents is not None 287 | residual = hidden_states 288 | with enable_lora( 289 | ( 290 | self.norm.linear, 291 | self.proj_mlp, 292 | ), 293 | model_config.get("latent_lora", False), 294 | ): 295 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb) 296 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) 297 | if using_cond: 298 | residual_cond = condition_latents 299 | norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb) 300 | mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents)) 301 | 302 | attn_output = attn_forward( 303 | self.attn, 304 | model_config=model_config, 305 | hidden_states=norm_hidden_states, 306 | image_rotary_emb=image_rotary_emb, 307 | **( 308 | { 309 | "condition_latents": norm_condition_latents, 310 | "cond_rotary_emb": cond_rotary_emb if using_cond else None, 311 | } 312 | if using_cond 313 | else {} 314 | ), 315 | ) 316 | if using_cond: 317 | attn_output, cond_attn_output = attn_output 318 | 319 | with enable_lora((self.proj_out,), model_config.get("latent_lora", False)): 320 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) 321 | gate = gate.unsqueeze(1) 322 | hidden_states = gate * self.proj_out(hidden_states) 323 | hidden_states = residual + hidden_states 324 | if using_cond: 325 | condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2) 326 | cond_gate = cond_gate.unsqueeze(1) 327 | condition_latents = cond_gate * self.proj_out(condition_latents) 328 | condition_latents = residual_cond + condition_latents 329 | 330 | if hidden_states.dtype == torch.float16: 331 | hidden_states = hidden_states.clip(-65504, 65504) 332 | 333 | return hidden_states if not using_cond else (hidden_states, condition_latents) 334 | -------------------------------------------------------------------------------- /src/flux/condition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Union, List, Tuple 3 | from diffusers.pipelines import FluxPipeline 4 | from PIL import Image, ImageFilter 5 | import numpy as np 6 | import cv2 7 | 8 | from .pipeline_tools import encode_images 9 | 10 | condition_dict = { 11 | "depth": 0, 12 | "canny": 1, 13 | "subject": 4, 14 | "coloring": 6, 15 | "deblurring": 7, 16 | "depth_pred": 8, 17 | "fill": 9, 18 | "sr": 10, 19 | "cartoon": 11, 20 | } 21 | 22 | 23 | class Condition(object): 24 | def __init__( 25 | self, 26 | condition_type: str, 27 | raw_img: Union[Image.Image, torch.Tensor] = None, 28 | condition: Union[Image.Image, torch.Tensor] = None, 29 | mask=None, 30 | position_delta=None, 31 | ) -> None: 32 | self.condition_type = condition_type 33 | assert raw_img is not None or condition is not None 34 | if raw_img is not None: 35 | self.condition = self.get_condition(condition_type, raw_img) 36 | else: 37 | self.condition = condition 38 | self.position_delta = position_delta 39 | # TODO: Add mask support 40 | assert mask is None, "Mask not supported yet" 41 | 42 | def get_condition( 43 | self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor] 44 | ) -> Union[Image.Image, torch.Tensor]: 45 | """ 46 | Returns the condition image. 47 | """ 48 | if condition_type == "depth": 49 | from transformers import pipeline 50 | 51 | depth_pipe = pipeline( 52 | task="depth-estimation", 53 | model="LiheYoung/depth-anything-small-hf", 54 | device="cuda", 55 | ) 56 | source_image = raw_img.convert("RGB") 57 | condition_img = depth_pipe(source_image)["depth"].convert("RGB") 58 | return condition_img 59 | elif condition_type == "canny": 60 | img = np.array(raw_img) 61 | edges = cv2.Canny(img, 100, 200) 62 | edges = Image.fromarray(edges).convert("RGB") 63 | return edges 64 | elif condition_type == "subject": 65 | return raw_img 66 | elif condition_type == "coloring": 67 | return raw_img.convert("L").convert("RGB") 68 | elif condition_type == "deblurring": 69 | condition_image = ( 70 | raw_img.convert("RGB") 71 | .filter(ImageFilter.GaussianBlur(10)) 72 | .convert("RGB") 73 | ) 74 | return condition_image 75 | elif condition_type == "fill": 76 | return raw_img.convert("RGB") 77 | elif condition_type == "cartoon": 78 | return raw_img.convert("RGB") 79 | return self.condition 80 | 81 | @property 82 | def type_id(self) -> int: 83 | """ 84 | Returns the type id of the condition. 85 | """ 86 | return condition_dict[self.condition_type] 87 | 88 | @classmethod 89 | def get_type_id(cls, condition_type: str) -> int: 90 | """ 91 | Returns the type id of the condition. 92 | """ 93 | return condition_dict[condition_type] 94 | 95 | def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]: 96 | """ 97 | Encodes the condition into tokens, ids and type_id. 98 | """ 99 | if self.condition_type in [ 100 | "depth", 101 | "canny", 102 | "subject", 103 | "coloring", 104 | "deblurring", 105 | "depth_pred", 106 | "fill", 107 | "sr", 108 | "cartoon" 109 | ]: 110 | tokens, ids = encode_images(pipe, self.condition) 111 | else: 112 | raise NotImplementedError( 113 | f"Condition type {self.condition_type} not implemented" 114 | ) 115 | if self.position_delta is None and self.condition_type == "subject": 116 | self.position_delta = [0, -self.condition.size[0] // 16] 117 | if self.position_delta is not None: 118 | ids[:, 1] += self.position_delta[0] 119 | ids[:, 2] += self.position_delta[1] 120 | type_id = torch.ones_like(ids[:, :1]) * self.type_id 121 | return tokens, ids, type_id 122 | -------------------------------------------------------------------------------- /src/flux/generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml, os 3 | from diffusers.pipelines import FluxPipeline 4 | from typing import List, Union, Optional, Dict, Any, Callable 5 | from .transformer import tranformer_forward 6 | from .condition import Condition 7 | 8 | from diffusers.pipelines.flux.pipeline_flux import ( 9 | FluxPipelineOutput, 10 | calculate_shift, 11 | retrieve_timesteps, 12 | np, 13 | ) 14 | 15 | 16 | def get_config(config_path: str = None): 17 | config_path = config_path or os.environ.get("XFL_CONFIG") 18 | if not config_path: 19 | return {} 20 | with open(config_path, "r") as f: 21 | config = yaml.safe_load(f) 22 | return config 23 | 24 | 25 | def prepare_params( 26 | prompt: Union[str, List[str]] = None, 27 | prompt_2: Optional[Union[str, List[str]]] = None, 28 | height: Optional[int] = 512, 29 | width: Optional[int] = 512, 30 | num_inference_steps: int = 28, 31 | timesteps: List[int] = None, 32 | guidance_scale: float = 3.5, 33 | num_images_per_prompt: Optional[int] = 1, 34 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 35 | latents: Optional[torch.FloatTensor] = None, 36 | prompt_embeds: Optional[torch.FloatTensor] = None, 37 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 38 | output_type: Optional[str] = "pil", 39 | return_dict: bool = True, 40 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 41 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 42 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 43 | max_sequence_length: int = 512, 44 | **kwargs: dict, 45 | ): 46 | return ( 47 | prompt, 48 | prompt_2, 49 | height, 50 | width, 51 | num_inference_steps, 52 | timesteps, 53 | guidance_scale, 54 | num_images_per_prompt, 55 | generator, 56 | latents, 57 | prompt_embeds, 58 | pooled_prompt_embeds, 59 | output_type, 60 | return_dict, 61 | joint_attention_kwargs, 62 | callback_on_step_end, 63 | callback_on_step_end_tensor_inputs, 64 | max_sequence_length, 65 | ) 66 | 67 | 68 | def seed_everything(seed: int = 42): 69 | torch.backends.cudnn.deterministic = True 70 | torch.manual_seed(seed) 71 | np.random.seed(seed) 72 | 73 | 74 | @torch.no_grad() 75 | def generate( 76 | pipeline: FluxPipeline, 77 | conditions: List[Condition] = None, 78 | config_path: str = None, 79 | model_config: Optional[Dict[str, Any]] = {}, 80 | condition_scale: float = 1.0, 81 | default_lora: bool = False, 82 | **params: dict, 83 | ): 84 | model_config = model_config or get_config(config_path).get("model", {}) 85 | if condition_scale != 1: 86 | for name, module in pipeline.transformer.named_modules(): 87 | if not name.endswith(".attn"): 88 | continue 89 | module.c_factor = torch.ones(1, 1) * condition_scale 90 | 91 | self = pipeline 92 | ( 93 | prompt, 94 | prompt_2, 95 | height, 96 | width, 97 | num_inference_steps, 98 | timesteps, 99 | guidance_scale, 100 | num_images_per_prompt, 101 | generator, 102 | latents, 103 | prompt_embeds, 104 | pooled_prompt_embeds, 105 | output_type, 106 | return_dict, 107 | joint_attention_kwargs, 108 | callback_on_step_end, 109 | callback_on_step_end_tensor_inputs, 110 | max_sequence_length, 111 | ) = prepare_params(**params) 112 | 113 | height = height or self.default_sample_size * self.vae_scale_factor 114 | width = width or self.default_sample_size * self.vae_scale_factor 115 | 116 | # 1. Check inputs. Raise error if not correct 117 | self.check_inputs( 118 | prompt, 119 | prompt_2, 120 | height, 121 | width, 122 | prompt_embeds=prompt_embeds, 123 | pooled_prompt_embeds=pooled_prompt_embeds, 124 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 125 | max_sequence_length=max_sequence_length, 126 | ) 127 | 128 | self._guidance_scale = guidance_scale 129 | self._joint_attention_kwargs = joint_attention_kwargs 130 | self._interrupt = False 131 | 132 | # 2. Define call parameters 133 | if prompt is not None and isinstance(prompt, str): 134 | batch_size = 1 135 | elif prompt is not None and isinstance(prompt, list): 136 | batch_size = len(prompt) 137 | else: 138 | batch_size = prompt_embeds.shape[0] 139 | 140 | device = self._execution_device 141 | 142 | lora_scale = ( 143 | self.joint_attention_kwargs.get("scale", None) 144 | if self.joint_attention_kwargs is not None 145 | else None 146 | ) 147 | ( 148 | prompt_embeds, 149 | pooled_prompt_embeds, 150 | text_ids, 151 | ) = self.encode_prompt( 152 | prompt=prompt, 153 | prompt_2=prompt_2, 154 | prompt_embeds=prompt_embeds, 155 | pooled_prompt_embeds=pooled_prompt_embeds, 156 | device=device, 157 | num_images_per_prompt=num_images_per_prompt, 158 | max_sequence_length=max_sequence_length, 159 | lora_scale=lora_scale, 160 | ) 161 | 162 | # 4. Prepare latent variables 163 | num_channels_latents = self.transformer.config.in_channels // 4 164 | latents, latent_image_ids = self.prepare_latents( 165 | batch_size * num_images_per_prompt, 166 | num_channels_latents, 167 | height, 168 | width, 169 | prompt_embeds.dtype, 170 | device, 171 | generator, 172 | latents, 173 | ) 174 | 175 | # 4.1. Prepare conditions 176 | condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3)) 177 | use_condition = conditions is not None or [] 178 | if use_condition: 179 | assert len(conditions) <= 1, "Only one condition is supported for now." 180 | if not default_lora: 181 | pipeline.set_adapters(conditions[0].condition_type) 182 | for condition in conditions: 183 | tokens, ids, type_id = condition.encode(self) 184 | condition_latents.append(tokens) # [batch_size, token_n, token_dim] 185 | condition_ids.append(ids) # [token_n, id_dim(3)] 186 | condition_type_ids.append(type_id) # [token_n, 1] 187 | condition_latents = torch.cat(condition_latents, dim=1) 188 | condition_ids = torch.cat(condition_ids, dim=0) 189 | condition_type_ids = torch.cat(condition_type_ids, dim=0) 190 | 191 | # 5. Prepare timesteps 192 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 193 | image_seq_len = latents.shape[1] 194 | mu = calculate_shift( 195 | image_seq_len, 196 | self.scheduler.config.base_image_seq_len, 197 | self.scheduler.config.max_image_seq_len, 198 | self.scheduler.config.base_shift, 199 | self.scheduler.config.max_shift, 200 | ) 201 | timesteps, num_inference_steps = retrieve_timesteps( 202 | self.scheduler, 203 | num_inference_steps, 204 | device, 205 | timesteps, 206 | sigmas, 207 | mu=mu, 208 | ) 209 | num_warmup_steps = max( 210 | len(timesteps) - num_inference_steps * self.scheduler.order, 0 211 | ) 212 | self._num_timesteps = len(timesteps) 213 | 214 | # 6. Denoising loop 215 | with self.progress_bar(total=num_inference_steps) as progress_bar: 216 | for i, t in enumerate(timesteps): 217 | if self.interrupt: 218 | continue 219 | 220 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 221 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 222 | 223 | # handle guidance 224 | if self.transformer.config.guidance_embeds: 225 | guidance = torch.tensor([guidance_scale], device=device) 226 | guidance = guidance.expand(latents.shape[0]) 227 | else: 228 | guidance = None 229 | noise_pred = tranformer_forward( 230 | self.transformer, 231 | model_config=model_config, 232 | # Inputs of the condition (new feature) 233 | condition_latents=condition_latents if use_condition else None, 234 | condition_ids=condition_ids if use_condition else None, 235 | condition_type_ids=condition_type_ids if use_condition else None, 236 | # Inputs to the original transformer 237 | hidden_states=latents, 238 | # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) 239 | timestep=timestep / 1000, 240 | guidance=guidance, 241 | pooled_projections=pooled_prompt_embeds, 242 | encoder_hidden_states=prompt_embeds, 243 | txt_ids=text_ids, 244 | img_ids=latent_image_ids, 245 | joint_attention_kwargs=self.joint_attention_kwargs, 246 | return_dict=False, 247 | )[0] 248 | 249 | # compute the previous noisy sample x_t -> x_t-1 250 | latents_dtype = latents.dtype 251 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 252 | 253 | if latents.dtype != latents_dtype: 254 | if torch.backends.mps.is_available(): 255 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 256 | latents = latents.to(latents_dtype) 257 | 258 | if callback_on_step_end is not None: 259 | callback_kwargs = {} 260 | for k in callback_on_step_end_tensor_inputs: 261 | callback_kwargs[k] = locals()[k] 262 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 263 | 264 | latents = callback_outputs.pop("latents", latents) 265 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 266 | 267 | # call the callback, if provided 268 | if i == len(timesteps) - 1 or ( 269 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 270 | ): 271 | progress_bar.update() 272 | 273 | if output_type == "latent": 274 | image = latents 275 | 276 | else: 277 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 278 | latents = ( 279 | latents / self.vae.config.scaling_factor 280 | ) + self.vae.config.shift_factor 281 | image = self.vae.decode(latents, return_dict=False)[0] 282 | image = self.image_processor.postprocess(image, output_type=output_type) 283 | 284 | # Offload all models 285 | self.maybe_free_model_hooks() 286 | 287 | if condition_scale != 1: 288 | for name, module in pipeline.transformer.named_modules(): 289 | if not name.endswith(".attn"): 290 | continue 291 | del module.c_factor 292 | 293 | if not return_dict: 294 | return (image,) 295 | 296 | return FluxPipelineOutput(images=image) 297 | -------------------------------------------------------------------------------- /src/flux/lora_controller.py: -------------------------------------------------------------------------------- 1 | from peft.tuners.tuners_utils import BaseTunerLayer 2 | from typing import List, Any, Optional, Type 3 | 4 | 5 | class enable_lora: 6 | def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None: 7 | self.activated: bool = activated 8 | if activated: 9 | return 10 | self.lora_modules: List[BaseTunerLayer] = [ 11 | each for each in lora_modules if isinstance(each, BaseTunerLayer) 12 | ] 13 | self.scales = [ 14 | { 15 | active_adapter: lora_module.scaling[active_adapter] 16 | for active_adapter in lora_module.active_adapters 17 | } 18 | for lora_module in self.lora_modules 19 | ] 20 | 21 | def __enter__(self) -> None: 22 | if self.activated: 23 | return 24 | 25 | for lora_module in self.lora_modules: 26 | if not isinstance(lora_module, BaseTunerLayer): 27 | continue 28 | lora_module.scale_layer(0) 29 | 30 | def __exit__( 31 | self, 32 | exc_type: Optional[Type[BaseException]], 33 | exc_val: Optional[BaseException], 34 | exc_tb: Optional[Any], 35 | ) -> None: 36 | if self.activated: 37 | return 38 | for i, lora_module in enumerate(self.lora_modules): 39 | if not isinstance(lora_module, BaseTunerLayer): 40 | continue 41 | for active_adapter in lora_module.active_adapters: 42 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter] 43 | 44 | 45 | class set_lora_scale: 46 | def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None: 47 | self.lora_modules: List[BaseTunerLayer] = [ 48 | each for each in lora_modules if isinstance(each, BaseTunerLayer) 49 | ] 50 | self.scales = [ 51 | { 52 | active_adapter: lora_module.scaling[active_adapter] 53 | for active_adapter in lora_module.active_adapters 54 | } 55 | for lora_module in self.lora_modules 56 | ] 57 | self.scale = scale 58 | 59 | def __enter__(self) -> None: 60 | for lora_module in self.lora_modules: 61 | if not isinstance(lora_module, BaseTunerLayer): 62 | continue 63 | lora_module.scale_layer(self.scale) 64 | 65 | def __exit__( 66 | self, 67 | exc_type: Optional[Type[BaseException]], 68 | exc_val: Optional[BaseException], 69 | exc_tb: Optional[Any], 70 | ) -> None: 71 | for i, lora_module in enumerate(self.lora_modules): 72 | if not isinstance(lora_module, BaseTunerLayer): 73 | continue 74 | for active_adapter in lora_module.active_adapters: 75 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter] 76 | -------------------------------------------------------------------------------- /src/flux/pipeline_tools.py: -------------------------------------------------------------------------------- 1 | from diffusers.pipelines import FluxPipeline 2 | from diffusers.utils import logging 3 | from diffusers.pipelines.flux.pipeline_flux import logger 4 | from torch import Tensor 5 | 6 | 7 | def encode_images(pipeline: FluxPipeline, images: Tensor): 8 | images = pipeline.image_processor.preprocess(images) 9 | images = images.to(pipeline.device).to(pipeline.dtype) 10 | images = pipeline.vae.encode(images).latent_dist.sample() 11 | images = ( 12 | images - pipeline.vae.config.shift_factor 13 | ) * pipeline.vae.config.scaling_factor 14 | images_tokens = pipeline._pack_latents(images, *images.shape) 15 | images_ids = pipeline._prepare_latent_image_ids( 16 | images.shape[0], 17 | images.shape[2], 18 | images.shape[3], 19 | pipeline.device, 20 | pipeline.dtype, 21 | ) 22 | if images_tokens.shape[1] != images_ids.shape[0]: 23 | images_ids = pipeline._prepare_latent_image_ids( 24 | images.shape[0], 25 | images.shape[2] // 2, 26 | images.shape[3] // 2, 27 | pipeline.device, 28 | pipeline.dtype, 29 | ) 30 | return images_tokens, images_ids 31 | 32 | 33 | def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512): 34 | # Turn off warnings (CLIP overflow) 35 | logger.setLevel(logging.ERROR) 36 | ( 37 | prompt_embeds, 38 | pooled_prompt_embeds, 39 | text_ids, 40 | ) = pipeline.encode_prompt( 41 | prompt=prompts, 42 | prompt_2=None, 43 | prompt_embeds=None, 44 | pooled_prompt_embeds=None, 45 | device=pipeline.device, 46 | num_images_per_prompt=1, 47 | max_sequence_length=max_sequence_length, 48 | lora_scale=None, 49 | ) 50 | # Turn on warnings 51 | logger.setLevel(logging.WARNING) 52 | return prompt_embeds, pooled_prompt_embeds, text_ids 53 | -------------------------------------------------------------------------------- /src/flux/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.pipelines import FluxPipeline 3 | from typing import List, Union, Optional, Dict, Any, Callable 4 | from .block import block_forward, single_block_forward 5 | from .lora_controller import enable_lora 6 | from diffusers.models.transformers.transformer_flux import ( 7 | FluxTransformer2DModel, 8 | Transformer2DModelOutput, 9 | USE_PEFT_BACKEND, 10 | is_torch_version, 11 | scale_lora_layers, 12 | unscale_lora_layers, 13 | logger, 14 | ) 15 | import numpy as np 16 | 17 | 18 | def prepare_params( 19 | hidden_states: torch.Tensor, 20 | encoder_hidden_states: torch.Tensor = None, 21 | pooled_projections: torch.Tensor = None, 22 | timestep: torch.LongTensor = None, 23 | img_ids: torch.Tensor = None, 24 | txt_ids: torch.Tensor = None, 25 | guidance: torch.Tensor = None, 26 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 27 | controlnet_block_samples=None, 28 | controlnet_single_block_samples=None, 29 | return_dict: bool = True, 30 | **kwargs: dict, 31 | ): 32 | return ( 33 | hidden_states, 34 | encoder_hidden_states, 35 | pooled_projections, 36 | timestep, 37 | img_ids, 38 | txt_ids, 39 | guidance, 40 | joint_attention_kwargs, 41 | controlnet_block_samples, 42 | controlnet_single_block_samples, 43 | return_dict, 44 | ) 45 | 46 | 47 | def tranformer_forward( 48 | transformer: FluxTransformer2DModel, 49 | condition_latents: torch.Tensor, 50 | condition_ids: torch.Tensor, 51 | condition_type_ids: torch.Tensor, 52 | model_config: Optional[Dict[str, Any]] = {}, 53 | c_t=0, 54 | **params: dict, 55 | ): 56 | self = transformer 57 | use_condition = condition_latents is not None 58 | 59 | ( 60 | hidden_states, 61 | encoder_hidden_states, 62 | pooled_projections, 63 | timestep, 64 | img_ids, 65 | txt_ids, 66 | guidance, 67 | joint_attention_kwargs, 68 | controlnet_block_samples, 69 | controlnet_single_block_samples, 70 | return_dict, 71 | ) = prepare_params(**params) 72 | 73 | if joint_attention_kwargs is not None: 74 | joint_attention_kwargs = joint_attention_kwargs.copy() 75 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 76 | else: 77 | lora_scale = 1.0 78 | 79 | if USE_PEFT_BACKEND: 80 | # weight the lora layers by setting `lora_scale` for each PEFT layer 81 | scale_lora_layers(self, lora_scale) 82 | else: 83 | if ( 84 | joint_attention_kwargs is not None 85 | and joint_attention_kwargs.get("scale", None) is not None 86 | ): 87 | logger.warning( 88 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 89 | ) 90 | 91 | with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)): 92 | hidden_states = self.x_embedder(hidden_states) 93 | condition_latents = self.x_embedder(condition_latents) if use_condition else None 94 | 95 | timestep = timestep.to(hidden_states.dtype) * 1000 96 | 97 | if guidance is not None: 98 | guidance = guidance.to(hidden_states.dtype) * 1000 99 | else: 100 | guidance = None 101 | 102 | temb = ( 103 | self.time_text_embed(timestep, pooled_projections) 104 | if guidance is None 105 | else self.time_text_embed(timestep, guidance, pooled_projections) 106 | ) 107 | 108 | cond_temb = ( 109 | self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections) 110 | if guidance is None 111 | else self.time_text_embed( 112 | torch.ones_like(timestep) * c_t * 1000, guidance, pooled_projections 113 | ) 114 | ) 115 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 116 | 117 | if txt_ids.ndim == 3: 118 | logger.warning( 119 | "Passing `txt_ids` 3d torch.Tensor is deprecated." 120 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 121 | ) 122 | txt_ids = txt_ids[0] 123 | if img_ids.ndim == 3: 124 | logger.warning( 125 | "Passing `img_ids` 3d torch.Tensor is deprecated." 126 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 127 | ) 128 | img_ids = img_ids[0] 129 | 130 | ids = torch.cat((txt_ids, img_ids), dim=0) 131 | image_rotary_emb = self.pos_embed(ids) 132 | if use_condition: 133 | # condition_ids[:, :1] = condition_type_ids 134 | cond_rotary_emb = self.pos_embed(condition_ids) 135 | 136 | # hidden_states = torch.cat([hidden_states, condition_latents], dim=1) 137 | 138 | for index_block, block in enumerate(self.transformer_blocks): 139 | if self.training and self.gradient_checkpointing: 140 | ckpt_kwargs: Dict[str, Any] = ( 141 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 142 | ) 143 | encoder_hidden_states, hidden_states, condition_latents = ( 144 | torch.utils.checkpoint.checkpoint( 145 | block_forward, 146 | self=block, 147 | model_config=model_config, 148 | hidden_states=hidden_states, 149 | encoder_hidden_states=encoder_hidden_states, 150 | condition_latents=condition_latents if use_condition else None, 151 | temb=temb, 152 | cond_temb=cond_temb if use_condition else None, 153 | cond_rotary_emb=cond_rotary_emb if use_condition else None, 154 | image_rotary_emb=image_rotary_emb, 155 | **ckpt_kwargs, 156 | ) 157 | ) 158 | 159 | else: 160 | encoder_hidden_states, hidden_states, condition_latents = block_forward( 161 | block, 162 | model_config=model_config, 163 | hidden_states=hidden_states, 164 | encoder_hidden_states=encoder_hidden_states, 165 | condition_latents=condition_latents if use_condition else None, 166 | temb=temb, 167 | cond_temb=cond_temb if use_condition else None, 168 | cond_rotary_emb=cond_rotary_emb if use_condition else None, 169 | image_rotary_emb=image_rotary_emb, 170 | ) 171 | 172 | # controlnet residual 173 | if controlnet_block_samples is not None: 174 | interval_control = len(self.transformer_blocks) / len( 175 | controlnet_block_samples 176 | ) 177 | interval_control = int(np.ceil(interval_control)) 178 | hidden_states = ( 179 | hidden_states 180 | + controlnet_block_samples[index_block // interval_control] 181 | ) 182 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 183 | 184 | for index_block, block in enumerate(self.single_transformer_blocks): 185 | if self.training and self.gradient_checkpointing: 186 | ckpt_kwargs: Dict[str, Any] = ( 187 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 188 | ) 189 | result = torch.utils.checkpoint.checkpoint( 190 | single_block_forward, 191 | self=block, 192 | model_config=model_config, 193 | hidden_states=hidden_states, 194 | temb=temb, 195 | image_rotary_emb=image_rotary_emb, 196 | **( 197 | { 198 | "condition_latents": condition_latents, 199 | "cond_temb": cond_temb, 200 | "cond_rotary_emb": cond_rotary_emb, 201 | } 202 | if use_condition 203 | else {} 204 | ), 205 | **ckpt_kwargs, 206 | ) 207 | 208 | else: 209 | result = single_block_forward( 210 | block, 211 | model_config=model_config, 212 | hidden_states=hidden_states, 213 | temb=temb, 214 | image_rotary_emb=image_rotary_emb, 215 | **( 216 | { 217 | "condition_latents": condition_latents, 218 | "cond_temb": cond_temb, 219 | "cond_rotary_emb": cond_rotary_emb, 220 | } 221 | if use_condition 222 | else {} 223 | ), 224 | ) 225 | if use_condition: 226 | hidden_states, condition_latents = result 227 | else: 228 | hidden_states = result 229 | 230 | # controlnet residual 231 | if controlnet_single_block_samples is not None: 232 | interval_control = len(self.single_transformer_blocks) / len( 233 | controlnet_single_block_samples 234 | ) 235 | interval_control = int(np.ceil(interval_control)) 236 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 237 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] 238 | + controlnet_single_block_samples[index_block // interval_control] 239 | ) 240 | 241 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 242 | 243 | hidden_states = self.norm_out(hidden_states, temb) 244 | output = self.proj_out(hidden_states) 245 | 246 | if USE_PEFT_BACKEND: 247 | # remove `lora_scale` from each PEFT layer 248 | unscale_lora_layers(self, lora_scale) 249 | 250 | if not return_dict: 251 | return (output,) 252 | return Transformer2DModelOutput(sample=output) 253 | -------------------------------------------------------------------------------- /src/gradio/gradio_app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import torch 3 | from PIL import Image, ImageDraw, ImageFont 4 | from diffusers.pipelines import FluxPipeline 5 | from diffusers import FluxTransformer2DModel 6 | import numpy as np 7 | 8 | from ..flux.condition import Condition 9 | from ..flux.generate import seed_everything, generate 10 | 11 | pipe = None 12 | use_int8 = False 13 | 14 | 15 | def get_gpu_memory(): 16 | return torch.cuda.get_device_properties(0).total_memory / 1024**3 17 | 18 | 19 | def init_pipeline(): 20 | global pipe 21 | if use_int8 or get_gpu_memory() < 33: 22 | transformer_model = FluxTransformer2DModel.from_pretrained( 23 | "sayakpaul/flux.1-schell-int8wo-improved", 24 | torch_dtype=torch.bfloat16, 25 | use_safetensors=False, 26 | ) 27 | pipe = FluxPipeline.from_pretrained( 28 | "black-forest-labs/FLUX.1-schnell", 29 | transformer=transformer_model, 30 | torch_dtype=torch.bfloat16, 31 | ) 32 | else: 33 | pipe = FluxPipeline.from_pretrained( 34 | "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 35 | ) 36 | pipe = pipe.to("cuda") 37 | pipe.load_lora_weights( 38 | "Yuanshi/OminiControl", 39 | weight_name="omini/subject_512.safetensors", 40 | adapter_name="subject", 41 | ) 42 | 43 | 44 | def process_image_and_text(image, text): 45 | # center crop image 46 | w, h, min_size = image.size[0], image.size[1], min(image.size) 47 | image = image.crop( 48 | ( 49 | (w - min_size) // 2, 50 | (h - min_size) // 2, 51 | (w + min_size) // 2, 52 | (h + min_size) // 2, 53 | ) 54 | ) 55 | image = image.resize((512, 512)) 56 | 57 | condition = Condition("subject", image, position_delta=(0, 32)) 58 | 59 | if pipe is None: 60 | init_pipeline() 61 | 62 | result_img = generate( 63 | pipe, 64 | prompt=text.strip(), 65 | conditions=[condition], 66 | num_inference_steps=8, 67 | height=512, 68 | width=512, 69 | ).images[0] 70 | 71 | return result_img 72 | 73 | 74 | def get_samples(): 75 | sample_list = [ 76 | { 77 | "image": "assets/oranges.jpg", 78 | "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'", 79 | }, 80 | { 81 | "image": "assets/penguin.jpg", 82 | "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'", 83 | }, 84 | { 85 | "image": "assets/rc_car.jpg", 86 | "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.", 87 | }, 88 | { 89 | "image": "assets/clock.jpg", 90 | "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.", 91 | }, 92 | { 93 | "image": "assets/tshirt.jpg", 94 | "text": "On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her.", 95 | }, 96 | ] 97 | return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list] 98 | 99 | 100 | demo = gr.Interface( 101 | fn=process_image_and_text, 102 | inputs=[ 103 | gr.Image(type="pil"), 104 | gr.Textbox(lines=2), 105 | ], 106 | outputs=gr.Image(type="pil"), 107 | title="OminiControl / Subject driven generation", 108 | examples=get_samples(), 109 | ) 110 | 111 | if __name__ == "__main__": 112 | init_pipeline() 113 | demo.launch( 114 | debug=True, 115 | ) 116 | -------------------------------------------------------------------------------- /src/train/callbacks.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from PIL import Image, ImageFilter, ImageDraw 3 | import numpy as np 4 | from transformers import pipeline 5 | import cv2 6 | import torch 7 | import os 8 | 9 | try: 10 | import wandb 11 | except ImportError: 12 | wandb = None 13 | 14 | from ..flux.condition import Condition 15 | from ..flux.generate import generate 16 | 17 | 18 | class TrainingCallback(L.Callback): 19 | def __init__(self, run_name, training_config: dict = {}): 20 | self.run_name, self.training_config = run_name, training_config 21 | 22 | self.print_every_n_steps = training_config.get("print_every_n_steps", 10) 23 | self.save_interval = training_config.get("save_interval", 1000) 24 | self.sample_interval = training_config.get("sample_interval", 1000) 25 | self.save_path = training_config.get("save_path", "./output") 26 | 27 | self.wandb_config = training_config.get("wandb", None) 28 | self.use_wandb = ( 29 | wandb is not None and os.environ.get("WANDB_API_KEY") is not None 30 | ) 31 | 32 | self.total_steps = 0 33 | 34 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 35 | gradient_size = 0 36 | max_gradient_size = 0 37 | count = 0 38 | for _, param in pl_module.named_parameters(): 39 | if param.grad is not None: 40 | gradient_size += param.grad.norm(2).item() 41 | max_gradient_size = max(max_gradient_size, param.grad.norm(2).item()) 42 | count += 1 43 | if count > 0: 44 | gradient_size /= count 45 | 46 | self.total_steps += 1 47 | 48 | # Print training progress every n steps 49 | if self.use_wandb: 50 | report_dict = { 51 | "steps": batch_idx, 52 | "steps": self.total_steps, 53 | "epoch": trainer.current_epoch, 54 | "gradient_size": gradient_size, 55 | } 56 | loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches 57 | report_dict["loss"] = loss_value 58 | report_dict["t"] = pl_module.last_t 59 | wandb.log(report_dict) 60 | 61 | if self.total_steps % self.print_every_n_steps == 0: 62 | print( 63 | f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}" 64 | ) 65 | 66 | # Save LoRA weights at specified intervals 67 | if self.total_steps % self.save_interval == 0: 68 | print( 69 | f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights" 70 | ) 71 | pl_module.save_lora( 72 | f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}" 73 | ) 74 | 75 | # Generate and save a sample image at specified intervals 76 | if self.total_steps % self.sample_interval == 0: 77 | print( 78 | f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample" 79 | ) 80 | self.generate_a_sample( 81 | trainer, 82 | pl_module, 83 | f"{self.save_path}/{self.run_name}/output", 84 | f"lora_{self.total_steps}", 85 | batch["condition_type"][ 86 | 0 87 | ], # Use the condition type from the current batch 88 | ) 89 | 90 | @torch.no_grad() 91 | def generate_a_sample( 92 | self, 93 | trainer, 94 | pl_module, 95 | save_path, 96 | file_name, 97 | condition_type="super_resolution", 98 | ): 99 | # TODO: change this two variables to parameters 100 | condition_size = trainer.training_config["dataset"]["condition_size"] 101 | target_size = trainer.training_config["dataset"]["target_size"] 102 | 103 | generator = torch.Generator(device=pl_module.device) 104 | generator.manual_seed(42) 105 | 106 | test_list = [] 107 | 108 | if condition_type == "subject": 109 | test_list.extend( 110 | [ 111 | ( 112 | Image.open("assets/test_in.jpg"), 113 | [0, -32], 114 | "Resting on the picnic table at a lakeside campsite, it's caught in the golden glow of early morning, with mist rising from the water and tall pines casting long shadows behind the scene.", 115 | ), 116 | ( 117 | Image.open("assets/test_out.jpg"), 118 | [0, -32], 119 | "In a bright room. It is placed on a table.", 120 | ), 121 | ] 122 | ) 123 | elif condition_type == "canny": 124 | condition_img = Image.open("assets/vase_hq.jpg").resize( 125 | (condition_size, condition_size) 126 | ) 127 | condition_img = np.array(condition_img) 128 | condition_img = cv2.Canny(condition_img, 100, 200) 129 | condition_img = Image.fromarray(condition_img).convert("RGB") 130 | test_list.append((condition_img, [0, 0], "A beautiful vase on a table.")) 131 | elif condition_type == "coloring": 132 | condition_img = ( 133 | Image.open("assets/vase_hq.jpg") 134 | .resize((condition_size, condition_size)) 135 | .convert("L") 136 | .convert("RGB") 137 | ) 138 | test_list.append((condition_img, [0, 0], "A beautiful vase on a table.")) 139 | elif condition_type == "depth": 140 | if not hasattr(self, "deepth_pipe"): 141 | self.deepth_pipe = pipeline( 142 | task="depth-estimation", 143 | model="LiheYoung/depth-anything-small-hf", 144 | device="cpu", 145 | ) 146 | condition_img = ( 147 | Image.open("assets/vase_hq.jpg") 148 | .resize((condition_size, condition_size)) 149 | .convert("RGB") 150 | ) 151 | condition_img = self.deepth_pipe(condition_img)["depth"].convert("RGB") 152 | test_list.append((condition_img, [0, 0], "A beautiful vase on a table.")) 153 | elif condition_type == "depth_pred": 154 | condition_img = ( 155 | Image.open("assets/vase_hq.jpg") 156 | .resize((condition_size, condition_size)) 157 | .convert("RGB") 158 | ) 159 | test_list.append((condition_img, [0, 0], "A beautiful vase on a table.")) 160 | elif condition_type == "deblurring": 161 | blur_radius = 5 162 | image = Image.open("./assets/vase_hq.jpg") 163 | condition_img = ( 164 | image.convert("RGB") 165 | .resize((condition_size, condition_size)) 166 | .filter(ImageFilter.GaussianBlur(blur_radius)) 167 | .convert("RGB") 168 | ) 169 | test_list.append((condition_img, [0, 0], "A beautiful vase on a table.")) 170 | elif condition_type == "fill": 171 | condition_img = ( 172 | Image.open("./assets/vase_hq.jpg") 173 | .resize((condition_size, condition_size)) 174 | .convert("RGB") 175 | ) 176 | mask = Image.new("L", condition_img.size, 0) 177 | draw = ImageDraw.Draw(mask) 178 | a = condition_img.size[0] // 4 179 | b = a * 3 180 | draw.rectangle([a, a, b, b], fill=255) 181 | condition_img = Image.composite( 182 | condition_img, Image.new("RGB", condition_img.size, (0, 0, 0)), mask 183 | ) 184 | test_list.append((condition_img, [0, 0], "A beautiful vase on a table.")) 185 | elif condition_type == "sr": 186 | condition_img = ( 187 | Image.open("assets/vase_hq.jpg") 188 | .resize((condition_size, condition_size)) 189 | .convert("RGB") 190 | ) 191 | test_list.append((condition_img, [0, -16], "A beautiful vase on a table.")) 192 | elif condition_type == "cartoon": 193 | condition_img = ( 194 | Image.open("assets/cartoon_boy.png") 195 | .resize((condition_size, condition_size)) 196 | .convert("RGB") 197 | ) 198 | test_list.append((condition_img, [0, -16], "A cartoon character in a white background. He is looking right, and running.")) 199 | else: 200 | raise NotImplementedError 201 | 202 | if not os.path.exists(save_path): 203 | os.makedirs(save_path) 204 | for i, (condition_img, position_delta, prompt) in enumerate(test_list): 205 | condition = Condition( 206 | condition_type=condition_type, 207 | condition=condition_img.resize( 208 | (condition_size, condition_size) 209 | ).convert("RGB"), 210 | position_delta=position_delta, 211 | ) 212 | res = generate( 213 | pl_module.flux_pipe, 214 | prompt=prompt, 215 | conditions=[condition], 216 | height=target_size, 217 | width=target_size, 218 | generator=generator, 219 | model_config=pl_module.model_config, 220 | default_lora=True, 221 | ) 222 | res.images[0].save( 223 | os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg") 224 | ) 225 | -------------------------------------------------------------------------------- /src/train/data.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFilter, ImageDraw 2 | import cv2 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | import torchvision.transforms as T 6 | import random 7 | 8 | 9 | class Subject200KDataset(Dataset): 10 | def __init__( 11 | self, 12 | base_dataset, 13 | condition_size: int = 512, 14 | target_size: int = 512, 15 | image_size: int = 512, 16 | padding: int = 0, 17 | condition_type: str = "subject", 18 | drop_text_prob: float = 0.1, 19 | drop_image_prob: float = 0.1, 20 | return_pil_image: bool = False, 21 | ): 22 | self.base_dataset = base_dataset 23 | self.condition_size = condition_size 24 | self.target_size = target_size 25 | self.image_size = image_size 26 | self.padding = padding 27 | self.condition_type = condition_type 28 | self.drop_text_prob = drop_text_prob 29 | self.drop_image_prob = drop_image_prob 30 | self.return_pil_image = return_pil_image 31 | 32 | self.to_tensor = T.ToTensor() 33 | 34 | def __len__(self): 35 | return len(self.base_dataset) * 2 36 | 37 | def __getitem__(self, idx): 38 | # If target is 0, left image is target, right image is condition 39 | target = idx % 2 40 | item = self.base_dataset[idx // 2] 41 | 42 | # Crop the image to target and condition 43 | image = item["image"] 44 | left_img = image.crop( 45 | ( 46 | self.padding, 47 | self.padding, 48 | self.image_size + self.padding, 49 | self.image_size + self.padding, 50 | ) 51 | ) 52 | right_img = image.crop( 53 | ( 54 | self.image_size + self.padding * 2, 55 | self.padding, 56 | self.image_size * 2 + self.padding * 2, 57 | self.image_size + self.padding, 58 | ) 59 | ) 60 | 61 | # Get the target and condition image 62 | target_image, condition_img = ( 63 | (left_img, right_img) if target == 0 else (right_img, left_img) 64 | ) 65 | 66 | # Resize the image 67 | condition_img = condition_img.resize( 68 | (self.condition_size, self.condition_size) 69 | ).convert("RGB") 70 | target_image = target_image.resize( 71 | (self.target_size, self.target_size) 72 | ).convert("RGB") 73 | 74 | # Get the description 75 | description = item["description"][ 76 | "description_0" if target == 0 else "description_1" 77 | ] 78 | 79 | # Randomly drop text or image 80 | drop_text = random.random() < self.drop_text_prob 81 | drop_image = random.random() < self.drop_image_prob 82 | if drop_text: 83 | description = "" 84 | if drop_image: 85 | condition_img = Image.new( 86 | "RGB", (self.condition_size, self.condition_size), (0, 0, 0) 87 | ) 88 | 89 | return { 90 | "image": self.to_tensor(target_image), 91 | "condition": self.to_tensor(condition_img), 92 | "condition_type": self.condition_type, 93 | "description": description, 94 | # 16 is the downscale factor of the image 95 | "position_delta": np.array([0, -self.condition_size // 16]), 96 | **({"pil_image": image} if self.return_pil_image else {}), 97 | } 98 | 99 | 100 | class ImageConditionDataset(Dataset): 101 | def __init__( 102 | self, 103 | base_dataset, 104 | condition_size: int = 512, 105 | target_size: int = 512, 106 | condition_type: str = "canny", 107 | drop_text_prob: float = 0.1, 108 | drop_image_prob: float = 0.1, 109 | return_pil_image: bool = False, 110 | ): 111 | self.base_dataset = base_dataset 112 | self.condition_size = condition_size 113 | self.target_size = target_size 114 | self.condition_type = condition_type 115 | self.drop_text_prob = drop_text_prob 116 | self.drop_image_prob = drop_image_prob 117 | self.return_pil_image = return_pil_image 118 | 119 | self.to_tensor = T.ToTensor() 120 | 121 | def __len__(self): 122 | return len(self.base_dataset) 123 | 124 | @property 125 | def depth_pipe(self): 126 | if not hasattr(self, "_depth_pipe"): 127 | from transformers import pipeline 128 | 129 | self._depth_pipe = pipeline( 130 | task="depth-estimation", 131 | model="LiheYoung/depth-anything-small-hf", 132 | device="cpu", 133 | ) 134 | return self._depth_pipe 135 | 136 | def _get_canny_edge(self, img): 137 | resize_ratio = self.condition_size / max(img.size) 138 | img = img.resize( 139 | (int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio)) 140 | ) 141 | img_np = np.array(img) 142 | img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) 143 | edges = cv2.Canny(img_gray, 100, 200) 144 | return Image.fromarray(edges).convert("RGB") 145 | 146 | def __getitem__(self, idx): 147 | image = self.base_dataset[idx]["jpg"] 148 | image = image.resize((self.target_size, self.target_size)).convert("RGB") 149 | description = self.base_dataset[idx]["json"]["prompt"] 150 | 151 | # Get the condition image 152 | position_delta = np.array([0, 0]) 153 | if self.condition_type == "canny": 154 | condition_img = self._get_canny_edge(image) 155 | elif self.condition_type == "coloring": 156 | condition_img = ( 157 | image.resize((self.condition_size, self.condition_size)) 158 | .convert("L") 159 | .convert("RGB") 160 | ) 161 | elif self.condition_type == "deblurring": 162 | blur_radius = random.randint(1, 10) 163 | condition_img = ( 164 | image.convert("RGB") 165 | .resize((self.condition_size, self.condition_size)) 166 | .filter(ImageFilter.GaussianBlur(blur_radius)) 167 | .convert("RGB") 168 | ) 169 | elif self.condition_type == "depth": 170 | condition_img = self.depth_pipe(image)["depth"].convert("RGB") 171 | elif self.condition_type == "depth_pred": 172 | condition_img = image 173 | image = self.depth_pipe(condition_img)["depth"].convert("RGB") 174 | description = f"[depth] {description}" 175 | elif self.condition_type == "fill": 176 | condition_img = image.resize( 177 | (self.condition_size, self.condition_size) 178 | ).convert("RGB") 179 | w, h = image.size 180 | x1, x2 = sorted([random.randint(0, w), random.randint(0, w)]) 181 | y1, y2 = sorted([random.randint(0, h), random.randint(0, h)]) 182 | mask = Image.new("L", image.size, 0) 183 | draw = ImageDraw.Draw(mask) 184 | draw.rectangle([x1, y1, x2, y2], fill=255) 185 | if random.random() > 0.5: 186 | mask = Image.eval(mask, lambda a: 255 - a) 187 | condition_img = Image.composite( 188 | image, Image.new("RGB", image.size, (0, 0, 0)), mask 189 | ) 190 | elif self.condition_type == "sr": 191 | condition_img = image.resize( 192 | (self.condition_size, self.condition_size) 193 | ).convert("RGB") 194 | position_delta = np.array([0, -self.condition_size // 16]) 195 | 196 | else: 197 | raise ValueError(f"Condition type {self.condition_type} not implemented") 198 | 199 | # Randomly drop text or image 200 | drop_text = random.random() < self.drop_text_prob 201 | drop_image = random.random() < self.drop_image_prob 202 | if drop_text: 203 | description = "" 204 | if drop_image: 205 | condition_img = Image.new( 206 | "RGB", (self.condition_size, self.condition_size), (0, 0, 0) 207 | ) 208 | 209 | return { 210 | "image": self.to_tensor(image), 211 | "condition": self.to_tensor(condition_img), 212 | "condition_type": self.condition_type, 213 | "description": description, 214 | "position_delta": position_delta, 215 | **({"pil_image": [image, condition_img]} if self.return_pil_image else {}), 216 | } 217 | 218 | 219 | class CartoonDataset(Dataset): 220 | def __init__( 221 | self, 222 | base_dataset, 223 | condition_size: int = 1024, 224 | target_size: int = 1024, 225 | image_size: int = 1024, 226 | padding: int = 0, 227 | condition_type: str = "cartoon", 228 | drop_text_prob: float = 0.1, 229 | drop_image_prob: float = 0.1, 230 | return_pil_image: bool = False, 231 | ): 232 | self.base_dataset = base_dataset 233 | self.condition_size = condition_size 234 | self.target_size = target_size 235 | self.image_size = image_size 236 | self.padding = padding 237 | self.condition_type = condition_type 238 | self.drop_text_prob = drop_text_prob 239 | self.drop_image_prob = drop_image_prob 240 | self.return_pil_image = return_pil_image 241 | 242 | self.to_tensor = T.ToTensor() 243 | 244 | 245 | def __len__(self): 246 | return len(self.base_dataset) 247 | 248 | def __getitem__(self, idx): 249 | data = self.base_dataset[idx] 250 | condition_img = data['condition'] 251 | target_image = data['target'] 252 | 253 | # Tag 254 | tag = data['tags'][0] 255 | 256 | target_description = data['target_description'] 257 | 258 | description = { 259 | "lion": "lion like animal", 260 | "bear": "bear like animal", 261 | "gorilla": "gorilla like animal", 262 | "dog": "dog like animal", 263 | "elephant": "elephant like animal", 264 | "eagle": "eagle like bird", 265 | "tiger": "tiger like animal", 266 | "owl": "owl like bird", 267 | "woman": "woman", 268 | "parrot": "parrot like bird", 269 | "mouse": "mouse like animal", 270 | "man": "man", 271 | "pigeon": "pigeon like bird", 272 | "girl": "girl", 273 | "panda": "panda like animal", 274 | "crocodile": "crocodile like animal", 275 | "rabbit": "rabbit like animal", 276 | "boy": "boy", 277 | "monkey": "monkey like animal", 278 | "cat": "cat like animal" 279 | } 280 | 281 | # Resize the image 282 | condition_img = condition_img.resize( 283 | (self.condition_size, self.condition_size) 284 | ).convert("RGB") 285 | target_image = target_image.resize( 286 | (self.target_size, self.target_size) 287 | ).convert("RGB") 288 | 289 | # Process datum to create description 290 | description = data.get("description", f"Photo of a {description[tag]} cartoon character in a white background. Character is facing {target_description['facing_direction']}. Character pose is {target_description['pose']}.") 291 | 292 | # Randomly drop text or image 293 | drop_text = random.random() < self.drop_text_prob 294 | drop_image = random.random() < self.drop_image_prob 295 | if drop_text: 296 | description = "" 297 | if drop_image: 298 | condition_img = Image.new( 299 | "RGB", (self.condition_size, self.condition_size), (0, 0, 0) 300 | ) 301 | 302 | 303 | return { 304 | "image": self.to_tensor(target_image), 305 | "condition": self.to_tensor(condition_img), 306 | "condition_type": self.condition_type, 307 | "description": description, 308 | # 16 is the downscale factor of the image 309 | "position_delta": np.array([0, -16]), 310 | } -------------------------------------------------------------------------------- /src/train/model.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from diffusers.pipelines import FluxPipeline 3 | import torch 4 | from peft import LoraConfig, get_peft_model_state_dict 5 | 6 | import prodigyopt 7 | 8 | from ..flux.transformer import tranformer_forward 9 | from ..flux.condition import Condition 10 | from ..flux.pipeline_tools import encode_images, prepare_text_input 11 | 12 | 13 | class OminiModel(L.LightningModule): 14 | def __init__( 15 | self, 16 | flux_pipe_id: str, 17 | lora_path: str = None, 18 | lora_config: dict = None, 19 | device: str = "cuda", 20 | dtype: torch.dtype = torch.bfloat16, 21 | model_config: dict = {}, 22 | optimizer_config: dict = None, 23 | gradient_checkpointing: bool = False, 24 | ): 25 | # Initialize the LightningModule 26 | super().__init__() 27 | self.model_config = model_config 28 | self.optimizer_config = optimizer_config 29 | 30 | # Load the Flux pipeline 31 | self.flux_pipe: FluxPipeline = ( 32 | FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device) 33 | ) 34 | self.transformer = self.flux_pipe.transformer 35 | self.transformer.gradient_checkpointing = gradient_checkpointing 36 | self.transformer.train() 37 | 38 | # Freeze the Flux pipeline 39 | self.flux_pipe.text_encoder.requires_grad_(False).eval() 40 | self.flux_pipe.text_encoder_2.requires_grad_(False).eval() 41 | self.flux_pipe.vae.requires_grad_(False).eval() 42 | 43 | # Initialize LoRA layers 44 | self.lora_layers = self.init_lora(lora_path, lora_config) 45 | 46 | self.to(device).to(dtype) 47 | 48 | def init_lora(self, lora_path: str, lora_config: dict): 49 | assert lora_path or lora_config 50 | if lora_path: 51 | # TODO: Implement this 52 | raise NotImplementedError 53 | else: 54 | self.transformer.add_adapter(LoraConfig(**lora_config)) 55 | # TODO: Check if this is correct (p.requires_grad) 56 | lora_layers = filter( 57 | lambda p: p.requires_grad, self.transformer.parameters() 58 | ) 59 | return list(lora_layers) 60 | 61 | def save_lora(self, path: str): 62 | FluxPipeline.save_lora_weights( 63 | save_directory=path, 64 | transformer_lora_layers=get_peft_model_state_dict(self.transformer), 65 | safe_serialization=True, 66 | ) 67 | 68 | def configure_optimizers(self): 69 | # Freeze the transformer 70 | self.transformer.requires_grad_(False) 71 | opt_config = self.optimizer_config 72 | 73 | # Set the trainable parameters 74 | self.trainable_params = self.lora_layers 75 | 76 | # Unfreeze trainable parameters 77 | for p in self.trainable_params: 78 | p.requires_grad_(True) 79 | 80 | # Initialize the optimizer 81 | if opt_config["type"] == "AdamW": 82 | optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"]) 83 | elif opt_config["type"] == "Prodigy": 84 | optimizer = prodigyopt.Prodigy( 85 | self.trainable_params, 86 | **opt_config["params"], 87 | ) 88 | elif opt_config["type"] == "SGD": 89 | optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"]) 90 | else: 91 | raise NotImplementedError 92 | 93 | return optimizer 94 | 95 | def training_step(self, batch, batch_idx): 96 | step_loss = self.step(batch) 97 | self.log_loss = ( 98 | step_loss.item() 99 | if not hasattr(self, "log_loss") 100 | else self.log_loss * 0.95 + step_loss.item() * 0.05 101 | ) 102 | return step_loss 103 | 104 | def step(self, batch): 105 | imgs = batch["image"] 106 | conditions = batch["condition"] 107 | condition_types = batch["condition_type"] 108 | prompts = batch["description"] 109 | position_delta = batch["position_delta"][0] 110 | 111 | # Prepare inputs 112 | with torch.no_grad(): 113 | # Prepare image input 114 | x_0, img_ids = encode_images(self.flux_pipe, imgs) 115 | 116 | # Prepare text input 117 | prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input( 118 | self.flux_pipe, prompts 119 | ) 120 | 121 | # Prepare t and x_t 122 | t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device)) 123 | x_1 = torch.randn_like(x_0).to(self.device) 124 | t_ = t.unsqueeze(1).unsqueeze(1) 125 | x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype) 126 | 127 | # Prepare conditions 128 | condition_latents, condition_ids = encode_images(self.flux_pipe, conditions) 129 | 130 | # Add position delta 131 | condition_ids[:, 1] += position_delta[0] 132 | condition_ids[:, 2] += position_delta[1] 133 | 134 | # Prepare condition type 135 | condition_type_ids = torch.tensor( 136 | [ 137 | Condition.get_type_id(condition_type) 138 | for condition_type in condition_types 139 | ] 140 | ).to(self.device) 141 | condition_type_ids = ( 142 | torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0] 143 | ).unsqueeze(1) 144 | 145 | # Prepare guidance 146 | guidance = ( 147 | torch.ones_like(t).to(self.device) 148 | if self.transformer.config.guidance_embeds 149 | else None 150 | ) 151 | 152 | # Forward pass 153 | transformer_out = tranformer_forward( 154 | self.transformer, 155 | # Model config 156 | model_config=self.model_config, 157 | # Inputs of the condition (new feature) 158 | condition_latents=condition_latents, 159 | condition_ids=condition_ids, 160 | condition_type_ids=condition_type_ids, 161 | # Inputs to the original transformer 162 | hidden_states=x_t, 163 | timestep=t, 164 | guidance=guidance, 165 | pooled_projections=pooled_prompt_embeds, 166 | encoder_hidden_states=prompt_embeds, 167 | txt_ids=text_ids, 168 | img_ids=img_ids, 169 | joint_attention_kwargs=None, 170 | return_dict=False, 171 | ) 172 | pred = transformer_out[0] 173 | 174 | # Compute loss 175 | loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean") 176 | self.last_t = t.mean().item() 177 | return loss 178 | -------------------------------------------------------------------------------- /src/train/train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torch 3 | import lightning as L 4 | import yaml 5 | import os 6 | import time 7 | 8 | from datasets import load_dataset 9 | 10 | from .data import ( 11 | ImageConditionDataset, 12 | Subject200KDataset, 13 | CartoonDataset 14 | ) 15 | from .model import OminiModel 16 | from .callbacks import TrainingCallback 17 | 18 | 19 | def get_rank(): 20 | try: 21 | rank = int(os.environ.get("LOCAL_RANK")) 22 | except: 23 | rank = 0 24 | return rank 25 | 26 | 27 | def get_config(): 28 | config_path = os.environ.get("XFL_CONFIG") 29 | assert config_path is not None, "Please set the XFL_CONFIG environment variable" 30 | with open(config_path, "r") as f: 31 | config = yaml.safe_load(f) 32 | return config 33 | 34 | 35 | def init_wandb(wandb_config, run_name): 36 | import wandb 37 | 38 | try: 39 | assert os.environ.get("WANDB_API_KEY") is not None 40 | wandb.init( 41 | project=wandb_config["project"], 42 | name=run_name, 43 | config={}, 44 | ) 45 | except Exception as e: 46 | print("Failed to initialize WanDB:", e) 47 | 48 | 49 | def main(): 50 | # Initialize 51 | is_main_process, rank = get_rank() == 0, get_rank() 52 | torch.cuda.set_device(rank) 53 | config = get_config() 54 | training_config = config["train"] 55 | run_name = time.strftime("%Y%m%d-%H%M%S") 56 | 57 | # Initialize WanDB 58 | wandb_config = training_config.get("wandb", None) 59 | if wandb_config is not None and is_main_process: 60 | init_wandb(wandb_config, run_name) 61 | 62 | print("Rank:", rank) 63 | if is_main_process: 64 | print("Config:", config) 65 | 66 | # Initialize dataset and dataloader 67 | if training_config["dataset"]["type"] == "subject": 68 | dataset = load_dataset("Yuanshi/Subjects200K") 69 | 70 | # Define filter function 71 | def filter_func(item): 72 | if not item.get("quality_assessment"): 73 | return False 74 | return all( 75 | item["quality_assessment"].get(key, 0) >= 5 76 | for key in ["compositeStructure", "objectConsistency", "imageQuality"] 77 | ) 78 | 79 | # Filter dataset 80 | if not os.path.exists("./cache/dataset"): 81 | os.makedirs("./cache/dataset") 82 | data_valid = dataset["train"].filter( 83 | filter_func, 84 | num_proc=16, 85 | cache_file_name="./cache/dataset/data_valid.arrow", 86 | ) 87 | dataset = Subject200KDataset( 88 | data_valid, 89 | condition_size=training_config["dataset"]["condition_size"], 90 | target_size=training_config["dataset"]["target_size"], 91 | image_size=training_config["dataset"]["image_size"], 92 | padding=training_config["dataset"]["padding"], 93 | condition_type=training_config["condition_type"], 94 | drop_text_prob=training_config["dataset"]["drop_text_prob"], 95 | drop_image_prob=training_config["dataset"]["drop_image_prob"], 96 | ) 97 | elif training_config["dataset"]["type"] == "img": 98 | # Load dataset text-to-image-2M 99 | dataset = load_dataset( 100 | "webdataset", 101 | data_files={"train": training_config["dataset"]["urls"]}, 102 | split="train", 103 | cache_dir="cache/t2i2m", 104 | num_proc=32, 105 | ) 106 | dataset = ImageConditionDataset( 107 | dataset, 108 | condition_size=training_config["dataset"]["condition_size"], 109 | target_size=training_config["dataset"]["target_size"], 110 | condition_type=training_config["condition_type"], 111 | drop_text_prob=training_config["dataset"]["drop_text_prob"], 112 | drop_image_prob=training_config["dataset"]["drop_image_prob"], 113 | ) 114 | elif training_config["dataset"]["type"] == "cartoon": 115 | dataset = load_dataset("saquiboye/oye-cartoon", split="train") 116 | dataset = CartoonDataset( 117 | dataset, 118 | condition_size=training_config["dataset"]["condition_size"], 119 | target_size=training_config["dataset"]["target_size"], 120 | image_size=training_config["dataset"]["image_size"], 121 | padding=training_config["dataset"]["padding"], 122 | condition_type=training_config["condition_type"], 123 | drop_text_prob=training_config["dataset"]["drop_text_prob"], 124 | drop_image_prob=training_config["dataset"]["drop_image_prob"], 125 | ) 126 | else: 127 | raise NotImplementedError 128 | 129 | print("Dataset length:", len(dataset)) 130 | train_loader = DataLoader( 131 | dataset, 132 | batch_size=training_config["batch_size"], 133 | shuffle=True, 134 | num_workers=training_config["dataloader_workers"], 135 | ) 136 | 137 | # Initialize model 138 | trainable_model = OminiModel( 139 | flux_pipe_id=config["flux_path"], 140 | lora_config=training_config["lora_config"], 141 | device=f"cuda", 142 | dtype=getattr(torch, config["dtype"]), 143 | optimizer_config=training_config["optimizer"], 144 | model_config=config.get("model", {}), 145 | gradient_checkpointing=training_config.get("gradient_checkpointing", False), 146 | ) 147 | 148 | # Callbacks for logging and saving checkpoints 149 | training_callbacks = ( 150 | [TrainingCallback(run_name, training_config=training_config)] 151 | if is_main_process 152 | else [] 153 | ) 154 | 155 | # Initialize trainer 156 | trainer = L.Trainer( 157 | accumulate_grad_batches=training_config["accumulate_grad_batches"], 158 | callbacks=training_callbacks, 159 | enable_checkpointing=False, 160 | enable_progress_bar=False, 161 | logger=False, 162 | max_steps=training_config.get("max_steps", -1), 163 | max_epochs=training_config.get("max_epochs", -1), 164 | gradient_clip_val=training_config.get("gradient_clip_val", 0.5), 165 | ) 166 | 167 | setattr(trainer, "training_config", training_config) 168 | 169 | # Save config 170 | save_path = training_config.get("save_path", "./output") 171 | if is_main_process: 172 | os.makedirs(f"{save_path}/{run_name}") 173 | with open(f"{save_path}/{run_name}/config.yaml", "w") as f: 174 | yaml.dump(config, f) 175 | 176 | # Start training 177 | trainer.fit(trainable_model, train_loader) 178 | 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /train/README.md: -------------------------------------------------------------------------------- 1 | # OminiControl Training 🛠️ 2 | 3 | ## Preparation 4 | 5 | ### Setup 6 | 1. **Environment** 7 | ```bash 8 | conda create -n omini python=3.10 9 | conda activate omini 10 | ``` 11 | 2. **Requirements** 12 | ```bash 13 | pip install -r train/requirements.txt 14 | ``` 15 | 16 | ### Dataset 17 | 1. Download dataset [Subject200K](https://huggingface.co/datasets/Yuanshi/Subjects200K). (**subject-driven generation**) 18 | ``` 19 | bash train/script/data_download/data_download1.sh 20 | ``` 21 | 2. Download dataset [text-to-image-2M](https://huggingface.co/datasets/jackyhate/text-to-image-2M). (**spatial control task**) 22 | ``` 23 | bash train/script/data_download/data_download2.sh 24 | ``` 25 | **Note:** By default, only a few files are downloaded. You can modify `data_download2.sh` to download additional datasets. Remember to update the config file to specify the training data accordingly. 26 | 27 | ## Training 28 | 29 | ### Start training training 30 | **Config file path**: `./train/config` 31 | 32 | **Scripts path**: `./train/script` 33 | 34 | 1. Subject-driven generation 35 | ```bash 36 | bash train/script/train_subject.sh 37 | ``` 38 | 2. Spatial control task 39 | ```bash 40 | bash train/script/train_canny.sh 41 | ``` 42 | 43 | **Note**: Detailed WanDB settings and GPU settings can be found in the script files and the config files. 44 | 45 | ### Other spatial control tasks 46 | This repository supports 5 spatial control tasks: 47 | 1. Canny edge to image (`canny`) 48 | 2. Image colorization (`coloring`) 49 | 3. Image deblurring (`deblurring`) 50 | 4. Depth map to image (`depth`) 51 | 5. Image to depth map (`depth_pred`) 52 | 6. Image inpainting (`fill`) 53 | 7. Super resolution (`sr`) 54 | 55 | You can modify the `condition_type` parameter in config file `config/canny_512.yaml` to switch between different tasks. 56 | 57 | ### Customize your own task 58 | You can customize your own task by constructing a new dataset and modifying the training code. 59 | 60 |
61 | Instructions 62 | 63 | 1. **Dataset** : 64 | 65 | Construct a new dataset with the following format: (`src/train/data.py`) 66 | ```python 67 | class MyDataset(Dataset): 68 | def __init__(self, ...): 69 | ... 70 | def __len__(self): 71 | ... 72 | def __getitem__(self, idx): 73 | ... 74 | return { 75 | "image": image, 76 | "condition": condition_img, 77 | "condition_type": "your_condition_type", 78 | "description": description, 79 | "position_delta": position_delta 80 | } 81 | ``` 82 | **Note:** For spatial control tasks, set the `position_delta` to be `[0, 0]`. For non-spatial control tasks, set `position_delta` to be `[0, -condition_width // 16]`. 83 | 2. **Condition**: 84 | 85 | Add a new condition type in the `Condition` class. (`src/flux/condition.py`) 86 | ```python 87 | condition_dict = { 88 | ... 89 | "your_condition_type": your_condition_id_number, # Add your condition type here 90 | } 91 | ... 92 | if condition_type in [ 93 | ... 94 | "your_condition_type", # Add your condition type here 95 | ]: 96 | ... 97 | ``` 98 | 3. **Test**: 99 | 100 | Add a new test function for your task. (`src/train/callbacks.py`) 101 | ```python 102 | if self.condition_type == "your_condition_type": 103 | condition_img = ( 104 | Image.open("images/vase.jpg") 105 | .resize((condition_size, condition_size)) 106 | .convert("RGB") 107 | ) 108 | ... 109 | test_list.append((condition_img, [0, 0], "A beautiful vase on a table.")) 110 | ``` 111 | 112 | 4. **Import relevant dataset in the training script** 113 | Update the file in the following section. (`src/train/train.py`) 114 | ```python 115 | from .data import ( 116 | ImageConditionDataset, 117 | Subject200KDateset, 118 | MyDataset 119 | ) 120 | ... 121 | 122 | # Initialize dataset and dataloader 123 | if training_config["dataset"]["type"] == "your_condition_type": 124 | ... 125 | ``` 126 | 127 |
128 | 129 | ## Hardware requirement 130 | **Note**: Memory optimization (like dynamic T5 model loading) is pending implementation. 131 | 132 | **Recommanded** 133 | - Hardware: 2x NVIDIA H100 GPUs 134 | - Memory: ~80GB GPU memory 135 | 136 | **Minimal** 137 | - Hardware: 1x NVIDIA L20 GPU 138 | - Memory: ~48GB GPU memory -------------------------------------------------------------------------------- /train/config/canny_512.yaml: -------------------------------------------------------------------------------- 1 | flux_path: "black-forest-labs/FLUX.1-dev" 2 | dtype: "bfloat16" 3 | 4 | model: 5 | union_cond_attn: true 6 | add_cond_attn: false 7 | latent_lora: false 8 | 9 | train: 10 | batch_size: 1 11 | accumulate_grad_batches: 1 12 | dataloader_workers: 5 13 | save_interval: 1000 14 | sample_interval: 100 15 | max_steps: -1 16 | gradient_checkpointing: true 17 | save_path: "runs" 18 | 19 | # Specify the type of condition to use. 20 | # Options: ["canny", "coloring", "deblurring", "depth", "depth_pred", "fill"] 21 | condition_type: "canny" 22 | dataset: 23 | type: "img" 24 | urls: 25 | - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000046.tar" 26 | - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000045.tar" 27 | cache_name: "data_512_2M" 28 | condition_size: 512 29 | target_size: 512 30 | drop_text_prob: 0.1 31 | drop_image_prob: 0.1 32 | 33 | wandb: 34 | project: "OminiControl" 35 | 36 | lora_config: 37 | r: 4 38 | lora_alpha: 4 39 | init_lora_weights: "gaussian" 40 | target_modules: "(.*x_embedder|.*(?