├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── app.py ├── assets ├── cat_cafe.png ├── clock.png ├── crystal_ball.png ├── cup.png ├── examples │ ├── 0_0.json │ ├── 0_0.png │ ├── 1_0.json │ ├── 1_0.png │ ├── 1one2one │ │ ├── config.json │ │ ├── ref1.jpg │ │ └── result.png │ ├── 2_0.json │ ├── 2_0.png │ ├── 2one2one │ │ ├── config.json │ │ ├── ref1.png │ │ └── result.png │ ├── 3two2one │ │ ├── config.json │ │ ├── ref1.png │ │ ├── ref2.png │ │ └── result.png │ ├── 4two2one │ │ ├── config.json │ │ ├── ref1.png │ │ ├── ref2.png │ │ └── result.png │ ├── 5many2one │ │ ├── config.json │ │ ├── ref1.png │ │ ├── ref2.png │ │ ├── ref3.png │ │ └── result.png │ └── 6t2i │ │ ├── config.json │ │ └── result.png ├── figurine.png ├── logo.png ├── simplecase.jpeg ├── simplecase.jpg └── teaser.jpg ├── config └── deepspeed │ ├── zero2_config.json │ └── zero3_config.json ├── datasets ├── dreambench_multiip.json └── dreambench_singleip.json ├── inference.py ├── pyproject.toml ├── requirements.txt ├── train.py └── uno ├── dataset └── uno.py ├── flux ├── math.py ├── model.py ├── modules │ ├── autoencoder.py │ ├── conditioner.py │ └── layers.py ├── pipeline.py ├── sampling.py └── util.py └── utils └── convert_yaml_to_args_file.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | 176 | # User config files 177 | .vscode/ 178 | output/ 179 | 180 | # ckpt 181 | *.bin 182 | *.pt 183 | *.pth -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "datasets/dreambooth"] 2 | path = datasets/dreambooth 3 | url = https://github.com/google/dreambooth.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Logo 3 | Less-to-More Generalization: Unlocking More Controllability by In-Context Generation 4 |

5 | 6 |

7 | Build 8 | Build 9 | Build 10 | 11 | 12 |

13 | 14 | >

Shaojin Wu, Mengqi Huang*, Wenxu Wu, Yufeng Cheng, Fei Ding+, Qian He
15 | >Intelligent Creation Team, ByteDance

16 | 17 |

18 | 20 |

21 | 22 | ## 🔥 News 23 | - [04/16/2024] 🔥 Our companion project [RealCustom](https://github.com/bytedance/RealCustom) is released. 24 | - [04/10/2025] 🔥 Update fp8 mode as a primary low vmemory usage support. Gift for consumer-grade GPU users. The peak Vmemory usage is ~16GB now. We may try further inference optimization later. 25 | - [04/03/2025] 🔥 The [demo](https://huggingface.co/spaces/bytedance-research/UNO-FLUX) of UNO is released. 26 | - [04/03/2025] 🔥 The [training code](https://github.com/bytedance/UNO), [inference code](https://github.com/bytedance/UNO), and [model](https://huggingface.co/bytedance-research/UNO) of UNO are released. 27 | - [04/02/2025] 🔥 The [project page](https://bytedance.github.io/UNO) of UNO is created. 28 | - [04/02/2025] 🔥 The arXiv [paper](https://arxiv.org/abs/2504.02160) of UNO is released. 29 | 30 | ## 📖 Introduction 31 | In this study, we propose a highly-consistent data synthesis pipeline to tackle this challenge. This pipeline harnesses the intrinsic in-context generation capabilities of diffusion transformers and generates high-consistency multi-subject paired data. Additionally, we introduce UNO, which consists of progressive cross-modal alignment and universal rotary position embedding. It is a multi-image conditioned subject-to-image model iteratively trained from a text-to-image model. Extensive experiments show that our method can achieve high consistency while ensuring controllability in both single-subject and multi-subject driven generation. 32 | 33 | 34 | ## ⚡️ Quick Start 35 | 36 | ### 🔧 Requirements and Installation 37 | 38 | Install the requirements 39 | ```bash 40 | # pip install -r requirements.txt # legacy installation command 41 | 42 | ## create a virtual environment with python >= 3.10 <= 3.12, like 43 | # python -m venv uno_env 44 | # source uno_env/bin/activate 45 | # or 46 | # conda create -n uno_env python=3.10 -y 47 | # conda activate uno_env 48 | # then install the requirements by you need 49 | 50 | # !!! if you are using amd GPU/NV RTX50 series/macos MPS, you should install the correct torch version by yourself first 51 | # !!! then run the install command 52 | pip install -e . # for who wanna to run the demo/inference only 53 | pip install -e .[train] # for who also want to train the model 54 | ``` 55 | 56 | then download checkpoints in one of the three ways: 57 | 1. Directly run the inference scripts, the checkpoints will be downloaded automatically by the `hf_hub_download` function in the code to your `$HF_HOME`(the default value is `~/.cache/huggingface`). 58 | 2. use `huggingface-cli download ` to download `black-forest-labs/FLUX.1-dev`, `xlabs-ai/xflux_text_encoders`, `openai/clip-vit-large-patch14`, `bytedance-research/UNO`, then run the inference scripts. You can just download the checkpoint in need only to speed up your set up and save your disk space. i.e. for `black-forest-labs/FLUX.1-dev` use `huggingface-cli download black-forest-labs/FLUX.1-dev flux1-dev.safetensors` and `huggingface-cli download black-forest-labs/FLUX.1-dev ae.safetensors`, ignoreing the text encoder in `black-forest-labes/FLUX.1-dev` model repo(They are here for `diffusers` call). All of the checkpoints will take 37 GB of disk space. 59 | 3. use `huggingface-cli download --local-dir ` to download all the checkpoints mentioned in 2. to the directories your want. Then set the environment variable `AE`, `FLUX_DEV`(or `FLUX_DEV_FP8` if you use fp8 mode), `T5`, `CLIP`, `LORA` to the corresponding paths. Finally, run the inference scripts. 60 | 4. **If you already have some of the checkpoints**, you can set the environment variable `AE`, `FLUX_DEV`, `T5`, `CLIP`, `LORA` to the corresponding paths. Finally, run the inference scripts. 61 | 62 | ### 🌟 Gradio Demo 63 | 64 | ```bash 65 | python app.py 66 | ``` 67 | 68 | **For low vmemory usage**, please pass the `--offload` and `--name flux-dev-fp8` args. The peak memory usage will be 16GB. Just for reference, the end2end inference time is 40s to 1min on RTX 3090 in fp8 and offload mode. 69 | 70 | ```bash 71 | python app.py --offload --name flux-dev-fp8 72 | ``` 73 | 74 | 75 | ### ✍️ Inference 76 | Start from the examples below to explore and spark your creativity. ✨ 77 | ```bash 78 | python inference.py --prompt "A clock on the beach is under a red sun umbrella" --image_paths "assets/clock.png" --width 704 --height 704 79 | python inference.py --prompt "The figurine is in the crystal ball" --image_paths "assets/figurine.png" "assets/crystal_ball.png" --width 704 --height 704 80 | python inference.py --prompt "The logo is printed on the cup" --image_paths "assets/cat_cafe.png" "assets/cup.png" --width 704 --height 704 81 | ``` 82 | 83 | Optional prepreration: If you want to test the inference on dreambench at the first time, you should clone the submodule `dreambench` to download the dataset. 84 | 85 | ```bash 86 | git submodule update --init 87 | ``` 88 | Then running the following scripts: 89 | ```bash 90 | # evaluated on dreambench 91 | ## for single-subject 92 | python inference.py --eval_json_path ./datasets/dreambench_singleip.json 93 | ## for multi-subject 94 | python inference.py --eval_json_path ./datasets/dreambench_multiip.json 95 | ``` 96 | 97 | 98 | 99 | ### 🚄 Training 100 | 101 | ```bash 102 | accelerate launch train.py 103 | ``` 104 | 105 | 106 | ### 📌 Tips and Notes 107 | We integrate single-subject and multi-subject generation within a unified model. For single-subject scenarios, the longest side of the reference image is set to 512 by default, while for multi-subject scenarios, it is set to 320. UNO demonstrates remarkable flexibility across various aspect ratios, thanks to its training on a multi-scale dataset. Despite being trained within 512 buckets, it can handle higher resolutions, including 512, 568, and 704, among others. 108 | 109 | UNO excels in subject-driven generation but has room for improvement in generalization due to dataset constraints. We are actively developing an enhanced model—stay tuned for updates. Your feedback is valuable, so please feel free to share any suggestions. 110 | 111 | ## 🎨 Application Scenarios 112 |

113 | 115 |

116 | 117 | ## 📄 Disclaimer 118 |

119 | We open-source this project for academic research. The vast majority of images 120 | used in this project are either generated or licensed. If you have any concerns, 121 | please contact us, and we will promptly remove any inappropriate content. 122 | Our code is released under the Apache 2.0 License,, while our models are under 123 | the CC BY-NC 4.0 License. Any models related to FLUX.1-dev 124 | base model must adhere to the original licensing terms. 125 |

This research aims to advance the field of generative AI. Users are free to 126 | create images using this tool, provided they comply with local laws and exercise 127 | responsible usage. The developers are not liable for any misuse of the tool by users.

