├── .gitignore ├── LICENSE ├── README.md ├── api.py ├── configs ├── config-dev-1-RTX6000ADA.json ├── config-dev-cuda0.json ├── config-dev-eval.json ├── config-dev-gigaquant.json ├── config-dev-offload-1-4080.json ├── config-dev-offload-1-4090.json ├── config-dev-offload.json ├── config-dev-prequant.json ├── config-dev.json ├── config-schnell-cuda0.json └── config-schnell.json ├── float8_quantize.py ├── flux_emphasis.py ├── flux_pipeline.py ├── image_encoder.py ├── lora_loading.py ├── main.py ├── main_gr.py ├── modules ├── autoencoder.py ├── conditioner.py └── flux_model.py ├── requirements.txt └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.jpg 3 | *.png 4 | *.jpeg 5 | *.gif 6 | *.bmp 7 | *.webp 8 | *.mp4 9 | *.mp3 10 | *.mp3 11 | *.txt 12 | .copilotignore 13 | .misc 14 | BFL-flux-diffusers 15 | .env 16 | .env.* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 Alex Redden 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flux FP8 (true) Matmul Implementation with FastAPI 2 | 3 | This repository contains an implementation of the Flux model, along with an API that allows you to generate images based on text prompts. And also a simple single line of code to use the generator as a single object, similar to diffusers pipelines. 4 | 5 | ## Speed Comparison 6 | 7 | Note: 8 | 9 | - The "bfl codebase" refers to the original [BFL codebase](https://github.com/black-forest-labs/flux), not this repo. 10 | - The "fp8 wo quant" refers to the original BFL codebase using fp8 weight only quantization, not using fp8 matmul which is default in this repo. 11 | - The "compile blocks & extras" refers to the option within this repo setting the config values `"compile_blocks" true` & `"compile_extras": true`. ❌ means both were set to false, ✅ means both were set to true. 12 | - All generations which including a ❌ or ✅ are using this repo. 13 | 14 | | Resolution | Device | Test | Average it/s | 15 | | ---------- | ---------- | -------------------------- | ------------ | 16 | | 1024x1024 | RTX4090 | bfl codebase fp8 wo quant | 1.7 | 17 | | 1024x1024 | RTX4090 | ❌ compile blocks & extras | 2.55 | 18 | | 1024x1024 | RTX4090 | ✅ compile blocks & extras | 3.51 | 19 | | 1024x1024 | RTX4000ADA | ❌ compile blocks & extras | 0.79 | 20 | | 1024x1024 | RTX4000ADA | ✅ compile blocks & extras | 1.26 | 21 | | 1024x1024 | RTX6000ADA | bfl codebase | 1.74 | 22 | | 1024x1024 | RTX6000ADA | ❌ compile blocks & extras | 2.08 | 23 | | 1024x1024 | RTX6000ADA | ✅ compile blocks & extras | 2.8 | 24 | | 1024x1024 | H100 | ❌ compile blocks & extras | 6.1 | 25 | | 1024x1024 | H100 | ✅ compile blocks & extras | 11.5 | 26 | | 768x768 | RTX4090 | bfl codebase fp8 wo quant | 2.32 | 27 | | 768x768 | RTX4090 | ❌ compile blocks & extras | 4.47 | 28 | | 768x768 | RTX4090 | ✅ compile blocks & extras | 6.2 | 29 | | 768x768 | RTX4000 | ❌ compile blocks & extras | 1.41 | 30 | | 768x768 | RTX4000 | ✅ compile blocks & extras | 2.19 | 31 | | 768x768 | RTX6000ADA | bfl codebase | 3.01 | 32 | | 768x768 | RTX6000ADA | ❌ compile blocks & extras | 3.43 | 33 | | 768x768 | RTX6000ADA | ✅ compile blocks & extras | 4.46 | 34 | | 768x768 | H100 | ❌ compile blocks & extras | 10.3 | 35 | | 768x768 | H100 | ✅ compile blocks & extras | 20.8 | 36 | | 1024x720 | RTX4090 | bfl codebase fp8 wo quant | 3.01 | 37 | | 1024x720 | RTX4090 | ❌ compile blocks & extras | 3.6 | 38 | | 1024x720 | RTX4090 | ✅ compile blocks & extras | 4.96 | 39 | | 1024x720 | RTX4000 | ❌ compile blocks & extras | 1.14 | 40 | | 1024x720 | RTX4000 | ✅ compile blocks & extras | 1.78 | 41 | | 1024x720 | RTX6000ADA | bfl codebase | 2.37 | 42 | | 1024x720 | RTX6000ADA | ❌ compile blocks & extras | 2.87 | 43 | | 1024x720 | RTX6000ADA | ✅ compile blocks & extras | 3.78 | 44 | | 1024x720 | H100 | ❌ compile blocks & extras | 8.2 | 45 | | 1024x720 | H100 | ✅ compile blocks & extras | 15.7 | 46 | 47 | ## Table of Contents 48 | 49 | - [Installation](#installation) 50 | - [Usage](#usage) 51 | - [Configuration](#configuration) 52 | - [API Endpoints](#api-endpoints) 53 | - [Examples](#examples) 54 | - [License](https://github.com/aredden/flux-fp8-api/blob/main/LICENSE) 55 | 56 | ### Updates 08/24/24 57 | 58 | - Add config options for levels of quantization for the flow transformer: 59 | - `quantize_modulation`: Quantize the modulation layers in the flow model. If false, adds ~2GB vram usage for moderate precision improvements `(default: true)` 60 | - `quantize_flow_embedder_layers`: Quantize the flow embedder layers in the flow model. If false, adds ~512MB vram usage, but precision improves considerably. `(default: false)` 61 | - Override default config values when loading FluxPipeline, e.g. `FluxPipeline.load_pipeline_from_config_path(config_path, **config_overrides)` 62 | 63 | #### Fixes 64 | 65 | - Fix bug where loading text encoder from HF with bnb will error if device is not set to cuda:0 66 | 67 | **note:** prequantized flow models will only work with the specified quantization levels as when they were created. e.g. if you create a prequantized flow model with `quantize_modulation` set to false, it will only work with `quantize_modulation` set to false, same with `quantize_flow_embedder_layers`. 68 | 69 | ### Updates 08/25/24 70 | 71 | - Added LoRA loading functionality to FluxPipeline. Simple example: 72 | 73 | ```python 74 | from flux_pipeline import FluxPipeline 75 | 76 | config_path = "path/to/config/file.json" 77 | config_overrides = { 78 | #... 79 | } 80 | 81 | lora_path = "path/to/lora/file.safetensors" 82 | 83 | pipeline = FluxPipeline.load_pipeline_from_config_path(config_path, **config_overrides) 84 | 85 | pipeline.load_lora(lora_path, scale=1.0) 86 | ``` 87 | 88 | ### Updates 09/07/24 89 | 90 | - Improve quality by ensuring that the RMSNorm layers use fp32 91 | - Raise the clamp range for single blocks & double blocks to +/-32000 to reduce deviation from expected outputs. 92 | - Make BF16 _not_ clamp, which improves quality and isn't needed because bf16 is the expected dtype for flux. **I would now recommend always using `"flow_dtype": "bfloat16"` in the config**, though it will slow things down on consumer gpus- but not by much at all since most of the compute still happens via fp8. 93 | - Allow for the T5 Model to be run without any quantization, by specifying `"text_enc_quantization_dtype": "bfloat16"` in the config - or also `"float16"`, though not recommended since t5 deviates a bit when running with float16. I noticed that even with qint8/qfloat8 there is a bit of deviation from bf16 text encoder outputs- so for those who want more accurate / expected text encoder outputs, you can use this option. 94 | 95 | ### Updates 10/3/24 96 | 97 | - #### Adding configurable clip model path 98 | Now you can specify the clip model's path in the config, using the `clip_path` parameter in a config file. 99 | - #### Improved lora loading 100 | I believe I have fixed the lora loading bug that was causing the lora to not apply properly, or when not all of the linear weights in the q/k/v/o had loras attached (it wouldn't be able to apply if only some of them did). 101 | - #### Lora loading via api endpoint 102 | 103 | You can now post to the `/lora` endpoint with a json file containing a `scale`, `path`, `name`, and `action` parameters. 104 | 105 | The `path` should be the path to the lora safetensors file either absolute or relative to the root of this repo. 106 | 107 | The `name` is an optional parameter, mainly just for checking purposes to see if the correct lora was being loaded, it's used as an identifier to check whether it's already been loaded or which lora to unload if `action` is `unload` (you can also use the exact same path which was loaded previously to unload the same lora). 108 | 109 | The `action` should be either `load` or `unload`, to load or unload the lora. 110 | 111 | The `scale` should be a float, which is the scale of the lora. 112 | 113 | e.g. 114 | 115 | ```json 116 | { 117 | 118 | "path": "./fluxloras/loras/aidmaImageUpgrader-FLUX-V0.2.safetensors", 119 | 120 | "name": "imgupgrade", 121 | 122 | "action": "load", 123 | 124 | "scale": 0.6 125 | } 126 | ``` 127 | 128 | ## Installation 129 | 130 | This repo _requires_ at least pytorch with cuda=12.4 and an ADA gpu with fp8 support, otherwise `torch._scaled_mm` will throw a CUDA error saying it's not supported. To install with conda/mamba: 131 | 132 | ```bash 133 | mamba create -n flux-fp8-matmul-api python=3.11 pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia 134 | mamba activate flux-fp8-matmul-api 135 | 136 | # or with conda 137 | conda create -n flux-fp8-matmul-api python=3.11 pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia 138 | conda activate flux-fp8-matmul-api 139 | 140 | # or with nightly... (which is what I am using) - also, just switch 'mamba' to 'conda' if you are using conda 141 | mamba create -n flux-fp8-matmul-api python=3.11 pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch-nightly -c nvidia 142 | mamba activate flux-fp8-matmul-api 143 | 144 | # or with pip 145 | python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 146 | # or pip nightly 147 | python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124 148 | ``` 149 | 150 | To install the required dependencies, run: 151 | 152 | ```bash 153 | python -m pip install -r requirements.txt 154 | ``` 155 | 156 | If you get errors installing `torch-cublas-hgemm`, feel free to comment it out in requirements.txt, since it's not necessary, but will speed up inference for non-fp8 linear layers. 157 | 158 | ## Usage 159 | 160 | For a single ADA GPU with less than 24GB vram, and more than 16GB vram, you should use the `configs/config-dev-offload-1-4080.json` config file as a base, and then tweak the parameters to fit your needs. It offloads all models to CPU when not in use, compiles the flow model with extra optimizations, and quantizes the text encoder to nf4 and the autoencoder to qfloat8. 161 | 162 | For a single ADA GPU with more than ~32GB vram, you should use the `configs/config-dev-1-RTX6000ADA.json` config file as a base, and then tweak the parameters to fit your needs. It does not offload any models to CPU, compiles the flow model with extra optimizations, and quantizes the text encoder to qfloat8 and the autoencoder to stays as bfloat16. 163 | 164 | For a single 4090 GPU, you should use the `configs/config-dev-offload-1-4090.json` config file as a base, and then tweak the parameters to fit your needs. It offloads the text encoder and the autoencoder to CPU, compiles the flow model with extra optimizations, and quantizes the text encoder to nf4 and the autoencoder to float8. 165 | 166 | **NOTE:** For all of these configs, you must change the `ckpt_path`, `ae_path`, and `text_enc_path` parameters to the path to your own checkpoint, autoencoder, and text encoder. 167 | 168 | You can run the API server using the following command: 169 | 170 | ```bash 171 | python main.py --config-path --port --host 172 | ``` 173 | 174 | ### API Command-Line Arguments 175 | 176 | - `--config-path`: Path to the configuration file. If not provided, the model will be loaded from the command line arguments. 177 | - `--port`: Port to run the server on (default: 8088). 178 | - `--host`: Host to run the server on (default: 0.0.0.0). 179 | - `--flow-model-path`: Path to the flow model. 180 | - `--text-enc-path`: Path to the text encoder. 181 | - `--autoencoder-path`: Path to the autoencoder. 182 | - `--model-version`: Choose model version (`flux-dev` or `flux-schnell`). 183 | - `--flux-device`: Device to run the flow model on (default: cuda:0). 184 | - `--text-enc-device`: Device to run the text encoder on (default: cuda:0). 185 | - `--autoencoder-device`: Device to run the autoencoder on (default: cuda:0). 186 | - `--compile`: Compile the flow model with extra optimizations (default: False). 187 | - `--quant-text-enc`: Quantize the T5 text encoder to the given dtype (`qint4`, `qfloat8`, `qint2`, `qint8`, `bf16`), if `bf16`, will not quantize (default: `qfloat8`). 188 | - `--quant-ae`: Quantize the autoencoder with float8 linear layers, otherwise will use bfloat16 (default: False). 189 | - `--offload-flow`: Offload the flow model to the CPU when not being used to save memory (default: False). 190 | - `--no-offload-ae`: Disable offloading the autoencoder to the CPU when not being used to increase e2e inference speed (default: True [implies it will offload, setting this flag sets it to False]). 191 | - `--no-offload-text-enc`: Disable offloading the text encoder to the CPU when not being used to increase e2e inference speed (default: True [implies it will offload, setting this flag sets it to False]). 192 | - `--prequantized-flow`: Load the flow model from a prequantized checkpoint, which reduces the size of the checkpoint by about 50% & reduces startup time (default: False). 193 | - `--no-quantize-flow-modulation`: Disable quantization of the modulation layers in the flow transformer, which improves precision _moderately_ but adds ~2GB vram usage. 194 | - `--quantize-flow-embedder-layers`: Quantize the flow embedder layers in the flow transformer, reduces precision _considerably_ but saves ~512MB vram usage. 195 | 196 | ## Configuration 197 | 198 | The configuration files are located in the `configs` directory. You can specify different configurations for different model versions and devices. 199 | 200 | Example configuration file for a single 4090 (`configs/config-dev-offload-1-4090.json`): 201 | 202 | ```js 203 | { 204 | "version": "flux-dev", // or flux-schnell 205 | "params": { 206 | "in_channels": 64, 207 | "vec_in_dim": 768, 208 | "context_in_dim": 4096, 209 | "hidden_size": 3072, 210 | "mlp_ratio": 4.0, 211 | "num_heads": 24, 212 | "depth": 19, 213 | "depth_single_blocks": 38, 214 | "axes_dim": [16, 56, 56], 215 | "theta": 10000, 216 | "qkv_bias": true, 217 | "guidance_embed": true // if you are using flux-schnell, set this to false 218 | }, 219 | "ae_params": { 220 | "resolution": 256, 221 | "in_channels": 3, 222 | "ch": 128, 223 | "out_ch": 3, 224 | "ch_mult": [1, 2, 4, 4], 225 | "num_res_blocks": 2, 226 | "z_channels": 16, 227 | "scale_factor": 0.3611, 228 | "shift_factor": 0.1159 229 | }, 230 | "ckpt_path": "/your/path/to/flux1-dev.sft", // local path to original bf16 BFL flux checkpoint 231 | "ae_path": "/your/path/to/ae.sft", // local path to original bf16 BFL autoencoder checkpoint 232 | "repo_id": "black-forest-labs/FLUX.1-dev", // can ignore 233 | "repo_flow": "flux1-dev.sft", // can ignore 234 | "repo_ae": "ae.sft", // can ignore 235 | "text_enc_max_length": 512, // use 256 if you are using flux-schnell 236 | "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16", // or custom HF full bf16 T5EncoderModel repo id 237 | "text_enc_device": "cuda:0", 238 | "ae_device": "cuda:0", 239 | "flux_device": "cuda:0", 240 | "flow_dtype": "float16", 241 | "ae_dtype": "bfloat16", 242 | "text_enc_dtype": "bfloat16", 243 | "flow_quantization_dtype": "qfloat8", // will always be qfloat8, so can ignore 244 | "text_enc_quantization_dtype": "qint4", // choose between qint4, qint8, qfloat8, qint2 or delete entry for no quantization 245 | "ae_quantization_dtype": "qfloat8", // can either be qfloat8 or delete entry for no quantization 246 | "compile_extras": true, // compile the layers not included in the single-blocks or double-blocks 247 | "compile_blocks": true, // compile the single-blocks and double-blocks 248 | "offload_text_encoder": true, // offload the text encoder to cpu when not in use 249 | "offload_vae": true, // offload the autoencoder to cpu when not in use 250 | "offload_flow": false, // offload the flow transformer to cpu when not in use 251 | "prequantized_flow": false, // load the flow transformer from a prequantized checkpoint, which reduces the size of the checkpoint by about 50% & reduces startup time (default: false) 252 | "quantize_modulation": true, // quantize the modulation layers in the flow transformer, which reduces precision moderately but saves ~2GB vram usage (default: true) 253 | "quantize_flow_embedder_layers": false, // quantize the flow embedder layers in the flow transformer, if false, improves precision considerably at the cost of adding ~512MB vram usage (default: false) 254 | } 255 | ``` 256 | 257 | The only things you should need to change in general are the: 258 | 259 | ```json5 260 | "ckpt_path": "/path/to/your/flux1-dev.sft", // path to your original BFL flow transformer (not diffusers) 261 | "ae_path": "/path/to/your/ae.sft", // path to your original BFL autoencoder (not diffusers) 262 | "text_enc_path": "path/to/your/t5-v1_1-xxl-encoder-bf16", // HF T5EncoderModel - can use "city96/t5-v1_1-xxl-encoder-bf16" for a simple to download version 263 | ``` 264 | 265 | Other things to change can be the 266 | 267 | - `"text_enc_max_length": 512` 268 | max length for the text encoder, 256 if you are using flux-schnell 269 | 270 | - `"ae_quantization_dtype": "qfloat8"` 271 | quantization dtype for the autoencoder, can be `qfloat8` or delete entry for no quantization, will use the float8 linear layer implementation included in this repo. 272 | 273 | - `"text_enc_quantization_dtype": "qfloat8"` 274 | quantization dtype for the text encoder, if `qfloat8` or `qint2` will use quanto, `qint4`, `qint8` will use bitsandbytes 275 | 276 | - `"compile_extras": true,` 277 | compiles all modules that are not the single-blocks or double-blocks (default: false) 278 | 279 | - `"compile_blocks": true,` 280 | compiles all single-blocks and double-blocks (default: false) 281 | 282 | - `"text_enc_offload": false,` 283 | offload text encoder to cpu (default: false) - set to true if you only have a single 4090 and no other GPUs, otherwise you can set this to false and reduce latency [NOTE: this will be slow, if you have multiple GPUs, change the text_enc_device to a different device so you can set offloading for text_enc to false] 284 | 285 | - `"ae_offload": false,` 286 | offload autoencoder to cpu (default: false) - set to true if you only have a single 4090 and no other GPUs, otherwise you can set this to false and reduce latency [NOTE: this will be slow, if you have multiple GPUs, change the ae_device to a different device so you can set offloading for ae to false] 287 | 288 | - `"flux_offload": false,` 289 | offload flow transformer to cpu (default: false) - set to true if you only have a single 4090 and no other GPUs, otherwise you can set this to false and reduce latency [NOTE: this will be slow, if you have multiple GPUs, change the flux_device to a different device so you can set offloading for flux to false] 290 | 291 | - `"flux_device": "cuda:0",` 292 | device for flow transformer (default: cuda:0) - this gpu must have fp8 support and at least 16GB of memory, does not need to be the same as text_enc_device or ae_device 293 | 294 | - `"text_enc_device": "cuda:0",` 295 | device for text encoder (default: cuda:0) - set this to a different device - e.g. `"cuda:1"` if you have multiple gpus so you can set offloading for text_enc to false, does not need to be the same as flux_device or ae_device 296 | 297 | - `"ae_device": "cuda:0",` 298 | device for autoencoder (default: cuda:0) - set this to a different device - e.g. `"cuda:1"` if you have multiple gpus so you can set offloading for ae to false, does not need to be the same as flux_device or text_enc_device 299 | 300 | - `"prequantized_flow": false,` 301 | load the flow transformer from a prequantized checkpoint, which reduces the size of the checkpoint by about 50% & reduces startup time (default: false) 302 | 303 | - Note: MUST be a prequantized checkpoint created with the same quantization settings as the current config, and must have been quantized using this repo. 304 | 305 | - `"quantize_modulation": true,` 306 | quantize the modulation layers in the flow transformer, which improves precision at the cost of adding ~2GB vram usage (default: true) 307 | 308 | - `"quantize_flow_embedder_layers": false,` 309 | quantize the flow embedder layers in the flow transformer, which improves precision considerably at the cost of adding ~512MB vram usage (default: false) 310 | 311 | ## API Endpoints 312 | 313 | ### Generate Image 314 | 315 | - **URL**: `/generate` 316 | - **Method**: `POST` 317 | - **Request Body**: 318 | 319 | - `prompt` (str): The text prompt for image generation. 320 | - `width` (int, optional): The width of the generated image (default: 720). 321 | - `height` (int, optional): The height of the generated image (default: 1024). 322 | - `num_steps` (int, optional): The number of steps for the generation process (default: 24). 323 | - `guidance` (float, optional): The guidance scale for the generation process (default: 3.5). 324 | - `seed` (int, optional): The seed for random number generation. 325 | - `init_image` (str, optional): The base64 encoded image to be used as a reference for the generation process. 326 | - `strength` (float, optional): The strength of the diffusion process when image is provided (default: 1.0). 327 | 328 | - **Response**: A JPEG image stream. 329 | 330 | ## Examples 331 | 332 | ### Running the Server 333 | 334 | ```bash 335 | python main.py --config-path configs/config-dev-1-4090.json --port 8088 --host 0.0.0.0 336 | ``` 337 | 338 | Or if you need more granular control over the all of the settings, you can run the server with something like this: 339 | 340 | ```bash 341 | python main.py --port 8088 --host 0.0.0.0 \ 342 | --flow-model-path /path/to/your/flux1-dev.sft \ 343 | --text-enc-path /path/to/your/t5-v1_1-xxl-encoder-bf16 \ 344 | --autoencoder-path /path/to/your/ae.sft \ 345 | --model-version flux-dev \ 346 | --flux-device cuda:0 \ 347 | --text-enc-device cuda:0 \ 348 | --autoencoder-device cuda:0 \ 349 | --compile \ 350 | --quant-text-enc qfloat8 \ 351 | --quant-ae 352 | ``` 353 | 354 | ### Generating an image on a client 355 | 356 | Send a POST request to `http://:/generate` with the following JSON body: 357 | 358 | ```json 359 | { 360 | "prompt": "a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns", 361 | "width": 1024, 362 | "height": 1024, 363 | "num_steps": 24, 364 | "guidance": 3.0, 365 | "seed": 13456 366 | } 367 | ``` 368 | 369 | For an example of how to generate from a python client using the FastAPI server: 370 | 371 | ```py 372 | import requests 373 | import io 374 | 375 | prompt = "a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns" 376 | res = requests.post( 377 | "http://localhost:8088/generate", 378 | json={ 379 | "width": 1024, 380 | "height": 720, 381 | "num_steps": 20, 382 | "guidance": 4, 383 | "prompt": prompt, 384 | }, 385 | stream=True, 386 | ) 387 | 388 | with open(f"output.jpg", "wb") as f: 389 | f.write(io.BytesIO(res.content).read()) 390 | 391 | ``` 392 | 393 | You can also generate an image by directly importing the FluxPipeline class and using it to generate an image. This is useful if you have a custom model configuration and want to generate an image without having to run the server. 394 | 395 | ```py 396 | import io 397 | from flux_pipeline import FluxPipeline 398 | 399 | 400 | pipe = FluxPipeline.load_pipeline_from_config_path( 401 | "configs/config-dev-offload-1-4090.json" # or whatever your config is 402 | ) 403 | 404 | output_jpeg_bytes: io.BytesIO = pipe.generate( 405 | # Required args: 406 | prompt="A beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns", 407 | # Optional args: 408 | width=1024, 409 | height=1024, 410 | num_steps=20, 411 | guidance=3.5, 412 | seed=13456, 413 | init_image="path/to/your/init_image.jpg", 414 | strength=0.8, 415 | ) 416 | 417 | with open("output.jpg", "wb") as f: 418 | f.write(output_jpeg_bytes.getvalue()) 419 | 420 | ``` 421 | -------------------------------------------------------------------------------- /api.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional, TYPE_CHECKING 2 | 3 | import numpy as np 4 | from fastapi import FastAPI 5 | from fastapi.responses import StreamingResponse, JSONResponse 6 | from pydantic import BaseModel, Field 7 | from platform import system 8 | 9 | if TYPE_CHECKING: 10 | from flux_pipeline import FluxPipeline 11 | 12 | if system() == "Windows": 13 | MAX_RAND = 2**16 - 1 14 | else: 15 | MAX_RAND = 2**32 - 1 16 | 17 | 18 | class AppState: 19 | model: "FluxPipeline" 20 | 21 | 22 | class FastAPIApp(FastAPI): 23 | state: AppState 24 | 25 | 26 | class LoraArgs(BaseModel): 27 | scale: Optional[float] = 1.0 28 | path: Optional[str] = None 29 | name: Optional[str] = None 30 | action: Optional[Literal["load", "unload"]] = "load" 31 | 32 | 33 | class LoraLoadResponse(BaseModel): 34 | status: Literal["success", "error"] 35 | message: Optional[str] = None 36 | 37 | 38 | class GenerateArgs(BaseModel): 39 | prompt: str 40 | width: Optional[int] = Field(default=720) 41 | height: Optional[int] = Field(default=1024) 42 | num_steps: Optional[int] = Field(default=24) 43 | guidance: Optional[float] = Field(default=3.5) 44 | seed: Optional[int] = Field( 45 | default_factory=lambda: np.random.randint(0, MAX_RAND), gt=0, lt=MAX_RAND 46 | ) 47 | strength: Optional[float] = 1.0 48 | init_image: Optional[str] = None 49 | 50 | 51 | app = FastAPIApp() 52 | 53 | 54 | @app.post("/generate") 55 | def generate(args: GenerateArgs): 56 | """ 57 | Generates an image from the Flux flow transformer. 58 | 59 | Args: 60 | args (GenerateArgs): Arguments for image generation: 61 | 62 | - `prompt`: The prompt used for image generation. 63 | 64 | - `width`: The width of the image. 65 | 66 | - `height`: The height of the image. 67 | 68 | - `num_steps`: The number of steps for the image generation. 69 | 70 | - `guidance`: The guidance for image generation, represents the 71 | influence of the prompt on the image generation. 72 | 73 | - `seed`: The seed for the image generation. 74 | 75 | - `strength`: strength for image generation, 0.0 - 1.0. 76 | Represents the percent of diffusion steps to run, 77 | setting the init_image as the noised latent at the 78 | given number of steps. 79 | 80 | - `init_image`: Base64 encoded image or path to image to use as the init image. 81 | 82 | Returns: 83 | StreamingResponse: The generated image as streaming jpeg bytes. 84 | """ 85 | result = app.state.model.generate(**args.model_dump()) 86 | return StreamingResponse(result, media_type="image/jpeg") 87 | 88 | 89 | @app.post("/lora", response_model=LoraLoadResponse) 90 | def lora_action(args: LoraArgs): 91 | """ 92 | Loads or unloads a LoRA checkpoint into / from the Flux flow transformer. 93 | 94 | Args: 95 | args (LoraArgs): Arguments for the LoRA action: 96 | 97 | - `scale`: The scaling factor for the LoRA weights. 98 | - `path`: The path to the LoRA checkpoint. 99 | - `name`: The name of the LoRA checkpoint. 100 | - `action`: The action to perform, either "load" or "unload". 101 | 102 | Returns: 103 | LoraLoadResponse: The status of the LoRA action. 104 | """ 105 | try: 106 | if args.action == "load": 107 | app.state.model.load_lora(args.path, args.scale, args.name) 108 | elif args.action == "unload": 109 | app.state.model.unload_lora(args.name if args.name else args.path) 110 | else: 111 | return JSONResponse( 112 | content={ 113 | "status": "error", 114 | "message": f"Invalid action, expected 'load' or 'unload', got {args.action}", 115 | }, 116 | status_code=400, 117 | ) 118 | except Exception as e: 119 | return JSONResponse( 120 | status_code=500, content={"status": "error", "message": str(e)} 121 | ) 122 | return JSONResponse(status_code=200, content={"status": "success"}) 123 | -------------------------------------------------------------------------------- /configs/config-dev-1-RTX6000ADA.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "flux-dev", 3 | "params": { 4 | "in_channels": 64, 5 | "vec_in_dim": 768, 6 | "context_in_dim": 4096, 7 | "hidden_size": 3072, 8 | "mlp_ratio": 4.0, 9 | "num_heads": 24, 10 | "depth": 19, 11 | "depth_single_blocks": 38, 12 | "axes_dim": [ 13 | 16, 14 | 56, 15 | 56 16 | ], 17 | "theta": 10000, 18 | "qkv_bias": true, 19 | "guidance_embed": true 20 | }, 21 | "ae_params": { 22 | "resolution": 256, 23 | "in_channels": 3, 24 | "ch": 128, 25 | "out_ch": 3, 26 | "ch_mult": [ 27 | 1, 28 | 2, 29 | 4, 30 | 4 31 | ], 32 | "num_res_blocks": 2, 33 | "z_channels": 16, 34 | "scale_factor": 0.3611, 35 | "shift_factor": 0.1159 36 | }, 37 | "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft", 38 | "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft", 39 | "repo_id": "black-forest-labs/FLUX.1-dev", 40 | "repo_flow": "flux1-dev.sft", 41 | "repo_ae": "ae.sft", 42 | "text_enc_max_length": 512, 43 | "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16", 44 | "text_enc_device": "cuda:0", 45 | "ae_device": "cuda:0", 46 | "flux_device": "cuda:0", 47 | "flow_dtype": "float16", 48 | "ae_dtype": "bfloat16", 49 | "text_enc_dtype": "bfloat16", 50 | "flow_quantization_dtype": "qfloat8", 51 | "text_enc_quantization_dtype": "qfloat8", 52 | "compile_extras": true, 53 | "compile_blocks": true, 54 | "offload_text_encoder": false, 55 | "offload_vae": false, 56 | "offload_flow": false 57 | } -------------------------------------------------------------------------------- /configs/config-dev-cuda0.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "flux-dev", 3 | "params": { 4 | "in_channels": 64, 5 | "vec_in_dim": 768, 6 | "context_in_dim": 4096, 7 | "hidden_size": 3072, 8 | "mlp_ratio": 4.0, 9 | "num_heads": 24, 10 | "depth": 19, 11 | "depth_single_blocks": 38, 12 | "axes_dim": [ 13 | 16, 14 | 56, 15 | 56 16 | ], 17 | "theta": 10000, 18 | "qkv_bias": true, 19 | "guidance_embed": true 20 | }, 21 | "ae_params": { 22 | "resolution": 256, 23 | "in_channels": 3, 24 | "ch": 128, 25 | "out_ch": 3, 26 | "ch_mult": [ 27 | 1, 28 | 2, 29 | 4, 30 | 4 31 | ], 32 | "num_res_blocks": 2, 33 | "z_channels": 16, 34 | "scale_factor": 0.3611, 35 | "shift_factor": 0.1159 36 | }, 37 | "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft", 38 | "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft", 39 | "repo_id": "black-forest-labs/FLUX.1-dev", 40 | "repo_flow": "flux1-dev.sft", 41 | "repo_ae": "ae.sft", 42 | "text_enc_max_length": 512, 43 | "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16", 44 | "text_enc_device": "cuda:0", 45 | "ae_device": "cuda:0", 46 | "flux_device": "cuda:0", 47 | "flow_dtype": "float16", 48 | "ae_dtype": "bfloat16", 49 | "text_enc_dtype": "bfloat16", 50 | "text_enc_quantization_dtype": "qfloat8", 51 | "compile_extras": false, 52 | "compile_blocks": false, 53 | "offload_ae": false, 54 | "offload_text_enc": false, 55 | "offload_flow": false 56 | } -------------------------------------------------------------------------------- /configs/config-dev-eval.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "flux-dev", 3 | "params": { 4 | "in_channels": 64, 5 | "vec_in_dim": 768, 6 | "context_in_dim": 4096, 7 | "hidden_size": 3072, 8 | "mlp_ratio": 4.0, 9 | "num_heads": 24, 10 | "depth": 19, 11 | "depth_single_blocks": 38, 12 | "axes_dim": [ 13 | 16, 14 | 56, 15 | 56 16 | ], 17 | "theta": 10000, 18 | "qkv_bias": true, 19 | "guidance_embed": true 20 | }, 21 | "ae_params": { 22 | "resolution": 256, 23 | "in_channels": 3, 24 | "ch": 128, 25 | "out_ch": 3, 26 | "ch_mult": [ 27 | 1, 28 | 2, 29 | 4, 30 | 4 31 | ], 32 | "num_res_blocks": 2, 33 | "z_channels": 16, 34 | "scale_factor": 0.3611, 35 | "shift_factor": 0.1159 36 | }, 37 | "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft", 38 | "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft", 39 | "repo_id": "black-forest-labs/FLUX.1-dev", 40 | "repo_flow": "flux1-dev.sft", 41 | "repo_ae": "ae.sft", 42 | "text_enc_max_length": 512, 43 | "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16", 44 | "text_enc_device": "cuda:1", 45 | "ae_device": "cuda:1", 46 | "flux_device": "cuda:0", 47 | "flow_dtype": "float16", 48 | "ae_dtype": "bfloat16", 49 | "text_enc_dtype": "bfloat16", 50 | "flow_quantization_dtype": "qfloat8", 51 | "text_enc_quantization_dtype": "qfloat8", 52 | "compile_extras": false, 53 | "compile_blocks": false, 54 | "offload_ae": false, 55 | "offload_text_enc": false, 56 | "offload_flow": false 57 | } -------------------------------------------------------------------------------- /configs/config-dev-gigaquant.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "flux-dev", 3 | "params": { 4 | "in_channels": 64, 5 | "vec_in_dim": 768, 6 | "context_in_dim": 4096, 7 | "hidden_size": 3072, 8 | "mlp_ratio": 4.0, 9 | "num_heads": 24, 10 | "depth": 19, 11 | "depth_single_blocks": 38, 12 | "axes_dim": [ 13 | 16, 14 | 56, 15 | 56 16 | ], 17 | "theta": 10000, 18 | "qkv_bias": true, 19 | "guidance_embed": true 20 | }, 21 | "ae_params": { 22 | "resolution": 256, 23 | "in_channels": 3, 24 | "ch": 128, 25 | "out_ch": 3, 26 | "ch_mult": [ 27 | 1, 28 | 2, 29 | 4, 30 | 4 31 | ], 32 | "num_res_blocks": 2, 33 | "z_channels": 16, 34 | "scale_factor": 0.3611, 35 | "shift_factor": 0.1159 36 | }, 37 | "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft", 38 | "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft", 39 | "repo_id": "black-forest-labs/FLUX.1-dev", 40 | "repo_flow": "flux1-dev.sft", 41 | "repo_ae": "ae.sft", 42 | "text_enc_max_length": 512, 43 | "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16", 44 | "text_enc_device": "cuda:0", 45 | "ae_device": "cuda:0", 46 | "flux_device": "cuda:0", 47 | "flow_dtype": "float16", 48 | "ae_dtype": "bfloat16", 49 | "text_enc_dtype": "bfloat16", 50 | "num_to_quant": 220, 51 | "flow_quantization_dtype": "qint4", 52 | "text_enc_quantization_dtype": "qint4", 53 | "ae_quantization_dtype": "qint4", 54 | "clip_quantization_dtype": "qint4", 55 | "compile_extras": false, 56 | "compile_blocks": false, 57 | "quantize_extras": true 58 | } -------------------------------------------------------------------------------- /configs/config-dev-offload-1-4080.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "flux-dev", 3 | "params": { 4 | "in_channels": 64, 5 | "vec_in_dim": 768, 6 | "context_in_dim": 4096, 7 | "hidden_size": 3072, 8 | "mlp_ratio": 4.0, 9 | "num_heads": 24, 10 | "depth": 19, 11 | "depth_single_blocks": 38, 12 | "axes_dim": [ 13 | 16, 14 | 56, 15 | 56 16 | ], 17 | "theta": 10000, 18 | "qkv_bias": true, 19 | "guidance_embed": true 20 | }, 21 | "ae_params": { 22 | "resolution": 256, 23 | "in_channels": 3, 24 | "ch": 128, 25 | "out_ch": 3, 26 | "ch_mult": [ 27 | 1, 28 | 2, 29 | 4, 30 | 4 31 | ], 32 | "num_res_blocks": 2, 33 | "z_channels": 16, 34 | "scale_factor": 0.3611, 35 | "shift_factor": 0.1159 36 | }, 37 | "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft", 38 | "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft", 39 | "repo_id": "black-forest-labs/FLUX.1-dev", 40 | "repo_flow": "flux1-dev.sft", 41 | "repo_ae": "ae.sft", 42 | "text_enc_max_length": 512, 43 | "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16", 44 | "text_enc_device": "cuda:0", 45 | "ae_device": "cuda:0", 46 | "flux_device": "cuda:0", 47 | "flow_dtype": "float16", 48 | "ae_dtype": "bfloat16", 49 | "text_enc_dtype": "bfloat16", 50 | "flow_quantization_dtype": "qfloat8", 51 | "text_enc_quantization_dtype": "qint4", 52 | "ae_quantization_dtype": "qfloat8", 53 | "compile_extras": true, 54 | "compile_blocks": true, 55 | "offload_text_encoder": true, 56 | "offload_vae": true, 57 | "offload_flow": true 58 | } -------------------------------------------------------------------------------- /configs/config-dev-offload-1-4090.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "flux-dev", 3 | "params": { 4 | "in_channels": 64, 5 | "vec_in_dim": 768, 6 | "context_in_dim": 4096, 7 | "hidden_size": 3072, 8 | "mlp_ratio": 4.0, 9 | "num_heads": 24, 10 | "depth": 19, 11 | "depth_single_blocks": 38, 12 | "axes_dim": [ 13 | 16, 14 | 56, 15 | 56 16 | ], 17 | "theta": 10000, 18 | "qkv_bias": true, 19 | "guidance_embed": true 20 | }, 21 | "ae_params": { 22 | "resolution": 256, 23 | "in_channels": 3, 24 | "ch": 128, 25 | "out_ch": 3, 26 | "ch_mult": [ 27 | 1, 28 | 2, 29 | 4, 30 | 4 31 | ], 32 | "num_res_blocks": 2, 33 | "z_channels": 16, 34 | "scale_factor": 0.3611, 35 | "shift_factor": 0.1159 36 | }, 37 | "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft", 38 | "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft", 39 | "repo_id": "black-forest-labs/FLUX.1-dev", 40 | "repo_flow": "flux1-dev.sft", 41 | "repo_ae": "ae.sft", 42 | "text_enc_max_length": 512, 43 | "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16", 44 | "text_enc_device": "cuda:0", 45 | "ae_device": "cuda:0", 46 | "flux_device": "cuda:0", 47 | "flow_dtype": "float16", 48 | "ae_dtype": "bfloat16", 49 | "text_enc_dtype": "bfloat16", 50 | "flow_quantization_dtype": "qfloat8", 51 | "text_enc_quantization_dtype": "qint4", 52 | "ae_quantization_dtype": "qfloat8", 53 | "compile_extras": true, 54 | "compile_blocks": true, 55 | "offload_text_encoder": true, 56 | "offload_vae": true, 57 | "offload_flow": false 58 | } -------------------------------------------------------------------------------- /configs/config-dev-offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "flux-dev", 3 | "params": { 4 | "in_channels": 64, 5 | "vec_in_dim": 768, 6 | "context_in_dim": 4096, 7 | "hidden_size": 3072, 8 | "mlp_ratio": 4.0, 9 | "num_heads": 24, 10 | "depth": 19, 11 | "depth_single_blocks": 38, 12 | "axes_dim": [ 13 | 16, 14 | 56, 15 | 56 16 | ], 17 | "theta": 10000, 18 | "qkv_bias": true, 19 | "guidance_embed": true 20 | }, 21 | "ae_params": { 22 | "resolution": 256, 23 | "in_channels": 3, 24 | "ch": 128, 25 | "out_ch": 3, 26 | "ch_mult": [ 27 | 1, 28 | 2, 29 | 4, 30 | 4 31 | ], 32 | "num_res_blocks": 2, 33 | "z_channels": 16, 34 | "scale_factor": 0.3611, 35 | "shift_factor": 0.1159 36 | }, 37 | "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft", 38 | "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft", 39 | "repo_id": "black-forest-labs/FLUX.1-dev", 40 | "repo_flow": "flux1-dev.sft", 41 | "repo_ae": "ae.sft", 42 | "text_enc_max_length": 512, 43 | "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16", 44 | "text_enc_device": "cuda:0", 45 | "ae_device": "cuda:0", 46 | "flux_device": "cuda:0", 47 | "flow_dtype": "float16", 48 | "ae_dtype": "bfloat16", 49 | "text_enc_dtype": "bfloat16", 50 | "flow_quantization_dtype": "qfloat8", 51 | "text_enc_quantization_dtype": "qint4", 52 | "ae_quantization_dtype": "qfloat8", 53 | "compile_extras": false, 54 | "compile_blocks": false, 55 | "offload_text_encoder": true, 56 | "offload_vae": true, 57 | "offload_flow": true 58 | } -------------------------------------------------------------------------------- /configs/config-dev-prequant.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "flux-dev", 3 | "params": { 4 | "in_channels": 64, 5 | "vec_in_dim": 768, 6 | "context_in_dim": 4096, 7 | "hidden_size": 3072, 8 | "mlp_ratio": 4.0, 9 | "num_heads": 24, 10 | "depth": 19, 11 | "depth_single_blocks": 38, 12 | "axes_dim": [ 13 | 16, 14 | 56, 15 | 56 16 | ], 17 | "theta": 10000, 18 | "qkv_bias": true, 19 | "guidance_embed": true 20 | }, 21 | "ae_params": { 22 | "resolution": 256, 23 | "in_channels": 3, 24 | "ch": 128, 25 | "out_ch": 3, 26 | "ch_mult": [ 27 | 1, 28 | 2, 29 | 4, 30 | 4 31 | ], 32 | "num_res_blocks": 2, 33 | "z_channels": 16, 34 | "scale_factor": 0.3611, 35 | "shift_factor": 0.1159 36 | }, 37 | "ckpt_path": "/big/generator-ui/flux-testing/flux/flux-fp16-acc/flux_fp8.safetensors", 38 | "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft", 39 | "repo_id": "black-forest-labs/FLUX.1-dev", 40 | "repo_flow": "flux1-dev.sft", 41 | "repo_ae": "ae.sft", 42 | "text_enc_max_length": 512, 43 | "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16", 44 | "text_enc_device": "cuda:1", 45 | "ae_device": "cuda:1", 46 | "flux_device": "cuda:0", 47 | "flow_dtype": "float16", 48 | "ae_dtype": "bfloat16", 49 | "text_enc_dtype": "bfloat16", 50 | "text_enc_quantization_dtype": "qfloat8", 51 | "compile_extras": false, 52 | "compile_blocks": false, 53 | "prequantized_flow": true, 54 | "offload_ae": false, 55 | "offload_text_enc": false, 56 | "offload_flow": false 57 | } -------------------------------------------------------------------------------- /configs/config-dev.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "flux-dev", 3 | "params": { 4 | "in_channels": 64, 5 | "vec_in_dim": 768, 6 | "context_in_dim": 4096, 7 | "hidden_size": 3072, 8 | "mlp_ratio": 4.0, 9 | "num_heads": 24, 10 | "depth": 19, 11 | "depth_single_blocks": 38, 12 | "axes_dim": [ 13 | 16, 14 | 56, 15 | 56 16 | ], 17 | "theta": 10000, 18 | "qkv_bias": true, 19 | "guidance_embed": true 20 | }, 21 | "ae_params": { 22 | "resolution": 256, 23 | "in_channels": 3, 24 | "ch": 128, 25 | "out_ch": 3, 26 | "ch_mult": [ 27 | 1, 28 | 2, 29 | 4, 30 | 4 31 | ], 32 | "num_res_blocks": 2, 33 | "z_channels": 16, 34 | "scale_factor": 0.3611, 35 | "shift_factor": 0.1159 36 | }, 37 | "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft", 38 | "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft", 39 | "repo_id": "black-forest-labs/FLUX.1-dev", 40 | "repo_flow": "flux1-dev.sft", 41 | "repo_ae": "ae.sft", 42 | "text_enc_max_length": 512, 43 | "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16", 44 | "text_enc_device": "cuda:1", 45 | "ae_device": "cuda:1", 46 | "flux_device": "cuda:0", 47 | "flow_dtype": "float16", 48 | "ae_dtype": "bfloat16", 49 | "text_enc_dtype": "bfloat16", 50 | "text_enc_quantization_dtype": "qfloat8", 51 | "ae_quantization_dtype": "qfloat8", 52 | "compile_extras": true, 53 | "compile_blocks": true, 54 | "offload_ae": false, 55 | "offload_text_enc": false, 56 | "offload_flow": false 57 | } -------------------------------------------------------------------------------- /configs/config-schnell-cuda0.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "flux-schnell", 3 | "params": { 4 | "in_channels": 64, 5 | "vec_in_dim": 768, 6 | "context_in_dim": 4096, 7 | "hidden_size": 3072, 8 | "mlp_ratio": 4.0, 9 | "num_heads": 24, 10 | "depth": 19, 11 | "depth_single_blocks": 38, 12 | "axes_dim": [ 13 | 16, 14 | 56, 15 | 56 16 | ], 17 | "theta": 10000, 18 | "qkv_bias": true, 19 | "guidance_embed": false 20 | }, 21 | "ae_params": { 22 | "resolution": 256, 23 | "in_channels": 3, 24 | "ch": 128, 25 | "out_ch": 3, 26 | "ch_mult": [ 27 | 1, 28 | 2, 29 | 4, 30 | 4 31 | ], 32 | "num_res_blocks": 2, 33 | "z_channels": 16, 34 | "scale_factor": 0.3611, 35 | "shift_factor": 0.1159 36 | }, 37 | "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/flux1-schnell.sft", 38 | "ae_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/ae.sft", 39 | "repo_id": "black-forest-labs/FLUX.1-schnell", 40 | "repo_flow": "flux1-schnell.sft", 41 | "repo_ae": "ae.sft", 42 | "text_enc_max_length": 256, 43 | "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16", 44 | "text_enc_device": "cuda:0", 45 | "ae_device": "cuda:0", 46 | "flux_device": "cuda:0", 47 | "flow_dtype": "float16", 48 | "ae_dtype": "bfloat16", 49 | "text_enc_dtype": "bfloat16", 50 | "text_enc_quantization_dtype": "qfloat8", 51 | "ae_quantization_dtype": "qfloat8", 52 | "compile_extras": false, 53 | "compile_blocks": false, 54 | "offload_ae": false, 55 | "offload_text_enc": false, 56 | "offload_flow": false 57 | } -------------------------------------------------------------------------------- /configs/config-schnell.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "flux-schnell", 3 | "params": { 4 | "in_channels": 64, 5 | "vec_in_dim": 768, 6 | "context_in_dim": 4096, 7 | "hidden_size": 3072, 8 | "mlp_ratio": 4.0, 9 | "num_heads": 24, 10 | "depth": 19, 11 | "depth_single_blocks": 38, 12 | "axes_dim": [ 13 | 16, 14 | 56, 15 | 56 16 | ], 17 | "theta": 10000, 18 | "qkv_bias": true, 19 | "guidance_embed": false 20 | }, 21 | "ae_params": { 22 | "resolution": 256, 23 | "in_channels": 3, 24 | "ch": 128, 25 | "out_ch": 3, 26 | "ch_mult": [ 27 | 1, 28 | 2, 29 | 4, 30 | 4 31 | ], 32 | "num_res_blocks": 2, 33 | "z_channels": 16, 34 | "scale_factor": 0.3611, 35 | "shift_factor": 0.1159 36 | }, 37 | "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/flux1-schnell.sft", 38 | "ae_path": "/big/generator-ui/flux-testing/flux/model-dir-schnell/ae.sft", 39 | "repo_id": "black-forest-labs/FLUX.1-schnell", 40 | "repo_flow": "flux1-schnell.sft", 41 | "repo_ae": "ae.sft", 42 | "text_enc_max_length": 256, 43 | "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16", 44 | "text_enc_device": "cuda:1", 45 | "ae_device": "cuda:1", 46 | "flux_device": "cuda:0", 47 | "flow_dtype": "float16", 48 | "ae_dtype": "bfloat16", 49 | "text_enc_dtype": "bfloat16", 50 | "text_enc_quantization_dtype": "qfloat8", 51 | "ae_quantization_dtype": "qfloat8", 52 | "compile_extras": true, 53 | "compile_blocks": true, 54 | "offload_ae": false, 55 | "offload_text_enc": false, 56 | "offload_flow": false 57 | } -------------------------------------------------------------------------------- /float8_quantize.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | import math 6 | from torch.compiler import is_compiling 7 | from torch import __version__ 8 | from torch.version import cuda 9 | 10 | from modules.flux_model import Modulation 11 | 12 | IS_TORCH_2_4 = __version__ < (2, 4, 9) 13 | LT_TORCH_2_4 = __version__ < (2, 4) 14 | if LT_TORCH_2_4: 15 | if not hasattr(torch, "_scaled_mm"): 16 | raise RuntimeError( 17 | "This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later." 18 | ) 19 | CUDA_VERSION = float(cuda) if cuda else 0 20 | if CUDA_VERSION < 12.4: 21 | raise RuntimeError( 22 | f"This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later got torch version {__version__} and CUDA version {cuda}." 23 | ) 24 | try: 25 | from cublas_ops import CublasLinear 26 | except ImportError: 27 | CublasLinear = type(None) 28 | 29 | 30 | class F8Linear(nn.Module): 31 | 32 | def __init__( 33 | self, 34 | in_features: int, 35 | out_features: int, 36 | bias: bool = True, 37 | device=None, 38 | dtype=torch.float16, 39 | float8_dtype=torch.float8_e4m3fn, 40 | float_weight: torch.Tensor = None, 41 | float_bias: torch.Tensor = None, 42 | num_scale_trials: int = 12, 43 | input_float8_dtype=torch.float8_e5m2, 44 | ) -> None: 45 | super().__init__() 46 | self.in_features = in_features 47 | self.out_features = out_features 48 | self.float8_dtype = float8_dtype 49 | self.input_float8_dtype = input_float8_dtype 50 | self.input_scale_initialized = False 51 | self.weight_initialized = False 52 | self.max_value = torch.finfo(self.float8_dtype).max 53 | self.input_max_value = torch.finfo(self.input_float8_dtype).max 54 | factory_kwargs = {"dtype": dtype, "device": device} 55 | if float_weight is None: 56 | self.weight = nn.Parameter( 57 | torch.empty((out_features, in_features), **factory_kwargs) 58 | ) 59 | else: 60 | self.weight = nn.Parameter( 61 | float_weight, requires_grad=float_weight.requires_grad 62 | ) 63 | if float_bias is None: 64 | if bias: 65 | self.bias = nn.Parameter( 66 | torch.empty(out_features, **factory_kwargs), 67 | ) 68 | else: 69 | self.register_parameter("bias", None) 70 | else: 71 | self.bias = nn.Parameter(float_bias, requires_grad=float_bias.requires_grad) 72 | self.num_scale_trials = num_scale_trials 73 | self.input_amax_trials = torch.zeros( 74 | num_scale_trials, requires_grad=False, device=device, dtype=torch.float32 75 | ) 76 | self.trial_index = 0 77 | self.register_buffer("scale", None) 78 | self.register_buffer( 79 | "input_scale", 80 | None, 81 | ) 82 | self.register_buffer( 83 | "float8_data", 84 | None, 85 | ) 86 | self.scale_reciprocal = self.register_buffer("scale_reciprocal", None) 87 | self.input_scale_reciprocal = self.register_buffer( 88 | "input_scale_reciprocal", None 89 | ) 90 | 91 | def _load_from_state_dict( 92 | self, 93 | state_dict, 94 | prefix, 95 | local_metadata, 96 | strict, 97 | missing_keys, 98 | unexpected_keys, 99 | error_msgs, 100 | ): 101 | sd = {k.replace(prefix, ""): v for k, v in state_dict.items()} 102 | if "weight" in sd: 103 | if ( 104 | "float8_data" not in sd 105 | or sd["float8_data"] is None 106 | and sd["weight"].shape == (self.out_features, self.in_features) 107 | ): 108 | # Initialize as if it's an F8Linear that needs to be quantized 109 | self._parameters["weight"] = nn.Parameter( 110 | sd["weight"], requires_grad=False 111 | ) 112 | if "bias" in sd: 113 | self._parameters["bias"] = nn.Parameter( 114 | sd["bias"], requires_grad=False 115 | ) 116 | self.quantize_weight() 117 | elif sd["float8_data"].shape == ( 118 | self.out_features, 119 | self.in_features, 120 | ) and sd["weight"] == torch.zeros_like(sd["weight"]): 121 | w = sd["weight"] 122 | # Set the init values as if it's already quantized float8_data 123 | self._buffers["float8_data"] = sd["float8_data"] 124 | self._parameters["weight"] = nn.Parameter( 125 | torch.zeros( 126 | 1, 127 | dtype=w.dtype, 128 | device=w.device, 129 | requires_grad=False, 130 | ) 131 | ) 132 | if "bias" in sd: 133 | self._parameters["bias"] = nn.Parameter( 134 | sd["bias"], requires_grad=False 135 | ) 136 | self.weight_initialized = True 137 | 138 | # Check if scales and reciprocals are initialized 139 | if all( 140 | key in sd 141 | for key in [ 142 | "scale", 143 | "input_scale", 144 | "scale_reciprocal", 145 | "input_scale_reciprocal", 146 | ] 147 | ): 148 | self.scale = sd["scale"].float() 149 | self.input_scale = sd["input_scale"].float() 150 | self.scale_reciprocal = sd["scale_reciprocal"].float() 151 | self.input_scale_reciprocal = sd["input_scale_reciprocal"].float() 152 | self.input_scale_initialized = True 153 | self.trial_index = self.num_scale_trials 154 | elif "scale" in sd and "scale_reciprocal" in sd: 155 | self.scale = sd["scale"].float() 156 | self.input_scale = ( 157 | sd["input_scale"].float() if "input_scale" in sd else None 158 | ) 159 | self.scale_reciprocal = sd["scale_reciprocal"].float() 160 | self.input_scale_reciprocal = ( 161 | sd["input_scale_reciprocal"].float() 162 | if "input_scale_reciprocal" in sd 163 | else None 164 | ) 165 | self.input_scale_initialized = ( 166 | True if "input_scale" in sd else False 167 | ) 168 | self.trial_index = ( 169 | self.num_scale_trials if "input_scale" in sd else 0 170 | ) 171 | self.input_amax_trials = torch.zeros( 172 | self.num_scale_trials, 173 | requires_grad=False, 174 | dtype=torch.float32, 175 | device=self.weight.device, 176 | ) 177 | self.input_scale_initialized = False 178 | self.trial_index = 0 179 | else: 180 | # If scales are not initialized, reset trials 181 | self.input_scale_initialized = False 182 | self.trial_index = 0 183 | self.input_amax_trials = torch.zeros( 184 | self.num_scale_trials, requires_grad=False, dtype=torch.float32 185 | ) 186 | else: 187 | raise RuntimeError( 188 | f"Weight tensor not found or has incorrect shape in state dict: {sd.keys()}" 189 | ) 190 | else: 191 | raise RuntimeError( 192 | "Weight tensor not found or has incorrect shape in state dict" 193 | ) 194 | 195 | def quantize_weight(self): 196 | if self.weight_initialized: 197 | return 198 | amax = torch.max(torch.abs(self.weight.data)).float() 199 | self.scale = self.amax_to_scale(amax, self.max_value) 200 | self.float8_data = self.to_fp8_saturated( 201 | self.weight.data, self.scale, self.max_value 202 | ).to(self.float8_dtype) 203 | self.scale_reciprocal = self.scale.reciprocal() 204 | self.weight.data = torch.zeros( 205 | 1, dtype=self.weight.dtype, device=self.weight.device, requires_grad=False 206 | ) 207 | self.weight_initialized = True 208 | 209 | def set_weight_tensor(self, tensor: torch.Tensor): 210 | self.weight.data = tensor 211 | self.weight_initialized = False 212 | self.quantize_weight() 213 | 214 | def amax_to_scale(self, amax, max_val): 215 | return (max_val / torch.clamp(amax, min=1e-12)).clamp(max=max_val) 216 | 217 | def to_fp8_saturated(self, x, scale, max_val): 218 | return (x * scale).clamp(-max_val, max_val) 219 | 220 | def quantize_input(self, x: torch.Tensor): 221 | if self.input_scale_initialized: 222 | return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( 223 | self.input_float8_dtype 224 | ) 225 | elif self.trial_index < self.num_scale_trials: 226 | 227 | amax = torch.max(torch.abs(x)).float() 228 | 229 | self.input_amax_trials[self.trial_index] = amax 230 | self.trial_index += 1 231 | self.input_scale = self.amax_to_scale( 232 | self.input_amax_trials[: self.trial_index].max(), self.input_max_value 233 | ) 234 | self.input_scale_reciprocal = self.input_scale.reciprocal() 235 | return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( 236 | self.input_float8_dtype 237 | ) 238 | else: 239 | self.input_scale = self.amax_to_scale( 240 | self.input_amax_trials.max(), self.input_max_value 241 | ) 242 | self.input_scale_reciprocal = self.input_scale.reciprocal() 243 | self.input_scale_initialized = True 244 | return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( 245 | self.input_float8_dtype 246 | ) 247 | 248 | def reset_parameters(self) -> None: 249 | if self.weight_initialized: 250 | self.weight = nn.Parameter( 251 | torch.empty( 252 | (self.out_features, self.in_features), 253 | **{ 254 | "dtype": self.weight.dtype, 255 | "device": self.weight.device, 256 | }, 257 | ) 258 | ) 259 | self.weight_initialized = False 260 | self.input_scale_initialized = False 261 | self.trial_index = 0 262 | self.input_amax_trials.zero_() 263 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 264 | if self.bias is not None: 265 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 266 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 267 | init.uniform_(self.bias, -bound, bound) 268 | self.quantize_weight() 269 | self.max_value = torch.finfo(self.float8_dtype).max 270 | self.input_max_value = torch.finfo(self.input_float8_dtype).max 271 | 272 | def forward(self, x: torch.Tensor) -> torch.Tensor: 273 | if self.input_scale_initialized or is_compiling(): 274 | x = self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( 275 | self.input_float8_dtype 276 | ) 277 | else: 278 | x = self.quantize_input(x) 279 | 280 | prev_dims = x.shape[:-1] 281 | x = x.view(-1, self.in_features) 282 | 283 | # float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices! 284 | out = torch._scaled_mm( 285 | x, 286 | self.float8_data.T, 287 | scale_a=self.input_scale_reciprocal, 288 | scale_b=self.scale_reciprocal, 289 | bias=self.bias, 290 | out_dtype=self.weight.dtype, 291 | use_fast_accum=True, 292 | ) 293 | if IS_TORCH_2_4: 294 | out = out[0] 295 | out = out.view(*prev_dims, self.out_features) 296 | return out 297 | 298 | @classmethod 299 | def from_linear( 300 | cls, 301 | linear: nn.Linear, 302 | float8_dtype=torch.float8_e4m3fn, 303 | input_float8_dtype=torch.float8_e5m2, 304 | ) -> "F8Linear": 305 | f8_lin = cls( 306 | in_features=linear.in_features, 307 | out_features=linear.out_features, 308 | bias=linear.bias is not None, 309 | device=linear.weight.device, 310 | dtype=linear.weight.dtype, 311 | float8_dtype=float8_dtype, 312 | float_weight=linear.weight.data, 313 | float_bias=(linear.bias.data if linear.bias is not None else None), 314 | input_float8_dtype=input_float8_dtype, 315 | ) 316 | f8_lin.quantize_weight() 317 | return f8_lin 318 | 319 | 320 | @torch.inference_mode() 321 | def recursive_swap_linears( 322 | model: nn.Module, 323 | float8_dtype=torch.float8_e4m3fn, 324 | input_float8_dtype=torch.float8_e5m2, 325 | quantize_modulation: bool = True, 326 | ignore_keys: list[str] = [], 327 | ) -> None: 328 | """ 329 | Recursively swaps all nn.Linear modules in the given model with F8Linear modules. 330 | 331 | This function traverses the model's structure and replaces each nn.Linear 332 | instance with an F8Linear instance, which uses 8-bit floating point 333 | quantization for weights. The original linear layer's weights are deleted 334 | after conversion to save memory. 335 | 336 | Args: 337 | model (nn.Module): The PyTorch model to modify. 338 | 339 | Note: 340 | This function modifies the model in-place. After calling this function, 341 | all linear layers in the model will be using 8-bit quantization. 342 | """ 343 | for name, child in model.named_children(): 344 | if name in ignore_keys: 345 | continue 346 | if isinstance(child, Modulation) and not quantize_modulation: 347 | continue 348 | if isinstance(child, nn.Linear) and not isinstance( 349 | child, (F8Linear, CublasLinear) 350 | ): 351 | 352 | setattr( 353 | model, 354 | name, 355 | F8Linear.from_linear( 356 | child, 357 | float8_dtype=float8_dtype, 358 | input_float8_dtype=input_float8_dtype, 359 | ), 360 | ) 361 | del child 362 | else: 363 | recursive_swap_linears( 364 | child, 365 | float8_dtype=float8_dtype, 366 | input_float8_dtype=input_float8_dtype, 367 | quantize_modulation=quantize_modulation, 368 | ignore_keys=ignore_keys, 369 | ) 370 | 371 | 372 | @torch.inference_mode() 373 | def swap_to_cublaslinear(model: nn.Module): 374 | if CublasLinear == type(None): 375 | return 376 | for name, child in model.named_children(): 377 | if isinstance(child, nn.Linear) and not isinstance( 378 | child, (F8Linear, CublasLinear) 379 | ): 380 | cublas_lin = CublasLinear( 381 | child.in_features, 382 | child.out_features, 383 | bias=child.bias is not None, 384 | dtype=child.weight.dtype, 385 | device=child.weight.device, 386 | ) 387 | cublas_lin.weight.data = child.weight.clone().detach() 388 | cublas_lin.bias.data = child.bias.clone().detach() 389 | setattr(model, name, cublas_lin) 390 | del child 391 | else: 392 | swap_to_cublaslinear(child) 393 | 394 | 395 | @torch.inference_mode() 396 | def quantize_flow_transformer_and_dispatch_float8( 397 | flow_model: nn.Module, 398 | device=torch.device("cuda"), 399 | float8_dtype=torch.float8_e4m3fn, 400 | input_float8_dtype=torch.float8_e5m2, 401 | offload_flow=False, 402 | swap_linears_with_cublaslinear=True, 403 | flow_dtype=torch.float16, 404 | quantize_modulation: bool = True, 405 | quantize_flow_embedder_layers: bool = True, 406 | ) -> nn.Module: 407 | """ 408 | Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device. 409 | 410 | Iteratively pushes each module to device, evals, replaces linear layers with F8Linear except for final_layer, and quantizes. 411 | 412 | Allows for fast dispatch to gpu & quantize without causing OOM on gpus with limited memory. 413 | 414 | After dispatching, if offload_flow is True, offloads the model to cpu. 415 | 416 | if swap_linears_with_cublaslinear is true, and flow_dtype == torch.float16, then swap all linears with cublaslinears for 2x performance boost on consumer GPUs. 417 | Otherwise will skip the cublaslinear swap. 418 | 419 | For added extra precision, you can set quantize_flow_embedder_layers to False, 420 | this helps maintain the output quality of the flow transformer moreso than fully quantizing, 421 | at the expense of ~512MB more VRAM usage. 422 | 423 | For added extra precision, you can set quantize_modulation to False, 424 | this helps maintain the output quality of the flow transformer moreso than fully quantizing, 425 | at the expense of ~2GB more VRAM usage, but- has a much higher impact on image quality than the embedder layers. 426 | """ 427 | for module in flow_model.double_blocks: 428 | module.to(device) 429 | module.eval() 430 | recursive_swap_linears( 431 | module, 432 | float8_dtype=float8_dtype, 433 | input_float8_dtype=input_float8_dtype, 434 | quantize_modulation=quantize_modulation, 435 | ) 436 | torch.cuda.empty_cache() 437 | for module in flow_model.single_blocks: 438 | module.to(device) 439 | module.eval() 440 | recursive_swap_linears( 441 | module, 442 | float8_dtype=float8_dtype, 443 | input_float8_dtype=input_float8_dtype, 444 | quantize_modulation=quantize_modulation, 445 | ) 446 | torch.cuda.empty_cache() 447 | to_gpu_extras = [ 448 | "vector_in", 449 | "img_in", 450 | "txt_in", 451 | "time_in", 452 | "guidance_in", 453 | "final_layer", 454 | "pe_embedder", 455 | ] 456 | for module in to_gpu_extras: 457 | m_extra = getattr(flow_model, module) 458 | if m_extra is None: 459 | continue 460 | m_extra.to(device) 461 | m_extra.eval() 462 | if isinstance(m_extra, nn.Linear) and not isinstance( 463 | m_extra, (F8Linear, CublasLinear) 464 | ): 465 | if quantize_flow_embedder_layers: 466 | setattr( 467 | flow_model, 468 | module, 469 | F8Linear.from_linear( 470 | m_extra, 471 | float8_dtype=float8_dtype, 472 | input_float8_dtype=input_float8_dtype, 473 | ), 474 | ) 475 | del m_extra 476 | elif module != "final_layer": 477 | if quantize_flow_embedder_layers: 478 | recursive_swap_linears( 479 | m_extra, 480 | float8_dtype=float8_dtype, 481 | input_float8_dtype=input_float8_dtype, 482 | quantize_modulation=quantize_modulation, 483 | ) 484 | torch.cuda.empty_cache() 485 | if ( 486 | swap_linears_with_cublaslinear 487 | and flow_dtype == torch.float16 488 | and CublasLinear != type(None) 489 | ): 490 | swap_to_cublaslinear(flow_model) 491 | elif swap_linears_with_cublaslinear and flow_dtype != torch.float16: 492 | logger.warning("Skipping cublas linear swap because flow_dtype is not float16") 493 | if offload_flow: 494 | flow_model.to("cpu") 495 | torch.cuda.empty_cache() 496 | return flow_model 497 | -------------------------------------------------------------------------------- /flux_emphasis.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Optional 2 | from pydash import flatten 3 | 4 | import torch 5 | from transformers.models.clip.tokenization_clip import CLIPTokenizer 6 | from einops import repeat 7 | 8 | if TYPE_CHECKING: 9 | from flux_pipeline import FluxPipeline 10 | 11 | 12 | def parse_prompt_attention(text): 13 | """ 14 | Parses a string with attention tokens and returns a list of pairs: text and its associated weight. 15 | Accepted tokens are: 16 | (abc) - increases attention to abc by a multiplier of 1.1 17 | (abc:3.12) - increases attention to abc by a multiplier of 3.12 18 | [abc] - decreases attention to abc by a multiplier of 1.1 19 | \\( - literal character '(' 20 | \\[ - literal character '[' 21 | \\) - literal character ')' 22 | \\] - literal character ']' 23 | \\ - literal character '\' 24 | anything else - just text 25 | 26 | >>> parse_prompt_attention('normal text') 27 | [['normal text', 1.0]] 28 | >>> parse_prompt_attention('an (important) word') 29 | [['an ', 1.0], ['important', 1.1], [' word', 1.0]] 30 | >>> parse_prompt_attention('(unbalanced') 31 | [['unbalanced', 1.1]] 32 | >>> parse_prompt_attention('\\(literal\\]') 33 | [['(literal]', 1.0]] 34 | >>> parse_prompt_attention('(unnecessary)(parens)') 35 | [['unnecessaryparens', 1.1]] 36 | >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') 37 | [['a ', 1.0], 38 | ['house', 1.5730000000000004], 39 | [' ', 1.1], 40 | ['on', 1.0], 41 | [' a ', 1.1], 42 | ['hill', 0.55], 43 | [', sun, ', 1.1], 44 | ['sky', 1.4641000000000006], 45 | ['.', 1.1]] 46 | """ 47 | import re 48 | 49 | re_attention = re.compile( 50 | r""" 51 | \\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)| 52 | \)|]|[^\\()\[\]:]+|: 53 | """, 54 | re.X, 55 | ) 56 | 57 | re_break = re.compile(r"\s*\bBREAK\b\s*", re.S) 58 | 59 | res = [] 60 | round_brackets = [] 61 | square_brackets = [] 62 | 63 | round_bracket_multiplier = 1.1 64 | square_bracket_multiplier = 1 / 1.1 65 | 66 | def multiply_range(start_position, multiplier): 67 | for p in range(start_position, len(res)): 68 | res[p][1] *= multiplier 69 | 70 | for m in re_attention.finditer(text): 71 | text = m.group(0) 72 | weight = m.group(1) 73 | 74 | if text.startswith("\\"): 75 | res.append([text[1:], 1.0]) 76 | elif text == "(": 77 | round_brackets.append(len(res)) 78 | elif text == "[": 79 | square_brackets.append(len(res)) 80 | elif weight is not None and len(round_brackets) > 0: 81 | multiply_range(round_brackets.pop(), float(weight)) 82 | elif text == ")" and len(round_brackets) > 0: 83 | multiply_range(round_brackets.pop(), round_bracket_multiplier) 84 | elif text == "]" and len(square_brackets) > 0: 85 | multiply_range(square_brackets.pop(), square_bracket_multiplier) 86 | else: 87 | parts = re.split(re_break, text) 88 | for i, part in enumerate(parts): 89 | if i > 0: 90 | res.append(["BREAK", -1]) 91 | res.append([part, 1.0]) 92 | 93 | for pos in round_brackets: 94 | multiply_range(pos, round_bracket_multiplier) 95 | 96 | for pos in square_brackets: 97 | multiply_range(pos, square_bracket_multiplier) 98 | 99 | if len(res) == 0: 100 | res = [["", 1.0]] 101 | 102 | # merge runs of identical weights 103 | i = 0 104 | while i + 1 < len(res): 105 | if res[i][1] == res[i + 1][1]: 106 | res[i][0] += res[i + 1][0] 107 | res.pop(i + 1) 108 | else: 109 | i += 1 110 | 111 | return res 112 | 113 | 114 | def get_prompts_tokens_with_weights( 115 | clip_tokenizer: CLIPTokenizer, prompt: str, debug: bool = False 116 | ): 117 | """ 118 | Get prompt token ids and weights, this function works for both prompt and negative prompt 119 | 120 | Args: 121 | pipe (CLIPTokenizer) 122 | A CLIPTokenizer 123 | prompt (str) 124 | A prompt string with weights 125 | 126 | Returns: 127 | text_tokens (list) 128 | A list contains token ids 129 | text_weight (list) 130 | A list contains the correspodent weight of token ids 131 | 132 | Example: 133 | import torch 134 | from transformers import CLIPTokenizer 135 | 136 | clip_tokenizer = CLIPTokenizer.from_pretrained( 137 | "stablediffusionapi/deliberate-v2" 138 | , subfolder = "tokenizer" 139 | , dtype = torch.float16 140 | ) 141 | 142 | token_id_list, token_weight_list = get_prompts_tokens_with_weights( 143 | clip_tokenizer = clip_tokenizer 144 | ,prompt = "a (red:1.5) cat"*70 145 | ) 146 | """ 147 | texts_and_weights = parse_prompt_attention(prompt) 148 | text_tokens, text_weights = [], [] 149 | maxlen = clip_tokenizer.model_max_length 150 | for word, weight in texts_and_weights: 151 | # tokenize and discard the starting and the ending token 152 | token = clip_tokenizer( 153 | word, truncation=False, padding=False, add_special_tokens=False 154 | ).input_ids 155 | # so that tokenize whatever length prompt 156 | # the returned token is a 1d list: [320, 1125, 539, 320] 157 | if debug: 158 | print( 159 | token, 160 | "|FOR MODEL LEN{}|".format(maxlen), 161 | clip_tokenizer.decode( 162 | token, skip_special_tokens=True, clean_up_tokenization_spaces=True 163 | ), 164 | ) 165 | # merge the new tokens to the all tokens holder: text_tokens 166 | text_tokens = [*text_tokens, *token] 167 | 168 | # each token chunk will come with one weight, like ['red cat', 2.0] 169 | # need to expand weight for each token. 170 | chunk_weights = [weight] * len(token) 171 | 172 | # append the weight back to the weight holder: text_weights 173 | text_weights = [*text_weights, *chunk_weights] 174 | return text_tokens, text_weights 175 | 176 | 177 | def group_tokens_and_weights( 178 | token_ids: list, 179 | weights: list, 180 | pad_last_block=False, 181 | bos=49406, 182 | eos=49407, 183 | max_length=77, 184 | pad_tokens=True, 185 | ): 186 | """ 187 | Produce tokens and weights in groups and pad the missing tokens 188 | 189 | Args: 190 | token_ids (list) 191 | The token ids from tokenizer 192 | weights (list) 193 | The weights list from function get_prompts_tokens_with_weights 194 | pad_last_block (bool) 195 | Control if fill the last token list to 75 tokens with eos 196 | Returns: 197 | new_token_ids (2d list) 198 | new_weights (2d list) 199 | 200 | Example: 201 | token_groups,weight_groups = group_tokens_and_weights( 202 | token_ids = token_id_list 203 | , weights = token_weight_list 204 | ) 205 | """ 206 | # TODO: Possibly need to fix this, since this doesn't seem correct. 207 | # Ignoring for now since I don't know what the consequences might be 208 | # if changed to <= instead of <. 209 | max_len = max_length - 2 if max_length < 77 else max_length 210 | # this will be a 2d list 211 | new_token_ids = [] 212 | new_weights = [] 213 | while len(token_ids) >= max_len: 214 | # get the first 75 tokens 215 | temp_77_token_ids = [token_ids.pop(0) for _ in range(max_len)] 216 | temp_77_weights = [weights.pop(0) for _ in range(max_len)] 217 | 218 | # extract token ids and weights 219 | 220 | if pad_tokens: 221 | if bos is not None: 222 | temp_77_token_ids = [bos] + temp_77_token_ids + [eos] 223 | temp_77_weights = [1.0] + temp_77_weights + [1.0] 224 | else: 225 | temp_77_token_ids = temp_77_token_ids + [eos] 226 | temp_77_weights = temp_77_weights + [1.0] 227 | 228 | # add 77 token and weights chunk to the holder list 229 | new_token_ids.append(temp_77_token_ids) 230 | new_weights.append(temp_77_weights) 231 | 232 | # padding the left 233 | if len(token_ids) > 0: 234 | if pad_tokens: 235 | padding_len = max_len - len(token_ids) if pad_last_block else 0 236 | 237 | temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos] 238 | new_token_ids.append(temp_77_token_ids) 239 | 240 | temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0] 241 | new_weights.append(temp_77_weights) 242 | else: 243 | new_token_ids.append(token_ids) 244 | new_weights.append(weights) 245 | return new_token_ids, new_weights 246 | 247 | 248 | def standardize_tensor( 249 | input_tensor: torch.Tensor, target_mean: float, target_std: float 250 | ) -> torch.Tensor: 251 | """ 252 | This function standardizes an input tensor so that it has a specific mean and standard deviation. 253 | 254 | Parameters: 255 | input_tensor (torch.Tensor): The tensor to standardize. 256 | target_mean (float): The target mean for the tensor. 257 | target_std (float): The target standard deviation for the tensor. 258 | 259 | Returns: 260 | torch.Tensor: The standardized tensor. 261 | """ 262 | 263 | # First, compute the mean and std of the input tensor 264 | mean = input_tensor.mean() 265 | std = input_tensor.std() 266 | 267 | # Then, standardize the tensor to have a mean of 0 and std of 1 268 | standardized_tensor = (input_tensor - mean) / std 269 | 270 | # Finally, scale the tensor to the target mean and std 271 | output_tensor = standardized_tensor * target_std + target_mean 272 | 273 | return output_tensor 274 | 275 | 276 | def apply_weights( 277 | prompt_tokens: torch.Tensor, 278 | weight_tensor: torch.Tensor, 279 | token_embedding: torch.Tensor, 280 | eos_token_id: int, 281 | pad_last_block: bool = True, 282 | ) -> torch.FloatTensor: 283 | mean = token_embedding.mean() 284 | std = token_embedding.std() 285 | if pad_last_block: 286 | pooled_tensor = token_embedding[ 287 | torch.arange(token_embedding.shape[0], device=token_embedding.device), 288 | ( 289 | prompt_tokens.to(dtype=torch.int, device=token_embedding.device) 290 | == eos_token_id 291 | ) 292 | .int() 293 | .argmax(dim=-1), 294 | ] 295 | else: 296 | pooled_tensor = token_embedding[:, -1] 297 | 298 | for j in range(len(weight_tensor)): 299 | if weight_tensor[j] != 1.0: 300 | token_embedding[:, j] = ( 301 | pooled_tensor 302 | + (token_embedding[:, j] - pooled_tensor) * weight_tensor[j] 303 | ) 304 | return standardize_tensor(token_embedding, mean, std) 305 | 306 | 307 | @torch.inference_mode() 308 | def get_weighted_text_embeddings_flux( 309 | pipe: "FluxPipeline", 310 | prompt: str = "", 311 | num_images_per_prompt: int = 1, 312 | device: Optional[torch.device] = None, 313 | target_device: Optional[torch.device] = torch.device("cuda:0"), 314 | target_dtype: Optional[torch.dtype] = torch.bfloat16, 315 | debug: bool = False, 316 | ): 317 | """ 318 | This function can process long prompt with weights, no length limitation 319 | for Stable Diffusion XL 320 | 321 | Args: 322 | pipe (StableDiffusionPipeline) 323 | prompt (str) 324 | prompt_2 (str) 325 | neg_prompt (str) 326 | neg_prompt_2 (str) 327 | num_images_per_prompt (int) 328 | device (torch.device) 329 | Returns: 330 | prompt_embeds (torch.Tensor) 331 | neg_prompt_embeds (torch.Tensor) 332 | """ 333 | device = device or pipe._execution_device 334 | 335 | eos = pipe.clip.tokenizer.eos_token_id 336 | eos_2 = pipe.t5.tokenizer.eos_token_id 337 | bos = pipe.clip.tokenizer.bos_token_id 338 | bos_2 = pipe.t5.tokenizer.bos_token_id 339 | 340 | clip = pipe.clip.hf_module 341 | t5 = pipe.t5.hf_module 342 | 343 | tokenizer_clip = pipe.clip.tokenizer 344 | tokenizer_t5 = pipe.t5.tokenizer 345 | 346 | t5_length = 512 if pipe.name == "flux-dev" else 256 347 | clip_length = 77 348 | 349 | # tokenizer 1 350 | prompt_tokens_clip, prompt_weights_clip = get_prompts_tokens_with_weights( 351 | tokenizer_clip, prompt, debug=debug 352 | ) 353 | 354 | # tokenizer 2 355 | prompt_tokens_t5, prompt_weights_t5 = get_prompts_tokens_with_weights( 356 | tokenizer_t5, prompt, debug=debug 357 | ) 358 | 359 | prompt_tokens_clip_grouped, prompt_weights_clip_grouped = group_tokens_and_weights( 360 | prompt_tokens_clip, 361 | prompt_weights_clip, 362 | pad_last_block=True, 363 | bos=bos, 364 | eos=eos, 365 | max_length=clip_length, 366 | ) 367 | prompt_tokens_t5_grouped, prompt_weights_t5_grouped = group_tokens_and_weights( 368 | prompt_tokens_t5, 369 | prompt_weights_t5, 370 | pad_last_block=True, 371 | bos=bos_2, 372 | eos=eos_2, 373 | max_length=t5_length, 374 | pad_tokens=False, 375 | ) 376 | prompt_tokens_t5 = flatten(prompt_tokens_t5_grouped) 377 | prompt_weights_t5 = flatten(prompt_weights_t5_grouped) 378 | prompt_tokens_clip = flatten(prompt_tokens_clip_grouped) 379 | prompt_weights_clip = flatten(prompt_weights_clip_grouped) 380 | 381 | prompt_tokens_clip = tokenizer_clip.decode( 382 | prompt_tokens_clip, skip_special_tokens=True, clean_up_tokenization_spaces=True 383 | ) 384 | prompt_tokens_clip = tokenizer_clip( 385 | prompt_tokens_clip, 386 | add_special_tokens=True, 387 | padding="max_length", 388 | truncation=True, 389 | max_length=clip_length, 390 | return_tensors="pt", 391 | ).input_ids.to(device) 392 | prompt_tokens_t5 = tokenizer_t5.decode( 393 | prompt_tokens_t5, skip_special_tokens=True, clean_up_tokenization_spaces=True 394 | ) 395 | prompt_tokens_t5 = tokenizer_t5( 396 | prompt_tokens_t5, 397 | add_special_tokens=True, 398 | padding="max_length", 399 | truncation=True, 400 | max_length=t5_length, 401 | return_tensors="pt", 402 | ).input_ids.to(device) 403 | 404 | prompt_weights_t5 = torch.cat( 405 | [ 406 | torch.tensor(prompt_weights_t5, dtype=torch.float32), 407 | torch.full( 408 | (t5_length - torch.tensor(prompt_weights_t5).numel(),), 409 | 1.0, 410 | dtype=torch.float32, 411 | ), 412 | ], 413 | dim=0, 414 | ).to(device) 415 | 416 | clip_embeds = clip( 417 | prompt_tokens_clip, output_hidden_states=True, attention_mask=None 418 | )["pooler_output"] 419 | if clip_embeds.shape[0] == 1 and num_images_per_prompt > 1: 420 | clip_embeds = repeat(clip_embeds, "1 ... -> bs ...", bs=num_images_per_prompt) 421 | 422 | weight_tensor_t5 = torch.tensor( 423 | flatten(prompt_weights_t5), dtype=torch.float32, device=device 424 | ) 425 | t5_embeds = t5(prompt_tokens_t5, output_hidden_states=True, attention_mask=None)[ 426 | "last_hidden_state" 427 | ] 428 | t5_embeds = apply_weights(prompt_tokens_t5, weight_tensor_t5, t5_embeds, eos_2) 429 | if debug: 430 | print(t5_embeds.shape) 431 | if t5_embeds.shape[0] == 1 and num_images_per_prompt > 1: 432 | t5_embeds = repeat(t5_embeds, "1 ... -> bs ...", bs=num_images_per_prompt) 433 | txt_ids = torch.zeros( 434 | num_images_per_prompt, 435 | t5_embeds.shape[1], 436 | 3, 437 | device=target_device, 438 | dtype=target_dtype, 439 | ) 440 | t5_embeds = t5_embeds.to(target_device, dtype=target_dtype) 441 | clip_embeds = clip_embeds.to(target_device, dtype=target_dtype) 442 | 443 | return ( 444 | clip_embeds, 445 | t5_embeds, 446 | txt_ids, 447 | ) 448 | -------------------------------------------------------------------------------- /flux_pipeline.py: -------------------------------------------------------------------------------- 1 | import io 2 | import math 3 | import random 4 | import warnings 5 | from typing import TYPE_CHECKING, Callable, List, Optional, OrderedDict, Union 6 | 7 | import numpy as np 8 | from PIL import Image 9 | 10 | warnings.filterwarnings("ignore", category=UserWarning) 11 | warnings.filterwarnings("ignore", category=FutureWarning) 12 | warnings.filterwarnings("ignore", category=DeprecationWarning) 13 | import torch 14 | from einops import rearrange 15 | 16 | from flux_emphasis import get_weighted_text_embeddings_flux 17 | 18 | torch.backends.cuda.matmul.allow_tf32 = True 19 | torch.backends.cudnn.allow_tf32 = True 20 | torch.backends.cudnn.benchmark = True 21 | torch.backends.cudnn.benchmark_limit = 20 22 | torch.set_float32_matmul_precision("high") 23 | from pybase64 import standard_b64decode 24 | from torch._dynamo import config 25 | from torch._inductor import config as ind_config 26 | 27 | config.cache_size_limit = 10000000000 28 | ind_config.shape_padding = True 29 | import platform 30 | 31 | from loguru import logger 32 | from torchvision.transforms import functional as TF 33 | from tqdm import tqdm 34 | 35 | import lora_loading 36 | from image_encoder import ImageEncoder 37 | from util import ( 38 | ModelSpec, 39 | ModelVersion, 40 | into_device, 41 | into_dtype, 42 | load_config_from_path, 43 | load_models_from_config, 44 | ) 45 | 46 | if platform.system() == "Windows": 47 | MAX_RAND = 2**16 - 1 48 | else: 49 | MAX_RAND = 2**32 - 1 50 | 51 | 52 | if TYPE_CHECKING: 53 | from modules.autoencoder import AutoEncoder 54 | from modules.conditioner import HFEmbedder 55 | from modules.flux_model import Flux 56 | 57 | 58 | class FluxPipeline: 59 | """ 60 | FluxPipeline is a class that provides a pipeline for generating images using the Flux model. 61 | It handles input preparation, timestep generation, noise generation, device management 62 | and model compilation. 63 | """ 64 | 65 | def __init__( 66 | self, 67 | name: str, 68 | offload: bool = False, 69 | clip: "HFEmbedder" = None, 70 | t5: "HFEmbedder" = None, 71 | model: "Flux" = None, 72 | ae: "AutoEncoder" = None, 73 | dtype: torch.dtype = torch.float16, 74 | verbose: bool = False, 75 | flux_device: torch.device | str = "cuda:0", 76 | ae_device: torch.device | str = "cuda:1", 77 | clip_device: torch.device | str = "cuda:1", 78 | t5_device: torch.device | str = "cuda:1", 79 | config: ModelSpec = None, 80 | debug: bool = False, 81 | ): 82 | """ 83 | Initialize the FluxPipeline class. 84 | 85 | This class is responsible for preparing input tensors for the Flux model, generating 86 | timesteps and noise, and handling device management for model offloading. 87 | """ 88 | 89 | if config is None: 90 | raise ValueError("ModelSpec config is required!") 91 | 92 | self.debug = debug 93 | self.name = name 94 | self.device_flux = into_device(flux_device) 95 | self.device_ae = into_device(ae_device) 96 | self.device_clip = into_device(clip_device) 97 | self.device_t5 = into_device(t5_device) 98 | self.dtype = into_dtype(dtype) 99 | self.offload = offload 100 | self.clip: "HFEmbedder" = clip 101 | self.t5: "HFEmbedder" = t5 102 | self.model: "Flux" = model 103 | self.ae: "AutoEncoder" = ae 104 | self.rng = torch.Generator(device="cpu") 105 | self.img_encoder = ImageEncoder() 106 | self.verbose = verbose 107 | self.ae_dtype = torch.bfloat16 108 | self.config = config 109 | self.offload_text_encoder = config.offload_text_encoder 110 | self.offload_vae = config.offload_vae 111 | self.offload_flow = config.offload_flow 112 | # If models are not offloaded, move them to the appropriate devices 113 | 114 | if not self.offload_flow: 115 | self.model.to(self.device_flux) 116 | if not self.offload_vae: 117 | self.ae.to(self.device_ae) 118 | if not self.offload_text_encoder: 119 | self.clip.to(self.device_clip) 120 | self.t5.to(self.device_t5) 121 | 122 | # compile the model if needed 123 | if config.compile_blocks or config.compile_extras: 124 | self.compile() 125 | 126 | def set_seed( 127 | self, seed: int | None = None, seed_globally: bool = False 128 | ) -> torch.Generator: 129 | if isinstance(seed, (int, float)): 130 | seed = int(abs(seed)) % MAX_RAND 131 | cuda_generator = torch.Generator("cuda").manual_seed(seed) 132 | elif isinstance(seed, str): 133 | try: 134 | seed = abs(int(seed)) % MAX_RAND 135 | except Exception as e: 136 | logger.warning( 137 | f"Recieved string representation of seed, but was not able to convert to int: {seed}, using random seed" 138 | ) 139 | seed = abs(self.rng.seed()) % MAX_RAND 140 | cuda_generator = torch.Generator("cuda").manual_seed(seed) 141 | else: 142 | seed = abs(self.rng.seed()) % MAX_RAND 143 | cuda_generator = torch.Generator("cuda").manual_seed(seed) 144 | 145 | if seed_globally: 146 | torch.cuda.manual_seed_all(seed) 147 | np.random.seed(seed) 148 | random.seed(seed) 149 | return cuda_generator, seed 150 | 151 | def load_lora( 152 | self, 153 | lora_path: Union[str, OrderedDict[str, torch.Tensor]], 154 | scale: float, 155 | name: Optional[str] = None, 156 | ): 157 | """ 158 | Loads a LoRA checkpoint into the Flux flow transformer. 159 | 160 | Currently supports LoRA checkpoints from either diffusers checkpoints which usually start with transformer.[...], 161 | or loras which contain keys which start with lora_unet_[...]. 162 | 163 | Args: 164 | lora_path (str | OrderedDict[str, torch.Tensor]): Path to the LoRA checkpoint or an ordered dictionary containing the LoRA weights. 165 | scale (float): Scaling factor for the LoRA weights. 166 | name (str): Name of the LoRA checkpoint, optionally can be left as None, since it only acts as an identifier. 167 | """ 168 | self.model.load_lora(path=lora_path, scale=scale, name=name) 169 | 170 | def unload_lora(self, path_or_identifier: str): 171 | """ 172 | Unloads the LoRA checkpoint from the Flux flow transformer. 173 | 174 | Args: 175 | path_or_identifier (str): Path to the LoRA checkpoint or the name given to the LoRA checkpoint when it was loaded. 176 | """ 177 | self.model.unload_lora(path_or_identifier=path_or_identifier) 178 | 179 | @torch.inference_mode() 180 | def compile(self): 181 | """ 182 | Compiles the model and extras. 183 | 184 | First, if: 185 | 186 | - A) Checkpoint which already has float8 quantized weights and tuned input scales. 187 | In which case, it will not run warmups since it assumes the input scales are already tuned. 188 | 189 | - B) Checkpoint which has not been quantized, in which case it will be quantized 190 | and the input scales will be tuned. via running a warmup loop. 191 | - If the model is flux-schnell, it will run 3 warmup loops since each loop is 4 steps. 192 | - If the model is flux-dev, it will run 1 warmup loop for 12 steps. 193 | 194 | """ 195 | 196 | # Run warmups if the checkpoint is not prequantized 197 | if not self.config.prequantized_flow: 198 | logger.info("Running warmups for compile...") 199 | warmup_dict = dict( 200 | prompt="A beautiful test image used to solidify the fp8 nn.Linear input scales prior to compilation 😉", 201 | height=768, 202 | width=768, 203 | num_steps=12, 204 | guidance=3.5, 205 | seed=10, 206 | ) 207 | if self.config.version == ModelVersion.flux_schnell: 208 | warmup_dict["num_steps"] = 4 209 | for _ in range(3): 210 | self.generate(**warmup_dict) 211 | else: 212 | self.generate(**warmup_dict) 213 | 214 | # Compile the model and extras 215 | to_gpu_extras = [ 216 | "vector_in", 217 | "img_in", 218 | "txt_in", 219 | "time_in", 220 | "guidance_in", 221 | "final_layer", 222 | "pe_embedder", 223 | ] 224 | if self.config.compile_blocks: 225 | for block in self.model.double_blocks: 226 | block.compile() 227 | for block in self.model.single_blocks: 228 | block.compile() 229 | if self.config.compile_extras: 230 | for extra in to_gpu_extras: 231 | getattr(self.model, extra).compile() 232 | 233 | @torch.inference_mode() 234 | def prepare( 235 | self, 236 | img: torch.Tensor, 237 | prompt: str | list[str], 238 | target_device: torch.device = torch.device("cuda:0"), 239 | target_dtype: torch.dtype = torch.float16, 240 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 241 | """ 242 | Prepare input tensors for the Flux model. 243 | 244 | This function processes the input image and text prompt, converting them into 245 | the appropriate format and embedding representations required by the model. 246 | 247 | Args: 248 | img (torch.Tensor): Input image tensor of shape (batch_size, channels, height, width). 249 | prompt (str | list[str]): Text prompt or list of prompts guiding the image generation. 250 | target_device (torch.device, optional): The target device for the output tensors. 251 | Defaults to torch.device("cuda:0"). 252 | target_dtype (torch.dtype, optional): The target data type for the output tensors. 253 | Defaults to torch.float16. 254 | 255 | Returns: 256 | tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: 257 | - img: Processed image tensor. 258 | - img_ids: Image position IDs. 259 | - vec: Clip text embedding vector. 260 | - txt: T5 text embedding hidden states. 261 | - txt_ids: Text position IDs. 262 | 263 | Note: 264 | This function handles the necessary device management for text encoder offloading 265 | if enabled in the configuration. 266 | """ 267 | bs, c, h, w = img.shape 268 | if bs == 1 and not isinstance(prompt, str): 269 | bs = len(prompt) 270 | img = img.unfold(2, 2, 2).unfold(3, 2, 2).permute(0, 2, 3, 1, 4, 5) 271 | img = img.reshape(img.shape[0], -1, img.shape[3] * img.shape[4] * img.shape[5]) 272 | assert img.shape == ( 273 | bs, 274 | (h // 2) * (w // 2), 275 | c * 2 * 2, 276 | ), f"{img.shape} != {(bs, (h//2)*(w//2), c*2*2)}" 277 | if img.shape[0] == 1 and bs > 1: 278 | img = img[None].repeat_interleave(bs, dim=0) 279 | 280 | img_ids = torch.zeros( 281 | h // 2, w // 2, 3, device=target_device, dtype=target_dtype 282 | ) 283 | img_ids[..., 1] = ( 284 | img_ids[..., 1] 285 | + torch.arange(h // 2, device=target_device, dtype=target_dtype)[:, None] 286 | ) 287 | img_ids[..., 2] = ( 288 | img_ids[..., 2] 289 | + torch.arange(w // 2, device=target_device, dtype=target_dtype)[None, :] 290 | ) 291 | 292 | img_ids = img_ids[None].repeat(bs, 1, 1, 1).flatten(1, 2) 293 | if self.offload_text_encoder: 294 | self.clip.to(device=self.device_clip) 295 | self.t5.to(device=self.device_t5) 296 | 297 | # get the text embeddings 298 | vec, txt, txt_ids = get_weighted_text_embeddings_flux( 299 | self, 300 | prompt, 301 | num_images_per_prompt=bs, 302 | device=self.device_clip, 303 | target_device=target_device, 304 | target_dtype=target_dtype, 305 | debug=self.debug, 306 | ) 307 | # offload text encoder to cpu if needed 308 | if self.offload_text_encoder: 309 | self.clip.to("cpu") 310 | self.t5.to("cpu") 311 | torch.cuda.empty_cache() 312 | return img, img_ids, vec, txt, txt_ids 313 | 314 | @torch.inference_mode() 315 | def time_shift(self, mu: float, sigma: float, t: torch.Tensor): 316 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 317 | 318 | def get_lin_function( 319 | self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 320 | ) -> Callable[[float], float]: 321 | m = (y2 - y1) / (x2 - x1) 322 | b = y1 - m * x1 323 | return lambda x: m * x + b 324 | 325 | @torch.inference_mode() 326 | def get_schedule( 327 | self, 328 | num_steps: int, 329 | image_seq_len: int, 330 | base_shift: float = 0.5, 331 | max_shift: float = 1.15, 332 | shift: bool = True, 333 | ) -> list[float]: 334 | """Generates a schedule of timesteps for the given number of steps and image sequence length.""" 335 | # extra step for zero 336 | timesteps = torch.linspace(1, 0, num_steps + 1) 337 | 338 | # shifting the schedule to favor high timesteps for higher signal images 339 | if shift: 340 | # eastimate mu based on linear estimation between two points 341 | mu = self.get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) 342 | timesteps = self.time_shift(mu, 1.0, timesteps) 343 | 344 | return timesteps.tolist() 345 | 346 | @torch.inference_mode() 347 | def get_noise( 348 | self, 349 | num_samples: int, 350 | height: int, 351 | width: int, 352 | generator: torch.Generator, 353 | dtype=None, 354 | device=None, 355 | ) -> torch.Tensor: 356 | """Generates a latent noise tensor of the given shape and dtype on the given device.""" 357 | if device is None: 358 | device = self.device_flux 359 | if dtype is None: 360 | dtype = self.dtype 361 | return torch.randn( 362 | num_samples, 363 | 16, 364 | # allow for packing 365 | 2 * math.ceil(height / 16), 366 | 2 * math.ceil(width / 16), 367 | device=device, 368 | dtype=dtype, 369 | generator=generator, 370 | requires_grad=False, 371 | ) 372 | 373 | @torch.inference_mode() 374 | def into_bytes(self, x: torch.Tensor, jpeg_quality: int = 99) -> io.BytesIO: 375 | """Converts the image tensor to bytes.""" 376 | # bring into PIL format and save 377 | num_images = x.shape[0] 378 | images: List[torch.Tensor] = [] 379 | for i in range(num_images): 380 | x = ( 381 | x[i] 382 | .clamp(-1, 1) 383 | .add(1.0) 384 | .mul(127.5) 385 | .clamp(0, 255) 386 | .contiguous() 387 | .type(torch.uint8) 388 | ) 389 | images.append(x) 390 | if len(images) == 1: 391 | im = images[0] 392 | else: 393 | im = torch.vstack(images) 394 | 395 | im = self.img_encoder.encode_torch(im, quality=jpeg_quality) 396 | images.clear() 397 | return im 398 | 399 | @torch.inference_mode() 400 | def load_init_image_if_needed( 401 | self, init_image: torch.Tensor | str | Image.Image | np.ndarray 402 | ) -> torch.Tensor: 403 | """ 404 | Loads the initial image if it is a string, numpy array, or PIL.Image, 405 | if torch.Tensor, expects it to be in the correct format and returns it as is. 406 | """ 407 | if isinstance(init_image, str): 408 | try: 409 | init_image = Image.open(init_image) 410 | except Exception as e: 411 | init_image = Image.open( 412 | io.BytesIO(standard_b64decode(init_image.split(",")[-1])) 413 | ) 414 | init_image = torch.from_numpy(np.array(init_image)).type(torch.uint8) 415 | elif isinstance(init_image, np.ndarray): 416 | init_image = torch.from_numpy(init_image).type(torch.uint8) 417 | elif isinstance(init_image, Image.Image): 418 | init_image = torch.from_numpy(np.array(init_image)).type(torch.uint8) 419 | 420 | return init_image 421 | 422 | @torch.inference_mode() 423 | def vae_decode(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor: 424 | """Decodes the latent tensor to the pixel space.""" 425 | if self.offload_vae: 426 | self.ae.to(self.device_ae) 427 | x = x.to(self.device_ae) 428 | else: 429 | x = x.to(self.device_ae) 430 | x = self.unpack(x.float(), height, width) 431 | with torch.autocast( 432 | device_type=self.device_ae.type, dtype=torch.bfloat16, cache_enabled=False 433 | ): 434 | x = self.ae.decode(x) 435 | if self.offload_vae: 436 | self.ae.to("cpu") 437 | torch.cuda.empty_cache() 438 | return x 439 | 440 | def unpack(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor: 441 | return rearrange( 442 | x, 443 | "b (h w) (c ph pw) -> b c (h ph) (w pw)", 444 | h=math.ceil(height / 16), 445 | w=math.ceil(width / 16), 446 | ph=2, 447 | pw=2, 448 | ) 449 | 450 | @torch.inference_mode() 451 | def resize_center_crop( 452 | self, img: torch.Tensor, height: int, width: int 453 | ) -> torch.Tensor: 454 | """Resizes and crops the image to the given height and width.""" 455 | img = TF.resize(img, min(width, height)) 456 | img = TF.center_crop(img, (height, width)) 457 | return img 458 | 459 | @torch.inference_mode() 460 | def preprocess_latent( 461 | self, 462 | init_image: torch.Tensor | np.ndarray = None, 463 | height: int = 720, 464 | width: int = 1024, 465 | num_steps: int = 20, 466 | strength: float = 1.0, 467 | generator: torch.Generator = None, 468 | num_images: int = 1, 469 | ) -> tuple[torch.Tensor, List[float]]: 470 | """ 471 | Preprocesses the latent tensor for the given number of steps and image sequence length. 472 | Also, if an initial image is provided, it is vae encoded and injected with the appropriate noise 473 | given the strength and number of steps replacing the latent tensor. 474 | """ 475 | # prepare input 476 | 477 | if init_image is not None: 478 | if isinstance(init_image, np.ndarray): 479 | init_image = torch.from_numpy(init_image) 480 | 481 | init_image = ( 482 | init_image.permute(2, 0, 1) 483 | .contiguous() 484 | .to(self.device_ae, dtype=self.ae_dtype) 485 | .div(127.5) 486 | .sub(1)[None, ...] 487 | ) 488 | init_image = self.resize_center_crop(init_image, height, width) 489 | with torch.autocast( 490 | device_type=self.device_ae.type, 491 | dtype=torch.bfloat16, 492 | cache_enabled=False, 493 | ): 494 | if self.offload_vae: 495 | self.ae.to(self.device_ae) 496 | init_image = ( 497 | self.ae.encode(init_image) 498 | .to(dtype=self.dtype, device=self.device_flux) 499 | .repeat(num_images, 1, 1, 1) 500 | ) 501 | if self.offload_vae: 502 | self.ae.to("cpu") 503 | torch.cuda.empty_cache() 504 | 505 | x = self.get_noise( 506 | num_images, 507 | height, 508 | width, 509 | device=self.device_flux, 510 | dtype=self.dtype, 511 | generator=generator, 512 | ) 513 | timesteps = self.get_schedule( 514 | num_steps=num_steps, 515 | image_seq_len=x.shape[-1] * x.shape[-2] // 4, 516 | shift=(self.name != "flux-schnell"), 517 | ) 518 | if init_image is not None: 519 | t_idx = int((1 - strength) * num_steps) 520 | t = timesteps[t_idx] 521 | timesteps = timesteps[t_idx:] 522 | x = t * x + (1.0 - t) * init_image 523 | return x, timesteps 524 | 525 | @torch.inference_mode() 526 | def generate( 527 | self, 528 | prompt: str, 529 | width: int = 720, 530 | height: int = 1024, 531 | num_steps: int = 24, 532 | guidance: float = 3.5, 533 | seed: int | None = None, 534 | init_image: torch.Tensor | str | Image.Image | np.ndarray | None = None, 535 | strength: float = 1.0, 536 | silent: bool = False, 537 | num_images: int = 1, 538 | return_seed: bool = False, 539 | jpeg_quality: int = 99, 540 | ) -> io.BytesIO: 541 | """ 542 | Generate images based on the given prompt and parameters. 543 | 544 | Args: 545 | prompt `(str)`: The text prompt to guide the image generation. 546 | 547 | width `(int, optional)`: Width of the generated image. Defaults to 720. 548 | 549 | height `(int, optional)`: Height of the generated image. Defaults to 1024. 550 | 551 | num_steps `(int, optional)`: Number of denoising steps. Defaults to 24. 552 | 553 | guidance `(float, optional)`: Guidance scale for text-to-image generation. Defaults to 3.5. 554 | 555 | seed `(int | None, optional)`: Random seed for reproducibility. If None, a random seed is used. Defaults to None. 556 | 557 | init_image `(torch.Tensor | str | Image.Image | np.ndarray | None, optional)`: Initial image for image-to-image generation. Defaults to None. 558 | 559 | -- note: if the image's height/width do not match the height/width of the generated image, the image is resized and centered cropped to match the height/width arguments. 560 | 561 | -- If a string is provided, it is assumed to be either a path to an image file or a base64 encoded image. 562 | 563 | -- If a numpy array is provided, it is assumed to be an RGB numpy array of shape (height, width, 3) and dtype uint8. 564 | 565 | -- If a PIL.Image is provided, it is assumed to be an RGB PIL.Image. 566 | 567 | -- If a torch.Tensor is provided, it is assumed to be a torch.Tensor of shape (height, width, 3) and dtype uint8 with range [0, 255]. 568 | 569 | strength `(float, optional)`: Strength of the init_image in image-to-image generation. Defaults to 1.0. 570 | 571 | silent `(bool, optional)`: If True, suppresses progress bar. Defaults to False. 572 | 573 | num_images `(int, optional)`: Number of images to generate. Defaults to 1. 574 | 575 | return_seed `(bool, optional)`: If True, returns the seed along with the generated image. Defaults to False. 576 | 577 | jpeg_quality `(int, optional)`: Quality of the JPEG compression. Defaults to 99. 578 | 579 | Returns: 580 | io.BytesIO: Generated image(s) in bytes format. 581 | int: Seed used for generation (only if return_seed is True). 582 | """ 583 | num_steps = 4 if self.name == "flux-schnell" else num_steps 584 | 585 | init_image = self.load_init_image_if_needed(init_image) 586 | 587 | # allow for packing and conversion to latent space 588 | height = 16 * (height // 16) 589 | width = 16 * (width // 16) 590 | 591 | generator, seed = self.set_seed(seed) 592 | 593 | if not silent: 594 | logger.info(f"Generating with:\nSeed: {seed}\nPrompt: {prompt}") 595 | 596 | # preprocess the latent 597 | img, timesteps = self.preprocess_latent( 598 | init_image=init_image, 599 | height=height, 600 | width=width, 601 | num_steps=num_steps, 602 | strength=strength, 603 | generator=generator, 604 | num_images=num_images, 605 | ) 606 | 607 | # prepare inputs 608 | img, img_ids, vec, txt, txt_ids = map( 609 | lambda x: x.contiguous(), 610 | self.prepare( 611 | img=img, 612 | prompt=prompt, 613 | target_device=self.device_flux, 614 | target_dtype=self.dtype, 615 | ), 616 | ) 617 | 618 | # this is ignored for schnell 619 | guidance_vec = torch.full( 620 | (img.shape[0],), guidance, device=self.device_flux, dtype=self.dtype 621 | ) 622 | t_vec = None 623 | # dispatch to gpu if offloaded 624 | if self.offload_flow: 625 | self.model.to(self.device_flux) 626 | 627 | # perform the denoising loop 628 | for t_curr, t_prev in tqdm( 629 | zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1, disable=silent 630 | ): 631 | if t_vec is None: 632 | t_vec = torch.full( 633 | (img.shape[0],), 634 | t_curr, 635 | dtype=self.dtype, 636 | device=self.device_flux, 637 | ) 638 | else: 639 | t_vec = t_vec.reshape((img.shape[0],)).fill_(t_curr) 640 | 641 | pred = self.model.forward( 642 | img=img, 643 | img_ids=img_ids, 644 | txt=txt, 645 | txt_ids=txt_ids, 646 | y=vec, 647 | timesteps=t_vec, 648 | guidance=guidance_vec, 649 | ) 650 | 651 | img = img + (t_prev - t_curr) * pred 652 | 653 | # offload the model to cpu if needed 654 | if self.offload_flow: 655 | self.model.to("cpu") 656 | torch.cuda.empty_cache() 657 | 658 | # decode latents to pixel space 659 | img = self.vae_decode(img, height, width) 660 | 661 | if return_seed: 662 | return self.into_bytes(img, jpeg_quality=jpeg_quality), seed 663 | return self.into_bytes(img, jpeg_quality=jpeg_quality) 664 | 665 | @classmethod 666 | def load_pipeline_from_config_path( 667 | cls, path: str, flow_model_path: str = None, debug: bool = False, **kwargs 668 | ) -> "FluxPipeline": 669 | with torch.inference_mode(): 670 | config = load_config_from_path(path) 671 | if flow_model_path: 672 | config.ckpt_path = flow_model_path 673 | for k, v in kwargs.items(): 674 | if hasattr(config, k): 675 | logger.info( 676 | f"Overriding config {k}:{getattr(config, k)} with value {v}" 677 | ) 678 | setattr(config, k, v) 679 | return cls.load_pipeline_from_config(config, debug=debug) 680 | 681 | @classmethod 682 | def load_pipeline_from_config( 683 | cls, config: ModelSpec, debug: bool = False 684 | ) -> "FluxPipeline": 685 | from float8_quantize import quantize_flow_transformer_and_dispatch_float8 686 | 687 | with torch.inference_mode(): 688 | if debug: 689 | logger.info( 690 | f"Loading as prequantized flow transformer? {config.prequantized_flow}" 691 | ) 692 | 693 | models = load_models_from_config(config) 694 | config = models.config 695 | flux_device = into_device(config.flux_device) 696 | ae_device = into_device(config.ae_device) 697 | clip_device = into_device(config.text_enc_device) 698 | t5_device = into_device(config.text_enc_device) 699 | flux_dtype = into_dtype(config.flow_dtype) 700 | flow_model = models.flow 701 | 702 | if not config.prequantized_flow: 703 | flow_model = quantize_flow_transformer_and_dispatch_float8( 704 | flow_model, 705 | flux_device, 706 | offload_flow=config.offload_flow, 707 | swap_linears_with_cublaslinear=flux_dtype == torch.float16, 708 | flow_dtype=flux_dtype, 709 | quantize_modulation=config.quantize_modulation, 710 | quantize_flow_embedder_layers=config.quantize_flow_embedder_layers, 711 | ) 712 | else: 713 | flow_model.eval().requires_grad_(False) 714 | 715 | return cls( 716 | name=config.version, 717 | clip=models.clip, 718 | t5=models.t5, 719 | model=flow_model, 720 | ae=models.ae, 721 | dtype=flux_dtype, 722 | verbose=False, 723 | flux_device=flux_device, 724 | ae_device=ae_device, 725 | clip_device=clip_device, 726 | t5_device=t5_device, 727 | config=config, 728 | debug=debug, 729 | ) 730 | -------------------------------------------------------------------------------- /image_encoder.py: -------------------------------------------------------------------------------- 1 | import io 2 | from PIL import Image 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class ImageEncoder: 8 | 9 | @torch.inference_mode() 10 | def encode_torch(self, img: torch.Tensor, quality=95): 11 | if img.ndim == 2: 12 | img = ( 13 | img[None] 14 | .repeat_interleave(3, dim=0) 15 | .permute(1, 2, 0) 16 | .contiguous() 17 | .clamp(0, 255) 18 | .type(torch.uint8) 19 | ) 20 | elif img.ndim == 3: 21 | if img.shape[0] == 3: 22 | img = img.permute(1, 2, 0).contiguous().clamp(0, 255).type(torch.uint8) 23 | elif img.shape[2] == 3: 24 | img = img.contiguous().clamp(0, 255).type(torch.uint8) 25 | else: 26 | raise ValueError(f"Unsupported image shape: {img.shape}") 27 | else: 28 | raise ValueError(f"Unsupported image num dims: {img.ndim}") 29 | 30 | img = img.cpu().numpy().astype(np.uint8) 31 | im = Image.fromarray(img) 32 | iob = io.BytesIO() 33 | im.save(iob, format="JPEG", quality=quality) 34 | iob.seek(0) 35 | return iob 36 | -------------------------------------------------------------------------------- /lora_loading.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Optional, OrderedDict, Tuple, TypeAlias, Union 3 | import torch 4 | from loguru import logger 5 | from safetensors.torch import load_file 6 | from tqdm import tqdm 7 | from torch import nn 8 | 9 | try: 10 | from cublas_ops import CublasLinear 11 | except Exception as e: 12 | CublasLinear = type(None) 13 | from float8_quantize import F8Linear 14 | from modules.flux_model import Flux 15 | 16 | path_regex = re.compile(r"/|\\") 17 | 18 | StateDict: TypeAlias = OrderedDict[str, torch.Tensor] 19 | 20 | 21 | class LoraWeights: 22 | def __init__( 23 | self, 24 | weights: StateDict, 25 | path: str, 26 | name: str = None, 27 | scale: float = 1.0, 28 | ) -> None: 29 | self.path = path 30 | self.weights = weights 31 | self.name = name if name else path_regex.split(path)[-1] 32 | self.scale = scale 33 | 34 | 35 | def swap_scale_shift(weight): 36 | scale, shift = weight.chunk(2, dim=0) 37 | new_weight = torch.cat([shift, scale], dim=0) 38 | return new_weight 39 | 40 | 41 | def check_if_lora_exists(state_dict, lora_name): 42 | subkey = lora_name.split(".lora_A")[0].split(".lora_B")[0].split(".weight")[0] 43 | for key in state_dict.keys(): 44 | if subkey in key: 45 | return subkey 46 | return False 47 | 48 | 49 | def convert_if_lora_exists(new_state_dict, state_dict, lora_name, flux_layer_name): 50 | if (original_stubkey := check_if_lora_exists(state_dict, lora_name)) != False: 51 | weights_to_pop = [k for k in state_dict.keys() if original_stubkey in k] 52 | for key in weights_to_pop: 53 | key_replacement = key.replace( 54 | original_stubkey, flux_layer_name.replace(".weight", "") 55 | ) 56 | new_state_dict[key_replacement] = state_dict.pop(key) 57 | return new_state_dict, state_dict 58 | else: 59 | return new_state_dict, state_dict 60 | 61 | 62 | def convert_diffusers_to_flux_transformer_checkpoint( 63 | diffusers_state_dict, 64 | num_layers, 65 | num_single_layers, 66 | has_guidance=True, 67 | prefix="", 68 | ): 69 | original_state_dict = {} 70 | 71 | # time_text_embed.timestep_embedder -> time_in 72 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 73 | original_state_dict, 74 | diffusers_state_dict, 75 | f"{prefix}time_text_embed.timestep_embedder.linear_1.weight", 76 | "time_in.in_layer.weight", 77 | ) 78 | # time_text_embed.text_embedder -> vector_in 79 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 80 | original_state_dict, 81 | diffusers_state_dict, 82 | f"{prefix}time_text_embed.text_embedder.linear_1.weight", 83 | "vector_in.in_layer.weight", 84 | ) 85 | 86 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 87 | original_state_dict, 88 | diffusers_state_dict, 89 | f"{prefix}time_text_embed.text_embedder.linear_2.weight", 90 | "vector_in.out_layer.weight", 91 | ) 92 | 93 | if has_guidance: 94 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 95 | original_state_dict, 96 | diffusers_state_dict, 97 | f"{prefix}time_text_embed.guidance_embedder.linear_1.weight", 98 | "guidance_in.in_layer.weight", 99 | ) 100 | 101 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 102 | original_state_dict, 103 | diffusers_state_dict, 104 | f"{prefix}time_text_embed.guidance_embedder.linear_2.weight", 105 | "guidance_in.out_layer.weight", 106 | ) 107 | 108 | # context_embedder -> txt_in 109 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 110 | original_state_dict, 111 | diffusers_state_dict, 112 | f"{prefix}context_embedder.weight", 113 | "txt_in.weight", 114 | ) 115 | 116 | # x_embedder -> img_in 117 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 118 | original_state_dict, 119 | diffusers_state_dict, 120 | f"{prefix}x_embedder.weight", 121 | "img_in.weight", 122 | ) 123 | # double transformer blocks 124 | for i in range(num_layers): 125 | block_prefix = f"transformer_blocks.{i}." 126 | # norms 127 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 128 | original_state_dict, 129 | diffusers_state_dict, 130 | f"{prefix}{block_prefix}norm1.linear.weight", 131 | f"double_blocks.{i}.img_mod.lin.weight", 132 | ) 133 | 134 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 135 | original_state_dict, 136 | diffusers_state_dict, 137 | f"{prefix}{block_prefix}norm1_context.linear.weight", 138 | f"double_blocks.{i}.txt_mod.lin.weight", 139 | ) 140 | 141 | # Q, K, V 142 | temp_dict = {} 143 | 144 | expected_shape_qkv_a = None 145 | expected_shape_qkv_b = None 146 | expected_shape_add_qkv_a = None 147 | expected_shape_add_qkv_b = None 148 | dtype = None 149 | device = None 150 | 151 | for component in [ 152 | "to_q", 153 | "to_k", 154 | "to_v", 155 | "add_q_proj", 156 | "add_k_proj", 157 | "add_v_proj", 158 | ]: 159 | 160 | sample_component_A_key = ( 161 | f"{prefix}{block_prefix}attn.{component}.lora_A.weight" 162 | ) 163 | sample_component_B_key = ( 164 | f"{prefix}{block_prefix}attn.{component}.lora_B.weight" 165 | ) 166 | if ( 167 | sample_component_A_key in diffusers_state_dict 168 | and sample_component_B_key in diffusers_state_dict 169 | ): 170 | sample_component_A = diffusers_state_dict.pop(sample_component_A_key) 171 | sample_component_B = diffusers_state_dict.pop(sample_component_B_key) 172 | temp_dict[f"{component}"] = [sample_component_A, sample_component_B] 173 | if expected_shape_qkv_a is None and not component.startswith("add_"): 174 | expected_shape_qkv_a = sample_component_A.shape 175 | expected_shape_qkv_b = sample_component_B.shape 176 | dtype = sample_component_A.dtype 177 | device = sample_component_A.device 178 | if expected_shape_add_qkv_a is None and component.startswith("add_"): 179 | expected_shape_add_qkv_a = sample_component_A.shape 180 | expected_shape_add_qkv_b = sample_component_B.shape 181 | dtype = sample_component_A.dtype 182 | device = sample_component_A.device 183 | else: 184 | logger.info( 185 | f"Skipping layer {i} since no LoRA weight is available for {sample_component_A_key}" 186 | ) 187 | temp_dict[f"{component}"] = [None, None] 188 | 189 | if device is not None: 190 | if expected_shape_qkv_a is not None: 191 | 192 | if (sq := temp_dict["to_q"])[0] is not None: 193 | sample_q_A, sample_q_B = sq 194 | else: 195 | sample_q_A, sample_q_B = [ 196 | torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device), 197 | torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device), 198 | ] 199 | if (sq := temp_dict["to_k"])[0] is not None: 200 | sample_k_A, sample_k_B = sq 201 | else: 202 | sample_k_A, sample_k_B = [ 203 | torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device), 204 | torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device), 205 | ] 206 | if (sq := temp_dict["to_v"])[0] is not None: 207 | sample_v_A, sample_v_B = sq 208 | else: 209 | sample_v_A, sample_v_B = [ 210 | torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device), 211 | torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device), 212 | ] 213 | original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_A.weight"] = ( 214 | torch.cat([sample_q_A, sample_k_A, sample_v_A], dim=0) 215 | ) 216 | original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_B.weight"] = ( 217 | torch.cat([sample_q_B, sample_k_B, sample_v_B], dim=0) 218 | ) 219 | if expected_shape_add_qkv_a is not None: 220 | 221 | if (sq := temp_dict["add_q_proj"])[0] is not None: 222 | context_q_A, context_q_B = sq 223 | else: 224 | context_q_A, context_q_B = [ 225 | torch.zeros( 226 | expected_shape_add_qkv_a, dtype=dtype, device=device 227 | ), 228 | torch.zeros( 229 | expected_shape_add_qkv_b, dtype=dtype, device=device 230 | ), 231 | ] 232 | if (sq := temp_dict["add_k_proj"])[0] is not None: 233 | context_k_A, context_k_B = sq 234 | else: 235 | context_k_A, context_k_B = [ 236 | torch.zeros( 237 | expected_shape_add_qkv_a, dtype=dtype, device=device 238 | ), 239 | torch.zeros( 240 | expected_shape_add_qkv_b, dtype=dtype, device=device 241 | ), 242 | ] 243 | if (sq := temp_dict["add_v_proj"])[0] is not None: 244 | context_v_A, context_v_B = sq 245 | else: 246 | context_v_A, context_v_B = [ 247 | torch.zeros( 248 | expected_shape_add_qkv_a, dtype=dtype, device=device 249 | ), 250 | torch.zeros( 251 | expected_shape_add_qkv_b, dtype=dtype, device=device 252 | ), 253 | ] 254 | 255 | original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_A.weight"] = ( 256 | torch.cat([context_q_A, context_k_A, context_v_A], dim=0) 257 | ) 258 | original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_B.weight"] = ( 259 | torch.cat([context_q_B, context_k_B, context_v_B], dim=0) 260 | ) 261 | 262 | # qk_norm 263 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 264 | original_state_dict, 265 | diffusers_state_dict, 266 | f"{prefix}{block_prefix}attn.norm_q.weight", 267 | f"double_blocks.{i}.img_attn.norm.query_norm.scale", 268 | ) 269 | 270 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 271 | original_state_dict, 272 | diffusers_state_dict, 273 | f"{prefix}{block_prefix}attn.norm_k.weight", 274 | f"double_blocks.{i}.img_attn.norm.key_norm.scale", 275 | ) 276 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 277 | original_state_dict, 278 | diffusers_state_dict, 279 | f"{prefix}{block_prefix}attn.norm_added_q.weight", 280 | f"double_blocks.{i}.txt_attn.norm.query_norm.scale", 281 | ) 282 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 283 | original_state_dict, 284 | diffusers_state_dict, 285 | f"{prefix}{block_prefix}attn.norm_added_k.weight", 286 | f"double_blocks.{i}.txt_attn.norm.key_norm.scale", 287 | ) 288 | 289 | # ff img_mlp 290 | 291 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 292 | original_state_dict, 293 | diffusers_state_dict, 294 | f"{prefix}{block_prefix}ff.net.0.proj.weight", 295 | f"double_blocks.{i}.img_mlp.0.weight", 296 | ) 297 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 298 | original_state_dict, 299 | diffusers_state_dict, 300 | f"{prefix}{block_prefix}ff.net.2.weight", 301 | f"double_blocks.{i}.img_mlp.2.weight", 302 | ) 303 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 304 | original_state_dict, 305 | diffusers_state_dict, 306 | f"{prefix}{block_prefix}ff_context.net.0.proj.weight", 307 | f"double_blocks.{i}.txt_mlp.0.weight", 308 | ) 309 | 310 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 311 | original_state_dict, 312 | diffusers_state_dict, 313 | f"{prefix}{block_prefix}ff_context.net.2.weight", 314 | f"double_blocks.{i}.txt_mlp.2.weight", 315 | ) 316 | # output projections 317 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 318 | original_state_dict, 319 | diffusers_state_dict, 320 | f"{prefix}{block_prefix}attn.to_out.0.weight", 321 | f"double_blocks.{i}.img_attn.proj.weight", 322 | ) 323 | 324 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 325 | original_state_dict, 326 | diffusers_state_dict, 327 | f"{prefix}{block_prefix}attn.to_add_out.weight", 328 | f"double_blocks.{i}.txt_attn.proj.weight", 329 | ) 330 | 331 | # single transformer blocks 332 | for i in range(num_single_layers): 333 | block_prefix = f"single_transformer_blocks.{i}." 334 | # norm.linear -> single_blocks.0.modulation.lin 335 | key_norm = f"{prefix}{block_prefix}norm.linear.weight" 336 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 337 | original_state_dict, 338 | diffusers_state_dict, 339 | key_norm, 340 | f"single_blocks.{i}.modulation.lin.weight", 341 | ) 342 | 343 | has_q, has_k, has_v, has_mlp = False, False, False, False 344 | shape_qkv_a = None 345 | shape_qkv_b = None 346 | # Q, K, V, mlp 347 | q_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight") 348 | q_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight") 349 | if q_A is not None and q_B is not None: 350 | has_q = True 351 | shape_qkv_a = q_A.shape 352 | shape_qkv_b = q_B.shape 353 | k_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight") 354 | k_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight") 355 | if k_A is not None and k_B is not None: 356 | has_k = True 357 | shape_qkv_a = k_A.shape 358 | shape_qkv_b = k_B.shape 359 | v_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight") 360 | v_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight") 361 | if v_A is not None and v_B is not None: 362 | has_v = True 363 | shape_qkv_a = v_A.shape 364 | shape_qkv_b = v_B.shape 365 | mlp_A = diffusers_state_dict.pop( 366 | f"{prefix}{block_prefix}proj_mlp.lora_A.weight" 367 | ) 368 | mlp_B = diffusers_state_dict.pop( 369 | f"{prefix}{block_prefix}proj_mlp.lora_B.weight" 370 | ) 371 | if mlp_A is not None and mlp_B is not None: 372 | has_mlp = True 373 | shape_qkv_a = mlp_A.shape 374 | shape_qkv_b = mlp_B.shape 375 | if any([has_q, has_k, has_v, has_mlp]): 376 | if not has_q: 377 | q_A, q_B = [ 378 | torch.zeros(shape_qkv_a, dtype=dtype, device=device), 379 | torch.zeros(shape_qkv_b, dtype=dtype, device=device), 380 | ] 381 | if not has_k: 382 | k_A, k_B = [ 383 | torch.zeros(shape_qkv_a, dtype=dtype, device=device), 384 | torch.zeros(shape_qkv_b, dtype=dtype, device=device), 385 | ] 386 | if not has_v: 387 | v_A, v_B = [ 388 | torch.zeros(shape_qkv_a, dtype=dtype, device=device), 389 | torch.zeros(shape_qkv_b, dtype=dtype, device=device), 390 | ] 391 | if not has_mlp: 392 | mlp_A, mlp_B = [ 393 | torch.zeros(shape_qkv_a, dtype=dtype, device=device), 394 | torch.zeros(shape_qkv_b, dtype=dtype, device=device), 395 | ] 396 | original_state_dict[f"single_blocks.{i}.linear1.lora_A.weight"] = torch.cat( 397 | [q_A, k_A, v_A, mlp_A], dim=0 398 | ) 399 | original_state_dict[f"single_blocks.{i}.linear1.lora_B.weight"] = torch.cat( 400 | [q_B, k_B, v_B, mlp_B], dim=0 401 | ) 402 | 403 | # output projections 404 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 405 | original_state_dict, 406 | diffusers_state_dict, 407 | f"{prefix}{block_prefix}proj_out.weight", 408 | f"single_blocks.{i}.linear2.weight", 409 | ) 410 | 411 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 412 | original_state_dict, 413 | diffusers_state_dict, 414 | f"{prefix}proj_out.weight", 415 | "final_layer.linear.weight", 416 | ) 417 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 418 | original_state_dict, 419 | diffusers_state_dict, 420 | f"{prefix}proj_out.bias", 421 | "final_layer.linear.bias", 422 | ) 423 | original_state_dict, diffusers_state_dict = convert_if_lora_exists( 424 | original_state_dict, 425 | diffusers_state_dict, 426 | f"{prefix}norm_out.linear.weight", 427 | "final_layer.adaLN_modulation.1.weight", 428 | ) 429 | if len(list(diffusers_state_dict.keys())) > 0: 430 | logger.warning("Unexpected keys:", diffusers_state_dict.keys()) 431 | 432 | return original_state_dict 433 | 434 | 435 | def convert_from_original_flux_checkpoint(original_state_dict: StateDict) -> StateDict: 436 | """ 437 | Convert the state dict from the original Flux checkpoint format to the new format. 438 | 439 | Args: 440 | original_state_dict (Dict[str, torch.Tensor]): The original Flux checkpoint state dict. 441 | 442 | Returns: 443 | Dict[str, torch.Tensor]: The converted state dict in the new format. 444 | """ 445 | sd = { 446 | k.replace("lora_unet_", "") 447 | .replace("double_blocks_", "double_blocks.") 448 | .replace("single_blocks_", "single_blocks.") 449 | .replace("_img_attn_", ".img_attn.") 450 | .replace("_txt_attn_", ".txt_attn.") 451 | .replace("_img_mod_", ".img_mod.") 452 | .replace("_txt_mod_", ".txt_mod.") 453 | .replace("_img_mlp_", ".img_mlp.") 454 | .replace("_txt_mlp_", ".txt_mlp.") 455 | .replace("_linear1", ".linear1") 456 | .replace("_linear2", ".linear2") 457 | .replace("_modulation_", ".modulation.") 458 | .replace("lora_up", "lora_B") 459 | .replace("lora_down", "lora_A"): v 460 | for k, v in original_state_dict.items() 461 | if "lora" in k 462 | } 463 | return sd 464 | 465 | 466 | def get_module_for_key( 467 | key: str, model: Flux 468 | ) -> F8Linear | torch.nn.Linear | CublasLinear: 469 | parts = key.split(".") 470 | module = model 471 | for part in parts: 472 | module = getattr(module, part) 473 | return module 474 | 475 | 476 | def get_lora_for_key( 477 | key: str, lora_weights: dict 478 | ) -> Optional[Tuple[torch.Tensor, torch.Tensor, Optional[float]]]: 479 | """ 480 | Get LoRA weights for a specific key. 481 | 482 | Args: 483 | key (str): The key to look up in the LoRA weights. 484 | lora_weights (dict): Dictionary containing LoRA weights. 485 | 486 | Returns: 487 | Optional[Tuple[torch.Tensor, torch.Tensor, Optional[float]]]: A tuple containing lora_A, lora_B, and alpha if found, None otherwise. 488 | """ 489 | prefix = key.split(".lora")[0] 490 | lora_A = lora_weights.get(f"{prefix}.lora_A.weight") 491 | lora_B = lora_weights.get(f"{prefix}.lora_B.weight") 492 | alpha = lora_weights.get(f"{prefix}.alpha") 493 | 494 | if lora_A is None or lora_B is None: 495 | return None 496 | return lora_A, lora_B, alpha 497 | 498 | 499 | def get_module_for_key( 500 | key: str, model: Flux 501 | ) -> F8Linear | torch.nn.Linear | CublasLinear: 502 | parts = key.split(".") 503 | module = model 504 | for part in parts: 505 | module = getattr(module, part) 506 | return module 507 | 508 | 509 | def calculate_lora_weight( 510 | lora_weights: Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, float]], 511 | rank: Optional[int] = None, 512 | lora_scale: float = 1.0, 513 | device: Optional[Union[torch.device, int, str]] = None, 514 | ): 515 | lora_A, lora_B, alpha = lora_weights 516 | if device is None: 517 | device = lora_A.device 518 | 519 | uneven_rank = lora_B.shape[1] != lora_A.shape[0] 520 | rank_diff = lora_A.shape[0] / lora_B.shape[1] 521 | 522 | if rank is None: 523 | rank = lora_B.shape[1] 524 | if alpha is None: 525 | alpha = rank 526 | 527 | dtype = torch.float32 528 | w_up = lora_A.to(dtype=dtype, device=device) 529 | w_down = lora_B.to(dtype=dtype, device=device) 530 | 531 | if alpha != rank: 532 | w_up = w_up * alpha / rank 533 | if uneven_rank: 534 | # Fuse each lora instead of repeat interleave for each individual lora, 535 | # seems to fuse more correctly. 536 | fused_lora = torch.zeros( 537 | (lora_B.shape[0], lora_A.shape[1]), device=device, dtype=dtype 538 | ) 539 | w_up = w_up.chunk(int(rank_diff), dim=0) 540 | for w_up_chunk in w_up: 541 | fused_lora = fused_lora + (lora_scale * torch.mm(w_down, w_up_chunk)) 542 | else: 543 | fused_lora = lora_scale * torch.mm(w_down, w_up) 544 | return fused_lora 545 | 546 | 547 | @torch.inference_mode() 548 | def unfuse_lora_weight_from_module( 549 | fused_weight: torch.Tensor, 550 | lora_weights: dict, 551 | rank: Optional[int] = None, 552 | lora_scale: float = 1.0, 553 | ): 554 | w_dtype = fused_weight.dtype 555 | dtype = torch.float32 556 | device = fused_weight.device 557 | 558 | fused_weight = fused_weight.to(dtype=dtype, device=device) 559 | fused_lora = calculate_lora_weight(lora_weights, rank, lora_scale, device=device) 560 | module_weight = fused_weight - fused_lora 561 | return module_weight.to(dtype=w_dtype, device=device) 562 | 563 | 564 | @torch.inference_mode() 565 | def apply_lora_weight_to_module( 566 | module_weight: torch.Tensor, 567 | lora_weights: dict, 568 | rank: int = None, 569 | lora_scale: float = 1.0, 570 | ): 571 | w_dtype = module_weight.dtype 572 | dtype = torch.float32 573 | device = module_weight.device 574 | 575 | fused_lora = calculate_lora_weight(lora_weights, rank, lora_scale, device=device) 576 | fused_weight = module_weight.to(dtype=dtype) + fused_lora 577 | return fused_weight.to(dtype=w_dtype, device=device) 578 | 579 | 580 | def resolve_lora_state_dict(lora_weights, has_guidance: bool = True): 581 | check_if_starts_with_transformer = [ 582 | k for k in lora_weights.keys() if k.startswith("transformer.") 583 | ] 584 | if len(check_if_starts_with_transformer) > 0: 585 | lora_weights = convert_diffusers_to_flux_transformer_checkpoint( 586 | lora_weights, 19, 38, has_guidance=has_guidance, prefix="transformer." 587 | ) 588 | else: 589 | lora_weights = convert_from_original_flux_checkpoint(lora_weights) 590 | logger.info("LoRA weights loaded") 591 | logger.debug("Extracting keys") 592 | keys_without_ab = list( 593 | set( 594 | [ 595 | key.replace(".lora_A.weight", "") 596 | .replace(".lora_B.weight", "") 597 | .replace(".lora_A", "") 598 | .replace(".lora_B", "") 599 | .replace(".alpha", "") 600 | for key in lora_weights.keys() 601 | ] 602 | ) 603 | ) 604 | logger.debug("Keys extracted") 605 | return keys_without_ab, lora_weights 606 | 607 | 608 | def get_lora_weights(lora_path: str | StateDict): 609 | if isinstance(lora_path, (dict, LoraWeights)): 610 | return lora_path, True 611 | else: 612 | return load_file(lora_path, "cpu"), False 613 | 614 | 615 | def extract_weight_from_linear(linear: Union[nn.Linear, CublasLinear, F8Linear]): 616 | dtype = linear.weight.dtype 617 | weight_is_f8 = False 618 | if isinstance(linear, F8Linear): 619 | weight_is_f8 = True 620 | weight = ( 621 | linear.float8_data.clone() 622 | .detach() 623 | .float() 624 | .mul(linear.scale_reciprocal) 625 | .to(linear.weight.device) 626 | ) 627 | elif isinstance(linear, torch.nn.Linear): 628 | weight = linear.weight.clone().detach().float() 629 | elif isinstance(linear, CublasLinear) and CublasLinear != type(None): 630 | weight = linear.weight.clone().detach().float() 631 | return weight, weight_is_f8, dtype 632 | 633 | 634 | @torch.inference_mode() 635 | def apply_lora_to_model( 636 | model: Flux, 637 | lora_path: str | StateDict, 638 | lora_scale: float = 1.0, 639 | return_lora_resolved: bool = False, 640 | ) -> Flux: 641 | has_guidance = model.params.guidance_embed 642 | logger.info(f"Loading LoRA weights for {lora_path}") 643 | lora_weights, already_loaded = get_lora_weights(lora_path) 644 | 645 | if not already_loaded: 646 | keys_without_ab, lora_weights = resolve_lora_state_dict( 647 | lora_weights, has_guidance 648 | ) 649 | elif isinstance(lora_weights, LoraWeights): 650 | b_ = lora_weights 651 | lora_weights = b_.weights 652 | keys_without_ab = list( 653 | set( 654 | [ 655 | key.replace(".lora_A.weight", "") 656 | .replace(".lora_B.weight", "") 657 | .replace(".lora_A", "") 658 | .replace(".lora_B", "") 659 | .replace(".alpha", "") 660 | for key in lora_weights.keys() 661 | ] 662 | ) 663 | ) 664 | else: 665 | lora_weights = lora_weights 666 | keys_without_ab = list( 667 | set( 668 | [ 669 | key.replace(".lora_A.weight", "") 670 | .replace(".lora_B.weight", "") 671 | .replace(".lora_A", "") 672 | .replace(".lora_B", "") 673 | .replace(".alpha", "") 674 | for key in lora_weights.keys() 675 | ] 676 | ) 677 | ) 678 | for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)): 679 | module = get_module_for_key(key, model) 680 | weight, is_f8, dtype = extract_weight_from_linear(module) 681 | lora_sd = get_lora_for_key(key, lora_weights) 682 | if lora_sd is None: 683 | # Skipping LoRA application for this module 684 | continue 685 | weight = apply_lora_weight_to_module(weight, lora_sd, lora_scale=lora_scale) 686 | if is_f8: 687 | module.set_weight_tensor(weight.type(dtype)) 688 | else: 689 | module.weight.data = weight.type(dtype) 690 | logger.success("Lora applied") 691 | if return_lora_resolved: 692 | return model, lora_weights 693 | return model 694 | 695 | 696 | def remove_lora_from_module( 697 | model: Flux, 698 | lora_path: str | StateDict, 699 | lora_scale: float = 1.0, 700 | ): 701 | has_guidance = model.params.guidance_embed 702 | logger.info(f"Loading LoRA weights for {lora_path}") 703 | lora_weights, already_loaded = get_lora_weights(lora_path) 704 | 705 | if not already_loaded: 706 | keys_without_ab, lora_weights = resolve_lora_state_dict( 707 | lora_weights, has_guidance 708 | ) 709 | elif isinstance(lora_weights, LoraWeights): 710 | b_ = lora_weights 711 | lora_weights = b_.weights 712 | keys_without_ab = list( 713 | set( 714 | [ 715 | key.replace(".lora_A.weight", "") 716 | .replace(".lora_B.weight", "") 717 | .replace(".lora_A", "") 718 | .replace(".lora_B", "") 719 | .replace(".alpha", "") 720 | for key in lora_weights.keys() 721 | ] 722 | ) 723 | ) 724 | lora_scale = b_.scale 725 | else: 726 | lora_weights = lora_weights 727 | keys_without_ab = list( 728 | set( 729 | [ 730 | key.replace(".lora_A.weight", "") 731 | .replace(".lora_B.weight", "") 732 | .replace(".lora_A", "") 733 | .replace(".lora_B", "") 734 | .replace(".alpha", "") 735 | for key in lora_weights.keys() 736 | ] 737 | ) 738 | ) 739 | 740 | for key in tqdm(keys_without_ab, desc="Unfusing LoRA", total=len(keys_without_ab)): 741 | module = get_module_for_key(key, model) 742 | weight, is_f8, dtype = extract_weight_from_linear(module) 743 | lora_sd = get_lora_for_key(key, lora_weights) 744 | if lora_sd is None: 745 | # Skipping LoRA application for this module 746 | continue 747 | weight = unfuse_lora_weight_from_module(weight, lora_sd, lora_scale=lora_scale) 748 | if is_f8: 749 | module.set_weight_tensor(weight.type(dtype)) 750 | else: 751 | module.weight.data = weight.type(dtype) 752 | logger.success("Lora unfused") 753 | return model 754 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import uvicorn 3 | from api import app 4 | 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description="Launch Flux API server") 8 | parser.add_argument( 9 | "-c", 10 | "--config-path", 11 | type=str, 12 | help="Path to the configuration file, if not provided, the model will be loaded from the command line arguments", 13 | ) 14 | parser.add_argument( 15 | "-p", 16 | "--port", 17 | type=int, 18 | default=8088, 19 | help="Port to run the server on", 20 | ) 21 | parser.add_argument( 22 | "-H", 23 | "--host", 24 | type=str, 25 | default="0.0.0.0", 26 | help="Host to run the server on", 27 | ) 28 | parser.add_argument( 29 | "-f", "--flow-model-path", type=str, help="Path to the flow model" 30 | ) 31 | parser.add_argument( 32 | "-t", "--text-enc-path", type=str, help="Path to the text encoder" 33 | ) 34 | parser.add_argument( 35 | "-a", "--autoencoder-path", type=str, help="Path to the autoencoder" 36 | ) 37 | parser.add_argument( 38 | "-m", 39 | "--model-version", 40 | type=str, 41 | choices=["flux-dev", "flux-schnell"], 42 | default="flux-dev", 43 | help="Choose model version", 44 | ) 45 | parser.add_argument( 46 | "-F", 47 | "--flux-device", 48 | type=str, 49 | default="cuda:0", 50 | help="Device to run the flow model on", 51 | ) 52 | parser.add_argument( 53 | "-T", 54 | "--text-enc-device", 55 | type=str, 56 | default="cuda:0", 57 | help="Device to run the text encoder on", 58 | ) 59 | parser.add_argument( 60 | "-A", 61 | "--autoencoder-device", 62 | type=str, 63 | default="cuda:0", 64 | help="Device to run the autoencoder on", 65 | ) 66 | parser.add_argument( 67 | "-q", 68 | "--num-to-quant", 69 | type=int, 70 | default=20, 71 | help="Number of linear layers in flow transformer (the 'unet') to quantize", 72 | ) 73 | parser.add_argument( 74 | "-C", 75 | "--compile", 76 | action="store_true", 77 | default=False, 78 | help="Compile the flow model with extra optimizations", 79 | ) 80 | parser.add_argument( 81 | "-qT", 82 | "--quant-text-enc", 83 | type=str, 84 | default="qfloat8", 85 | choices=["qint4", "qfloat8", "qint2", "qint8", "bf16"], 86 | help="Quantize the t5 text encoder to the given dtype, if bf16, will not quantize", 87 | dest="quant_text_enc", 88 | ) 89 | parser.add_argument( 90 | "-qA", 91 | "--quant-ae", 92 | action="store_true", 93 | default=False, 94 | help="Quantize the autoencoder with float8 linear layers, otherwise will use bfloat16", 95 | dest="quant_ae", 96 | ) 97 | parser.add_argument( 98 | "-OF", 99 | "--offload-flow", 100 | action="store_true", 101 | default=False, 102 | dest="offload_flow", 103 | help="Offload the flow model to the CPU when not being used to save memory", 104 | ) 105 | parser.add_argument( 106 | "-OA", 107 | "--no-offload-ae", 108 | action="store_false", 109 | default=True, 110 | dest="offload_ae", 111 | help="Disable offloading the autoencoder to the CPU when not being used to increase e2e inference speed", 112 | ) 113 | parser.add_argument( 114 | "-OT", 115 | "--no-offload-text-enc", 116 | action="store_false", 117 | default=True, 118 | dest="offload_text_enc", 119 | help="Disable offloading the text encoder to the CPU when not being used to increase e2e inference speed", 120 | ) 121 | parser.add_argument( 122 | "-PF", 123 | "--prequantized-flow", 124 | action="store_true", 125 | default=False, 126 | dest="prequantized_flow", 127 | help="Load the flow model from a prequantized checkpoint " 128 | + "(requires loading the flow model, running a minimum of 24 steps, " 129 | + "and then saving the state_dict as a safetensors file), " 130 | + "which reduces the size of the checkpoint by about 50% & reduces startup time", 131 | ) 132 | parser.add_argument( 133 | "-nqfm", 134 | "--no-quantize-flow-modulation", 135 | action="store_false", 136 | default=True, 137 | dest="quantize_modulation", 138 | help="Disable quantization of the modulation layers in the flow model, adds ~2GB vram usage for moderate precision improvements", 139 | ) 140 | parser.add_argument( 141 | "-qfl", 142 | "--quantize-flow-embedder-layers", 143 | action="store_true", 144 | default=False, 145 | dest="quantize_flow_embedder_layers", 146 | help="Quantize the flow embedder layers in the flow model, saves ~512MB vram usage, but precision loss is very noticeable", 147 | ) 148 | return parser.parse_args() 149 | 150 | 151 | def main(): 152 | args = parse_args() 153 | 154 | # lazy loading so cli returns fast instead of waiting for torch to load modules 155 | from flux_pipeline import FluxPipeline 156 | from util import load_config, ModelVersion 157 | 158 | if args.config_path: 159 | app.state.model = FluxPipeline.load_pipeline_from_config_path( 160 | args.config_path, flow_model_path=args.flow_model_path 161 | ) 162 | else: 163 | model_version = ( 164 | ModelVersion.flux_dev 165 | if args.model_version == "flux-dev" 166 | else ModelVersion.flux_schnell 167 | ) 168 | config = load_config( 169 | model_version, 170 | flux_path=args.flow_model_path, 171 | flux_device=args.flux_device, 172 | ae_path=args.autoencoder_path, 173 | ae_device=args.autoencoder_device, 174 | text_enc_path=args.text_enc_path, 175 | text_enc_device=args.text_enc_device, 176 | flow_dtype="float16", 177 | text_enc_dtype="bfloat16", 178 | ae_dtype="bfloat16", 179 | num_to_quant=args.num_to_quant, 180 | compile_extras=args.compile, 181 | compile_blocks=args.compile, 182 | quant_text_enc=( 183 | None if args.quant_text_enc == "bf16" else args.quant_text_enc 184 | ), 185 | quant_ae=args.quant_ae, 186 | offload_flow=args.offload_flow, 187 | offload_ae=args.offload_ae, 188 | offload_text_enc=args.offload_text_enc, 189 | prequantized_flow=args.prequantized_flow, 190 | quantize_modulation=args.quantize_modulation, 191 | quantize_flow_embedder_layers=args.quantize_flow_embedder_layers, 192 | ) 193 | app.state.model = FluxPipeline.load_pipeline_from_config(config) 194 | 195 | uvicorn.run(app, host=args.host, port=args.port) 196 | 197 | 198 | if __name__ == "__main__": 199 | main() 200 | -------------------------------------------------------------------------------- /main_gr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from flux_pipeline import FluxPipeline 4 | import gradio as gr # type: ignore 5 | from PIL import Image 6 | 7 | 8 | def create_demo( 9 | config_path: str, 10 | ): 11 | generator = FluxPipeline.load_pipeline_from_config_path(config_path) 12 | 13 | def generate_image( 14 | prompt, 15 | width, 16 | height, 17 | num_steps, 18 | guidance, 19 | seed, 20 | init_image, 21 | image2image_strength, 22 | add_sampling_metadata, 23 | ): 24 | 25 | seed = int(seed) 26 | if seed == -1: 27 | seed = None 28 | out = generator.generate( 29 | prompt, 30 | width, 31 | height, 32 | num_steps=num_steps, 33 | guidance=guidance, 34 | seed=seed, 35 | init_image=init_image, 36 | strength=image2image_strength, 37 | silent=False, 38 | num_images=1, 39 | return_seed=True, 40 | ) 41 | image_bytes = out[0] 42 | return Image.open(image_bytes), str(out[1]), None 43 | 44 | is_schnell = generator.config.version == "flux-schnell" 45 | 46 | with gr.Blocks() as demo: 47 | gr.Markdown(f"# Flux Image Generation Demo - Model: {generator.config.version}") 48 | 49 | with gr.Row(): 50 | with gr.Column(): 51 | prompt = gr.Textbox( 52 | label="Prompt", 53 | value='a photo of a forest with mist swirling around the tree trunks. The word "FLUX" is painted over it in big, red brush strokes with visible texture', 54 | ) 55 | do_img2img = gr.Checkbox( 56 | label="Image to Image", value=False, interactive=not is_schnell 57 | ) 58 | init_image = gr.Image(label="Input Image", visible=False) 59 | image2image_strength = gr.Slider( 60 | 0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False 61 | ) 62 | 63 | with gr.Accordion("Advanced Options", open=False): 64 | width = gr.Slider(128, 8192, 1152, step=16, label="Width") 65 | height = gr.Slider(128, 8192, 640, step=16, label="Height") 66 | num_steps = gr.Slider( 67 | 1, 50, 4 if is_schnell else 20, step=1, label="Number of steps" 68 | ) 69 | guidance = gr.Slider( 70 | 1.0, 71 | 10.0, 72 | 3.5, 73 | step=0.1, 74 | label="Guidance", 75 | interactive=not is_schnell, 76 | ) 77 | seed = gr.Textbox(-1, label="Seed (-1 for random)") 78 | add_sampling_metadata = gr.Checkbox( 79 | label="Add sampling parameters to metadata?", value=True 80 | ) 81 | 82 | generate_btn = gr.Button("Generate") 83 | 84 | with gr.Column(min_width="960px"): 85 | output_image = gr.Image(label="Generated Image") 86 | seed_output = gr.Number(label="Used Seed") 87 | warning_text = gr.Textbox(label="Warning", visible=False) 88 | # download_btn = gr.File(label="Download full-resolution") 89 | 90 | def update_img2img(do_img2img): 91 | return { 92 | init_image: gr.update(visible=do_img2img), 93 | image2image_strength: gr.update(visible=do_img2img), 94 | } 95 | 96 | do_img2img.change( 97 | update_img2img, do_img2img, [init_image, image2image_strength] 98 | ) 99 | 100 | generate_btn.click( 101 | fn=generate_image, 102 | inputs=[ 103 | prompt, 104 | width, 105 | height, 106 | num_steps, 107 | guidance, 108 | seed, 109 | init_image, 110 | image2image_strength, 111 | add_sampling_metadata, 112 | ], 113 | outputs=[output_image, seed_output, warning_text], 114 | ) 115 | 116 | return demo 117 | 118 | 119 | if __name__ == "__main__": 120 | import argparse 121 | 122 | parser = argparse.ArgumentParser(description="Flux") 123 | parser.add_argument( 124 | "--config", type=str, default="configs/config-dev.json", help="Config file path" 125 | ) 126 | parser.add_argument( 127 | "--share", action="store_true", help="Create a public link to your demo" 128 | ) 129 | args = parser.parse_args() 130 | 131 | demo = create_demo(args.config) 132 | demo.launch(share=args.share) 133 | -------------------------------------------------------------------------------- /modules/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import Tensor, nn 4 | from pydantic import BaseModel 5 | 6 | 7 | class AutoEncoderParams(BaseModel): 8 | resolution: int 9 | in_channels: int 10 | ch: int 11 | out_ch: int 12 | ch_mult: list[int] 13 | num_res_blocks: int 14 | z_channels: int 15 | scale_factor: float 16 | shift_factor: float 17 | 18 | 19 | def swish(x: Tensor) -> Tensor: 20 | return x * torch.sigmoid(x) 21 | 22 | 23 | class AttnBlock(nn.Module): 24 | def __init__(self, in_channels: int): 25 | super().__init__() 26 | self.in_channels = in_channels 27 | 28 | self.norm = nn.GroupNorm( 29 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 30 | ) 31 | 32 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 33 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 34 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 35 | self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) 36 | 37 | def attention(self, h_: Tensor) -> Tensor: 38 | h_ = self.norm(h_) 39 | q = self.q(h_) 40 | k = self.k(h_) 41 | v = self.v(h_) 42 | 43 | b, c, h, w = q.shape 44 | q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 45 | k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 46 | v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 47 | h_ = nn.functional.scaled_dot_product_attention(q, k, v) 48 | 49 | return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) 50 | 51 | def forward(self, x: Tensor) -> Tensor: 52 | return x + self.proj_out(self.attention(x)) 53 | 54 | 55 | class ResnetBlock(nn.Module): 56 | def __init__(self, in_channels: int, out_channels: int): 57 | super().__init__() 58 | self.in_channels = in_channels 59 | out_channels = in_channels if out_channels is None else out_channels 60 | self.out_channels = out_channels 61 | 62 | self.norm1 = nn.GroupNorm( 63 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 64 | ) 65 | self.conv1 = nn.Conv2d( 66 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 67 | ) 68 | self.norm2 = nn.GroupNorm( 69 | num_groups=32, num_channels=out_channels, eps=1e-6, affine=True 70 | ) 71 | self.conv2 = nn.Conv2d( 72 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 73 | ) 74 | if self.in_channels != self.out_channels: 75 | self.nin_shortcut = nn.Conv2d( 76 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 77 | ) 78 | 79 | def forward(self, x): 80 | h = x 81 | h = self.norm1(h) 82 | h = swish(h) 83 | h = self.conv1(h) 84 | 85 | h = self.norm2(h) 86 | h = swish(h) 87 | h = self.conv2(h) 88 | 89 | if self.in_channels != self.out_channels: 90 | x = self.nin_shortcut(x) 91 | 92 | return x + h 93 | 94 | 95 | class Downsample(nn.Module): 96 | def __init__(self, in_channels: int): 97 | super().__init__() 98 | # no asymmetric padding in torch conv, must do it ourselves 99 | self.conv = nn.Conv2d( 100 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 101 | ) 102 | 103 | def forward(self, x: Tensor): 104 | pad = (0, 1, 0, 1) 105 | x = nn.functional.pad(x, pad, mode="constant", value=0) 106 | x = self.conv(x) 107 | return x 108 | 109 | 110 | class Upsample(nn.Module): 111 | def __init__(self, in_channels: int): 112 | super().__init__() 113 | self.conv = nn.Conv2d( 114 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 115 | ) 116 | 117 | def forward(self, x: Tensor): 118 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 119 | x = self.conv(x) 120 | return x 121 | 122 | 123 | class Encoder(nn.Module): 124 | def __init__( 125 | self, 126 | resolution: int, 127 | in_channels: int, 128 | ch: int, 129 | ch_mult: list[int], 130 | num_res_blocks: int, 131 | z_channels: int, 132 | ): 133 | super().__init__() 134 | self.ch = ch 135 | self.num_resolutions = len(ch_mult) 136 | self.num_res_blocks = num_res_blocks 137 | self.resolution = resolution 138 | self.in_channels = in_channels 139 | # downsampling 140 | self.conv_in = nn.Conv2d( 141 | in_channels, self.ch, kernel_size=3, stride=1, padding=1 142 | ) 143 | 144 | curr_res = resolution 145 | in_ch_mult = (1,) + tuple(ch_mult) 146 | self.in_ch_mult = in_ch_mult 147 | self.down = nn.ModuleList() 148 | block_in = self.ch 149 | for i_level in range(self.num_resolutions): 150 | block = nn.ModuleList() 151 | attn = nn.ModuleList() 152 | block_in = ch * in_ch_mult[i_level] 153 | block_out = ch * ch_mult[i_level] 154 | for _ in range(self.num_res_blocks): 155 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 156 | block_in = block_out 157 | down = nn.Module() 158 | down.block = block 159 | down.attn = attn 160 | if i_level != self.num_resolutions - 1: 161 | down.downsample = Downsample(block_in) 162 | curr_res = curr_res // 2 163 | self.down.append(down) 164 | 165 | # middle 166 | self.mid = nn.Module() 167 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 168 | self.mid.attn_1 = AttnBlock(block_in) 169 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 170 | 171 | # end 172 | self.norm_out = nn.GroupNorm( 173 | num_groups=32, num_channels=block_in, eps=1e-6, affine=True 174 | ) 175 | self.conv_out = nn.Conv2d( 176 | block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 177 | ) 178 | 179 | def forward(self, x: Tensor) -> Tensor: 180 | # downsampling 181 | hs = [self.conv_in(x)] 182 | for i_level in range(self.num_resolutions): 183 | for i_block in range(self.num_res_blocks): 184 | h = self.down[i_level].block[i_block](hs[-1]) 185 | if len(self.down[i_level].attn) > 0: 186 | h = self.down[i_level].attn[i_block](h) 187 | hs.append(h) 188 | if i_level != self.num_resolutions - 1: 189 | hs.append(self.down[i_level].downsample(hs[-1])) 190 | 191 | # middle 192 | h = hs[-1] 193 | h = self.mid.block_1(h) 194 | h = self.mid.attn_1(h) 195 | h = self.mid.block_2(h) 196 | # end 197 | h = self.norm_out(h) 198 | h = swish(h) 199 | h = self.conv_out(h) 200 | return h 201 | 202 | 203 | class Decoder(nn.Module): 204 | def __init__( 205 | self, 206 | ch: int, 207 | out_ch: int, 208 | ch_mult: list[int], 209 | num_res_blocks: int, 210 | in_channels: int, 211 | resolution: int, 212 | z_channels: int, 213 | ): 214 | super().__init__() 215 | self.ch = ch 216 | self.num_resolutions = len(ch_mult) 217 | self.num_res_blocks = num_res_blocks 218 | self.resolution = resolution 219 | self.in_channels = in_channels 220 | self.ffactor = 2 ** (self.num_resolutions - 1) 221 | 222 | # compute in_ch_mult, block_in and curr_res at lowest res 223 | block_in = ch * ch_mult[self.num_resolutions - 1] 224 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 225 | self.z_shape = (1, z_channels, curr_res, curr_res) 226 | 227 | # z to block_in 228 | self.conv_in = nn.Conv2d( 229 | z_channels, block_in, kernel_size=3, stride=1, padding=1 230 | ) 231 | 232 | # middle 233 | self.mid = nn.Module() 234 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 235 | self.mid.attn_1 = AttnBlock(block_in) 236 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 237 | 238 | # upsampling 239 | self.up = nn.ModuleList() 240 | for i_level in reversed(range(self.num_resolutions)): 241 | block = nn.ModuleList() 242 | attn = nn.ModuleList() 243 | block_out = ch * ch_mult[i_level] 244 | for _ in range(self.num_res_blocks + 1): 245 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 246 | block_in = block_out 247 | up = nn.Module() 248 | up.block = block 249 | up.attn = attn 250 | if i_level != 0: 251 | up.upsample = Upsample(block_in) 252 | curr_res = curr_res * 2 253 | self.up.insert(0, up) # prepend to get consistent order 254 | 255 | # end 256 | self.norm_out = nn.GroupNorm( 257 | num_groups=32, num_channels=block_in, eps=1e-6, affine=True 258 | ) 259 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 260 | 261 | def forward(self, z: Tensor) -> Tensor: 262 | # z to block_in 263 | h = self.conv_in(z) 264 | 265 | # middle 266 | h = self.mid.block_1(h) 267 | h = self.mid.attn_1(h) 268 | h = self.mid.block_2(h) 269 | 270 | # upsampling 271 | for i_level in reversed(range(self.num_resolutions)): 272 | for i_block in range(self.num_res_blocks + 1): 273 | h = self.up[i_level].block[i_block](h) 274 | if len(self.up[i_level].attn) > 0: 275 | h = self.up[i_level].attn[i_block](h) 276 | if i_level != 0: 277 | h = self.up[i_level].upsample(h) 278 | 279 | # end 280 | h = self.norm_out(h) 281 | h = swish(h) 282 | h = self.conv_out(h) 283 | return h 284 | 285 | 286 | class DiagonalGaussian(nn.Module): 287 | def __init__(self, sample: bool = True, chunk_dim: int = 1): 288 | super().__init__() 289 | self.sample = sample 290 | self.chunk_dim = chunk_dim 291 | 292 | def forward(self, z: Tensor) -> Tensor: 293 | mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) 294 | if self.sample: 295 | std = torch.exp(0.5 * logvar) 296 | return mean + std * torch.randn_like(mean) 297 | else: 298 | return mean 299 | 300 | 301 | class AutoEncoder(nn.Module): 302 | def __init__(self, params: AutoEncoderParams): 303 | super().__init__() 304 | self.encoder = Encoder( 305 | resolution=params.resolution, 306 | in_channels=params.in_channels, 307 | ch=params.ch, 308 | ch_mult=params.ch_mult, 309 | num_res_blocks=params.num_res_blocks, 310 | z_channels=params.z_channels, 311 | ) 312 | self.decoder = Decoder( 313 | resolution=params.resolution, 314 | in_channels=params.in_channels, 315 | ch=params.ch, 316 | out_ch=params.out_ch, 317 | ch_mult=params.ch_mult, 318 | num_res_blocks=params.num_res_blocks, 319 | z_channels=params.z_channels, 320 | ) 321 | self.reg = DiagonalGaussian() 322 | 323 | self.scale_factor = params.scale_factor 324 | self.shift_factor = params.shift_factor 325 | 326 | def encode(self, x: Tensor) -> Tensor: 327 | z = self.reg(self.encoder(x)) 328 | z = self.scale_factor * (z - self.shift_factor) 329 | return z 330 | 331 | def decode(self, z: Tensor) -> Tensor: 332 | z = z / self.scale_factor + self.shift_factor 333 | return self.decoder(z) 334 | 335 | def forward(self, x: Tensor) -> Tensor: 336 | return self.decode(self.encode(x)) 337 | -------------------------------------------------------------------------------- /modules/conditioner.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | from transformers import ( 6 | CLIPTextModel, 7 | CLIPTokenizer, 8 | T5EncoderModel, 9 | T5Tokenizer, 10 | __version__, 11 | ) 12 | from transformers.utils.quantization_config import QuantoConfig, BitsAndBytesConfig 13 | 14 | CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface") 15 | 16 | 17 | def auto_quantization_config( 18 | quantization_dtype: str, 19 | ) -> QuantoConfig | BitsAndBytesConfig: 20 | if quantization_dtype == "qfloat8": 21 | return QuantoConfig(weights="float8") 22 | elif quantization_dtype == "qint4": 23 | return BitsAndBytesConfig( 24 | load_in_4bit=True, 25 | bnb_4bit_compute_dtype=torch.bfloat16, 26 | bnb_4bit_quant_type="nf4", 27 | ) 28 | elif quantization_dtype == "qint8": 29 | return BitsAndBytesConfig(load_in_8bit=True, llm_int8_has_fp16_weight=False) 30 | elif quantization_dtype == "qint2": 31 | return QuantoConfig(weights="int2") 32 | elif quantization_dtype is None or quantization_dtype == "bfloat16": 33 | return None 34 | else: 35 | raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}") 36 | 37 | 38 | class HFEmbedder(nn.Module): 39 | def __init__( 40 | self, 41 | version: str, 42 | max_length: int, 43 | device: torch.device | int, 44 | quantization_dtype: str | None = None, 45 | offloading_device: torch.device | int | None = torch.device("cpu"), 46 | is_clip: bool = False, 47 | **hf_kwargs, 48 | ): 49 | super().__init__() 50 | self.offloading_device = ( 51 | offloading_device 52 | if isinstance(offloading_device, torch.device) 53 | else torch.device(offloading_device) 54 | ) 55 | self.device = ( 56 | device if isinstance(device, torch.device) else torch.device(device) 57 | ) 58 | self.is_clip = version.startswith("openai") or is_clip 59 | self.max_length = max_length 60 | self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" 61 | 62 | auto_quant_config = ( 63 | auto_quantization_config(quantization_dtype) 64 | if quantization_dtype is not None 65 | and quantization_dtype != "bfloat16" 66 | and quantization_dtype != "float16" 67 | else None 68 | ) 69 | 70 | # BNB will move to cuda:0 by default if not specified 71 | if isinstance(auto_quant_config, BitsAndBytesConfig): 72 | hf_kwargs["device_map"] = {"": self.device.index} 73 | if auto_quant_config is not None: 74 | hf_kwargs["quantization_config"] = auto_quant_config 75 | 76 | if self.is_clip: 77 | self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( 78 | version, max_length=max_length 79 | ) 80 | 81 | self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( 82 | version, 83 | **hf_kwargs, 84 | ) 85 | 86 | else: 87 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( 88 | version, max_length=max_length 89 | ) 90 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( 91 | version, 92 | **hf_kwargs, 93 | ) 94 | 95 | def offload(self): 96 | self.hf_module.to(device=self.offloading_device) 97 | torch.cuda.empty_cache() 98 | 99 | def cuda(self): 100 | self.hf_module.to(device=self.device) 101 | 102 | def forward(self, text: list[str]) -> Tensor: 103 | batch_encoding = self.tokenizer( 104 | text, 105 | truncation=True, 106 | max_length=self.max_length, 107 | return_length=False, 108 | return_overflowing_tokens=False, 109 | padding="max_length", 110 | return_tensors="pt", 111 | ) 112 | outputs = self.hf_module( 113 | input_ids=batch_encoding["input_ids"].to(self.hf_module.device), 114 | attention_mask=None, 115 | output_hidden_states=False, 116 | ) 117 | return outputs[self.output_key] 118 | 119 | 120 | if __name__ == "__main__": 121 | model = HFEmbedder( 122 | "city96/t5-v1_1-xxl-encoder-bf16", 123 | max_length=512, 124 | device=0, 125 | quantization_dtype="qfloat8", 126 | ) 127 | o = model(["hello"]) 128 | print(o) 129 | -------------------------------------------------------------------------------- /modules/flux_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import namedtuple 3 | from typing import TYPE_CHECKING, List 4 | 5 | import torch 6 | from loguru import logger 7 | 8 | if TYPE_CHECKING: 9 | from lora_loading import LoraWeights 10 | from util import ModelSpec 11 | DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1" 12 | torch.backends.cuda.matmul.allow_tf32 = True 13 | torch.backends.cudnn.allow_tf32 = True 14 | torch.backends.cudnn.benchmark = True 15 | torch.backends.cudnn.benchmark_limit = 20 16 | torch.set_float32_matmul_precision("high") 17 | import math 18 | 19 | from pydantic import BaseModel 20 | from torch import Tensor, nn 21 | from torch.nn import functional as F 22 | 23 | 24 | class FluxParams(BaseModel): 25 | in_channels: int 26 | vec_in_dim: int 27 | context_in_dim: int 28 | hidden_size: int 29 | mlp_ratio: float 30 | num_heads: int 31 | depth: int 32 | depth_single_blocks: int 33 | axes_dim: list[int] 34 | theta: int 35 | qkv_bias: bool 36 | guidance_embed: bool 37 | 38 | 39 | # attention is always same shape each time it's called per H*W, so compile with fullgraph 40 | # @torch.compile(mode="reduce-overhead", fullgraph=True, disable=DISABLE_COMPILE) 41 | def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: 42 | q, k = apply_rope(q, k, pe) 43 | x = F.scaled_dot_product_attention(q, k, v).transpose(1, 2) 44 | x = x.reshape(*x.shape[:-2], -1) 45 | return x 46 | 47 | 48 | # @torch.compile(mode="reduce-overhead", disable=DISABLE_COMPILE) 49 | def rope(pos: Tensor, dim: int, theta: int) -> Tensor: 50 | scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim 51 | omega = 1.0 / (theta**scale) 52 | out = torch.einsum("...n,d->...nd", pos, omega) 53 | out = torch.stack( 54 | [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1 55 | ) 56 | out = out.reshape(*out.shape[:-1], 2, 2) 57 | return out 58 | 59 | 60 | def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: 61 | xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) 62 | xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) 63 | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 64 | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 65 | return xq_out.reshape(*xq.shape), xk_out.reshape(*xk.shape) 66 | 67 | 68 | class EmbedND(nn.Module): 69 | def __init__( 70 | self, 71 | dim: int, 72 | theta: int, 73 | axes_dim: list[int], 74 | dtype: torch.dtype = torch.bfloat16, 75 | ): 76 | super().__init__() 77 | self.dim = dim 78 | self.theta = theta 79 | self.axes_dim = axes_dim 80 | self.dtype = dtype 81 | 82 | def forward(self, ids: Tensor) -> Tensor: 83 | n_axes = ids.shape[-1] 84 | emb = torch.cat( 85 | [ 86 | rope(ids[..., i], self.axes_dim[i], self.theta).type(self.dtype) 87 | for i in range(n_axes) 88 | ], 89 | dim=-3, 90 | ) 91 | 92 | return emb.unsqueeze(1) 93 | 94 | 95 | def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): 96 | """ 97 | Create sinusoidal timestep embeddings. 98 | :param t: a 1-D Tensor of N indices, one per batch element. 99 | These may be fractional. 100 | :param dim: the dimension of the output. 101 | :param max_period: controls the minimum frequency of the embeddings. 102 | :return: an (N, D) Tensor of positional embeddings. 103 | """ 104 | t = time_factor * t 105 | half = dim // 2 106 | freqs = torch.exp( 107 | -math.log(max_period) 108 | * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) 109 | / half 110 | ) 111 | 112 | args = t[:, None].float() * freqs[None] 113 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 114 | if dim % 2: 115 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 116 | return embedding 117 | 118 | 119 | class MLPEmbedder(nn.Module): 120 | def __init__( 121 | self, in_dim: int, hidden_dim: int, prequantized: bool = False, quantized=False 122 | ): 123 | from float8_quantize import F8Linear 124 | 125 | super().__init__() 126 | self.in_layer = ( 127 | nn.Linear(in_dim, hidden_dim, bias=True) 128 | if not prequantized 129 | else ( 130 | F8Linear( 131 | in_features=in_dim, 132 | out_features=hidden_dim, 133 | bias=True, 134 | ) 135 | if quantized 136 | else nn.Linear(in_dim, hidden_dim, bias=True) 137 | ) 138 | ) 139 | self.silu = nn.SiLU() 140 | self.out_layer = ( 141 | nn.Linear(hidden_dim, hidden_dim, bias=True) 142 | if not prequantized 143 | else ( 144 | F8Linear( 145 | in_features=hidden_dim, 146 | out_features=hidden_dim, 147 | bias=True, 148 | ) 149 | if quantized 150 | else nn.Linear(hidden_dim, hidden_dim, bias=True) 151 | ) 152 | ) 153 | 154 | def forward(self, x: Tensor) -> Tensor: 155 | return self.out_layer(self.silu(self.in_layer(x))) 156 | 157 | 158 | class RMSNorm(torch.nn.Module): 159 | def __init__(self, dim: int): 160 | super().__init__() 161 | self.scale = nn.Parameter(torch.ones(dim)) 162 | 163 | def forward(self, x: Tensor): 164 | return F.rms_norm(x.float(), self.scale.shape, self.scale, eps=1e-6).to(x) 165 | 166 | 167 | class QKNorm(torch.nn.Module): 168 | def __init__(self, dim: int): 169 | super().__init__() 170 | self.query_norm = RMSNorm(dim) 171 | self.key_norm = RMSNorm(dim) 172 | 173 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: 174 | q = self.query_norm(q) 175 | k = self.key_norm(k) 176 | return q, k 177 | 178 | 179 | class SelfAttention(nn.Module): 180 | def __init__( 181 | self, 182 | dim: int, 183 | num_heads: int = 8, 184 | qkv_bias: bool = False, 185 | prequantized: bool = False, 186 | ): 187 | super().__init__() 188 | from float8_quantize import F8Linear 189 | 190 | self.num_heads = num_heads 191 | head_dim = dim // num_heads 192 | 193 | self.qkv = ( 194 | nn.Linear(dim, dim * 3, bias=qkv_bias) 195 | if not prequantized 196 | else F8Linear( 197 | in_features=dim, 198 | out_features=dim * 3, 199 | bias=qkv_bias, 200 | ) 201 | ) 202 | self.norm = QKNorm(head_dim) 203 | self.proj = ( 204 | nn.Linear(dim, dim) 205 | if not prequantized 206 | else F8Linear( 207 | in_features=dim, 208 | out_features=dim, 209 | bias=True, 210 | ) 211 | ) 212 | self.K = 3 213 | self.H = self.num_heads 214 | self.KH = self.K * self.H 215 | 216 | def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]: 217 | B, L, D = x.shape 218 | q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4) 219 | return q, k, v 220 | 221 | def forward(self, x: Tensor, pe: Tensor) -> Tensor: 222 | qkv = self.qkv(x) 223 | q, k, v = self.rearrange_for_norm(qkv) 224 | q, k = self.norm(q, k, v) 225 | x = attention(q, k, v, pe=pe) 226 | x = self.proj(x) 227 | return x 228 | 229 | 230 | ModulationOut = namedtuple("ModulationOut", ["shift", "scale", "gate"]) 231 | 232 | 233 | class Modulation(nn.Module): 234 | def __init__(self, dim: int, double: bool, quantized_modulation: bool = False): 235 | super().__init__() 236 | from float8_quantize import F8Linear 237 | 238 | self.is_double = double 239 | self.multiplier = 6 if double else 3 240 | self.lin = ( 241 | nn.Linear(dim, self.multiplier * dim, bias=True) 242 | if not quantized_modulation 243 | else F8Linear( 244 | in_features=dim, 245 | out_features=self.multiplier * dim, 246 | bias=True, 247 | ) 248 | ) 249 | self.act = nn.SiLU() 250 | 251 | def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: 252 | out = self.lin(self.act(vec))[:, None, :].chunk(self.multiplier, dim=-1) 253 | 254 | return ( 255 | ModulationOut(*out[:3]), 256 | ModulationOut(*out[3:]) if self.is_double else None, 257 | ) 258 | 259 | 260 | class DoubleStreamBlock(nn.Module): 261 | def __init__( 262 | self, 263 | hidden_size: int, 264 | num_heads: int, 265 | mlp_ratio: float, 266 | qkv_bias: bool = False, 267 | dtype: torch.dtype = torch.float16, 268 | quantized_modulation: bool = False, 269 | prequantized: bool = False, 270 | ): 271 | super().__init__() 272 | from float8_quantize import F8Linear 273 | 274 | self.dtype = dtype 275 | 276 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 277 | self.num_heads = num_heads 278 | self.hidden_size = hidden_size 279 | self.img_mod = Modulation( 280 | hidden_size, double=True, quantized_modulation=quantized_modulation 281 | ) 282 | self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 283 | self.img_attn = SelfAttention( 284 | dim=hidden_size, 285 | num_heads=num_heads, 286 | qkv_bias=qkv_bias, 287 | prequantized=prequantized, 288 | ) 289 | 290 | self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 291 | self.img_mlp = nn.Sequential( 292 | ( 293 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True) 294 | if not prequantized 295 | else F8Linear( 296 | in_features=hidden_size, 297 | out_features=mlp_hidden_dim, 298 | bias=True, 299 | ) 300 | ), 301 | nn.GELU(approximate="tanh"), 302 | ( 303 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True) 304 | if not prequantized 305 | else F8Linear( 306 | in_features=mlp_hidden_dim, 307 | out_features=hidden_size, 308 | bias=True, 309 | ) 310 | ), 311 | ) 312 | 313 | self.txt_mod = Modulation( 314 | hidden_size, double=True, quantized_modulation=quantized_modulation 315 | ) 316 | self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 317 | self.txt_attn = SelfAttention( 318 | dim=hidden_size, 319 | num_heads=num_heads, 320 | qkv_bias=qkv_bias, 321 | prequantized=prequantized, 322 | ) 323 | 324 | self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 325 | self.txt_mlp = nn.Sequential( 326 | ( 327 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True) 328 | if not prequantized 329 | else F8Linear( 330 | in_features=hidden_size, 331 | out_features=mlp_hidden_dim, 332 | bias=True, 333 | ) 334 | ), 335 | nn.GELU(approximate="tanh"), 336 | ( 337 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True) 338 | if not prequantized 339 | else F8Linear( 340 | in_features=mlp_hidden_dim, 341 | out_features=hidden_size, 342 | bias=True, 343 | ) 344 | ), 345 | ) 346 | self.K = 3 347 | self.H = self.num_heads 348 | self.KH = self.K * self.H 349 | self.do_clamp = dtype == torch.float16 350 | 351 | def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]: 352 | B, L, D = x.shape 353 | q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4) 354 | return q, k, v 355 | 356 | def forward( 357 | self, 358 | img: Tensor, 359 | txt: Tensor, 360 | vec: Tensor, 361 | pe: Tensor, 362 | ) -> tuple[Tensor, Tensor]: 363 | img_mod1, img_mod2 = self.img_mod(vec) 364 | txt_mod1, txt_mod2 = self.txt_mod(vec) 365 | 366 | # prepare image for attention 367 | img_modulated = self.img_norm1(img) 368 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 369 | img_qkv = self.img_attn.qkv(img_modulated) 370 | img_q, img_k, img_v = self.rearrange_for_norm(img_qkv) 371 | img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) 372 | 373 | # prepare txt for attention 374 | txt_modulated = self.txt_norm1(txt) 375 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 376 | txt_qkv = self.txt_attn.qkv(txt_modulated) 377 | txt_q, txt_k, txt_v = self.rearrange_for_norm(txt_qkv) 378 | txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) 379 | 380 | q = torch.cat((txt_q, img_q), dim=2) 381 | k = torch.cat((txt_k, img_k), dim=2) 382 | v = torch.cat((txt_v, img_v), dim=2) 383 | 384 | attn = attention(q, k, v, pe=pe) 385 | txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] 386 | # calculate the img bloks 387 | img = img + img_mod1.gate * self.img_attn.proj(img_attn) 388 | img = img + img_mod2.gate * self.img_mlp( 389 | (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift 390 | ) 391 | 392 | # calculate the txt bloks 393 | txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) 394 | txt = txt + txt_mod2.gate * self.txt_mlp( 395 | (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift 396 | ) 397 | if self.do_clamp: 398 | img = img.clamp(min=-32000, max=32000) 399 | txt = txt.clamp(min=-32000, max=32000) 400 | return img, txt 401 | 402 | 403 | class SingleStreamBlock(nn.Module): 404 | """ 405 | A DiT block with parallel linear layers as described in 406 | https://arxiv.org/abs/2302.05442 and adapted modulation interface. 407 | """ 408 | 409 | def __init__( 410 | self, 411 | hidden_size: int, 412 | num_heads: int, 413 | mlp_ratio: float = 4.0, 414 | qk_scale: float | None = None, 415 | dtype: torch.dtype = torch.float16, 416 | quantized_modulation: bool = False, 417 | prequantized: bool = False, 418 | ): 419 | super().__init__() 420 | from float8_quantize import F8Linear 421 | 422 | self.dtype = dtype 423 | self.hidden_dim = hidden_size 424 | self.num_heads = num_heads 425 | head_dim = hidden_size // num_heads 426 | self.scale = qk_scale or head_dim**-0.5 427 | 428 | self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 429 | # qkv and mlp_in 430 | self.linear1 = ( 431 | nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) 432 | if not prequantized 433 | else F8Linear( 434 | in_features=hidden_size, 435 | out_features=hidden_size * 3 + self.mlp_hidden_dim, 436 | bias=True, 437 | ) 438 | ) 439 | # proj and mlp_out 440 | self.linear2 = ( 441 | nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) 442 | if not prequantized 443 | else F8Linear( 444 | in_features=hidden_size + self.mlp_hidden_dim, 445 | out_features=hidden_size, 446 | bias=True, 447 | ) 448 | ) 449 | 450 | self.norm = QKNorm(head_dim) 451 | 452 | self.hidden_size = hidden_size 453 | self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 454 | 455 | self.mlp_act = nn.GELU(approximate="tanh") 456 | self.modulation = Modulation( 457 | hidden_size, 458 | double=False, 459 | quantized_modulation=quantized_modulation and prequantized, 460 | ) 461 | 462 | self.K = 3 463 | self.H = self.num_heads 464 | self.KH = self.K * self.H 465 | self.do_clamp = dtype == torch.float16 466 | 467 | def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: 468 | mod = self.modulation(vec)[0] 469 | pre_norm = self.pre_norm(x) 470 | x_mod = (1 + mod.scale) * pre_norm + mod.shift 471 | qkv, mlp = torch.split( 472 | self.linear1(x_mod), 473 | [3 * self.hidden_size, self.mlp_hidden_dim], 474 | dim=-1, 475 | ) 476 | B, L, D = qkv.shape 477 | q, k, v = qkv.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4) 478 | q, k = self.norm(q, k, v) 479 | attn = attention(q, k, v, pe=pe) 480 | output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) 481 | if self.do_clamp: 482 | out = (x + mod.gate * output).clamp(min=-32000, max=32000) 483 | else: 484 | out = x + mod.gate * output 485 | return out 486 | 487 | 488 | class LastLayer(nn.Module): 489 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int): 490 | super().__init__() 491 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 492 | self.linear = nn.Linear( 493 | hidden_size, patch_size * patch_size * out_channels, bias=True 494 | ) 495 | self.adaLN_modulation = nn.Sequential( 496 | nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) 497 | ) 498 | 499 | def forward(self, x: Tensor, vec: Tensor) -> Tensor: 500 | shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) 501 | x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] 502 | x = self.linear(x) 503 | return x 504 | 505 | 506 | class Flux(nn.Module): 507 | """ 508 | Transformer model for flow matching on sequences. 509 | """ 510 | 511 | def __init__(self, config: "ModelSpec", dtype: torch.dtype = torch.float16): 512 | super().__init__() 513 | 514 | self.dtype = dtype 515 | self.params = config.params 516 | self.in_channels = config.params.in_channels 517 | self.out_channels = self.in_channels 518 | self.loras: List[LoraWeights] = [] 519 | prequantized_flow = config.prequantized_flow 520 | quantized_embedders = config.quantize_flow_embedder_layers and prequantized_flow 521 | quantized_modulation = config.quantize_modulation and prequantized_flow 522 | from float8_quantize import F8Linear 523 | 524 | if config.params.hidden_size % config.params.num_heads != 0: 525 | raise ValueError( 526 | f"Hidden size {config.params.hidden_size} must be divisible by num_heads {config.params.num_heads}" 527 | ) 528 | pe_dim = config.params.hidden_size // config.params.num_heads 529 | if sum(config.params.axes_dim) != pe_dim: 530 | raise ValueError( 531 | f"Got {config.params.axes_dim} but expected positional dim {pe_dim}" 532 | ) 533 | self.hidden_size = config.params.hidden_size 534 | self.num_heads = config.params.num_heads 535 | self.pe_embedder = EmbedND( 536 | dim=pe_dim, 537 | theta=config.params.theta, 538 | axes_dim=config.params.axes_dim, 539 | dtype=self.dtype, 540 | ) 541 | self.img_in = ( 542 | nn.Linear(self.in_channels, self.hidden_size, bias=True) 543 | if not prequantized_flow 544 | else ( 545 | F8Linear( 546 | in_features=self.in_channels, 547 | out_features=self.hidden_size, 548 | bias=True, 549 | ) 550 | if quantized_embedders 551 | else nn.Linear(self.in_channels, self.hidden_size, bias=True) 552 | ) 553 | ) 554 | self.time_in = MLPEmbedder( 555 | in_dim=256, 556 | hidden_dim=self.hidden_size, 557 | prequantized=prequantized_flow, 558 | quantized=quantized_embedders, 559 | ) 560 | self.vector_in = MLPEmbedder( 561 | config.params.vec_in_dim, 562 | self.hidden_size, 563 | prequantized=prequantized_flow, 564 | quantized=quantized_embedders, 565 | ) 566 | self.guidance_in = ( 567 | MLPEmbedder( 568 | in_dim=256, 569 | hidden_dim=self.hidden_size, 570 | prequantized=prequantized_flow, 571 | quantized=quantized_embedders, 572 | ) 573 | if config.params.guidance_embed 574 | else nn.Identity() 575 | ) 576 | self.txt_in = ( 577 | nn.Linear(config.params.context_in_dim, self.hidden_size) 578 | if not quantized_embedders 579 | else ( 580 | F8Linear( 581 | in_features=config.params.context_in_dim, 582 | out_features=self.hidden_size, 583 | bias=True, 584 | ) 585 | if quantized_embedders 586 | else nn.Linear(config.params.context_in_dim, self.hidden_size) 587 | ) 588 | ) 589 | 590 | self.double_blocks = nn.ModuleList( 591 | [ 592 | DoubleStreamBlock( 593 | self.hidden_size, 594 | self.num_heads, 595 | mlp_ratio=config.params.mlp_ratio, 596 | qkv_bias=config.params.qkv_bias, 597 | dtype=self.dtype, 598 | quantized_modulation=quantized_modulation, 599 | prequantized=prequantized_flow, 600 | ) 601 | for _ in range(config.params.depth) 602 | ] 603 | ) 604 | 605 | self.single_blocks = nn.ModuleList( 606 | [ 607 | SingleStreamBlock( 608 | self.hidden_size, 609 | self.num_heads, 610 | mlp_ratio=config.params.mlp_ratio, 611 | dtype=self.dtype, 612 | quantized_modulation=quantized_modulation, 613 | prequantized=prequantized_flow, 614 | ) 615 | for _ in range(config.params.depth_single_blocks) 616 | ] 617 | ) 618 | 619 | self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) 620 | 621 | def get_lora(self, identifier: str): 622 | for lora in self.loras: 623 | if lora.path == identifier or lora.name == identifier: 624 | return lora 625 | 626 | def has_lora(self, identifier: str): 627 | for lora in self.loras: 628 | if lora.path == identifier or lora.name == identifier: 629 | return True 630 | 631 | def load_lora(self, path: str, scale: float, name: str = None): 632 | from lora_loading import ( 633 | LoraWeights, 634 | apply_lora_to_model, 635 | remove_lora_from_module, 636 | ) 637 | 638 | if self.has_lora(path): 639 | lora = self.get_lora(path) 640 | if lora.scale == scale: 641 | logger.warning( 642 | f"Lora {lora.name} already loaded with same scale - ignoring!" 643 | ) 644 | else: 645 | remove_lora_from_module(self, lora, lora.scale) 646 | apply_lora_to_model(self, lora, scale) 647 | for idx, lora_ in enumerate(self.loras): 648 | if lora_.path == lora.path: 649 | self.loras[idx].scale = scale 650 | break 651 | else: 652 | _, lora = apply_lora_to_model(self, path, scale, return_lora_resolved=True) 653 | self.loras.append(LoraWeights(lora, path, name, scale)) 654 | 655 | def unload_lora(self, path_or_identifier: str): 656 | from lora_loading import remove_lora_from_module 657 | 658 | removed = False 659 | for idx, lora_ in enumerate(list(self.loras)): 660 | if lora_.path == path_or_identifier or lora_.name == path_or_identifier: 661 | remove_lora_from_module(self, lora_.weights, lora_.scale) 662 | self.loras.pop(idx) 663 | removed = True 664 | break 665 | if not removed: 666 | logger.warning( 667 | f"Couldn't remove lora {path_or_identifier} as it wasn't found fused to the model!" 668 | ) 669 | else: 670 | logger.info("Successfully removed lora from module.") 671 | 672 | def forward( 673 | self, 674 | img: Tensor, 675 | img_ids: Tensor, 676 | txt: Tensor, 677 | txt_ids: Tensor, 678 | timesteps: Tensor, 679 | y: Tensor, 680 | guidance: Tensor | None = None, 681 | ) -> Tensor: 682 | if img.ndim != 3 or txt.ndim != 3: 683 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 684 | 685 | # running on sequences img 686 | img = self.img_in(img) 687 | vec = self.time_in(timestep_embedding(timesteps, 256).type(self.dtype)) 688 | 689 | if self.params.guidance_embed: 690 | if guidance is None: 691 | raise ValueError( 692 | "Didn't get guidance strength for guidance distilled model." 693 | ) 694 | vec = vec + self.guidance_in( 695 | timestep_embedding(guidance, 256).type(self.dtype) 696 | ) 697 | vec = vec + self.vector_in(y) 698 | 699 | txt = self.txt_in(txt) 700 | 701 | ids = torch.cat((txt_ids, img_ids), dim=1) 702 | pe = self.pe_embedder(ids) 703 | 704 | # double stream blocks 705 | for block in self.double_blocks: 706 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe) 707 | 708 | img = torch.cat((txt, img), 1) 709 | 710 | # single stream blocks 711 | for block in self.single_blocks: 712 | img = block(img, vec=vec, pe=pe) 713 | 714 | img = img[:, txt.shape[1] :, ...] 715 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 716 | return img 717 | 718 | @classmethod 719 | def from_pretrained( 720 | cls: "Flux", path: str, dtype: torch.dtype = torch.float16 721 | ) -> "Flux": 722 | from safetensors.torch import load_file 723 | 724 | from util import load_config_from_path 725 | 726 | config = load_config_from_path(path) 727 | with torch.device("meta"): 728 | klass = cls(config=config, dtype=dtype) 729 | if not config.prequantized_flow: 730 | klass.type(dtype) 731 | 732 | ckpt = load_file(config.ckpt_path, device="cpu") 733 | klass.load_state_dict(ckpt, assign=True) 734 | return klass.to("cpu") 735 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/aredden/torch-cublas-hgemm.git@master 2 | einops 3 | PyTurboJPEG 4 | pydantic 5 | fastapi 6 | bitsandbytes 7 | loguru 8 | transformers 9 | tokenizers 10 | sentencepiece 11 | click 12 | accelerate 13 | quanto 14 | pydash 15 | pybase64 16 | uvicorn -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Literal, Optional 4 | 5 | import torch 6 | from modules.autoencoder import AutoEncoder, AutoEncoderParams 7 | from modules.conditioner import HFEmbedder 8 | from modules.flux_model import Flux, FluxParams 9 | from safetensors.torch import load_file as load_sft 10 | 11 | try: 12 | from enum import StrEnum 13 | except: 14 | from enum import Enum 15 | 16 | class StrEnum(str, Enum): 17 | pass 18 | 19 | 20 | from pydantic import BaseModel, ConfigDict 21 | from loguru import logger 22 | 23 | 24 | class ModelVersion(StrEnum): 25 | flux_dev = "flux-dev" 26 | flux_schnell = "flux-schnell" 27 | 28 | 29 | class QuantizationDtype(StrEnum): 30 | qfloat8 = "qfloat8" 31 | qint2 = "qint2" 32 | qint4 = "qint4" 33 | qint8 = "qint8" 34 | bfloat16 = "bfloat16" 35 | float16 = "float16" 36 | 37 | 38 | class ModelSpec(BaseModel): 39 | version: ModelVersion 40 | params: FluxParams 41 | ae_params: AutoEncoderParams 42 | ckpt_path: str | None 43 | # Add option to pass in custom clip model 44 | clip_path: str | None = "openai/clip-vit-large-patch14" 45 | ae_path: str | None 46 | repo_id: str | None 47 | repo_flow: str | None 48 | repo_ae: str | None 49 | text_enc_max_length: int = 512 50 | text_enc_path: str | None 51 | text_enc_device: str | torch.device | None = "cuda:0" 52 | ae_device: str | torch.device | None = "cuda:0" 53 | flux_device: str | torch.device | None = "cuda:0" 54 | flow_dtype: str = "float16" 55 | ae_dtype: str = "bfloat16" 56 | text_enc_dtype: str = "bfloat16" 57 | # unused / deprecated 58 | num_to_quant: Optional[int] = 20 59 | quantize_extras: bool = False 60 | compile_extras: bool = False 61 | compile_blocks: bool = False 62 | flow_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8 63 | text_enc_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8 64 | ae_quantization_dtype: Optional[QuantizationDtype] = None 65 | clip_quantization_dtype: Optional[QuantizationDtype] = None 66 | offload_text_encoder: bool = False 67 | offload_vae: bool = False 68 | offload_flow: bool = False 69 | prequantized_flow: bool = False 70 | 71 | # Improved precision via not quanitzing the modulation linear layers 72 | quantize_modulation: bool = True 73 | # Improved precision via not quanitzing the flow embedder layers 74 | quantize_flow_embedder_layers: bool = False 75 | 76 | model_config: ConfigDict = { 77 | "arbitrary_types_allowed": True, 78 | "use_enum_values": True, 79 | } 80 | 81 | 82 | def load_models(config: ModelSpec) -> tuple[Flux, AutoEncoder, HFEmbedder, HFEmbedder]: 83 | flow = load_flow_model(config) 84 | ae = load_autoencoder(config) 85 | clip, t5 = load_text_encoders(config) 86 | return flow, ae, clip, t5 87 | 88 | 89 | def parse_device(device: str | torch.device | None) -> torch.device: 90 | if isinstance(device, str): 91 | return torch.device(device) 92 | elif isinstance(device, torch.device): 93 | return device 94 | else: 95 | return torch.device("cuda:0") 96 | 97 | 98 | def into_dtype(dtype: str) -> torch.dtype: 99 | if isinstance(dtype, torch.dtype): 100 | return dtype 101 | if dtype == "float16": 102 | return torch.float16 103 | elif dtype == "bfloat16": 104 | return torch.bfloat16 105 | elif dtype == "float32": 106 | return torch.float32 107 | else: 108 | raise ValueError(f"Invalid dtype: {dtype}") 109 | 110 | 111 | def into_device(device: str | torch.device | None) -> torch.device: 112 | if isinstance(device, str): 113 | return torch.device(device) 114 | elif isinstance(device, torch.device): 115 | return device 116 | elif isinstance(device, int): 117 | return torch.device(f"cuda:{device}") 118 | else: 119 | return torch.device("cuda:0") 120 | 121 | 122 | def load_config( 123 | name: ModelVersion = ModelVersion.flux_dev, 124 | flux_path: str | None = None, 125 | ae_path: str | None = None, 126 | text_enc_path: str | None = None, 127 | text_enc_device: str | torch.device | None = None, 128 | ae_device: str | torch.device | None = None, 129 | flux_device: str | torch.device | None = None, 130 | flow_dtype: str = "float16", 131 | ae_dtype: str = "bfloat16", 132 | text_enc_dtype: str = "bfloat16", 133 | num_to_quant: Optional[int] = 20, 134 | compile_extras: bool = False, 135 | compile_blocks: bool = False, 136 | offload_text_enc: bool = False, 137 | offload_ae: bool = False, 138 | offload_flow: bool = False, 139 | quant_text_enc: Optional[Literal["float8", "qint2", "qint4", "qint8"]] = None, 140 | quant_ae: bool = False, 141 | prequantized_flow: bool = False, 142 | quantize_modulation: bool = True, 143 | quantize_flow_embedder_layers: bool = False, 144 | ) -> ModelSpec: 145 | """ 146 | Load a model configuration using the passed arguments. 147 | """ 148 | text_enc_device = str(parse_device(text_enc_device)) 149 | ae_device = str(parse_device(ae_device)) 150 | flux_device = str(parse_device(flux_device)) 151 | return ModelSpec( 152 | version=name, 153 | repo_id=( 154 | "black-forest-labs/FLUX.1-dev" 155 | if name == ModelVersion.flux_dev 156 | else "black-forest-labs/FLUX.1-schnell" 157 | ), 158 | repo_flow=( 159 | "flux1-dev.sft" if name == ModelVersion.flux_dev else "flux1-schnell.sft" 160 | ), 161 | repo_ae="ae.sft", 162 | ckpt_path=flux_path, 163 | params=FluxParams( 164 | in_channels=64, 165 | vec_in_dim=768, 166 | context_in_dim=4096, 167 | hidden_size=3072, 168 | mlp_ratio=4.0, 169 | num_heads=24, 170 | depth=19, 171 | depth_single_blocks=38, 172 | axes_dim=[16, 56, 56], 173 | theta=10_000, 174 | qkv_bias=True, 175 | guidance_embed=name == ModelVersion.flux_dev, 176 | ), 177 | ae_path=ae_path, 178 | ae_params=AutoEncoderParams( 179 | resolution=256, 180 | in_channels=3, 181 | ch=128, 182 | out_ch=3, 183 | ch_mult=[1, 2, 4, 4], 184 | num_res_blocks=2, 185 | z_channels=16, 186 | scale_factor=0.3611, 187 | shift_factor=0.1159, 188 | ), 189 | text_enc_path=text_enc_path, 190 | text_enc_device=text_enc_device, 191 | ae_device=ae_device, 192 | flux_device=flux_device, 193 | flow_dtype=flow_dtype, 194 | ae_dtype=ae_dtype, 195 | text_enc_dtype=text_enc_dtype, 196 | text_enc_max_length=512 if name == ModelVersion.flux_dev else 256, 197 | num_to_quant=num_to_quant, 198 | compile_extras=compile_extras, 199 | compile_blocks=compile_blocks, 200 | offload_flow=offload_flow, 201 | offload_text_encoder=offload_text_enc, 202 | offload_vae=offload_ae, 203 | text_enc_quantization_dtype={ 204 | "float8": QuantizationDtype.qfloat8, 205 | "qint2": QuantizationDtype.qint2, 206 | "qint4": QuantizationDtype.qint4, 207 | "qint8": QuantizationDtype.qint8, 208 | }.get(quant_text_enc, None), 209 | ae_quantization_dtype=QuantizationDtype.qfloat8 if quant_ae else None, 210 | prequantized_flow=prequantized_flow, 211 | quantize_modulation=quantize_modulation, 212 | quantize_flow_embedder_layers=quantize_flow_embedder_layers, 213 | ) 214 | 215 | 216 | def load_config_from_path(path: str) -> ModelSpec: 217 | path_path = Path(path) 218 | if not path_path.exists(): 219 | raise ValueError(f"Path {path} does not exist") 220 | if not path_path.is_file(): 221 | raise ValueError(f"Path {path} is not a file") 222 | return ModelSpec(**json.loads(path_path.read_text())) 223 | 224 | 225 | def print_load_warning(missing: list[str], unexpected: list[str]) -> None: 226 | if len(missing) > 0 and len(unexpected) > 0: 227 | logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 228 | logger.warning("\n" + "-" * 79 + "\n") 229 | logger.warning( 230 | f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) 231 | ) 232 | elif len(missing) > 0: 233 | logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 234 | elif len(unexpected) > 0: 235 | logger.warning( 236 | f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) 237 | ) 238 | 239 | 240 | def load_flow_model(config: ModelSpec) -> Flux: 241 | ckpt_path = config.ckpt_path 242 | FluxClass = Flux 243 | 244 | with torch.device("meta"): 245 | model = FluxClass(config, dtype=into_dtype(config.flow_dtype)) 246 | if not config.prequantized_flow: 247 | model.type(into_dtype(config.flow_dtype)) 248 | 249 | if ckpt_path is not None: 250 | # load_sft doesn't support torch.device 251 | sd = load_sft(ckpt_path, device="cpu") 252 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 253 | print_load_warning(missing, unexpected) 254 | if not config.prequantized_flow: 255 | model.type(into_dtype(config.flow_dtype)) 256 | return model 257 | 258 | 259 | def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]: 260 | clip = HFEmbedder( 261 | config.clip_path, 262 | max_length=77, 263 | torch_dtype=into_dtype(config.text_enc_dtype), 264 | device=into_device(config.text_enc_device).index or 0, 265 | is_clip=True, 266 | quantization_dtype=config.clip_quantization_dtype, 267 | ) 268 | t5 = HFEmbedder( 269 | config.text_enc_path, 270 | max_length=config.text_enc_max_length, 271 | torch_dtype=into_dtype(config.text_enc_dtype), 272 | device=into_device(config.text_enc_device).index or 0, 273 | quantization_dtype=config.text_enc_quantization_dtype, 274 | ) 275 | return clip, t5 276 | 277 | 278 | def load_autoencoder(config: ModelSpec) -> AutoEncoder: 279 | ckpt_path = config.ae_path 280 | with torch.device("meta" if ckpt_path is not None else config.ae_device): 281 | ae = AutoEncoder(config.ae_params).to(into_dtype(config.ae_dtype)) 282 | 283 | if ckpt_path is not None: 284 | sd = load_sft(ckpt_path, device=str(config.ae_device)) 285 | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) 286 | print_load_warning(missing, unexpected) 287 | ae.to(device=into_device(config.ae_device), dtype=into_dtype(config.ae_dtype)) 288 | if config.ae_quantization_dtype is not None: 289 | from float8_quantize import recursive_swap_linears 290 | 291 | recursive_swap_linears(ae) 292 | if config.offload_vae: 293 | ae.to("cpu") 294 | torch.cuda.empty_cache() 295 | return ae 296 | 297 | 298 | class LoadedModels(BaseModel): 299 | flow: Flux 300 | ae: AutoEncoder 301 | clip: HFEmbedder 302 | t5: HFEmbedder 303 | config: ModelSpec 304 | 305 | model_config = { 306 | "arbitrary_types_allowed": True, 307 | "use_enum_values": True, 308 | } 309 | 310 | 311 | def load_models_from_config_path( 312 | path: str, 313 | ) -> LoadedModels: 314 | config = load_config_from_path(path) 315 | clip, t5 = load_text_encoders(config) 316 | return LoadedModels( 317 | flow=load_flow_model(config), 318 | ae=load_autoencoder(config), 319 | clip=clip, 320 | t5=t5, 321 | config=config, 322 | ) 323 | 324 | 325 | def load_models_from_config(config: ModelSpec) -> LoadedModels: 326 | clip, t5 = load_text_encoders(config) 327 | return LoadedModels( 328 | flow=load_flow_model(config), 329 | ae=load_autoencoder(config), 330 | clip=clip, 331 | t5=t5, 332 | config=config, 333 | ) 334 | --------------------------------------------------------------------------------