├── LICENSE
├── README.md
├── assets
├── book.jpg
├── cartoon_boy.png
├── clock.jpg
├── coffee.png
├── demo
│ ├── book_omini.jpg
│ ├── clock_omini.jpg
│ ├── demo_this_is_omini_control.jpg
│ ├── dreambooth_res.jpg
│ ├── man_omini.jpg
│ ├── monalisa_omini.jpg
│ ├── oranges_omini.jpg
│ ├── panda_omini.jpg
│ ├── penguin_omini.jpg
│ ├── rc_car_omini.jpg
│ ├── room_corner_canny.jpg
│ ├── room_corner_coloring.jpg
│ ├── room_corner_deblurring.jpg
│ ├── room_corner_depth.jpg
│ ├── scene_variation.jpg
│ ├── shirt_omini.jpg
│ └── try_on.jpg
├── monalisa.jpg
├── oranges.jpg
├── penguin.jpg
├── rc_car.jpg
├── room_corner.jpg
├── test_in.jpg
├── test_out.jpg
├── tshirt.jpg
├── vase.jpg
└── vase_hq.jpg
├── examples
├── inpainting.ipynb
├── spatial.ipynb
├── subject.ipynb
└── subject_1024.ipynb
├── gradio_app.py
├── requirements.txt
├── src
├── flux
│ ├── block.py
│ ├── condition.py
│ ├── generate.py
│ ├── lora_controller.py
│ ├── pipeline_tools.py
│ └── transformer.py
├── gradio
│ └── gradio_app.py
└── train
│ ├── callbacks.py
│ ├── data.py
│ ├── model.py
│ └── train.py
└── train
├── README.md
├── config
├── canny_512.yaml
├── cartoon_512.yaml
├── fill_1024.yaml
├── sr_512.yaml
└── subject_512.yaml
├── requirements.txt
└── script
├── data_download
├── data_download1.sh
└── data_download2.sh
├── train_canny.sh
├── train_cartoon.sh
└── train_subject.sh
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [2024] [Zhenxiong Tan]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OminiControl
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | > **OminiControl: Minimal and Universal Control for Diffusion Transformer**
14 | >
15 | > Zhenxiong Tan,
16 | > [Songhua Liu](http://121.37.94.87/),
17 | > [Xingyi Yang](https://adamdad.github.io/),
18 | > Qiaochu Xue,
19 | > and
20 | > [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
21 | >
22 | > [Learning and Vision Lab](http://lv-nus.org/), National University of Singapore
23 | >
24 |
25 |
26 | ## Features
27 |
28 |
29 | OminiControl is a minimal yet powerful universal control framework for Diffusion Transformer models like [FLUX](https://github.com/black-forest-labs/flux).
30 |
31 | * **Universal Control 🌐**: A unified control framework that supports both subject-driven control and spatial control (such as edge-guided and in-painting generation).
32 |
33 | * **Minimal Design 🚀**: Injects control signals while preserving original model structure. Only introduces 0.1% additional parameters to the base model.
34 |
35 | ## OminiControlGP (OminiControl for the GPU Poor) by DeepBeepMeep
36 |
37 | With just one line adding the 'mmgp' module (https://github.com/deepbeepmeep/mmgp\), OminiControl can generate images from a derived Flux model in less than 6s with 16 GB of VRAM (profile 1), in 9s with 8 GB VRAM (profile 4) or in 16s with less than 6 GB of VRAM (profile 5)
38 |
39 | To run the Gradio app with a profile 3 (default profile, the fastest but requires the most VRAM):
40 | ```bash
41 | python gradio_app --profile 3
42 | ```
43 | To run the Gradio app with a profile 5 (a bit slower but requires only 6 GB of VRAM):
44 | ```bash
45 | python gradio_app --profile 5
46 | ```
47 |
48 | You may check the mmgp git homepage if you want to design your own profiles (for instance to disable quantization).
49 |
50 | If you enjoy this applcitation, you will certainly appreciate these ones:\
51 | - Hunyuan3D-2GP: https://github.com/deepbeepmeep/Hunyuan3D-2GP\
52 | A great image to 3D or text to 3D tool by the Tencent team. Thanks to mmgp it can run with less than 6 GB of VRAM
53 |
54 | - HuanyuanVideoGP: https://github.com/deepbeepmeep/HunyuanVideoGP\
55 | One of the best open source Text to Video generator
56 |
57 | - FluxFillGP: https://github.com/deepbeepmeep/FluxFillGP\
58 | One of the best inpainting / outpainting tools based on Flux that can run with less than 12 GB of VRAM.
59 |
60 | - Cosmos1GP: https://github.com/deepbeepmeep/Cosmos1GP\
61 | This application include two models: a text to world generator and a image / video to world (probably the best open source image to video generator).
62 |
63 | ## News
64 | - **2025-01-25**: ⭐️ DeepBeepMeep fork: added support for mmgp
65 | - **2024-12-26**: ⭐️ Training code are released. Now you can create your own OminiControl model by customizing any control tasks (3D, multi-view, pose-guided, try-on, etc.) with the FLUX model. Check the [training folder](./train) for more details.
66 |
67 | ## Quick Start
68 | ### Setup (Optional)
69 | 1. **Environment setup**
70 | ```bash
71 | conda create -n omini python=3.10
72 | conda activate omini
73 | ```
74 | 2. **Requirements installation**
75 | ```bash
76 | pip install -r requirements.txt
77 | ```
78 | ### Usage example
79 | 1. Subject-driven generation: `examples/subject.ipynb`
80 | 2. In-painting: `examples/inpainting.ipynb`
81 | 3. Canny edge to image, depth to image, colorization, deblurring: `examples/spatial.ipynb`
82 |
83 |
84 |
85 | ### Guidelines for subject-driven generation
86 | 1. Input images are automatically center-cropped and resized to 512x512 resolution.
87 | 2. When writing prompts, refer to the subject using phrases like `this item`, `the object`, or `it`. e.g.
88 | 1. *A close up view of this item. It is placed on a wooden table.*
89 | 2. *A young lady is wearing this shirt.*
90 | 3. The model primarily works with objects rather than human subjects currently, due to the absence of human data in training.
91 |
92 | ## Generated samples
93 | ### Subject-driven generation
94 |
95 |
96 | **Demos** (Left: condition image; Right: generated image)
97 |
98 |
99 |

100 |

101 |

102 |

103 |
104 |
105 |
106 | Text Prompts
107 |
108 | - Prompt1: *A close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!.'*
109 | - Prompt2: *A film style shot. On the moon, this item drives across the moon surface. A flag on it reads 'Omini'. The background is that Earth looms large in the foreground.*
110 | - Prompt3: *In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.*
111 | - Prompt4: *"On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple."*
112 |
113 |
114 | More results
115 |
116 | * Try on:
117 |
118 | * Scene variations:
119 |
120 | * Dreambooth dataset:
121 |
122 | * Oye-cartoon finetune:
123 |
124 |

125 |

126 |
127 |
128 |
129 | ### Spatially aligned control
130 | 1. **Image Inpainting** (Left: original image; Center: masked image; Right: filled image)
131 | - Prompt: *The Mona Lisa is wearing a white VR headset with 'Omini' written on it.*
132 |
133 |
134 | - Prompt: *A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.*
135 |
136 |
137 | 2. **Other spatially aligned tasks** (Canny edge to image, depth to image, colorization, deblurring)
138 |
139 |
140 | Click to show
141 |
142 |

143 |

144 |

145 |

146 |
147 |
148 | Prompt: *A light gray sofa stands against a white wall, featuring a black and white geometric patterned pillow. A white side table sits next to the sofa, topped with a white adjustable desk lamp and some books. Dark hardwood flooring contrasts with the pale walls and furniture.*
149 |
150 |
151 |
152 |
153 |
154 | ## Models
155 |
156 | **Subject-driven control:**
157 | | Model | Base model | Description | Resolution |
158 | | ------------------------------------------------------------------------------------------------ | -------------- | -------------------------------------------------------------------------------------------------------- | ------------ |
159 | | [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `subject` | FLUX.1-schnell | The model used in the paper. | (512, 512) |
160 | | [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_512` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset. | (512, 512) |
161 | | [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_1024` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset and accommodates higher resolution. (To be released) | (1024, 1024) |
162 | | [`oye-cartoon`](https://huggingface.co/saquiboye/oye-cartoon) | FLUX.1-dev | The model has been fine-tuned on [oye-cartoon](https://huggingface.co/datasets/saquiboye/oye-cartoon) dataset by [@saquib764](https://github.com/Saquib764) | (512, 512) |
163 |
164 | **Spatial aligned control:**
165 | | Model | Base model | Description | Resolution |
166 | | --------------------------------------------------------------------------------------------------------- | ---------- | -------------------------------------------------------------------------- | ------------ |
167 | | [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `` | FLUX.1 | Canny edge to image, depth to image, colorization, deblurring, in-painting | (512, 512) |
168 | | [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `_1024` | FLUX.1 | Supports higher resolution.(To be released) | (1024, 1024) |
169 |
170 | ## Community Extensions
171 | - [ComfyUI-Diffusers-OminiControl](https://github.com/Macoron/ComfyUI-Diffusers-OminiControl) - ComfyUI integration by [@Macoron](https://github.com/Macoron)
172 | - [ComfyUI_RH_OminiControl](https://github.com/HM-RunningHub/ComfyUI_RH_OminiControl) - ComfyUI integration by [@HM-RunningHub](https://github.com/HM-RunningHub)
173 |
174 | ## Limitations
175 | 1. The model's subject-driven generation primarily works with objects rather than human subjects due to the absence of human data in training.
176 | 2. The subject-driven generation model may not work well with `FLUX.1-dev`.
177 | 3. The released model currently only supports the resolution of 512x512.
178 |
179 | ## Training
180 | Training instructions can be found in this [folder](./train).
181 |
182 |
183 | ## To-do
184 | - [x] Release the training code.
185 | - [ ] Release the model for higher resolution (1024x1024).
186 |
187 | ## Citation
188 | ```
189 | @article{
190 | tan2024omini,
191 | title={OminiControl: Minimal and Universal Control for Diffusion Transformer},
192 | author={Zhenxiong Tan, Songhua Liu, Xingyi Yang, Qiaochu Xue, and Xinchao Wang},
193 | journal={arXiv preprint arXiv:2411.15098},
194 | year={2024}
195 | }
196 | ```
197 |
--------------------------------------------------------------------------------
/assets/book.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/book.jpg
--------------------------------------------------------------------------------
/assets/cartoon_boy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/cartoon_boy.png
--------------------------------------------------------------------------------
/assets/clock.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/clock.jpg
--------------------------------------------------------------------------------
/assets/coffee.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/coffee.png
--------------------------------------------------------------------------------
/assets/demo/book_omini.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/book_omini.jpg
--------------------------------------------------------------------------------
/assets/demo/clock_omini.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/clock_omini.jpg
--------------------------------------------------------------------------------
/assets/demo/demo_this_is_omini_control.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/demo_this_is_omini_control.jpg
--------------------------------------------------------------------------------
/assets/demo/dreambooth_res.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/dreambooth_res.jpg
--------------------------------------------------------------------------------
/assets/demo/man_omini.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/man_omini.jpg
--------------------------------------------------------------------------------
/assets/demo/monalisa_omini.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/monalisa_omini.jpg
--------------------------------------------------------------------------------
/assets/demo/oranges_omini.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/oranges_omini.jpg
--------------------------------------------------------------------------------
/assets/demo/panda_omini.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/panda_omini.jpg
--------------------------------------------------------------------------------
/assets/demo/penguin_omini.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/penguin_omini.jpg
--------------------------------------------------------------------------------
/assets/demo/rc_car_omini.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/rc_car_omini.jpg
--------------------------------------------------------------------------------
/assets/demo/room_corner_canny.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/room_corner_canny.jpg
--------------------------------------------------------------------------------
/assets/demo/room_corner_coloring.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/room_corner_coloring.jpg
--------------------------------------------------------------------------------
/assets/demo/room_corner_deblurring.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/room_corner_deblurring.jpg
--------------------------------------------------------------------------------
/assets/demo/room_corner_depth.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/room_corner_depth.jpg
--------------------------------------------------------------------------------
/assets/demo/scene_variation.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/scene_variation.jpg
--------------------------------------------------------------------------------
/assets/demo/shirt_omini.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/shirt_omini.jpg
--------------------------------------------------------------------------------
/assets/demo/try_on.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/demo/try_on.jpg
--------------------------------------------------------------------------------
/assets/monalisa.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/monalisa.jpg
--------------------------------------------------------------------------------
/assets/oranges.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/oranges.jpg
--------------------------------------------------------------------------------
/assets/penguin.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/penguin.jpg
--------------------------------------------------------------------------------
/assets/rc_car.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/rc_car.jpg
--------------------------------------------------------------------------------
/assets/room_corner.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/room_corner.jpg
--------------------------------------------------------------------------------
/assets/test_in.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/test_in.jpg
--------------------------------------------------------------------------------
/assets/test_out.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/test_out.jpg
--------------------------------------------------------------------------------
/assets/tshirt.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/tshirt.jpg
--------------------------------------------------------------------------------
/assets/vase.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/vase.jpg
--------------------------------------------------------------------------------
/assets/vase_hq.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepbeepmeep/OminiControlGP/7f5cab917f0a816d3e8b9306f5f0831eba31d337/assets/vase_hq.jpg
--------------------------------------------------------------------------------
/examples/inpainting.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "\n",
11 | "os.chdir(\"..\")"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "import torch\n",
21 | "from diffusers.pipelines import FluxPipeline\n",
22 | "from src.flux.condition import Condition\n",
23 | "from PIL import Image\n",
24 | "\n",
25 | "from src.flux.generate import generate, seed_everything"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "pipe = FluxPipeline.from_pretrained(\n",
35 | " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
36 | ")\n",
37 | "pipe = pipe.to(\"cuda\")"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "pipe.load_lora_weights(\n",
47 | " \"Yuanshi/OminiControl\",\n",
48 | " weight_name=f\"experimental/fill.safetensors\",\n",
49 | " adapter_name=\"fill\",\n",
50 | ")"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "image = Image.open(\"assets/monalisa.jpg\").convert(\"RGB\").resize((512, 512))\n",
60 | "\n",
61 | "masked_image = image.copy()\n",
62 | "masked_image.paste((0, 0, 0), (128, 100, 384, 220))\n",
63 | "\n",
64 | "condition = Condition(\"fill\", masked_image)\n",
65 | "\n",
66 | "seed_everything()\n",
67 | "result_img = generate(\n",
68 | " pipe,\n",
69 | " prompt=\"The Mona Lisa is wearing a white VR headset with 'Omini' written on it.\",\n",
70 | " conditions=[condition],\n",
71 | ").images[0]\n",
72 | "\n",
73 | "concat_image = Image.new(\"RGB\", (1536, 512))\n",
74 | "concat_image.paste(image, (0, 0))\n",
75 | "concat_image.paste(condition.condition, (512, 0))\n",
76 | "concat_image.paste(result_img, (1024, 0))\n",
77 | "concat_image"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "image = Image.open(\"assets/book.jpg\").convert(\"RGB\").resize((512, 512))\n",
87 | "\n",
88 | "w, h, min_dim = image.size + (min(image.size),)\n",
89 | "image = image.crop(\n",
90 | " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n",
91 | ").resize((512, 512))\n",
92 | "\n",
93 | "\n",
94 | "masked_image = image.copy()\n",
95 | "masked_image.paste((0, 0, 0), (150, 150, 350, 250))\n",
96 | "masked_image.paste((0, 0, 0), (200, 380, 320, 420))\n",
97 | "\n",
98 | "condition = Condition(\"fill\", masked_image)\n",
99 | "\n",
100 | "seed_everything()\n",
101 | "result_img = generate(\n",
102 | " pipe,\n",
103 | " prompt=\"A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.\",\n",
104 | " conditions=[condition],\n",
105 | ").images[0]\n",
106 | "\n",
107 | "concat_image = Image.new(\"RGB\", (1536, 512))\n",
108 | "concat_image.paste(image, (0, 0))\n",
109 | "concat_image.paste(condition.condition, (512, 0))\n",
110 | "concat_image.paste(result_img, (1024, 0))\n",
111 | "concat_image"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": []
120 | }
121 | ],
122 | "metadata": {
123 | "kernelspec": {
124 | "display_name": "base",
125 | "language": "python",
126 | "name": "python3"
127 | },
128 | "language_info": {
129 | "codemirror_mode": {
130 | "name": "ipython",
131 | "version": 3
132 | },
133 | "file_extension": ".py",
134 | "mimetype": "text/x-python",
135 | "name": "python",
136 | "nbconvert_exporter": "python",
137 | "pygments_lexer": "ipython3",
138 | "version": "3.12.7"
139 | }
140 | },
141 | "nbformat": 4,
142 | "nbformat_minor": 2
143 | }
144 |
--------------------------------------------------------------------------------
/examples/spatial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "\n",
11 | "os.chdir(\"..\")"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "import torch\n",
21 | "from diffusers.pipelines import FluxPipeline\n",
22 | "from src.flux.condition import Condition\n",
23 | "from PIL import Image\n",
24 | "\n",
25 | "from src.flux.generate import generate, seed_everything"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "pipe = FluxPipeline.from_pretrained(\n",
35 | " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
36 | ")\n",
37 | "pipe = pipe.to(\"cuda\")"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "for condition_type in [\"canny\", \"depth\", \"coloring\", \"deblurring\"]:\n",
47 | " pipe.load_lora_weights(\n",
48 | " \"Yuanshi/OminiControl\",\n",
49 | " weight_name=f\"experimental/{condition_type}.safetensors\",\n",
50 | " adapter_name=condition_type,\n",
51 | " )"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "image = Image.open(\"assets/coffee.png\").convert(\"RGB\")\n",
61 | "\n",
62 | "w, h, min_dim = image.size + (min(image.size),)\n",
63 | "image = image.crop(\n",
64 | " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n",
65 | ").resize((512, 512))\n",
66 | "\n",
67 | "prompt = \"In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table.\""
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": null,
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "condition = Condition(\"canny\", image)\n",
77 | "\n",
78 | "seed_everything()\n",
79 | "\n",
80 | "result_img = generate(\n",
81 | " pipe,\n",
82 | " prompt=prompt,\n",
83 | " conditions=[condition],\n",
84 | ").images[0]\n",
85 | "\n",
86 | "concat_image = Image.new(\"RGB\", (1536, 512))\n",
87 | "concat_image.paste(image, (0, 0))\n",
88 | "concat_image.paste(condition.condition, (512, 0))\n",
89 | "concat_image.paste(result_img, (1024, 0))\n",
90 | "concat_image"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": null,
96 | "metadata": {},
97 | "outputs": [],
98 | "source": [
99 | "condition = Condition(\"depth\", image)\n",
100 | "\n",
101 | "seed_everything()\n",
102 | "\n",
103 | "result_img = generate(\n",
104 | " pipe,\n",
105 | " prompt=prompt,\n",
106 | " conditions=[condition],\n",
107 | ").images[0]\n",
108 | "\n",
109 | "concat_image = Image.new(\"RGB\", (1536, 512))\n",
110 | "concat_image.paste(image, (0, 0))\n",
111 | "concat_image.paste(condition.condition, (512, 0))\n",
112 | "concat_image.paste(result_img, (1024, 0))\n",
113 | "concat_image"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": null,
119 | "metadata": {},
120 | "outputs": [],
121 | "source": [
122 | "condition = Condition(\"deblurring\", image)\n",
123 | "\n",
124 | "seed_everything()\n",
125 | "\n",
126 | "result_img = generate(\n",
127 | " pipe,\n",
128 | " prompt=prompt,\n",
129 | " conditions=[condition],\n",
130 | ").images[0]\n",
131 | "\n",
132 | "concat_image = Image.new(\"RGB\", (1536, 512))\n",
133 | "concat_image.paste(image, (0, 0))\n",
134 | "concat_image.paste(condition.condition, (512, 0))\n",
135 | "concat_image.paste(result_img, (1024, 0))\n",
136 | "concat_image"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "metadata": {},
143 | "outputs": [],
144 | "source": [
145 | "condition = Condition(\"coloring\", image)\n",
146 | "\n",
147 | "seed_everything()\n",
148 | "\n",
149 | "result_img = generate(\n",
150 | " pipe,\n",
151 | " prompt=prompt,\n",
152 | " conditions=[condition],\n",
153 | ").images[0]\n",
154 | "\n",
155 | "concat_image = Image.new(\"RGB\", (1536, 512))\n",
156 | "concat_image.paste(image, (0, 0))\n",
157 | "concat_image.paste(condition.condition, (512, 0))\n",
158 | "concat_image.paste(result_img, (1024, 0))\n",
159 | "concat_image"
160 | ]
161 | }
162 | ],
163 | "metadata": {
164 | "kernelspec": {
165 | "display_name": "base",
166 | "language": "python",
167 | "name": "python3"
168 | },
169 | "language_info": {
170 | "codemirror_mode": {
171 | "name": "ipython",
172 | "version": 3
173 | },
174 | "file_extension": ".py",
175 | "mimetype": "text/x-python",
176 | "name": "python",
177 | "nbconvert_exporter": "python",
178 | "pygments_lexer": "ipython3",
179 | "version": "3.12.7"
180 | }
181 | },
182 | "nbformat": 4,
183 | "nbformat_minor": 2
184 | }
185 |
--------------------------------------------------------------------------------
/examples/subject.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "\n",
11 | "os.chdir(\"..\")"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "import torch\n",
21 | "from diffusers.pipelines import FluxPipeline\n",
22 | "from src.flux.condition import Condition\n",
23 | "from PIL import Image\n",
24 | "\n",
25 | "from src.flux.generate import generate, seed_everything"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "pipe = FluxPipeline.from_pretrained(\n",
35 | " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n",
36 | ")\n",
37 | "pipe = pipe.to(\"cuda\")\n",
38 | "pipe.load_lora_weights(\n",
39 | " \"Yuanshi/OminiControl\",\n",
40 | " weight_name=f\"omini/subject_512.safetensors\",\n",
41 | " adapter_name=\"subject\",\n",
42 | ")"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n",
52 | "\n",
53 | "condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
54 | "\n",
55 | "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n",
56 | "\n",
57 | "\n",
58 | "seed_everything(0)\n",
59 | "\n",
60 | "result_img = generate(\n",
61 | " pipe,\n",
62 | " prompt=prompt,\n",
63 | " conditions=[condition],\n",
64 | " num_inference_steps=8,\n",
65 | " height=512,\n",
66 | " width=512,\n",
67 | ").images[0]\n",
68 | "\n",
69 | "concat_image = Image.new(\"RGB\", (1024, 512))\n",
70 | "concat_image.paste(image, (0, 0))\n",
71 | "concat_image.paste(result_img, (512, 0))\n",
72 | "concat_image"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n",
82 | "\n",
83 | "condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
84 | "\n",
85 | "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n",
86 | "\n",
87 | "\n",
88 | "seed_everything()\n",
89 | "\n",
90 | "result_img = generate(\n",
91 | " pipe,\n",
92 | " prompt=prompt,\n",
93 | " conditions=[condition],\n",
94 | " num_inference_steps=8,\n",
95 | " height=512,\n",
96 | " width=512,\n",
97 | ").images[0]\n",
98 | "\n",
99 | "concat_image = Image.new(\"RGB\", (1024, 512))\n",
100 | "concat_image.paste(condition.condition, (0, 0))\n",
101 | "concat_image.paste(result_img, (512, 0))\n",
102 | "concat_image"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n",
112 | "\n",
113 | "condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
114 | "\n",
115 | "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n",
116 | "\n",
117 | "seed_everything()\n",
118 | "\n",
119 | "result_img = generate(\n",
120 | " pipe,\n",
121 | " prompt=prompt,\n",
122 | " conditions=[condition],\n",
123 | " num_inference_steps=8,\n",
124 | " height=512,\n",
125 | " width=512,\n",
126 | ").images[0]\n",
127 | "\n",
128 | "concat_image = Image.new(\"RGB\", (1024, 512))\n",
129 | "concat_image.paste(condition.condition, (0, 0))\n",
130 | "concat_image.paste(result_img, (512, 0))\n",
131 | "concat_image"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n",
141 | "\n",
142 | "condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
143 | "\n",
144 | "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n",
145 | "\n",
146 | "seed_everything()\n",
147 | "\n",
148 | "result_img = generate(\n",
149 | " pipe,\n",
150 | " prompt=prompt,\n",
151 | " conditions=[condition],\n",
152 | " num_inference_steps=8,\n",
153 | " height=512,\n",
154 | " width=512,\n",
155 | ").images[0]\n",
156 | "\n",
157 | "concat_image = Image.new(\"RGB\", (1024, 512))\n",
158 | "concat_image.paste(condition.condition, (0, 0))\n",
159 | "concat_image.paste(result_img, (512, 0))\n",
160 | "concat_image"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "metadata": {},
167 | "outputs": [],
168 | "source": [
169 | "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n",
170 | "\n",
171 | "condition = Condition(\"subject\", image, position_delta=(0, 32))\n",
172 | "\n",
173 | "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n",
174 | "\n",
175 | "seed_everything()\n",
176 | "\n",
177 | "result_img = generate(\n",
178 | " pipe,\n",
179 | " prompt=prompt,\n",
180 | " conditions=[condition],\n",
181 | " num_inference_steps=8,\n",
182 | " height=512,\n",
183 | " width=512,\n",
184 | ").images[0]\n",
185 | "\n",
186 | "concat_image = Image.new(\"RGB\", (1024, 512))\n",
187 | "concat_image.paste(condition.condition, (0, 0))\n",
188 | "concat_image.paste(result_img, (512, 0))\n",
189 | "concat_image"
190 | ]
191 | }
192 | ],
193 | "metadata": {
194 | "kernelspec": {
195 | "display_name": "base",
196 | "language": "python",
197 | "name": "python3"
198 | },
199 | "language_info": {
200 | "codemirror_mode": {
201 | "name": "ipython",
202 | "version": 3
203 | },
204 | "file_extension": ".py",
205 | "mimetype": "text/x-python",
206 | "name": "python",
207 | "nbconvert_exporter": "python",
208 | "pygments_lexer": "ipython3",
209 | "version": "3.12.7"
210 | }
211 | },
212 | "nbformat": 4,
213 | "nbformat_minor": 2
214 | }
215 |
--------------------------------------------------------------------------------
/examples/subject_1024.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 4,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "\n",
11 | "os.chdir(\"..\")"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "import torch\n",
21 | "from diffusers.pipelines import FluxPipeline\n",
22 | "from src.flux.condition import Condition\n",
23 | "from PIL import Image\n",
24 | "\n",
25 | "from src.flux.generate import generate, seed_everything"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "pipe = FluxPipeline.from_pretrained(\n",
35 | " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n",
36 | ")\n",
37 | "pipe = pipe.to(\"cuda\")\n",
38 | "pipe.load_lora_weights(\n",
39 | " \"Yuanshi/OminiControl\",\n",
40 | " weight_name=f\"omini/subject_1024_beta.safetensors\",\n",
41 | " adapter_name=\"subject\",\n",
42 | ")"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n",
52 | "\n",
53 | "condition = Condition(\"subject\", image)\n",
54 | "\n",
55 | "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n",
56 | "\n",
57 | "\n",
58 | "seed_everything(0)\n",
59 | "\n",
60 | "result_img = generate(\n",
61 | " pipe,\n",
62 | " prompt=prompt,\n",
63 | " conditions=[condition],\n",
64 | " num_inference_steps=8,\n",
65 | " height=1024,\n",
66 | " width=1024,\n",
67 | ").images[0]\n",
68 | "\n",
69 | "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
70 | "concat_image.paste(image, (0, 0))\n",
71 | "concat_image.paste(result_img, (512, 0))\n",
72 | "concat_image"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n",
82 | "\n",
83 | "condition = Condition(\"subject\", image)\n",
84 | "\n",
85 | "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n",
86 | "\n",
87 | "\n",
88 | "seed_everything(0)\n",
89 | "\n",
90 | "result_img = generate(\n",
91 | " pipe,\n",
92 | " prompt=prompt,\n",
93 | " conditions=[condition],\n",
94 | " num_inference_steps=8,\n",
95 | " height=1024,\n",
96 | " width=1024,\n",
97 | ").images[0]\n",
98 | "\n",
99 | "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
100 | "concat_image.paste(image, (0, 0))\n",
101 | "concat_image.paste(result_img, (512, 0))\n",
102 | "concat_image"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n",
112 | "\n",
113 | "condition = Condition(\"subject\", image)\n",
114 | "\n",
115 | "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n",
116 | "\n",
117 | "seed_everything()\n",
118 | "\n",
119 | "result_img = generate(\n",
120 | " pipe,\n",
121 | " prompt=prompt,\n",
122 | " conditions=[condition],\n",
123 | " num_inference_steps=8,\n",
124 | " height=1024,\n",
125 | " width=1024,\n",
126 | ").images[0]\n",
127 | "\n",
128 | "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
129 | "concat_image.paste(image, (0, 0))\n",
130 | "concat_image.paste(result_img, (512, 0))\n",
131 | "concat_image"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n",
141 | "\n",
142 | "condition = Condition(\"subject\", image)\n",
143 | "\n",
144 | "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n",
145 | "\n",
146 | "seed_everything(0)\n",
147 | "\n",
148 | "result_img = generate(\n",
149 | " pipe,\n",
150 | " prompt=prompt,\n",
151 | " conditions=[condition],\n",
152 | " num_inference_steps=8,\n",
153 | " height=1024,\n",
154 | " width=1024,\n",
155 | ").images[0]\n",
156 | "\n",
157 | "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
158 | "concat_image.paste(image, (0, 0))\n",
159 | "concat_image.paste(result_img, (512, 0))\n",
160 | "concat_image"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "metadata": {},
167 | "outputs": [],
168 | "source": [
169 | "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n",
170 | "\n",
171 | "condition = Condition(\"subject\", image)\n",
172 | "\n",
173 | "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n",
174 | "\n",
175 | "seed_everything()\n",
176 | "\n",
177 | "result_img = generate(\n",
178 | " pipe,\n",
179 | " prompt=prompt,\n",
180 | " conditions=[condition],\n",
181 | " num_inference_steps=8,\n",
182 | " height=1024,\n",
183 | " width=1024,\n",
184 | ").images[0]\n",
185 | "\n",
186 | "concat_image = Image.new(\"RGB\", (1024+512, 1024))\n",
187 | "concat_image.paste(image, (0, 0))\n",
188 | "concat_image.paste(result_img, (512, 0))\n",
189 | "concat_image"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": null,
195 | "metadata": {},
196 | "outputs": [],
197 | "source": []
198 | }
199 | ],
200 | "metadata": {
201 | "kernelspec": {
202 | "display_name": "Python 3 (ipykernel)",
203 | "language": "python",
204 | "name": "python3"
205 | },
206 | "language_info": {
207 | "codemirror_mode": {
208 | "name": "ipython",
209 | "version": 3
210 | },
211 | "file_extension": ".py",
212 | "mimetype": "text/x-python",
213 | "name": "python",
214 | "nbconvert_exporter": "python",
215 | "pygments_lexer": "ipython3",
216 | "version": "3.9.21"
217 | }
218 | },
219 | "nbformat": 4,
220 | "nbformat_minor": 2
221 | }
222 |
--------------------------------------------------------------------------------
/gradio_app.py:
--------------------------------------------------------------------------------
1 | from mmgp import offload
2 | import gradio as gr
3 | import torch
4 | from PIL import Image, ImageDraw, ImageFont
5 | from diffusers.pipelines import FluxPipeline
6 | from diffusers import FluxTransformer2DModel
7 | import numpy as np
8 |
9 | from src.flux.condition import Condition
10 | from src.flux.generate import seed_everything, generate
11 |
12 |
13 | pipe = None
14 | use_int8 = False
15 |
16 |
17 | def get_gpu_memory():
18 | return torch.cuda.get_device_properties(0).total_memory / 1024**3
19 |
20 |
21 | def init_pipeline():
22 | global pipe
23 | if False and (use_int8 or get_gpu_memory() < 33):
24 | transformer_model = FluxTransformer2DModel.from_pretrained(
25 | "sayakpaul/flux.1-schell-int8wo-improved",
26 | torch_dtype=torch.bfloat16,
27 | use_safetensors=False,
28 | )
29 | pipe = FluxPipeline.from_pretrained(
30 | "black-forest-labs/FLUX.1-schnell",
31 | transformer=transformer_model,
32 | torch_dtype=torch.bfloat16,
33 | )
34 | else:
35 | pipe = FluxPipeline.from_pretrained(
36 | "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
37 | )
38 | pipe = pipe.to("cpu")
39 | pipe.load_lora_weights(
40 | "Yuanshi/OminiControl",
41 | weight_name="omini/subject_512.safetensors",
42 | adapter_name="subject",
43 | )
44 | offload.profile(pipe, profile_no=int(args.profile), verboseLevel=int(args.verbose), quantizeTransformer= False
45 | )
46 |
47 | def process_image_and_text(image, text):
48 | # center crop image
49 | w, h, min_size = image.size[0], image.size[1], min(image.size)
50 | image = image.crop(
51 | (
52 | (w - min_size) // 2,
53 | (h - min_size) // 2,
54 | (w + min_size) // 2,
55 | (h + min_size) // 2,
56 | )
57 | )
58 | image = image.resize((512, 512))
59 |
60 | condition = Condition("subject", image, position_delta=(0, 32))
61 |
62 | if pipe is None:
63 | init_pipeline()
64 |
65 | result_img = generate(
66 | pipe,
67 | prompt=text.strip(),
68 | conditions=[condition],
69 | num_inference_steps=8,
70 | height=512,
71 | width=512,
72 | ).images[0]
73 |
74 | return result_img
75 |
76 |
77 | def get_samples():
78 | sample_list = [
79 | {
80 | "image": "assets/oranges.jpg",
81 | "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'",
82 | },
83 | {
84 | "image": "assets/penguin.jpg",
85 | "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'",
86 | },
87 | {
88 | "image": "assets/rc_car.jpg",
89 | "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.",
90 | },
91 | {
92 | "image": "assets/clock.jpg",
93 | "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.",
94 | },
95 | {
96 | "image": "assets/tshirt.jpg",
97 | "text": "On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her.",
98 | },
99 | ]
100 | return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list]
101 |
102 |
103 | demo = gr.Interface(
104 | fn=process_image_and_text,
105 | inputs=[
106 | gr.Image(type="pil"),
107 | gr.Textbox(lines=2),
108 | ],
109 | outputs=gr.Image(type="pil"),
110 | title="OminiControlGP / Subject driven generation for the GPU Poor",
111 |
112 | examples=get_samples(),
113 | )
114 |
115 | if __name__ == "__main__":
116 | import argparse
117 |
118 | parser = argparse.ArgumentParser()
119 | parser.add_argument('--profile', type=str, default="3")
120 | parser.add_argument('--verbose', type=str, default="1")
121 |
122 | args = parser.parse_args()
123 |
124 | init_pipeline()
125 | demo.launch(
126 | debug=True,
127 | )
128 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers
2 | diffusers
3 | peft
4 | opencv-python
5 | protobuf
6 | sentencepiece
7 | gradio
8 | jupyter
9 | torchao
10 | mmgp==3.1.4.post1
--------------------------------------------------------------------------------
/src/flux/block.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import List, Union, Optional, Dict, Any, Callable
3 | from diffusers.models.attention_processor import Attention, F
4 | from .lora_controller import enable_lora
5 |
6 |
7 | def attn_forward(
8 | attn: Attention,
9 | hidden_states: torch.FloatTensor,
10 | encoder_hidden_states: torch.FloatTensor = None,
11 | condition_latents: torch.FloatTensor = None,
12 | attention_mask: Optional[torch.FloatTensor] = None,
13 | image_rotary_emb: Optional[torch.Tensor] = None,
14 | cond_rotary_emb: Optional[torch.Tensor] = None,
15 | model_config: Optional[Dict[str, Any]] = {},
16 | ) -> torch.FloatTensor:
17 | batch_size, _, _ = (
18 | hidden_states.shape
19 | if encoder_hidden_states is None
20 | else encoder_hidden_states.shape
21 | )
22 |
23 | with enable_lora(
24 | (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False)
25 | ):
26 | # `sample` projections.
27 | query = attn.to_q(hidden_states)
28 | key = attn.to_k(hidden_states)
29 | value = attn.to_v(hidden_states)
30 |
31 | inner_dim = key.shape[-1]
32 | head_dim = inner_dim // attn.heads
33 |
34 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
35 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
36 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
37 |
38 | if attn.norm_q is not None:
39 | query = attn.norm_q(query)
40 | if attn.norm_k is not None:
41 | key = attn.norm_k(key)
42 |
43 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
44 | if encoder_hidden_states is not None:
45 | # `context` projections.
46 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
47 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
48 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
49 |
50 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
51 | batch_size, -1, attn.heads, head_dim
52 | ).transpose(1, 2)
53 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
54 | batch_size, -1, attn.heads, head_dim
55 | ).transpose(1, 2)
56 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
57 | batch_size, -1, attn.heads, head_dim
58 | ).transpose(1, 2)
59 |
60 | if attn.norm_added_q is not None:
61 | encoder_hidden_states_query_proj = attn.norm_added_q(
62 | encoder_hidden_states_query_proj
63 | )
64 | if attn.norm_added_k is not None:
65 | encoder_hidden_states_key_proj = attn.norm_added_k(
66 | encoder_hidden_states_key_proj
67 | )
68 |
69 | # attention
70 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
71 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
72 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
73 |
74 | if image_rotary_emb is not None:
75 | from diffusers.models.embeddings import apply_rotary_emb
76 |
77 | query = apply_rotary_emb(query, image_rotary_emb)
78 | key = apply_rotary_emb(key, image_rotary_emb)
79 |
80 | if condition_latents is not None:
81 | cond_query = attn.to_q(condition_latents)
82 | cond_key = attn.to_k(condition_latents)
83 | cond_value = attn.to_v(condition_latents)
84 |
85 | cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose(
86 | 1, 2
87 | )
88 | cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
89 | cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose(
90 | 1, 2
91 | )
92 | if attn.norm_q is not None:
93 | cond_query = attn.norm_q(cond_query)
94 | if attn.norm_k is not None:
95 | cond_key = attn.norm_k(cond_key)
96 |
97 | if cond_rotary_emb is not None:
98 | cond_query = apply_rotary_emb(cond_query, cond_rotary_emb)
99 | cond_key = apply_rotary_emb(cond_key, cond_rotary_emb)
100 |
101 | if condition_latents is not None:
102 | query = torch.cat([query, cond_query], dim=2)
103 | key = torch.cat([key, cond_key], dim=2)
104 | value = torch.cat([value, cond_value], dim=2)
105 |
106 | if not model_config.get("union_cond_attn", True):
107 | # If we don't want to use the union condition attention, we need to mask the attention
108 | # between the hidden states and the condition latents
109 | attention_mask = torch.ones(
110 | query.shape[2], key.shape[2], device=query.device, dtype=torch.bool
111 | )
112 | condition_n = cond_query.shape[2]
113 | attention_mask[-condition_n:, :-condition_n] = False
114 | attention_mask[:-condition_n, -condition_n:] = False
115 | if hasattr(attn, "c_factor"):
116 | attention_mask = torch.zeros(
117 | query.shape[2], key.shape[2], device=query.device, dtype=query.dtype
118 | )
119 | condition_n = cond_query.shape[2]
120 | bias = torch.log(attn.c_factor[0])
121 | attention_mask[-condition_n:, :-condition_n] = bias
122 | attention_mask[:-condition_n, -condition_n:] = bias
123 | hidden_states = F.scaled_dot_product_attention(
124 | query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask
125 | )
126 | hidden_states = hidden_states.transpose(1, 2).reshape(
127 | batch_size, -1, attn.heads * head_dim
128 | )
129 | hidden_states = hidden_states.to(query.dtype)
130 |
131 | if encoder_hidden_states is not None:
132 | if condition_latents is not None:
133 | encoder_hidden_states, hidden_states, condition_latents = (
134 | hidden_states[:, : encoder_hidden_states.shape[1]],
135 | hidden_states[
136 | :, encoder_hidden_states.shape[1] : -condition_latents.shape[1]
137 | ],
138 | hidden_states[:, -condition_latents.shape[1] :],
139 | )
140 | else:
141 | encoder_hidden_states, hidden_states = (
142 | hidden_states[:, : encoder_hidden_states.shape[1]],
143 | hidden_states[:, encoder_hidden_states.shape[1] :],
144 | )
145 |
146 | with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)):
147 | # linear proj
148 | hidden_states = attn.to_out[0](hidden_states)
149 | # dropout
150 | hidden_states = attn.to_out[1](hidden_states)
151 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
152 |
153 | if condition_latents is not None:
154 | condition_latents = attn.to_out[0](condition_latents)
155 | condition_latents = attn.to_out[1](condition_latents)
156 |
157 | return (
158 | (hidden_states, encoder_hidden_states, condition_latents)
159 | if condition_latents is not None
160 | else (hidden_states, encoder_hidden_states)
161 | )
162 | elif condition_latents is not None:
163 | # if there are condition_latents, we need to separate the hidden_states and the condition_latents
164 | hidden_states, condition_latents = (
165 | hidden_states[:, : -condition_latents.shape[1]],
166 | hidden_states[:, -condition_latents.shape[1] :],
167 | )
168 | return hidden_states, condition_latents
169 | else:
170 | return hidden_states
171 |
172 |
173 | def block_forward(
174 | self,
175 | hidden_states: torch.FloatTensor,
176 | encoder_hidden_states: torch.FloatTensor,
177 | condition_latents: torch.FloatTensor,
178 | temb: torch.FloatTensor,
179 | cond_temb: torch.FloatTensor,
180 | cond_rotary_emb=None,
181 | image_rotary_emb=None,
182 | model_config: Optional[Dict[str, Any]] = {},
183 | ):
184 | use_cond = condition_latents is not None
185 | with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)):
186 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
187 | hidden_states, emb=temb
188 | )
189 |
190 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
191 | self.norm1_context(encoder_hidden_states, emb=temb)
192 | )
193 |
194 | if use_cond:
195 | (
196 | norm_condition_latents,
197 | cond_gate_msa,
198 | cond_shift_mlp,
199 | cond_scale_mlp,
200 | cond_gate_mlp,
201 | ) = self.norm1(condition_latents, emb=cond_temb)
202 |
203 | # Attention.
204 | result = attn_forward(
205 | self.attn,
206 | model_config=model_config,
207 | hidden_states=norm_hidden_states,
208 | encoder_hidden_states=norm_encoder_hidden_states,
209 | condition_latents=norm_condition_latents if use_cond else None,
210 | image_rotary_emb=image_rotary_emb,
211 | cond_rotary_emb=cond_rotary_emb if use_cond else None,
212 | )
213 | attn_output, context_attn_output = result[:2]
214 | cond_attn_output = result[2] if use_cond else None
215 |
216 | # Process attention outputs for the `hidden_states`.
217 | # 1. hidden_states
218 | attn_output = gate_msa.unsqueeze(1) * attn_output
219 | hidden_states = hidden_states + attn_output
220 | # 2. encoder_hidden_states
221 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
222 | encoder_hidden_states = encoder_hidden_states + context_attn_output
223 | # 3. condition_latents
224 | if use_cond:
225 | cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
226 | condition_latents = condition_latents + cond_attn_output
227 | if model_config.get("add_cond_attn", False):
228 | hidden_states += cond_attn_output
229 |
230 | # LayerNorm + MLP.
231 | # 1. hidden_states
232 | norm_hidden_states = self.norm2(hidden_states)
233 | norm_hidden_states = (
234 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
235 | )
236 | # 2. encoder_hidden_states
237 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
238 | norm_encoder_hidden_states = (
239 | norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
240 | )
241 | # 3. condition_latents
242 | if use_cond:
243 | norm_condition_latents = self.norm2(condition_latents)
244 | norm_condition_latents = (
245 | norm_condition_latents * (1 + cond_scale_mlp[:, None])
246 | + cond_shift_mlp[:, None]
247 | )
248 |
249 | # Feed-forward.
250 | with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)):
251 | # 1. hidden_states
252 | ff_output = self.ff(norm_hidden_states)
253 | ff_output = gate_mlp.unsqueeze(1) * ff_output
254 | # 2. encoder_hidden_states
255 | context_ff_output = self.ff_context(norm_encoder_hidden_states)
256 | context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output
257 | # 3. condition_latents
258 | if use_cond:
259 | cond_ff_output = self.ff(norm_condition_latents)
260 | cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
261 |
262 | # Process feed-forward outputs.
263 | hidden_states = hidden_states + ff_output
264 | encoder_hidden_states = encoder_hidden_states + context_ff_output
265 | if use_cond:
266 | condition_latents = condition_latents + cond_ff_output
267 |
268 | # Clip to avoid overflow.
269 | if encoder_hidden_states.dtype == torch.float16:
270 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
271 |
272 | return encoder_hidden_states, hidden_states, condition_latents if use_cond else None
273 |
274 |
275 | def single_block_forward(
276 | self,
277 | hidden_states: torch.FloatTensor,
278 | temb: torch.FloatTensor,
279 | image_rotary_emb=None,
280 | condition_latents: torch.FloatTensor = None,
281 | cond_temb: torch.FloatTensor = None,
282 | cond_rotary_emb=None,
283 | model_config: Optional[Dict[str, Any]] = {},
284 | ):
285 |
286 | using_cond = condition_latents is not None
287 | residual = hidden_states
288 | with enable_lora(
289 | (
290 | self.norm.linear,
291 | self.proj_mlp,
292 | ),
293 | model_config.get("latent_lora", False),
294 | ):
295 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
296 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
297 | if using_cond:
298 | residual_cond = condition_latents
299 | norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb)
300 | mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents))
301 |
302 | attn_output = attn_forward(
303 | self.attn,
304 | model_config=model_config,
305 | hidden_states=norm_hidden_states,
306 | image_rotary_emb=image_rotary_emb,
307 | **(
308 | {
309 | "condition_latents": norm_condition_latents,
310 | "cond_rotary_emb": cond_rotary_emb if using_cond else None,
311 | }
312 | if using_cond
313 | else {}
314 | ),
315 | )
316 | if using_cond:
317 | attn_output, cond_attn_output = attn_output
318 |
319 | with enable_lora((self.proj_out,), model_config.get("latent_lora", False)):
320 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
321 | gate = gate.unsqueeze(1)
322 | hidden_states = gate * self.proj_out(hidden_states)
323 | hidden_states = residual + hidden_states
324 | if using_cond:
325 | condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
326 | cond_gate = cond_gate.unsqueeze(1)
327 | condition_latents = cond_gate * self.proj_out(condition_latents)
328 | condition_latents = residual_cond + condition_latents
329 |
330 | if hidden_states.dtype == torch.float16:
331 | hidden_states = hidden_states.clip(-65504, 65504)
332 |
333 | return hidden_states if not using_cond else (hidden_states, condition_latents)
334 |
--------------------------------------------------------------------------------
/src/flux/condition.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional, Union, List, Tuple
3 | from diffusers.pipelines import FluxPipeline
4 | from PIL import Image, ImageFilter
5 | import numpy as np
6 | import cv2
7 |
8 | from .pipeline_tools import encode_images
9 |
10 | condition_dict = {
11 | "depth": 0,
12 | "canny": 1,
13 | "subject": 4,
14 | "coloring": 6,
15 | "deblurring": 7,
16 | "depth_pred": 8,
17 | "fill": 9,
18 | "sr": 10,
19 | "cartoon": 11,
20 | }
21 |
22 |
23 | class Condition(object):
24 | def __init__(
25 | self,
26 | condition_type: str,
27 | raw_img: Union[Image.Image, torch.Tensor] = None,
28 | condition: Union[Image.Image, torch.Tensor] = None,
29 | mask=None,
30 | position_delta=None,
31 | ) -> None:
32 | self.condition_type = condition_type
33 | assert raw_img is not None or condition is not None
34 | if raw_img is not None:
35 | self.condition = self.get_condition(condition_type, raw_img)
36 | else:
37 | self.condition = condition
38 | self.position_delta = position_delta
39 | # TODO: Add mask support
40 | assert mask is None, "Mask not supported yet"
41 |
42 | def get_condition(
43 | self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
44 | ) -> Union[Image.Image, torch.Tensor]:
45 | """
46 | Returns the condition image.
47 | """
48 | if condition_type == "depth":
49 | from transformers import pipeline
50 |
51 | depth_pipe = pipeline(
52 | task="depth-estimation",
53 | model="LiheYoung/depth-anything-small-hf",
54 | device="cuda",
55 | )
56 | source_image = raw_img.convert("RGB")
57 | condition_img = depth_pipe(source_image)["depth"].convert("RGB")
58 | return condition_img
59 | elif condition_type == "canny":
60 | img = np.array(raw_img)
61 | edges = cv2.Canny(img, 100, 200)
62 | edges = Image.fromarray(edges).convert("RGB")
63 | return edges
64 | elif condition_type == "subject":
65 | return raw_img
66 | elif condition_type == "coloring":
67 | return raw_img.convert("L").convert("RGB")
68 | elif condition_type == "deblurring":
69 | condition_image = (
70 | raw_img.convert("RGB")
71 | .filter(ImageFilter.GaussianBlur(10))
72 | .convert("RGB")
73 | )
74 | return condition_image
75 | elif condition_type == "fill":
76 | return raw_img.convert("RGB")
77 | elif condition_type == "cartoon":
78 | return raw_img.convert("RGB")
79 | return self.condition
80 |
81 | @property
82 | def type_id(self) -> int:
83 | """
84 | Returns the type id of the condition.
85 | """
86 | return condition_dict[self.condition_type]
87 |
88 | @classmethod
89 | def get_type_id(cls, condition_type: str) -> int:
90 | """
91 | Returns the type id of the condition.
92 | """
93 | return condition_dict[condition_type]
94 |
95 | def encode(self, pipe: FluxPipeline) -> Tuple[torch.Tensor, torch.Tensor, int]:
96 | """
97 | Encodes the condition into tokens, ids and type_id.
98 | """
99 | if self.condition_type in [
100 | "depth",
101 | "canny",
102 | "subject",
103 | "coloring",
104 | "deblurring",
105 | "depth_pred",
106 | "fill",
107 | "sr",
108 | "cartoon"
109 | ]:
110 | tokens, ids = encode_images(pipe, self.condition)
111 | else:
112 | raise NotImplementedError(
113 | f"Condition type {self.condition_type} not implemented"
114 | )
115 | if self.position_delta is None and self.condition_type == "subject":
116 | self.position_delta = [0, -self.condition.size[0] // 16]
117 | if self.position_delta is not None:
118 | ids[:, 1] += self.position_delta[0]
119 | ids[:, 2] += self.position_delta[1]
120 | type_id = torch.ones_like(ids[:, :1]) * self.type_id
121 | return tokens, ids, type_id
122 |
--------------------------------------------------------------------------------
/src/flux/generate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import yaml, os
3 | from diffusers.pipelines import FluxPipeline
4 | from typing import List, Union, Optional, Dict, Any, Callable
5 | from .transformer import tranformer_forward
6 | from .condition import Condition
7 |
8 | from diffusers.pipelines.flux.pipeline_flux import (
9 | FluxPipelineOutput,
10 | calculate_shift,
11 | retrieve_timesteps,
12 | np,
13 | )
14 |
15 |
16 | def get_config(config_path: str = None):
17 | config_path = config_path or os.environ.get("XFL_CONFIG")
18 | if not config_path:
19 | return {}
20 | with open(config_path, "r") as f:
21 | config = yaml.safe_load(f)
22 | return config
23 |
24 |
25 | def prepare_params(
26 | prompt: Union[str, List[str]] = None,
27 | prompt_2: Optional[Union[str, List[str]]] = None,
28 | height: Optional[int] = 512,
29 | width: Optional[int] = 512,
30 | num_inference_steps: int = 28,
31 | timesteps: List[int] = None,
32 | guidance_scale: float = 3.5,
33 | num_images_per_prompt: Optional[int] = 1,
34 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
35 | latents: Optional[torch.FloatTensor] = None,
36 | prompt_embeds: Optional[torch.FloatTensor] = None,
37 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
38 | output_type: Optional[str] = "pil",
39 | return_dict: bool = True,
40 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
41 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
42 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
43 | max_sequence_length: int = 512,
44 | **kwargs: dict,
45 | ):
46 | return (
47 | prompt,
48 | prompt_2,
49 | height,
50 | width,
51 | num_inference_steps,
52 | timesteps,
53 | guidance_scale,
54 | num_images_per_prompt,
55 | generator,
56 | latents,
57 | prompt_embeds,
58 | pooled_prompt_embeds,
59 | output_type,
60 | return_dict,
61 | joint_attention_kwargs,
62 | callback_on_step_end,
63 | callback_on_step_end_tensor_inputs,
64 | max_sequence_length,
65 | )
66 |
67 |
68 | def seed_everything(seed: int = 42):
69 | torch.backends.cudnn.deterministic = True
70 | torch.manual_seed(seed)
71 | np.random.seed(seed)
72 |
73 |
74 | @torch.no_grad()
75 | def generate(
76 | pipeline: FluxPipeline,
77 | conditions: List[Condition] = None,
78 | config_path: str = None,
79 | model_config: Optional[Dict[str, Any]] = {},
80 | condition_scale: float = 1.0,
81 | default_lora: bool = False,
82 | **params: dict,
83 | ):
84 | model_config = model_config or get_config(config_path).get("model", {})
85 | if condition_scale != 1:
86 | for name, module in pipeline.transformer.named_modules():
87 | if not name.endswith(".attn"):
88 | continue
89 | module.c_factor = torch.ones(1, 1) * condition_scale
90 |
91 | self = pipeline
92 | (
93 | prompt,
94 | prompt_2,
95 | height,
96 | width,
97 | num_inference_steps,
98 | timesteps,
99 | guidance_scale,
100 | num_images_per_prompt,
101 | generator,
102 | latents,
103 | prompt_embeds,
104 | pooled_prompt_embeds,
105 | output_type,
106 | return_dict,
107 | joint_attention_kwargs,
108 | callback_on_step_end,
109 | callback_on_step_end_tensor_inputs,
110 | max_sequence_length,
111 | ) = prepare_params(**params)
112 |
113 | height = height or self.default_sample_size * self.vae_scale_factor
114 | width = width or self.default_sample_size * self.vae_scale_factor
115 |
116 | # 1. Check inputs. Raise error if not correct
117 | self.check_inputs(
118 | prompt,
119 | prompt_2,
120 | height,
121 | width,
122 | prompt_embeds=prompt_embeds,
123 | pooled_prompt_embeds=pooled_prompt_embeds,
124 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
125 | max_sequence_length=max_sequence_length,
126 | )
127 |
128 | self._guidance_scale = guidance_scale
129 | self._joint_attention_kwargs = joint_attention_kwargs
130 | self._interrupt = False
131 |
132 | # 2. Define call parameters
133 | if prompt is not None and isinstance(prompt, str):
134 | batch_size = 1
135 | elif prompt is not None and isinstance(prompt, list):
136 | batch_size = len(prompt)
137 | else:
138 | batch_size = prompt_embeds.shape[0]
139 |
140 | device = self._execution_device
141 |
142 | lora_scale = (
143 | self.joint_attention_kwargs.get("scale", None)
144 | if self.joint_attention_kwargs is not None
145 | else None
146 | )
147 | (
148 | prompt_embeds,
149 | pooled_prompt_embeds,
150 | text_ids,
151 | ) = self.encode_prompt(
152 | prompt=prompt,
153 | prompt_2=prompt_2,
154 | prompt_embeds=prompt_embeds,
155 | pooled_prompt_embeds=pooled_prompt_embeds,
156 | device=device,
157 | num_images_per_prompt=num_images_per_prompt,
158 | max_sequence_length=max_sequence_length,
159 | lora_scale=lora_scale,
160 | )
161 |
162 | # 4. Prepare latent variables
163 | num_channels_latents = self.transformer.config.in_channels // 4
164 | latents, latent_image_ids = self.prepare_latents(
165 | batch_size * num_images_per_prompt,
166 | num_channels_latents,
167 | height,
168 | width,
169 | prompt_embeds.dtype,
170 | device,
171 | generator,
172 | latents,
173 | )
174 |
175 | # 4.1. Prepare conditions
176 | condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3))
177 | use_condition = conditions is not None or []
178 | if use_condition:
179 | assert len(conditions) <= 1, "Only one condition is supported for now."
180 | if not default_lora:
181 | pipeline.set_adapters(conditions[0].condition_type)
182 | for condition in conditions:
183 | tokens, ids, type_id = condition.encode(self)
184 | condition_latents.append(tokens) # [batch_size, token_n, token_dim]
185 | condition_ids.append(ids) # [token_n, id_dim(3)]
186 | condition_type_ids.append(type_id) # [token_n, 1]
187 | condition_latents = torch.cat(condition_latents, dim=1)
188 | condition_ids = torch.cat(condition_ids, dim=0)
189 | condition_type_ids = torch.cat(condition_type_ids, dim=0)
190 |
191 | # 5. Prepare timesteps
192 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
193 | image_seq_len = latents.shape[1]
194 | mu = calculate_shift(
195 | image_seq_len,
196 | self.scheduler.config.base_image_seq_len,
197 | self.scheduler.config.max_image_seq_len,
198 | self.scheduler.config.base_shift,
199 | self.scheduler.config.max_shift,
200 | )
201 | timesteps, num_inference_steps = retrieve_timesteps(
202 | self.scheduler,
203 | num_inference_steps,
204 | device,
205 | timesteps,
206 | sigmas,
207 | mu=mu,
208 | )
209 | num_warmup_steps = max(
210 | len(timesteps) - num_inference_steps * self.scheduler.order, 0
211 | )
212 | self._num_timesteps = len(timesteps)
213 |
214 | # 6. Denoising loop
215 | with self.progress_bar(total=num_inference_steps) as progress_bar:
216 | for i, t in enumerate(timesteps):
217 | if self.interrupt:
218 | continue
219 |
220 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
221 | timestep = t.expand(latents.shape[0]).to(latents.dtype)
222 |
223 | # handle guidance
224 | if self.transformer.config.guidance_embeds:
225 | guidance = torch.tensor([guidance_scale], device=device)
226 | guidance = guidance.expand(latents.shape[0])
227 | else:
228 | guidance = None
229 | noise_pred = tranformer_forward(
230 | self.transformer,
231 | model_config=model_config,
232 | # Inputs of the condition (new feature)
233 | condition_latents=condition_latents if use_condition else None,
234 | condition_ids=condition_ids if use_condition else None,
235 | condition_type_ids=condition_type_ids if use_condition else None,
236 | # Inputs to the original transformer
237 | hidden_states=latents,
238 | # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
239 | timestep=timestep / 1000,
240 | guidance=guidance,
241 | pooled_projections=pooled_prompt_embeds,
242 | encoder_hidden_states=prompt_embeds,
243 | txt_ids=text_ids,
244 | img_ids=latent_image_ids,
245 | joint_attention_kwargs=self.joint_attention_kwargs,
246 | return_dict=False,
247 | )[0]
248 |
249 | # compute the previous noisy sample x_t -> x_t-1
250 | latents_dtype = latents.dtype
251 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
252 |
253 | if latents.dtype != latents_dtype:
254 | if torch.backends.mps.is_available():
255 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
256 | latents = latents.to(latents_dtype)
257 |
258 | if callback_on_step_end is not None:
259 | callback_kwargs = {}
260 | for k in callback_on_step_end_tensor_inputs:
261 | callback_kwargs[k] = locals()[k]
262 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
263 |
264 | latents = callback_outputs.pop("latents", latents)
265 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
266 |
267 | # call the callback, if provided
268 | if i == len(timesteps) - 1 or (
269 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
270 | ):
271 | progress_bar.update()
272 |
273 | if output_type == "latent":
274 | image = latents
275 |
276 | else:
277 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
278 | latents = (
279 | latents / self.vae.config.scaling_factor
280 | ) + self.vae.config.shift_factor
281 | image = self.vae.decode(latents, return_dict=False)[0]
282 | image = self.image_processor.postprocess(image, output_type=output_type)
283 |
284 | # Offload all models
285 | self.maybe_free_model_hooks()
286 |
287 | if condition_scale != 1:
288 | for name, module in pipeline.transformer.named_modules():
289 | if not name.endswith(".attn"):
290 | continue
291 | del module.c_factor
292 |
293 | if not return_dict:
294 | return (image,)
295 |
296 | return FluxPipelineOutput(images=image)
297 |
--------------------------------------------------------------------------------
/src/flux/lora_controller.py:
--------------------------------------------------------------------------------
1 | from peft.tuners.tuners_utils import BaseTunerLayer
2 | from typing import List, Any, Optional, Type
3 |
4 |
5 | class enable_lora:
6 | def __init__(self, lora_modules: List[BaseTunerLayer], activated: bool) -> None:
7 | self.activated: bool = activated
8 | if activated:
9 | return
10 | self.lora_modules: List[BaseTunerLayer] = [
11 | each for each in lora_modules if isinstance(each, BaseTunerLayer)
12 | ]
13 | self.scales = [
14 | {
15 | active_adapter: lora_module.scaling[active_adapter]
16 | for active_adapter in lora_module.active_adapters
17 | }
18 | for lora_module in self.lora_modules
19 | ]
20 |
21 | def __enter__(self) -> None:
22 | if self.activated:
23 | return
24 |
25 | for lora_module in self.lora_modules:
26 | if not isinstance(lora_module, BaseTunerLayer):
27 | continue
28 | lora_module.scale_layer(0)
29 |
30 | def __exit__(
31 | self,
32 | exc_type: Optional[Type[BaseException]],
33 | exc_val: Optional[BaseException],
34 | exc_tb: Optional[Any],
35 | ) -> None:
36 | if self.activated:
37 | return
38 | for i, lora_module in enumerate(self.lora_modules):
39 | if not isinstance(lora_module, BaseTunerLayer):
40 | continue
41 | for active_adapter in lora_module.active_adapters:
42 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
43 |
44 |
45 | class set_lora_scale:
46 | def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None:
47 | self.lora_modules: List[BaseTunerLayer] = [
48 | each for each in lora_modules if isinstance(each, BaseTunerLayer)
49 | ]
50 | self.scales = [
51 | {
52 | active_adapter: lora_module.scaling[active_adapter]
53 | for active_adapter in lora_module.active_adapters
54 | }
55 | for lora_module in self.lora_modules
56 | ]
57 | self.scale = scale
58 |
59 | def __enter__(self) -> None:
60 | for lora_module in self.lora_modules:
61 | if not isinstance(lora_module, BaseTunerLayer):
62 | continue
63 | lora_module.scale_layer(self.scale)
64 |
65 | def __exit__(
66 | self,
67 | exc_type: Optional[Type[BaseException]],
68 | exc_val: Optional[BaseException],
69 | exc_tb: Optional[Any],
70 | ) -> None:
71 | for i, lora_module in enumerate(self.lora_modules):
72 | if not isinstance(lora_module, BaseTunerLayer):
73 | continue
74 | for active_adapter in lora_module.active_adapters:
75 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter]
76 |
--------------------------------------------------------------------------------
/src/flux/pipeline_tools.py:
--------------------------------------------------------------------------------
1 | from diffusers.pipelines import FluxPipeline
2 | from diffusers.utils import logging
3 | from diffusers.pipelines.flux.pipeline_flux import logger
4 | from torch import Tensor
5 |
6 |
7 | def encode_images(pipeline: FluxPipeline, images: Tensor):
8 | images = pipeline.image_processor.preprocess(images)
9 | images = images.to(pipeline.device).to(pipeline.dtype)
10 | images = pipeline.vae.encode(images).latent_dist.sample()
11 | images = (
12 | images - pipeline.vae.config.shift_factor
13 | ) * pipeline.vae.config.scaling_factor
14 | images_tokens = pipeline._pack_latents(images, *images.shape)
15 | images_ids = pipeline._prepare_latent_image_ids(
16 | images.shape[0],
17 | images.shape[2],
18 | images.shape[3],
19 | pipeline.device,
20 | pipeline.dtype,
21 | )
22 | if images_tokens.shape[1] != images_ids.shape[0]:
23 | images_ids = pipeline._prepare_latent_image_ids(
24 | images.shape[0],
25 | images.shape[2] // 2,
26 | images.shape[3] // 2,
27 | pipeline.device,
28 | pipeline.dtype,
29 | )
30 | return images_tokens, images_ids
31 |
32 |
33 | def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512):
34 | # Turn off warnings (CLIP overflow)
35 | logger.setLevel(logging.ERROR)
36 | (
37 | prompt_embeds,
38 | pooled_prompt_embeds,
39 | text_ids,
40 | ) = pipeline.encode_prompt(
41 | prompt=prompts,
42 | prompt_2=None,
43 | prompt_embeds=None,
44 | pooled_prompt_embeds=None,
45 | device=pipeline.device,
46 | num_images_per_prompt=1,
47 | max_sequence_length=max_sequence_length,
48 | lora_scale=None,
49 | )
50 | # Turn on warnings
51 | logger.setLevel(logging.WARNING)
52 | return prompt_embeds, pooled_prompt_embeds, text_ids
53 |
--------------------------------------------------------------------------------
/src/flux/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers.pipelines import FluxPipeline
3 | from typing import List, Union, Optional, Dict, Any, Callable
4 | from .block import block_forward, single_block_forward
5 | from .lora_controller import enable_lora
6 | from diffusers.models.transformers.transformer_flux import (
7 | FluxTransformer2DModel,
8 | Transformer2DModelOutput,
9 | USE_PEFT_BACKEND,
10 | is_torch_version,
11 | scale_lora_layers,
12 | unscale_lora_layers,
13 | logger,
14 | )
15 | import numpy as np
16 |
17 |
18 | def prepare_params(
19 | hidden_states: torch.Tensor,
20 | encoder_hidden_states: torch.Tensor = None,
21 | pooled_projections: torch.Tensor = None,
22 | timestep: torch.LongTensor = None,
23 | img_ids: torch.Tensor = None,
24 | txt_ids: torch.Tensor = None,
25 | guidance: torch.Tensor = None,
26 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
27 | controlnet_block_samples=None,
28 | controlnet_single_block_samples=None,
29 | return_dict: bool = True,
30 | **kwargs: dict,
31 | ):
32 | return (
33 | hidden_states,
34 | encoder_hidden_states,
35 | pooled_projections,
36 | timestep,
37 | img_ids,
38 | txt_ids,
39 | guidance,
40 | joint_attention_kwargs,
41 | controlnet_block_samples,
42 | controlnet_single_block_samples,
43 | return_dict,
44 | )
45 |
46 |
47 | def tranformer_forward(
48 | transformer: FluxTransformer2DModel,
49 | condition_latents: torch.Tensor,
50 | condition_ids: torch.Tensor,
51 | condition_type_ids: torch.Tensor,
52 | model_config: Optional[Dict[str, Any]] = {},
53 | c_t=0,
54 | **params: dict,
55 | ):
56 | self = transformer
57 | use_condition = condition_latents is not None
58 |
59 | (
60 | hidden_states,
61 | encoder_hidden_states,
62 | pooled_projections,
63 | timestep,
64 | img_ids,
65 | txt_ids,
66 | guidance,
67 | joint_attention_kwargs,
68 | controlnet_block_samples,
69 | controlnet_single_block_samples,
70 | return_dict,
71 | ) = prepare_params(**params)
72 |
73 | if joint_attention_kwargs is not None:
74 | joint_attention_kwargs = joint_attention_kwargs.copy()
75 | lora_scale = joint_attention_kwargs.pop("scale", 1.0)
76 | else:
77 | lora_scale = 1.0
78 |
79 | if USE_PEFT_BACKEND:
80 | # weight the lora layers by setting `lora_scale` for each PEFT layer
81 | scale_lora_layers(self, lora_scale)
82 | else:
83 | if (
84 | joint_attention_kwargs is not None
85 | and joint_attention_kwargs.get("scale", None) is not None
86 | ):
87 | logger.warning(
88 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
89 | )
90 |
91 | with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)):
92 | hidden_states = self.x_embedder(hidden_states)
93 | condition_latents = self.x_embedder(condition_latents) if use_condition else None
94 |
95 | timestep = timestep.to(hidden_states.dtype) * 1000
96 |
97 | if guidance is not None:
98 | guidance = guidance.to(hidden_states.dtype) * 1000
99 | else:
100 | guidance = None
101 |
102 | temb = (
103 | self.time_text_embed(timestep, pooled_projections)
104 | if guidance is None
105 | else self.time_text_embed(timestep, guidance, pooled_projections)
106 | )
107 |
108 | cond_temb = (
109 | self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections)
110 | if guidance is None
111 | else self.time_text_embed(
112 | torch.ones_like(timestep) * c_t * 1000, guidance, pooled_projections
113 | )
114 | )
115 | encoder_hidden_states = self.context_embedder(encoder_hidden_states)
116 |
117 | if txt_ids.ndim == 3:
118 | logger.warning(
119 | "Passing `txt_ids` 3d torch.Tensor is deprecated."
120 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
121 | )
122 | txt_ids = txt_ids[0]
123 | if img_ids.ndim == 3:
124 | logger.warning(
125 | "Passing `img_ids` 3d torch.Tensor is deprecated."
126 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
127 | )
128 | img_ids = img_ids[0]
129 |
130 | ids = torch.cat((txt_ids, img_ids), dim=0)
131 | image_rotary_emb = self.pos_embed(ids)
132 | if use_condition:
133 | # condition_ids[:, :1] = condition_type_ids
134 | cond_rotary_emb = self.pos_embed(condition_ids)
135 |
136 | # hidden_states = torch.cat([hidden_states, condition_latents], dim=1)
137 |
138 | for index_block, block in enumerate(self.transformer_blocks):
139 | if self.training and self.gradient_checkpointing:
140 | ckpt_kwargs: Dict[str, Any] = (
141 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
142 | )
143 | encoder_hidden_states, hidden_states, condition_latents = (
144 | torch.utils.checkpoint.checkpoint(
145 | block_forward,
146 | self=block,
147 | model_config=model_config,
148 | hidden_states=hidden_states,
149 | encoder_hidden_states=encoder_hidden_states,
150 | condition_latents=condition_latents if use_condition else None,
151 | temb=temb,
152 | cond_temb=cond_temb if use_condition else None,
153 | cond_rotary_emb=cond_rotary_emb if use_condition else None,
154 | image_rotary_emb=image_rotary_emb,
155 | **ckpt_kwargs,
156 | )
157 | )
158 |
159 | else:
160 | encoder_hidden_states, hidden_states, condition_latents = block_forward(
161 | block,
162 | model_config=model_config,
163 | hidden_states=hidden_states,
164 | encoder_hidden_states=encoder_hidden_states,
165 | condition_latents=condition_latents if use_condition else None,
166 | temb=temb,
167 | cond_temb=cond_temb if use_condition else None,
168 | cond_rotary_emb=cond_rotary_emb if use_condition else None,
169 | image_rotary_emb=image_rotary_emb,
170 | )
171 |
172 | # controlnet residual
173 | if controlnet_block_samples is not None:
174 | interval_control = len(self.transformer_blocks) / len(
175 | controlnet_block_samples
176 | )
177 | interval_control = int(np.ceil(interval_control))
178 | hidden_states = (
179 | hidden_states
180 | + controlnet_block_samples[index_block // interval_control]
181 | )
182 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
183 |
184 | for index_block, block in enumerate(self.single_transformer_blocks):
185 | if self.training and self.gradient_checkpointing:
186 | ckpt_kwargs: Dict[str, Any] = (
187 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
188 | )
189 | result = torch.utils.checkpoint.checkpoint(
190 | single_block_forward,
191 | self=block,
192 | model_config=model_config,
193 | hidden_states=hidden_states,
194 | temb=temb,
195 | image_rotary_emb=image_rotary_emb,
196 | **(
197 | {
198 | "condition_latents": condition_latents,
199 | "cond_temb": cond_temb,
200 | "cond_rotary_emb": cond_rotary_emb,
201 | }
202 | if use_condition
203 | else {}
204 | ),
205 | **ckpt_kwargs,
206 | )
207 |
208 | else:
209 | result = single_block_forward(
210 | block,
211 | model_config=model_config,
212 | hidden_states=hidden_states,
213 | temb=temb,
214 | image_rotary_emb=image_rotary_emb,
215 | **(
216 | {
217 | "condition_latents": condition_latents,
218 | "cond_temb": cond_temb,
219 | "cond_rotary_emb": cond_rotary_emb,
220 | }
221 | if use_condition
222 | else {}
223 | ),
224 | )
225 | if use_condition:
226 | hidden_states, condition_latents = result
227 | else:
228 | hidden_states = result
229 |
230 | # controlnet residual
231 | if controlnet_single_block_samples is not None:
232 | interval_control = len(self.single_transformer_blocks) / len(
233 | controlnet_single_block_samples
234 | )
235 | interval_control = int(np.ceil(interval_control))
236 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
237 | hidden_states[:, encoder_hidden_states.shape[1] :, ...]
238 | + controlnet_single_block_samples[index_block // interval_control]
239 | )
240 |
241 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
242 |
243 | hidden_states = self.norm_out(hidden_states, temb)
244 | output = self.proj_out(hidden_states)
245 |
246 | if USE_PEFT_BACKEND:
247 | # remove `lora_scale` from each PEFT layer
248 | unscale_lora_layers(self, lora_scale)
249 |
250 | if not return_dict:
251 | return (output,)
252 | return Transformer2DModelOutput(sample=output)
253 |
--------------------------------------------------------------------------------
/src/gradio/gradio_app.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import torch
3 | from PIL import Image, ImageDraw, ImageFont
4 | from diffusers.pipelines import FluxPipeline
5 | from diffusers import FluxTransformer2DModel
6 | import numpy as np
7 |
8 | from ..flux.condition import Condition
9 | from ..flux.generate import seed_everything, generate
10 |
11 | pipe = None
12 | use_int8 = False
13 |
14 |
15 | def get_gpu_memory():
16 | return torch.cuda.get_device_properties(0).total_memory / 1024**3
17 |
18 |
19 | def init_pipeline():
20 | global pipe
21 | if use_int8 or get_gpu_memory() < 33:
22 | transformer_model = FluxTransformer2DModel.from_pretrained(
23 | "sayakpaul/flux.1-schell-int8wo-improved",
24 | torch_dtype=torch.bfloat16,
25 | use_safetensors=False,
26 | )
27 | pipe = FluxPipeline.from_pretrained(
28 | "black-forest-labs/FLUX.1-schnell",
29 | transformer=transformer_model,
30 | torch_dtype=torch.bfloat16,
31 | )
32 | else:
33 | pipe = FluxPipeline.from_pretrained(
34 | "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
35 | )
36 | pipe = pipe.to("cuda")
37 | pipe.load_lora_weights(
38 | "Yuanshi/OminiControl",
39 | weight_name="omini/subject_512.safetensors",
40 | adapter_name="subject",
41 | )
42 |
43 |
44 | def process_image_and_text(image, text):
45 | # center crop image
46 | w, h, min_size = image.size[0], image.size[1], min(image.size)
47 | image = image.crop(
48 | (
49 | (w - min_size) // 2,
50 | (h - min_size) // 2,
51 | (w + min_size) // 2,
52 | (h + min_size) // 2,
53 | )
54 | )
55 | image = image.resize((512, 512))
56 |
57 | condition = Condition("subject", image, position_delta=(0, 32))
58 |
59 | if pipe is None:
60 | init_pipeline()
61 |
62 | result_img = generate(
63 | pipe,
64 | prompt=text.strip(),
65 | conditions=[condition],
66 | num_inference_steps=8,
67 | height=512,
68 | width=512,
69 | ).images[0]
70 |
71 | return result_img
72 |
73 |
74 | def get_samples():
75 | sample_list = [
76 | {
77 | "image": "assets/oranges.jpg",
78 | "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'",
79 | },
80 | {
81 | "image": "assets/penguin.jpg",
82 | "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'",
83 | },
84 | {
85 | "image": "assets/rc_car.jpg",
86 | "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.",
87 | },
88 | {
89 | "image": "assets/clock.jpg",
90 | "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.",
91 | },
92 | {
93 | "image": "assets/tshirt.jpg",
94 | "text": "On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her.",
95 | },
96 | ]
97 | return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list]
98 |
99 |
100 | demo = gr.Interface(
101 | fn=process_image_and_text,
102 | inputs=[
103 | gr.Image(type="pil"),
104 | gr.Textbox(lines=2),
105 | ],
106 | outputs=gr.Image(type="pil"),
107 | title="OminiControl / Subject driven generation",
108 | examples=get_samples(),
109 | )
110 |
111 | if __name__ == "__main__":
112 | init_pipeline()
113 | demo.launch(
114 | debug=True,
115 | )
116 |
--------------------------------------------------------------------------------
/src/train/callbacks.py:
--------------------------------------------------------------------------------
1 | import lightning as L
2 | from PIL import Image, ImageFilter, ImageDraw
3 | import numpy as np
4 | from transformers import pipeline
5 | import cv2
6 | import torch
7 | import os
8 |
9 | try:
10 | import wandb
11 | except ImportError:
12 | wandb = None
13 |
14 | from ..flux.condition import Condition
15 | from ..flux.generate import generate
16 |
17 |
18 | class TrainingCallback(L.Callback):
19 | def __init__(self, run_name, training_config: dict = {}):
20 | self.run_name, self.training_config = run_name, training_config
21 |
22 | self.print_every_n_steps = training_config.get("print_every_n_steps", 10)
23 | self.save_interval = training_config.get("save_interval", 1000)
24 | self.sample_interval = training_config.get("sample_interval", 1000)
25 | self.save_path = training_config.get("save_path", "./output")
26 |
27 | self.wandb_config = training_config.get("wandb", None)
28 | self.use_wandb = (
29 | wandb is not None and os.environ.get("WANDB_API_KEY") is not None
30 | )
31 |
32 | self.total_steps = 0
33 |
34 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
35 | gradient_size = 0
36 | max_gradient_size = 0
37 | count = 0
38 | for _, param in pl_module.named_parameters():
39 | if param.grad is not None:
40 | gradient_size += param.grad.norm(2).item()
41 | max_gradient_size = max(max_gradient_size, param.grad.norm(2).item())
42 | count += 1
43 | if count > 0:
44 | gradient_size /= count
45 |
46 | self.total_steps += 1
47 |
48 | # Print training progress every n steps
49 | if self.use_wandb:
50 | report_dict = {
51 | "steps": batch_idx,
52 | "steps": self.total_steps,
53 | "epoch": trainer.current_epoch,
54 | "gradient_size": gradient_size,
55 | }
56 | loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches
57 | report_dict["loss"] = loss_value
58 | report_dict["t"] = pl_module.last_t
59 | wandb.log(report_dict)
60 |
61 | if self.total_steps % self.print_every_n_steps == 0:
62 | print(
63 | f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}"
64 | )
65 |
66 | # Save LoRA weights at specified intervals
67 | if self.total_steps % self.save_interval == 0:
68 | print(
69 | f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights"
70 | )
71 | pl_module.save_lora(
72 | f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}"
73 | )
74 |
75 | # Generate and save a sample image at specified intervals
76 | if self.total_steps % self.sample_interval == 0:
77 | print(
78 | f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample"
79 | )
80 | self.generate_a_sample(
81 | trainer,
82 | pl_module,
83 | f"{self.save_path}/{self.run_name}/output",
84 | f"lora_{self.total_steps}",
85 | batch["condition_type"][
86 | 0
87 | ], # Use the condition type from the current batch
88 | )
89 |
90 | @torch.no_grad()
91 | def generate_a_sample(
92 | self,
93 | trainer,
94 | pl_module,
95 | save_path,
96 | file_name,
97 | condition_type="super_resolution",
98 | ):
99 | # TODO: change this two variables to parameters
100 | condition_size = trainer.training_config["dataset"]["condition_size"]
101 | target_size = trainer.training_config["dataset"]["target_size"]
102 |
103 | generator = torch.Generator(device=pl_module.device)
104 | generator.manual_seed(42)
105 |
106 | test_list = []
107 |
108 | if condition_type == "subject":
109 | test_list.extend(
110 | [
111 | (
112 | Image.open("assets/test_in.jpg"),
113 | [0, -32],
114 | "Resting on the picnic table at a lakeside campsite, it's caught in the golden glow of early morning, with mist rising from the water and tall pines casting long shadows behind the scene.",
115 | ),
116 | (
117 | Image.open("assets/test_out.jpg"),
118 | [0, -32],
119 | "In a bright room. It is placed on a table.",
120 | ),
121 | ]
122 | )
123 | elif condition_type == "canny":
124 | condition_img = Image.open("assets/vase_hq.jpg").resize(
125 | (condition_size, condition_size)
126 | )
127 | condition_img = np.array(condition_img)
128 | condition_img = cv2.Canny(condition_img, 100, 200)
129 | condition_img = Image.fromarray(condition_img).convert("RGB")
130 | test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
131 | elif condition_type == "coloring":
132 | condition_img = (
133 | Image.open("assets/vase_hq.jpg")
134 | .resize((condition_size, condition_size))
135 | .convert("L")
136 | .convert("RGB")
137 | )
138 | test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
139 | elif condition_type == "depth":
140 | if not hasattr(self, "deepth_pipe"):
141 | self.deepth_pipe = pipeline(
142 | task="depth-estimation",
143 | model="LiheYoung/depth-anything-small-hf",
144 | device="cpu",
145 | )
146 | condition_img = (
147 | Image.open("assets/vase_hq.jpg")
148 | .resize((condition_size, condition_size))
149 | .convert("RGB")
150 | )
151 | condition_img = self.deepth_pipe(condition_img)["depth"].convert("RGB")
152 | test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
153 | elif condition_type == "depth_pred":
154 | condition_img = (
155 | Image.open("assets/vase_hq.jpg")
156 | .resize((condition_size, condition_size))
157 | .convert("RGB")
158 | )
159 | test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
160 | elif condition_type == "deblurring":
161 | blur_radius = 5
162 | image = Image.open("./assets/vase_hq.jpg")
163 | condition_img = (
164 | image.convert("RGB")
165 | .resize((condition_size, condition_size))
166 | .filter(ImageFilter.GaussianBlur(blur_radius))
167 | .convert("RGB")
168 | )
169 | test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
170 | elif condition_type == "fill":
171 | condition_img = (
172 | Image.open("./assets/vase_hq.jpg")
173 | .resize((condition_size, condition_size))
174 | .convert("RGB")
175 | )
176 | mask = Image.new("L", condition_img.size, 0)
177 | draw = ImageDraw.Draw(mask)
178 | a = condition_img.size[0] // 4
179 | b = a * 3
180 | draw.rectangle([a, a, b, b], fill=255)
181 | condition_img = Image.composite(
182 | condition_img, Image.new("RGB", condition_img.size, (0, 0, 0)), mask
183 | )
184 | test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
185 | elif condition_type == "sr":
186 | condition_img = (
187 | Image.open("assets/vase_hq.jpg")
188 | .resize((condition_size, condition_size))
189 | .convert("RGB")
190 | )
191 | test_list.append((condition_img, [0, -16], "A beautiful vase on a table."))
192 | elif condition_type == "cartoon":
193 | condition_img = (
194 | Image.open("assets/cartoon_boy.png")
195 | .resize((condition_size, condition_size))
196 | .convert("RGB")
197 | )
198 | test_list.append((condition_img, [0, -16], "A cartoon character in a white background. He is looking right, and running."))
199 | else:
200 | raise NotImplementedError
201 |
202 | if not os.path.exists(save_path):
203 | os.makedirs(save_path)
204 | for i, (condition_img, position_delta, prompt) in enumerate(test_list):
205 | condition = Condition(
206 | condition_type=condition_type,
207 | condition=condition_img.resize(
208 | (condition_size, condition_size)
209 | ).convert("RGB"),
210 | position_delta=position_delta,
211 | )
212 | res = generate(
213 | pl_module.flux_pipe,
214 | prompt=prompt,
215 | conditions=[condition],
216 | height=target_size,
217 | width=target_size,
218 | generator=generator,
219 | model_config=pl_module.model_config,
220 | default_lora=True,
221 | )
222 | res.images[0].save(
223 | os.path.join(save_path, f"{file_name}_{condition_type}_{i}.jpg")
224 | )
225 |
--------------------------------------------------------------------------------
/src/train/data.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageFilter, ImageDraw
2 | import cv2
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 | import torchvision.transforms as T
6 | import random
7 |
8 |
9 | class Subject200KDataset(Dataset):
10 | def __init__(
11 | self,
12 | base_dataset,
13 | condition_size: int = 512,
14 | target_size: int = 512,
15 | image_size: int = 512,
16 | padding: int = 0,
17 | condition_type: str = "subject",
18 | drop_text_prob: float = 0.1,
19 | drop_image_prob: float = 0.1,
20 | return_pil_image: bool = False,
21 | ):
22 | self.base_dataset = base_dataset
23 | self.condition_size = condition_size
24 | self.target_size = target_size
25 | self.image_size = image_size
26 | self.padding = padding
27 | self.condition_type = condition_type
28 | self.drop_text_prob = drop_text_prob
29 | self.drop_image_prob = drop_image_prob
30 | self.return_pil_image = return_pil_image
31 |
32 | self.to_tensor = T.ToTensor()
33 |
34 | def __len__(self):
35 | return len(self.base_dataset) * 2
36 |
37 | def __getitem__(self, idx):
38 | # If target is 0, left image is target, right image is condition
39 | target = idx % 2
40 | item = self.base_dataset[idx // 2]
41 |
42 | # Crop the image to target and condition
43 | image = item["image"]
44 | left_img = image.crop(
45 | (
46 | self.padding,
47 | self.padding,
48 | self.image_size + self.padding,
49 | self.image_size + self.padding,
50 | )
51 | )
52 | right_img = image.crop(
53 | (
54 | self.image_size + self.padding * 2,
55 | self.padding,
56 | self.image_size * 2 + self.padding * 2,
57 | self.image_size + self.padding,
58 | )
59 | )
60 |
61 | # Get the target and condition image
62 | target_image, condition_img = (
63 | (left_img, right_img) if target == 0 else (right_img, left_img)
64 | )
65 |
66 | # Resize the image
67 | condition_img = condition_img.resize(
68 | (self.condition_size, self.condition_size)
69 | ).convert("RGB")
70 | target_image = target_image.resize(
71 | (self.target_size, self.target_size)
72 | ).convert("RGB")
73 |
74 | # Get the description
75 | description = item["description"][
76 | "description_0" if target == 0 else "description_1"
77 | ]
78 |
79 | # Randomly drop text or image
80 | drop_text = random.random() < self.drop_text_prob
81 | drop_image = random.random() < self.drop_image_prob
82 | if drop_text:
83 | description = ""
84 | if drop_image:
85 | condition_img = Image.new(
86 | "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
87 | )
88 |
89 | return {
90 | "image": self.to_tensor(target_image),
91 | "condition": self.to_tensor(condition_img),
92 | "condition_type": self.condition_type,
93 | "description": description,
94 | # 16 is the downscale factor of the image
95 | "position_delta": np.array([0, -self.condition_size // 16]),
96 | **({"pil_image": image} if self.return_pil_image else {}),
97 | }
98 |
99 |
100 | class ImageConditionDataset(Dataset):
101 | def __init__(
102 | self,
103 | base_dataset,
104 | condition_size: int = 512,
105 | target_size: int = 512,
106 | condition_type: str = "canny",
107 | drop_text_prob: float = 0.1,
108 | drop_image_prob: float = 0.1,
109 | return_pil_image: bool = False,
110 | ):
111 | self.base_dataset = base_dataset
112 | self.condition_size = condition_size
113 | self.target_size = target_size
114 | self.condition_type = condition_type
115 | self.drop_text_prob = drop_text_prob
116 | self.drop_image_prob = drop_image_prob
117 | self.return_pil_image = return_pil_image
118 |
119 | self.to_tensor = T.ToTensor()
120 |
121 | def __len__(self):
122 | return len(self.base_dataset)
123 |
124 | @property
125 | def depth_pipe(self):
126 | if not hasattr(self, "_depth_pipe"):
127 | from transformers import pipeline
128 |
129 | self._depth_pipe = pipeline(
130 | task="depth-estimation",
131 | model="LiheYoung/depth-anything-small-hf",
132 | device="cpu",
133 | )
134 | return self._depth_pipe
135 |
136 | def _get_canny_edge(self, img):
137 | resize_ratio = self.condition_size / max(img.size)
138 | img = img.resize(
139 | (int(img.size[0] * resize_ratio), int(img.size[1] * resize_ratio))
140 | )
141 | img_np = np.array(img)
142 | img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
143 | edges = cv2.Canny(img_gray, 100, 200)
144 | return Image.fromarray(edges).convert("RGB")
145 |
146 | def __getitem__(self, idx):
147 | image = self.base_dataset[idx]["jpg"]
148 | image = image.resize((self.target_size, self.target_size)).convert("RGB")
149 | description = self.base_dataset[idx]["json"]["prompt"]
150 |
151 | # Get the condition image
152 | position_delta = np.array([0, 0])
153 | if self.condition_type == "canny":
154 | condition_img = self._get_canny_edge(image)
155 | elif self.condition_type == "coloring":
156 | condition_img = (
157 | image.resize((self.condition_size, self.condition_size))
158 | .convert("L")
159 | .convert("RGB")
160 | )
161 | elif self.condition_type == "deblurring":
162 | blur_radius = random.randint(1, 10)
163 | condition_img = (
164 | image.convert("RGB")
165 | .resize((self.condition_size, self.condition_size))
166 | .filter(ImageFilter.GaussianBlur(blur_radius))
167 | .convert("RGB")
168 | )
169 | elif self.condition_type == "depth":
170 | condition_img = self.depth_pipe(image)["depth"].convert("RGB")
171 | elif self.condition_type == "depth_pred":
172 | condition_img = image
173 | image = self.depth_pipe(condition_img)["depth"].convert("RGB")
174 | description = f"[depth] {description}"
175 | elif self.condition_type == "fill":
176 | condition_img = image.resize(
177 | (self.condition_size, self.condition_size)
178 | ).convert("RGB")
179 | w, h = image.size
180 | x1, x2 = sorted([random.randint(0, w), random.randint(0, w)])
181 | y1, y2 = sorted([random.randint(0, h), random.randint(0, h)])
182 | mask = Image.new("L", image.size, 0)
183 | draw = ImageDraw.Draw(mask)
184 | draw.rectangle([x1, y1, x2, y2], fill=255)
185 | if random.random() > 0.5:
186 | mask = Image.eval(mask, lambda a: 255 - a)
187 | condition_img = Image.composite(
188 | image, Image.new("RGB", image.size, (0, 0, 0)), mask
189 | )
190 | elif self.condition_type == "sr":
191 | condition_img = image.resize(
192 | (self.condition_size, self.condition_size)
193 | ).convert("RGB")
194 | position_delta = np.array([0, -self.condition_size // 16])
195 |
196 | else:
197 | raise ValueError(f"Condition type {self.condition_type} not implemented")
198 |
199 | # Randomly drop text or image
200 | drop_text = random.random() < self.drop_text_prob
201 | drop_image = random.random() < self.drop_image_prob
202 | if drop_text:
203 | description = ""
204 | if drop_image:
205 | condition_img = Image.new(
206 | "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
207 | )
208 |
209 | return {
210 | "image": self.to_tensor(image),
211 | "condition": self.to_tensor(condition_img),
212 | "condition_type": self.condition_type,
213 | "description": description,
214 | "position_delta": position_delta,
215 | **({"pil_image": [image, condition_img]} if self.return_pil_image else {}),
216 | }
217 |
218 |
219 | class CartoonDataset(Dataset):
220 | def __init__(
221 | self,
222 | base_dataset,
223 | condition_size: int = 1024,
224 | target_size: int = 1024,
225 | image_size: int = 1024,
226 | padding: int = 0,
227 | condition_type: str = "cartoon",
228 | drop_text_prob: float = 0.1,
229 | drop_image_prob: float = 0.1,
230 | return_pil_image: bool = False,
231 | ):
232 | self.base_dataset = base_dataset
233 | self.condition_size = condition_size
234 | self.target_size = target_size
235 | self.image_size = image_size
236 | self.padding = padding
237 | self.condition_type = condition_type
238 | self.drop_text_prob = drop_text_prob
239 | self.drop_image_prob = drop_image_prob
240 | self.return_pil_image = return_pil_image
241 |
242 | self.to_tensor = T.ToTensor()
243 |
244 |
245 | def __len__(self):
246 | return len(self.base_dataset)
247 |
248 | def __getitem__(self, idx):
249 | data = self.base_dataset[idx]
250 | condition_img = data['condition']
251 | target_image = data['target']
252 |
253 | # Tag
254 | tag = data['tags'][0]
255 |
256 | target_description = data['target_description']
257 |
258 | description = {
259 | "lion": "lion like animal",
260 | "bear": "bear like animal",
261 | "gorilla": "gorilla like animal",
262 | "dog": "dog like animal",
263 | "elephant": "elephant like animal",
264 | "eagle": "eagle like bird",
265 | "tiger": "tiger like animal",
266 | "owl": "owl like bird",
267 | "woman": "woman",
268 | "parrot": "parrot like bird",
269 | "mouse": "mouse like animal",
270 | "man": "man",
271 | "pigeon": "pigeon like bird",
272 | "girl": "girl",
273 | "panda": "panda like animal",
274 | "crocodile": "crocodile like animal",
275 | "rabbit": "rabbit like animal",
276 | "boy": "boy",
277 | "monkey": "monkey like animal",
278 | "cat": "cat like animal"
279 | }
280 |
281 | # Resize the image
282 | condition_img = condition_img.resize(
283 | (self.condition_size, self.condition_size)
284 | ).convert("RGB")
285 | target_image = target_image.resize(
286 | (self.target_size, self.target_size)
287 | ).convert("RGB")
288 |
289 | # Process datum to create description
290 | description = data.get("description", f"Photo of a {description[tag]} cartoon character in a white background. Character is facing {target_description['facing_direction']}. Character pose is {target_description['pose']}.")
291 |
292 | # Randomly drop text or image
293 | drop_text = random.random() < self.drop_text_prob
294 | drop_image = random.random() < self.drop_image_prob
295 | if drop_text:
296 | description = ""
297 | if drop_image:
298 | condition_img = Image.new(
299 | "RGB", (self.condition_size, self.condition_size), (0, 0, 0)
300 | )
301 |
302 |
303 | return {
304 | "image": self.to_tensor(target_image),
305 | "condition": self.to_tensor(condition_img),
306 | "condition_type": self.condition_type,
307 | "description": description,
308 | # 16 is the downscale factor of the image
309 | "position_delta": np.array([0, -16]),
310 | }
--------------------------------------------------------------------------------
/src/train/model.py:
--------------------------------------------------------------------------------
1 | import lightning as L
2 | from diffusers.pipelines import FluxPipeline
3 | import torch
4 | from peft import LoraConfig, get_peft_model_state_dict
5 |
6 | import prodigyopt
7 |
8 | from ..flux.transformer import tranformer_forward
9 | from ..flux.condition import Condition
10 | from ..flux.pipeline_tools import encode_images, prepare_text_input
11 |
12 |
13 | class OminiModel(L.LightningModule):
14 | def __init__(
15 | self,
16 | flux_pipe_id: str,
17 | lora_path: str = None,
18 | lora_config: dict = None,
19 | device: str = "cuda",
20 | dtype: torch.dtype = torch.bfloat16,
21 | model_config: dict = {},
22 | optimizer_config: dict = None,
23 | gradient_checkpointing: bool = False,
24 | ):
25 | # Initialize the LightningModule
26 | super().__init__()
27 | self.model_config = model_config
28 | self.optimizer_config = optimizer_config
29 |
30 | # Load the Flux pipeline
31 | self.flux_pipe: FluxPipeline = (
32 | FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device)
33 | )
34 | self.transformer = self.flux_pipe.transformer
35 | self.transformer.gradient_checkpointing = gradient_checkpointing
36 | self.transformer.train()
37 |
38 | # Freeze the Flux pipeline
39 | self.flux_pipe.text_encoder.requires_grad_(False).eval()
40 | self.flux_pipe.text_encoder_2.requires_grad_(False).eval()
41 | self.flux_pipe.vae.requires_grad_(False).eval()
42 |
43 | # Initialize LoRA layers
44 | self.lora_layers = self.init_lora(lora_path, lora_config)
45 |
46 | self.to(device).to(dtype)
47 |
48 | def init_lora(self, lora_path: str, lora_config: dict):
49 | assert lora_path or lora_config
50 | if lora_path:
51 | # TODO: Implement this
52 | raise NotImplementedError
53 | else:
54 | self.transformer.add_adapter(LoraConfig(**lora_config))
55 | # TODO: Check if this is correct (p.requires_grad)
56 | lora_layers = filter(
57 | lambda p: p.requires_grad, self.transformer.parameters()
58 | )
59 | return list(lora_layers)
60 |
61 | def save_lora(self, path: str):
62 | FluxPipeline.save_lora_weights(
63 | save_directory=path,
64 | transformer_lora_layers=get_peft_model_state_dict(self.transformer),
65 | safe_serialization=True,
66 | )
67 |
68 | def configure_optimizers(self):
69 | # Freeze the transformer
70 | self.transformer.requires_grad_(False)
71 | opt_config = self.optimizer_config
72 |
73 | # Set the trainable parameters
74 | self.trainable_params = self.lora_layers
75 |
76 | # Unfreeze trainable parameters
77 | for p in self.trainable_params:
78 | p.requires_grad_(True)
79 |
80 | # Initialize the optimizer
81 | if opt_config["type"] == "AdamW":
82 | optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"])
83 | elif opt_config["type"] == "Prodigy":
84 | optimizer = prodigyopt.Prodigy(
85 | self.trainable_params,
86 | **opt_config["params"],
87 | )
88 | elif opt_config["type"] == "SGD":
89 | optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"])
90 | else:
91 | raise NotImplementedError
92 |
93 | return optimizer
94 |
95 | def training_step(self, batch, batch_idx):
96 | step_loss = self.step(batch)
97 | self.log_loss = (
98 | step_loss.item()
99 | if not hasattr(self, "log_loss")
100 | else self.log_loss * 0.95 + step_loss.item() * 0.05
101 | )
102 | return step_loss
103 |
104 | def step(self, batch):
105 | imgs = batch["image"]
106 | conditions = batch["condition"]
107 | condition_types = batch["condition_type"]
108 | prompts = batch["description"]
109 | position_delta = batch["position_delta"][0]
110 |
111 | # Prepare inputs
112 | with torch.no_grad():
113 | # Prepare image input
114 | x_0, img_ids = encode_images(self.flux_pipe, imgs)
115 |
116 | # Prepare text input
117 | prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input(
118 | self.flux_pipe, prompts
119 | )
120 |
121 | # Prepare t and x_t
122 | t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
123 | x_1 = torch.randn_like(x_0).to(self.device)
124 | t_ = t.unsqueeze(1).unsqueeze(1)
125 | x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
126 |
127 | # Prepare conditions
128 | condition_latents, condition_ids = encode_images(self.flux_pipe, conditions)
129 |
130 | # Add position delta
131 | condition_ids[:, 1] += position_delta[0]
132 | condition_ids[:, 2] += position_delta[1]
133 |
134 | # Prepare condition type
135 | condition_type_ids = torch.tensor(
136 | [
137 | Condition.get_type_id(condition_type)
138 | for condition_type in condition_types
139 | ]
140 | ).to(self.device)
141 | condition_type_ids = (
142 | torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0]
143 | ).unsqueeze(1)
144 |
145 | # Prepare guidance
146 | guidance = (
147 | torch.ones_like(t).to(self.device)
148 | if self.transformer.config.guidance_embeds
149 | else None
150 | )
151 |
152 | # Forward pass
153 | transformer_out = tranformer_forward(
154 | self.transformer,
155 | # Model config
156 | model_config=self.model_config,
157 | # Inputs of the condition (new feature)
158 | condition_latents=condition_latents,
159 | condition_ids=condition_ids,
160 | condition_type_ids=condition_type_ids,
161 | # Inputs to the original transformer
162 | hidden_states=x_t,
163 | timestep=t,
164 | guidance=guidance,
165 | pooled_projections=pooled_prompt_embeds,
166 | encoder_hidden_states=prompt_embeds,
167 | txt_ids=text_ids,
168 | img_ids=img_ids,
169 | joint_attention_kwargs=None,
170 | return_dict=False,
171 | )
172 | pred = transformer_out[0]
173 |
174 | # Compute loss
175 | loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean")
176 | self.last_t = t.mean().item()
177 | return loss
178 |
--------------------------------------------------------------------------------
/src/train/train.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | import torch
3 | import lightning as L
4 | import yaml
5 | import os
6 | import time
7 |
8 | from datasets import load_dataset
9 |
10 | from .data import (
11 | ImageConditionDataset,
12 | Subject200KDataset,
13 | CartoonDataset
14 | )
15 | from .model import OminiModel
16 | from .callbacks import TrainingCallback
17 |
18 |
19 | def get_rank():
20 | try:
21 | rank = int(os.environ.get("LOCAL_RANK"))
22 | except:
23 | rank = 0
24 | return rank
25 |
26 |
27 | def get_config():
28 | config_path = os.environ.get("XFL_CONFIG")
29 | assert config_path is not None, "Please set the XFL_CONFIG environment variable"
30 | with open(config_path, "r") as f:
31 | config = yaml.safe_load(f)
32 | return config
33 |
34 |
35 | def init_wandb(wandb_config, run_name):
36 | import wandb
37 |
38 | try:
39 | assert os.environ.get("WANDB_API_KEY") is not None
40 | wandb.init(
41 | project=wandb_config["project"],
42 | name=run_name,
43 | config={},
44 | )
45 | except Exception as e:
46 | print("Failed to initialize WanDB:", e)
47 |
48 |
49 | def main():
50 | # Initialize
51 | is_main_process, rank = get_rank() == 0, get_rank()
52 | torch.cuda.set_device(rank)
53 | config = get_config()
54 | training_config = config["train"]
55 | run_name = time.strftime("%Y%m%d-%H%M%S")
56 |
57 | # Initialize WanDB
58 | wandb_config = training_config.get("wandb", None)
59 | if wandb_config is not None and is_main_process:
60 | init_wandb(wandb_config, run_name)
61 |
62 | print("Rank:", rank)
63 | if is_main_process:
64 | print("Config:", config)
65 |
66 | # Initialize dataset and dataloader
67 | if training_config["dataset"]["type"] == "subject":
68 | dataset = load_dataset("Yuanshi/Subjects200K")
69 |
70 | # Define filter function
71 | def filter_func(item):
72 | if not item.get("quality_assessment"):
73 | return False
74 | return all(
75 | item["quality_assessment"].get(key, 0) >= 5
76 | for key in ["compositeStructure", "objectConsistency", "imageQuality"]
77 | )
78 |
79 | # Filter dataset
80 | if not os.path.exists("./cache/dataset"):
81 | os.makedirs("./cache/dataset")
82 | data_valid = dataset["train"].filter(
83 | filter_func,
84 | num_proc=16,
85 | cache_file_name="./cache/dataset/data_valid.arrow",
86 | )
87 | dataset = Subject200KDataset(
88 | data_valid,
89 | condition_size=training_config["dataset"]["condition_size"],
90 | target_size=training_config["dataset"]["target_size"],
91 | image_size=training_config["dataset"]["image_size"],
92 | padding=training_config["dataset"]["padding"],
93 | condition_type=training_config["condition_type"],
94 | drop_text_prob=training_config["dataset"]["drop_text_prob"],
95 | drop_image_prob=training_config["dataset"]["drop_image_prob"],
96 | )
97 | elif training_config["dataset"]["type"] == "img":
98 | # Load dataset text-to-image-2M
99 | dataset = load_dataset(
100 | "webdataset",
101 | data_files={"train": training_config["dataset"]["urls"]},
102 | split="train",
103 | cache_dir="cache/t2i2m",
104 | num_proc=32,
105 | )
106 | dataset = ImageConditionDataset(
107 | dataset,
108 | condition_size=training_config["dataset"]["condition_size"],
109 | target_size=training_config["dataset"]["target_size"],
110 | condition_type=training_config["condition_type"],
111 | drop_text_prob=training_config["dataset"]["drop_text_prob"],
112 | drop_image_prob=training_config["dataset"]["drop_image_prob"],
113 | )
114 | elif training_config["dataset"]["type"] == "cartoon":
115 | dataset = load_dataset("saquiboye/oye-cartoon", split="train")
116 | dataset = CartoonDataset(
117 | dataset,
118 | condition_size=training_config["dataset"]["condition_size"],
119 | target_size=training_config["dataset"]["target_size"],
120 | image_size=training_config["dataset"]["image_size"],
121 | padding=training_config["dataset"]["padding"],
122 | condition_type=training_config["condition_type"],
123 | drop_text_prob=training_config["dataset"]["drop_text_prob"],
124 | drop_image_prob=training_config["dataset"]["drop_image_prob"],
125 | )
126 | else:
127 | raise NotImplementedError
128 |
129 | print("Dataset length:", len(dataset))
130 | train_loader = DataLoader(
131 | dataset,
132 | batch_size=training_config["batch_size"],
133 | shuffle=True,
134 | num_workers=training_config["dataloader_workers"],
135 | )
136 |
137 | # Initialize model
138 | trainable_model = OminiModel(
139 | flux_pipe_id=config["flux_path"],
140 | lora_config=training_config["lora_config"],
141 | device=f"cuda",
142 | dtype=getattr(torch, config["dtype"]),
143 | optimizer_config=training_config["optimizer"],
144 | model_config=config.get("model", {}),
145 | gradient_checkpointing=training_config.get("gradient_checkpointing", False),
146 | )
147 |
148 | # Callbacks for logging and saving checkpoints
149 | training_callbacks = (
150 | [TrainingCallback(run_name, training_config=training_config)]
151 | if is_main_process
152 | else []
153 | )
154 |
155 | # Initialize trainer
156 | trainer = L.Trainer(
157 | accumulate_grad_batches=training_config["accumulate_grad_batches"],
158 | callbacks=training_callbacks,
159 | enable_checkpointing=False,
160 | enable_progress_bar=False,
161 | logger=False,
162 | max_steps=training_config.get("max_steps", -1),
163 | max_epochs=training_config.get("max_epochs", -1),
164 | gradient_clip_val=training_config.get("gradient_clip_val", 0.5),
165 | )
166 |
167 | setattr(trainer, "training_config", training_config)
168 |
169 | # Save config
170 | save_path = training_config.get("save_path", "./output")
171 | if is_main_process:
172 | os.makedirs(f"{save_path}/{run_name}")
173 | with open(f"{save_path}/{run_name}/config.yaml", "w") as f:
174 | yaml.dump(config, f)
175 |
176 | # Start training
177 | trainer.fit(trainable_model, train_loader)
178 |
179 |
180 | if __name__ == "__main__":
181 | main()
182 |
--------------------------------------------------------------------------------
/train/README.md:
--------------------------------------------------------------------------------
1 | # OminiControl Training 🛠️
2 |
3 | ## Preparation
4 |
5 | ### Setup
6 | 1. **Environment**
7 | ```bash
8 | conda create -n omini python=3.10
9 | conda activate omini
10 | ```
11 | 2. **Requirements**
12 | ```bash
13 | pip install -r train/requirements.txt
14 | ```
15 |
16 | ### Dataset
17 | 1. Download dataset [Subject200K](https://huggingface.co/datasets/Yuanshi/Subjects200K). (**subject-driven generation**)
18 | ```
19 | bash train/script/data_download/data_download1.sh
20 | ```
21 | 2. Download dataset [text-to-image-2M](https://huggingface.co/datasets/jackyhate/text-to-image-2M). (**spatial control task**)
22 | ```
23 | bash train/script/data_download/data_download2.sh
24 | ```
25 | **Note:** By default, only a few files are downloaded. You can modify `data_download2.sh` to download additional datasets. Remember to update the config file to specify the training data accordingly.
26 |
27 | ## Training
28 |
29 | ### Start training training
30 | **Config file path**: `./train/config`
31 |
32 | **Scripts path**: `./train/script`
33 |
34 | 1. Subject-driven generation
35 | ```bash
36 | bash train/script/train_subject.sh
37 | ```
38 | 2. Spatial control task
39 | ```bash
40 | bash train/script/train_canny.sh
41 | ```
42 |
43 | **Note**: Detailed WanDB settings and GPU settings can be found in the script files and the config files.
44 |
45 | ### Other spatial control tasks
46 | This repository supports 5 spatial control tasks:
47 | 1. Canny edge to image (`canny`)
48 | 2. Image colorization (`coloring`)
49 | 3. Image deblurring (`deblurring`)
50 | 4. Depth map to image (`depth`)
51 | 5. Image to depth map (`depth_pred`)
52 | 6. Image inpainting (`fill`)
53 | 7. Super resolution (`sr`)
54 |
55 | You can modify the `condition_type` parameter in config file `config/canny_512.yaml` to switch between different tasks.
56 |
57 | ### Customize your own task
58 | You can customize your own task by constructing a new dataset and modifying the training code.
59 |
60 |
61 | Instructions
62 |
63 | 1. **Dataset** :
64 |
65 | Construct a new dataset with the following format: (`src/train/data.py`)
66 | ```python
67 | class MyDataset(Dataset):
68 | def __init__(self, ...):
69 | ...
70 | def __len__(self):
71 | ...
72 | def __getitem__(self, idx):
73 | ...
74 | return {
75 | "image": image,
76 | "condition": condition_img,
77 | "condition_type": "your_condition_type",
78 | "description": description,
79 | "position_delta": position_delta
80 | }
81 | ```
82 | **Note:** For spatial control tasks, set the `position_delta` to be `[0, 0]`. For non-spatial control tasks, set `position_delta` to be `[0, -condition_width // 16]`.
83 | 2. **Condition**:
84 |
85 | Add a new condition type in the `Condition` class. (`src/flux/condition.py`)
86 | ```python
87 | condition_dict = {
88 | ...
89 | "your_condition_type": your_condition_id_number, # Add your condition type here
90 | }
91 | ...
92 | if condition_type in [
93 | ...
94 | "your_condition_type", # Add your condition type here
95 | ]:
96 | ...
97 | ```
98 | 3. **Test**:
99 |
100 | Add a new test function for your task. (`src/train/callbacks.py`)
101 | ```python
102 | if self.condition_type == "your_condition_type":
103 | condition_img = (
104 | Image.open("images/vase.jpg")
105 | .resize((condition_size, condition_size))
106 | .convert("RGB")
107 | )
108 | ...
109 | test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
110 | ```
111 |
112 | 4. **Import relevant dataset in the training script**
113 | Update the file in the following section. (`src/train/train.py`)
114 | ```python
115 | from .data import (
116 | ImageConditionDataset,
117 | Subject200KDateset,
118 | MyDataset
119 | )
120 | ...
121 |
122 | # Initialize dataset and dataloader
123 | if training_config["dataset"]["type"] == "your_condition_type":
124 | ...
125 | ```
126 |
127 |
128 |
129 | ## Hardware requirement
130 | **Note**: Memory optimization (like dynamic T5 model loading) is pending implementation.
131 |
132 | **Recommanded**
133 | - Hardware: 2x NVIDIA H100 GPUs
134 | - Memory: ~80GB GPU memory
135 |
136 | **Minimal**
137 | - Hardware: 1x NVIDIA L20 GPU
138 | - Memory: ~48GB GPU memory
--------------------------------------------------------------------------------
/train/config/canny_512.yaml:
--------------------------------------------------------------------------------
1 | flux_path: "black-forest-labs/FLUX.1-dev"
2 | dtype: "bfloat16"
3 |
4 | model:
5 | union_cond_attn: true
6 | add_cond_attn: false
7 | latent_lora: false
8 |
9 | train:
10 | batch_size: 1
11 | accumulate_grad_batches: 1
12 | dataloader_workers: 5
13 | save_interval: 1000
14 | sample_interval: 100
15 | max_steps: -1
16 | gradient_checkpointing: true
17 | save_path: "runs"
18 |
19 | # Specify the type of condition to use.
20 | # Options: ["canny", "coloring", "deblurring", "depth", "depth_pred", "fill"]
21 | condition_type: "canny"
22 | dataset:
23 | type: "img"
24 | urls:
25 | - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000046.tar"
26 | - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000045.tar"
27 | cache_name: "data_512_2M"
28 | condition_size: 512
29 | target_size: 512
30 | drop_text_prob: 0.1
31 | drop_image_prob: 0.1
32 |
33 | wandb:
34 | project: "OminiControl"
35 |
36 | lora_config:
37 | r: 4
38 | lora_alpha: 4
39 | init_lora_weights: "gaussian"
40 | target_modules: "(.*x_embedder|.*(?