128 | 129 | ## 🚀 Updates 130 | For the purpose of fostering research and the open-source community, we plan to open-source the entire project, encompassing training, inference, weights, etc. Thank you for your patience and support! 🌟 131 | - [x] Release github repo. 132 | - [x] Release inference code. 133 | - [x] Release training code. 134 | - [x] Release model checkpoints. 135 | - [x] Release arXiv paper. 136 | - [x] Release huggingface space demo. 137 | - [ ] Release in-context data generation pipelines. 138 | 139 | ## Related resources 140 | 141 | **ComfyUI** 142 | 143 | - https://github.com/jax-explorer/ComfyUI-UNO a ComfyUI node implementation of UNO by jax-explorer. 144 | - https://github.com/HM-RunningHub/ComfyUI_RH_UNO a ComfyUI node implementation of UNO by HM-RunningHub. 145 | - https://github.com/ShmuelRonen/ComfyUI-UNO-Wrapper a ComfyUI node implementation of UNO by ShmuelRonen. 146 | - https://github.com/Yuan-ManX/ComfyUI-UNO a ComfyUI node implementation of UNO by Yuan-ManX. 147 | - https://github.com/QijiTec/ComfyUI-RED-UNO a ComfyUI node implementation of UNO by QijiTec. 148 | 149 | We thanks the passionate community contributors, since we have reviced many requests about comfyui, but there aren't so much time to make so many adaptations by ourselves. if you wanna try our work in comfyui, you can try the above repos. Remember, they are slightly different, so you may need some trail and error to make find the best match repo for you. 150 | 151 | ## Citation 152 | If UNO is helpful, please help to ⭐ the repo. 153 | 154 | If you find this project useful for your research, please consider citing our paper: 155 | ```bibtex 156 | @article{wu2025less, 157 | title={Less-to-More Generalization: Unlocking More Controllability by In-Context Generation}, 158 | author={Wu, Shaojin and Huang, Mengqi and Wu, Wenxu and Cheng, Yufeng and Ding, Fei and He, Qian}, 159 | journal={arXiv preprint arXiv:2504.02160}, 160 | year={2025} 161 | } 162 | ``` -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import dataclasses 16 | import json 17 | from pathlib import Path 18 | 19 | import gradio as gr 20 | import torch 21 | 22 | from uno.flux.pipeline import UNOPipeline 23 | 24 | 25 | def get_examples(examples_dir: str = "assets/examples") -> list: 26 | examples = Path(examples_dir) 27 | ans = [] 28 | for example in examples.iterdir(): 29 | if not example.is_dir(): 30 | continue 31 | with open(example / "config.json") as f: 32 | example_dict = json.load(f) 33 | 34 | 35 | example_list = [] 36 | 37 | example_list.append(example_dict["useage"]) # case for 38 | example_list.append(example_dict["prompt"]) # prompt 39 | 40 | for key in ["image_ref1", "image_ref2", "image_ref3", "image_ref4"]: 41 | if key in example_dict: 42 | example_list.append(str(example / example_dict[key])) 43 | else: 44 | example_list.append(None) 45 | 46 | example_list.append(example_dict["seed"]) 47 | 48 | ans.append(example_list) 49 | return ans 50 | 51 | 52 | def create_demo( 53 | model_type: str, 54 | device: str = "cuda" if torch.cuda.is_available() else "cpu", 55 | offload: bool = False, 56 | ): 57 | pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512) 58 | 59 | badges_text = r""" 60 |
61 | Build 62 | Build 63 | Build 64 | 65 | 66 |
67 | """.strip() 68 | 69 | with gr.Blocks() as demo: 70 | gr.Markdown(f"# UNO by UNO team") 71 | gr.Markdown(badges_text) 72 | with gr.Row(): 73 | with gr.Column(): 74 | prompt = gr.Textbox(label="Prompt", value="handsome woman in the city") 75 | with gr.Row(): 76 | image_prompt1 = gr.Image(label="Ref Img1", visible=True, interactive=True, type="pil") 77 | image_prompt2 = gr.Image(label="Ref Img2", visible=True, interactive=True, type="pil") 78 | image_prompt3 = gr.Image(label="Ref Img3", visible=True, interactive=True, type="pil") 79 | image_prompt4 = gr.Image(label="Ref img4", visible=True, interactive=True, type="pil") 80 | 81 | with gr.Row(): 82 | with gr.Column(): 83 | width = gr.Slider(512, 2048, 512, step=16, label="Gneration Width") 84 | height = gr.Slider(512, 2048, 512, step=16, label="Gneration Height") 85 | with gr.Column(): 86 | gr.Markdown("📌 The model trained on 512x512 resolution.\n") 87 | gr.Markdown( 88 | "The size closer to 512 is more stable," 89 | " and the higher size gives a better visual effect but is less stable" 90 | ) 91 | 92 | with gr.Accordion("Advanced Options", open=False): 93 | with gr.Row(): 94 | num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps") 95 | guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True) 96 | seed = gr.Number(-1, label="Seed (-1 for random)") 97 | 98 | generate_btn = gr.Button("Generate") 99 | 100 | with gr.Column(): 101 | output_image = gr.Image(label="Generated Image") 102 | download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False) 103 | 104 | 105 | inputs = [ 106 | prompt, width, height, guidance, num_steps, 107 | seed, image_prompt1, image_prompt2, image_prompt3, image_prompt4 108 | ] 109 | generate_btn.click( 110 | fn=pipeline.gradio_generate, 111 | inputs=inputs, 112 | outputs=[output_image, download_btn], 113 | ) 114 | 115 | example_text = gr.Text("", visible=False, label="Case For:") 116 | examples = get_examples("./assets/examples") 117 | 118 | gr.Examples( 119 | examples=examples, 120 | inputs=[ 121 | example_text, prompt, 122 | image_prompt1, image_prompt2, image_prompt3, image_prompt4, 123 | seed, output_image 124 | ], 125 | ) 126 | 127 | return demo 128 | 129 | if __name__ == "__main__": 130 | from typing import Literal 131 | 132 | from transformers import HfArgumentParser 133 | 134 | @dataclasses.dataclass 135 | class AppArgs: 136 | name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev" 137 | device: Literal["cuda", "cpu"] = ( 138 | "cuda" if torch.cuda.is_available() \ 139 | else "mps" if torch.backends.mps.is_available() \ 140 | else "cpu" 141 | ) 142 | offload: bool = dataclasses.field( 143 | default=False, 144 | metadata={"help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."} 145 | ) 146 | port: int = 7860 147 | 148 | parser = HfArgumentParser([AppArgs]) 149 | args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs] 150 | args = args_tuple[0] 151 | 152 | demo = create_demo(args.name, args.device, args.offload) 153 | demo.launch(server_port=args.port) 154 | -------------------------------------------------------------------------------- /assets/cat_cafe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/cat_cafe.png -------------------------------------------------------------------------------- /assets/clock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/clock.png -------------------------------------------------------------------------------- /assets/crystal_ball.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/crystal_ball.png -------------------------------------------------------------------------------- /assets/cup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/cup.png -------------------------------------------------------------------------------- /assets/examples/0_0.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompt": "A clock on the beach is under a red sun umbrella", 3 | "image_paths": [ 4 | "assets/clock.png" 5 | ], 6 | "eval_json_path": null, 7 | "offload": false, 8 | "num_images_per_prompt": 1, 9 | "model_type": "flux-dev", 10 | "width": 704, 11 | "height": 704, 12 | "ref_size": 512, 13 | "num_steps": 25, 14 | "guidance": 4, 15 | "seed": 3407, 16 | "save_path": "output/inference", 17 | "only_lora": true, 18 | "concat_refs": false, 19 | "lora_rank": 512, 20 | "data_resolution": 512, 21 | "pe": "d" 22 | } -------------------------------------------------------------------------------- /assets/examples/0_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/0_0.png -------------------------------------------------------------------------------- /assets/examples/1_0.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompt": "The figurine is in the crystal ball", 3 | "image_paths": [ 4 | "assets/figurine.png", 5 | "assets/crystal_ball.png" 6 | ], 7 | "eval_json_path": null, 8 | "offload": false, 9 | "num_images_per_prompt": 1, 10 | "model_type": "flux-dev", 11 | "width": 704, 12 | "height": 704, 13 | "ref_size": 320, 14 | "num_steps": 25, 15 | "guidance": 4, 16 | "seed": 3407, 17 | "save_path": "output/inference", 18 | "only_lora": true, 19 | "concat_refs": false, 20 | "lora_rank": 512, 21 | "data_resolution": 512, 22 | "pe": "d" 23 | } -------------------------------------------------------------------------------- /assets/examples/1_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/1_0.png -------------------------------------------------------------------------------- /assets/examples/1one2one/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompt": "A clock on the beach is under a red sun umbrella", 3 | "seed": 0, 4 | "ref_long_side": 512, 5 | "useage": "one2one", 6 | "image_ref1": "./ref1.jpg", 7 | "image_result": "./result.png" 8 | } -------------------------------------------------------------------------------- /assets/examples/1one2one/ref1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/1one2one/ref1.jpg -------------------------------------------------------------------------------- /assets/examples/1one2one/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/1one2one/result.png -------------------------------------------------------------------------------- /assets/examples/2_0.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompt": "The logo is printed on the cup", 3 | "image_paths": [ 4 | "assets/cat_cafe.png", 5 | "assets/cup.png" 6 | ], 7 | "eval_json_path": null, 8 | "offload": false, 9 | "num_images_per_prompt": 1, 10 | "model_type": "flux-dev", 11 | "width": 704, 12 | "height": 704, 13 | "ref_size": 320, 14 | "num_steps": 25, 15 | "guidance": 4, 16 | "seed": 3407, 17 | "save_path": "output/inference", 18 | "only_lora": true, 19 | "concat_refs": false, 20 | "lora_rank": 512, 21 | "data_resolution": 512, 22 | "pe": "d" 23 | } -------------------------------------------------------------------------------- /assets/examples/2_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/2_0.png -------------------------------------------------------------------------------- /assets/examples/2one2one/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompt": "A pretty woman wears a flower petal dress, in the flower", 3 | "seed": 1, 4 | "ref_long_side": 512, 5 | "useage": "one2one", 6 | "image_ref1": "./ref1.png", 7 | "image_result": "./result.png" 8 | } -------------------------------------------------------------------------------- /assets/examples/2one2one/ref1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/2one2one/ref1.png -------------------------------------------------------------------------------- /assets/examples/2one2one/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/2one2one/result.png -------------------------------------------------------------------------------- /assets/examples/3two2one/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompt": "The figurine is in the crystal ball", 3 | "seed": 0, 4 | "ref_long_side": 320, 5 | "useage": "two2one", 6 | "image_ref1": "./ref1.png", 7 | "image_ref2": "./ref2.png", 8 | "image_result": "./result.png" 9 | } -------------------------------------------------------------------------------- /assets/examples/3two2one/ref1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/3two2one/ref1.png -------------------------------------------------------------------------------- /assets/examples/3two2one/ref2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/3two2one/ref2.png -------------------------------------------------------------------------------- /assets/examples/3two2one/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/3two2one/result.png -------------------------------------------------------------------------------- /assets/examples/4two2one/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompt": "The logo is printed on the cup", 3 | "seed": 61733557, 4 | "ref_long_side": 320, 5 | "useage": "two2one", 6 | "image_ref1": "./ref1.png", 7 | "image_ref2": "./ref2.png", 8 | "image_result": "./result.png" 9 | } -------------------------------------------------------------------------------- /assets/examples/4two2one/ref1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/4two2one/ref1.png -------------------------------------------------------------------------------- /assets/examples/4two2one/ref2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/4two2one/ref2.png -------------------------------------------------------------------------------- /assets/examples/4two2one/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/4two2one/result.png -------------------------------------------------------------------------------- /assets/examples/5many2one/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompt": "A woman wears the dress and holds a bag, in the flowers.", 3 | "seed": 37635012, 4 | "ref_long_side": 320, 5 | "useage": "many2one", 6 | "image_ref1": "./ref1.png", 7 | "image_ref2": "./ref2.png", 8 | "image_ref3": "./ref3.png", 9 | "image_result": "./result.png" 10 | } -------------------------------------------------------------------------------- /assets/examples/5many2one/ref1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/5many2one/ref1.png -------------------------------------------------------------------------------- /assets/examples/5many2one/ref2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/5many2one/ref2.png -------------------------------------------------------------------------------- /assets/examples/5many2one/ref3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/5many2one/ref3.png -------------------------------------------------------------------------------- /assets/examples/5many2one/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/5many2one/result.png -------------------------------------------------------------------------------- /assets/examples/6t2i/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompt": "A woman wears the dress and holds a bag, in the flowers.", 3 | "seed": 37635012, 4 | "ref_long_side": 512, 5 | "useage": "t2i", 6 | "image_result": "./result.png" 7 | } -------------------------------------------------------------------------------- /assets/examples/6t2i/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/examples/6t2i/result.png -------------------------------------------------------------------------------- /assets/figurine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/figurine.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/logo.png -------------------------------------------------------------------------------- /assets/simplecase.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/simplecase.jpeg -------------------------------------------------------------------------------- /assets/simplecase.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/simplecase.jpg -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/UNO/d305981b9e301315b5c110aef9c7eed827457858/assets/teaser.jpg -------------------------------------------------------------------------------- /config/deepspeed/zero2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "zero_optimization": { 6 | "stage": 2, 7 | "offload_optimizer": { 8 | "device": "none" 9 | }, 10 | "contiguous_gradients": true, 11 | "overlap_comm": true 12 | }, 13 | "train_micro_batch_size_per_gpu": 1, 14 | "gradient_accumulation_steps": "auto" 15 | } 16 | -------------------------------------------------------------------------------- /config/deepspeed/zero3_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "zero_optimization": { 12 | "stage": 3, 13 | "offload_optimizer": { 14 | "device": "cpu", 15 | "pin_memory": true 16 | }, 17 | "offload_param": { 18 | "device": "cpu", 19 | "pin_memory": true 20 | }, 21 | "overlap_comm": true, 22 | "contiguous_gradients": true, 23 | "reduce_bucket_size": 16777216, 24 | "stage3_prefetch_bucket_size": 15099494, 25 | "stage3_param_persistence_threshold": 40960, 26 | "sub_group_size": 1e9, 27 | "stage3_max_live_parameters": 1e9, 28 | "stage3_max_reuse_distance": 1e9, 29 | "stage3_gather_16bit_weights_on_model_save": true 30 | }, 31 | "gradient_accumulation_steps": "auto", 32 | "train_micro_batch_size_per_gpu": 1 33 | } 34 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import dataclasses 17 | from typing import Literal 18 | 19 | from accelerate import Accelerator 20 | from transformers import HfArgumentParser 21 | from PIL import Image 22 | import json 23 | import itertools 24 | 25 | from uno.flux.pipeline import UNOPipeline, preprocess_ref 26 | 27 | 28 | def horizontal_concat(images): 29 | widths, heights = zip(*(img.size for img in images)) 30 | 31 | total_width = sum(widths) 32 | max_height = max(heights) 33 | 34 | new_im = Image.new('RGB', (total_width, max_height)) 35 | 36 | x_offset = 0 37 | for img in images: 38 | new_im.paste(img, (x_offset, 0)) 39 | x_offset += img.size[0] 40 | 41 | return new_im 42 | 43 | @dataclasses.dataclass 44 | class InferenceArgs: 45 | prompt: str | None = None 46 | image_paths: list[str] | None = None 47 | eval_json_path: str | None = None 48 | offload: bool = False 49 | num_images_per_prompt: int = 1 50 | model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev" 51 | width: int = 512 52 | height: int = 512 53 | ref_size: int = -1 54 | num_steps: int = 25 55 | guidance: float = 4 56 | seed: int = 3407 57 | save_path: str = "output/inference" 58 | only_lora: bool = True 59 | concat_refs: bool = False 60 | lora_rank: int = 512 61 | data_resolution: int = 512 62 | pe: Literal['d', 'h', 'w', 'o'] = 'd' 63 | 64 | def main(args: InferenceArgs): 65 | accelerator = Accelerator() 66 | 67 | pipeline = UNOPipeline( 68 | args.model_type, 69 | accelerator.device, 70 | args.offload, 71 | only_lora=args.only_lora, 72 | lora_rank=args.lora_rank 73 | ) 74 | 75 | assert args.prompt is not None or args.eval_json_path is not None, \ 76 | "Please provide either prompt or eval_json_path" 77 | 78 | if args.eval_json_path is not None: 79 | with open(args.eval_json_path, "rt") as f: 80 | data_dicts = json.load(f) 81 | data_root = os.path.dirname(args.eval_json_path) 82 | else: 83 | data_root = "./" 84 | data_dicts = [{"prompt": args.prompt, "image_paths": args.image_paths}] 85 | 86 | for (i, data_dict), j in itertools.product(enumerate(data_dicts), range(args.num_images_per_prompt)): 87 | if (i * args.num_images_per_prompt + j) % accelerator.num_processes != accelerator.process_index: 88 | continue 89 | 90 | ref_imgs = [ 91 | Image.open(os.path.join(data_root, img_path)) 92 | for img_path in data_dict["image_paths"] 93 | ] 94 | if args.ref_size==-1: 95 | args.ref_size = 512 if len(ref_imgs)==1 else 320 96 | 97 | ref_imgs = [preprocess_ref(img, args.ref_size) for img in ref_imgs] 98 | 99 | image_gen = pipeline( 100 | prompt=data_dict["prompt"], 101 | width=args.width, 102 | height=args.height, 103 | guidance=args.guidance, 104 | num_steps=args.num_steps, 105 | seed=args.seed + j, 106 | ref_imgs=ref_imgs, 107 | pe=args.pe, 108 | ) 109 | if args.concat_refs: 110 | image_gen = horizontal_concat([image_gen, *ref_imgs]) 111 | 112 | os.makedirs(args.save_path, exist_ok=True) 113 | image_gen.save(os.path.join(args.save_path, f"{i}_{j}.png")) 114 | 115 | # save config and image 116 | args_dict = vars(args) 117 | args_dict['prompt'] = data_dict["prompt"] 118 | args_dict['image_paths'] = data_dict["image_paths"] 119 | with open(os.path.join(args.save_path, f"{i}_{j}.json"), 'w') as f: 120 | json.dump(args_dict, f, indent=4) 121 | 122 | if __name__ == "__main__": 123 | parser = HfArgumentParser([InferenceArgs]) 124 | args = parser.parse_args_into_dataclasses()[0] 125 | main(args) 126 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "uno" 7 | version = "0.0.1" 8 | authors = [ 9 | { name="Bytedance Ltd. and/or its affiliates" }, 10 | ] 11 | maintainers = [ 12 | {name = "Wu Shaojin", email = "wushaojin@bytedance.com"}, 13 | {name = "Huang Mengqi", email = "huangmengqi.98@bytedance.com"}, 14 | {name = "Wu Wenxu", email = "wuwenxu.01@bytedance.com"}, 15 | {name = "Cheng Yufeng", email = "chengyufeng.cb1@bytedance.com"}, 16 | ] 17 | 18 | description = "🔥🔥 UNO: A Universal Customization Method for Both Single and Multi-Subject Conditioning" 19 | readme = "README.md" 20 | requires-python = ">=3.10, <=3.12" 21 | classifiers = [ 22 | "Programming Language :: Python :: 3", 23 | "Operating System :: OS Independent", 24 | ] 25 | license = "Apache-2.0" 26 | license-files = ["LICENSE"] 27 | 28 | 29 | dependencies = [ 30 | "torch>=2.4.0", 31 | "torchvision>=0.19.0", 32 | "einops>=0.8.0", 33 | "transformers>=4.43.3", 34 | "huggingface-hub", 35 | "diffusers>=0.30.1", 36 | "sentencepiece==0.2.0", 37 | "gradio>=5.22.0", 38 | ] 39 | 40 | [project.optional-dependencies] 41 | 42 | train = [ 43 | "accelerate==1.1.1", 44 | "deepspeed==0.14.4", 45 | ] 46 | 47 | dev = [ 48 | "ruff", 49 | ] 50 | 51 | 52 | [project.urls] 53 | Repository = "https://github.com/bytedance/UNO" 54 | ProjectPage = "https://bytedance.github.io/UNO" 55 | Models = "https://huggingface.co/bytedance-research/UNO" 56 | Demo = "https://huggingface.co/spaces/bytedance-research/UNO-FLUX" 57 | Arxiv = "https://arxiv.org/abs/2504.02160" 58 | 59 | 60 | [tool.setuptools.packages.find] 61 | where = [""] 62 | namespaces = false # to disable scanning PEP 420 namespaces (true by default) 63 | 64 | [tool.ruff] 65 | include = ["uno/**/*.py"] 66 | line-length = 120 67 | indent-width = 4 68 | target-version = "py310" 69 | show-fixes = true 70 | 71 | [tool.ruff.lint] 72 | select = ["E4", "E7", "E9", "F", "I"] 73 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ## after update to pyproject.toml, the only usage of requirements.txt is to install the dependencies in huggingface demo, so comment out the training dependencies 2 | # accelerate==1.1.1 3 | # deepspeed==0.14.4 4 | einops==0.8.0 5 | transformers==4.43.3 6 | huggingface-hub 7 | diffusers==0.30.1 8 | sentencepiece==0.2.0 9 | gradio==5.22.0 10 | 11 | --extra-index-url https://download.pytorch.org/whl/cu124 12 | torch==2.4.0 13 | torchvision==0.19.0 14 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import dataclasses 16 | import gc 17 | import itertools 18 | import logging 19 | import os 20 | import random 21 | from copy import deepcopy 22 | from typing import TYPE_CHECKING, Literal 23 | 24 | import torch 25 | import torch.nn.functional as F 26 | import transformers 27 | from accelerate import Accelerator, DeepSpeedPlugin 28 | from accelerate.logging import get_logger 29 | from accelerate.utils import set_seed 30 | from diffusers.optimization import get_scheduler 31 | from einops import rearrange 32 | from PIL import Image 33 | from safetensors.torch import load_file 34 | from torch.utils.data import DataLoader 35 | from tqdm import tqdm 36 | 37 | from uno.dataset.uno import FluxPairedDatasetV2 38 | from uno.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack 39 | from uno.flux.util import load_ae, load_clip, load_flow_model, load_t5, set_lora 40 | 41 | if TYPE_CHECKING: 42 | from uno.flux.model import Flux 43 | from uno.flux.modules.autoencoder import AutoEncoder 44 | from uno.flux.modules.conditioner import HFEmbedder 45 | 46 | logger = get_logger(__name__) 47 | 48 | def get_models(name: str, device, offload: bool=False): 49 | t5 = load_t5(device, max_length=512) 50 | clip = load_clip(device) 51 | model = load_flow_model(name, device="cpu") 52 | vae = load_ae(name, device="cpu" if offload else device) 53 | return model, vae, t5, clip 54 | 55 | def inference( 56 | batch: dict, 57 | model: "Flux", t5: "HFEmbedder", clip: "HFEmbedder", ae: "AutoEncoder", 58 | accelerator: Accelerator, 59 | seed: int = 0, 60 | pe: Literal["d", "h", "w", "o"] = "d" 61 | ) -> Image.Image: 62 | ref_imgs = batch["ref_imgs"] 63 | prompt = batch["txt"] 64 | neg_prompt = '' 65 | width, height = 512, 512 66 | num_steps = 25 67 | x = get_noise( 68 | 1, height, width, 69 | device=accelerator.device, 70 | dtype=torch.bfloat16, 71 | seed=seed + accelerator.process_index 72 | ) 73 | timesteps = get_schedule( 74 | num_steps, 75 | (width // 8) * (height // 8) // (16 * 16), 76 | shift=True, 77 | ) 78 | with torch.no_grad(): 79 | ref_imgs = [ 80 | ae.encode(ref_img_.to(accelerator.device, torch.float32)).to(torch.bfloat16) 81 | for ref_img_ in ref_imgs 82 | ] 83 | inp_cond = prepare_multi_ip( 84 | t5=t5, clip=clip, img=x, prompt=prompt, 85 | ref_imgs=ref_imgs, 86 | pe=pe 87 | ) 88 | 89 | x = denoise( 90 | model, 91 | **inp_cond, 92 | timesteps=timesteps, 93 | guidance=4, 94 | ) 95 | 96 | x = unpack(x.float(), height, width) 97 | x = ae.decode(x) 98 | 99 | x1 = x.clamp(-1, 1) 100 | x1 = rearrange(x1[-1], "c h w -> h w c") 101 | output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) 102 | 103 | return output_img 104 | 105 | 106 | def resume_from_checkpoint( 107 | resume_from_checkpoint: str | None | Literal["latest"], 108 | project_dir: str, 109 | accelerator: Accelerator, 110 | dit: "Flux", 111 | dit_ema_dict: dict | None = None, 112 | ) -> tuple["Flux", torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler, dict | None, int]: 113 | # Potentially load in the weights and states from a previous save 114 | if resume_from_checkpoint is None: 115 | return dit, dit_ema_dict, 0 116 | 117 | if resume_from_checkpoint == "latest": 118 | # Get the most recent checkpoint 119 | dirs = os.listdir(project_dir) 120 | dirs = [d for d in dirs if d.startswith("checkpoint")] 121 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 122 | if len(dirs) == 0: 123 | accelerator.print( 124 | f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run." 125 | ) 126 | return dit, dit_ema_dict, 0 127 | path = dirs[-1] 128 | else: 129 | path = os.path.basename(resume_from_checkpoint) 130 | 131 | 132 | accelerator.print(f"Resuming from checkpoint {path}") 133 | lora_state = load_file( 134 | os.path.join(project_dir, path, 'dit_lora.safetensors'), 135 | device=accelerator.device.__str__() 136 | ) 137 | unwarp_dit = accelerator.unwrap_model(dit) 138 | unwarp_dit.load_state_dict(lora_state, strict=False) 139 | if dit_ema_dict is not None: 140 | dit_ema_dict = load_file( 141 | os.path.join(project_dir, path, 'dit_lora_ema.safetensors'), 142 | device=accelerator.device.__str__() 143 | ) 144 | if dit is not unwarp_dit: 145 | dit_ema_dict = {f"module.{k}": v for k, v in dit_ema_dict.items() if k in unwarp_dit.state_dict()} 146 | 147 | global_step = int(path.split("-")[1]) 148 | 149 | return dit, dit_ema_dict, global_step 150 | 151 | @dataclasses.dataclass 152 | class TrainArgs: 153 | ## accelerator 154 | project_dir: str | None = None 155 | mixed_precision: Literal["no", "fp16", "bf16"] = "bf16" 156 | gradient_accumulation_steps: int = 1 157 | seed: int = 42 158 | wandb_project_name: str | None = None 159 | wandb_run_name: str | None = None 160 | 161 | ## model 162 | model_name: Literal["flux-dev", "flux-schnell"] = "flux-dev" 163 | lora_rank: int = 512 164 | double_blocks_indices: list[int] | None = dataclasses.field( 165 | default=None, 166 | metadata={"help": "Indices of double blocks to apply LoRA. None means all double blocks."} 167 | ) 168 | single_blocks_indices: list[int] | None = dataclasses.field( 169 | default=None, 170 | metadata={"help": "Indices of double blocks to apply LoRA. None means all single blocks."} 171 | ) 172 | pe: Literal["d", "h", "w", "o"] = "d" 173 | gradient_checkpoint: bool = True 174 | ema: bool = False 175 | ema_interval: int = 1 176 | ema_decay: float = 0.99 177 | 178 | 179 | ## optimizer 180 | learning_rate: float = 1e-2 181 | adam_betas: list[float] = dataclasses.field(default_factory=lambda: [0.9, 0.999]) 182 | adam_eps: float = 1e-8 183 | adam_weight_decay: float = 0.01 184 | max_grad_norm: float = 1.0 185 | 186 | ## lr_scheduler 187 | lr_scheduler: str = "constant" 188 | lr_warmup_steps: int = 100 189 | max_train_steps: int = 100000 190 | 191 | ## dataloader 192 | # TODO: change to your own dataset, or use one data syenthsize pipeline comming in the future. stay tuned 193 | train_data_json: str = "datasets/dreambench_singleip.json" 194 | batch_size: int = 1 195 | text_dropout: float = 0.1 196 | resolution: int = 512 197 | resolution_ref: int | None = None 198 | 199 | eval_data_json: str = "datasets/dreambench_singleip.json" 200 | eval_batch_size: int = 1 201 | 202 | ## misc 203 | resume_from_checkpoint: str | None | Literal["latest"] = None 204 | checkpointing_steps: int = 1000 205 | 206 | def main( 207 | args: TrainArgs, 208 | ): 209 | ## accelerator 210 | deepspeed_plugins = { 211 | "dit": DeepSpeedPlugin(hf_ds_config='config/deepspeed/zero2_config.json'), 212 | "t5": DeepSpeedPlugin(hf_ds_config='config/deepspeed/zero3_config.json'), 213 | "clip": DeepSpeedPlugin(hf_ds_config='config/deepspeed/zero3_config.json') 214 | } 215 | accelerator = Accelerator( 216 | project_dir=args.project_dir, 217 | gradient_accumulation_steps=args.gradient_accumulation_steps, 218 | mixed_precision=args.mixed_precision, 219 | deepspeed_plugins=deepspeed_plugins, 220 | log_with="wandb", 221 | ) 222 | set_seed(args.seed, device_specific=True) 223 | accelerator.init_trackers( 224 | project_name=args.wandb_project_name, 225 | config=args.__dict__, 226 | init_kwargs={ 227 | "wandb": { 228 | "name": args.wandb_run_name, 229 | "dir": accelerator.project_dir, 230 | }, 231 | }, 232 | ) 233 | weight_dtype = { 234 | "fp16": torch.float16, 235 | "bf16": torch.bfloat16, 236 | "no": torch.float32, 237 | }.get(accelerator.mixed_precision, torch.float32) 238 | 239 | ## logger 240 | logging.basicConfig( 241 | format=f"[RANK {accelerator.process_index}] " + "%(asctime)s - %(levelname)s - %(name)s - %(message)s", 242 | datefmt="%m/%d/%Y %H:%M:%S", 243 | level=logging.INFO, 244 | force=True 245 | ) 246 | logger.info(accelerator.state) 247 | logger.info("Training script launched", main_process_only=False) 248 | 249 | ## model 250 | dit, vae, t5, clip = get_models( 251 | name=args.model_name, 252 | device=accelerator.device, 253 | ) 254 | 255 | vae.requires_grad_(False) 256 | t5.requires_grad_(False) 257 | clip.requires_grad_(False) 258 | 259 | dit.requires_grad_(False) 260 | dit = set_lora(dit, args.lora_rank, args.double_blocks_indices, args.single_blocks_indices, accelerator.device) 261 | dit.train() 262 | dit.gradient_checkpointing = args.gradient_checkpoint 263 | 264 | ## ema 265 | dit_ema_dict = { 266 | f"module.{k}": deepcopy(v).requires_grad_(False) for k, v in dit.named_parameters() if v.requires_grad 267 | } if args.ema else None 268 | 269 | ## optimizer and lr scheduler 270 | optimizer = torch.optim.AdamW( 271 | [p for p in dit.parameters() if p.requires_grad], 272 | lr=args.learning_rate, 273 | betas=args.adam_betas, 274 | weight_decay=args.adam_weight_decay, 275 | eps=args.adam_eps, 276 | ) 277 | 278 | lr_scheduler = get_scheduler( 279 | args.lr_scheduler, 280 | optimizer=optimizer, 281 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 282 | num_training_steps=args.max_train_steps * accelerator.num_processes, 283 | ) 284 | 285 | ## resume 286 | ( 287 | dit, 288 | dit_ema_dict, 289 | global_step 290 | ) = resume_from_checkpoint( 291 | args.resume_from_checkpoint, 292 | project_dir=args.project_dir, 293 | accelerator=accelerator, 294 | dit=dit, 295 | dit_ema_dict=dit_ema_dict 296 | ) 297 | 298 | # dataloader 299 | dataset = FluxPairedDatasetV2( 300 | json_file=args.train_data_json, 301 | resolution=args.resolution, resolution_ref=args.resolution_ref 302 | ) 303 | dataloader = DataLoader( 304 | dataset, 305 | batch_size=args.batch_size, 306 | shuffle=True, 307 | collate_fn=dataset.collate_fn 308 | ) 309 | eval_dataset = FluxPairedDatasetV2( 310 | json_file=args.eval_data_json, 311 | resolution=args.resolution, resolution_ref=args.resolution_ref 312 | ) 313 | eval_dataloader = DataLoader( 314 | eval_dataset, 315 | batch_size=args.eval_batch_size, 316 | shuffle=False, 317 | collate_fn=eval_dataset.collate_fn 318 | ) 319 | 320 | dataloader = accelerator.prepare_data_loader(dataloader) 321 | eval_dataloader = accelerator.prepare_data_loader(eval_dataloader) 322 | dataloader = itertools.cycle(dataloader) # as infinite fetch data loader 323 | 324 | ## parallel 325 | accelerator.state.select_deepspeed_plugin("dit") 326 | dit, optimizer, lr_scheduler = accelerator.prepare(dit, optimizer, lr_scheduler) 327 | accelerator.state.select_deepspeed_plugin("t5") 328 | t5 = accelerator.prepare(t5) # type: torch.nn.Module 329 | accelerator.state.select_deepspeed_plugin("clip") 330 | clip = accelerator.prepare(clip) # type: torch.nn.Module 331 | 332 | ## noise scheduler 333 | timesteps = get_schedule( 334 | 999, 335 | (args.resolution // 8) * (args.resolution // 8) // 4, 336 | shift=True, 337 | ) 338 | timesteps = torch.tensor(timesteps, device=accelerator.device) 339 | total_batch_size = args.batch_size * accelerator.num_processes * args.gradient_accumulation_steps 340 | 341 | logger.info("***** Running training *****") 342 | logger.info(f" Instantaneous batch size per device = {args.batch_size}") 343 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 344 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 345 | logger.info(f" Total optimization steps = {args.max_train_steps}") 346 | logger.info(f" Total validation prompts = {len(eval_dataloader)}") 347 | 348 | progress_bar = tqdm( 349 | range(0, args.max_train_steps), 350 | initial=global_step, 351 | desc="Steps", 352 | total=args.max_train_steps, 353 | disable=not accelerator.is_local_main_process, 354 | ) 355 | 356 | train_loss = 0.0 357 | while global_step < (args.max_train_steps): 358 | batch = next(dataloader) 359 | prompts = [txt_ if random.random() > args.text_dropout else "" for txt_ in batch["txt"]] 360 | img = batch["img"] 361 | ref_imgs = batch["ref_imgs"] 362 | 363 | with torch.no_grad(): 364 | x_1 = vae.encode(img.to(accelerator.device).to(torch.float32)) 365 | x_ref = [vae.encode(ref_img.to(accelerator.device).to(torch.float32)) for ref_img in ref_imgs] 366 | inp = prepare_multi_ip(t5=t5, clip=clip, img=x_1, prompt=prompts, ref_imgs=tuple(x_ref), pe=args.pe) 367 | x_1 = rearrange(x_1, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 368 | x_ref = [rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) for x in x_ref] 369 | 370 | bs = img.shape[0] 371 | t = torch.randint(0, 1000, (bs,), device=accelerator.device) 372 | t = timesteps[t] 373 | x_0 = torch.randn_like(x_1, device=accelerator.device) 374 | x_t = (1 - t[:, None, None]) * x_1 + t[:, None, None] * x_0 375 | guidance_vec = torch.full((x_t.shape[0],), 1, device=x_t.device, dtype=x_t.dtype) 376 | 377 | with accelerator.accumulate(dit): 378 | # Predict the noise residual and compute loss 379 | model_pred = dit( 380 | img=x_t.to(weight_dtype), 381 | img_ids=inp['img_ids'].to(weight_dtype), 382 | ref_img=[x.to(weight_dtype) for x in x_ref], 383 | ref_img_ids=[ref_img_id.to(weight_dtype) for ref_img_id in inp['ref_img_ids']], 384 | txt=inp['txt'].to(weight_dtype), 385 | txt_ids=inp['txt_ids'].to(weight_dtype), 386 | y=inp['vec'].to(weight_dtype), 387 | timesteps=t.to(weight_dtype), 388 | guidance=guidance_vec.to(weight_dtype) 389 | ) 390 | 391 | loss = F.mse_loss(model_pred.float(), (x_0 - x_1).float(), reduction="mean") 392 | 393 | # Gather the losses across all processes for logging (if we use distributed training). 394 | avg_loss = accelerator.gather(loss.repeat(args.batch_size)).mean() 395 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 396 | 397 | # Backpropagate 398 | accelerator.backward(loss) 399 | if accelerator.sync_gradients: 400 | accelerator.clip_grad_norm_(dit.parameters(), args.max_grad_norm) 401 | optimizer.step() 402 | lr_scheduler.step() 403 | optimizer.zero_grad() 404 | 405 | # Checks if the accelerator has performed an optimization step behind the scenes 406 | if accelerator.sync_gradients: 407 | progress_bar.update(1) 408 | global_step += 1 409 | accelerator.log({"train_loss": train_loss}, step=global_step) 410 | train_loss = 0.0 411 | 412 | if accelerator.sync_gradients and dit_ema_dict is not None and global_step % args.ema_interval == 0: 413 | src_dict = dit.state_dict() 414 | for tgt_name in dit_ema_dict: 415 | dit_ema_dict[tgt_name].data.lerp_(src_dict[tgt_name].to(dit_ema_dict[tgt_name]), 1 - args.ema_decay) 416 | 417 | if accelerator.sync_gradients and accelerator.is_main_process and global_step % args.checkpointing_steps == 0: 418 | logger.info(f"saving checkpoint in {global_step=}") 419 | save_path = os.path.join(args.project_dir, f"checkpoint-{global_step}") 420 | os.makedirs(save_path, exist_ok=True) 421 | 422 | # save 423 | accelerator.wait_for_everyone() 424 | unwrapped_model = accelerator.unwrap_model(dit) 425 | unwrapped_model_state = unwrapped_model.state_dict() 426 | requires_grad_key = [k for k, v in unwrapped_model.named_parameters() if v.requires_grad] 427 | unwrapped_model_state = {k: unwrapped_model_state[k] for k in requires_grad_key} 428 | 429 | accelerator.save( 430 | unwrapped_model_state, 431 | os.path.join(save_path, 'dit_lora.safetensors'), 432 | safe_serialization=True 433 | ) 434 | unwrapped_opt = accelerator.unwrap_model(optimizer) 435 | accelerator.save(unwrapped_opt.state_dict(), os.path.join(save_path, 'optimizer.bin')) 436 | logger.info(f"Saved state to {save_path}") 437 | 438 | if args.ema: 439 | accelerator.save( 440 | {k.split("module.")[-1]: v for k, v in dit_ema_dict.items()}, 441 | os.path.join(save_path, 'dit_lora_ema.safetensors'), 442 | safe_serialization=True 443 | ) 444 | 445 | # validate 446 | dit.eval() 447 | torch.set_grad_enabled(False) 448 | for i, batch in enumerate(eval_dataloader): 449 | result = inference(batch, dit, t5, clip, vae, accelerator, seed=0) 450 | accelerator.log({f"eval_gen_{i}": result}, step=global_step) 451 | 452 | 453 | if args.ema: 454 | original_state_dict = dit.state_dict() 455 | dit.load_state_dict(dit_ema_dict, strict=False) 456 | for batch in eval_dataloader: 457 | result = inference(batch, dit, t5, clip, vae, accelerator, seed=0) 458 | accelerator.log({f"eval_ema_gen_{i}": result}, step=global_step) 459 | dit.load_state_dict(original_state_dict, strict=False) 460 | 461 | torch.cuda.empty_cache() 462 | gc.collect() 463 | torch.set_grad_enabled(True) 464 | dit.train() 465 | accelerator.wait_for_everyone() 466 | 467 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 468 | progress_bar.set_postfix(**logs) 469 | 470 | accelerator.wait_for_everyone() 471 | accelerator.end_training() 472 | 473 | if __name__ == "__main__": 474 | parser = transformers.HfArgumentParser([TrainArgs]) 475 | args_tuple = parser.parse_args_into_dataclasses(args_file_flag="--config") 476 | main(*args_tuple) 477 | -------------------------------------------------------------------------------- /uno/dataset/uno.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import json 17 | import os 18 | 19 | import numpy as np 20 | import torch 21 | import torchvision.transforms.functional as TVF 22 | from PIL import Image 23 | from torch.utils.data import DataLoader, Dataset 24 | from torchvision.transforms import Compose, Normalize, ToTensor 25 | 26 | 27 | def bucket_images(images: list[torch.Tensor], resolution: int = 512): 28 | bucket_override=[ 29 | # h w 30 | (256, 768), 31 | (320, 768), 32 | (320, 704), 33 | (384, 640), 34 | (448, 576), 35 | (512, 512), 36 | (576, 448), 37 | (640, 384), 38 | (704, 320), 39 | (768, 320), 40 | (768, 256) 41 | ] 42 | bucket_override = [(int(h / 512 * resolution), int(w / 512 * resolution)) for h, w in bucket_override] 43 | bucket_override = [(h // 16 * 16, w // 16 * 16) for h, w in bucket_override] 44 | 45 | aspect_ratios = [image.shape[-2] / image.shape[-1] for image in images] 46 | mean_aspect_ratio = np.mean(aspect_ratios) 47 | 48 | new_h, new_w = bucket_override[0] 49 | min_aspect_diff = np.abs(new_h / new_w - mean_aspect_ratio) 50 | for h, w in bucket_override: 51 | aspect_diff = np.abs(h / w - mean_aspect_ratio) 52 | if aspect_diff < min_aspect_diff: 53 | min_aspect_diff = aspect_diff 54 | new_h, new_w = h, w 55 | 56 | images = [TVF.resize(image, (new_h, new_w)) for image in images] 57 | images = torch.stack(images, dim=0) 58 | return images 59 | 60 | class FluxPairedDatasetV2(Dataset): 61 | def __init__(self, json_file: str, resolution: int, resolution_ref: int | None = None): 62 | super().__init__() 63 | self.json_file = json_file 64 | self.resolution = resolution 65 | self.resolution_ref = resolution_ref if resolution_ref is not None else resolution 66 | self.image_root = os.path.dirname(json_file) 67 | 68 | with open(self.json_file, "rt") as f: 69 | self.data_dicts = json.load(f) 70 | 71 | self.transform = Compose([ 72 | ToTensor(), 73 | Normalize([0.5], [0.5]), 74 | ]) 75 | 76 | def __getitem__(self, idx): 77 | data_dict = self.data_dicts[idx] 78 | image_paths = [data_dict["image_path"]] if "image_path" in data_dict else data_dict["image_paths"] 79 | txt = data_dict["prompt"] 80 | image_tgt_path = data_dict.get("image_tgt_path", None) 81 | # image_tgt_path = data_dict.get("image_paths", None)[0] # TODO: for debugging delete it when release paired data pipeline 82 | ref_imgs = [ 83 | Image.open(os.path.join(self.image_root, path)).convert("RGB") 84 | for path in image_paths 85 | ] 86 | ref_imgs = [self.transform(img) for img in ref_imgs] 87 | img = None 88 | if image_tgt_path is not None: 89 | img = Image.open(os.path.join(self.image_root, image_tgt_path)).convert("RGB") 90 | img = self.transform(img) 91 | 92 | return { 93 | "img": img, 94 | "txt": txt, 95 | "ref_imgs": ref_imgs, 96 | } 97 | 98 | def __len__(self): 99 | return len(self.data_dicts) 100 | 101 | def collate_fn(self, batch): 102 | img = [data["img"] for data in batch] 103 | txt = [data["txt"] for data in batch] 104 | ref_imgs = [data["ref_imgs"] for data in batch] 105 | assert all([len(ref_imgs[0]) == len(ref_imgs[i]) for i in range(len(ref_imgs))]) 106 | 107 | n_ref = len(ref_imgs[0]) 108 | 109 | img = bucket_images(img, self.resolution) 110 | ref_imgs_new = [] 111 | for i in range(n_ref): 112 | ref_imgs_i = [refs[i] for refs in ref_imgs] 113 | ref_imgs_i = bucket_images(ref_imgs_i, self.resolution_ref) 114 | ref_imgs_new.append(ref_imgs_i) 115 | 116 | return { 117 | "txt": txt, 118 | "img": img, 119 | "ref_imgs": ref_imgs_new, 120 | } 121 | 122 | if __name__ == '__main__': 123 | import argparse 124 | from pprint import pprint 125 | parser = argparse.ArgumentParser() 126 | # parser.add_argument("--json_file", type=str, required=True) 127 | parser.add_argument("--json_file", type=str, default="datasets/fake_train_data.json") 128 | args = parser.parse_args() 129 | dataset = FluxPairedDatasetV2(args.json_file, 512) 130 | dataloder = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn) 131 | 132 | for i, data_dict in enumerate(dataloder): 133 | pprint(i) 134 | pprint(data_dict) 135 | breakpoint() 136 | -------------------------------------------------------------------------------- /uno/flux/math.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from einops import rearrange 18 | from torch import Tensor 19 | 20 | 21 | def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: 22 | q, k = apply_rope(q, k, pe) 23 | 24 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 25 | x = rearrange(x, "B H L D -> B L (H D)") 26 | 27 | return x 28 | 29 | 30 | def rope(pos: Tensor, dim: int, theta: int) -> Tensor: 31 | assert dim % 2 == 0 32 | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim 33 | omega = 1.0 / (theta**scale) 34 | out = torch.einsum("...n,d->...nd", pos, omega) 35 | out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) 36 | out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) 37 | return out.float() 38 | 39 | 40 | def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: 41 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 42 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 43 | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 44 | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 45 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) 46 | -------------------------------------------------------------------------------- /uno/flux/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from dataclasses import dataclass 17 | 18 | import torch 19 | from torch import Tensor, nn 20 | 21 | from .modules.layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding 22 | 23 | 24 | @dataclass 25 | class FluxParams: 26 | in_channels: int 27 | vec_in_dim: int 28 | context_in_dim: int 29 | hidden_size: int 30 | mlp_ratio: float 31 | num_heads: int 32 | depth: int 33 | depth_single_blocks: int 34 | axes_dim: list[int] 35 | theta: int 36 | qkv_bias: bool 37 | guidance_embed: bool 38 | 39 | 40 | class Flux(nn.Module): 41 | """ 42 | Transformer model for flow matching on sequences. 43 | """ 44 | _supports_gradient_checkpointing = True 45 | 46 | def __init__(self, params: FluxParams): 47 | super().__init__() 48 | 49 | self.params = params 50 | self.in_channels = params.in_channels 51 | self.out_channels = self.in_channels 52 | if params.hidden_size % params.num_heads != 0: 53 | raise ValueError( 54 | f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" 55 | ) 56 | pe_dim = params.hidden_size // params.num_heads 57 | if sum(params.axes_dim) != pe_dim: 58 | raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") 59 | self.hidden_size = params.hidden_size 60 | self.num_heads = params.num_heads 61 | self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 62 | self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) 63 | self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 64 | self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) 65 | self.guidance_in = ( 66 | MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() 67 | ) 68 | self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) 69 | 70 | self.double_blocks = nn.ModuleList( 71 | [ 72 | DoubleStreamBlock( 73 | self.hidden_size, 74 | self.num_heads, 75 | mlp_ratio=params.mlp_ratio, 76 | qkv_bias=params.qkv_bias, 77 | ) 78 | for _ in range(params.depth) 79 | ] 80 | ) 81 | 82 | self.single_blocks = nn.ModuleList( 83 | [ 84 | SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) 85 | for _ in range(params.depth_single_blocks) 86 | ] 87 | ) 88 | 89 | self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) 90 | self.gradient_checkpointing = False 91 | 92 | def _set_gradient_checkpointing(self, module, value=False): 93 | if hasattr(module, "gradient_checkpointing"): 94 | module.gradient_checkpointing = value 95 | 96 | @property 97 | def attn_processors(self): 98 | # set recursively 99 | processors = {} # type: dict[str, nn.Module] 100 | 101 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): 102 | if hasattr(module, "set_processor"): 103 | processors[f"{name}.processor"] = module.processor 104 | 105 | for sub_name, child in module.named_children(): 106 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 107 | 108 | return processors 109 | 110 | for name, module in self.named_children(): 111 | fn_recursive_add_processors(name, module, processors) 112 | 113 | return processors 114 | 115 | def set_attn_processor(self, processor): 116 | r""" 117 | Sets the attention processor to use to compute attention. 118 | 119 | Parameters: 120 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 121 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 122 | for **all** `Attention` layers. 123 | 124 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 125 | processor. This is strongly recommended when setting trainable attention processors. 126 | 127 | """ 128 | count = len(self.attn_processors.keys()) 129 | 130 | if isinstance(processor, dict) and len(processor) != count: 131 | raise ValueError( 132 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 133 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 134 | ) 135 | 136 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 137 | if hasattr(module, "set_processor"): 138 | if not isinstance(processor, dict): 139 | module.set_processor(processor) 140 | else: 141 | module.set_processor(processor.pop(f"{name}.processor")) 142 | 143 | for sub_name, child in module.named_children(): 144 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 145 | 146 | for name, module in self.named_children(): 147 | fn_recursive_attn_processor(name, module, processor) 148 | 149 | def forward( 150 | self, 151 | img: Tensor, 152 | img_ids: Tensor, 153 | txt: Tensor, 154 | txt_ids: Tensor, 155 | timesteps: Tensor, 156 | y: Tensor, 157 | guidance: Tensor | None = None, 158 | ref_img: Tensor | None = None, 159 | ref_img_ids: Tensor | None = None, 160 | ) -> Tensor: 161 | if img.ndim != 3 or txt.ndim != 3: 162 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 163 | 164 | # running on sequences img 165 | img = self.img_in(img) 166 | vec = self.time_in(timestep_embedding(timesteps, 256)) 167 | if self.params.guidance_embed: 168 | if guidance is None: 169 | raise ValueError("Didn't get guidance strength for guidance distilled model.") 170 | vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) 171 | vec = vec + self.vector_in(y) 172 | txt = self.txt_in(txt) 173 | 174 | ids = torch.cat((txt_ids, img_ids), dim=1) 175 | 176 | # concat ref_img/img 177 | img_end = img.shape[1] 178 | if ref_img is not None: 179 | if isinstance(ref_img, tuple) or isinstance(ref_img, list): 180 | img_in = [img] + [self.img_in(ref) for ref in ref_img] 181 | img_ids = [ids] + [ref_ids for ref_ids in ref_img_ids] 182 | img = torch.cat(img_in, dim=1) 183 | ids = torch.cat(img_ids, dim=1) 184 | else: 185 | img = torch.cat((img, self.img_in(ref_img)), dim=1) 186 | ids = torch.cat((ids, ref_img_ids), dim=1) 187 | pe = self.pe_embedder(ids) 188 | 189 | for index_block, block in enumerate(self.double_blocks): 190 | if self.training and self.gradient_checkpointing: 191 | img, txt = torch.utils.checkpoint.checkpoint( 192 | block, 193 | img=img, 194 | txt=txt, 195 | vec=vec, 196 | pe=pe, 197 | use_reentrant=False, 198 | ) 199 | else: 200 | img, txt = block( 201 | img=img, 202 | txt=txt, 203 | vec=vec, 204 | pe=pe 205 | ) 206 | 207 | img = torch.cat((txt, img), 1) 208 | for block in self.single_blocks: 209 | if self.training and self.gradient_checkpointing: 210 | img = torch.utils.checkpoint.checkpoint( 211 | block, 212 | img, vec=vec, pe=pe, 213 | use_reentrant=False 214 | ) 215 | else: 216 | img = block(img, vec=vec, pe=pe) 217 | img = img[:, txt.shape[1] :, ...] 218 | # index img 219 | img = img[:, :img_end, ...] 220 | 221 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 222 | return img 223 | -------------------------------------------------------------------------------- /uno/flux/modules/autoencoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from dataclasses import dataclass 17 | 18 | import torch 19 | from einops import rearrange 20 | from torch import Tensor, nn 21 | 22 | 23 | @dataclass 24 | class AutoEncoderParams: 25 | resolution: int 26 | in_channels: int 27 | ch: int 28 | out_ch: int 29 | ch_mult: list[int] 30 | num_res_blocks: int 31 | z_channels: int 32 | scale_factor: float 33 | shift_factor: float 34 | 35 | 36 | def swish(x: Tensor) -> Tensor: 37 | return x * torch.sigmoid(x) 38 | 39 | 40 | class AttnBlock(nn.Module): 41 | def __init__(self, in_channels: int): 42 | super().__init__() 43 | self.in_channels = in_channels 44 | 45 | self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 46 | 47 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 48 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 49 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 50 | self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) 51 | 52 | def attention(self, h_: Tensor) -> Tensor: 53 | h_ = self.norm(h_) 54 | q = self.q(h_) 55 | k = self.k(h_) 56 | v = self.v(h_) 57 | 58 | b, c, h, w = q.shape 59 | q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 60 | k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 61 | v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 62 | h_ = nn.functional.scaled_dot_product_attention(q, k, v) 63 | 64 | return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) 65 | 66 | def forward(self, x: Tensor) -> Tensor: 67 | return x + self.proj_out(self.attention(x)) 68 | 69 | 70 | class ResnetBlock(nn.Module): 71 | def __init__(self, in_channels: int, out_channels: int): 72 | super().__init__() 73 | self.in_channels = in_channels 74 | out_channels = in_channels if out_channels is None else out_channels 75 | self.out_channels = out_channels 76 | 77 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 79 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) 80 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 81 | if self.in_channels != self.out_channels: 82 | self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 83 | 84 | def forward(self, x): 85 | h = x 86 | h = self.norm1(h) 87 | h = swish(h) 88 | h = self.conv1(h) 89 | 90 | h = self.norm2(h) 91 | h = swish(h) 92 | h = self.conv2(h) 93 | 94 | if self.in_channels != self.out_channels: 95 | x = self.nin_shortcut(x) 96 | 97 | return x + h 98 | 99 | 100 | class Downsample(nn.Module): 101 | def __init__(self, in_channels: int): 102 | super().__init__() 103 | # no asymmetric padding in torch conv, must do it ourselves 104 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) 105 | 106 | def forward(self, x: Tensor): 107 | pad = (0, 1, 0, 1) 108 | x = nn.functional.pad(x, pad, mode="constant", value=0) 109 | x = self.conv(x) 110 | return x 111 | 112 | 113 | class Upsample(nn.Module): 114 | def __init__(self, in_channels: int): 115 | super().__init__() 116 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 117 | 118 | def forward(self, x: Tensor): 119 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 120 | x = self.conv(x) 121 | return x 122 | 123 | 124 | class Encoder(nn.Module): 125 | def __init__( 126 | self, 127 | resolution: int, 128 | in_channels: int, 129 | ch: int, 130 | ch_mult: list[int], 131 | num_res_blocks: int, 132 | z_channels: int, 133 | ): 134 | super().__init__() 135 | self.ch = ch 136 | self.num_resolutions = len(ch_mult) 137 | self.num_res_blocks = num_res_blocks 138 | self.resolution = resolution 139 | self.in_channels = in_channels 140 | # downsampling 141 | self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 142 | 143 | curr_res = resolution 144 | in_ch_mult = (1,) + tuple(ch_mult) 145 | self.in_ch_mult = in_ch_mult 146 | self.down = nn.ModuleList() 147 | block_in = self.ch 148 | for i_level in range(self.num_resolutions): 149 | block = nn.ModuleList() 150 | attn = nn.ModuleList() 151 | block_in = ch * in_ch_mult[i_level] 152 | block_out = ch * ch_mult[i_level] 153 | for _ in range(self.num_res_blocks): 154 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 155 | block_in = block_out 156 | down = nn.Module() 157 | down.block = block 158 | down.attn = attn 159 | if i_level != self.num_resolutions - 1: 160 | down.downsample = Downsample(block_in) 161 | curr_res = curr_res // 2 162 | self.down.append(down) 163 | 164 | # middle 165 | self.mid = nn.Module() 166 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 167 | self.mid.attn_1 = AttnBlock(block_in) 168 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 169 | 170 | # end 171 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 172 | self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) 173 | 174 | def forward(self, x: Tensor) -> Tensor: 175 | # downsampling 176 | hs = [self.conv_in(x)] 177 | for i_level in range(self.num_resolutions): 178 | for i_block in range(self.num_res_blocks): 179 | h = self.down[i_level].block[i_block](hs[-1]) 180 | if len(self.down[i_level].attn) > 0: 181 | h = self.down[i_level].attn[i_block](h) 182 | hs.append(h) 183 | if i_level != self.num_resolutions - 1: 184 | hs.append(self.down[i_level].downsample(hs[-1])) 185 | 186 | # middle 187 | h = hs[-1] 188 | h = self.mid.block_1(h) 189 | h = self.mid.attn_1(h) 190 | h = self.mid.block_2(h) 191 | # end 192 | h = self.norm_out(h) 193 | h = swish(h) 194 | h = self.conv_out(h) 195 | return h 196 | 197 | 198 | class Decoder(nn.Module): 199 | def __init__( 200 | self, 201 | ch: int, 202 | out_ch: int, 203 | ch_mult: list[int], 204 | num_res_blocks: int, 205 | in_channels: int, 206 | resolution: int, 207 | z_channels: int, 208 | ): 209 | super().__init__() 210 | self.ch = ch 211 | self.num_resolutions = len(ch_mult) 212 | self.num_res_blocks = num_res_blocks 213 | self.resolution = resolution 214 | self.in_channels = in_channels 215 | self.ffactor = 2 ** (self.num_resolutions - 1) 216 | 217 | # compute in_ch_mult, block_in and curr_res at lowest res 218 | block_in = ch * ch_mult[self.num_resolutions - 1] 219 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 220 | self.z_shape = (1, z_channels, curr_res, curr_res) 221 | 222 | # z to block_in 223 | self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 224 | 225 | # middle 226 | self.mid = nn.Module() 227 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 228 | self.mid.attn_1 = AttnBlock(block_in) 229 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 230 | 231 | # upsampling 232 | self.up = nn.ModuleList() 233 | for i_level in reversed(range(self.num_resolutions)): 234 | block = nn.ModuleList() 235 | attn = nn.ModuleList() 236 | block_out = ch * ch_mult[i_level] 237 | for _ in range(self.num_res_blocks + 1): 238 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 239 | block_in = block_out 240 | up = nn.Module() 241 | up.block = block 242 | up.attn = attn 243 | if i_level != 0: 244 | up.upsample = Upsample(block_in) 245 | curr_res = curr_res * 2 246 | self.up.insert(0, up) # prepend to get consistent order 247 | 248 | # end 249 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 250 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 251 | 252 | def forward(self, z: Tensor) -> Tensor: 253 | # z to block_in 254 | h = self.conv_in(z) 255 | 256 | # middle 257 | h = self.mid.block_1(h) 258 | h = self.mid.attn_1(h) 259 | h = self.mid.block_2(h) 260 | 261 | # upsampling 262 | for i_level in reversed(range(self.num_resolutions)): 263 | for i_block in range(self.num_res_blocks + 1): 264 | h = self.up[i_level].block[i_block](h) 265 | if len(self.up[i_level].attn) > 0: 266 | h = self.up[i_level].attn[i_block](h) 267 | if i_level != 0: 268 | h = self.up[i_level].upsample(h) 269 | 270 | # end 271 | h = self.norm_out(h) 272 | h = swish(h) 273 | h = self.conv_out(h) 274 | return h 275 | 276 | 277 | class DiagonalGaussian(nn.Module): 278 | def __init__(self, sample: bool = True, chunk_dim: int = 1): 279 | super().__init__() 280 | self.sample = sample 281 | self.chunk_dim = chunk_dim 282 | 283 | def forward(self, z: Tensor) -> Tensor: 284 | mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) 285 | if self.sample: 286 | std = torch.exp(0.5 * logvar) 287 | return mean + std * torch.randn_like(mean) 288 | else: 289 | return mean 290 | 291 | 292 | class AutoEncoder(nn.Module): 293 | def __init__(self, params: AutoEncoderParams): 294 | super().__init__() 295 | self.encoder = Encoder( 296 | resolution=params.resolution, 297 | in_channels=params.in_channels, 298 | ch=params.ch, 299 | ch_mult=params.ch_mult, 300 | num_res_blocks=params.num_res_blocks, 301 | z_channels=params.z_channels, 302 | ) 303 | self.decoder = Decoder( 304 | resolution=params.resolution, 305 | in_channels=params.in_channels, 306 | ch=params.ch, 307 | out_ch=params.out_ch, 308 | ch_mult=params.ch_mult, 309 | num_res_blocks=params.num_res_blocks, 310 | z_channels=params.z_channels, 311 | ) 312 | self.reg = DiagonalGaussian() 313 | 314 | self.scale_factor = params.scale_factor 315 | self.shift_factor = params.shift_factor 316 | 317 | def encode(self, x: Tensor) -> Tensor: 318 | z = self.reg(self.encoder(x)) 319 | z = self.scale_factor * (z - self.shift_factor) 320 | return z 321 | 322 | def decode(self, z: Tensor) -> Tensor: 323 | z = z / self.scale_factor + self.shift_factor 324 | return self.decoder(z) 325 | 326 | def forward(self, x: Tensor) -> Tensor: 327 | return self.decode(self.encode(x)) 328 | -------------------------------------------------------------------------------- /uno/flux/modules/conditioner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from torch import Tensor, nn 17 | from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, 18 | T5Tokenizer) 19 | 20 | 21 | class HFEmbedder(nn.Module): 22 | def __init__(self, version: str, max_length: int, **hf_kwargs): 23 | super().__init__() 24 | self.is_clip = "clip" in version.lower() 25 | self.max_length = max_length 26 | self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" 27 | 28 | if self.is_clip: 29 | self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) 30 | self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) 31 | else: 32 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) 33 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) 34 | 35 | self.hf_module = self.hf_module.eval().requires_grad_(False) 36 | 37 | def forward(self, text: list[str]) -> Tensor: 38 | batch_encoding = self.tokenizer( 39 | text, 40 | truncation=True, 41 | max_length=self.max_length, 42 | return_length=False, 43 | return_overflowing_tokens=False, 44 | padding="max_length", 45 | return_tensors="pt", 46 | ) 47 | 48 | outputs = self.hf_module( 49 | input_ids=batch_encoding["input_ids"].to(self.hf_module.device), 50 | attention_mask=None, 51 | output_hidden_states=False, 52 | ) 53 | return outputs[self.output_key] 54 | -------------------------------------------------------------------------------- /uno/flux/modules/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import math 17 | from dataclasses import dataclass 18 | 19 | import torch 20 | from einops import rearrange 21 | from torch import Tensor, nn 22 | 23 | from ..math import attention, rope 24 | import torch.nn.functional as F 25 | 26 | class EmbedND(nn.Module): 27 | def __init__(self, dim: int, theta: int, axes_dim: list[int]): 28 | super().__init__() 29 | self.dim = dim 30 | self.theta = theta 31 | self.axes_dim = axes_dim 32 | 33 | def forward(self, ids: Tensor) -> Tensor: 34 | n_axes = ids.shape[-1] 35 | emb = torch.cat( 36 | [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 37 | dim=-3, 38 | ) 39 | 40 | return emb.unsqueeze(1) 41 | 42 | 43 | def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): 44 | """ 45 | Create sinusoidal timestep embeddings. 46 | :param t: a 1-D Tensor of N indices, one per batch element. 47 | These may be fractional. 48 | :param dim: the dimension of the output. 49 | :param max_period: controls the minimum frequency of the embeddings. 50 | :return: an (N, D) Tensor of positional embeddings. 51 | """ 52 | t = time_factor * t 53 | half = dim // 2 54 | freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( 55 | t.device 56 | ) 57 | 58 | args = t[:, None].float() * freqs[None] 59 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 60 | if dim % 2: 61 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 62 | if torch.is_floating_point(t): 63 | embedding = embedding.to(t) 64 | return embedding 65 | 66 | 67 | class MLPEmbedder(nn.Module): 68 | def __init__(self, in_dim: int, hidden_dim: int): 69 | super().__init__() 70 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) 71 | self.silu = nn.SiLU() 72 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) 73 | 74 | def forward(self, x: Tensor) -> Tensor: 75 | return self.out_layer(self.silu(self.in_layer(x))) 76 | 77 | 78 | class RMSNorm(torch.nn.Module): 79 | def __init__(self, dim: int): 80 | super().__init__() 81 | self.scale = nn.Parameter(torch.ones(dim)) 82 | 83 | def forward(self, x: Tensor): 84 | x_dtype = x.dtype 85 | x = x.float() 86 | rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) 87 | return ((x * rrms) * self.scale.float()).to(dtype=x_dtype) 88 | 89 | 90 | class QKNorm(torch.nn.Module): 91 | def __init__(self, dim: int): 92 | super().__init__() 93 | self.query_norm = RMSNorm(dim) 94 | self.key_norm = RMSNorm(dim) 95 | 96 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: 97 | q = self.query_norm(q) 98 | k = self.key_norm(k) 99 | return q.to(v), k.to(v) 100 | 101 | class LoRALinearLayer(nn.Module): 102 | def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): 103 | super().__init__() 104 | 105 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) 106 | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 107 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 108 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 109 | self.network_alpha = network_alpha 110 | self.rank = rank 111 | 112 | nn.init.normal_(self.down.weight, std=1 / rank) 113 | nn.init.zeros_(self.up.weight) 114 | 115 | def forward(self, hidden_states): 116 | orig_dtype = hidden_states.dtype 117 | dtype = self.down.weight.dtype 118 | 119 | down_hidden_states = self.down(hidden_states.to(dtype)) 120 | up_hidden_states = self.up(down_hidden_states) 121 | 122 | if self.network_alpha is not None: 123 | up_hidden_states *= self.network_alpha / self.rank 124 | 125 | return up_hidden_states.to(orig_dtype) 126 | 127 | class FLuxSelfAttnProcessor: 128 | def __call__(self, attn, x, pe, **attention_kwargs): 129 | qkv = attn.qkv(x) 130 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 131 | q, k = attn.norm(q, k, v) 132 | x = attention(q, k, v, pe=pe) 133 | x = attn.proj(x) 134 | return x 135 | 136 | class LoraFluxAttnProcessor(nn.Module): 137 | 138 | def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): 139 | super().__init__() 140 | self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 141 | self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) 142 | self.lora_weight = lora_weight 143 | 144 | 145 | def __call__(self, attn, x, pe, **attention_kwargs): 146 | qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight 147 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 148 | q, k = attn.norm(q, k, v) 149 | x = attention(q, k, v, pe=pe) 150 | x = attn.proj(x) + self.proj_lora(x) * self.lora_weight 151 | return x 152 | 153 | class SelfAttention(nn.Module): 154 | def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): 155 | super().__init__() 156 | self.num_heads = num_heads 157 | head_dim = dim // num_heads 158 | 159 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 160 | self.norm = QKNorm(head_dim) 161 | self.proj = nn.Linear(dim, dim) 162 | def forward(): 163 | pass 164 | 165 | 166 | @dataclass 167 | class ModulationOut: 168 | shift: Tensor 169 | scale: Tensor 170 | gate: Tensor 171 | 172 | 173 | class Modulation(nn.Module): 174 | def __init__(self, dim: int, double: bool): 175 | super().__init__() 176 | self.is_double = double 177 | self.multiplier = 6 if double else 3 178 | self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) 179 | 180 | def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: 181 | out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) 182 | 183 | return ( 184 | ModulationOut(*out[:3]), 185 | ModulationOut(*out[3:]) if self.is_double else None, 186 | ) 187 | 188 | class DoubleStreamBlockLoraProcessor(nn.Module): 189 | def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): 190 | super().__init__() 191 | self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 192 | self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha) 193 | self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 194 | self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha) 195 | self.lora_weight = lora_weight 196 | 197 | def forward(self, attn, img, txt, vec, pe, **attention_kwargs): 198 | img_mod1, img_mod2 = attn.img_mod(vec) 199 | txt_mod1, txt_mod2 = attn.txt_mod(vec) 200 | 201 | # prepare image for attention 202 | img_modulated = attn.img_norm1(img) 203 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 204 | img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight 205 | img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 206 | img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) 207 | 208 | # prepare txt for attention 209 | txt_modulated = attn.txt_norm1(txt) 210 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 211 | txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight 212 | txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 213 | txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) 214 | 215 | # run actual attention 216 | q = torch.cat((txt_q, img_q), dim=2) 217 | k = torch.cat((txt_k, img_k), dim=2) 218 | v = torch.cat((txt_v, img_v), dim=2) 219 | 220 | attn1 = attention(q, k, v, pe=pe) 221 | txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] 222 | 223 | # calculate the img bloks 224 | img = img + img_mod1.gate * (attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight) 225 | img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) 226 | 227 | # calculate the txt bloks 228 | txt = txt + txt_mod1.gate * (attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight) 229 | txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) 230 | return img, txt 231 | 232 | class DoubleStreamBlockProcessor: 233 | def __call__(self, attn, img, txt, vec, pe, **attention_kwargs): 234 | img_mod1, img_mod2 = attn.img_mod(vec) 235 | txt_mod1, txt_mod2 = attn.txt_mod(vec) 236 | 237 | # prepare image for attention 238 | img_modulated = attn.img_norm1(img) 239 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 240 | img_qkv = attn.img_attn.qkv(img_modulated) 241 | img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 242 | img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) 243 | 244 | # prepare txt for attention 245 | txt_modulated = attn.txt_norm1(txt) 246 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 247 | txt_qkv = attn.txt_attn.qkv(txt_modulated) 248 | txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) 249 | txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) 250 | 251 | # run actual attention 252 | q = torch.cat((txt_q, img_q), dim=2) 253 | k = torch.cat((txt_k, img_k), dim=2) 254 | v = torch.cat((txt_v, img_v), dim=2) 255 | 256 | attn1 = attention(q, k, v, pe=pe) 257 | txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] 258 | 259 | # calculate the img bloks 260 | img = img + img_mod1.gate * attn.img_attn.proj(img_attn) 261 | img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) 262 | 263 | # calculate the txt bloks 264 | txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) 265 | txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) 266 | return img, txt 267 | 268 | class DoubleStreamBlock(nn.Module): 269 | def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): 270 | super().__init__() 271 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 272 | self.num_heads = num_heads 273 | self.hidden_size = hidden_size 274 | self.head_dim = hidden_size // num_heads 275 | 276 | self.img_mod = Modulation(hidden_size, double=True) 277 | self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 278 | self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 279 | 280 | self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 281 | self.img_mlp = nn.Sequential( 282 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 283 | nn.GELU(approximate="tanh"), 284 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 285 | ) 286 | 287 | self.txt_mod = Modulation(hidden_size, double=True) 288 | self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 289 | self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 290 | 291 | self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 292 | self.txt_mlp = nn.Sequential( 293 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 294 | nn.GELU(approximate="tanh"), 295 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 296 | ) 297 | processor = DoubleStreamBlockProcessor() 298 | self.set_processor(processor) 299 | 300 | def set_processor(self, processor) -> None: 301 | self.processor = processor 302 | 303 | def get_processor(self): 304 | return self.processor 305 | 306 | def forward( 307 | self, 308 | img: Tensor, 309 | txt: Tensor, 310 | vec: Tensor, 311 | pe: Tensor, 312 | image_proj: Tensor = None, 313 | ip_scale: float =1.0, 314 | ) -> tuple[Tensor, Tensor]: 315 | if image_proj is None: 316 | return self.processor(self, img, txt, vec, pe) 317 | else: 318 | return self.processor(self, img, txt, vec, pe, image_proj, ip_scale) 319 | 320 | 321 | class SingleStreamBlockLoraProcessor(nn.Module): 322 | def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1): 323 | super().__init__() 324 | self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) 325 | self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha) 326 | self.lora_weight = lora_weight 327 | 328 | def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: 329 | 330 | mod, _ = attn.modulation(vec) 331 | x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift 332 | qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) 333 | qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight 334 | 335 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 336 | q, k = attn.norm(q, k, v) 337 | 338 | # compute attention 339 | attn_1 = attention(q, k, v, pe=pe) 340 | 341 | # compute activation in mlp stream, cat again and run second linear layer 342 | output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) 343 | output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight 344 | output = x + mod.gate * output 345 | return output 346 | 347 | 348 | class SingleStreamBlockProcessor: 349 | def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, **attention_kwargs) -> Tensor: 350 | 351 | mod, _ = attn.modulation(vec) 352 | x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift 353 | qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) 354 | 355 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) 356 | q, k = attn.norm(q, k, v) 357 | 358 | # compute attention 359 | attn_1 = attention(q, k, v, pe=pe) 360 | 361 | # compute activation in mlp stream, cat again and run second linear layer 362 | output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) 363 | output = x + mod.gate * output 364 | return output 365 | 366 | class SingleStreamBlock(nn.Module): 367 | """ 368 | A DiT block with parallel linear layers as described in 369 | https://arxiv.org/abs/2302.05442 and adapted modulation interface. 370 | """ 371 | 372 | def __init__( 373 | self, 374 | hidden_size: int, 375 | num_heads: int, 376 | mlp_ratio: float = 4.0, 377 | qk_scale: float | None = None, 378 | ): 379 | super().__init__() 380 | self.hidden_dim = hidden_size 381 | self.num_heads = num_heads 382 | self.head_dim = hidden_size // num_heads 383 | self.scale = qk_scale or self.head_dim**-0.5 384 | 385 | self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 386 | # qkv and mlp_in 387 | self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) 388 | # proj and mlp_out 389 | self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) 390 | 391 | self.norm = QKNorm(self.head_dim) 392 | 393 | self.hidden_size = hidden_size 394 | self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 395 | 396 | self.mlp_act = nn.GELU(approximate="tanh") 397 | self.modulation = Modulation(hidden_size, double=False) 398 | 399 | processor = SingleStreamBlockProcessor() 400 | self.set_processor(processor) 401 | 402 | 403 | def set_processor(self, processor) -> None: 404 | self.processor = processor 405 | 406 | def get_processor(self): 407 | return self.processor 408 | 409 | def forward( 410 | self, 411 | x: Tensor, 412 | vec: Tensor, 413 | pe: Tensor, 414 | image_proj: Tensor | None = None, 415 | ip_scale: float = 1.0, 416 | ) -> Tensor: 417 | if image_proj is None: 418 | return self.processor(self, x, vec, pe) 419 | else: 420 | return self.processor(self, x, vec, pe, image_proj, ip_scale) 421 | 422 | 423 | 424 | class LastLayer(nn.Module): 425 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int): 426 | super().__init__() 427 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 428 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 429 | self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) 430 | 431 | def forward(self, x: Tensor, vec: Tensor) -> Tensor: 432 | shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) 433 | x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] 434 | x = self.linear(x) 435 | return x 436 | -------------------------------------------------------------------------------- /uno/flux/pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from typing import Literal 18 | 19 | import torch 20 | from einops import rearrange 21 | from PIL import ExifTags, Image 22 | import torchvision.transforms.functional as TVF 23 | 24 | from uno.flux.modules.layers import ( 25 | DoubleStreamBlockLoraProcessor, 26 | DoubleStreamBlockProcessor, 27 | SingleStreamBlockLoraProcessor, 28 | SingleStreamBlockProcessor, 29 | ) 30 | from uno.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack 31 | from uno.flux.util import ( 32 | get_lora_rank, 33 | load_ae, 34 | load_checkpoint, 35 | load_clip, 36 | load_flow_model, 37 | load_flow_model_only_lora, 38 | load_flow_model_quintized, 39 | load_t5, 40 | ) 41 | 42 | 43 | def find_nearest_scale(image_h, image_w, predefined_scales): 44 | """ 45 | 根据图片的高度和宽度,找到最近的预定义尺度。 46 | 47 | :param image_h: 图片的高度 48 | :param image_w: 图片的宽度 49 | :param predefined_scales: 预定义尺度列表 [(h1, w1), (h2, w2), ...] 50 | :return: 最近的预定义尺度 (h, w) 51 | """ 52 | # 计算输入图片的长宽比 53 | image_ratio = image_h / image_w 54 | 55 | # 初始化变量以存储最小差异和最近的尺度 56 | min_diff = float('inf') 57 | nearest_scale = None 58 | 59 | # 遍历所有预定义尺度,找到与输入图片长宽比最接近的尺度 60 | for scale_h, scale_w in predefined_scales: 61 | predefined_ratio = scale_h / scale_w 62 | diff = abs(predefined_ratio - image_ratio) 63 | 64 | if diff < min_diff: 65 | min_diff = diff 66 | nearest_scale = (scale_h, scale_w) 67 | 68 | return nearest_scale 69 | 70 | def preprocess_ref(raw_image: Image.Image, long_size: int = 512): 71 | # 获取原始图像的宽度和高度 72 | image_w, image_h = raw_image.size 73 | 74 | # 计算长边和短边 75 | if image_w >= image_h: 76 | new_w = long_size 77 | new_h = int((long_size / image_w) * image_h) 78 | else: 79 | new_h = long_size 80 | new_w = int((long_size / image_h) * image_w) 81 | 82 | # 按新的宽高进行等比例缩放 83 | raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS) 84 | target_w = new_w // 16 * 16 85 | target_h = new_h // 16 * 16 86 | 87 | # 计算裁剪的起始坐标以实现中心裁剪 88 | left = (new_w - target_w) // 2 89 | top = (new_h - target_h) // 2 90 | right = left + target_w 91 | bottom = top + target_h 92 | 93 | # 进行中心裁剪 94 | raw_image = raw_image.crop((left, top, right, bottom)) 95 | 96 | # 转换为 RGB 模式 97 | raw_image = raw_image.convert("RGB") 98 | return raw_image 99 | 100 | class UNOPipeline: 101 | def __init__( 102 | self, 103 | model_type: str, 104 | device: torch.device, 105 | offload: bool = False, 106 | only_lora: bool = False, 107 | lora_rank: int = 16 108 | ): 109 | self.device = device 110 | self.offload = offload 111 | self.model_type = model_type 112 | 113 | self.clip = load_clip(self.device) 114 | self.t5 = load_t5(self.device, max_length=512) 115 | self.ae = load_ae(model_type, device="cpu" if offload else self.device) 116 | self.use_fp8 = "fp8" in model_type 117 | if only_lora: 118 | self.model = load_flow_model_only_lora( 119 | model_type, 120 | device="cpu" if offload else self.device, 121 | lora_rank=lora_rank, 122 | use_fp8=self.use_fp8 123 | ) 124 | else: 125 | self.model = load_flow_model(model_type, device="cpu" if offload else self.device) 126 | 127 | 128 | def load_ckpt(self, ckpt_path): 129 | if ckpt_path is not None: 130 | from safetensors.torch import load_file as load_sft 131 | print("Loading checkpoint to replace old keys") 132 | # load_sft doesn't support torch.device 133 | if ckpt_path.endswith('safetensors'): 134 | sd = load_sft(ckpt_path, device='cpu') 135 | missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True) 136 | else: 137 | dit_state = torch.load(ckpt_path, map_location='cpu') 138 | sd = {} 139 | for k in dit_state.keys(): 140 | sd[k.replace('module.','')] = dit_state[k] 141 | missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True) 142 | self.model.to(str(self.device)) 143 | print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}") 144 | 145 | def set_lora(self, local_path: str = None, repo_id: str = None, 146 | name: str = None, lora_weight: int = 0.7): 147 | checkpoint = load_checkpoint(local_path, repo_id, name) 148 | self.update_model_with_lora(checkpoint, lora_weight) 149 | 150 | def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7): 151 | checkpoint = load_checkpoint( 152 | None, self.hf_lora_collection, self.lora_types_to_names[lora_type] 153 | ) 154 | self.update_model_with_lora(checkpoint, lora_weight) 155 | 156 | def update_model_with_lora(self, checkpoint, lora_weight): 157 | rank = get_lora_rank(checkpoint) 158 | lora_attn_procs = {} 159 | 160 | for name, _ in self.model.attn_processors.items(): 161 | lora_state_dict = {} 162 | for k in checkpoint.keys(): 163 | if name in k: 164 | lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight 165 | 166 | if len(lora_state_dict): 167 | if name.startswith("single_blocks"): 168 | lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank) 169 | else: 170 | lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank) 171 | lora_attn_procs[name].load_state_dict(lora_state_dict) 172 | lora_attn_procs[name].to(self.device) 173 | else: 174 | if name.startswith("single_blocks"): 175 | lora_attn_procs[name] = SingleStreamBlockProcessor() 176 | else: 177 | lora_attn_procs[name] = DoubleStreamBlockProcessor() 178 | 179 | self.model.set_attn_processor(lora_attn_procs) 180 | 181 | 182 | def __call__( 183 | self, 184 | prompt: str, 185 | width: int = 512, 186 | height: int = 512, 187 | guidance: float = 4, 188 | num_steps: int = 50, 189 | seed: int = 123456789, 190 | **kwargs 191 | ): 192 | width = 16 * (width // 16) 193 | height = 16 * (height // 16) 194 | 195 | device_type = self.device if isinstance(self.device, str) else self.device.type 196 | if device_type == "mps": 197 | device_type = "cpu" # for support macos mps 198 | with torch.autocast(enabled=self.use_fp8, device_type=device_type, dtype=torch.bfloat16): 199 | return self.forward( 200 | prompt, 201 | width, 202 | height, 203 | guidance, 204 | num_steps, 205 | seed, 206 | **kwargs 207 | ) 208 | 209 | @torch.inference_mode() 210 | def gradio_generate( 211 | self, 212 | prompt: str, 213 | width: int, 214 | height: int, 215 | guidance: float, 216 | num_steps: int, 217 | seed: int, 218 | image_prompt1: Image.Image, 219 | image_prompt2: Image.Image, 220 | image_prompt3: Image.Image, 221 | image_prompt4: Image.Image, 222 | ): 223 | ref_imgs = [image_prompt1, image_prompt2, image_prompt3, image_prompt4] 224 | ref_imgs = [img for img in ref_imgs if isinstance(img, Image.Image)] 225 | ref_long_side = 512 if len(ref_imgs) <= 1 else 320 226 | ref_imgs = [preprocess_ref(img, ref_long_side) for img in ref_imgs] 227 | 228 | seed = seed if seed != -1 else torch.randint(0, 10 ** 8, (1,)).item() 229 | 230 | img = self(prompt=prompt, width=width, height=height, guidance=guidance, 231 | num_steps=num_steps, seed=seed, ref_imgs=ref_imgs) 232 | 233 | filename = f"output/gradio/{seed}_{prompt[:20]}.png" 234 | os.makedirs(os.path.dirname(filename), exist_ok=True) 235 | exif_data = Image.Exif() 236 | exif_data[ExifTags.Base.Make] = "UNO" 237 | exif_data[ExifTags.Base.Model] = self.model_type 238 | info = f"{prompt=}, {seed=}, {width=}, {height=}, {guidance=}, {num_steps=}" 239 | exif_data[ExifTags.Base.ImageDescription] = info 240 | img.save(filename, format="png", exif=exif_data) 241 | return img, filename 242 | 243 | @torch.inference_mode 244 | def forward( 245 | self, 246 | prompt: str, 247 | width: int, 248 | height: int, 249 | guidance: float, 250 | num_steps: int, 251 | seed: int, 252 | ref_imgs: list[Image.Image] | None = None, 253 | pe: Literal['d', 'h', 'w', 'o'] = 'd', 254 | ): 255 | x = get_noise( 256 | 1, height, width, device=self.device, 257 | dtype=torch.bfloat16, seed=seed 258 | ) 259 | timesteps = get_schedule( 260 | num_steps, 261 | (width // 8) * (height // 8) // (16 * 16), 262 | shift=True, 263 | ) 264 | if self.offload: 265 | self.ae.encoder = self.ae.encoder.to(self.device) 266 | x_1_refs = [ 267 | self.ae.encode( 268 | (TVF.to_tensor(ref_img) * 2.0 - 1.0) 269 | .unsqueeze(0).to(self.device, torch.float32) 270 | ).to(torch.bfloat16) 271 | for ref_img in ref_imgs 272 | ] 273 | 274 | if self.offload: 275 | self.offload_model_to_cpu(self.ae.encoder) 276 | self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) 277 | inp_cond = prepare_multi_ip( 278 | t5=self.t5, clip=self.clip, 279 | img=x, 280 | prompt=prompt, ref_imgs=x_1_refs, pe=pe 281 | ) 282 | 283 | if self.offload: 284 | self.offload_model_to_cpu(self.t5, self.clip) 285 | self.model = self.model.to(self.device) 286 | 287 | x = denoise( 288 | self.model, 289 | **inp_cond, 290 | timesteps=timesteps, 291 | guidance=guidance, 292 | ) 293 | 294 | if self.offload: 295 | self.offload_model_to_cpu(self.model) 296 | self.ae.decoder.to(x.device) 297 | x = unpack(x.float(), height, width) 298 | x = self.ae.decode(x) 299 | self.offload_model_to_cpu(self.ae.decoder) 300 | 301 | x1 = x.clamp(-1, 1) 302 | x1 = rearrange(x1[-1], "c h w -> h w c") 303 | output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) 304 | return output_img 305 | 306 | def offload_model_to_cpu(self, *models): 307 | if not self.offload: return 308 | for model in models: 309 | model.cpu() 310 | torch.cuda.empty_cache() 311 | -------------------------------------------------------------------------------- /uno/flux/sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import math 17 | from typing import Literal 18 | 19 | import torch 20 | from einops import rearrange, repeat 21 | from torch import Tensor 22 | from tqdm import tqdm 23 | 24 | from .model import Flux 25 | from .modules.conditioner import HFEmbedder 26 | 27 | 28 | def get_noise( 29 | num_samples: int, 30 | height: int, 31 | width: int, 32 | device: torch.device, 33 | dtype: torch.dtype, 34 | seed: int, 35 | ): 36 | return torch.randn( 37 | num_samples, 38 | 16, 39 | # allow for packing 40 | 2 * math.ceil(height / 16), 41 | 2 * math.ceil(width / 16), 42 | device=device, 43 | dtype=dtype, 44 | generator=torch.Generator(device=device).manual_seed(seed), 45 | ) 46 | 47 | 48 | def prepare( 49 | t5: HFEmbedder, 50 | clip: HFEmbedder, 51 | img: Tensor, 52 | prompt: str | list[str], 53 | ref_img: None | Tensor=None, 54 | pe: Literal['d', 'h', 'w', 'o'] ='d' 55 | ) -> dict[str, Tensor]: 56 | assert pe in ['d', 'h', 'w', 'o'] 57 | bs, c, h, w = img.shape 58 | if bs == 1 and not isinstance(prompt, str): 59 | bs = len(prompt) 60 | 61 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 62 | if img.shape[0] == 1 and bs > 1: 63 | img = repeat(img, "1 ... -> bs ...", bs=bs) 64 | 65 | img_ids = torch.zeros(h // 2, w // 2, 3) 66 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] 67 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] 68 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 69 | 70 | if ref_img is not None: 71 | _, _, ref_h, ref_w = ref_img.shape 72 | ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 73 | if ref_img.shape[0] == 1 and bs > 1: 74 | ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) 75 | ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3) 76 | # img id分别在宽高偏移各自最大值 77 | h_offset = h // 2 if pe in {'d', 'h'} else 0 78 | w_offset = w // 2 if pe in {'d', 'w'} else 0 79 | ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None] + h_offset 80 | ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :] + w_offset 81 | ref_img_ids = repeat(ref_img_ids, "h w c -> b (h w) c", b=bs) 82 | 83 | if isinstance(prompt, str): 84 | prompt = [prompt] 85 | txt = t5(prompt) 86 | if txt.shape[0] == 1 and bs > 1: 87 | txt = repeat(txt, "1 ... -> bs ...", bs=bs) 88 | txt_ids = torch.zeros(bs, txt.shape[1], 3) 89 | 90 | vec = clip(prompt) 91 | if vec.shape[0] == 1 and bs > 1: 92 | vec = repeat(vec, "1 ... -> bs ...", bs=bs) 93 | 94 | if ref_img is not None: 95 | return { 96 | "img": img, 97 | "img_ids": img_ids.to(img.device), 98 | "ref_img": ref_img, 99 | "ref_img_ids": ref_img_ids.to(img.device), 100 | "txt": txt.to(img.device), 101 | "txt_ids": txt_ids.to(img.device), 102 | "vec": vec.to(img.device), 103 | } 104 | else: 105 | return { 106 | "img": img, 107 | "img_ids": img_ids.to(img.device), 108 | "txt": txt.to(img.device), 109 | "txt_ids": txt_ids.to(img.device), 110 | "vec": vec.to(img.device), 111 | } 112 | 113 | def prepare_multi_ip( 114 | t5: HFEmbedder, 115 | clip: HFEmbedder, 116 | img: Tensor, 117 | prompt: str | list[str], 118 | ref_imgs: list[Tensor] | None = None, 119 | pe: Literal['d', 'h', 'w', 'o'] = 'd' 120 | ) -> dict[str, Tensor]: 121 | assert pe in ['d', 'h', 'w', 'o'] 122 | bs, c, h, w = img.shape 123 | if bs == 1 and not isinstance(prompt, str): 124 | bs = len(prompt) 125 | 126 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 127 | if img.shape[0] == 1 and bs > 1: 128 | img = repeat(img, "1 ... -> bs ...", bs=bs) 129 | 130 | img_ids = torch.zeros(h // 2, w // 2, 3) 131 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] 132 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] 133 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 134 | 135 | ref_img_ids = [] 136 | ref_imgs_list = [] 137 | pe_shift_w, pe_shift_h = w // 2, h // 2 138 | for ref_img in ref_imgs: 139 | _, _, ref_h1, ref_w1 = ref_img.shape 140 | ref_img = rearrange(ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 141 | if ref_img.shape[0] == 1 and bs > 1: 142 | ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) 143 | ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3) 144 | # img id分别在宽高偏移各自最大值 145 | h_offset = pe_shift_h if pe in {'d', 'h'} else 0 146 | w_offset = pe_shift_w if pe in {'d', 'w'} else 0 147 | ref_img_ids1[..., 1] = ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset 148 | ref_img_ids1[..., 2] = ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset 149 | ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs) 150 | ref_img_ids.append(ref_img_ids1) 151 | ref_imgs_list.append(ref_img) 152 | 153 | # 更新pe shift 154 | pe_shift_h += ref_h1 // 2 155 | pe_shift_w += ref_w1 // 2 156 | 157 | if isinstance(prompt, str): 158 | prompt = [prompt] 159 | txt = t5(prompt) 160 | if txt.shape[0] == 1 and bs > 1: 161 | txt = repeat(txt, "1 ... -> bs ...", bs=bs) 162 | txt_ids = torch.zeros(bs, txt.shape[1], 3) 163 | 164 | vec = clip(prompt) 165 | if vec.shape[0] == 1 and bs > 1: 166 | vec = repeat(vec, "1 ... -> bs ...", bs=bs) 167 | 168 | return { 169 | "img": img, 170 | "img_ids": img_ids.to(img.device), 171 | "ref_img": tuple(ref_imgs_list), 172 | "ref_img_ids": [ref_img_id.to(img.device) for ref_img_id in ref_img_ids], 173 | "txt": txt.to(img.device), 174 | "txt_ids": txt_ids.to(img.device), 175 | "vec": vec.to(img.device), 176 | } 177 | 178 | 179 | def time_shift(mu: float, sigma: float, t: Tensor): 180 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 181 | 182 | 183 | def get_lin_function( 184 | x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 185 | ): 186 | m = (y2 - y1) / (x2 - x1) 187 | b = y1 - m * x1 188 | return lambda x: m * x + b 189 | 190 | 191 | def get_schedule( 192 | num_steps: int, 193 | image_seq_len: int, 194 | base_shift: float = 0.5, 195 | max_shift: float = 1.15, 196 | shift: bool = True, 197 | ) -> list[float]: 198 | # extra step for zero 199 | timesteps = torch.linspace(1, 0, num_steps + 1) 200 | 201 | # shifting the schedule to favor high timesteps for higher signal images 202 | if shift: 203 | # eastimate mu based on linear estimation between two points 204 | mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) 205 | timesteps = time_shift(mu, 1.0, timesteps) 206 | 207 | return timesteps.tolist() 208 | 209 | 210 | def denoise( 211 | model: Flux, 212 | # model input 213 | img: Tensor, 214 | img_ids: Tensor, 215 | txt: Tensor, 216 | txt_ids: Tensor, 217 | vec: Tensor, 218 | # sampling parameters 219 | timesteps: list[float], 220 | guidance: float = 4.0, 221 | ref_img: Tensor=None, 222 | ref_img_ids: Tensor=None, 223 | ): 224 | i = 0 225 | guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) 226 | for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1): 227 | t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) 228 | pred = model( 229 | img=img, 230 | img_ids=img_ids, 231 | ref_img=ref_img, 232 | ref_img_ids=ref_img_ids, 233 | txt=txt, 234 | txt_ids=txt_ids, 235 | y=vec, 236 | timesteps=t_vec, 237 | guidance=guidance_vec 238 | ) 239 | img = img + (t_prev - t_curr) * pred 240 | i += 1 241 | return img 242 | 243 | 244 | def unpack(x: Tensor, height: int, width: int) -> Tensor: 245 | return rearrange( 246 | x, 247 | "b (h w) (c ph pw) -> b c (h ph) (w pw)", 248 | h=math.ceil(height / 16), 249 | w=math.ceil(width / 16), 250 | ph=2, 251 | pw=2, 252 | ) 253 | -------------------------------------------------------------------------------- /uno/flux/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from dataclasses import dataclass 18 | 19 | import torch 20 | import json 21 | import numpy as np 22 | from huggingface_hub import hf_hub_download 23 | from safetensors import safe_open 24 | from safetensors.torch import load_file as load_sft 25 | 26 | from .model import Flux, FluxParams 27 | from .modules.autoencoder import AutoEncoder, AutoEncoderParams 28 | from .modules.conditioner import HFEmbedder 29 | 30 | import re 31 | from uno.flux.modules.layers import DoubleStreamBlockLoraProcessor, SingleStreamBlockLoraProcessor 32 | def load_model(ckpt, device='cpu'): 33 | if ckpt.endswith('safetensors'): 34 | from safetensors import safe_open 35 | pl_sd = {} 36 | with safe_open(ckpt, framework="pt", device=device) as f: 37 | for k in f.keys(): 38 | pl_sd[k] = f.get_tensor(k) 39 | else: 40 | pl_sd = torch.load(ckpt, map_location=device) 41 | return pl_sd 42 | 43 | def load_safetensors(path): 44 | tensors = {} 45 | with safe_open(path, framework="pt", device="cpu") as f: 46 | for key in f.keys(): 47 | tensors[key] = f.get_tensor(key) 48 | return tensors 49 | 50 | def get_lora_rank(checkpoint): 51 | for k in checkpoint.keys(): 52 | if k.endswith(".down.weight"): 53 | return checkpoint[k].shape[0] 54 | 55 | def load_checkpoint(local_path, repo_id, name): 56 | if local_path is not None: 57 | if '.safetensors' in local_path: 58 | print(f"Loading .safetensors checkpoint from {local_path}") 59 | checkpoint = load_safetensors(local_path) 60 | else: 61 | print(f"Loading checkpoint from {local_path}") 62 | checkpoint = torch.load(local_path, map_location='cpu') 63 | elif repo_id is not None and name is not None: 64 | print(f"Loading checkpoint {name} from repo id {repo_id}") 65 | checkpoint = load_from_repo_id(repo_id, name) 66 | else: 67 | raise ValueError( 68 | "LOADING ERROR: you must specify local_path or repo_id with name in HF to download" 69 | ) 70 | return checkpoint 71 | 72 | 73 | def c_crop(image): 74 | width, height = image.size 75 | new_size = min(width, height) 76 | left = (width - new_size) / 2 77 | top = (height - new_size) / 2 78 | right = (width + new_size) / 2 79 | bottom = (height + new_size) / 2 80 | return image.crop((left, top, right, bottom)) 81 | 82 | def pad64(x): 83 | return int(np.ceil(float(x) / 64.0) * 64 - x) 84 | 85 | def HWC3(x): 86 | assert x.dtype == np.uint8 87 | if x.ndim == 2: 88 | x = x[:, :, None] 89 | assert x.ndim == 3 90 | H, W, C = x.shape 91 | assert C == 1 or C == 3 or C == 4 92 | if C == 3: 93 | return x 94 | if C == 1: 95 | return np.concatenate([x, x, x], axis=2) 96 | if C == 4: 97 | color = x[:, :, 0:3].astype(np.float32) 98 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 99 | y = color * alpha + 255.0 * (1.0 - alpha) 100 | y = y.clip(0, 255).astype(np.uint8) 101 | return y 102 | 103 | @dataclass 104 | class ModelSpec: 105 | params: FluxParams 106 | ae_params: AutoEncoderParams 107 | ckpt_path: str | None 108 | ae_path: str | None 109 | repo_id: str | None 110 | repo_flow: str | None 111 | repo_ae: str | None 112 | repo_id_ae: str | None 113 | 114 | 115 | configs = { 116 | "flux-dev": ModelSpec( 117 | repo_id="black-forest-labs/FLUX.1-dev", 118 | repo_id_ae="black-forest-labs/FLUX.1-dev", 119 | repo_flow="flux1-dev.safetensors", 120 | repo_ae="ae.safetensors", 121 | ckpt_path=os.getenv("FLUX_DEV"), 122 | params=FluxParams( 123 | in_channels=64, 124 | vec_in_dim=768, 125 | context_in_dim=4096, 126 | hidden_size=3072, 127 | mlp_ratio=4.0, 128 | num_heads=24, 129 | depth=19, 130 | depth_single_blocks=38, 131 | axes_dim=[16, 56, 56], 132 | theta=10_000, 133 | qkv_bias=True, 134 | guidance_embed=True, 135 | ), 136 | ae_path=os.getenv("AE"), 137 | ae_params=AutoEncoderParams( 138 | resolution=256, 139 | in_channels=3, 140 | ch=128, 141 | out_ch=3, 142 | ch_mult=[1, 2, 4, 4], 143 | num_res_blocks=2, 144 | z_channels=16, 145 | scale_factor=0.3611, 146 | shift_factor=0.1159, 147 | ), 148 | ), 149 | "flux-dev-fp8": ModelSpec( 150 | repo_id="black-forest-labs/FLUX.1-dev", 151 | repo_id_ae="black-forest-labs/FLUX.1-dev", 152 | repo_flow="flux1-dev.safetensors", 153 | repo_ae="ae.safetensors", 154 | ckpt_path=os.getenv("FLUX_DEV_FP8"), 155 | params=FluxParams( 156 | in_channels=64, 157 | vec_in_dim=768, 158 | context_in_dim=4096, 159 | hidden_size=3072, 160 | mlp_ratio=4.0, 161 | num_heads=24, 162 | depth=19, 163 | depth_single_blocks=38, 164 | axes_dim=[16, 56, 56], 165 | theta=10_000, 166 | qkv_bias=True, 167 | guidance_embed=True, 168 | ), 169 | ae_path=os.getenv("AE"), 170 | ae_params=AutoEncoderParams( 171 | resolution=256, 172 | in_channels=3, 173 | ch=128, 174 | out_ch=3, 175 | ch_mult=[1, 2, 4, 4], 176 | num_res_blocks=2, 177 | z_channels=16, 178 | scale_factor=0.3611, 179 | shift_factor=0.1159, 180 | ), 181 | ), 182 | "flux-schnell": ModelSpec( 183 | repo_id="black-forest-labs/FLUX.1-schnell", 184 | repo_id_ae="black-forest-labs/FLUX.1-dev", 185 | repo_flow="flux1-schnell.safetensors", 186 | repo_ae="ae.safetensors", 187 | ckpt_path=os.getenv("FLUX_SCHNELL"), 188 | params=FluxParams( 189 | in_channels=64, 190 | vec_in_dim=768, 191 | context_in_dim=4096, 192 | hidden_size=3072, 193 | mlp_ratio=4.0, 194 | num_heads=24, 195 | depth=19, 196 | depth_single_blocks=38, 197 | axes_dim=[16, 56, 56], 198 | theta=10_000, 199 | qkv_bias=True, 200 | guidance_embed=False, 201 | ), 202 | ae_path=os.getenv("AE"), 203 | ae_params=AutoEncoderParams( 204 | resolution=256, 205 | in_channels=3, 206 | ch=128, 207 | out_ch=3, 208 | ch_mult=[1, 2, 4, 4], 209 | num_res_blocks=2, 210 | z_channels=16, 211 | scale_factor=0.3611, 212 | shift_factor=0.1159, 213 | ), 214 | ), 215 | } 216 | 217 | 218 | def print_load_warning(missing: list[str], unexpected: list[str]) -> None: 219 | if len(missing) > 0 and len(unexpected) > 0: 220 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 221 | print("\n" + "-" * 79 + "\n") 222 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 223 | elif len(missing) > 0: 224 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 225 | elif len(unexpected) > 0: 226 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 227 | 228 | def load_from_repo_id(repo_id, checkpoint_name): 229 | ckpt_path = hf_hub_download(repo_id, checkpoint_name) 230 | sd = load_sft(ckpt_path, device='cpu') 231 | return sd 232 | 233 | def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): 234 | # Loading Flux 235 | print("Init model") 236 | ckpt_path = configs[name].ckpt_path 237 | if ( 238 | ckpt_path is None 239 | and configs[name].repo_id is not None 240 | and configs[name].repo_flow is not None 241 | and hf_download 242 | ): 243 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) 244 | 245 | with torch.device("meta" if ckpt_path is not None else device): 246 | model = Flux(configs[name].params).to(torch.bfloat16) 247 | 248 | if ckpt_path is not None: 249 | print("Loading checkpoint") 250 | # load_sft doesn't support torch.device 251 | sd = load_model(ckpt_path, device=str(device)) 252 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 253 | print_load_warning(missing, unexpected) 254 | return model 255 | 256 | def load_flow_model_only_lora( 257 | name: str, 258 | device: str | torch.device = "cuda", 259 | hf_download: bool = True, 260 | lora_rank: int = 16, 261 | use_fp8: bool = False 262 | ): 263 | # Loading Flux 264 | print("Init model") 265 | ckpt_path = configs[name].ckpt_path 266 | if ( 267 | ckpt_path is None 268 | and configs[name].repo_id is not None 269 | and configs[name].repo_flow is not None 270 | and hf_download 271 | ): 272 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")) 273 | 274 | if hf_download: 275 | try: 276 | lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors") 277 | except: 278 | lora_ckpt_path = os.environ.get("LORA", None) 279 | else: 280 | lora_ckpt_path = os.environ.get("LORA", None) 281 | 282 | with torch.device("meta" if ckpt_path is not None else device): 283 | model = Flux(configs[name].params) 284 | 285 | 286 | model = set_lora(model, lora_rank, device="meta" if lora_ckpt_path is not None else device) 287 | 288 | if ckpt_path is not None: 289 | print("Loading lora") 290 | lora_sd = load_sft(lora_ckpt_path, device=str(device)) if lora_ckpt_path.endswith("safetensors")\ 291 | else torch.load(lora_ckpt_path, map_location='cpu') 292 | 293 | print("Loading main checkpoint") 294 | # load_sft doesn't support torch.device 295 | 296 | if ckpt_path.endswith('safetensors'): 297 | if use_fp8: 298 | print( 299 | "####\n" 300 | "We are in fp8 mode right now, since the fp8 checkpoint of XLabs-AI/flux-dev-fp8 seems broken\n" 301 | "we convert the fp8 checkpoint on flight from bf16 checkpoint\n" 302 | "If your storage is constrained" 303 | "you can save the fp8 checkpoint and replace the bf16 checkpoint by yourself\n" 304 | ) 305 | sd = load_sft(ckpt_path, device="cpu") 306 | sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()} 307 | else: 308 | sd = load_sft(ckpt_path, device=str(device)) 309 | 310 | sd.update(lora_sd) 311 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 312 | else: 313 | dit_state = torch.load(ckpt_path, map_location='cpu') 314 | sd = {} 315 | for k in dit_state.keys(): 316 | sd[k.replace('module.','')] = dit_state[k] 317 | sd.update(lora_sd) 318 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 319 | model.to(str(device)) 320 | print_load_warning(missing, unexpected) 321 | return model 322 | 323 | 324 | def set_lora( 325 | model: Flux, 326 | lora_rank: int, 327 | double_blocks_indices: list[int] | None = None, 328 | single_blocks_indices: list[int] | None = None, 329 | device: str | torch.device = "cpu", 330 | ) -> Flux: 331 | double_blocks_indices = list(range(model.params.depth)) if double_blocks_indices is None else double_blocks_indices 332 | single_blocks_indices = list(range(model.params.depth_single_blocks)) if single_blocks_indices is None \ 333 | else single_blocks_indices 334 | 335 | lora_attn_procs = {} 336 | with torch.device(device): 337 | for name, attn_processor in model.attn_processors.items(): 338 | match = re.search(r'\.(\d+)\.', name) 339 | if match: 340 | layer_index = int(match.group(1)) 341 | 342 | if name.startswith("double_blocks") and layer_index in double_blocks_indices: 343 | lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank) 344 | elif name.startswith("single_blocks") and layer_index in single_blocks_indices: 345 | lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=model.params.hidden_size, rank=lora_rank) 346 | else: 347 | lora_attn_procs[name] = attn_processor 348 | model.set_attn_processor(lora_attn_procs) 349 | return model 350 | 351 | 352 | def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True): 353 | # Loading Flux 354 | from optimum.quanto import requantize 355 | print("Init model") 356 | ckpt_path = configs[name].ckpt_path 357 | if ( 358 | ckpt_path is None 359 | and configs[name].repo_id is not None 360 | and configs[name].repo_flow is not None 361 | and hf_download 362 | ): 363 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) 364 | # json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json') 365 | 366 | 367 | model = Flux(configs[name].params).to(torch.bfloat16) 368 | 369 | print("Loading checkpoint") 370 | # load_sft doesn't support torch.device 371 | sd = load_sft(ckpt_path, device='cpu') 372 | sd = {k: v.to(dtype=torch.float8_e4m3fn, device=device) for k, v in sd.items()} 373 | model.load_state_dict(sd, assign=True) 374 | return model 375 | with open(json_path, "r") as f: 376 | quantization_map = json.load(f) 377 | print("Start a quantization process...") 378 | requantize(model, sd, quantization_map, device=device) 379 | print("Model is quantized!") 380 | return model 381 | 382 | def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: 383 | # max length 64, 128, 256 and 512 should work (if your sequence is short enough) 384 | version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders") 385 | return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(device) 386 | 387 | def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: 388 | version = os.environ.get("CLIP", "openai/clip-vit-large-patch14") 389 | return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device) 390 | 391 | 392 | def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: 393 | ckpt_path = configs[name].ae_path 394 | if ( 395 | ckpt_path is None 396 | and configs[name].repo_id is not None 397 | and configs[name].repo_ae is not None 398 | and hf_download 399 | ): 400 | ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) 401 | 402 | # Loading the autoencoder 403 | print("Init AE") 404 | with torch.device("meta" if ckpt_path is not None else device): 405 | ae = AutoEncoder(configs[name].ae_params) 406 | 407 | if ckpt_path is not None: 408 | sd = load_sft(ckpt_path, device=str(device)) 409 | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) 410 | print_load_warning(missing, unexpected) 411 | return ae -------------------------------------------------------------------------------- /uno/utils/convert_yaml_to_args_file.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import yaml 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--yaml", type=str, required=True) 20 | parser.add_argument("--arg", type=str, required=True) 21 | args = parser.parse_args() 22 | 23 | 24 | with open(args.yaml, "r") as f: 25 | data = yaml.safe_load(f) 26 | 27 | with open(args.arg, "w") as f: 28 | for k, v in data.items(): 29 | if isinstance(v, list): 30 | v = list(map(str, v)) 31 | v = " ".join(v) 32 | if v is None: 33 | continue 34 | print(f"--{k} {v}", end=" ", file=f) 35 | --------------------------------------------------------------------------------