├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── app.py
├── assets
├── cat_cafe.png
├── clock.png
├── crystal_ball.png
├── cup.png
├── examples
│ ├── 0_0.json
│ ├── 0_0.png
│ ├── 1_0.json
│ ├── 1_0.png
│ ├── 1one2one
│ │ ├── config.json
│ │ ├── ref1.jpg
│ │ └── result.png
│ ├── 2_0.json
│ ├── 2_0.png
│ ├── 2one2one
│ │ ├── config.json
│ │ ├── ref1.png
│ │ └── result.png
│ ├── 3two2one
│ │ ├── config.json
│ │ ├── ref1.png
│ │ ├── ref2.png
│ │ └── result.png
│ ├── 4two2one
│ │ ├── config.json
│ │ ├── ref1.png
│ │ ├── ref2.png
│ │ └── result.png
│ ├── 5many2one
│ │ ├── config.json
│ │ ├── ref1.png
│ │ ├── ref2.png
│ │ ├── ref3.png
│ │ └── result.png
│ └── 6t2i
│ │ ├── config.json
│ │ └── result.png
├── figurine.png
├── logo.png
├── simplecase.jpeg
├── simplecase.jpg
└── teaser.jpg
├── config
└── deepspeed
│ ├── zero2_config.json
│ └── zero3_config.json
├── datasets
├── dreambench_multiip.json
└── dreambench_singleip.json
├── inference.py
├── pyproject.toml
├── requirements.txt
├── train.py
└── uno
├── dataset
└── uno.py
├── flux
├── math.py
├── model.py
├── modules
│ ├── autoencoder.py
│ ├── conditioner.py
│ └── layers.py
├── pipeline.py
├── sampling.py
└── util.py
└── utils
└── convert_yaml_to_args_file.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # Ruff stuff:
171 | .ruff_cache/
172 |
173 | # PyPI configuration file
174 | .pypirc
175 |
176 | # User config files
177 | .vscode/
178 | output/
179 |
180 | # ckpt
181 | *.bin
182 | *.pt
183 | *.pth
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "datasets/dreambooth"]
2 | path = datasets/dreambooth
3 | url = https://github.com/google/dreambooth.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | Less-to-More Generalization: Unlocking More Controllability by In-Context Generation
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | > Shaojin Wu, Mengqi Huang*, Wenxu Wu, Yufeng Cheng, Fei Ding+, Qian He
15 | >Intelligent Creation Team, ByteDance
16 |
17 |
18 |
20 |
21 |
22 | ## 🔥 News
23 | - [04/16/2024] 🔥 Our companion project [RealCustom](https://github.com/bytedance/RealCustom) is released.
24 | - [04/10/2025] 🔥 Update fp8 mode as a primary low vmemory usage support. Gift for consumer-grade GPU users. The peak Vmemory usage is ~16GB now. We may try further inference optimization later.
25 | - [04/03/2025] 🔥 The [demo](https://huggingface.co/spaces/bytedance-research/UNO-FLUX) of UNO is released.
26 | - [04/03/2025] 🔥 The [training code](https://github.com/bytedance/UNO), [inference code](https://github.com/bytedance/UNO), and [model](https://huggingface.co/bytedance-research/UNO) of UNO are released.
27 | - [04/02/2025] 🔥 The [project page](https://bytedance.github.io/UNO) of UNO is created.
28 | - [04/02/2025] 🔥 The arXiv [paper](https://arxiv.org/abs/2504.02160) of UNO is released.
29 |
30 | ## 📖 Introduction
31 | In this study, we propose a highly-consistent data synthesis pipeline to tackle this challenge. This pipeline harnesses the intrinsic in-context generation capabilities of diffusion transformers and generates high-consistency multi-subject paired data. Additionally, we introduce UNO, which consists of progressive cross-modal alignment and universal rotary position embedding. It is a multi-image conditioned subject-to-image model iteratively trained from a text-to-image model. Extensive experiments show that our method can achieve high consistency while ensuring controllability in both single-subject and multi-subject driven generation.
32 |
33 |
34 | ## ⚡️ Quick Start
35 |
36 | ### 🔧 Requirements and Installation
37 |
38 | Install the requirements
39 | ```bash
40 | # pip install -r requirements.txt # legacy installation command
41 |
42 | ## create a virtual environment with python >= 3.10 <= 3.12, like
43 | # python -m venv uno_env
44 | # source uno_env/bin/activate
45 | # or
46 | # conda create -n uno_env python=3.10 -y
47 | # conda activate uno_env
48 | # then install the requirements by you need
49 |
50 | # !!! if you are using amd GPU/NV RTX50 series/macos MPS, you should install the correct torch version by yourself first
51 | # !!! then run the install command
52 | pip install -e . # for who wanna to run the demo/inference only
53 | pip install -e .[train] # for who also want to train the model
54 | ```
55 |
56 | then download checkpoints in one of the three ways:
57 | 1. Directly run the inference scripts, the checkpoints will be downloaded automatically by the `hf_hub_download` function in the code to your `$HF_HOME`(the default value is `~/.cache/huggingface`).
58 | 2. use `huggingface-cli download ` to download `black-forest-labs/FLUX.1-dev`, `xlabs-ai/xflux_text_encoders`, `openai/clip-vit-large-patch14`, `bytedance-research/UNO`, then run the inference scripts. You can just download the checkpoint in need only to speed up your set up and save your disk space. i.e. for `black-forest-labs/FLUX.1-dev` use `huggingface-cli download black-forest-labs/FLUX.1-dev flux1-dev.safetensors` and `huggingface-cli download black-forest-labs/FLUX.1-dev ae.safetensors`, ignoreing the text encoder in `black-forest-labes/FLUX.1-dev` model repo(They are here for `diffusers` call). All of the checkpoints will take 37 GB of disk space.
59 | 3. use `huggingface-cli download --local-dir ` to download all the checkpoints mentioned in 2. to the directories your want. Then set the environment variable `AE`, `FLUX_DEV`(or `FLUX_DEV_FP8` if you use fp8 mode), `T5`, `CLIP`, `LORA` to the corresponding paths. Finally, run the inference scripts.
60 | 4. **If you already have some of the checkpoints**, you can set the environment variable `AE`, `FLUX_DEV`, `T5`, `CLIP`, `LORA` to the corresponding paths. Finally, run the inference scripts.
61 |
62 | ### 🌟 Gradio Demo
63 |
64 | ```bash
65 | python app.py
66 | ```
67 |
68 | **For low vmemory usage**, please pass the `--offload` and `--name flux-dev-fp8` args. The peak memory usage will be 16GB. Just for reference, the end2end inference time is 40s to 1min on RTX 3090 in fp8 and offload mode.
69 |
70 | ```bash
71 | python app.py --offload --name flux-dev-fp8
72 | ```
73 |
74 |
75 | ### ✍️ Inference
76 | Start from the examples below to explore and spark your creativity. ✨
77 | ```bash
78 | python inference.py --prompt "A clock on the beach is under a red sun umbrella" --image_paths "assets/clock.png" --width 704 --height 704
79 | python inference.py --prompt "The figurine is in the crystal ball" --image_paths "assets/figurine.png" "assets/crystal_ball.png" --width 704 --height 704
80 | python inference.py --prompt "The logo is printed on the cup" --image_paths "assets/cat_cafe.png" "assets/cup.png" --width 704 --height 704
81 | ```
82 |
83 | Optional prepreration: If you want to test the inference on dreambench at the first time, you should clone the submodule `dreambench` to download the dataset.
84 |
85 | ```bash
86 | git submodule update --init
87 | ```
88 | Then running the following scripts:
89 | ```bash
90 | # evaluated on dreambench
91 | ## for single-subject
92 | python inference.py --eval_json_path ./datasets/dreambench_singleip.json
93 | ## for multi-subject
94 | python inference.py --eval_json_path ./datasets/dreambench_multiip.json
95 | ```
96 |
97 |
98 |
99 | ### 🚄 Training
100 |
101 | ```bash
102 | accelerate launch train.py
103 | ```
104 |
105 |
106 | ### 📌 Tips and Notes
107 | We integrate single-subject and multi-subject generation within a unified model. For single-subject scenarios, the longest side of the reference image is set to 512 by default, while for multi-subject scenarios, it is set to 320. UNO demonstrates remarkable flexibility across various aspect ratios, thanks to its training on a multi-scale dataset. Despite being trained within 512 buckets, it can handle higher resolutions, including 512, 568, and 704, among others.
108 |
109 | UNO excels in subject-driven generation but has room for improvement in generalization due to dataset constraints. We are actively developing an enhanced model—stay tuned for updates. Your feedback is valuable, so please feel free to share any suggestions.
110 |
111 | ## 🎨 Application Scenarios
112 |
113 |
115 |
116 |
117 | ## 📄 Disclaimer
118 |
119 | We open-source this project for academic research. The vast majority of images
120 | used in this project are either generated or licensed. If you have any concerns,
121 | please contact us, and we will promptly remove any inappropriate content.
122 | Our code is released under the Apache 2.0 License,, while our models are under
123 | the CC BY-NC 4.0 License. Any models related to FLUX.1-dev
124 | base model must adhere to the original licensing terms.
125 |
This research aims to advance the field of generative AI. Users are free to
126 | create images using this tool, provided they comply with local laws and exercise
127 | responsible usage. The developers are not liable for any misuse of the tool by users.
128 |
129 | ## 🚀 Updates
130 | For the purpose of fostering research and the open-source community, we plan to open-source the entire project, encompassing training, inference, weights, etc. Thank you for your patience and support! 🌟
131 | - [x] Release github repo.
132 | - [x] Release inference code.
133 | - [x] Release training code.
134 | - [x] Release model checkpoints.
135 | - [x] Release arXiv paper.
136 | - [x] Release huggingface space demo.
137 | - [ ] Release in-context data generation pipelines.
138 |
139 | ## Related resources
140 |
141 | **ComfyUI**
142 |
143 | - https://github.com/jax-explorer/ComfyUI-UNO a ComfyUI node implementation of UNO by jax-explorer.
144 | - https://github.com/HM-RunningHub/ComfyUI_RH_UNO a ComfyUI node implementation of UNO by HM-RunningHub.
145 | - https://github.com/ShmuelRonen/ComfyUI-UNO-Wrapper a ComfyUI node implementation of UNO by ShmuelRonen.
146 | - https://github.com/Yuan-ManX/ComfyUI-UNO a ComfyUI node implementation of UNO by Yuan-ManX.
147 | - https://github.com/QijiTec/ComfyUI-RED-UNO a ComfyUI node implementation of UNO by QijiTec.
148 |
149 | We thanks the passionate community contributors, since we have reviced many requests about comfyui, but there aren't so much time to make so many adaptations by ourselves. if you wanna try our work in comfyui, you can try the above repos. Remember, they are slightly different, so you may need some trail and error to make find the best match repo for you.
150 |
151 | ## Citation
152 | If UNO is helpful, please help to ⭐ the repo.
153 |
154 | If you find this project useful for your research, please consider citing our paper:
155 | ```bibtex
156 | @article{wu2025less,
157 | title={Less-to-More Generalization: Unlocking More Controllability by In-Context Generation},
158 | author={Wu, Shaojin and Huang, Mengqi and Wu, Wenxu and Cheng, Yufeng and Ding, Fei and He, Qian},
159 | journal={arXiv preprint arXiv:2504.02160},
160 | year={2025}
161 | }
162 | ```
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import dataclasses
16 | import json
17 | from pathlib import Path
18 |
19 | import gradio as gr
20 | import torch
21 |
22 | from uno.flux.pipeline import UNOPipeline
23 |
24 |
25 | def get_examples(examples_dir: str = "assets/examples") -> list:
26 | examples = Path(examples_dir)
27 | ans = []
28 | for example in examples.iterdir():
29 | if not example.is_dir():
30 | continue
31 | with open(example / "config.json") as f:
32 | example_dict = json.load(f)
33 |
34 |
35 | example_list = []
36 |
37 | example_list.append(example_dict["useage"]) # case for
38 | example_list.append(example_dict["prompt"]) # prompt
39 |
40 | for key in ["image_ref1", "image_ref2", "image_ref3", "image_ref4"]:
41 | if key in example_dict:
42 | example_list.append(str(example / example_dict[key]))
43 | else:
44 | example_list.append(None)
45 |
46 | example_list.append(example_dict["seed"])
47 |
48 | ans.append(example_list)
49 | return ans
50 |
51 |
52 | def create_demo(
53 | model_type: str,
54 | device: str = "cuda" if torch.cuda.is_available() else "cpu",
55 | offload: bool = False,
56 | ):
57 | pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
58 |
59 | badges_text = r"""
60 |
61 |

62 |

63 |

64 |

65 |

66 |
67 | """.strip()
68 |
69 | with gr.Blocks() as demo:
70 | gr.Markdown(f"# UNO by UNO team")
71 | gr.Markdown(badges_text)
72 | with gr.Row():
73 | with gr.Column():
74 | prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
75 | with gr.Row():
76 | image_prompt1 = gr.Image(label="Ref Img1", visible=True, interactive=True, type="pil")
77 | image_prompt2 = gr.Image(label="Ref Img2", visible=True, interactive=True, type="pil")
78 | image_prompt3 = gr.Image(label="Ref Img3", visible=True, interactive=True, type="pil")
79 | image_prompt4 = gr.Image(label="Ref img4", visible=True, interactive=True, type="pil")
80 |
81 | with gr.Row():
82 | with gr.Column():
83 | width = gr.Slider(512, 2048, 512, step=16, label="Gneration Width")
84 | height = gr.Slider(512, 2048, 512, step=16, label="Gneration Height")
85 | with gr.Column():
86 | gr.Markdown("📌 The model trained on 512x512 resolution.\n")
87 | gr.Markdown(
88 | "The size closer to 512 is more stable,"
89 | " and the higher size gives a better visual effect but is less stable"
90 | )
91 |
92 | with gr.Accordion("Advanced Options", open=False):
93 | with gr.Row():
94 | num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
95 | guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
96 | seed = gr.Number(-1, label="Seed (-1 for random)")
97 |
98 | generate_btn = gr.Button("Generate")
99 |
100 | with gr.Column():
101 | output_image = gr.Image(label="Generated Image")
102 | download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
103 |
104 |
105 | inputs = [
106 | prompt, width, height, guidance, num_steps,
107 | seed, image_prompt1, image_prompt2, image_prompt3, image_prompt4
108 | ]
109 | generate_btn.click(
110 | fn=pipeline.gradio_generate,
111 | inputs=inputs,
112 | outputs=[output_image, download_btn],
113 | )
114 |
115 | example_text = gr.Text("", visible=False, label="Case For:")
116 | examples = get_examples("./assets/examples")
117 |
118 | gr.Examples(
119 | examples=examples,
120 | inputs=[
121 | example_text, prompt,
122 | image_prompt1, image_prompt2, image_prompt3, image_prompt4,
123 | seed, output_image
124 | ],
125 | )
126 |
127 | return demo
128 |
129 | if __name__ == "__main__":
130 | from typing import Literal
131 |
132 | from transformers import HfArgumentParser
133 |
134 | @dataclasses.dataclass
135 | class AppArgs:
136 | name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
137 | device: Literal["cuda", "cpu"] = (
138 | "cuda" if torch.cuda.is_available() \
139 | else "mps" if torch.backends.mps.is_available() \
140 | else "cpu"
141 | )
142 | offload: bool = dataclasses.field(
143 | default=False,
144 | metadata={"help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."}
145 | )
146 | port: int = 7860
147 |
148 | parser = HfArgumentParser([AppArgs])
149 | args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs]
150 | args = args_tuple[0]
151 |
152 | demo = create_demo(args.name, args.device, args.offload)
153 | demo.launch(server_port=args.port)
154 |
--------------------------------------------------------------------------------
/assets/cat_cafe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/cat_cafe.png
--------------------------------------------------------------------------------
/assets/clock.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/clock.png
--------------------------------------------------------------------------------
/assets/crystal_ball.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/crystal_ball.png
--------------------------------------------------------------------------------
/assets/cup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/cup.png
--------------------------------------------------------------------------------
/assets/examples/0_0.json:
--------------------------------------------------------------------------------
1 | {
2 | "prompt": "A clock on the beach is under a red sun umbrella",
3 | "image_paths": [
4 | "assets/clock.png"
5 | ],
6 | "eval_json_path": null,
7 | "offload": false,
8 | "num_images_per_prompt": 1,
9 | "model_type": "flux-dev",
10 | "width": 704,
11 | "height": 704,
12 | "ref_size": 512,
13 | "num_steps": 25,
14 | "guidance": 4,
15 | "seed": 3407,
16 | "save_path": "output/inference",
17 | "only_lora": true,
18 | "concat_refs": false,
19 | "lora_rank": 512,
20 | "data_resolution": 512,
21 | "pe": "d"
22 | }
--------------------------------------------------------------------------------
/assets/examples/0_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/0_0.png
--------------------------------------------------------------------------------
/assets/examples/1_0.json:
--------------------------------------------------------------------------------
1 | {
2 | "prompt": "The figurine is in the crystal ball",
3 | "image_paths": [
4 | "assets/figurine.png",
5 | "assets/crystal_ball.png"
6 | ],
7 | "eval_json_path": null,
8 | "offload": false,
9 | "num_images_per_prompt": 1,
10 | "model_type": "flux-dev",
11 | "width": 704,
12 | "height": 704,
13 | "ref_size": 320,
14 | "num_steps": 25,
15 | "guidance": 4,
16 | "seed": 3407,
17 | "save_path": "output/inference",
18 | "only_lora": true,
19 | "concat_refs": false,
20 | "lora_rank": 512,
21 | "data_resolution": 512,
22 | "pe": "d"
23 | }
--------------------------------------------------------------------------------
/assets/examples/1_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/1_0.png
--------------------------------------------------------------------------------
/assets/examples/1one2one/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "prompt": "A clock on the beach is under a red sun umbrella",
3 | "seed": 0,
4 | "ref_long_side": 512,
5 | "useage": "one2one",
6 | "image_ref1": "./ref1.jpg",
7 | "image_result": "./result.png"
8 | }
--------------------------------------------------------------------------------
/assets/examples/1one2one/ref1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/1one2one/ref1.jpg
--------------------------------------------------------------------------------
/assets/examples/1one2one/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/1one2one/result.png
--------------------------------------------------------------------------------
/assets/examples/2_0.json:
--------------------------------------------------------------------------------
1 | {
2 | "prompt": "The logo is printed on the cup",
3 | "image_paths": [
4 | "assets/cat_cafe.png",
5 | "assets/cup.png"
6 | ],
7 | "eval_json_path": null,
8 | "offload": false,
9 | "num_images_per_prompt": 1,
10 | "model_type": "flux-dev",
11 | "width": 704,
12 | "height": 704,
13 | "ref_size": 320,
14 | "num_steps": 25,
15 | "guidance": 4,
16 | "seed": 3407,
17 | "save_path": "output/inference",
18 | "only_lora": true,
19 | "concat_refs": false,
20 | "lora_rank": 512,
21 | "data_resolution": 512,
22 | "pe": "d"
23 | }
--------------------------------------------------------------------------------
/assets/examples/2_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/2_0.png
--------------------------------------------------------------------------------
/assets/examples/2one2one/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "prompt": "A pretty woman wears a flower petal dress, in the flower",
3 | "seed": 1,
4 | "ref_long_side": 512,
5 | "useage": "one2one",
6 | "image_ref1": "./ref1.png",
7 | "image_result": "./result.png"
8 | }
--------------------------------------------------------------------------------
/assets/examples/2one2one/ref1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/2one2one/ref1.png
--------------------------------------------------------------------------------
/assets/examples/2one2one/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/2one2one/result.png
--------------------------------------------------------------------------------
/assets/examples/3two2one/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "prompt": "The figurine is in the crystal ball",
3 | "seed": 0,
4 | "ref_long_side": 320,
5 | "useage": "two2one",
6 | "image_ref1": "./ref1.png",
7 | "image_ref2": "./ref2.png",
8 | "image_result": "./result.png"
9 | }
--------------------------------------------------------------------------------
/assets/examples/3two2one/ref1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/3two2one/ref1.png
--------------------------------------------------------------------------------
/assets/examples/3two2one/ref2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/3two2one/ref2.png
--------------------------------------------------------------------------------
/assets/examples/3two2one/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/3two2one/result.png
--------------------------------------------------------------------------------
/assets/examples/4two2one/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "prompt": "The logo is printed on the cup",
3 | "seed": 61733557,
4 | "ref_long_side": 320,
5 | "useage": "two2one",
6 | "image_ref1": "./ref1.png",
7 | "image_ref2": "./ref2.png",
8 | "image_result": "./result.png"
9 | }
--------------------------------------------------------------------------------
/assets/examples/4two2one/ref1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/4two2one/ref1.png
--------------------------------------------------------------------------------
/assets/examples/4two2one/ref2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/4two2one/ref2.png
--------------------------------------------------------------------------------
/assets/examples/4two2one/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/4two2one/result.png
--------------------------------------------------------------------------------
/assets/examples/5many2one/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "prompt": "A woman wears the dress and holds a bag, in the flowers.",
3 | "seed": 37635012,
4 | "ref_long_side": 320,
5 | "useage": "many2one",
6 | "image_ref1": "./ref1.png",
7 | "image_ref2": "./ref2.png",
8 | "image_ref3": "./ref3.png",
9 | "image_result": "./result.png"
10 | }
--------------------------------------------------------------------------------
/assets/examples/5many2one/ref1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/5many2one/ref1.png
--------------------------------------------------------------------------------
/assets/examples/5many2one/ref2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/5many2one/ref2.png
--------------------------------------------------------------------------------
/assets/examples/5many2one/ref3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/5many2one/ref3.png
--------------------------------------------------------------------------------
/assets/examples/5many2one/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/5many2one/result.png
--------------------------------------------------------------------------------
/assets/examples/6t2i/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "prompt": "A woman wears the dress and holds a bag, in the flowers.",
3 | "seed": 37635012,
4 | "ref_long_side": 512,
5 | "useage": "t2i",
6 | "image_result": "./result.png"
7 | }
--------------------------------------------------------------------------------
/assets/examples/6t2i/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/6t2i/result.png
--------------------------------------------------------------------------------
/assets/figurine.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/figurine.png
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/logo.png
--------------------------------------------------------------------------------
/assets/simplecase.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/simplecase.jpeg
--------------------------------------------------------------------------------
/assets/simplecase.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/simplecase.jpg
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/teaser.jpg
--------------------------------------------------------------------------------
/config/deepspeed/zero2_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "zero_optimization": {
6 | "stage": 2,
7 | "offload_optimizer": {
8 | "device": "none"
9 | },
10 | "contiguous_gradients": true,
11 | "overlap_comm": true
12 | },
13 | "train_micro_batch_size_per_gpu": 1,
14 | "gradient_accumulation_steps": "auto"
15 | }
16 |
--------------------------------------------------------------------------------
/config/deepspeed/zero3_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 |
11 | "zero_optimization": {
12 | "stage": 3,
13 | "offload_optimizer": {
14 | "device": "cpu",
15 | "pin_memory": true
16 | },
17 | "offload_param": {
18 | "device": "cpu",
19 | "pin_memory": true
20 | },
21 | "overlap_comm": true,
22 | "contiguous_gradients": true,
23 | "reduce_bucket_size": 16777216,
24 | "stage3_prefetch_bucket_size": 15099494,
25 | "stage3_param_persistence_threshold": 40960,
26 | "sub_group_size": 1e9,
27 | "stage3_max_live_parameters": 1e9,
28 | "stage3_max_reuse_distance": 1e9,
29 | "stage3_gather_16bit_weights_on_model_save": true
30 | },
31 | "gradient_accumulation_steps": "auto",
32 | "train_micro_batch_size_per_gpu": 1
33 | }
34 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import dataclasses
17 | from typing import Literal
18 |
19 | from accelerate import Accelerator
20 | from transformers import HfArgumentParser
21 | from PIL import Image
22 | import json
23 | import itertools
24 |
25 | from uno.flux.pipeline import UNOPipeline, preprocess_ref
26 |
27 |
28 | def horizontal_concat(images):
29 | widths, heights = zip(*(img.size for img in images))
30 |
31 | total_width = sum(widths)
32 | max_height = max(heights)
33 |
34 | new_im = Image.new('RGB', (total_width, max_height))
35 |
36 | x_offset = 0
37 | for img in images:
38 | new_im.paste(img, (x_offset, 0))
39 | x_offset += img.size[0]
40 |
41 | return new_im
42 |
43 | @dataclasses.dataclass
44 | class InferenceArgs:
45 | prompt: str | None = None
46 | image_paths: list[str] | None = None
47 | eval_json_path: str | None = None
48 | offload: bool = False
49 | num_images_per_prompt: int = 1
50 | model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
51 | width: int = 512
52 | height: int = 512
53 | ref_size: int = -1
54 | num_steps: int = 25
55 | guidance: float = 4
56 | seed: int = 3407
57 | save_path: str = "output/inference"
58 | only_lora: bool = True
59 | concat_refs: bool = False
60 | lora_rank: int = 512
61 | data_resolution: int = 512
62 | pe: Literal['d', 'h', 'w', 'o'] = 'd'
63 |
64 | def main(args: InferenceArgs):
65 | accelerator = Accelerator()
66 |
67 | pipeline = UNOPipeline(
68 | args.model_type,
69 | accelerator.device,
70 | args.offload,
71 | only_lora=args.only_lora,
72 | lora_rank=args.lora_rank
73 | )
74 |
75 | assert args.prompt is not None or args.eval_json_path is not None, \
76 | "Please provide either prompt or eval_json_path"
77 |
78 | if args.eval_json_path is not None:
79 | with open(args.eval_json_path, "rt") as f:
80 | data_dicts = json.load(f)
81 | data_root = os.path.dirname(args.eval_json_path)
82 | else:
83 | data_root = "./"
84 | data_dicts = [{"prompt": args.prompt, "image_paths": args.image_paths}]
85 |
86 | for (i, data_dict), j in itertools.product(enumerate(data_dicts), range(args.num_images_per_prompt)):
87 | if (i * args.num_images_per_prompt + j) % accelerator.num_processes != accelerator.process_index:
88 | continue
89 |
90 | ref_imgs = [
91 | Image.open(os.path.join(data_root, img_path))
92 | for img_path in data_dict["image_paths"]
93 | ]
94 | if args.ref_size==-1:
95 | args.ref_size = 512 if len(ref_imgs)==1 else 320
96 |
97 | ref_imgs = [preprocess_ref(img, args.ref_size) for img in ref_imgs]
98 |
99 | image_gen = pipeline(
100 | prompt=data_dict["prompt"],
101 | width=args.width,
102 | height=args.height,
103 | guidance=args.guidance,
104 | num_steps=args.num_steps,
105 | seed=args.seed + j,
106 | ref_imgs=ref_imgs,
107 | pe=args.pe,
108 | )
109 | if args.concat_refs:
110 | image_gen = horizontal_concat([image_gen, *ref_imgs])
111 |
112 | os.makedirs(args.save_path, exist_ok=True)
113 | image_gen.save(os.path.join(args.save_path, f"{i}_{j}.png"))
114 |
115 | # save config and image
116 | args_dict = vars(args)
117 | args_dict['prompt'] = data_dict["prompt"]
118 | args_dict['image_paths'] = data_dict["image_paths"]
119 | with open(os.path.join(args.save_path, f"{i}_{j}.json"), 'w') as f:
120 | json.dump(args_dict, f, indent=4)
121 |
122 | if __name__ == "__main__":
123 | parser = HfArgumentParser([InferenceArgs])
124 | args = parser.parse_args_into_dataclasses()[0]
125 | main(args)
126 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "setuptools-scm"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "uno"
7 | version = "0.0.1"
8 | authors = [
9 | { name="Bytedance Ltd. and/or its affiliates" },
10 | ]
11 | maintainers = [
12 | {name = "Wu Shaojin", email = "wushaojin@bytedance.com"},
13 | {name = "Huang Mengqi", email = "huangmengqi.98@bytedance.com"},
14 | {name = "Wu Wenxu", email = "wuwenxu.01@bytedance.com"},
15 | {name = "Cheng Yufeng", email = "chengyufeng.cb1@bytedance.com"},
16 | ]
17 |
18 | description = "🔥🔥 UNO: A Universal Customization Method for Both Single and Multi-Subject Conditioning"
19 | readme = "README.md"
20 | requires-python = ">=3.10, <=3.12"
21 | classifiers = [
22 | "Programming Language :: Python :: 3",
23 | "Operating System :: OS Independent",
24 | ]
25 | license = "Apache-2.0"
26 | license-files = ["LICENSE"]
27 |
28 |
29 | dependencies = [
30 | "torch>=2.4.0",
31 | "torchvision>=0.19.0",
32 | "einops>=0.8.0",
33 | "transformers>=4.43.3",
34 | "huggingface-hub",
35 | "diffusers>=0.30.1",
36 | "sentencepiece==0.2.0",
37 | "gradio>=5.22.0",
38 | ]
39 |
40 | [project.optional-dependencies]
41 |
42 | train = [
43 | "accelerate==1.1.1",
44 | "deepspeed==0.14.4",
45 | ]
46 |
47 | dev = [
48 | "ruff",
49 | ]
50 |
51 |
52 | [project.urls]
53 | Repository = "https://github.com/bytedance/UNO"
54 | ProjectPage = "https://bytedance.github.io/UNO"
55 | Models = "https://huggingface.co/bytedance-research/UNO"
56 | Demo = "https://huggingface.co/spaces/bytedance-research/UNO-FLUX"
57 | Arxiv = "https://arxiv.org/abs/2504.02160"
58 |
59 |
60 | [tool.setuptools.packages.find]
61 | where = [""]
62 | namespaces = false # to disable scanning PEP 420 namespaces (true by default)
63 |
64 | [tool.ruff]
65 | include = ["uno/**/*.py"]
66 | line-length = 120
67 | indent-width = 4
68 | target-version = "py310"
69 | show-fixes = true
70 |
71 | [tool.ruff.lint]
72 | select = ["E4", "E7", "E9", "F", "I"]
73 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ## after update to pyproject.toml, the only usage of requirements.txt is to install the dependencies in huggingface demo, so comment out the training dependencies
2 | # accelerate==1.1.1
3 | # deepspeed==0.14.4
4 | einops==0.8.0
5 | transformers==4.43.3
6 | huggingface-hub
7 | diffusers==0.30.1
8 | sentencepiece==0.2.0
9 | gradio==5.22.0
10 |
11 | --extra-index-url https://download.pytorch.org/whl/cu124
12 | torch==2.4.0
13 | torchvision==0.19.0
14 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import dataclasses
16 | import gc
17 | import itertools
18 | import logging
19 | import os
20 | import random
21 | from copy import deepcopy
22 | from typing import TYPE_CHECKING, Literal
23 |
24 | import torch
25 | import torch.nn.functional as F
26 | import transformers
27 | from accelerate import Accelerator, DeepSpeedPlugin
28 | from accelerate.logging import get_logger
29 | from accelerate.utils import set_seed
30 | from diffusers.optimization import get_scheduler
31 | from einops import rearrange
32 | from PIL import Image
33 | from safetensors.torch import load_file
34 | from torch.utils.data import DataLoader
35 | from tqdm import tqdm
36 |
37 | from uno.dataset.uno import FluxPairedDatasetV2
38 | from uno.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack
39 | from uno.flux.util import load_ae, load_clip, load_flow_model, load_t5, set_lora
40 |
41 | if TYPE_CHECKING:
42 | from uno.flux.model import Flux
43 | from uno.flux.modules.autoencoder import AutoEncoder
44 | from uno.flux.modules.conditioner import HFEmbedder
45 |
46 | logger = get_logger(__name__)
47 |
48 | def get_models(name: str, device, offload: bool=False):
49 | t5 = load_t5(device, max_length=512)
50 | clip = load_clip(device)
51 | model = load_flow_model(name, device="cpu")
52 | vae = load_ae(name, device="cpu" if offload else device)
53 | return model, vae, t5, clip
54 |
55 | def inference(
56 | batch: dict,
57 | model: "Flux", t5: "HFEmbedder", clip: "HFEmbedder", ae: "AutoEncoder",
58 | accelerator: Accelerator,
59 | seed: int = 0,
60 | pe: Literal["d", "h", "w", "o"] = "d"
61 | ) -> Image.Image:
62 | ref_imgs = batch["ref_imgs"]
63 | prompt = batch["txt"]
64 | neg_prompt = ''
65 | width, height = 512, 512
66 | num_steps = 25
67 | x = get_noise(
68 | 1, height, width,
69 | device=accelerator.device,
70 | dtype=torch.bfloat16,
71 | seed=seed + accelerator.process_index
72 | )
73 | timesteps = get_schedule(
74 | num_steps,
75 | (width // 8) * (height // 8) // (16 * 16),
76 | shift=True,
77 | )
78 | with torch.no_grad():
79 | ref_imgs = [
80 | ae.encode(ref_img_.to(accelerator.device, torch.float32)).to(torch.bfloat16)
81 | for ref_img_ in ref_imgs
82 | ]
83 | inp_cond = prepare_multi_ip(
84 | t5=t5, clip=clip, img=x, prompt=prompt,
85 | ref_imgs=ref_imgs,
86 | pe=pe
87 | )
88 |
89 | x = denoise(
90 | model,
91 | **inp_cond,
92 | timesteps=timesteps,
93 | guidance=4,
94 | )
95 |
96 | x = unpack(x.float(), height, width)
97 | x = ae.decode(x)
98 |
99 | x1 = x.clamp(-1, 1)
100 | x1 = rearrange(x1[-1], "c h w -> h w c")
101 | output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
102 |
103 | return output_img
104 |
105 |
106 | def resume_from_checkpoint(
107 | resume_from_checkpoint: str | None | Literal["latest"],
108 | project_dir: str,
109 | accelerator: Accelerator,
110 | dit: "Flux",
111 | dit_ema_dict: dict | None = None,
112 | ) -> tuple["Flux", torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler, dict | None, int]:
113 | # Potentially load in the weights and states from a previous save
114 | if resume_from_checkpoint is None:
115 | return dit, dit_ema_dict, 0
116 |
117 | if resume_from_checkpoint == "latest":
118 | # Get the most recent checkpoint
119 | dirs = os.listdir(project_dir)
120 | dirs = [d for d in dirs if d.startswith("checkpoint")]
121 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
122 | if len(dirs) == 0:
123 | accelerator.print(
124 | f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
125 | )
126 | return dit, dit_ema_dict, 0
127 | path = dirs[-1]
128 | else:
129 | path = os.path.basename(resume_from_checkpoint)
130 |
131 |
132 | accelerator.print(f"Resuming from checkpoint {path}")
133 | lora_state = load_file(
134 | os.path.join(project_dir, path, 'dit_lora.safetensors'),
135 | device=accelerator.device.__str__()
136 | )
137 | unwarp_dit = accelerator.unwrap_model(dit)
138 | unwarp_dit.load_state_dict(lora_state, strict=False)
139 | if dit_ema_dict is not None:
140 | dit_ema_dict = load_file(
141 | os.path.join(project_dir, path, 'dit_lora_ema.safetensors'),
142 | device=accelerator.device.__str__()
143 | )
144 | if dit is not unwarp_dit:
145 | dit_ema_dict = {f"module.{k}": v for k, v in dit_ema_dict.items() if k in unwarp_dit.state_dict()}
146 |
147 | global_step = int(path.split("-")[1])
148 |
149 | return dit, dit_ema_dict, global_step
150 |
151 | @dataclasses.dataclass
152 | class TrainArgs:
153 | ## accelerator
154 | project_dir: str | None = None
155 | mixed_precision: Literal["no", "fp16", "bf16"] = "bf16"
156 | gradient_accumulation_steps: int = 1
157 | seed: int = 42
158 | wandb_project_name: str | None = None
159 | wandb_run_name: str | None = None
160 |
161 | ## model
162 | model_name: Literal["flux-dev", "flux-schnell"] = "flux-dev"
163 | lora_rank: int = 512
164 | double_blocks_indices: list[int] | None = dataclasses.field(
165 | default=None,
166 | metadata={"help": "Indices of double blocks to apply LoRA. None means all double blocks."}
167 | )
168 | single_blocks_indices: list[int] | None = dataclasses.field(
169 | default=None,
170 | metadata={"help": "Indices of double blocks to apply LoRA. None means all single blocks."}
171 | )
172 | pe: Literal["d", "h", "w", "o"] = "d"
173 | gradient_checkpoint: bool = True
174 | ema: bool = False
175 | ema_interval: int = 1
176 | ema_decay: float = 0.99
177 |
178 |
179 | ## optimizer
180 | learning_rate: float = 1e-2
181 | adam_betas: list[float] = dataclasses.field(default_factory=lambda: [0.9, 0.999])
182 | adam_eps: float = 1e-8
183 | adam_weight_decay: float = 0.01
184 | max_grad_norm: float = 1.0
185 |
186 | ## lr_scheduler
187 | lr_scheduler: str = "constant"
188 | lr_warmup_steps: int = 100
189 | max_train_steps: int = 100000
190 |
191 | ## dataloader
192 | # TODO: change to your own dataset, or use one data syenthsize pipeline comming in the future. stay tuned
193 | train_data_json: str = "datasets/dreambench_singleip.json"
194 | batch_size: int = 1
195 | text_dropout: float = 0.1
196 | resolution: int = 512
197 | resolution_ref: int | None = None
198 |
199 | eval_data_json: str = "datasets/dreambench_singleip.json"
200 | eval_batch_size: int = 1
201 |
202 | ## misc
203 | resume_from_checkpoint: str | None | Literal["latest"] = None
204 | checkpointing_steps: int = 1000
205 |
206 | def main(
207 | args: TrainArgs,
208 | ):
209 | ## accelerator
210 | deepspeed_plugins = {
211 | "dit": DeepSpeedPlugin(hf_ds_config='config/deepspeed/zero2_config.json'),
212 | "t5": DeepSpeedPlugin(hf_ds_config='config/deepspeed/zero3_config.json'),
213 | "clip": DeepSpeedPlugin(hf_ds_config='config/deepspeed/zero3_config.json')
214 | }
215 | accelerator = Accelerator(
216 | project_dir=args.project_dir,
217 | gradient_accumulation_steps=args.gradient_accumulation_steps,
218 | mixed_precision=args.mixed_precision,
219 | deepspeed_plugins=deepspeed_plugins,
220 | log_with="wandb",
221 | )
222 | set_seed(args.seed, device_specific=True)
223 | accelerator.init_trackers(
224 | project_name=args.wandb_project_name,
225 | config=args.__dict__,
226 | init_kwargs={
227 | "wandb": {
228 | "name": args.wandb_run_name,
229 | "dir": accelerator.project_dir,
230 | },
231 | },
232 | )
233 | weight_dtype = {
234 | "fp16": torch.float16,
235 | "bf16": torch.bfloat16,
236 | "no": torch.float32,
237 | }.get(accelerator.mixed_precision, torch.float32)
238 |
239 | ## logger
240 | logging.basicConfig(
241 | format=f"[RANK {accelerator.process_index}] " + "%(asctime)s - %(levelname)s - %(name)s - %(message)s",
242 | datefmt="%m/%d/%Y %H:%M:%S",
243 | level=logging.INFO,
244 | force=True
245 | )
246 | logger.info(accelerator.state)
247 | logger.info("Training script launched", main_process_only=False)
248 |
249 | ## model
250 | dit, vae, t5, clip = get_models(
251 | name=args.model_name,
252 | device=accelerator.device,
253 | )
254 |
255 | vae.requires_grad_(False)
256 | t5.requires_grad_(False)
257 | clip.requires_grad_(False)
258 |
259 | dit.requires_grad_(False)
260 | dit = set_lora(dit, args.lora_rank, args.double_blocks_indices, args.single_blocks_indices, accelerator.device)
261 | dit.train()
262 | dit.gradient_checkpointing = args.gradient_checkpoint
263 |
264 | ## ema
265 | dit_ema_dict = {
266 | f"module.{k}": deepcopy(v).requires_grad_(False) for k, v in dit.named_parameters() if v.requires_grad
267 | } if args.ema else None
268 |
269 | ## optimizer and lr scheduler
270 | optimizer = torch.optim.AdamW(
271 | [p for p in dit.parameters() if p.requires_grad],
272 | lr=args.learning_rate,
273 | betas=args.adam_betas,
274 | weight_decay=args.adam_weight_decay,
275 | eps=args.adam_eps,
276 | )
277 |
278 | lr_scheduler = get_scheduler(
279 | args.lr_scheduler,
280 | optimizer=optimizer,
281 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
282 | num_training_steps=args.max_train_steps * accelerator.num_processes,
283 | )
284 |
285 | ## resume
286 | (
287 | dit,
288 | dit_ema_dict,
289 | global_step
290 | ) = resume_from_checkpoint(
291 | args.resume_from_checkpoint,
292 | project_dir=args.project_dir,
293 | accelerator=accelerator,
294 | dit=dit,
295 | dit_ema_dict=dit_ema_dict
296 | )
297 |
298 | # dataloader
299 | dataset = FluxPairedDatasetV2(
300 | json_file=args.train_data_json,
301 | resolution=args.resolution, resolution_ref=args.resolution_ref
302 | )
303 | dataloader = DataLoader(
304 | dataset,
305 | batch_size=args.batch_size,
306 | shuffle=True,
307 | collate_fn=dataset.collate_fn
308 | )
309 | eval_dataset = FluxPairedDatasetV2(
310 | json_file=args.eval_data_json,
311 | resolution=args.resolution, resolution_ref=args.resolution_ref
312 | )
313 | eval_dataloader = DataLoader(
314 | eval_dataset,
315 | batch_size=args.eval_batch_size,
316 | shuffle=False,
317 | collate_fn=eval_dataset.collate_fn
318 | )
319 |
320 | dataloader = accelerator.prepare_data_loader(dataloader)
321 | eval_dataloader = accelerator.prepare_data_loader(eval_dataloader)
322 | dataloader = itertools.cycle(dataloader) # as infinite fetch data loader
323 |
324 | ## parallel
325 | accelerator.state.select_deepspeed_plugin("dit")
326 | dit, optimizer, lr_scheduler = accelerator.prepare(dit, optimizer, lr_scheduler)
327 | accelerator.state.select_deepspeed_plugin("t5")
328 | t5 = accelerator.prepare(t5) # type: torch.nn.Module
329 | accelerator.state.select_deepspeed_plugin("clip")
330 | clip = accelerator.prepare(clip) # type: torch.nn.Module
331 |
332 | ## noise scheduler
333 | timesteps = get_schedule(
334 | 999,
335 | (args.resolution // 8) * (args.resolution // 8) // 4,
336 | shift=True,
337 | )
338 | timesteps = torch.tensor(timesteps, device=accelerator.device)
339 | total_batch_size = args.batch_size * accelerator.num_processes * args.gradient_accumulation_steps
340 |
341 | logger.info("***** Running training *****")
342 | logger.info(f" Instantaneous batch size per device = {args.batch_size}")
343 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
344 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
345 | logger.info(f" Total optimization steps = {args.max_train_steps}")
346 | logger.info(f" Total validation prompts = {len(eval_dataloader)}")
347 |
348 | progress_bar = tqdm(
349 | range(0, args.max_train_steps),
350 | initial=global_step,
351 | desc="Steps",
352 | total=args.max_train_steps,
353 | disable=not accelerator.is_local_main_process,
354 | )
355 |
356 | train_loss = 0.0
357 | while global_step < (args.max_train_steps):
358 | batch = next(dataloader)
359 | prompts = [txt_ if random.random() > args.text_dropout else "" for txt_ in batch["txt"]]
360 | img = batch["img"]
361 | ref_imgs = batch["ref_imgs"]
362 |
363 | with torch.no_grad():
364 | x_1 = vae.encode(img.to(accelerator.device).to(torch.float32))
365 | x_ref = [vae.encode(ref_img.to(accelerator.device).to(torch.float32)) for ref_img in ref_imgs]
366 | inp = prepare_multi_ip(t5=t5, clip=clip, img=x_1, prompt=prompts, ref_imgs=tuple(x_ref), pe=args.pe)
367 | x_1 = rearrange(x_1, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
368 | x_ref = [rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) for x in x_ref]
369 |
370 | bs = img.shape[0]
371 | t = torch.randint(0, 1000, (bs,), device=accelerator.device)
372 | t = timesteps[t]
373 | x_0 = torch.randn_like(x_1, device=accelerator.device)
374 | x_t = (1 - t[:, None, None]) * x_1 + t[:, None, None] * x_0
375 | guidance_vec = torch.full((x_t.shape[0],), 1, device=x_t.device, dtype=x_t.dtype)
376 |
377 | with accelerator.accumulate(dit):
378 | # Predict the noise residual and compute loss
379 | model_pred = dit(
380 | img=x_t.to(weight_dtype),
381 | img_ids=inp['img_ids'].to(weight_dtype),
382 | ref_img=[x.to(weight_dtype) for x in x_ref],
383 | ref_img_ids=[ref_img_id.to(weight_dtype) for ref_img_id in inp['ref_img_ids']],
384 | txt=inp['txt'].to(weight_dtype),
385 | txt_ids=inp['txt_ids'].to(weight_dtype),
386 | y=inp['vec'].to(weight_dtype),
387 | timesteps=t.to(weight_dtype),
388 | guidance=guidance_vec.to(weight_dtype)
389 | )
390 |
391 | loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean")
392 |
393 | # Gather the losses across all processes for logging (if we use distributed training).
394 | avg_loss = accelerator.gather(loss.repeat(args.batch_size)).mean()
395 | train_loss += avg_loss.item() / args.gradient_accumulation_steps
396 |
397 | # Backpropagate
398 | accelerator.backward(loss)
399 | if accelerator.sync_gradients:
400 | accelerator.clip_grad_norm_(dit.parameters(), args.max_grad_norm)
401 | optimizer.step()
402 | lr_scheduler.step()
403 | optimizer.zero_grad()
404 |
405 | # Checks if the accelerator has performed an optimization step behind the scenes
406 | if accelerator.sync_gradients:
407 | progress_bar.update(1)
408 | global_step += 1
409 | accelerator.log({"train_loss": train_loss}, step=global_step)
410 | train_loss = 0.0
411 |
412 | if accelerator.sync_gradients and dit_ema_dict is not None and global_step % args.ema_interval == 0:
413 | src_dict = dit.state_dict()
414 | for tgt_name in dit_ema_dict:
415 | dit_ema_dict[tgt_name].data.lerp_(src_dict[tgt_name].to(dit_ema_dict[tgt_name]), 1 - args.ema_decay)
416 |
417 | if accelerator.sync_gradients and accelerator.is_main_process and global_step % args.checkpointing_steps == 0:
418 | logger.info(f"saving checkpoint in {global_step=}")
419 | save_path = os.path.join(args.project_dir, f"checkpoint-{global_step}")
420 | os.makedirs(save_path, exist_ok=True)
421 |
422 | # save
423 | accelerator.wait_for_everyone()
424 | unwrapped_model = accelerator.unwrap_model(dit)
425 | unwrapped_model_state = unwrapped_model.state_dict()
426 | requires_grad_key = [k for k, v in unwrapped_model.named_parameters() if v.requires_grad]
427 | unwrapped_model_state = {k: unwrapped_model_state[k] for k in requires_grad_key}
428 |
429 | accelerator.save(
430 | unwrapped_model_state,
431 | os.path.join(save_path, 'dit_lora.safetensors'),
432 | safe_serialization=True
433 | )
434 | unwrapped_opt = accelerator.unwrap_model(optimizer)
435 | accelerator.save(unwrapped_opt.state_dict(), os.path.join(save_path, 'optimizer.bin'))
436 | logger.info(f"Saved state to {save_path}")
437 |
438 | if args.ema:
439 | accelerator.save(
440 | {k.split("module.")[-1]: v for k, v in dit_ema_dict.items()},
441 | os.path.join(save_path, 'dit_lora_ema.safetensors'),
442 | safe_serialization=True
443 | )
444 |
445 | # validate
446 | dit.eval()
447 | torch.set_grad_enabled(False)
448 | for i, batch in enumerate(eval_dataloader):
449 | result = inference(batch, dit, t5, clip, vae, accelerator, seed=0)
450 | accelerator.log({f"eval_gen_{i}": result}, step=global_step)
451 |
452 |
453 | if args.ema:
454 | original_state_dict = dit.state_dict()
455 | dit.load_state_dict(dit_ema_dict, strict=False)
456 | for batch in eval_dataloader:
457 | result = inference(batch, dit, t5, clip, vae, accelerator, seed=0)
458 | accelerator.log({f"eval_ema_gen_{i}": result}, step=global_step)
459 | dit.load_state_dict(original_state_dict, strict=False)
460 |
461 | torch.cuda.empty_cache()
462 | gc.collect()
463 | torch.set_grad_enabled(True)
464 | dit.train()
465 | accelerator.wait_for_everyone()
466 |
467 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
468 | progress_bar.set_postfix(**logs)
469 |
470 | accelerator.wait_for_everyone()
471 | accelerator.end_training()
472 |
473 | if __name__ == "__main__":
474 | parser = transformers.HfArgumentParser([TrainArgs])
475 | args_tuple = parser.parse_args_into_dataclasses(args_file_flag="--config")
476 | main(*args_tuple)
477 |
--------------------------------------------------------------------------------
/uno/dataset/uno.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import json
17 | import os
18 |
19 | import numpy as np
20 | import torch
21 | import torchvision.transforms.functional as TVF
22 | from PIL import Image
23 | from torch.utils.data import DataLoader, Dataset
24 | from torchvision.transforms import Compose, Normalize, ToTensor
25 |
26 |
27 | def bucket_images(images: list[torch.Tensor], resolution: int = 512):
28 | bucket_override=[
29 | # h w
30 | (256, 768),
31 | (320, 768),
32 | (320, 704),
33 | (384, 640),
34 | (448, 576),
35 | (512, 512),
36 | (576, 448),
37 | (640, 384),
38 | (704, 320),
39 | (768, 320),
40 | (768, 256)
41 | ]
42 | bucket_override = [(int(h / 512 * resolution), int(w / 512 * resolution)) for h, w in bucket_override]
43 | bucket_override = [(h // 16 * 16, w // 16 * 16) for h, w in bucket_override]
44 |
45 | aspect_ratios = [image.shape[-2] / image.shape[-1] for image in images]
46 | mean_aspect_ratio = np.mean(aspect_ratios)
47 |
48 | new_h, new_w = bucket_override[0]
49 | min_aspect_diff = np.abs(new_h / new_w - mean_aspect_ratio)
50 | for h, w in bucket_override:
51 | aspect_diff = np.abs(h / w - mean_aspect_ratio)
52 | if aspect_diff < min_aspect_diff:
53 | min_aspect_diff = aspect_diff
54 | new_h, new_w = h, w
55 |
56 | images = [TVF.resize(image, (new_h, new_w)) for image in images]
57 | images = torch.stack(images, dim=0)
58 | return images
59 |
60 | class FluxPairedDatasetV2(Dataset):
61 | def __init__(self, json_file: str, resolution: int, resolution_ref: int | None = None):
62 | super().__init__()
63 | self.json_file = json_file
64 | self.resolution = resolution
65 | self.resolution_ref = resolution_ref if resolution_ref is not None else resolution
66 | self.image_root = os.path.dirname(json_file)
67 |
68 | with open(self.json_file, "rt") as f:
69 | self.data_dicts = json.load(f)
70 |
71 | self.transform = Compose([
72 | ToTensor(),
73 | Normalize([0.5], [0.5]),
74 | ])
75 |
76 | def __getitem__(self, idx):
77 | data_dict = self.data_dicts[idx]
78 | image_paths = [data_dict["image_path"]] if "image_path" in data_dict else data_dict["image_paths"]
79 | txt = data_dict["prompt"]
80 | image_tgt_path = data_dict.get("image_tgt_path", None)
81 | # image_tgt_path = data_dict.get("image_paths", None)[0] # TODO: for debugging delete it when release paired data pipeline
82 | ref_imgs = [
83 | Image.open(os.path.join(self.image_root, path)).convert("RGB")
84 | for path in image_paths
85 | ]
86 | ref_imgs = [self.transform(img) for img in ref_imgs]
87 | img = None
88 | if image_tgt_path is not None:
89 | img = Image.open(os.path.join(self.image_root, image_tgt_path)).convert("RGB")
90 | img = self.transform(img)
91 |
92 | return {
93 | "img": img,
94 | "txt": txt,
95 | "ref_imgs": ref_imgs,
96 | }
97 |
98 | def __len__(self):
99 | return len(self.data_dicts)
100 |
101 | def collate_fn(self, batch):
102 | img = [data["img"] for data in batch]
103 | txt = [data["txt"] for data in batch]
104 | ref_imgs = [data["ref_imgs"] for data in batch]
105 | assert all([len(ref_imgs[0]) == len(ref_imgs[i]) for i in range(len(ref_imgs))])
106 |
107 | n_ref = len(ref_imgs[0])
108 |
109 | img = bucket_images(img, self.resolution)
110 | ref_imgs_new = []
111 | for i in range(n_ref):
112 | ref_imgs_i = [refs[i] for refs in ref_imgs]
113 | ref_imgs_i = bucket_images(ref_imgs_i, self.resolution_ref)
114 | ref_imgs_new.append(ref_imgs_i)
115 |
116 | return {
117 | "txt": txt,
118 | "img": img,
119 | "ref_imgs": ref_imgs_new,
120 | }
121 |
122 | if __name__ == '__main__':
123 | import argparse
124 | from pprint import pprint
125 | parser = argparse.ArgumentParser()
126 | # parser.add_argument("--json_file", type=str, required=True)
127 | parser.add_argument("--json_file", type=str, default="datasets/fake_train_data.json")
128 | args = parser.parse_args()
129 | dataset = FluxPairedDatasetV2(args.json_file, 512)
130 | dataloder = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn)
131 |
132 | for i, data_dict in enumerate(dataloder):
133 | pprint(i)
134 | pprint(data_dict)
135 | breakpoint()
136 |
--------------------------------------------------------------------------------
/uno/flux/math.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 | from einops import rearrange
18 | from torch import Tensor
19 |
20 |
21 | def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
22 | q, k = apply_rope(q, k, pe)
23 |
24 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
25 | x = rearrange(x, "B H L D -> B L (H D)")
26 |
27 | return x
28 |
29 |
30 | def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
31 | assert dim % 2 == 0
32 | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
33 | omega = 1.0 / (theta**scale)
34 | out = torch.einsum("...n,d->...nd", pos, omega)
35 | out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
36 | out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
37 | return out.float()
38 |
39 |
40 | def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
41 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
42 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
43 | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
44 | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
45 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
46 |
--------------------------------------------------------------------------------
/uno/flux/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from dataclasses import dataclass
17 |
18 | import torch
19 | from torch import Tensor, nn
20 |
21 | from .modules.layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding
22 |
23 |
24 | @dataclass
25 | class FluxParams:
26 | in_channels: int
27 | vec_in_dim: int
28 | context_in_dim: int
29 | hidden_size: int
30 | mlp_ratio: float
31 | num_heads: int
32 | depth: int
33 | depth_single_blocks: int
34 | axes_dim: list[int]
35 | theta: int
36 | qkv_bias: bool
37 | guidance_embed: bool
38 |
39 |
40 | class Flux(nn.Module):
41 | """
42 | Transformer model for flow matching on sequences.
43 | """
44 | _supports_gradient_checkpointing = True
45 |
46 | def __init__(self, params: FluxParams):
47 | super().__init__()
48 |
49 | self.params = params
50 | self.in_channels = params.in_channels
51 | self.out_channels = self.in_channels
52 | if params.hidden_size % params.num_heads != 0:
53 | raise ValueError(
54 | f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
55 | )
56 | pe_dim = params.hidden_size // params.num_heads
57 | if sum(params.axes_dim) != pe_dim:
58 | raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
59 | self.hidden_size = params.hidden_size
60 | self.num_heads = params.num_heads
61 | self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
62 | self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
63 | self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
64 | self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
65 | self.guidance_in = (
66 | MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
67 | )
68 | self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
69 |
70 | self.double_blocks = nn.ModuleList(
71 | [
72 | DoubleStreamBlock(
73 | self.hidden_size,
74 | self.num_heads,
75 | mlp_ratio=params.mlp_ratio,
76 | qkv_bias=params.qkv_bias,
77 | )
78 | for _ in range(params.depth)
79 | ]
80 | )
81 |
82 | self.single_blocks = nn.ModuleList(
83 | [
84 | SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
85 | for _ in range(params.depth_single_blocks)
86 | ]
87 | )
88 |
89 | self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
90 | self.gradient_checkpointing = False
91 |
92 | def _set_gradient_checkpointing(self, module, value=False):
93 | if hasattr(module, "gradient_checkpointing"):
94 | module.gradient_checkpointing = value
95 |
96 | @property
97 | def attn_processors(self):
98 | # set recursively
99 | processors = {} # type: dict[str, nn.Module]
100 |
101 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
102 | if hasattr(module, "set_processor"):
103 | processors[f"{name}.processor"] = module.processor
104 |
105 | for sub_name, child in module.named_children():
106 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
107 |
108 | return processors
109 |
110 | for name, module in self.named_children():
111 | fn_recursive_add_processors(name, module, processors)
112 |
113 | return processors
114 |
115 | def set_attn_processor(self, processor):
116 | r"""
117 | Sets the attention processor to use to compute attention.
118 |
119 | Parameters:
120 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
121 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
122 | for **all** `Attention` layers.
123 |
124 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
125 | processor. This is strongly recommended when setting trainable attention processors.
126 |
127 | """
128 | count = len(self.attn_processors.keys())
129 |
130 | if isinstance(processor, dict) and len(processor) != count:
131 | raise ValueError(
132 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
133 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
134 | )
135 |
136 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
137 | if hasattr(module, "set_processor"):
138 | if not isinstance(processor, dict):
139 | module.set_processor(processor)
140 | else:
141 | module.set_processor(processor.pop(f"{name}.processor"))
142 |
143 | for sub_name, child in module.named_children():
144 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
145 |
146 | for name, module in self.named_children():
147 | fn_recursive_attn_processor(name, module, processor)
148 |
149 | def forward(
150 | self,
151 | img: Tensor,
152 | img_ids: Tensor,
153 | txt: Tensor,
154 | txt_ids: Tensor,
155 | timesteps: Tensor,
156 | y: Tensor,
157 | guidance: Tensor | None = None,
158 | ref_img: Tensor | None = None,
159 | ref_img_ids: Tensor | None = None,
160 | ) -> Tensor:
161 | if img.ndim != 3 or txt.ndim != 3:
162 | raise ValueError("Input img and txt tensors must have 3 dimensions.")
163 |
164 | # running on sequences img
165 | img = self.img_in(img)
166 | vec = self.time_in(timestep_embedding(timesteps, 256))
167 | if self.params.guidance_embed:
168 | if guidance is None:
169 | raise ValueError("Didn't get guidance strength for guidance distilled model.")
170 | vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
171 | vec = vec + self.vector_in(y)
172 | txt = self.txt_in(txt)
173 |
174 | ids = torch.cat((txt_ids, img_ids), dim=1)
175 |
176 | # concat ref_img/img
177 | img_end = img.shape[1]
178 | if ref_img is not None:
179 | if isinstance(ref_img, tuple) or isinstance(ref_img, list):
180 | img_in = [img] + [self.img_in(ref) for ref in ref_img]
181 | img_ids = [ids] + [ref_ids for ref_ids in ref_img_ids]
182 | img = torch.cat(img_in, dim=1)
183 | ids = torch.cat(img_ids, dim=1)
184 | else:
185 | img = torch.cat((img, self.img_in(ref_img)), dim=1)
186 | ids = torch.cat((ids, ref_img_ids), dim=1)
187 | pe = self.pe_embedder(ids)
188 |
189 | for index_block, block in enumerate(self.double_blocks):
190 | if self.training and self.gradient_checkpointing:
191 | img, txt = torch.utils.checkpoint.checkpoint(
192 | block,
193 | img=img,
194 | txt=txt,
195 | vec=vec,
196 | pe=pe,
197 | use_reentrant=False,
198 | )
199 | else:
200 | img, txt = block(
201 | img=img,
202 | txt=txt,
203 | vec=vec,
204 | pe=pe
205 | )
206 |
207 | img = torch.cat((txt, img), 1)
208 | for block in self.single_blocks:
209 | if self.training and self.gradient_checkpointing:
210 | img = torch.utils.checkpoint.checkpoint(
211 | block,
212 | img, vec=vec, pe=pe,
213 | use_reentrant=False
214 | )
215 | else:
216 | img = block(img, vec=vec, pe=pe)
217 | img = img[:, txt.shape[1] :, ...]
218 | # index img
219 | img = img[:, :img_end, ...]
220 |
221 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
222 | return img
223 |
--------------------------------------------------------------------------------
/uno/flux/modules/autoencoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from dataclasses import dataclass
17 |
18 | import torch
19 | from einops import rearrange
20 | from torch import Tensor, nn
21 |
22 |
23 | @dataclass
24 | class AutoEncoderParams:
25 | resolution: int
26 | in_channels: int
27 | ch: int
28 | out_ch: int
29 | ch_mult: list[int]
30 | num_res_blocks: int
31 | z_channels: int
32 | scale_factor: float
33 | shift_factor: float
34 |
35 |
36 | def swish(x: Tensor) -> Tensor:
37 | return x * torch.sigmoid(x)
38 |
39 |
40 | class AttnBlock(nn.Module):
41 | def __init__(self, in_channels: int):
42 | super().__init__()
43 | self.in_channels = in_channels
44 |
45 | self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
46 |
47 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
48 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
49 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
50 | self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
51 |
52 | def attention(self, h_: Tensor) -> Tensor:
53 | h_ = self.norm(h_)
54 | q = self.q(h_)
55 | k = self.k(h_)
56 | v = self.v(h_)
57 |
58 | b, c, h, w = q.shape
59 | q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
60 | k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
61 | v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
62 | h_ = nn.functional.scaled_dot_product_attention(q, k, v)
63 |
64 | return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
65 |
66 | def forward(self, x: Tensor) -> Tensor:
67 | return x + self.proj_out(self.attention(x))
68 |
69 |
70 | class ResnetBlock(nn.Module):
71 | def __init__(self, in_channels: int, out_channels: int):
72 | super().__init__()
73 | self.in_channels = in_channels
74 | out_channels = in_channels if out_channels is None else out_channels
75 | self.out_channels = out_channels
76 |
77 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
79 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
80 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
81 | if self.in_channels != self.out_channels:
82 | self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
83 |
84 | def forward(self, x):
85 | h = x
86 | h = self.norm1(h)
87 | h = swish(h)
88 | h = self.conv1(h)
89 |
90 | h = self.norm2(h)
91 | h = swish(h)
92 | h = self.conv2(h)
93 |
94 | if self.in_channels != self.out_channels:
95 | x = self.nin_shortcut(x)
96 |
97 | return x + h
98 |
99 |
100 | class Downsample(nn.Module):
101 | def __init__(self, in_channels: int):
102 | super().__init__()
103 | # no asymmetric padding in torch conv, must do it ourselves
104 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
105 |
106 | def forward(self, x: Tensor):
107 | pad = (0, 1, 0, 1)
108 | x = nn.functional.pad(x, pad, mode="constant", value=0)
109 | x = self.conv(x)
110 | return x
111 |
112 |
113 | class Upsample(nn.Module):
114 | def __init__(self, in_channels: int):
115 | super().__init__()
116 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
117 |
118 | def forward(self, x: Tensor):
119 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
120 | x = self.conv(x)
121 | return x
122 |
123 |
124 | class Encoder(nn.Module):
125 | def __init__(
126 | self,
127 | resolution: int,
128 | in_channels: int,
129 | ch: int,
130 | ch_mult: list[int],
131 | num_res_blocks: int,
132 | z_channels: int,
133 | ):
134 | super().__init__()
135 | self.ch = ch
136 | self.num_resolutions = len(ch_mult)
137 | self.num_res_blocks = num_res_blocks
138 | self.resolution = resolution
139 | self.in_channels = in_channels
140 | # downsampling
141 | self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
142 |
143 | curr_res = resolution
144 | in_ch_mult = (1,) + tuple(ch_mult)
145 | self.in_ch_mult = in_ch_mult
146 | self.down = nn.ModuleList()
147 | block_in = self.ch
148 | for i_level in range(self.num_resolutions):
149 | block = nn.ModuleList()
150 | attn = nn.ModuleList()
151 | block_in = ch * in_ch_mult[i_level]
152 | block_out = ch * ch_mult[i_level]
153 | for _ in range(self.num_res_blocks):
154 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
155 | block_in = block_out
156 | down = nn.Module()
157 | down.block = block
158 | down.attn = attn
159 | if i_level != self.num_resolutions - 1:
160 | down.downsample = Downsample(block_in)
161 | curr_res = curr_res // 2
162 | self.down.append(down)
163 |
164 | # middle
165 | self.mid = nn.Module()
166 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
167 | self.mid.attn_1 = AttnBlock(block_in)
168 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
169 |
170 | # end
171 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
172 | self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
173 |
174 | def forward(self, x: Tensor) -> Tensor:
175 | # downsampling
176 | hs = [self.conv_in(x)]
177 | for i_level in range(self.num_resolutions):
178 | for i_block in range(self.num_res_blocks):
179 | h = self.down[i_level].block[i_block](hs[-1])
180 | if len(self.down[i_level].attn) > 0:
181 | h = self.down[i_level].attn[i_block](h)
182 | hs.append(h)
183 | if i_level != self.num_resolutions - 1:
184 | hs.append(self.down[i_level].downsample(hs[-1]))
185 |
186 | # middle
187 | h = hs[-1]
188 | h = self.mid.block_1(h)
189 | h = self.mid.attn_1(h)
190 | h = self.mid.block_2(h)
191 | # end
192 | h = self.norm_out(h)
193 | h = swish(h)
194 | h = self.conv_out(h)
195 | return h
196 |
197 |
198 | class Decoder(nn.Module):
199 | def __init__(
200 | self,
201 | ch: int,
202 | out_ch: int,
203 | ch_mult: list[int],
204 | num_res_blocks: int,
205 | in_channels: int,
206 | resolution: int,
207 | z_channels: int,
208 | ):
209 | super().__init__()
210 | self.ch = ch
211 | self.num_resolutions = len(ch_mult)
212 | self.num_res_blocks = num_res_blocks
213 | self.resolution = resolution
214 | self.in_channels = in_channels
215 | self.ffactor = 2 ** (self.num_resolutions - 1)
216 |
217 | # compute in_ch_mult, block_in and curr_res at lowest res
218 | block_in = ch * ch_mult[self.num_resolutions - 1]
219 | curr_res = resolution // 2 ** (self.num_resolutions - 1)
220 | self.z_shape = (1, z_channels, curr_res, curr_res)
221 |
222 | # z to block_in
223 | self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
224 |
225 | # middle
226 | self.mid = nn.Module()
227 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
228 | self.mid.attn_1 = AttnBlock(block_in)
229 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
230 |
231 | # upsampling
232 | self.up = nn.ModuleList()
233 | for i_level in reversed(range(self.num_resolutions)):
234 | block = nn.ModuleList()
235 | attn = nn.ModuleList()
236 | block_out = ch * ch_mult[i_level]
237 | for _ in range(self.num_res_blocks + 1):
238 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
239 | block_in = block_out
240 | up = nn.Module()
241 | up.block = block
242 | up.attn = attn
243 | if i_level != 0:
244 | up.upsample = Upsample(block_in)
245 | curr_res = curr_res * 2
246 | self.up.insert(0, up) # prepend to get consistent order
247 |
248 | # end
249 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
250 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
251 |
252 | def forward(self, z: Tensor) -> Tensor:
253 | # z to block_in
254 | h = self.conv_in(z)
255 |
256 | # middle
257 | h = self.mid.block_1(h)
258 | h = self.mid.attn_1(h)
259 | h = self.mid.block_2(h)
260 |
261 | # upsampling
262 | for i_level in reversed(range(self.num_resolutions)):
263 | for i_block in range(self.num_res_blocks + 1):
264 | h = self.up[i_level].block[i_block](h)
265 | if len(self.up[i_level].attn) > 0:
266 | h = self.up[i_level].attn[i_block](h)
267 | if i_level != 0:
268 | h = self.up[i_level].upsample(h)
269 |
270 | # end
271 | h = self.norm_out(h)
272 | h = swish(h)
273 | h = self.conv_out(h)
274 | return h
275 |
276 |
277 | class DiagonalGaussian(nn.Module):
278 | def __init__(self, sample: bool = True, chunk_dim: int = 1):
279 | super().__init__()
280 | self.sample = sample
281 | self.chunk_dim = chunk_dim
282 |
283 | def forward(self, z: Tensor) -> Tensor:
284 | mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
285 | if self.sample:
286 | std = torch.exp(0.5 * logvar)
287 | return mean + std * torch.randn_like(mean)
288 | else:
289 | return mean
290 |
291 |
292 | class AutoEncoder(nn.Module):
293 | def __init__(self, params: AutoEncoderParams):
294 | super().__init__()
295 | self.encoder = Encoder(
296 | resolution=params.resolution,
297 | in_channels=params.in_channels,
298 | ch=params.ch,
299 | ch_mult=params.ch_mult,
300 | num_res_blocks=params.num_res_blocks,
301 | z_channels=params.z_channels,
302 | )
303 | self.decoder = Decoder(
304 | resolution=params.resolution,
305 | in_channels=params.in_channels,
306 | ch=params.ch,
307 | out_ch=params.out_ch,
308 | ch_mult=params.ch_mult,
309 | num_res_blocks=params.num_res_blocks,
310 | z_channels=params.z_channels,
311 | )
312 | self.reg = DiagonalGaussian()
313 |
314 | self.scale_factor = params.scale_factor
315 | self.shift_factor = params.shift_factor
316 |
317 | def encode(self, x: Tensor) -> Tensor:
318 | z = self.reg(self.encoder(x))
319 | z = self.scale_factor * (z - self.shift_factor)
320 | return z
321 |
322 | def decode(self, z: Tensor) -> Tensor:
323 | z = z / self.scale_factor + self.shift_factor
324 | return self.decoder(z)
325 |
326 | def forward(self, x: Tensor) -> Tensor:
327 | return self.decode(self.encode(x))
328 |
--------------------------------------------------------------------------------
/uno/flux/modules/conditioner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from torch import Tensor, nn
17 | from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
18 | T5Tokenizer)
19 |
20 |
21 | class HFEmbedder(nn.Module):
22 | def __init__(self, version: str, max_length: int, **hf_kwargs):
23 | super().__init__()
24 | self.is_clip = "clip" in version.lower()
25 | self.max_length = max_length
26 | self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
27 |
28 | if self.is_clip:
29 | self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
30 | self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
31 | else:
32 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
33 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
34 |
35 | self.hf_module = self.hf_module.eval().requires_grad_(False)
36 |
37 | def forward(self, text: list[str]) -> Tensor:
38 | batch_encoding = self.tokenizer(
39 | text,
40 | truncation=True,
41 | max_length=self.max_length,
42 | return_length=False,
43 | return_overflowing_tokens=False,
44 | padding="max_length",
45 | return_tensors="pt",
46 | )
47 |
48 | outputs = self.hf_module(
49 | input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
50 | attention_mask=None,
51 | output_hidden_states=False,
52 | )
53 | return outputs[self.output_key]
54 |
--------------------------------------------------------------------------------
/uno/flux/modules/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import math
17 | from dataclasses import dataclass
18 |
19 | import torch
20 | from einops import rearrange
21 | from torch import Tensor, nn
22 |
23 | from ..math import attention, rope
24 | import torch.nn.functional as F
25 |
26 | class EmbedND(nn.Module):
27 | def __init__(self, dim: int, theta: int, axes_dim: list[int]):
28 | super().__init__()
29 | self.dim = dim
30 | self.theta = theta
31 | self.axes_dim = axes_dim
32 |
33 | def forward(self, ids: Tensor) -> Tensor:
34 | n_axes = ids.shape[-1]
35 | emb = torch.cat(
36 | [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
37 | dim=-3,
38 | )
39 |
40 | return emb.unsqueeze(1)
41 |
42 |
43 | def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
44 | """
45 | Create sinusoidal timestep embeddings.
46 | :param t: a 1-D Tensor of N indices, one per batch element.
47 | These may be fractional.
48 | :param dim: the dimension of the output.
49 | :param max_period: controls the minimum frequency of the embeddings.
50 | :return: an (N, D) Tensor of positional embeddings.
51 | """
52 | t = time_factor * t
53 | half = dim // 2
54 | freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
55 | t.device
56 | )
57 |
58 | args = t[:, None].float() * freqs[None]
59 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
60 | if dim % 2:
61 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
62 | if torch.is_floating_point(t):
63 | embedding = embedding.to(t)
64 | return embedding
65 |
66 |
67 | class MLPEmbedder(nn.Module):
68 | def __init__(self, in_dim: int, hidden_dim: int):
69 | super().__init__()
70 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
71 | self.silu = nn.SiLU()
72 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
73 |
74 | def forward(self, x: Tensor) -> Tensor:
75 | return self.out_layer(self.silu(self.in_layer(x)))
76 |
77 |
78 | class RMSNorm(torch.nn.Module):
79 | def __init__(self, dim: int):
80 | super().__init__()
81 | self.scale = nn.Parameter(torch.ones(dim))
82 |
83 | def forward(self, x: Tensor):
84 | x_dtype = x.dtype
85 | x = x.float()
86 | rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
87 | return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
88 |
89 |
90 | class QKNorm(torch.nn.Module):
91 | def __init__(self, dim: int):
92 | super().__init__()
93 | self.query_norm = RMSNorm(dim)
94 | self.key_norm = RMSNorm(dim)
95 |
96 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
97 | q = self.query_norm(q)
98 | k = self.key_norm(k)
99 | return q.to(v), k.to(v)
100 |
101 | class LoRALinearLayer(nn.Module):
102 | def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
103 | super().__init__()
104 |
105 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
106 | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
107 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
108 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
109 | self.network_alpha = network_alpha
110 | self.rank = rank
111 |
112 | nn.init.normal_(self.down.weight, std=1 / rank)
113 | nn.init.zeros_(self.up.weight)
114 |
115 | def forward(self, hidden_states):
116 | orig_dtype = hidden_states.dtype
117 | dtype = self.down.weight.dtype
118 |
119 | down_hidden_states = self.down(hidden_states.to(dtype))
120 | up_hidden_states = self.up(down_hidden_states)
121 |
122 | if self.network_alpha is not None:
123 | up_hidden_states *= self.network_alpha / self.rank
124 |
125 | return up_hidden_states.to(orig_dtype)
126 |
127 | class FLuxSelfAttnProcessor:
128 | def __call__(self, attn, x, pe, **attention_kwargs):
129 | qkv = attn.qkv(x)
130 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
131 | q, k = attn.norm(q, k, v)
132 | x = attention(q, k, v, pe=pe)
133 | x = attn.proj(x)
134 | return x
135 |
136 | class LoraFluxAttnProcessor(nn.Module):
137 |
138 | def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
139 | super().__init__()
140 | self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
141 | self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
142 | self.lora_weight = lora_weight
143 |
144 |
145 | def __call__(self, attn, x, pe, **attention_kwargs):
146 | qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
147 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
148 | q, k = attn.norm(q, k, v)
149 | x = attention(q, k, v, pe=pe)
150 | x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
151 | return x
152 |
153 | class SelfAttention(nn.Module):
154 | def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
155 | super().__init__()
156 | self.num_heads = num_heads
157 | head_dim = dim // num_heads
158 |
159 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
160 | self.norm = QKNorm(head_dim)
161 | self.proj = nn.Linear(dim, dim)
162 | def forward():
163 | pass
164 |
165 |
166 | @dataclass
167 | class ModulationOut:
168 | shift: Tensor
169 | scale: Tensor
170 | gate: Tensor
171 |
172 |
173 | class Modulation(nn.Module):
174 | def __init__(self, dim: int, double: bool):
175 | super().__init__()
176 | self.is_double = double
177 | self.multiplier = 6 if double else 3
178 | self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
179 |
180 | def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
181 | out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
182 |
183 | return (
184 | ModulationOut(*out[:3]),
185 | ModulationOut(*out[3:]) if self.is_double else None,
186 | )
187 |
188 | class DoubleStreamBlockLoraProcessor(nn.Module):
189 | def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
190 | super().__init__()
191 | self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
192 | self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
193 | self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
194 | self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
195 | self.lora_weight = lora_weight
196 |
197 | def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
198 | img_mod1, img_mod2 = attn.img_mod(vec)
199 | txt_mod1, txt_mod2 = attn.txt_mod(vec)
200 |
201 | # prepare image for attention
202 | img_modulated = attn.img_norm1(img)
203 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
204 | img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
205 | img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
206 | img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
207 |
208 | # prepare txt for attention
209 | txt_modulated = attn.txt_norm1(txt)
210 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
211 | txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
212 | txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
213 | txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
214 |
215 | # run actual attention
216 | q = torch.cat((txt_q, img_q), dim=2)
217 | k = torch.cat((txt_k, img_k), dim=2)
218 | v = torch.cat((txt_v, img_v), dim=2)
219 |
220 | attn1 = attention(q, k, v, pe=pe)
221 | txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
222 |
223 | # calculate the img bloks
224 | img = img + img_mod1.gate * (attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight)
225 | img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
226 |
227 | # calculate the txt bloks
228 | txt = txt + txt_mod1.gate * (attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight)
229 | txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
230 | return img, txt
231 |
232 | class DoubleStreamBlockProcessor:
233 | def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
234 | img_mod1, img_mod2 = attn.img_mod(vec)
235 | txt_mod1, txt_mod2 = attn.txt_mod(vec)
236 |
237 | # prepare image for attention
238 | img_modulated = attn.img_norm1(img)
239 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
240 | img_qkv = attn.img_attn.qkv(img_modulated)
241 | img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
242 | img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
243 |
244 | # prepare txt for attention
245 | txt_modulated = attn.txt_norm1(txt)
246 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
247 | txt_qkv = attn.txt_attn.qkv(txt_modulated)
248 | txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
249 | txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
250 |
251 | # run actual attention
252 | q = torch.cat((txt_q, img_q), dim=2)
253 | k = torch.cat((txt_k, img_k), dim=2)
254 | v = torch.cat((txt_v, img_v), dim=2)
255 |
256 | attn1 = attention(q, k, v, pe=pe)
257 | txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
258 |
259 | # calculate the img bloks
260 | img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
261 | img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
262 |
263 | # calculate the txt bloks
264 | txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
265 | txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
266 | return img, txt
267 |
268 | class DoubleStreamBlock(nn.Module):
269 | def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
270 | super().__init__()
271 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
272 | self.num_heads = num_heads
273 | self.hidden_size = hidden_size
274 | self.head_dim = hidden_size // num_heads
275 |
276 | self.img_mod = Modulation(hidden_size, double=True)
277 | self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
278 | self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
279 |
280 | self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
281 | self.img_mlp = nn.Sequential(
282 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
283 | nn.GELU(approximate="tanh"),
284 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
285 | )
286 |
287 | self.txt_mod = Modulation(hidden_size, double=True)
288 | self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
289 | self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
290 |
291 | self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
292 | self.txt_mlp = nn.Sequential(
293 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
294 | nn.GELU(approximate="tanh"),
295 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
296 | )
297 | processor = DoubleStreamBlockProcessor()
298 | self.set_processor(processor)
299 |
300 | def set_processor(self, processor) -> None:
301 | self.processor = processor
302 |
303 | def get_processor(self):
304 | return self.processor
305 |
306 | def forward(
307 | self,
308 | img: Tensor,
309 | txt: Tensor,
310 | vec: Tensor,
311 | pe: Tensor,
312 | image_proj: Tensor = None,
313 | ip_scale: float =1.0,
314 | ) -> tuple[Tensor, Tensor]:
315 | if image_proj is None:
316 | return self.processor(self, img, txt, vec, pe)
317 | else:
318 | return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
319 |
320 |
321 | class SingleStreamBlockLoraProcessor(nn.Module):
322 | def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1):
323 | super().__init__()
324 | self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
325 | self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha)
326 | self.lora_weight = lora_weight
327 |
328 | def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
329 |
330 | mod, _ = attn.modulation(vec)
331 | x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
332 | qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
333 | qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
334 |
335 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
336 | q, k = attn.norm(q, k, v)
337 |
338 | # compute attention
339 | attn_1 = attention(q, k, v, pe=pe)
340 |
341 | # compute activation in mlp stream, cat again and run second linear layer
342 | output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
343 | output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight
344 | output = x + mod.gate * output
345 | return output
346 |
347 |
348 | class SingleStreamBlockProcessor:
349 | def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, **attention_kwargs) -> Tensor:
350 |
351 | mod, _ = attn.modulation(vec)
352 | x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
353 | qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
354 |
355 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
356 | q, k = attn.norm(q, k, v)
357 |
358 | # compute attention
359 | attn_1 = attention(q, k, v, pe=pe)
360 |
361 | # compute activation in mlp stream, cat again and run second linear layer
362 | output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
363 | output = x + mod.gate * output
364 | return output
365 |
366 | class SingleStreamBlock(nn.Module):
367 | """
368 | A DiT block with parallel linear layers as described in
369 | https://arxiv.org/abs/2302.05442 and adapted modulation interface.
370 | """
371 |
372 | def __init__(
373 | self,
374 | hidden_size: int,
375 | num_heads: int,
376 | mlp_ratio: float = 4.0,
377 | qk_scale: float | None = None,
378 | ):
379 | super().__init__()
380 | self.hidden_dim = hidden_size
381 | self.num_heads = num_heads
382 | self.head_dim = hidden_size // num_heads
383 | self.scale = qk_scale or self.head_dim**-0.5
384 |
385 | self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
386 | # qkv and mlp_in
387 | self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
388 | # proj and mlp_out
389 | self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
390 |
391 | self.norm = QKNorm(self.head_dim)
392 |
393 | self.hidden_size = hidden_size
394 | self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
395 |
396 | self.mlp_act = nn.GELU(approximate="tanh")
397 | self.modulation = Modulation(hidden_size, double=False)
398 |
399 | processor = SingleStreamBlockProcessor()
400 | self.set_processor(processor)
401 |
402 |
403 | def set_processor(self, processor) -> None:
404 | self.processor = processor
405 |
406 | def get_processor(self):
407 | return self.processor
408 |
409 | def forward(
410 | self,
411 | x: Tensor,
412 | vec: Tensor,
413 | pe: Tensor,
414 | image_proj: Tensor | None = None,
415 | ip_scale: float = 1.0,
416 | ) -> Tensor:
417 | if image_proj is None:
418 | return self.processor(self, x, vec, pe)
419 | else:
420 | return self.processor(self, x, vec, pe, image_proj, ip_scale)
421 |
422 |
423 |
424 | class LastLayer(nn.Module):
425 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
426 | super().__init__()
427 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
428 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
429 | self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
430 |
431 | def forward(self, x: Tensor, vec: Tensor) -> Tensor:
432 | shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
433 | x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
434 | x = self.linear(x)
435 | return x
436 |
--------------------------------------------------------------------------------
/uno/flux/pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | from typing import Literal
18 |
19 | import torch
20 | from einops import rearrange
21 | from PIL import ExifTags, Image
22 | import torchvision.transforms.functional as TVF
23 |
24 | from uno.flux.modules.layers import (
25 | DoubleStreamBlockLoraProcessor,
26 | DoubleStreamBlockProcessor,
27 | SingleStreamBlockLoraProcessor,
28 | SingleStreamBlockProcessor,
29 | )
30 | from uno.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack
31 | from uno.flux.util import (
32 | get_lora_rank,
33 | load_ae,
34 | load_checkpoint,
35 | load_clip,
36 | load_flow_model,
37 | load_flow_model_only_lora,
38 | load_flow_model_quintized,
39 | load_t5,
40 | )
41 |
42 |
43 | def find_nearest_scale(image_h, image_w, predefined_scales):
44 | """
45 | 根据图片的高度和宽度,找到最近的预定义尺度。
46 |
47 | :param image_h: 图片的高度
48 | :param image_w: 图片的宽度
49 | :param predefined_scales: 预定义尺度列表 [(h1, w1), (h2, w2), ...]
50 | :return: 最近的预定义尺度 (h, w)
51 | """
52 | # 计算输入图片的长宽比
53 | image_ratio = image_h / image_w
54 |
55 | # 初始化变量以存储最小差异和最近的尺度
56 | min_diff = float('inf')
57 | nearest_scale = None
58 |
59 | # 遍历所有预定义尺度,找到与输入图片长宽比最接近的尺度
60 | for scale_h, scale_w in predefined_scales:
61 | predefined_ratio = scale_h / scale_w
62 | diff = abs(predefined_ratio - image_ratio)
63 |
64 | if diff < min_diff:
65 | min_diff = diff
66 | nearest_scale = (scale_h, scale_w)
67 |
68 | return nearest_scale
69 |
70 | def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
71 | # 获取原始图像的宽度和高度
72 | image_w, image_h = raw_image.size
73 |
74 | # 计算长边和短边
75 | if image_w >= image_h:
76 | new_w = long_size
77 | new_h = int((long_size / image_w) * image_h)
78 | else:
79 | new_h = long_size
80 | new_w = int((long_size / image_h) * image_w)
81 |
82 | # 按新的宽高进行等比例缩放
83 | raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
84 | target_w = new_w // 16 * 16
85 | target_h = new_h // 16 * 16
86 |
87 | # 计算裁剪的起始坐标以实现中心裁剪
88 | left = (new_w - target_w) // 2
89 | top = (new_h - target_h) // 2
90 | right = left + target_w
91 | bottom = top + target_h
92 |
93 | # 进行中心裁剪
94 | raw_image = raw_image.crop((left, top, right, bottom))
95 |
96 | # 转换为 RGB 模式
97 | raw_image = raw_image.convert("RGB")
98 | return raw_image
99 |
100 | class UNOPipeline:
101 | def __init__(
102 | self,
103 | model_type: str,
104 | device: torch.device,
105 | offload: bool = False,
106 | only_lora: bool = False,
107 | lora_rank: int = 16
108 | ):
109 | self.device = device
110 | self.offload = offload
111 | self.model_type = model_type
112 |
113 | self.clip = load_clip(self.device)
114 | self.t5 = load_t5(self.device, max_length=512)
115 | self.ae = load_ae(model_type, device="cpu" if offload else self.device)
116 | self.use_fp8 = "fp8" in model_type
117 | if only_lora:
118 | self.model = load_flow_model_only_lora(
119 | model_type,
120 | device="cpu" if offload else self.device,
121 | lora_rank=lora_rank,
122 | use_fp8=self.use_fp8
123 | )
124 | else:
125 | self.model = load_flow_model(model_type, device="cpu" if offload else self.device)
126 |
127 |
128 | def load_ckpt(self, ckpt_path):
129 | if ckpt_path is not None:
130 | from safetensors.torch import load_file as load_sft
131 | print("Loading checkpoint to replace old keys")
132 | # load_sft doesn't support torch.device
133 | if ckpt_path.endswith('safetensors'):
134 | sd = load_sft(ckpt_path, device='cpu')
135 | missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
136 | else:
137 | dit_state = torch.load(ckpt_path, map_location='cpu')
138 | sd = {}
139 | for k in dit_state.keys():
140 | sd[k.replace('module.','')] = dit_state[k]
141 | missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
142 | self.model.to(str(self.device))
143 | print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}")
144 |
145 | def set_lora(self, local_path: str = None, repo_id: str = None,
146 | name: str = None, lora_weight: int = 0.7):
147 | checkpoint = load_checkpoint(local_path, repo_id, name)
148 | self.update_model_with_lora(checkpoint, lora_weight)
149 |
150 | def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7):
151 | checkpoint = load_checkpoint(
152 | None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
153 | )
154 | self.update_model_with_lora(checkpoint, lora_weight)
155 |
156 | def update_model_with_lora(self, checkpoint, lora_weight):
157 | rank = get_lora_rank(checkpoint)
158 | lora_attn_procs = {}
159 |
160 | for name, _ in self.model.attn_processors.items():
161 | lora_state_dict = {}
162 | for k in checkpoint.keys():
163 | if name in k:
164 | lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight
165 |
166 | if len(lora_state_dict):
167 | if name.startswith("single_blocks"):
168 | lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank)
169 | else:
170 | lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank)
171 | lora_attn_procs[name].load_state_dict(lora_state_dict)
172 | lora_attn_procs[name].to(self.device)
173 | else:
174 | if name.startswith("single_blocks"):
175 | lora_attn_procs[name] = SingleStreamBlockProcessor()
176 | else:
177 | lora_attn_procs[name] = DoubleStreamBlockProcessor()
178 |
179 | self.model.set_attn_processor(lora_attn_procs)
180 |
181 |
182 | def __call__(
183 | self,
184 | prompt: str,
185 | width: int = 512,
186 | height: int = 512,
187 | guidance: float = 4,
188 | num_steps: int = 50,
189 | seed: int = 123456789,
190 | **kwargs
191 | ):
192 | width = 16 * (width // 16)
193 | height = 16 * (height // 16)
194 |
195 | device_type = self.device if isinstance(self.device, str) else self.device.type
196 | if device_type == "mps":
197 | device_type = "cpu" # for support macos mps
198 | with torch.autocast(enabled=self.use_fp8, device_type=device_type, dtype=torch.bfloat16):
199 | return self.forward(
200 | prompt,
201 | width,
202 | height,
203 | guidance,
204 | num_steps,
205 | seed,
206 | **kwargs
207 | )
208 |
209 | @torch.inference_mode()
210 | def gradio_generate(
211 | self,
212 | prompt: str,
213 | width: int,
214 | height: int,
215 | guidance: float,
216 | num_steps: int,
217 | seed: int,
218 | image_prompt1: Image.Image,
219 | image_prompt2: Image.Image,
220 | image_prompt3: Image.Image,
221 | image_prompt4: Image.Image,
222 | ):
223 | ref_imgs = [image_prompt1, image_prompt2, image_prompt3, image_prompt4]
224 | ref_imgs = [img for img in ref_imgs if isinstance(img, Image.Image)]
225 | ref_long_side = 512 if len(ref_imgs) <= 1 else 320
226 | ref_imgs = [preprocess_ref(img, ref_long_side) for img in ref_imgs]
227 |
228 | seed = seed if seed != -1 else torch.randint(0, 10 ** 8, (1,)).item()
229 |
230 | img = self(prompt=prompt, width=width, height=height, guidance=guidance,
231 | num_steps=num_steps, seed=seed, ref_imgs=ref_imgs)
232 |
233 | filename = f"output/gradio/{seed}_{prompt[:20]}.png"
234 | os.makedirs(os.path.dirname(filename), exist_ok=True)
235 | exif_data = Image.Exif()
236 | exif_data[ExifTags.Base.Make] = "UNO"
237 | exif_data[ExifTags.Base.Model] = self.model_type
238 | info = f"{prompt=}, {seed=}, {width=}, {height=}, {guidance=}, {num_steps=}"
239 | exif_data[ExifTags.Base.ImageDescription] = info
240 | img.save(filename, format="png", exif=exif_data)
241 | return img, filename
242 |
243 | @torch.inference_mode
244 | def forward(
245 | self,
246 | prompt: str,
247 | width: int,
248 | height: int,
249 | guidance: float,
250 | num_steps: int,
251 | seed: int,
252 | ref_imgs: list[Image.Image] | None = None,
253 | pe: Literal['d', 'h', 'w', 'o'] = 'd',
254 | ):
255 | x = get_noise(
256 | 1, height, width, device=self.device,
257 | dtype=torch.bfloat16, seed=seed
258 | )
259 | timesteps = get_schedule(
260 | num_steps,
261 | (width // 8) * (height // 8) // (16 * 16),
262 | shift=True,
263 | )
264 | if self.offload:
265 | self.ae.encoder = self.ae.encoder.to(self.device)
266 | x_1_refs = [
267 | self.ae.encode(
268 | (TVF.to_tensor(ref_img) * 2.0 - 1.0)
269 | .unsqueeze(0).to(self.device, torch.float32)
270 | ).to(torch.bfloat16)
271 | for ref_img in ref_imgs
272 | ]
273 |
274 | if self.offload:
275 | self.offload_model_to_cpu(self.ae.encoder)
276 | self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
277 | inp_cond = prepare_multi_ip(
278 | t5=self.t5, clip=self.clip,
279 | img=x,
280 | prompt=prompt, ref_imgs=x_1_refs, pe=pe
281 | )
282 |
283 | if self.offload:
284 | self.offload_model_to_cpu(self.t5, self.clip)
285 | self.model = self.model.to(self.device)
286 |
287 | x = denoise(
288 | self.model,
289 | **inp_cond,
290 | timesteps=timesteps,
291 | guidance=guidance,
292 | )
293 |
294 | if self.offload:
295 | self.offload_model_to_cpu(self.model)
296 | self.ae.decoder.to(x.device)
297 | x = unpack(x.float(), height, width)
298 | x = self.ae.decode(x)
299 | self.offload_model_to_cpu(self.ae.decoder)
300 |
301 | x1 = x.clamp(-1, 1)
302 | x1 = rearrange(x1[-1], "c h w -> h w c")
303 | output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
304 | return output_img
305 |
306 | def offload_model_to_cpu(self, *models):
307 | if not self.offload: return
308 | for model in models:
309 | model.cpu()
310 | torch.cuda.empty_cache()
311 |
--------------------------------------------------------------------------------
/uno/flux/sampling.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import math
17 | from typing import Literal
18 |
19 | import torch
20 | from einops import rearrange, repeat
21 | from torch import Tensor
22 | from tqdm import tqdm
23 |
24 | from .model import Flux
25 | from .modules.conditioner import HFEmbedder
26 |
27 |
28 | def get_noise(
29 | num_samples: int,
30 | height: int,
31 | width: int,
32 | device: torch.device,
33 | dtype: torch.dtype,
34 | seed: int,
35 | ):
36 | return torch.randn(
37 | num_samples,
38 | 16,
39 | # allow for packing
40 | 2 * math.ceil(height / 16),
41 | 2 * math.ceil(width / 16),
42 | device=device,
43 | dtype=dtype,
44 | generator=torch.Generator(device=device).manual_seed(seed),
45 | )
46 |
47 |
48 | def prepare(
49 | t5: HFEmbedder,
50 | clip: HFEmbedder,
51 | img: Tensor,
52 | prompt: str | list[str],
53 | ref_img: None | Tensor=None,
54 | pe: Literal['d', 'h', 'w', 'o'] ='d'
55 | ) -> dict[str, Tensor]:
56 | assert pe in ['d', 'h', 'w', 'o']
57 | bs, c, h, w = img.shape
58 | if bs == 1 and not isinstance(prompt, str):
59 | bs = len(prompt)
60 |
61 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
62 | if img.shape[0] == 1 and bs > 1:
63 | img = repeat(img, "1 ... -> bs ...", bs=bs)
64 |
65 | img_ids = torch.zeros(h // 2, w // 2, 3)
66 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
67 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
68 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
69 |
70 | if ref_img is not None:
71 | _, _, ref_h, ref_w = ref_img.shape
72 | ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
73 | if ref_img.shape[0] == 1 and bs > 1:
74 | ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
75 | ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)
76 | # img id分别在宽高偏移各自最大值
77 | h_offset = h // 2 if pe in {'d', 'h'} else 0
78 | w_offset = w // 2 if pe in {'d', 'w'} else 0
79 | ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] + h_offset
80 | ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] + w_offset
81 | ref_img_ids = repeat(ref_img_ids, "h w c -> b (h w) c", b=bs)
82 |
83 | if isinstance(prompt, str):
84 | prompt = [prompt]
85 | txt = t5(prompt)
86 | if txt.shape[0] == 1 and bs > 1:
87 | txt = repeat(txt, "1 ... -> bs ...", bs=bs)
88 | txt_ids = torch.zeros(bs, txt.shape[1], 3)
89 |
90 | vec = clip(prompt)
91 | if vec.shape[0] == 1 and bs > 1:
92 | vec = repeat(vec, "1 ... -> bs ...", bs=bs)
93 |
94 | if ref_img is not None:
95 | return {
96 | "img": img,
97 | "img_ids": img_ids.to(img.device),
98 | "ref_img": ref_img,
99 | "ref_img_ids": ref_img_ids.to(img.device),
100 | "txt": txt.to(img.device),
101 | "txt_ids": txt_ids.to(img.device),
102 | "vec": vec.to(img.device),
103 | }
104 | else:
105 | return {
106 | "img": img,
107 | "img_ids": img_ids.to(img.device),
108 | "txt": txt.to(img.device),
109 | "txt_ids": txt_ids.to(img.device),
110 | "vec": vec.to(img.device),
111 | }
112 |
113 | def prepare_multi_ip(
114 | t5: HFEmbedder,
115 | clip: HFEmbedder,
116 | img: Tensor,
117 | prompt: str | list[str],
118 | ref_imgs: list[Tensor] | None = None,
119 | pe: Literal['d', 'h', 'w', 'o'] = 'd'
120 | ) -> dict[str, Tensor]:
121 | assert pe in ['d', 'h', 'w', 'o']
122 | bs, c, h, w = img.shape
123 | if bs == 1 and not isinstance(prompt, str):
124 | bs = len(prompt)
125 |
126 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
127 | if img.shape[0] == 1 and bs > 1:
128 | img = repeat(img, "1 ... -> bs ...", bs=bs)
129 |
130 | img_ids = torch.zeros(h // 2, w // 2, 3)
131 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
132 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
133 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
134 |
135 | ref_img_ids = []
136 | ref_imgs_list = []
137 | pe_shift_w, pe_shift_h = w // 2, h // 2
138 | for ref_img in ref_imgs:
139 | _, _, ref_h1, ref_w1 = ref_img.shape
140 | ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
141 | if ref_img.shape[0] == 1 and bs > 1:
142 | ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
143 | ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3)
144 | # img id分别在宽高偏移各自最大值
145 | h_offset = pe_shift_h if pe in {'d', 'h'} else 0
146 | w_offset = pe_shift_w if pe in {'d', 'w'} else 0
147 | ref_img_ids1[..., 1] = ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset
148 | ref_img_ids1[..., 2] = ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset
149 | ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs)
150 | ref_img_ids.append(ref_img_ids1)
151 | ref_imgs_list.append(ref_img)
152 |
153 | # 更新pe shift
154 | pe_shift_h += ref_h1 // 2
155 | pe_shift_w += ref_w1 // 2
156 |
157 | if isinstance(prompt, str):
158 | prompt = [prompt]
159 | txt = t5(prompt)
160 | if txt.shape[0] == 1 and bs > 1:
161 | txt = repeat(txt, "1 ... -> bs ...", bs=bs)
162 | txt_ids = torch.zeros(bs, txt.shape[1], 3)
163 |
164 | vec = clip(prompt)
165 | if vec.shape[0] == 1 and bs > 1:
166 | vec = repeat(vec, "1 ... -> bs ...", bs=bs)
167 |
168 | return {
169 | "img": img,
170 | "img_ids": img_ids.to(img.device),
171 | "ref_img": tuple(ref_imgs_list),
172 | "ref_img_ids": [ref_img_id.to(img.device) for ref_img_id in ref_img_ids],
173 | "txt": txt.to(img.device),
174 | "txt_ids": txt_ids.to(img.device),
175 | "vec": vec.to(img.device),
176 | }
177 |
178 |
179 | def time_shift(mu: float, sigma: float, t: Tensor):
180 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
181 |
182 |
183 | def get_lin_function(
184 | x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
185 | ):
186 | m = (y2 - y1) / (x2 - x1)
187 | b = y1 - m * x1
188 | return lambda x: m * x + b
189 |
190 |
191 | def get_schedule(
192 | num_steps: int,
193 | image_seq_len: int,
194 | base_shift: float = 0.5,
195 | max_shift: float = 1.15,
196 | shift: bool = True,
197 | ) -> list[float]:
198 | # extra step for zero
199 | timesteps = torch.linspace(1, 0, num_steps + 1)
200 |
201 | # shifting the schedule to favor high timesteps for higher signal images
202 | if shift:
203 | # eastimate mu based on linear estimation between two points
204 | mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
205 | timesteps = time_shift(mu, 1.0, timesteps)
206 |
207 | return timesteps.tolist()
208 |
209 |
210 | def denoise(
211 | model: Flux,
212 | # model input
213 | img: Tensor,
214 | img_ids: Tensor,
215 | txt: Tensor,
216 | txt_ids: Tensor,
217 | vec: Tensor,
218 | # sampling parameters
219 | timesteps: list[float],
220 | guidance: float = 4.0,
221 | ref_img: Tensor=None,
222 | ref_img_ids: Tensor=None,
223 | ):
224 | i = 0
225 | guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
226 | for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1):
227 | t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
228 | pred = model(
229 | img=img,
230 | img_ids=img_ids,
231 | ref_img=ref_img,
232 | ref_img_ids=ref_img_ids,
233 | txt=txt,
234 | txt_ids=txt_ids,
235 | y=vec,
236 | timesteps=t_vec,
237 | guidance=guidance_vec
238 | )
239 | img = img + (t_prev - t_curr) * pred
240 | i += 1
241 | return img
242 |
243 |
244 | def unpack(x: Tensor, height: int, width: int) -> Tensor:
245 | return rearrange(
246 | x,
247 | "b (h w) (c ph pw) -> b c (h ph) (w pw)",
248 | h=math.ceil(height / 16),
249 | w=math.ceil(width / 16),
250 | ph=2,
251 | pw=2,
252 | )
253 |
--------------------------------------------------------------------------------
/uno/flux/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 |
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | from dataclasses import dataclass
18 |
19 | import torch
20 | import json
21 | import numpy as np
22 | from huggingface_hub import hf_hub_download
23 | from safetensors import safe_open
24 | from safetensors.torch import load_file as load_sft
25 |
26 | from .model import Flux, FluxParams
27 | from .modules.autoencoder import AutoEncoder, AutoEncoderParams
28 | from .modules.conditioner import HFEmbedder
29 |
30 | import re
31 | from uno.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor
32 | def load_model(ckpt, device='cpu'):
33 | if ckpt.endswith('safetensors'):
34 | from safetensors import safe_open
35 | pl_sd = {}
36 | with safe_open(ckpt, framework="pt", device=device) as f:
37 | for k in f.keys():
38 | pl_sd[k] = f.get_tensor(k)
39 | else:
40 | pl_sd = torch.load(ckpt, map_location=device)
41 | return pl_sd
42 |
43 | def load_safetensors(path):
44 | tensors = {}
45 | with safe_open(path, framework="pt", device="cpu") as f:
46 | for key in f.keys():
47 | tensors[key] = f.get_tensor(key)
48 | return tensors
49 |
50 | def get_lora_rank(checkpoint):
51 | for k in checkpoint.keys():
52 | if k.endswith(".down.weight"):
53 | return checkpoint[k].shape[0]
54 |
55 | def load_checkpoint(local_path, repo_id, name):
56 | if local_path is not None:
57 | if '.safetensors' in local_path:
58 | print(f"Loading .safetensors checkpoint from {local_path}")
59 | checkpoint = load_safetensors(local_path)
60 | else:
61 | print(f"Loading checkpoint from {local_path}")
62 | checkpoint = torch.load(local_path, map_location='cpu')
63 | elif repo_id is not None and name is not None:
64 | print(f"Loading checkpoint {name} from repo id {repo_id}")
65 | checkpoint = load_from_repo_id(repo_id, name)
66 | else:
67 | raise ValueError(
68 | "LOADING ERROR: you must specify local_path or repo_id with name in HF to download"
69 | )
70 | return checkpoint
71 |
72 |
73 | def c_crop(image):
74 | width, height = image.size
75 | new_size = min(width, height)
76 | left = (width - new_size) / 2
77 | top = (height - new_size) / 2
78 | right = (width + new_size) / 2
79 | bottom = (height + new_size) / 2
80 | return image.crop((left, top, right, bottom))
81 |
82 | def pad64(x):
83 | return int(np.ceil(float(x) / 64.0) * 64 - x)
84 |
85 | def HWC3(x):
86 | assert x.dtype == np.uint8
87 | if x.ndim == 2:
88 | x = x[:, :, None]
89 | assert x.ndim == 3
90 | H, W, C = x.shape
91 | assert C == 1 or C == 3 or C == 4
92 | if C == 3:
93 | return x
94 | if C == 1:
95 | return np.concatenate([x, x, x], axis=2)
96 | if C == 4:
97 | color = x[:, :, 0:3].astype(np.float32)
98 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0
99 | y = color * alpha + 255.0 * (1.0 - alpha)
100 | y = y.clip(0, 255).astype(np.uint8)
101 | return y
102 |
103 | @dataclass
104 | class ModelSpec:
105 | params: FluxParams
106 | ae_params: AutoEncoderParams
107 | ckpt_path: str | None
108 | ae_path: str | None
109 | repo_id: str | None
110 | repo_flow: str | None
111 | repo_ae: str | None
112 | repo_id_ae: str | None
113 |
114 |
115 | configs = {
116 | "flux-dev": ModelSpec(
117 | repo_id="black-forest-labs/FLUX.1-dev",
118 | repo_id_ae="black-forest-labs/FLUX.1-dev",
119 | repo_flow="flux1-dev.safetensors",
120 | repo_ae="ae.safetensors",
121 | ckpt_path=os.getenv("FLUX_DEV"),
122 | params=FluxParams(
123 | in_channels=64,
124 | vec_in_dim=768,
125 | context_in_dim=4096,
126 | hidden_size=3072,
127 | mlp_ratio=4.0,
128 | num_heads=24,
129 | depth=19,
130 | depth_single_blocks=38,
131 | axes_dim=[16, 56, 56],
132 | theta=10_000,
133 | qkv_bias=True,
134 | guidance_embed=True,
135 | ),
136 | ae_path=os.getenv("AE"),
137 | ae_params=AutoEncoderParams(
138 | resolution=256,
139 | in_channels=3,
140 | ch=128,
141 | out_ch=3,
142 | ch_mult=[1, 2, 4, 4],
143 | num_res_blocks=2,
144 | z_channels=16,
145 | scale_factor=0.3611,
146 | shift_factor=0.1159,
147 | ),
148 | ),
149 | "flux-dev-fp8": ModelSpec(
150 | repo_id="black-forest-labs/FLUX.1-dev",
151 | repo_id_ae="black-forest-labs/FLUX.1-dev",
152 | repo_flow="flux1-dev.safetensors",
153 | repo_ae="ae.safetensors",
154 | ckpt_path=os.getenv("FLUX_DEV_FP8"),
155 | params=FluxParams(
156 | in_channels=64,
157 | vec_in_dim=768,
158 | context_in_dim=4096,
159 | hidden_size=3072,
160 | mlp_ratio=4.0,
161 | num_heads=24,
162 | depth=19,
163 | depth_single_blocks=38,
164 | axes_dim=[16, 56, 56],
165 | theta=10_000,
166 | qkv_bias=True,
167 | guidance_embed=True,
168 | ),
169 | ae_path=os.getenv("AE"),
170 | ae_params=AutoEncoderParams(
171 | resolution=256,
172 | in_channels=3,
173 | ch=128,
174 | out_ch=3,
175 | ch_mult=[1, 2, 4, 4],
176 | num_res_blocks=2,
177 | z_channels=16,
178 | scale_factor=0.3611,
179 | shift_factor=0.1159,
180 | ),
181 | ),
182 | "flux-schnell": ModelSpec(
183 | repo_id="black-forest-labs/FLUX.1-schnell",
184 | repo_id_ae="black-forest-labs/FLUX.1-dev",
185 | repo_flow="flux1-schnell.safetensors",
186 | repo_ae="ae.safetensors",
187 | ckpt_path=os.getenv("FLUX_SCHNELL"),
188 | params=FluxParams(
189 | in_channels=64,
190 | vec_in_dim=768,
191 | context_in_dim=4096,
192 | hidden_size=3072,
193 | mlp_ratio=4.0,
194 | num_heads=24,
195 | depth=19,
196 | depth_single_blocks=38,
197 | axes_dim=[16, 56, 56],
198 | theta=10_000,
199 | qkv_bias=True,
200 | guidance_embed=False,
201 | ),
202 | ae_path=os.getenv("AE"),
203 | ae_params=AutoEncoderParams(
204 | resolution=256,
205 | in_channels=3,
206 | ch=128,
207 | out_ch=3,
208 | ch_mult=[1, 2, 4, 4],
209 | num_res_blocks=2,
210 | z_channels=16,
211 | scale_factor=0.3611,
212 | shift_factor=0.1159,
213 | ),
214 | ),
215 | }
216 |
217 |
218 | def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
219 | if len(missing) > 0 and len(unexpected) > 0:
220 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
221 | print("\n" + "-" * 79 + "\n")
222 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
223 | elif len(missing) > 0:
224 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
225 | elif len(unexpected) > 0:
226 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
227 |
228 | def load_from_repo_id(repo_id, checkpoint_name):
229 | ckpt_path = hf_hub_download(repo_id, checkpoint_name)
230 | sd = load_sft(ckpt_path, device='cpu')
231 | return sd
232 |
233 | def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
234 | # Loading Flux
235 | print("Init model")
236 | ckpt_path = configs[name].ckpt_path
237 | if (
238 | ckpt_path is None
239 | and configs[name].repo_id is not None
240 | and configs[name].repo_flow is not None
241 | and hf_download
242 | ):
243 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
244 |
245 | with torch.device("meta" if ckpt_path is not None else device):
246 | model = Flux(configs[name].params).to(torch.bfloat16)
247 |
248 | if ckpt_path is not None:
249 | print("Loading checkpoint")
250 | # load_sft doesn't support torch.device
251 | sd = load_model(ckpt_path, device=str(device))
252 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
253 | print_load_warning(missing, unexpected)
254 | return model
255 |
256 | def load_flow_model_only_lora(
257 | name: str,
258 | device: str | torch.device = "cuda",
259 | hf_download: bool = True,
260 | lora_rank: int = 16,
261 | use_fp8: bool = False
262 | ):
263 | # Loading Flux
264 | print("Init model")
265 | ckpt_path = configs[name].ckpt_path
266 | if (
267 | ckpt_path is None
268 | and configs[name].repo_id is not None
269 | and configs[name].repo_flow is not None
270 | and hf_download
271 | ):
272 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
273 |
274 | if hf_download:
275 | try:
276 | lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
277 | except:
278 | lora_ckpt_path = os.environ.get("LORA", None)
279 | else:
280 | lora_ckpt_path = os.environ.get("LORA", None)
281 |
282 | with torch.device("meta" if ckpt_path is not None else device):
283 | model = Flux(configs[name].params)
284 |
285 |
286 | model = set_lora(model, lora_rank, device="meta" if lora_ckpt_path is not None else device)
287 |
288 | if ckpt_path is not None:
289 | print("Loading lora")
290 | lora_sd = load_sft(lora_ckpt_path, device=str(device)) if lora_ckpt_path.endswith("safetensors")\
291 | else torch.load(lora_ckpt_path, map_location='cpu')
292 |
293 | print("Loading main checkpoint")
294 | # load_sft doesn't support torch.device
295 |
296 | if ckpt_path.endswith('safetensors'):
297 | if use_fp8:
298 | print(
299 | "####\n"
300 | "We are in fp8 mode right now, since the fp8 checkpoint of XLabs-AI/flux-dev-fp8 seems broken\n"
301 | "we convert the fp8 checkpoint on flight from bf16 checkpoint\n"
302 | "If your storage is constrained"
303 | "you can save the fp8 checkpoint and replace the bf16 checkpoint by yourself\n"
304 | )
305 | sd = load_sft(ckpt_path, device="cpu")
306 | sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()}
307 | else:
308 | sd = load_sft(ckpt_path, device=str(device))
309 |
310 | sd.update(lora_sd)
311 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
312 | else:
313 | dit_state = torch.load(ckpt_path, map_location='cpu')
314 | sd = {}
315 | for k in dit_state.keys():
316 | sd[k.replace('module.','')] = dit_state[k]
317 | sd.update(lora_sd)
318 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
319 | model.to(str(device))
320 | print_load_warning(missing, unexpected)
321 | return model
322 |
323 |
324 | def set_lora(
325 | model: Flux,
326 | lora_rank: int,
327 | double_blocks_indices: list[int] | None = None,
328 | single_blocks_indices: list[int] | None = None,
329 | device: str | torch.device = "cpu",
330 | ) -> Flux:
331 | double_blocks_indices = list(range(model.params.depth)) if double_blocks_indices is None else double_blocks_indices
332 | single_blocks_indices = list(range(model.params.depth_single_blocks)) if single_blocks_indices is None \
333 | else single_blocks_indices
334 |
335 | lora_attn_procs = {}
336 | with torch.device(device):
337 | for name, attn_processor in model.attn_processors.items():
338 | match = re.search(r'\.(\d+)\.', name)
339 | if match:
340 | layer_index = int(match.group(1))
341 |
342 | if name.startswith("double_blocks") and layer_index in double_blocks_indices:
343 | lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
344 | elif name.startswith("single_blocks") and layer_index in single_blocks_indices:
345 | lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank)
346 | else:
347 | lora_attn_procs[name] = attn_processor
348 | model.set_attn_processor(lora_attn_procs)
349 | return model
350 |
351 |
352 | def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
353 | # Loading Flux
354 | from optimum.quanto import requantize
355 | print("Init model")
356 | ckpt_path = configs[name].ckpt_path
357 | if (
358 | ckpt_path is None
359 | and configs[name].repo_id is not None
360 | and configs[name].repo_flow is not None
361 | and hf_download
362 | ):
363 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
364 | # json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json')
365 |
366 |
367 | model = Flux(configs[name].params).to(torch.bfloat16)
368 |
369 | print("Loading checkpoint")
370 | # load_sft doesn't support torch.device
371 | sd = load_sft(ckpt_path, device='cpu')
372 | sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()}
373 | model.load_state_dict(sd, assign=True)
374 | return model
375 | with open(json_path, "r") as f:
376 | quantization_map = json.load(f)
377 | print("Start a quantization process...")
378 | requantize(model, sd, quantization_map, device=device)
379 | print("Model is quantized!")
380 | return model
381 |
382 | def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
383 | # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
384 | version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders")
385 | return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(device)
386 |
387 | def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
388 | version = os.environ.get("CLIP", "openai/clip-vit-large-patch14")
389 | return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device)
390 |
391 |
392 | def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
393 | ckpt_path = configs[name].ae_path
394 | if (
395 | ckpt_path is None
396 | and configs[name].repo_id is not None
397 | and configs[name].repo_ae is not None
398 | and hf_download
399 | ):
400 | ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
401 |
402 | # Loading the autoencoder
403 | print("Init AE")
404 | with torch.device("meta" if ckpt_path is not None else device):
405 | ae = AutoEncoder(configs[name].ae_params)
406 |
407 | if ckpt_path is not None:
408 | sd = load_sft(ckpt_path, device=str(device))
409 | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
410 | print_load_warning(missing, unexpected)
411 | return ae
--------------------------------------------------------------------------------
/uno/utils/convert_yaml_to_args_file.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 |
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import yaml
17 |
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("--yaml", type=str, required=True)
20 | parser.add_argument("--arg", type=str, required=True)
21 | args = parser.parse_args()
22 |
23 |
24 | with open(args.yaml, "r") as f:
25 | data = yaml.safe_load(f)
26 |
27 | with open(args.arg, "w") as f:
28 | for k, v in data.items():
29 | if isinstance(v, list):
30 | v = list(map(str, v))
31 | v = " ".join(v)
32 | if v is None:
33 | continue
34 | print(f"--{k} {v}", end=" ", file=f)
35 |
--------------------------------------------------------------------------------