├── .dockerignore ├── .github └── workflows │ ├── ci.yaml │ └── push.yaml ├── .gitignore ├── LICENSE ├── README.md ├── bfl_predictor.py ├── cog.yaml.template ├── diffusers_predictor.py ├── feature-extractor └── preprocessor_config.json ├── flux ├── __init__.py ├── __main__.py ├── cli.py ├── math.py ├── model.py ├── modules │ ├── autoencoder.py │ ├── conditioner.py │ ├── image_embedders.py │ ├── layers.py │ └── quantize.py ├── sampling.py └── util.py ├── fp8 ├── __init__.py ├── configs │ ├── config-1-flux-dev-fp8-h100.json │ ├── config-1-flux-dev-h100.json │ ├── config-1-flux-schnell-fp8-h100.json │ └── config-1-flux-schnell-h100.json ├── float8_quantize.py ├── flux_pipeline.py ├── image_encoder.py ├── lora_loading.py ├── modules │ ├── autoencoder.py │ ├── conditioner.py │ └── flux_model.py └── util.py ├── lora_loading_patch.py ├── model-cog-configs ├── canny-dev.yaml ├── depth-dev.yaml ├── dev-lora.yaml ├── dev.yaml ├── fill-dev.yaml ├── hotswap-lora.yaml ├── redux-dev.yaml ├── redux-schnell.yaml ├── schnell-lora.yaml ├── schnell.yaml └── test.yaml ├── predict.py ├── ruff.toml ├── safe-push-configs ├── canny-dev.yaml ├── depth-dev.yaml ├── dev-lora.yaml ├── dev.yaml ├── fill-dev.yaml ├── hotswap-lora.yaml ├── redux-dev.yaml ├── redux-schnell.yaml ├── schnell-lora.yaml └── schnell.yaml ├── samples.py ├── save_fp8_quantized.py ├── script ├── prod-deploy-all.sh ├── push.sh ├── select.sh └── update-schema.sh └── weights.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | 19 | /model-cache 20 | /safety-cache 21 | /falcon-cache 22 | 23 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | 8 | jobs: 9 | lint: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v3 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: "3.11" 19 | 20 | - name: Install ruff 21 | run: | 22 | pip install ruff 23 | 24 | - name: Run ruff linter 25 | run: | 26 | ruff check 27 | 28 | - name: Run ruff formatter 29 | run: | 30 | ruff format --diff 31 | -------------------------------------------------------------------------------- /.github/workflows/push.yaml: -------------------------------------------------------------------------------- 1 | name: Push Models 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | models: 7 | description: 'Comma-separated list of models to push (schnell,dev,fill-dev,canny-dev,depth-dev,redux-dev,redux-schnell,schnell-lora,dev-lora,hotswap-lora) or "all"' 8 | type: string 9 | default: 'all' 10 | 11 | jobs: 12 | prepare-matrix: 13 | runs-on: ubuntu-latest 14 | outputs: 15 | matrix: ${{ steps.set-matrix.outputs.matrix }} 16 | steps: 17 | - name: Install jq 18 | run: sudo apt-get update && sudo apt-get install -y jq 19 | 20 | - id: set-matrix 21 | run: | 22 | if [ "${{ inputs.models }}" = "all" ]; then 23 | echo "matrix={\"model\":[\"schnell\",\"dev\",\"fill-dev\",\"canny-dev\",\"depth-dev\",\"redux-dev\",\"redux-schnell\",\"schnell-lora\",\"dev-lora\",\"hotswap-lora\"]}" >> $GITHUB_OUTPUT 24 | else 25 | # Convert comma-separated string to JSON array 26 | MODELS=$(echo "${{ inputs.models }}" | jq -R -s -c 'split(",")') 27 | echo "matrix={\"model\":$MODELS}" >> $GITHUB_OUTPUT 28 | fi 29 | 30 | cog-safe-push: 31 | # runs-on: ubuntu-latest-4-cores 32 | needs: prepare-matrix 33 | runs-on: depot-ubuntu-22.04-4 34 | strategy: 35 | matrix: ${{fromJson(needs.prepare-matrix.outputs.matrix)}} 36 | fail-fast: false # Continue with other models if one fails 37 | 38 | steps: 39 | - uses: actions/checkout@v3 40 | 41 | - name: Set up Python 42 | uses: actions/setup-python@v4 43 | with: 44 | python-version: "3.12" 45 | 46 | - name: Install Cog 47 | run: | 48 | sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/download/v0.15.8/cog_$(uname -s)_$(uname -m)" 49 | sudo chmod +x /usr/local/bin/cog 50 | 51 | - name: cog login 52 | run: | 53 | echo ${{ secrets.COG_TOKEN }} | cog login --token-stdin 54 | 55 | - name: Install cog-safe-push 56 | run: | 57 | pip install git+https://github.com/replicate/cog-safe-push.git 58 | 59 | - name: Push selected models 60 | env: 61 | ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} 62 | REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }} 63 | run: | 64 | echo "===" 65 | echo "===" 66 | echo "=== Pushing ${{ matrix.model }}" 67 | echo "===" 68 | echo "===" 69 | ./script/select.sh ${{ matrix.model }} 70 | cog-safe-push -vv 71 | if [ "${{ matrix.model }}" != "hotswap-lora" ]; then 72 | cog push r8.im/black-forest-labs/flux-${{ matrix.model }} 73 | fi -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .cog 2 | __pycache__ 3 | wandb 4 | ft* 5 | *.ipynb 6 | output* 7 | training_out* 8 | trained_model.tar 9 | __*.zip 10 | **-cache 11 | model-cache 12 | falcon-cache 13 | 14 | cog.yaml 15 | *.png 16 | *.webp 17 | *.jpg 18 | cog-safe-push.yaml -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023, Replicate, Inc. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cog-flux 2 | 3 | This is a [Cog](https://cog.run) inference model for FLUX.1 [schnell] and FLUX.1 [dev] by [Black Forest Labs](https://blackforestlabs.ai/). It powers the following Replicate models: 4 | 5 | * https://replicate.com/black-forest-labs/flux-schnell 6 | * https://replicate.com/black-forest-labs/flux-dev 7 | 8 | ## Features 9 | 10 | * Compilation with `torch.compile` 11 | * Optional fp8 quantization based on [aredden/flux-fp8-api](https://github.com/aredden/flux-fp8-api), using fast CuDNN attention from Pytorch nightlies 12 | * NSFW checking with [CompVis](https://huggingface.co/CompVis/stable-diffusion-safety-checker) and [Falcons.ai](https://huggingface.co/Falconsai/nsfw_image_detection) safety checkers 13 | * img2img support 14 | 15 | ## Getting started 16 | 17 | If you just want to use the models, you can run [FLUX.1 [schnell]](https://replicate.com/black-forest-labs/flux-schnell) and [FLUX.1 [dev]](https://replicate.com/black-forest-labs/flux-dev) on Replicate with an API or in the browser. 18 | 19 | The code in this repo can be used as a template for customizations on FLUX.1, or to run the models on your own hardware. 20 | 21 | First you need to select which model to run: 22 | 23 | ```shell 24 | script/select.sh {dev,schnell} 25 | ``` 26 | 27 | Then you can run a single prediction on the model using: 28 | 29 | ```shell 30 | cog predict -i prompt="a cat in a hat" 31 | ``` 32 | 33 | The [Cog getting started guide](https://cog.run/getting-started/) explains what Cog is and how it works. 34 | 35 | To deploy it to Replicate, run: 36 | 37 | ```shell 38 | cog login 39 | cog push r8.im// 40 | ``` 41 | 42 | Learn more on [the deploy a custom model guide in the Replicate documentation](https://replicate.com/docs/guides/deploy-a-custom-model). 43 | 44 | ## Contributing 45 | 46 | Pull requests and issues are welcome! If you see a novel technique or feature you think will make FLUX.1 inference better or faster, let us know and we'll do our best to integrate it. 47 | 48 | ## Rough, partial roadmap 49 | 50 | * Serialize quantized model instead of quantizing on the fly 51 | * Use row-wise quantization 52 | * Port quantization and compilation code over to https://github.com/replicate/flux-fine-tuner 53 | 54 | ## License 55 | 56 | The code in this repository is licensed under the [Apache-2.0 License](LICENSE). 57 | 58 | FLUX.1 [dev] falls under the [`FLUX.1 [dev]` Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md). 59 | 60 | FLUX.1 [schnell] falls under the [Apache-2.0 License](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md). 61 | -------------------------------------------------------------------------------- /cog.yaml.template: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://cog.run/yaml 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: true 7 | cuda: "12.6" 8 | 9 | python_version: "3.11" 10 | 11 | python_packages: 12 | - "numpy<2" 13 | - "einops==0.8.0" 14 | - "fire==0.6.0" 15 | - "huggingface-hub==0.25.0" 16 | - "safetensors==0.4.3" 17 | - "sentencepiece==0.2.0" 18 | - "transformers==4.43.3" 19 | - "tokenizers==0.19.1" 20 | - "protobuf==5.27.2" 21 | - "diffusers==0.32.2" 22 | - "loguru==0.7.2" 23 | - "pybase64==1.4.0" 24 | - "pydash==8.0.3" 25 | - "opencv-python-headless==4.10.0.84" 26 | - "torch==2.7.0" 27 | - "torchvision==0.22" 28 | - "redis==6.0.0" 29 | 30 | 31 | # commands run after the environment is setup 32 | run: 33 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget 34 | -------------------------------------------------------------------------------- /diffusers_predictor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | from dataclasses import dataclass 5 | from typing import List 6 | 7 | import numpy as np 8 | import torch 9 | import logging 10 | from PIL import Image 11 | from pathlib import Path 12 | from diffusers.pipelines import ( 13 | FluxPipeline, 14 | FluxInpaintPipeline, 15 | FluxImg2ImgPipeline, 16 | ) 17 | 18 | from weights import WeightsDownloadCache 19 | 20 | from lora_loading_patch import load_lora_into_transformer 21 | 22 | MODEL_URL_DEV = ( 23 | "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/files.tar" 24 | ) 25 | MODEL_URL_SCHNELL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-schnell/slim.tar" 26 | 27 | FLUX_DEV_PATH = "./model-cache/FLUX.1-dev" 28 | FLUX_SCHNELL_PATH = "./model-cache/FLUX.1-schnell" 29 | MODEL_CACHE = "./model-cache/" 30 | 31 | MAX_IMAGE_SIZE = 1440 32 | 33 | 34 | @dataclass 35 | class FluxConfig: 36 | url: str 37 | path: str 38 | download_path: str # this only exists b/c flux-dev needs a different donwload_path from "path" based on how we're storing weights. 39 | num_steps: int 40 | max_sequence_length: int 41 | 42 | 43 | CONFIGS = { 44 | "flux-schnell": FluxConfig( 45 | MODEL_URL_SCHNELL, FLUX_SCHNELL_PATH, FLUX_SCHNELL_PATH, 4, 256 46 | ), 47 | "flux-dev": FluxConfig(MODEL_URL_DEV, FLUX_DEV_PATH, MODEL_CACHE, 28, 512), 48 | } 49 | 50 | # Suppress diffusers nsfw warnings 51 | logging.getLogger("diffusers").setLevel(logging.CRITICAL) 52 | logging.getLogger("transformers").setLevel(logging.CRITICAL) 53 | 54 | 55 | @dataclass 56 | class LoadedLoRAs: 57 | main: str | None 58 | extra: str | None 59 | 60 | 61 | from diffusers import AutoencoderKL 62 | from transformers import CLIPTextModel, T5EncoderModel, CLIPTokenizer, T5TokenizerFast 63 | 64 | 65 | @dataclass 66 | class ModelHolster: 67 | vae: AutoencoderKL 68 | text_encoder: CLIPTextModel 69 | text_encoder_2: T5EncoderModel 70 | tokenizer: CLIPTokenizer 71 | tokenizer_2: T5TokenizerFast 72 | 73 | 74 | class DiffusersFlux: 75 | """ 76 | Wrapper to map diffusers flux pipeline to the methods we need to serve these models in predict.py 77 | """ 78 | 79 | def __init__( 80 | self, 81 | model_name: str, 82 | weights_cache: WeightsDownloadCache, 83 | shared_models: ModelHolster | None = None, 84 | ) -> None: # pyright: ignore 85 | """Load the model into memory to make running multiple predictions efficient""" 86 | start = time.time() 87 | 88 | # Don't pull weights 89 | os.environ["TRANSFORMERS_OFFLINE"] = "1" 90 | 91 | config = CONFIGS[model_name] 92 | model_path = config.path 93 | 94 | self.default_num_steps = config.num_steps 95 | self.max_sequence_length = config.max_sequence_length 96 | 97 | # dependency injection hell yeah it's java time baybee 98 | self.weights_cache = weights_cache 99 | 100 | if not os.path.exists(model_path): # noqa: PTH110 101 | print("Model path not found, downloading models") 102 | # TODO: download everything separately; it will suck less. 103 | download_base_weights(config.url, config.download_path) 104 | 105 | print("Loading pipeline") 106 | if shared_models: 107 | txt2img_pipe = FluxPipeline.from_pretrained( 108 | model_path, 109 | vae=shared_models.vae, 110 | text_encoder=shared_models.text_encoder, 111 | text_encoder_2=shared_models.text_encoder_2, 112 | tokenizer=shared_models.tokenizer, 113 | tokenizer_2=shared_models.tokenizer_2, 114 | torch_dtype=torch.bfloat16, 115 | ).to("cuda") 116 | else: 117 | txt2img_pipe = FluxPipeline.from_pretrained( 118 | model_path, 119 | torch_dtype=torch.bfloat16, 120 | ).to("cuda") 121 | txt2img_pipe.__class__.load_lora_into_transformer = classmethod( 122 | load_lora_into_transformer 123 | ) 124 | self.txt2img_pipe = txt2img_pipe 125 | 126 | # Load img2img pipelines 127 | img2img_pipe = FluxImg2ImgPipeline( 128 | transformer=txt2img_pipe.transformer, 129 | scheduler=txt2img_pipe.scheduler, 130 | vae=txt2img_pipe.vae, 131 | text_encoder=txt2img_pipe.text_encoder, 132 | text_encoder_2=txt2img_pipe.text_encoder_2, 133 | tokenizer=txt2img_pipe.tokenizer, 134 | tokenizer_2=txt2img_pipe.tokenizer_2, 135 | ).to("cuda") 136 | img2img_pipe.__class__.load_lora_into_transformer = classmethod( 137 | load_lora_into_transformer 138 | ) 139 | 140 | self.img2img_pipe = img2img_pipe 141 | 142 | # Load inpainting pipelines 143 | inpaint_pipe = FluxInpaintPipeline( 144 | transformer=txt2img_pipe.transformer, 145 | scheduler=txt2img_pipe.scheduler, 146 | vae=txt2img_pipe.vae, 147 | text_encoder=txt2img_pipe.text_encoder, 148 | text_encoder_2=txt2img_pipe.text_encoder_2, 149 | tokenizer=txt2img_pipe.tokenizer, 150 | tokenizer_2=txt2img_pipe.tokenizer_2, 151 | ).to("cuda") 152 | inpaint_pipe.__class__.load_lora_into_transformer = classmethod( 153 | load_lora_into_transformer 154 | ) 155 | 156 | self.inpaint_pipe = inpaint_pipe 157 | 158 | self.loaded_lora_urls = LoadedLoRAs(main=None, extra=None) 159 | self.lora_scale = 1.0 160 | print("setup took: ", time.time() - start) 161 | 162 | def get_models(self): 163 | return ModelHolster( 164 | vae=self.txt2img_pipe.vae, 165 | text_encoder=self.txt2img_pipe.text_encoder, 166 | text_encoder_2=self.txt2img_pipe.text_encoder_2, 167 | tokenizer=self.txt2img_pipe.tokenizer, 168 | tokenizer_2=self.txt2img_pipe.tokenizer_2, 169 | ) 170 | 171 | def handle_loras(self, lora_weights, lora_scale, extra_lora, extra_lora_scale): 172 | # all pipes share the same weights, can do this to any of them 173 | pipe = self.txt2img_pipe 174 | 175 | if lora_weights: 176 | start_time = time.time() 177 | if extra_lora: 178 | self.lora_scale = 1.0 179 | self.load_multiple_loras(lora_weights, extra_lora) 180 | pipe.set_adapters( 181 | ["main", "extra"], adapter_weights=[lora_scale, extra_lora_scale] 182 | ) 183 | else: 184 | self.load_single_lora(lora_weights) 185 | pipe.set_adapters(["main"], adapter_weights=[lora_scale]) 186 | self.lora_scale = lora_scale 187 | print(f"Loaded LoRAs in {time.time() - start_time:.2f}s") 188 | else: 189 | pipe.unload_lora_weights() 190 | self.loaded_lora_urls = LoadedLoRAs(main=None, extra=None) 191 | self.lora_scale = 1.0 192 | 193 | @torch.inference_mode() 194 | def predict( # pyright: ignore 195 | self, 196 | prompt: str, 197 | num_outputs: int = 1, 198 | num_inference_steps: int | None = None, 199 | legacy_image_path: Path = None, 200 | legacy_mask_path: Path = None, 201 | width: int | None = None, 202 | height: int | None = None, 203 | guidance: float = 3.5, 204 | prompt_strength: float = 0.8, 205 | seed: int | None = None, 206 | ) -> List[Path]: 207 | """Run a single prediction on the model""" 208 | if seed is None or seed < 0: 209 | seed = int.from_bytes(os.urandom(2), "big") 210 | print(f"Using seed: {seed}") 211 | 212 | is_img2img_mode = legacy_image_path is not None and legacy_mask_path is None 213 | is_inpaint_mode = legacy_image_path is not None and legacy_mask_path is not None 214 | 215 | flux_kwargs = {} 216 | print(f"Prompt: {prompt}") 217 | 218 | if is_img2img_mode or is_inpaint_mode: 219 | input_image = Image.open(legacy_image_path).convert("RGB") 220 | original_width, original_height = input_image.size 221 | 222 | width = original_width 223 | height = original_height 224 | print(f"Input image size: {width}x{height}") 225 | 226 | # Calculate the scaling factor if the image exceeds max_image_size 227 | scale = min(MAX_IMAGE_SIZE / width, MAX_IMAGE_SIZE / height, 1) 228 | if scale < 1: 229 | width = int(width * scale) 230 | height = int(height * scale) 231 | 232 | # Calculate dimensions that are multiples of 16 233 | target_width = make_multiple_of_16(width) 234 | target_height = make_multiple_of_16(height) 235 | target_size = (target_width, target_height) 236 | 237 | print( 238 | f"[!] Resizing input image from {original_width}x{original_height} to {target_width}x{target_height}" 239 | ) 240 | 241 | # We're using highest quality settings; if you want to go fast, you're not running this code. 242 | input_image = input_image.resize(target_size, Image.LANCZOS) 243 | flux_kwargs["image"] = input_image 244 | 245 | # Set width and height to match the resized input image 246 | flux_kwargs["width"], flux_kwargs["height"] = target_size 247 | 248 | if is_img2img_mode: 249 | print("[!] img2img mode") 250 | pipe = self.img2img_pipe 251 | else: # is_inpaint_mode 252 | print("[!] inpaint mode") 253 | mask_image = Image.open(legacy_mask_path).convert("RGB") 254 | mask_image = mask_image.resize(target_size, Image.NEAREST) 255 | flux_kwargs["mask_image"] = mask_image 256 | pipe = self.inpaint_pipe 257 | 258 | flux_kwargs["strength"] = prompt_strength 259 | 260 | else: # is_txt2img_mode 261 | print("[!] txt2img mode") 262 | pipe = self.txt2img_pipe 263 | flux_kwargs["width"] = width 264 | flux_kwargs["height"] = height 265 | 266 | max_sequence_length = self.max_sequence_length 267 | 268 | generator = torch.Generator(device="cuda").manual_seed(seed) 269 | 270 | if self.loaded_lora_urls.main is not None: 271 | # this sets lora scale for prompt encoding, weirdly enough. it does not actually do anything with attention processing anymore. 272 | flux_kwargs["joint_attention_kwargs"] = {"scale": self.lora_scale} 273 | 274 | common_args = { 275 | "prompt": [prompt] * num_outputs, 276 | "guidance_scale": guidance, 277 | "generator": generator, 278 | "num_inference_steps": num_inference_steps 279 | if num_inference_steps 280 | else self.default_num_steps, 281 | "max_sequence_length": max_sequence_length, 282 | "output_type": "pil", 283 | } 284 | 285 | output = pipe(**common_args, **flux_kwargs) 286 | 287 | return output.images, [np.array(img) for img in output.images] 288 | 289 | def load_single_lora(self, lora_url: str): 290 | # If no change, skip 291 | if lora_url == self.loaded_lora_urls.main: 292 | print("Weights already loaded") 293 | return 294 | 295 | pipe = self.txt2img_pipe 296 | pipe.unload_lora_weights() 297 | lora_path = self.weights_cache.ensure(lora_url) 298 | pipe.load_lora_weights(lora_path, adapter_name="main") 299 | self.loaded_lora_urls = LoadedLoRAs(main=lora_url, extra=None) 300 | pipe = pipe.to("cuda") 301 | 302 | def load_multiple_loras(self, main_lora_url: str, extra_lora_url: str): 303 | pipe = self.txt2img_pipe 304 | 305 | # If no change, skip 306 | if ( 307 | main_lora_url == self.loaded_lora_urls.main 308 | and extra_lora_url == self.loaded_lora_urls.extra 309 | ): 310 | print("Weights already loaded") 311 | return 312 | 313 | # We always need to load both? 314 | pipe.unload_lora_weights() 315 | 316 | main_lora_path = self.weights_cache.ensure(main_lora_url) 317 | pipe.load_lora_weights(main_lora_path, adapter_name="main") 318 | 319 | extra_lora_path = self.weights_cache.ensure(extra_lora_url) 320 | pipe.load_lora_weights(extra_lora_path, adapter_name="extra") 321 | 322 | self.loaded_lora_urls = LoadedLoRAs(main=main_lora_url, extra=extra_lora_url) 323 | pipe = pipe.to("cuda") 324 | 325 | 326 | def download_base_weights(url: str, dest: Path): 327 | start = time.time() 328 | print("downloading url: ", url) 329 | print("downloading to: ", dest) 330 | subprocess.check_call(["pget", "-xf", url, dest], close_fds=False) 331 | print("downloading took: ", time.time() - start) 332 | 333 | 334 | def make_multiple_of_16(n): 335 | # Rounds up to the next multiple of 16, or returns n if already a multiple of 16 336 | return ((n + 15) // 16) * 16 337 | -------------------------------------------------------------------------------- /feature-extractor/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "crop_size": 224, 3 | "do_center_crop": true, 4 | "do_convert_rgb": true, 5 | "do_normalize": true, 6 | "do_resize": true, 7 | "feature_extractor_type": "CLIPFeatureExtractor", 8 | "image_mean": [ 9 | 0.48145466, 10 | 0.4578275, 11 | 0.40821073 12 | ], 13 | "image_std": [ 14 | 0.26862954, 15 | 0.26130258, 16 | 0.27577711 17 | ], 18 | "resample": 3, 19 | "size": 224 20 | } -------------------------------------------------------------------------------- /flux/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from ._version import version as __version__ # type: ignore 3 | from ._version import version_tuple 4 | except ImportError: 5 | __version__ = "unknown (no version information available)" 6 | version_tuple = (0, 0, "unknown", "noinfo") 7 | 8 | from pathlib import Path 9 | 10 | PACKAGE = __package__.replace("_", "-") 11 | PACKAGE_ROOT = Path(__file__).parent 12 | -------------------------------------------------------------------------------- /flux/__main__.py: -------------------------------------------------------------------------------- 1 | from .cli import app 2 | 3 | if __name__ == "__main__": 4 | app() 5 | -------------------------------------------------------------------------------- /flux/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from dataclasses import dataclass 4 | from glob import glob 5 | 6 | import torch 7 | from einops import rearrange 8 | from fire import Fire 9 | from PIL import Image 10 | 11 | from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack 12 | from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5 13 | 14 | 15 | @dataclass 16 | class SamplingOptions: 17 | prompt: str 18 | width: int 19 | height: int 20 | num_steps: int 21 | guidance: float 22 | seed: int | None 23 | 24 | 25 | def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: 26 | # TODO: document 27 | while (prompt := input("Next prompt: ")).startswith("/"): 28 | if prompt.startswith("/w"): 29 | if prompt.count(" ") != 1: 30 | print(f"Got invalid command '{prompt}'") 31 | continue 32 | _, width = prompt.split() 33 | options.width = 16 * (int(width) // 16) 34 | print(f"Setting width to {options.width}") 35 | elif prompt.startswith("/h"): 36 | if prompt.count(" ") != 1: 37 | print(f"Got invalid command '{prompt}'") 38 | continue 39 | _, height = prompt.split() 40 | options.height = 16 * (int(height) // 16) 41 | print(f"Setting width to {options.height}") 42 | elif prompt.startswith("/g"): 43 | if prompt.count(" ") != 1: 44 | print(f"Got invalid command '{prompt}'") 45 | continue 46 | _, guidance = prompt.split() 47 | options.guidance = float(guidance) 48 | print(f"Setting guidance to {options.guidance}") 49 | elif prompt.startswith("/s"): 50 | if prompt.count(" ") != 1: 51 | print(f"Got invalid command '{prompt}'") 52 | continue 53 | _, seed = prompt.split() 54 | options.seed = int(seed) 55 | print(f"Setting seed to {options.seed}") 56 | elif prompt.startswith("/n"): 57 | if prompt.count(" ") != 1: 58 | print(f"Got invalid command '{prompt}'") 59 | continue 60 | _, steps = prompt.split() 61 | options.num_steps = int(steps) 62 | print(f"Setting seed to {options.num_steps}") 63 | elif prompt.startswith("/q"): 64 | print("Quitting") 65 | return None 66 | else: 67 | print(f"Got invalid command '{prompt}'") 68 | if prompt != "": 69 | options.prompt = prompt 70 | return options 71 | 72 | 73 | @torch.inference_mode() 74 | def main( 75 | name: str = "flux-schnell", 76 | width: int = 1360, 77 | height: int = 768, 78 | quantize_flow: bool = False, 79 | seed: int | None = None, 80 | prompt: str = ( 81 | "a photo of a forest with mist swirling around the tree trunks. The word " 82 | '"FLUX" is painted over it in big, red brush strokes with visible texture' 83 | ), 84 | device: str = "cuda" if torch.cuda.is_available() else "cpu", 85 | num_steps: int | None = None, 86 | loop: bool = False, 87 | guidance: float = 3.5, 88 | offload: bool = False, 89 | output_dir: str = "output", 90 | ): 91 | """ 92 | Sample the flux model. Either interactively (set `--loop`) or run for a 93 | single image. 94 | 95 | Args: 96 | name: Name of the model to load 97 | height: height of the sample in pixels (should be a multiple of 16) 98 | width: width of the sample in pixels (should be a multiple of 16) 99 | quantize_flow: quantizes some layers of the flow model 100 | seed: Set a seed for sampling 101 | output_name: where to save the output image, `{idx}` will be replaced 102 | by the index of the sample 103 | prompt: Prompt used for sampling 104 | device: Pytorch device 105 | num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) 106 | loop: start an interactive session and sample multiple times 107 | guidance: guidance value used for guidance distillation 108 | """ 109 | if name not in configs: 110 | available = ", ".join(configs.keys()) 111 | raise ValueError(f"Got unknown model name: {name}, chose from {available}") 112 | 113 | torch_device = torch.device(device) 114 | if num_steps is None: 115 | num_steps = 4 if name == "flux-schnell" else 50 116 | 117 | # allow for packing and conversion to latent space 118 | height = 16 * (height // 16) 119 | width = 16 * (width // 16) 120 | 121 | output_name = os.path.join(output_dir, "img_{idx}.png") 122 | if not os.path.exists(output_dir): 123 | os.makedirs(output_dir) 124 | idx = 0 125 | elif len(os.listdir(output_dir)) > 0: 126 | fns = glob(output_name.format(idx="*")) 127 | idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 128 | else: 129 | idx = 0 130 | 131 | # init all components 132 | t5 = load_t5(torch_device) 133 | clip = load_clip(torch_device) 134 | model = load_flow_model(name, device="cpu" if offload else torch_device, quantize=quantize_flow) 135 | ae = load_ae(name, device="cpu" if offload else torch_device) 136 | 137 | rng = torch.Generator(device="cpu") 138 | opts = SamplingOptions( 139 | prompt=prompt, 140 | width=width, 141 | height=height, 142 | num_steps=num_steps, 143 | guidance=guidance, 144 | seed=seed, 145 | ) 146 | 147 | if loop: 148 | opts = parse_prompt(opts) 149 | 150 | while opts is not None: 151 | if opts.seed is None: 152 | opts.seed = rng.seed() 153 | print(f"Generating '{opts.prompt}' with seed {opts.seed}") 154 | t0 = time.perf_counter() 155 | 156 | # prepare input 157 | x = get_noise( 158 | 1, 159 | opts.height, 160 | opts.width, 161 | device=torch_device, 162 | dtype=torch.bfloat16, 163 | seed=opts.seed, 164 | ) 165 | opts.seed = None 166 | if offload: 167 | ae = ae.cpu() 168 | torch.cuda.empty_cache() 169 | t5, clip = t5.to(torch_device), clip.to(torch_device) 170 | inp = prepare(t5, clip, x, prompt=opts.prompt) 171 | timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) 172 | 173 | # offload TEs to CPU, load model to gpu 174 | if offload: 175 | t5, clip = t5.cpu(), clip.cpu() 176 | torch.cuda.empty_cache() 177 | model = model.to(torch_device) 178 | 179 | # denoise initial noise 180 | x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) 181 | 182 | # offload model, load autoencoder to gpu 183 | if offload: 184 | model.cpu() 185 | torch.cuda.empty_cache() 186 | ae.decoder.to(x.device) 187 | 188 | # decode latents to pixel space 189 | x = unpack(x.float(), opts.height, opts.width) 190 | with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): 191 | x = ae.decode(x) 192 | t1 = time.perf_counter() 193 | 194 | fn = output_name.format(idx=idx) 195 | print(f"Done in {t1 - t0:.1f}s. Saving {fn}") 196 | # bring into PIL format and save 197 | x = rearrange(x[0], "c h w -> h w c").clamp(-1, 1) 198 | img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) 199 | img.save(fn) 200 | idx += 1 201 | 202 | if loop: 203 | print("-" * 80) 204 | opts = parse_prompt(opts) 205 | else: 206 | opts = None 207 | 208 | 209 | def app(): 210 | Fire(main) 211 | 212 | 213 | if __name__ == "__main__": 214 | app() 215 | -------------------------------------------------------------------------------- /flux/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import Tensor 4 | from torch.nn.attention import SDPBackend, sdpa_kernel 5 | 6 | 7 | 8 | def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: 9 | q, k = apply_rope(q, k, pe) 10 | # Only enable flash attention backend 11 | with sdpa_kernel(SDPBackend.FLASH_ATTENTION): 12 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 13 | x = rearrange(x, "B H L D -> B L (H D)") 14 | 15 | return x 16 | 17 | 18 | def rope(pos: Tensor, dim: int, theta: int) -> Tensor: 19 | assert dim % 2 == 0 20 | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim 21 | omega = 1.0 / (theta**scale) 22 | out = torch.einsum("...n,d->...nd", pos, omega) 23 | out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) 24 | out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) 25 | return out.float() 26 | 27 | 28 | def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: 29 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 30 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 31 | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 32 | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 33 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) 34 | -------------------------------------------------------------------------------- /flux/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | 6 | from flux.modules.layers import ( 7 | DoubleStreamBlock, 8 | EmbedND, 9 | LastLayer, 10 | MLPEmbedder, 11 | SingleStreamBlock, 12 | timestep_embedding, 13 | ) 14 | 15 | 16 | @dataclass 17 | class FluxParams: 18 | in_channels: int 19 | out_channels: int 20 | vec_in_dim: int 21 | context_in_dim: int 22 | hidden_size: int 23 | mlp_ratio: float 24 | num_heads: int 25 | depth: int 26 | depth_single_blocks: int 27 | axes_dim: list[int] 28 | theta: int 29 | qkv_bias: bool 30 | guidance_embed: bool 31 | 32 | 33 | class Flux(nn.Module): 34 | """ 35 | Transformer model for flow matching on sequences. 36 | """ 37 | 38 | def __init__(self, params: FluxParams): 39 | super().__init__() 40 | 41 | self.params = params 42 | self.in_channels = params.in_channels 43 | self.out_channels = params.out_channels 44 | if params.hidden_size % params.num_heads != 0: 45 | raise ValueError( 46 | f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" 47 | ) 48 | pe_dim = params.hidden_size // params.num_heads 49 | if sum(params.axes_dim) != pe_dim: 50 | raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") 51 | self.hidden_size = params.hidden_size 52 | self.num_heads = params.num_heads 53 | self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 54 | self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) 55 | self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 56 | self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) 57 | self.guidance_in = ( 58 | MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() 59 | ) 60 | self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) 61 | 62 | self.double_blocks = nn.ModuleList( 63 | [ 64 | DoubleStreamBlock( 65 | self.hidden_size, 66 | self.num_heads, 67 | mlp_ratio=params.mlp_ratio, 68 | qkv_bias=params.qkv_bias, 69 | ) 70 | for _ in range(params.depth) 71 | ] 72 | ) 73 | 74 | self.single_blocks = nn.ModuleList( 75 | [ 76 | SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) 77 | for _ in range(params.depth_single_blocks) 78 | ] 79 | ) 80 | 81 | self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) 82 | 83 | def forward( 84 | self, 85 | img: Tensor, 86 | img_ids: Tensor, 87 | txt: Tensor, 88 | txt_ids: Tensor, 89 | timesteps: Tensor, 90 | y: Tensor, 91 | guidance: Tensor | None = None, 92 | ) -> Tensor: 93 | if img.ndim != 3 or txt.ndim != 3: 94 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 95 | 96 | # running on sequences img 97 | img = self.img_in(img) 98 | vec = self.time_in(timestep_embedding(timesteps, 256)) 99 | if self.params.guidance_embed: 100 | if guidance is None: 101 | raise ValueError("Didn't get guidance strength for guidance distilled model.") 102 | vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) 103 | vec = vec + self.vector_in(y) 104 | txt = self.txt_in(txt) 105 | 106 | ids = torch.cat((txt_ids, img_ids), dim=1) 107 | pe = self.pe_embedder(ids) 108 | 109 | for block in self.double_blocks: 110 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe) 111 | 112 | img = torch.cat((txt, img), 1) 113 | for block in self.single_blocks: 114 | img = block(img, vec=vec, pe=pe) 115 | img = img[:, txt.shape[1] :, ...] 116 | 117 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 118 | return img 119 | -------------------------------------------------------------------------------- /flux/modules/autoencoder.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import rearrange 5 | from torch import Tensor, nn 6 | 7 | 8 | @dataclass 9 | class AutoEncoderParams: 10 | resolution: int 11 | in_channels: int 12 | ch: int 13 | out_ch: int 14 | ch_mult: list[int] 15 | num_res_blocks: int 16 | z_channels: int 17 | scale_factor: float 18 | shift_factor: float 19 | 20 | 21 | def swish(x: Tensor) -> Tensor: 22 | return x * torch.sigmoid(x) 23 | 24 | 25 | class AttnBlock(nn.Module): 26 | def __init__(self, in_channels: int): 27 | super().__init__() 28 | self.in_channels = in_channels 29 | 30 | self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 31 | 32 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 33 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 34 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 35 | self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) 36 | 37 | def attention(self, h_: Tensor) -> Tensor: 38 | h_ = self.norm(h_) 39 | q = self.q(h_) 40 | k = self.k(h_) 41 | v = self.v(h_) 42 | 43 | b, c, h, w = q.shape 44 | q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 45 | k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 46 | v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 47 | h_ = nn.functional.scaled_dot_product_attention(q, k, v) 48 | 49 | return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) 50 | 51 | def forward(self, x: Tensor) -> Tensor: 52 | return x + self.proj_out(self.attention(x)) 53 | 54 | 55 | class ResnetBlock(nn.Module): 56 | def __init__(self, in_channels: int, out_channels: int): 57 | super().__init__() 58 | self.in_channels = in_channels 59 | out_channels = in_channels if out_channels is None else out_channels 60 | self.out_channels = out_channels 61 | 62 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 63 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 64 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) 65 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 66 | if self.in_channels != self.out_channels: 67 | self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 68 | 69 | def forward(self, x): 70 | h = x 71 | h = self.norm1(h) 72 | h = swish(h) 73 | h = self.conv1(h) 74 | 75 | h = self.norm2(h) 76 | h = swish(h) 77 | h = self.conv2(h) 78 | 79 | if self.in_channels != self.out_channels: 80 | x = self.nin_shortcut(x) 81 | 82 | return x + h 83 | 84 | 85 | class Downsample(nn.Module): 86 | def __init__(self, in_channels: int): 87 | super().__init__() 88 | # no asymmetric padding in torch conv, must do it ourselves 89 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) 90 | 91 | def forward(self, x: Tensor): 92 | pad = (0, 1, 0, 1) 93 | x = nn.functional.pad(x, pad, mode="constant", value=0) 94 | x = self.conv(x) 95 | return x 96 | 97 | 98 | class Upsample(nn.Module): 99 | def __init__(self, in_channels: int): 100 | super().__init__() 101 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 102 | 103 | def forward(self, x: Tensor): 104 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 105 | x = self.conv(x) 106 | return x 107 | 108 | 109 | class Encoder(nn.Module): 110 | def __init__( 111 | self, 112 | resolution: int, 113 | in_channels: int, 114 | ch: int, 115 | ch_mult: list[int], 116 | num_res_blocks: int, 117 | z_channels: int, 118 | ): 119 | super().__init__() 120 | self.ch = ch 121 | self.num_resolutions = len(ch_mult) 122 | self.num_res_blocks = num_res_blocks 123 | self.resolution = resolution 124 | self.in_channels = in_channels 125 | # downsampling 126 | self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 127 | 128 | curr_res = resolution 129 | in_ch_mult = (1,) + tuple(ch_mult) 130 | self.in_ch_mult = in_ch_mult 131 | self.down = nn.ModuleList() 132 | block_in = self.ch 133 | for i_level in range(self.num_resolutions): 134 | block = nn.ModuleList() 135 | attn = nn.ModuleList() 136 | block_in = ch * in_ch_mult[i_level] 137 | block_out = ch * ch_mult[i_level] 138 | for _ in range(self.num_res_blocks): 139 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 140 | block_in = block_out 141 | down = nn.Module() 142 | down.block = block 143 | down.attn = attn 144 | if i_level != self.num_resolutions - 1: 145 | down.downsample = Downsample(block_in) 146 | curr_res = curr_res // 2 147 | self.down.append(down) 148 | 149 | # middle 150 | self.mid = nn.Module() 151 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 152 | self.mid.attn_1 = AttnBlock(block_in) 153 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 154 | 155 | # end 156 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 157 | self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) 158 | 159 | def forward(self, x: Tensor) -> Tensor: 160 | # downsampling 161 | hs = [self.conv_in(x)] 162 | for i_level in range(self.num_resolutions): 163 | for i_block in range(self.num_res_blocks): 164 | h = self.down[i_level].block[i_block](hs[-1]) 165 | if len(self.down[i_level].attn) > 0: 166 | h = self.down[i_level].attn[i_block](h) 167 | hs.append(h) 168 | if i_level != self.num_resolutions - 1: 169 | hs.append(self.down[i_level].downsample(hs[-1])) 170 | 171 | # middle 172 | h = hs[-1] 173 | h = self.mid.block_1(h) 174 | h = self.mid.attn_1(h) 175 | h = self.mid.block_2(h) 176 | # end 177 | h = self.norm_out(h) 178 | h = swish(h) 179 | h = self.conv_out(h) 180 | return h 181 | 182 | 183 | class Decoder(nn.Module): 184 | def __init__( 185 | self, 186 | ch: int, 187 | out_ch: int, 188 | ch_mult: list[int], 189 | num_res_blocks: int, 190 | in_channels: int, 191 | resolution: int, 192 | z_channels: int, 193 | ): 194 | super().__init__() 195 | self.ch = ch 196 | self.num_resolutions = len(ch_mult) 197 | self.num_res_blocks = num_res_blocks 198 | self.resolution = resolution 199 | self.in_channels = in_channels 200 | self.ffactor = 2 ** (self.num_resolutions - 1) 201 | 202 | # compute in_ch_mult, block_in and curr_res at lowest res 203 | block_in = ch * ch_mult[self.num_resolutions - 1] 204 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 205 | self.z_shape = (1, z_channels, curr_res, curr_res) 206 | 207 | # z to block_in 208 | self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 209 | 210 | # middle 211 | self.mid = nn.Module() 212 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 213 | self.mid.attn_1 = AttnBlock(block_in) 214 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 215 | 216 | # upsampling 217 | self.up = nn.ModuleList() 218 | for i_level in reversed(range(self.num_resolutions)): 219 | block = nn.ModuleList() 220 | attn = nn.ModuleList() 221 | block_out = ch * ch_mult[i_level] 222 | for _ in range(self.num_res_blocks + 1): 223 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 224 | block_in = block_out 225 | up = nn.Module() 226 | up.block = block 227 | up.attn = attn 228 | if i_level != 0: 229 | up.upsample = Upsample(block_in) 230 | curr_res = curr_res * 2 231 | self.up.insert(0, up) # prepend to get consistent order 232 | 233 | # end 234 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 235 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 236 | 237 | def forward(self, z: Tensor) -> Tensor: 238 | # z to block_in 239 | h = self.conv_in(z) 240 | 241 | # middle 242 | h = self.mid.block_1(h) 243 | h = self.mid.attn_1(h) 244 | h = self.mid.block_2(h) 245 | 246 | # upsampling 247 | for i_level in reversed(range(self.num_resolutions)): 248 | for i_block in range(self.num_res_blocks + 1): 249 | h = self.up[i_level].block[i_block](h) 250 | if len(self.up[i_level].attn) > 0: 251 | h = self.up[i_level].attn[i_block](h) 252 | if i_level != 0: 253 | h = self.up[i_level].upsample(h) 254 | 255 | # end 256 | h = self.norm_out(h) 257 | h = swish(h) 258 | h = self.conv_out(h) 259 | return h 260 | 261 | 262 | class DiagonalGaussian(nn.Module): 263 | def __init__(self, sample: bool = True, chunk_dim: int = 1): 264 | super().__init__() 265 | self.sample = sample 266 | self.chunk_dim = chunk_dim 267 | 268 | def forward(self, z: Tensor) -> Tensor: 269 | mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) 270 | if self.sample: 271 | std = torch.exp(0.5 * logvar) 272 | return mean + std * torch.randn_like(mean) 273 | else: 274 | return mean 275 | 276 | 277 | class AutoEncoder(nn.Module): 278 | def __init__(self, params: AutoEncoderParams): 279 | super().__init__() 280 | self.encoder = Encoder( 281 | resolution=params.resolution, 282 | in_channels=params.in_channels, 283 | ch=params.ch, 284 | ch_mult=params.ch_mult, 285 | num_res_blocks=params.num_res_blocks, 286 | z_channels=params.z_channels, 287 | ) 288 | self.decoder = Decoder( 289 | resolution=params.resolution, 290 | in_channels=params.in_channels, 291 | ch=params.ch, 292 | out_ch=params.out_ch, 293 | ch_mult=params.ch_mult, 294 | num_res_blocks=params.num_res_blocks, 295 | z_channels=params.z_channels, 296 | ) 297 | self.reg = DiagonalGaussian() 298 | 299 | self.scale_factor = params.scale_factor 300 | self.shift_factor = params.shift_factor 301 | 302 | def encode(self, x: Tensor) -> Tensor: 303 | z = self.reg(self.encoder(x)) 304 | z = self.scale_factor * (z - self.shift_factor) 305 | return z 306 | 307 | def decode(self, z: Tensor) -> Tensor: 308 | z = z / self.scale_factor + self.shift_factor 309 | return self.decoder(z) 310 | 311 | def forward(self, x: Tensor) -> Tensor: 312 | return self.decode(self.encode(x)) 313 | -------------------------------------------------------------------------------- /flux/modules/conditioner.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | from transformers import CLIPTextModel, T5EncoderModel, CLIPTokenizer, T5Tokenizer 3 | 4 | 5 | class HFEmbedder(nn.Module): 6 | def __init__(self, version: str, max_length: int, is_clip=False, **hf_kwargs): 7 | super().__init__() 8 | self.is_clip = is_clip 9 | self.max_length = max_length 10 | self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" 11 | 12 | if self.is_clip: 13 | self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version + "/tokenizer", max_length=max_length) 14 | self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version + "/model", **hf_kwargs) 15 | else: 16 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version + "/tokenizer", max_length=max_length) 17 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version + "/model", **hf_kwargs) 18 | 19 | self.hf_module = self.hf_module.eval().requires_grad_(False) 20 | 21 | def forward(self, text: list[str]) -> Tensor: 22 | batch_encoding = self.tokenizer( 23 | text, 24 | truncation=True, 25 | max_length=self.max_length, 26 | return_length=False, 27 | return_overflowing_tokens=False, 28 | padding="max_length", 29 | return_tensors="pt", 30 | ) 31 | 32 | outputs = self.hf_module( 33 | input_ids=batch_encoding["input_ids"].to(self.hf_module.device), 34 | attention_mask=None, 35 | output_hidden_states=False, 36 | ) 37 | return outputs[self.output_key] 38 | 39 | 40 | class PreLoadedHFEmbedder(nn.Module): 41 | """ 42 | Does the same thing as the HFEmbedder, but lets you share the tokenizer & hf module. Could also just share the HFEmbedder but here we are. 43 | """ 44 | def __init__(self, is_clip: bool, max_length: int, tokenizer, hf_module): 45 | super().__init__() 46 | self.is_clip = is_clip 47 | self.max_length = max_length 48 | self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" 49 | 50 | self.tokenizer = tokenizer 51 | self.hf_module = hf_module 52 | 53 | self.hf_module = self.hf_module.eval().requires_grad_(False) 54 | 55 | def forward(self, text: list[str]) -> Tensor: 56 | batch_encoding = self.tokenizer( 57 | text, 58 | truncation=True, 59 | max_length=self.max_length, 60 | return_length=False, 61 | return_overflowing_tokens=False, 62 | padding="max_length", 63 | return_tensors="pt", 64 | ) 65 | 66 | outputs = self.hf_module( 67 | input_ids=batch_encoding["input_ids"].to(self.hf_module.device), 68 | attention_mask=None, 69 | output_hidden_states=False, 70 | ) 71 | return outputs[self.output_key] 72 | -------------------------------------------------------------------------------- /flux/modules/image_embedders.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Protocol 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from einops import rearrange, repeat 8 | from PIL import Image 9 | from safetensors.torch import load_file as load_sft 10 | from torch import nn 11 | from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel 12 | 13 | # hack to avoid circular imports 14 | def print_load_warning(missing: list[str], unexpected: list[str]) -> None: 15 | if len(missing) > 0 and len(unexpected) > 0: 16 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 17 | print("\n" + "-" * 79 + "\n") 18 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 19 | elif len(missing) > 0: 20 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 21 | elif len(unexpected) > 0: 22 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 23 | 24 | 25 | class ImageEncoder(Protocol): 26 | def __call__(self, img: torch.Tensor) -> torch.Tensor: ... 27 | 28 | 29 | class DepthImageEncoder(ImageEncoder): 30 | depth_model_name = "LiheYoung/depth-anything-large-hf" 31 | 32 | def __init__(self, device, depth_model_path=depth_model_name): 33 | self.device = device 34 | self.depth_model = AutoModelForDepthEstimation.from_pretrained(depth_model_path).to(device) 35 | self.processor = AutoProcessor.from_pretrained(depth_model_path) 36 | 37 | def __call__(self, img: torch.Tensor) -> torch.Tensor: 38 | hw = img.shape[-2:] 39 | 40 | img = torch.clamp(img, -1.0, 1.0) 41 | img_byte = ((img + 1.0) * 127.5).byte() 42 | 43 | img = self.processor(img_byte, return_tensors="pt")["pixel_values"] 44 | depth = self.depth_model(img.to(self.device)).predicted_depth 45 | depth = repeat(depth, "b h w -> b 3 h w") 46 | depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True) 47 | 48 | depth = depth / 127.5 - 1.0 49 | return depth 50 | 51 | 52 | class CannyImageEncoder(ImageEncoder): 53 | def __init__( 54 | self, 55 | device, 56 | min_t: int = 50, 57 | max_t: int = 200, 58 | ): 59 | self.device = device 60 | self.min_t = min_t 61 | self.max_t = max_t 62 | 63 | def __call__(self, img: torch.Tensor) -> torch.Tensor: 64 | assert img.shape[0] == 1, "Only batch size 1 is supported" 65 | 66 | img = rearrange(img[0], "c h w -> h w c") 67 | img = torch.clamp(img, -1.0, 1.0) 68 | img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8) 69 | 70 | # Apply Canny edge detection 71 | canny = cv2.Canny(img_np, self.min_t, self.max_t) 72 | 73 | # Convert back to torch tensor and reshape 74 | canny = torch.from_numpy(canny).float() / 127.5 - 1.0 75 | canny = rearrange(canny, "h w -> 1 1 h w") 76 | canny = repeat(canny, "b 1 ... -> b 3 ...") 77 | return canny.to(self.device) 78 | 79 | 80 | class ReduxImageEncoder(nn.Module): 81 | siglip_model_name = "google/siglip-so400m-patch14-384" 82 | 83 | def __init__( 84 | self, 85 | device, 86 | redux_dim: int = 1152, 87 | txt_in_features: int = 4096, 88 | redux_path: str | None = os.getenv("FLUX_REDUX"), 89 | siglip_path: str | None = siglip_model_name, 90 | dtype=torch.bfloat16, 91 | ) -> None: 92 | assert redux_path is not None, "Redux path must be provided" 93 | 94 | super().__init__() 95 | 96 | self.redux_dim = redux_dim 97 | self.device = device if isinstance(device, torch.device) else torch.device(device) 98 | self.dtype = dtype 99 | 100 | with self.device: 101 | self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype) 102 | self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype) 103 | 104 | sd = load_sft(redux_path, device=str(device)) 105 | missing, unexpected = self.load_state_dict(sd, strict=False, assign=True) 106 | print_load_warning(missing, unexpected) 107 | 108 | self.siglip = SiglipVisionModel.from_pretrained(siglip_path).to(dtype=dtype) 109 | self.normalize = SiglipImageProcessor.from_pretrained(siglip_path) 110 | 111 | 112 | def __call__(self, x: Image.Image) -> torch.Tensor: 113 | imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True) 114 | 115 | _encoded_x = self.siglip(**imgs.to(device=self.device, dtype=self.dtype)).last_hidden_state 116 | 117 | projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x))) 118 | 119 | return projected_x -------------------------------------------------------------------------------- /flux/modules/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | from einops import rearrange 6 | from torch import Tensor, nn 7 | 8 | from flux.math import attention, rope 9 | 10 | 11 | class EmbedND(nn.Module): 12 | def __init__(self, dim: int, theta: int, axes_dim: list[int]): 13 | super().__init__() 14 | self.dim = dim 15 | self.theta = theta 16 | self.axes_dim = axes_dim 17 | 18 | def forward(self, ids: Tensor) -> Tensor: 19 | n_axes = ids.shape[-1] 20 | emb = torch.cat( 21 | [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 22 | dim=-3, 23 | ) 24 | 25 | return emb.unsqueeze(1) 26 | 27 | 28 | def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): 29 | """ 30 | Create sinusoidal timestep embeddings. 31 | :param t: a 1-D Tensor of N indices, one per batch element. 32 | These may be fractional. 33 | :param dim: the dimension of the output. 34 | :param max_period: controls the minimum frequency of the embeddings. 35 | :return: an (N, D) Tensor of positional embeddings. 36 | """ 37 | t = time_factor * t 38 | half = dim // 2 39 | freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( 40 | t.device 41 | ) 42 | 43 | args = t[:, None].float() * freqs[None] 44 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 45 | if dim % 2: 46 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 47 | if torch.is_floating_point(t): 48 | embedding = embedding.to(t) 49 | return embedding 50 | 51 | 52 | class MLPEmbedder(nn.Module): 53 | def __init__(self, in_dim: int, hidden_dim: int): 54 | super().__init__() 55 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) 56 | self.silu = nn.SiLU() 57 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) 58 | 59 | def forward(self, x: Tensor) -> Tensor: 60 | return self.out_layer(self.silu(self.in_layer(x))) 61 | 62 | 63 | class RMSNorm(torch.nn.Module): 64 | def __init__(self, dim: int): 65 | super().__init__() 66 | self.scale = nn.Parameter(torch.ones(dim)) 67 | 68 | def forward(self, x: Tensor): 69 | x_dtype = x.dtype 70 | x = x.float() 71 | rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) 72 | return (x * rrms).to(dtype=x_dtype) * self.scale 73 | 74 | 75 | class QKNorm(torch.nn.Module): 76 | def __init__(self, dim: int): 77 | super().__init__() 78 | self.query_norm = RMSNorm(dim) 79 | self.key_norm = RMSNorm(dim) 80 | 81 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: 82 | q = self.query_norm(q) 83 | k = self.key_norm(k) 84 | return q.to(v), k.to(v) 85 | 86 | 87 | class SelfAttention(nn.Module): 88 | def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): 89 | super().__init__() 90 | self.num_heads = num_heads 91 | head_dim = dim // num_heads 92 | 93 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 94 | self.norm = QKNorm(head_dim) 95 | self.proj = nn.Linear(dim, dim) 96 | 97 | def forward(self, x: Tensor, pe: Tensor) -> Tensor: 98 | qkv = self.qkv(x) 99 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 100 | q, k = self.norm(q, k, v) 101 | x = attention(q, k, v, pe=pe) 102 | x = self.proj(x) 103 | return x 104 | 105 | 106 | @dataclass 107 | class ModulationOut: 108 | shift: Tensor 109 | scale: Tensor 110 | gate: Tensor 111 | 112 | 113 | class Modulation(nn.Module): 114 | def __init__(self, dim: int, double: bool): 115 | super().__init__() 116 | self.is_double = double 117 | self.multiplier = 6 if double else 3 118 | self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) 119 | 120 | def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: 121 | out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) 122 | 123 | return ( 124 | ModulationOut(*out[:3]), 125 | ModulationOut(*out[3:]) if self.is_double else None, 126 | ) 127 | 128 | 129 | class DoubleStreamBlock(nn.Module): 130 | def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): 131 | super().__init__() 132 | 133 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 134 | self.num_heads = num_heads 135 | self.hidden_size = hidden_size 136 | self.img_mod = Modulation(hidden_size, double=True) 137 | self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 138 | self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 139 | 140 | self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 141 | self.img_mlp = nn.Sequential( 142 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 143 | nn.GELU(approximate="tanh"), 144 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 145 | ) 146 | 147 | self.txt_mod = Modulation(hidden_size, double=True) 148 | self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 149 | self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 150 | 151 | self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 152 | self.txt_mlp = nn.Sequential( 153 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 154 | nn.GELU(approximate="tanh"), 155 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 156 | ) 157 | 158 | def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: 159 | img_mod1, img_mod2 = self.img_mod(vec) 160 | txt_mod1, txt_mod2 = self.txt_mod(vec) 161 | 162 | # prepare image for attention 163 | img_modulated = self.img_norm1(img) 164 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 165 | img_qkv = self.img_attn.qkv(img_modulated) 166 | img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 167 | img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) 168 | 169 | # prepare txt for attention 170 | txt_modulated = self.txt_norm1(txt) 171 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 172 | txt_qkv = self.txt_attn.qkv(txt_modulated) 173 | txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 174 | txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) 175 | 176 | # run actual attention 177 | q = torch.cat((txt_q, img_q), dim=2) 178 | k = torch.cat((txt_k, img_k), dim=2) 179 | v = torch.cat((txt_v, img_v), dim=2) 180 | 181 | attn = attention(q, k, v, pe=pe) 182 | txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] 183 | 184 | # calculate the img bloks 185 | img = img + img_mod1.gate * self.img_attn.proj(img_attn) 186 | img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) 187 | 188 | # calculate the txt bloks 189 | txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) 190 | txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) 191 | return img, txt 192 | 193 | 194 | class SingleStreamBlock(nn.Module): 195 | """ 196 | A DiT block with parallel linear layers as described in 197 | https://arxiv.org/abs/2302.05442 and adapted modulation interface. 198 | """ 199 | 200 | def __init__( 201 | self, 202 | hidden_size: int, 203 | num_heads: int, 204 | mlp_ratio: float = 4.0, 205 | qk_scale: float | None = None, 206 | ): 207 | super().__init__() 208 | self.hidden_dim = hidden_size 209 | self.num_heads = num_heads 210 | head_dim = hidden_size // num_heads 211 | self.scale = qk_scale or head_dim**-0.5 212 | 213 | self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 214 | # qkv and mlp_in 215 | self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) 216 | # proj and mlp_out 217 | self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) 218 | 219 | self.norm = QKNorm(head_dim) 220 | 221 | self.hidden_size = hidden_size 222 | self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 223 | 224 | self.mlp_act = nn.GELU(approximate="tanh") 225 | self.modulation = Modulation(hidden_size, double=False) 226 | 227 | def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: 228 | mod, _ = self.modulation(vec) 229 | x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift 230 | qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) 231 | 232 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 233 | q, k = self.norm(q, k, v) 234 | 235 | # compute attention 236 | attn = attention(q, k, v, pe=pe) 237 | # compute activation in mlp stream, cat again and run second linear layer 238 | output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) 239 | return x + mod.gate * output 240 | 241 | 242 | class LastLayer(nn.Module): 243 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int): 244 | super().__init__() 245 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 246 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 247 | self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) 248 | 249 | def forward(self, x: Tensor, vec: Tensor) -> Tensor: 250 | shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) 251 | x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] 252 | x = self.linear(x) 253 | return x 254 | -------------------------------------------------------------------------------- /flux/modules/quantize.py: -------------------------------------------------------------------------------- 1 | # Code is based on: 2 | # https://github.com/pytorch-labs/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/quantize.py 3 | 4 | # Subject to the followring copyright notice / license: 5 | # Copyright 2023 Meta 6 | # 7 | # Redistribution and use in source and binary forms, with or without 8 | # modification, are permitted provided that the following conditions are met: 9 | # 10 | # 1. Redistributions of source code must retain the above copyright notice, 11 | # this list of conditions and the following disclaimer. 12 | # 13 | # 2. Redistributions in binary form must reproduce the above copyright notice, 14 | # this list of conditions and the following disclaimer in the documentation 15 | # and/or other materials provided with the distribution. 16 | # 17 | # 3. Neither the name of the copyright holder nor the names of its contributors 18 | # may be used to endorse or promote products derived from this software 19 | # without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 25 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 31 | # POSSIBILITY OF SUCH DAMAGE. 32 | 33 | 34 | import torch 35 | from torch import Tensor, nn 36 | from torch.nn import functional as F 37 | 38 | 39 | def replace_linear_weight_only_int8_per_channel(module): 40 | for name, child in module.named_children(): 41 | if isinstance(child, nn.Linear): 42 | setattr( 43 | module, 44 | name, 45 | WeightOnlyInt8Linear( 46 | child.in_features, 47 | child.out_features, 48 | dtype=child.weight.dtype, 49 | bias=child.bias is not None, 50 | ), 51 | ) 52 | else: 53 | replace_linear_weight_only_int8_per_channel(child) 54 | 55 | 56 | class WeightOnlyInt8Linear(torch.nn.Module): 57 | __constants__ = ["in_features", "out_features"] 58 | in_features: int 59 | out_features: int 60 | weight: Tensor 61 | scales: Tensor 62 | bias: Tensor | None 63 | 64 | def __init__( 65 | self, 66 | in_features: int, 67 | out_features: int, 68 | dtype: torch.dtype, 69 | bias: bool = False, 70 | ) -> None: 71 | super().__init__() 72 | self.in_features = in_features 73 | self.out_features = out_features 74 | self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) 75 | self.register_buffer("scales", torch.ones(out_features, dtype=dtype)) 76 | 77 | if bias: 78 | self.register_buffer("bias", torch.zeros((out_features), dtype=dtype)) 79 | else: 80 | self.bias = None 81 | 82 | def forward(self, input: Tensor) -> Tensor: 83 | pre_bias = F.linear(input, self.weight.to(input)) * self.scales 84 | return (pre_bias + self.bias.to(input)) if self.bias is None else pre_bias 85 | -------------------------------------------------------------------------------- /flux/sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable, Optional 3 | 4 | from PIL import Image 5 | import torch 6 | from einops import rearrange, repeat 7 | from torch import Tensor 8 | from tqdm.auto import tqdm 9 | 10 | from .model import Flux 11 | from .modules.conditioner import HFEmbedder 12 | 13 | def get_noise( 14 | num_samples: int, 15 | height: int, 16 | width: int, 17 | device: torch.device, 18 | dtype: torch.dtype, 19 | seed: int, 20 | ): 21 | return torch.randn( 22 | num_samples, 23 | 16, 24 | # allow for packing 25 | 2 * math.ceil(height / 16), 26 | 2 * math.ceil(width / 16), 27 | device=device, 28 | dtype=dtype, 29 | generator=torch.Generator(device=device).manual_seed(seed), 30 | ) 31 | 32 | 33 | def prepare( 34 | t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str] 35 | ) -> dict[str, Tensor]: 36 | bs, c, h, w = img.shape 37 | if bs == 1 and not isinstance(prompt, str): 38 | bs = len(prompt) 39 | 40 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 41 | if img.shape[0] == 1 and bs > 1: 42 | img = repeat(img, "1 ... -> bs ...", bs=bs) 43 | 44 | img_ids = torch.zeros(h // 2, w // 2, 3) 45 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] 46 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] 47 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 48 | 49 | if isinstance(prompt, str): 50 | prompt = [prompt] 51 | txt = t5(prompt) 52 | if txt.shape[0] == 1 and bs > 1: 53 | txt = repeat(txt, "1 ... -> bs ...", bs=bs) 54 | txt_ids = torch.zeros(bs, txt.shape[1], 3) 55 | 56 | vec = clip(prompt) 57 | if vec.shape[0] == 1 and bs > 1: 58 | vec = repeat(vec, "1 ... -> bs ...", bs=bs) 59 | 60 | return { 61 | "img": img, 62 | "img_ids": img_ids.to(img.device), 63 | "txt": txt.to(img.device), 64 | "txt_ids": txt_ids.to(img.device), 65 | "vec": vec.to(img.device), 66 | } 67 | 68 | 69 | def prepare_redux( 70 | t5: HFEmbedder, 71 | clip: HFEmbedder, 72 | img: Tensor, 73 | prompt: str | list[str], 74 | encoder: "ReduxImageEncoder", 75 | img_cond_path: str, 76 | ) -> dict[str, Tensor]: 77 | bs, _, h, w = img.shape 78 | if bs == 1 and not isinstance(prompt, str): 79 | bs = len(prompt) 80 | 81 | img_cond = Image.open(img_cond_path).convert("RGB") 82 | with torch.no_grad(): 83 | img_cond = encoder(img_cond) 84 | 85 | img_cond = img_cond.to(torch.bfloat16) 86 | if img_cond.shape[0] == 1 and bs > 1: 87 | img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) 88 | 89 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 90 | if img.shape[0] == 1 and bs > 1: 91 | img = repeat(img, "1 ... -> bs ...", bs=bs) 92 | 93 | img_ids = torch.zeros(h // 2, w // 2, 3) 94 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] 95 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] 96 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 97 | 98 | if isinstance(prompt, str): 99 | prompt = [prompt] 100 | txt = t5(prompt) 101 | txt = torch.cat((txt, img_cond.to(txt)), dim=-2) 102 | if txt.shape[0] == 1 and bs > 1: 103 | txt = repeat(txt, "1 ... -> bs ...", bs=bs) 104 | txt_ids = torch.zeros(bs, txt.shape[1], 3) 105 | 106 | vec = clip(prompt) 107 | if vec.shape[0] == 1 and bs > 1: 108 | vec = repeat(vec, "1 ... -> bs ...", bs=bs) 109 | 110 | return { 111 | "img": img, 112 | "img_ids": img_ids.to(img.device), 113 | "txt": txt.to(img.device), 114 | "txt_ids": txt_ids.to(img.device), 115 | "vec": vec.to(img.device), 116 | } 117 | 118 | 119 | def time_shift(mu: float, sigma: float, t: Tensor): 120 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 121 | 122 | 123 | def get_lin_function( 124 | x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 125 | ) -> Callable[[float], float]: 126 | m = (y2 - y1) / (x2 - x1) 127 | b = y1 - m * x1 128 | return lambda x: m * x + b 129 | 130 | 131 | def get_schedule( 132 | num_steps: int, 133 | image_seq_len: int, 134 | base_shift: float = 0.5, 135 | max_shift: float = 1.15, 136 | shift: bool = True, 137 | ) -> list[float]: 138 | # extra step for zero 139 | timesteps = torch.linspace(1, 0, num_steps + 1) 140 | 141 | # shifting the schedule to favor high timesteps for higher signal images 142 | if shift: 143 | # eastimate mu based on linear estimation between two points 144 | mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) 145 | timesteps = time_shift(mu, 1.0, timesteps) 146 | 147 | return timesteps.tolist() 148 | 149 | 150 | def denoise_single_item( 151 | model: Flux, 152 | img: Tensor, 153 | img_ids: Tensor, 154 | txt: Tensor, 155 | txt_ids: Tensor, 156 | vec: Tensor, 157 | timesteps: list[float], 158 | guidance: float = 4.0, 159 | img_cond: Tensor | None = None, 160 | compile_run: bool = False, 161 | image_latents: Optional[Tensor] = None, 162 | mask: Optional[Tensor] = None, 163 | noise: Optional[Tensor] = None 164 | ): 165 | img = img.unsqueeze(0) 166 | img_ids = img_ids.unsqueeze(0) 167 | txt = txt.unsqueeze(0) 168 | txt_ids = txt_ids.unsqueeze(0) 169 | vec = vec.unsqueeze(0) 170 | guidance_vec = torch.full((1,), guidance, device=img.device, dtype=img.dtype) 171 | 172 | if compile_run: 173 | torch._dynamo.mark_dynamic( 174 | img, 1, min=256, max=8100 175 | ) # needs at least torch 2.4 176 | torch._dynamo.mark_dynamic(img_ids, 1, min=256, max=8100) 177 | torch._dynamo.mark_dynamic(img_cond, 1, min=256, max=8100) 178 | model = model.to(memory_format=torch.channels_last) 179 | model = torch.compile(model) 180 | 181 | for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:])): 182 | t_vec = torch.full((1,), t_curr, dtype=img.dtype, device=img.device) 183 | 184 | pred = model( 185 | img=torch.cat((img, img_cond), dim=-1) if img_cond is not None else img, 186 | img_ids=img_ids, 187 | txt=txt, 188 | txt_ids=txt_ids, 189 | y=vec, 190 | timesteps=t_vec, 191 | guidance=guidance_vec, 192 | ) 193 | 194 | img = img + (t_prev - t_curr) * pred.squeeze(0) 195 | if mask is not None: 196 | if t_prev != timesteps[-1]: 197 | proper_noise_latents = t_prev * noise + (1.0 - t_prev) * image_latents 198 | else: 199 | proper_noise_latents = image_latents 200 | 201 | img = (1 - mask) * proper_noise_latents + mask * img 202 | 203 | return img, model 204 | 205 | 206 | def denoise( 207 | model: Flux, 208 | # model input 209 | img: Tensor, 210 | img_ids: Tensor, 211 | txt: Tensor, 212 | txt_ids: Tensor, 213 | vec: Tensor, 214 | # sampling parameters 215 | timesteps: list[float], 216 | guidance: float = 4.0, 217 | img_cond: Tensor | None = None, 218 | compile_run: bool = False, 219 | image_latents: Optional[Tensor] = None, 220 | mask: Optional[Tensor] = None, 221 | noise: Optional[Tensor] = None 222 | ): 223 | batch_size = img.shape[0] 224 | output_imgs = [] 225 | 226 | for i in range(batch_size): 227 | denoised_img, model = denoise_single_item( 228 | model=model, 229 | img=img[i], 230 | img_ids=img_ids[i], 231 | txt=txt[i], 232 | txt_ids=txt_ids[i], 233 | vec=vec[i], 234 | timesteps=timesteps, 235 | guidance=guidance, 236 | img_cond=None if img_cond is None else img_cond, 237 | compile_run=compile_run, 238 | image_latents=image_latents, 239 | mask=mask, 240 | noise=None if noise is None else noise[i] 241 | ) 242 | compile_run = False 243 | output_imgs.append(denoised_img) 244 | 245 | return torch.cat(output_imgs), model 246 | 247 | 248 | def unpack(x: Tensor, height: int, width: int) -> Tensor: 249 | return rearrange( 250 | x, 251 | "b (h w) (c ph pw) -> b c (h ph) (w pw)", 252 | h=math.ceil(height / 16), 253 | w=math.ceil(width / 16), 254 | ph=2, 255 | pw=2, 256 | ) 257 | -------------------------------------------------------------------------------- /flux/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | import subprocess 4 | import time 5 | 6 | import torch 7 | from safetensors.torch import load_file as load_sft 8 | 9 | from flux.model import Flux, FluxParams 10 | from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams 11 | from flux.modules.conditioner import HFEmbedder 12 | from flux.modules.image_embedders import DepthImageEncoder, ReduxImageEncoder 13 | from flux.modules.quantize import replace_linear_weight_only_int8_per_channel 14 | from huggingface_hub import hf_hub_download 15 | from pathlib import Path 16 | 17 | 18 | @dataclass 19 | class ModelSpec: 20 | params: FluxParams 21 | ae_params: AutoEncoderParams 22 | ckpt_path: str | None 23 | ckpt_url: str | None 24 | ae_path: str | None 25 | ae_url: str | None 26 | 27 | T5_URL = "https://weights.replicate.delivery/default/official-models/flux/t5/t5-v1_1-xxl.tar" 28 | T5_CACHE = "./model-cache/t5" 29 | CLIP_URL = "https://weights.replicate.delivery/default/official-models/flux/clip/clip-vit-large-patch14.tar" 30 | CLIP_CACHE = "./model-cache/clip" 31 | SCHNELL_CACHE = "./model-cache/schnell/schnell.sft" 32 | SCHNELL_URL = "https://weights.replicate.delivery/default/official-models/flux/schnell/schnell.sft" 33 | DEV_CACHE = "./model-cache/dev/dev.sft" 34 | DEV_URL = "https://weights.replicate.delivery/default/official-models/flux/dev/dev.sft" 35 | DEV_CANNY_CACHE = "./model-cache/dev-canny/dev-canny.safetensors" 36 | DEV_CANNY_URL = "https://weights.replicate.delivery/default/black-forest-labs/ctrl-n-fill/flux1-canny-dev.safetensors" 37 | DEV_DEPTH_CACHE = "./model-cache/dev-depth/dev-depth.safetensors" 38 | DEV_DEPTH_URL = "https://weights.replicate.delivery/default/black-forest-labs/ctrl-n-fill/flux1-depth-dev.safetensors" 39 | DEV_INPAINTING_CACHE = "./model-cache/dev-inpainting/dev-inpainting.safetensors" 40 | DEV_INPAINTING_URL = "https://weights.replicate.delivery/default/black-forest-labs/ctrl-n-fill/flux1-fill-dev.safetensors" 41 | AE_CACHE = "./model-cache/ae/ae.sft" 42 | AE_URL = "https://weights.replicate.delivery/default/official-models/flux/ae/ae.sft" 43 | SIGLIP_URL = "https://weights.replicate.delivery/default/google/siglip-so400m-patch14-384/model-bf16.tar" 44 | SIGLIP_CACHE = "./model-cache/siglip" 45 | REDUX_URL = "https://weights.replicate.delivery/default/black-forest-labs/ctrl-n-fill/flux1-redux-dev.safetensors" 46 | REDUX_CACHE = "./model-cache/redux/flux1-redux-dev.safetensors" 47 | DEPTH_URL = "https://weights.replicate.delivery/default/liheyoung/depth-anything-large/flux-depth-model.tar" 48 | DEPTH_CACHE = "./model-cache/depth" 49 | 50 | configs = { 51 | "flux-dev": ModelSpec( 52 | ckpt_path=DEV_CACHE, 53 | ckpt_url=DEV_URL, 54 | params=FluxParams( 55 | in_channels=64, 56 | out_channels=64, 57 | vec_in_dim=768, 58 | context_in_dim=4096, 59 | hidden_size=3072, 60 | mlp_ratio=4.0, 61 | num_heads=24, 62 | depth=19, 63 | depth_single_blocks=38, 64 | axes_dim=[16, 56, 56], 65 | theta=10_000, 66 | qkv_bias=True, 67 | guidance_embed=True, 68 | ), 69 | ae_path=AE_CACHE, 70 | ae_url=AE_URL, 71 | ae_params=AutoEncoderParams( 72 | resolution=256, 73 | in_channels=3, 74 | ch=128, 75 | out_ch=3, 76 | ch_mult=[1, 2, 4, 4], 77 | num_res_blocks=2, 78 | z_channels=16, 79 | scale_factor=0.3611, 80 | shift_factor=0.1159, 81 | ), 82 | ), 83 | "flux-schnell": ModelSpec( 84 | ckpt_path=SCHNELL_CACHE, 85 | ckpt_url=SCHNELL_URL, 86 | params=FluxParams( 87 | in_channels=64, 88 | out_channels=64, 89 | vec_in_dim=768, 90 | context_in_dim=4096, 91 | hidden_size=3072, 92 | mlp_ratio=4.0, 93 | num_heads=24, 94 | depth=19, 95 | depth_single_blocks=38, 96 | axes_dim=[16, 56, 56], 97 | theta=10_000, 98 | qkv_bias=True, 99 | guidance_embed=False, 100 | ), 101 | ae_path=AE_CACHE, 102 | ae_url=AE_URL, 103 | ae_params=AutoEncoderParams( 104 | resolution=256, 105 | in_channels=3, 106 | ch=128, 107 | out_ch=3, 108 | ch_mult=[1, 2, 4, 4], 109 | num_res_blocks=2, 110 | z_channels=16, 111 | scale_factor=0.3611, 112 | shift_factor=0.1159, 113 | ), 114 | ), 115 | "flux-canny-dev": ModelSpec( 116 | ckpt_path=DEV_CANNY_CACHE, 117 | ckpt_url=DEV_CANNY_URL, 118 | params=FluxParams( 119 | in_channels=128, 120 | out_channels=64, 121 | vec_in_dim=768, 122 | context_in_dim=4096, 123 | hidden_size=3072, 124 | mlp_ratio=4.0, 125 | num_heads=24, 126 | depth=19, 127 | depth_single_blocks=38, 128 | axes_dim=[16, 56, 56], 129 | theta=10_000, 130 | qkv_bias=True, 131 | guidance_embed=True, 132 | ), 133 | ae_path=AE_CACHE, 134 | ae_url=AE_URL, 135 | ae_params=AutoEncoderParams( 136 | resolution=256, 137 | in_channels=3, 138 | ch=128, 139 | out_ch=3, 140 | ch_mult=[1, 2, 4, 4], 141 | num_res_blocks=2, 142 | z_channels=16, 143 | scale_factor=0.3611, 144 | shift_factor=0.1159, 145 | ), 146 | ), 147 | "flux-depth-dev": ModelSpec( 148 | ckpt_path=DEV_DEPTH_CACHE, 149 | ckpt_url=DEV_DEPTH_URL, 150 | params=FluxParams( 151 | in_channels=128, 152 | out_channels=64, 153 | vec_in_dim=768, 154 | context_in_dim=4096, 155 | hidden_size=3072, 156 | mlp_ratio=4.0, 157 | num_heads=24, 158 | depth=19, 159 | depth_single_blocks=38, 160 | axes_dim=[16, 56, 56], 161 | theta=10_000, 162 | qkv_bias=True, 163 | guidance_embed=True, 164 | ), 165 | ae_path=AE_CACHE, 166 | ae_url=AE_URL, 167 | ae_params=AutoEncoderParams( 168 | resolution=256, 169 | in_channels=3, 170 | ch=128, 171 | out_ch=3, 172 | ch_mult=[1, 2, 4, 4], 173 | num_res_blocks=2, 174 | z_channels=16, 175 | scale_factor=0.3611, 176 | shift_factor=0.1159, 177 | ), 178 | ), 179 | "flux-fill-dev": ModelSpec( 180 | ckpt_path=DEV_INPAINTING_CACHE, 181 | ckpt_url=DEV_INPAINTING_URL, 182 | params=FluxParams( 183 | in_channels=384, 184 | out_channels=64, 185 | vec_in_dim=768, 186 | context_in_dim=4096, 187 | hidden_size=3072, 188 | mlp_ratio=4.0, 189 | num_heads=24, 190 | depth=19, 191 | depth_single_blocks=38, 192 | axes_dim=[16, 56, 56], 193 | theta=10_000, 194 | qkv_bias=True, 195 | guidance_embed=True, 196 | ), 197 | ae_path=AE_CACHE, 198 | ae_url=AE_URL, 199 | ae_params=AutoEncoderParams( 200 | resolution=256, 201 | in_channels=3, 202 | ch=128, 203 | out_ch=3, 204 | ch_mult=[1, 2, 4, 4], 205 | num_res_blocks=2, 206 | z_channels=16, 207 | scale_factor=0.3611, 208 | shift_factor=0.1159, 209 | ), 210 | ), 211 | } 212 | 213 | 214 | def print_load_warning(missing: list[str], unexpected: list[str]) -> None: 215 | if len(missing) > 0 and len(unexpected) > 0: 216 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 217 | print("\n" + "-" * 79 + "\n") 218 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 219 | elif len(missing) > 0: 220 | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 221 | elif len(unexpected) > 0: 222 | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 223 | 224 | 225 | def load_flow_model(name: str, device: str | torch.device = "cuda", quantize: bool = False): 226 | # Loading Flux 227 | print("Init model") 228 | ckpt_path = configs[name].ckpt_path 229 | ckpt_url = configs[name].ckpt_url 230 | 231 | if not os.path.exists(ckpt_path): 232 | download_weights(ckpt_url, ckpt_path) 233 | 234 | with torch.device("meta" if ckpt_path is not None else device): 235 | model = Flux(configs[name].params).to(torch.bfloat16) 236 | 237 | if quantize: 238 | replace_linear_weight_only_int8_per_channel(model) 239 | 240 | if quantize and ckpt_path is not None: 241 | ckpt_path = Path(ckpt_path).stem + "_quantized.sft" 242 | print(f"Quantized checkpoint path: {ckpt_path}") 243 | 244 | print("Loading checkpoint") 245 | # load_sft doesn't support torch.device 246 | if ckpt_path is not None: 247 | sd = load_sft(ckpt_path, device=str(device)) 248 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 249 | print_load_warning(missing, unexpected) 250 | return model 251 | 252 | 253 | def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: 254 | # max length 64, 128, 256 and 512 should work (if your sequence is short enough) 255 | if not os.path.exists(T5_CACHE): 256 | download_weights(T5_URL, T5_CACHE) 257 | device = torch.device(device) 258 | return HFEmbedder(T5_CACHE, max_length=max_length, torch_dtype=torch.bfloat16).to(device) 259 | 260 | 261 | def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: 262 | if not os.path.exists(CLIP_CACHE): 263 | download_weights(CLIP_URL, CLIP_CACHE) 264 | device = torch.device(device) 265 | return HFEmbedder(CLIP_CACHE, max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device) 266 | 267 | 268 | def load_ae(name: str, device: str | torch.device = "cuda") -> AutoEncoder: 269 | # Loading the autoencoder 270 | print("Init AE") 271 | with torch.device("meta" if configs[name].ae_path is not None else device): 272 | ae = AutoEncoder(configs[name].ae_params) 273 | 274 | ae_path = configs[name].ae_path 275 | ae_url = configs[name].ae_url 276 | if not os.path.exists(ae_path): 277 | download_weights(ae_url, ae_path) 278 | 279 | if configs[name].ae_path is not None: 280 | sd = load_sft(configs[name].ae_path, device=str(device)) 281 | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) 282 | print_load_warning(missing, unexpected) 283 | return ae 284 | 285 | def load_redux(device: str | torch.device = "cuda") -> ReduxImageEncoder: 286 | if not os.path.exists(SIGLIP_CACHE): 287 | download_weights(SIGLIP_URL, SIGLIP_CACHE) 288 | if not os.path.exists(REDUX_CACHE): 289 | download_weights(REDUX_URL, REDUX_CACHE) 290 | 291 | return ReduxImageEncoder(device, redux_path = REDUX_CACHE, siglip_path=SIGLIP_CACHE, dtype=torch.bfloat16) 292 | 293 | def load_depth_encoder(device: str | torch.device = "cuda") -> DepthImageEncoder: 294 | if not os.path.exists(DEPTH_CACHE): 295 | download_weights(DEPTH_URL, DEPTH_CACHE) 296 | 297 | return DepthImageEncoder(device, DEPTH_CACHE) 298 | 299 | 300 | def download_ckpt_from_hf( 301 | repo_id: str, 302 | ckpt_name: str = "flux.safetensors", 303 | ae_name: str | None = None, 304 | **kwargs, 305 | ) -> tuple[Path, Path | None]: 306 | ckpt_path = hf_hub_download(repo_id, ckpt_name, **kwargs) 307 | ae_path = hf_hub_download(repo_id, ae_name, **kwargs) if ae_name else None 308 | return Path(ckpt_path).resolve(), Path(ae_path).resolve() if ae_path else None 309 | 310 | 311 | def download_weights(url: str, dest: Path): 312 | start = time.time() 313 | print("downloading url: ", url) 314 | print("downloading to: ", dest) 315 | if url.endswith("tar"): 316 | subprocess.check_call(["pget", "--log-level=WARNING", "-x", url, dest], close_fds=False) 317 | else: 318 | subprocess.check_call(["pget", "--log-level=WARNING", url, dest], close_fds=False) 319 | print("downloading took: ", time.time() - start) 320 | -------------------------------------------------------------------------------- /fp8/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-flux/a7efe0293062da0df0057b1d11b9ce3dbdd299c8/fp8/__init__.py -------------------------------------------------------------------------------- /fp8/configs/config-1-flux-dev-fp8-h100.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "flux-dev", 3 | "params": { 4 | "in_channels": 64, 5 | "vec_in_dim": 768, 6 | "context_in_dim": 4096, 7 | "hidden_size": 3072, 8 | "mlp_ratio": 4.0, 9 | "num_heads": 24, 10 | "depth": 19, 11 | "depth_single_blocks": 38, 12 | "axes_dim": [ 13 | 16, 14 | 56, 15 | 56 16 | ], 17 | "theta": 10000, 18 | "qkv_bias": true, 19 | "guidance_embed": true 20 | }, 21 | "ae_params": { 22 | "resolution": 256, 23 | "in_channels": 3, 24 | "ch": 128, 25 | "out_ch": 3, 26 | "ch_mult": [ 27 | 1, 28 | 2, 29 | 4, 30 | 4 31 | ], 32 | "num_res_blocks": 2, 33 | "z_channels": 16, 34 | "scale_factor": 0.3611, 35 | "shift_factor": 0.1159 36 | }, 37 | "ckpt_path": "./model-cache/dev-fp8/dev-fp8.safetensors", 38 | "flux_url": "https://weights.replicate.delivery/default/official-models/flux/dev/dev-fp8.safetensors", 39 | "ae_path": "./model-cache/ae/ae.sft", 40 | "repo_id": "n/a", 41 | "repo_flow": "n/a", 42 | "repo_ae": "n/a", 43 | "text_enc_max_length": 512, 44 | "text_enc_path": "./model-cache/t5", 45 | "text_enc_device": "cuda:0", 46 | "ae_device": "cuda:0", 47 | "flux_device": "cuda:0", 48 | "prequantized_flow": true, 49 | "flow_dtype": "bfloat16", 50 | "ae_dtype": "bfloat16", 51 | "text_enc_dtype": "bfloat16", 52 | "flow_quantization_dtype": "qfloat8", 53 | "text_enc_quantization_dtype": "qfloat8", 54 | "compile_whole_model": true, 55 | "compile_extras": true, 56 | "compile_blocks": true, 57 | "offload_text_encoder": false, 58 | "offload_vae": false, 59 | "offload_flow": false 60 | } -------------------------------------------------------------------------------- /fp8/configs/config-1-flux-dev-h100.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "flux-dev", 3 | "params": { 4 | "in_channels": 64, 5 | "vec_in_dim": 768, 6 | "context_in_dim": 4096, 7 | "hidden_size": 3072, 8 | "mlp_ratio": 4.0, 9 | "num_heads": 24, 10 | "depth": 19, 11 | "depth_single_blocks": 38, 12 | "axes_dim": [ 13 | 16, 14 | 56, 15 | 56 16 | ], 17 | "theta": 10000, 18 | "qkv_bias": true, 19 | "guidance_embed": true 20 | }, 21 | "ae_params": { 22 | "resolution": 256, 23 | "in_channels": 3, 24 | "ch": 128, 25 | "out_ch": 3, 26 | "ch_mult": [ 27 | 1, 28 | 2, 29 | 4, 30 | 4 31 | ], 32 | "num_res_blocks": 2, 33 | "z_channels": 16, 34 | "scale_factor": 0.3611, 35 | "shift_factor": 0.1159 36 | }, 37 | "ckpt_path": "./model-cache/dev/dev.sft", 38 | "ae_path": "./model-cache/ae/ae.sft", 39 | "repo_id": "black-forest-labs/FLUX.1-dev", 40 | "repo_flow": "flux1-dev.sft", 41 | "repo_ae": "ae.sft", 42 | "flux_url": "n/a", 43 | "text_enc_max_length": 512, 44 | "text_enc_path": "./model-cache/t5", 45 | "text_enc_device": "cuda:0", 46 | "ae_device": "cuda:0", 47 | "flux_device": "cuda:0", 48 | "flow_dtype": "bfloat16", 49 | "ae_dtype": "bfloat16", 50 | "text_enc_dtype": "bfloat16", 51 | "flow_quantization_dtype": "qfloat8", 52 | "text_enc_quantization_dtype": "qfloat8", 53 | "compile_whole_model": true, 54 | "compile_extras": true, 55 | "compile_blocks": true, 56 | "offload_text_encoder": false, 57 | "offload_vae": false, 58 | "offload_flow": false 59 | } -------------------------------------------------------------------------------- /fp8/configs/config-1-flux-schnell-fp8-h100.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "flux-schnell", 3 | "params": { 4 | "in_channels": 64, 5 | "vec_in_dim": 768, 6 | "context_in_dim": 4096, 7 | "hidden_size": 3072, 8 | "mlp_ratio": 4.0, 9 | "num_heads": 24, 10 | "depth": 19, 11 | "depth_single_blocks": 38, 12 | "axes_dim": [ 13 | 16, 14 | 56, 15 | 56 16 | ], 17 | "theta": 10000, 18 | "qkv_bias": true, 19 | "guidance_embed": false 20 | }, 21 | "ae_params": { 22 | "resolution": 256, 23 | "in_channels": 3, 24 | "ch": 128, 25 | "out_ch": 3, 26 | "ch_mult": [ 27 | 1, 28 | 2, 29 | 4, 30 | 4 31 | ], 32 | "num_res_blocks": 2, 33 | "z_channels": 16, 34 | "scale_factor": 0.3611, 35 | "shift_factor": 0.1159 36 | }, 37 | "ckpt_path": "./model-cache/schnell-fp8/schnell-fp8.safetensors", 38 | "flux_url": "https://weights.replicate.delivery/default/official-models/flux/schnell/schnell-fp8.safetensors", 39 | "ae_path": "./model-cache/ae/ae.sft", 40 | "repo_id": "n/a", 41 | "repo_flow": "n/a", 42 | "repo_ae": "n/a", 43 | "text_enc_max_length": 256, 44 | "text_enc_path": "./model-cache/t5", 45 | "text_enc_device": "cuda:0", 46 | "ae_device": "cuda:0", 47 | "flux_device": "cuda:0", 48 | "prequantized_flow": true, 49 | "flow_dtype": "bfloat16", 50 | "ae_dtype": "bfloat16", 51 | "text_enc_dtype": "bfloat16", 52 | "flow_quantization_dtype": "qfloat8", 53 | "text_enc_quantization_dtype": "qfloat8", 54 | "compile_whole_model": true, 55 | "compile_extras": true, 56 | "compile_blocks": true, 57 | "offload_text_encoder": false, 58 | "offload_vae": false, 59 | "offload_flow": false 60 | } -------------------------------------------------------------------------------- /fp8/configs/config-1-flux-schnell-h100.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "flux-schnell", 3 | "params": { 4 | "in_channels": 64, 5 | "vec_in_dim": 768, 6 | "context_in_dim": 4096, 7 | "hidden_size": 3072, 8 | "mlp_ratio": 4.0, 9 | "num_heads": 24, 10 | "depth": 19, 11 | "depth_single_blocks": 38, 12 | "axes_dim": [ 13 | 16, 14 | 56, 15 | 56 16 | ], 17 | "theta": 10000, 18 | "qkv_bias": true, 19 | "guidance_embed": false 20 | }, 21 | "ae_params": { 22 | "resolution": 256, 23 | "in_channels": 3, 24 | "ch": 128, 25 | "out_ch": 3, 26 | "ch_mult": [ 27 | 1, 28 | 2, 29 | 4, 30 | 4 31 | ], 32 | "num_res_blocks": 2, 33 | "z_channels": 16, 34 | "scale_factor": 0.3611, 35 | "shift_factor": 0.1159 36 | }, 37 | "ckpt_path": "./model-cache/schnell/schnell.sft", 38 | "ae_path": "./model-cache/ae/ae.sft", 39 | "repo_id": "black-forest-labs/FLUX.1-dev", 40 | "repo_flow": "flux1-dev.sft", 41 | "repo_ae": "ae.sft", 42 | "flux_url": "n/a", 43 | "text_enc_max_length": 256, 44 | "text_enc_path": "./model-cache/t5", 45 | "text_enc_device": "cuda:0", 46 | "ae_device": "cuda:0", 47 | "flux_device": "cuda:0", 48 | "flow_dtype": "bfloat16", 49 | "ae_dtype": "bfloat16", 50 | "text_enc_dtype": "bfloat16", 51 | "flow_quantization_dtype": "qfloat8", 52 | "text_enc_quantization_dtype": "qfloat8", 53 | "compile_whole_model": true, 54 | "compile_extras": true, 55 | "compile_blocks": true, 56 | "offload_text_encoder": false, 57 | "offload_vae": false, 58 | "offload_flow": false 59 | } -------------------------------------------------------------------------------- /fp8/float8_quantize.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | import math 6 | from torch.compiler import is_compiling 7 | from torch import __version__ 8 | from torch.version import cuda 9 | 10 | from fp8.modules.flux_model import Modulation 11 | 12 | IS_TORCH_2_4 = __version__ < (2, 4, 9) 13 | LT_TORCH_2_4 = __version__ < (2, 4) 14 | if LT_TORCH_2_4: 15 | if not hasattr(torch, "_scaled_mm"): 16 | raise RuntimeError( 17 | "This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later." 18 | ) 19 | CUDA_VERSION = float(cuda) if cuda else 0 20 | if CUDA_VERSION < 12.4: 21 | raise RuntimeError( 22 | f"This version of PyTorch is not supported. Please upgrade to PyTorch 2.4 with CUDA 12.4 or later got torch version {__version__} and CUDA version {cuda}." 23 | ) 24 | try: 25 | from cublas_ops import CublasLinear 26 | except ImportError: 27 | CublasLinear = type(None) 28 | 29 | 30 | class F8Linear(nn.Module): 31 | def __init__( 32 | self, 33 | in_features: int, 34 | out_features: int, 35 | bias: bool = True, 36 | device=None, 37 | dtype=torch.float16, 38 | float8_dtype=torch.float8_e4m3fn, 39 | float_weight: torch.Tensor = None, 40 | float_bias: torch.Tensor = None, 41 | num_scale_trials: int = 12, 42 | input_float8_dtype=torch.float8_e5m2, 43 | ) -> None: 44 | super().__init__() 45 | self.in_features = in_features 46 | self.out_features = out_features 47 | self.float8_dtype = float8_dtype 48 | self.input_float8_dtype = input_float8_dtype 49 | self.input_scale_initialized = False 50 | self.weight_initialized = False 51 | self.max_value = torch.finfo(self.float8_dtype).max 52 | self.input_max_value = torch.finfo(self.input_float8_dtype).max 53 | factory_kwargs = {"dtype": dtype, "device": device} 54 | if float_weight is None: 55 | self.weight = nn.Parameter( 56 | torch.empty((out_features, in_features), **factory_kwargs) 57 | ) 58 | else: 59 | self.weight = nn.Parameter( 60 | float_weight, requires_grad=float_weight.requires_grad 61 | ) 62 | if float_bias is None: 63 | if bias: 64 | self.bias = nn.Parameter( 65 | torch.empty(out_features, **factory_kwargs), 66 | ) 67 | else: 68 | self.register_parameter("bias", None) 69 | else: 70 | self.bias = nn.Parameter(float_bias, requires_grad=float_bias.requires_grad) 71 | self.num_scale_trials = num_scale_trials 72 | self.input_amax_trials = torch.zeros( 73 | num_scale_trials, requires_grad=False, device=device, dtype=torch.float32 74 | ) 75 | self.trial_index = 0 76 | self.register_buffer("scale", None) 77 | self.register_buffer( 78 | "input_scale", 79 | None, 80 | ) 81 | self.register_buffer( 82 | "float8_data", 83 | None, 84 | ) 85 | self.scale_reciprocal = self.register_buffer("scale_reciprocal", None) 86 | self.input_scale_reciprocal = self.register_buffer( 87 | "input_scale_reciprocal", None 88 | ) 89 | 90 | def _load_from_state_dict( 91 | self, 92 | state_dict, 93 | prefix, 94 | local_metadata, # noqa 95 | strict, # noqa 96 | missing_keys, # noqa 97 | unexpected_keys, # noqa 98 | error_msgs, # noqa 99 | ): 100 | sd = {k.replace(prefix, ""): v for k, v in state_dict.items()} 101 | if "weight" in sd: 102 | if ( 103 | "float8_data" not in sd 104 | or sd["float8_data"] is None 105 | and sd["weight"].shape == (self.out_features, self.in_features) 106 | ): 107 | # Initialize as if it's an F8Linear that needs to be quantized 108 | self._parameters["weight"] = nn.Parameter( 109 | sd["weight"], requires_grad=False 110 | ) 111 | if "bias" in sd: 112 | self._parameters["bias"] = nn.Parameter( 113 | sd["bias"], requires_grad=False 114 | ) 115 | self.quantize_weight() 116 | elif sd["float8_data"].shape == ( 117 | self.out_features, 118 | self.in_features, 119 | ) and sd["weight"] == torch.zeros_like(sd["weight"]): 120 | w = sd["weight"] 121 | # Set the init values as if it's already quantized float8_data 122 | self._buffers["float8_data"] = sd["float8_data"] 123 | self._parameters["weight"] = nn.Parameter( 124 | torch.zeros( 125 | 1, 126 | dtype=w.dtype, 127 | device=w.device, 128 | requires_grad=False, 129 | ) 130 | ) 131 | if "bias" in sd: 132 | self._parameters["bias"] = nn.Parameter( 133 | sd["bias"], requires_grad=False 134 | ) 135 | self.weight_initialized = True 136 | 137 | # Check if scales and reciprocals are initialized 138 | if all( 139 | key in sd 140 | for key in [ 141 | "scale", 142 | "input_scale", 143 | "scale_reciprocal", 144 | "input_scale_reciprocal", 145 | ] 146 | ): 147 | self.scale = sd["scale"].float() 148 | self.input_scale = sd["input_scale"].float() 149 | self.scale_reciprocal = sd["scale_reciprocal"].float() 150 | self.input_scale_reciprocal = sd["input_scale_reciprocal"].float() 151 | self.input_scale_initialized = True 152 | self.trial_index = self.num_scale_trials 153 | elif "scale" in sd and "scale_reciprocal" in sd: 154 | self.scale = sd["scale"].float() 155 | self.input_scale = ( 156 | sd["input_scale"].float() if "input_scale" in sd else None 157 | ) 158 | self.scale_reciprocal = sd["scale_reciprocal"].float() 159 | self.input_scale_reciprocal = ( 160 | sd["input_scale_reciprocal"].float() 161 | if "input_scale_reciprocal" in sd 162 | else None 163 | ) 164 | self.input_scale_initialized = "input_scale" in sd 165 | self.trial_index = ( 166 | self.num_scale_trials if "input_scale" in sd else 0 167 | ) 168 | self.input_amax_trials = torch.zeros( 169 | self.num_scale_trials, 170 | requires_grad=False, 171 | dtype=torch.float32, 172 | device=self.weight.device, 173 | ) 174 | self.input_scale_initialized = False 175 | self.trial_index = 0 176 | else: 177 | # If scales are not initialized, reset trials 178 | self.input_scale_initialized = False 179 | self.trial_index = 0 180 | self.input_amax_trials = torch.zeros( 181 | self.num_scale_trials, requires_grad=False, dtype=torch.float32 182 | ) 183 | else: 184 | raise RuntimeError( 185 | f"Weight tensor not found or has incorrect shape in state dict: {sd.keys()}" 186 | ) 187 | else: 188 | raise RuntimeError( 189 | "Weight tensor not found or has incorrect shape in state dict" 190 | ) 191 | 192 | def quantize_weight(self): 193 | if self.weight_initialized: 194 | return 195 | amax = torch.max(torch.abs(self.weight.data)).float() 196 | self.scale = self.amax_to_scale(amax, self.max_value) 197 | self.float8_data = self.to_fp8_saturated( 198 | self.weight.data, self.scale, self.max_value 199 | ).to(self.float8_dtype) 200 | self.scale_reciprocal = self.scale.reciprocal() 201 | self.weight.data = torch.zeros( 202 | 1, dtype=self.weight.dtype, device=self.weight.device, requires_grad=False 203 | ) 204 | self.weight_initialized = True 205 | 206 | def set_weight_tensor(self, tensor: torch.Tensor): 207 | self.weight.data = tensor 208 | self.weight_initialized = False 209 | self.quantize_weight() 210 | 211 | def amax_to_scale(self, amax, max_val): 212 | return (max_val / torch.clamp(amax, min=1e-12)).clamp(max=max_val) 213 | 214 | def to_fp8_saturated(self, x, scale, max_val): 215 | return (x * scale).clamp(-max_val, max_val) 216 | 217 | def quantize_input(self, x: torch.Tensor): 218 | if self.input_scale_initialized: 219 | return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( 220 | self.input_float8_dtype 221 | ) 222 | if self.trial_index < self.num_scale_trials: 223 | amax = torch.max(torch.abs(x)).float() 224 | 225 | self.input_amax_trials[self.trial_index] = amax 226 | self.trial_index += 1 227 | self.input_scale = self.amax_to_scale( 228 | self.input_amax_trials[: self.trial_index].max(), self.input_max_value 229 | ) 230 | self.input_scale_reciprocal = self.input_scale.reciprocal() 231 | return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( 232 | self.input_float8_dtype 233 | ) 234 | self.input_scale = self.amax_to_scale( 235 | self.input_amax_trials.max(), self.input_max_value 236 | ) 237 | self.input_scale_reciprocal = self.input_scale.reciprocal() 238 | self.input_scale_initialized = True 239 | return self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( 240 | self.input_float8_dtype 241 | ) 242 | 243 | def reset_parameters(self) -> None: 244 | if self.weight_initialized: 245 | self.weight = nn.Parameter( 246 | torch.empty( 247 | (self.out_features, self.in_features), 248 | **{ 249 | "dtype": self.weight.dtype, 250 | "device": self.weight.device, 251 | }, 252 | ) 253 | ) 254 | self.weight_initialized = False 255 | self.input_scale_initialized = False 256 | self.trial_index = 0 257 | self.input_amax_trials.zero_() 258 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 259 | if self.bias is not None: 260 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) # noqa 261 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 262 | init.uniform_(self.bias, -bound, bound) 263 | self.quantize_weight() 264 | self.max_value = torch.finfo(self.float8_dtype).max 265 | self.input_max_value = torch.finfo(self.input_float8_dtype).max 266 | 267 | def forward(self, x: torch.Tensor) -> torch.Tensor: 268 | if self.input_scale_initialized or is_compiling(): 269 | x = self.to_fp8_saturated(x, self.input_scale, self.input_max_value).to( 270 | self.input_float8_dtype 271 | ) 272 | else: 273 | x = self.quantize_input(x) 274 | 275 | prev_dims = x.shape[:-1] 276 | x = x.view(-1, self.in_features) 277 | 278 | device = x.device 279 | if x.device.type != 'cpu' and torch.cuda.get_device_capability(x.device) >= (8, 9): 280 | # float8 matmul, much faster than float16 matmul w/ float32 accumulate on ADA devices! 281 | out = torch._scaled_mm( # noqa 282 | x, 283 | self.float8_data.T, 284 | scale_a=self.input_scale_reciprocal, 285 | scale_b=self.scale_reciprocal, 286 | bias=self.bias, 287 | out_dtype=self.weight.dtype, 288 | use_fast_accum=True, 289 | ) 290 | else: 291 | # Plain matrix multiplication for non-ADA devices 292 | # Assuming x is in float8 and self.float8_data is in float8 as well 293 | # Convert to float32, perform the multiplication, and then apply scaling and bias if necessary 294 | 295 | # Convert float8 to float32 for the multiplication 296 | x_float32 = x.to(torch.float32) 297 | float8_data_float32 = self.float8_data.T.to(torch.float32) 298 | 299 | # Regular matrix multiplication 300 | out = torch.matmul(x_float32, float8_data_float32) 301 | 302 | # Scale the output accordingly 303 | out = out * (self.input_scale_reciprocal * self.scale_reciprocal) 304 | 305 | # Add bias if it exists 306 | if self.bias is not None: 307 | out += self.bias 308 | out = out.to(self.weight.dtype) 309 | 310 | if IS_TORCH_2_4: 311 | out = out[0] 312 | return out.view(*prev_dims, self.out_features) 313 | 314 | @classmethod 315 | def from_linear( 316 | cls, 317 | linear: nn.Linear, 318 | float8_dtype=torch.float8_e4m3fn, 319 | input_float8_dtype=torch.float8_e5m2, 320 | ) -> "F8Linear": 321 | f8_lin = cls( 322 | in_features=linear.in_features, 323 | out_features=linear.out_features, 324 | bias=linear.bias is not None, 325 | device=linear.weight.device, 326 | dtype=linear.weight.dtype, 327 | float8_dtype=float8_dtype, 328 | float_weight=linear.weight.data, 329 | float_bias=(linear.bias.data if linear.bias is not None else None), 330 | input_float8_dtype=input_float8_dtype, 331 | ) 332 | f8_lin.quantize_weight() 333 | return f8_lin 334 | 335 | 336 | @torch.inference_mode() 337 | def recursive_swap_linears( 338 | model: nn.Module, 339 | float8_dtype=torch.float8_e4m3fn, 340 | input_float8_dtype=torch.float8_e5m2, 341 | quantize_modulation: bool = True, 342 | ignore_keys: list[str] = [], 343 | ) -> None: 344 | """ 345 | Recursively swaps all nn.Linear modules in the given model with F8Linear modules. 346 | 347 | This function traverses the model's structure and replaces each nn.Linear 348 | instance with an F8Linear instance, which uses 8-bit floating point 349 | quantization for weights. The original linear layer's weights are deleted 350 | after conversion to save memory. 351 | 352 | Args: 353 | model (nn.Module): The PyTorch model to modify. 354 | 355 | Note: 356 | This function modifies the model in-place. After calling this function, 357 | all linear layers in the model will be using 8-bit quantization. 358 | """ 359 | for name, child in model.named_children(): 360 | if name in ignore_keys: 361 | continue 362 | if isinstance(child, Modulation) and not quantize_modulation: 363 | continue 364 | if isinstance(child, nn.Linear) and not isinstance( 365 | child, (F8Linear, CublasLinear) 366 | ): 367 | setattr( 368 | model, 369 | name, 370 | F8Linear.from_linear( 371 | child, 372 | float8_dtype=float8_dtype, 373 | input_float8_dtype=input_float8_dtype, 374 | ), 375 | ) 376 | del child 377 | else: 378 | recursive_swap_linears( 379 | child, 380 | float8_dtype=float8_dtype, 381 | input_float8_dtype=input_float8_dtype, 382 | quantize_modulation=quantize_modulation, 383 | ignore_keys=ignore_keys, 384 | ) 385 | 386 | 387 | @torch.inference_mode() 388 | def swap_to_cublaslinear(model: nn.Module): 389 | if not isinstance(CublasLinear, type(torch.nn.Module)): 390 | return 391 | for name, child in model.named_children(): 392 | if isinstance(child, nn.Linear) and not isinstance( 393 | child, (F8Linear, CublasLinear) 394 | ): 395 | cublas_lin = CublasLinear( 396 | child.in_features, 397 | child.out_features, 398 | bias=child.bias is not None, 399 | dtype=child.weight.dtype, 400 | device=child.weight.device, 401 | ) 402 | cublas_lin.weight.data = child.weight.clone().detach() 403 | cublas_lin.bias.data = child.bias.clone().detach() 404 | setattr(model, name, cublas_lin) 405 | del child 406 | else: 407 | swap_to_cublaslinear(child) 408 | 409 | 410 | @torch.inference_mode() 411 | def quantize_flow_transformer_and_dispatch_float8( 412 | flow_model: nn.Module, 413 | device=torch.device("cuda"), 414 | float8_dtype=torch.float8_e4m3fn, 415 | input_float8_dtype=torch.float8_e5m2, 416 | offload_flow=False, 417 | swap_linears_with_cublaslinear=False, 418 | flow_dtype=torch.float16, 419 | quantize_modulation: bool = True, 420 | quantize_flow_embedder_layers: bool = True, 421 | ) -> nn.Module: 422 | """ 423 | Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device. 424 | 425 | Iteratively pushes each module to device, evals, replaces linear layers with F8Linear except for final_layer, and quantizes. 426 | 427 | Allows for fast dispatch to gpu & quantize without causing OOM on gpus with limited memory. 428 | 429 | After dispatching, if offload_flow is True, offloads the model to cpu. 430 | 431 | if swap_linears_with_cublaslinear is true, and flow_dtype == torch.float16, then swap all linears with cublaslinears for 2x performance boost on consumer GPUs. 432 | Otherwise will skip the cublaslinear swap. 433 | 434 | For added extra precision, you can set quantize_flow_embedder_layers to False, 435 | this helps maintain the output quality of the flow transformer moreso than fully quantizing, 436 | at the expense of ~512MB more VRAM usage. 437 | 438 | For added extra precision, you can set quantize_modulation to False, 439 | this helps maintain the output quality of the flow transformer moreso than fully quantizing, 440 | at the expense of ~2GB more VRAM usage, but- has a much higher impact on image quality than the embedder layers. 441 | """ 442 | for module in flow_model.double_blocks: 443 | module.to(device) 444 | module.eval() 445 | recursive_swap_linears( 446 | module, 447 | float8_dtype=float8_dtype, 448 | input_float8_dtype=input_float8_dtype, 449 | quantize_modulation=quantize_modulation, 450 | ) 451 | torch.cuda.empty_cache() 452 | for module in flow_model.single_blocks: 453 | module.to(device) 454 | module.eval() 455 | recursive_swap_linears( 456 | module, 457 | float8_dtype=float8_dtype, 458 | input_float8_dtype=input_float8_dtype, 459 | quantize_modulation=quantize_modulation, 460 | ) 461 | torch.cuda.empty_cache() 462 | to_gpu_extras = [ 463 | "vector_in", 464 | "img_in", 465 | "txt_in", 466 | "time_in", 467 | "guidance_in", 468 | "final_layer", 469 | "pe_embedder", 470 | ] 471 | for module in to_gpu_extras: 472 | m_extra = getattr(flow_model, module) 473 | if m_extra is None: 474 | continue 475 | m_extra.to(device) 476 | m_extra.eval() 477 | if isinstance(m_extra, nn.Linear) and not isinstance( 478 | m_extra, (F8Linear, CublasLinear) 479 | ): 480 | if quantize_flow_embedder_layers: 481 | setattr( 482 | flow_model, 483 | module, 484 | F8Linear.from_linear( 485 | m_extra, 486 | float8_dtype=float8_dtype, 487 | input_float8_dtype=input_float8_dtype, 488 | ), 489 | ) 490 | del m_extra 491 | elif module != "final_layer": 492 | if quantize_flow_embedder_layers: 493 | recursive_swap_linears( 494 | m_extra, 495 | float8_dtype=float8_dtype, 496 | input_float8_dtype=input_float8_dtype, 497 | quantize_modulation=quantize_modulation, 498 | ) 499 | torch.cuda.empty_cache() 500 | if ( 501 | swap_linears_with_cublaslinear 502 | and flow_dtype == torch.float16 503 | and isinstance(CublasLinear, type(torch.nn.Linear)) 504 | ): 505 | swap_to_cublaslinear(flow_model) 506 | elif swap_linears_with_cublaslinear and flow_dtype != torch.float16: 507 | logger.warning("Skipping cublas linear swap because flow_dtype is not float16") 508 | if offload_flow: 509 | flow_model.to("cpu") 510 | torch.cuda.empty_cache() 511 | return flow_model 512 | -------------------------------------------------------------------------------- /fp8/image_encoder.py: -------------------------------------------------------------------------------- 1 | import io 2 | from PIL import Image 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class ImageEncoder: 8 | @torch.inference_mode() 9 | def encode_pil(self, img: torch.Tensor) -> Image: 10 | if img.ndim == 2: 11 | img = ( 12 | img[None] 13 | .repeat_interleave(3, dim=0) 14 | .permute(1, 2, 0) 15 | .contiguous() 16 | .clamp(0, 255) 17 | .type(torch.uint8) 18 | ) 19 | elif img.ndim == 3: 20 | if img.shape[0] == 3: 21 | img = img.permute(1, 2, 0).contiguous().clamp(0, 255).type(torch.uint8) 22 | elif img.shape[2] == 3: 23 | img = img.contiguous().clamp(0, 255).type(torch.uint8) 24 | else: 25 | raise ValueError(f"Unsupported image shape: {img.shape}") 26 | else: 27 | raise ValueError(f"Unsupported image num dims: {img.ndim}") 28 | 29 | img = img.cpu().numpy().astype(np.uint8) 30 | return Image.fromarray(img) 31 | 32 | @torch.inference_mode() 33 | def encode_torch(self, img: torch.Tensor, quality=95): 34 | im = self.encode_pil(img) 35 | iob = io.BytesIO() 36 | im.save(iob, format="JPEG", quality=quality) 37 | iob.seek(0) 38 | return iob 39 | -------------------------------------------------------------------------------- /fp8/modules/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import Tensor, nn 4 | from pydantic import BaseModel 5 | 6 | 7 | class AutoEncoderParams(BaseModel): 8 | resolution: int 9 | in_channels: int 10 | ch: int 11 | out_ch: int 12 | ch_mult: list[int] 13 | num_res_blocks: int 14 | z_channels: int 15 | scale_factor: float 16 | shift_factor: float 17 | 18 | 19 | def swish(x: Tensor) -> Tensor: 20 | return x * torch.sigmoid(x) 21 | 22 | 23 | class AttnBlock(nn.Module): 24 | def __init__(self, in_channels: int): 25 | super().__init__() 26 | self.in_channels = in_channels 27 | 28 | self.norm = nn.GroupNorm( 29 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 30 | ) 31 | 32 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 33 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 34 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 35 | self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) 36 | 37 | def attention(self, h_: Tensor) -> Tensor: 38 | h_ = self.norm(h_) 39 | q = self.q(h_) 40 | k = self.k(h_) 41 | v = self.v(h_) 42 | 43 | b, c, h, w = q.shape 44 | q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 45 | k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 46 | v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 47 | h_ = nn.functional.scaled_dot_product_attention(q, k, v) 48 | 49 | return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) 50 | 51 | def forward(self, x: Tensor) -> Tensor: 52 | return x + self.proj_out(self.attention(x)) 53 | 54 | 55 | class ResnetBlock(nn.Module): 56 | def __init__(self, in_channels: int, out_channels: int): 57 | super().__init__() 58 | self.in_channels = in_channels 59 | out_channels = in_channels if out_channels is None else out_channels 60 | self.out_channels = out_channels 61 | 62 | self.norm1 = nn.GroupNorm( 63 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 64 | ) 65 | self.conv1 = nn.Conv2d( 66 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 67 | ) 68 | self.norm2 = nn.GroupNorm( 69 | num_groups=32, num_channels=out_channels, eps=1e-6, affine=True 70 | ) 71 | self.conv2 = nn.Conv2d( 72 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 73 | ) 74 | if self.in_channels != self.out_channels: 75 | self.nin_shortcut = nn.Conv2d( 76 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 77 | ) 78 | 79 | def forward(self, x): 80 | h = x 81 | h = self.norm1(h) 82 | h = swish(h) 83 | h = self.conv1(h) 84 | 85 | h = self.norm2(h) 86 | h = swish(h) 87 | h = self.conv2(h) 88 | 89 | if self.in_channels != self.out_channels: 90 | x = self.nin_shortcut(x) 91 | 92 | return x + h 93 | 94 | 95 | class Downsample(nn.Module): 96 | def __init__(self, in_channels: int): 97 | super().__init__() 98 | # no asymmetric padding in torch conv, must do it ourselves 99 | self.conv = nn.Conv2d( 100 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 101 | ) 102 | 103 | def forward(self, x: Tensor): 104 | pad = (0, 1, 0, 1) 105 | x = nn.functional.pad(x, pad, mode="constant", value=0) 106 | x = self.conv(x) 107 | return x 108 | 109 | 110 | class Upsample(nn.Module): 111 | def __init__(self, in_channels: int): 112 | super().__init__() 113 | self.conv = nn.Conv2d( 114 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 115 | ) 116 | 117 | def forward(self, x: Tensor): 118 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 119 | x = self.conv(x) 120 | return x 121 | 122 | 123 | class Encoder(nn.Module): 124 | def __init__( 125 | self, 126 | resolution: int, 127 | in_channels: int, 128 | ch: int, 129 | ch_mult: list[int], 130 | num_res_blocks: int, 131 | z_channels: int, 132 | ): 133 | super().__init__() 134 | self.ch = ch 135 | self.num_resolutions = len(ch_mult) 136 | self.num_res_blocks = num_res_blocks 137 | self.resolution = resolution 138 | self.in_channels = in_channels 139 | # downsampling 140 | self.conv_in = nn.Conv2d( 141 | in_channels, self.ch, kernel_size=3, stride=1, padding=1 142 | ) 143 | 144 | curr_res = resolution 145 | in_ch_mult = (1,) + tuple(ch_mult) 146 | self.in_ch_mult = in_ch_mult 147 | self.down = nn.ModuleList() 148 | block_in = self.ch 149 | for i_level in range(self.num_resolutions): 150 | block = nn.ModuleList() 151 | attn = nn.ModuleList() 152 | block_in = ch * in_ch_mult[i_level] 153 | block_out = ch * ch_mult[i_level] 154 | for _ in range(self.num_res_blocks): 155 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 156 | block_in = block_out 157 | down = nn.Module() 158 | down.block = block 159 | down.attn = attn 160 | if i_level != self.num_resolutions - 1: 161 | down.downsample = Downsample(block_in) 162 | curr_res = curr_res // 2 163 | self.down.append(down) 164 | 165 | # middle 166 | self.mid = nn.Module() 167 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 168 | self.mid.attn_1 = AttnBlock(block_in) 169 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 170 | 171 | # end 172 | self.norm_out = nn.GroupNorm( 173 | num_groups=32, num_channels=block_in, eps=1e-6, affine=True 174 | ) 175 | self.conv_out = nn.Conv2d( 176 | block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 177 | ) 178 | 179 | def forward(self, x: Tensor) -> Tensor: 180 | # downsampling 181 | hs = [self.conv_in(x)] 182 | for i_level in range(self.num_resolutions): 183 | for i_block in range(self.num_res_blocks): 184 | h = self.down[i_level].block[i_block](hs[-1]) 185 | if len(self.down[i_level].attn) > 0: 186 | h = self.down[i_level].attn[i_block](h) 187 | hs.append(h) 188 | if i_level != self.num_resolutions - 1: 189 | hs.append(self.down[i_level].downsample(hs[-1])) 190 | 191 | # middle 192 | h = hs[-1] 193 | h = self.mid.block_1(h) 194 | h = self.mid.attn_1(h) 195 | h = self.mid.block_2(h) 196 | # end 197 | h = self.norm_out(h) 198 | h = swish(h) 199 | h = self.conv_out(h) 200 | return h 201 | 202 | 203 | class Decoder(nn.Module): 204 | def __init__( 205 | self, 206 | ch: int, 207 | out_ch: int, 208 | ch_mult: list[int], 209 | num_res_blocks: int, 210 | in_channels: int, 211 | resolution: int, 212 | z_channels: int, 213 | ): 214 | super().__init__() 215 | self.ch = ch 216 | self.num_resolutions = len(ch_mult) 217 | self.num_res_blocks = num_res_blocks 218 | self.resolution = resolution 219 | self.in_channels = in_channels 220 | self.ffactor = 2 ** (self.num_resolutions - 1) 221 | 222 | # compute in_ch_mult, block_in and curr_res at lowest res 223 | block_in = ch * ch_mult[self.num_resolutions - 1] 224 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 225 | self.z_shape = (1, z_channels, curr_res, curr_res) 226 | 227 | # z to block_in 228 | self.conv_in = nn.Conv2d( 229 | z_channels, block_in, kernel_size=3, stride=1, padding=1 230 | ) 231 | 232 | # middle 233 | self.mid = nn.Module() 234 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 235 | self.mid.attn_1 = AttnBlock(block_in) 236 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 237 | 238 | # upsampling 239 | self.up = nn.ModuleList() 240 | for i_level in reversed(range(self.num_resolutions)): 241 | block = nn.ModuleList() 242 | attn = nn.ModuleList() 243 | block_out = ch * ch_mult[i_level] 244 | for _ in range(self.num_res_blocks + 1): 245 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 246 | block_in = block_out 247 | up = nn.Module() 248 | up.block = block 249 | up.attn = attn 250 | if i_level != 0: 251 | up.upsample = Upsample(block_in) 252 | curr_res = curr_res * 2 253 | self.up.insert(0, up) # prepend to get consistent order 254 | 255 | # end 256 | self.norm_out = nn.GroupNorm( 257 | num_groups=32, num_channels=block_in, eps=1e-6, affine=True 258 | ) 259 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 260 | 261 | def forward(self, z: Tensor) -> Tensor: 262 | # z to block_in 263 | h = self.conv_in(z) 264 | 265 | # middle 266 | h = self.mid.block_1(h) 267 | h = self.mid.attn_1(h) 268 | h = self.mid.block_2(h) 269 | 270 | # upsampling 271 | for i_level in reversed(range(self.num_resolutions)): 272 | for i_block in range(self.num_res_blocks + 1): 273 | h = self.up[i_level].block[i_block](h) 274 | if len(self.up[i_level].attn) > 0: 275 | h = self.up[i_level].attn[i_block](h) 276 | if i_level != 0: 277 | h = self.up[i_level].upsample(h) 278 | 279 | # end 280 | h = self.norm_out(h) 281 | h = swish(h) 282 | h = self.conv_out(h) 283 | return h 284 | 285 | 286 | class DiagonalGaussian(nn.Module): 287 | def __init__(self, sample: bool = True, chunk_dim: int = 1): 288 | super().__init__() 289 | self.sample = sample 290 | self.chunk_dim = chunk_dim 291 | 292 | def forward(self, z: Tensor) -> Tensor: 293 | mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) 294 | if self.sample: 295 | std = torch.exp(0.5 * logvar) 296 | return mean + std * torch.randn_like(mean) 297 | else: 298 | return mean 299 | 300 | 301 | class AutoEncoder(nn.Module): 302 | def __init__(self, params: AutoEncoderParams): 303 | super().__init__() 304 | self.encoder = Encoder( 305 | resolution=params.resolution, 306 | in_channels=params.in_channels, 307 | ch=params.ch, 308 | ch_mult=params.ch_mult, 309 | num_res_blocks=params.num_res_blocks, 310 | z_channels=params.z_channels, 311 | ) 312 | self.decoder = Decoder( 313 | resolution=params.resolution, 314 | in_channels=params.in_channels, 315 | ch=params.ch, 316 | out_ch=params.out_ch, 317 | ch_mult=params.ch_mult, 318 | num_res_blocks=params.num_res_blocks, 319 | z_channels=params.z_channels, 320 | ) 321 | self.reg = DiagonalGaussian() 322 | 323 | self.scale_factor = params.scale_factor 324 | self.shift_factor = params.shift_factor 325 | 326 | def encode(self, x: Tensor) -> Tensor: 327 | z = self.reg(self.encoder(x)) 328 | z = self.scale_factor * (z - self.shift_factor) 329 | return z 330 | 331 | def decode(self, z: Tensor) -> Tensor: 332 | z = z / self.scale_factor + self.shift_factor 333 | return self.decoder(z) 334 | 335 | def forward(self, x: Tensor) -> Tensor: 336 | return self.decode(self.encode(x)) 337 | -------------------------------------------------------------------------------- /fp8/modules/conditioner.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | from transformers import ( 6 | CLIPTextModel, 7 | CLIPTokenizer, 8 | T5EncoderModel, 9 | T5Tokenizer, 10 | __version__, 11 | ) 12 | from transformers.utils.quantization_config import QuantoConfig, BitsAndBytesConfig 13 | 14 | CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface") 15 | 16 | 17 | def auto_quantization_config( 18 | quantization_dtype: str, 19 | ) -> QuantoConfig | BitsAndBytesConfig: 20 | if quantization_dtype == "qfloat8": 21 | return QuantoConfig(weights="float8") 22 | elif quantization_dtype == "qint4": 23 | return BitsAndBytesConfig( 24 | load_in_4bit=True, 25 | bnb_4bit_compute_dtype=torch.bfloat16, 26 | bnb_4bit_quant_type="nf4", 27 | ) 28 | elif quantization_dtype == "qint8": 29 | return BitsAndBytesConfig(load_in_8bit=True, llm_int8_has_fp16_weight=False) 30 | elif quantization_dtype == "qint2": 31 | return QuantoConfig(weights="int2") 32 | else: 33 | raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}") 34 | 35 | 36 | class HFEmbedder(nn.Module): 37 | def __init__( 38 | self, 39 | version: str, 40 | max_length: int, 41 | device: torch.device | int, 42 | quantization_dtype: str | None = None, 43 | offloading_device: torch.device | int | None = torch.device("cpu"), 44 | **hf_kwargs, 45 | ): 46 | super().__init__() 47 | self.offloading_device = ( 48 | offloading_device 49 | if isinstance(offloading_device, torch.device) 50 | else torch.device(offloading_device) 51 | ) 52 | self.device = ( 53 | device if isinstance(device, torch.device) else torch.device(device) 54 | ) 55 | self.is_clip = version.startswith("openai") 56 | self.max_length = max_length 57 | self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" 58 | 59 | auto_quant_config = ( 60 | auto_quantization_config(quantization_dtype) if quantization_dtype else None 61 | ) 62 | 63 | # BNB will move to cuda:0 by default if not specified 64 | if isinstance(auto_quant_config, BitsAndBytesConfig): 65 | hf_kwargs["device_map"] = {"": self.device.index} 66 | if auto_quant_config is not None: 67 | hf_kwargs["quantization_config"] = auto_quant_config 68 | 69 | if self.is_clip: 70 | self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( 71 | version, max_length=max_length 72 | ) 73 | 74 | self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( 75 | version, 76 | **hf_kwargs, 77 | ) 78 | 79 | else: 80 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( 81 | os.path.join(version, 'tokenizer'), max_length=max_length 82 | ) 83 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( 84 | os.path.join(version, 'model'), 85 | **hf_kwargs, 86 | ) 87 | 88 | def offload(self): 89 | self.hf_module.to(device=self.offloading_device) 90 | torch.cuda.empty_cache() 91 | 92 | def cuda(self): 93 | self.hf_module.to(device=self.device) 94 | 95 | def forward(self, text: list[str]) -> Tensor: 96 | batch_encoding = self.tokenizer( 97 | text, 98 | truncation=True, 99 | max_length=self.max_length, 100 | return_length=False, 101 | return_overflowing_tokens=False, 102 | padding="max_length", 103 | return_tensors="pt", 104 | ) 105 | outputs = self.hf_module( 106 | input_ids=batch_encoding["input_ids"].to(self.hf_module.device), 107 | attention_mask=None, 108 | output_hidden_states=False, 109 | ) 110 | return outputs[self.output_key] 111 | 112 | 113 | if __name__ == "__main__": 114 | model = HFEmbedder( 115 | "city96/t5-v1_1-xxl-encoder-bf16", 116 | max_length=512, 117 | device=0, 118 | quantization_dtype="qfloat8", 119 | ) 120 | o = model(["hello"]) 121 | print(o) 122 | -------------------------------------------------------------------------------- /fp8/util.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import json 3 | import os 4 | from pathlib import Path 5 | import subprocess 6 | import time 7 | from typing import Any, Literal, Optional 8 | 9 | import torch 10 | from fp8.modules.autoencoder import AutoEncoder, AutoEncoderParams 11 | from fp8.modules.conditioner import HFEmbedder 12 | from fp8.modules.flux_model import Flux, FluxParams 13 | from safetensors.torch import load_file as load_sft 14 | 15 | try: 16 | from enum import StrEnum 17 | except: 18 | from enum import Enum 19 | 20 | class StrEnum(str, Enum): 21 | pass 22 | 23 | 24 | from pydantic import BaseModel, Field, validator 25 | from loguru import logger 26 | 27 | 28 | class ModelVersion(StrEnum): 29 | flux_dev = "flux-dev" 30 | flux_schnell = "flux-schnell" 31 | 32 | 33 | class QuantizationDtype(StrEnum): 34 | qfloat8 = "qfloat8" 35 | qint2 = "qint2" 36 | qint4 = "qint4" 37 | qint8 = "qint8" 38 | 39 | 40 | class ModelSpec(BaseModel): 41 | class Config: 42 | arbitrary_types_allowed = True 43 | use_enum_values = True 44 | version: ModelVersion 45 | params: FluxParams 46 | ae_params: AutoEncoderParams 47 | ckpt_path: str | None 48 | ae_path: str | None 49 | repo_id: str | None 50 | repo_flow: str | None 51 | repo_ae: str | None 52 | text_enc_max_length: int = 512 53 | text_enc_path: str | None 54 | text_enc_device: str | torch.device | None = "cuda:0" 55 | ae_device: str | torch.device | None = "cuda:0" 56 | flux_device: str | torch.device | None = "cuda:0" 57 | flux_url: str | None 58 | flow_dtype: str = "float16" 59 | ae_dtype: str = "bfloat16" 60 | text_enc_dtype: str = "bfloat16" 61 | # unused / deprecated 62 | num_to_quant: Optional[int] = 20 63 | quantize_extras: bool = False 64 | compile_whole_model: bool = False 65 | compile_extras: bool = False 66 | compile_blocks: bool = False 67 | flow_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8 68 | text_enc_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8 69 | ae_quantization_dtype: Optional[QuantizationDtype] = None 70 | clip_quantization_dtype: Optional[QuantizationDtype] = None 71 | offload_text_encoder: bool = False 72 | offload_vae: bool = False 73 | offload_flow: bool = False 74 | prequantized_flow: bool = False 75 | 76 | # Improved precision via not quanitzing the modulation linear layers 77 | quantize_modulation: bool = True 78 | # Improved precision via not quanitzing the flow embedder layers 79 | quantize_flow_embedder_layers: bool = False 80 | 81 | def load_models(config: ModelSpec) -> tuple[Flux, AutoEncoder, HFEmbedder, HFEmbedder]: 82 | flow = load_flow_model(config) 83 | ae = load_autoencoder(config) 84 | clip, t5 = load_text_encoders(config) 85 | return flow, ae, clip, t5 86 | 87 | 88 | def parse_device(device: str | torch.device | None) -> torch.device: 89 | if isinstance(device, str): 90 | return torch.device(device) 91 | elif isinstance(device, torch.device): 92 | return device 93 | else: 94 | return torch.device("cuda:0") 95 | 96 | 97 | def into_dtype(dtype: str) -> torch.dtype: 98 | if isinstance(dtype, torch.dtype): 99 | return dtype 100 | if dtype == "float16": 101 | return torch.float16 102 | elif dtype == "bfloat16": 103 | return torch.bfloat16 104 | elif dtype == "float32": 105 | return torch.float32 106 | else: 107 | raise ValueError(f"Invalid dtype: {dtype}") 108 | 109 | 110 | def into_device(device: str | torch.device | None) -> torch.device: 111 | if isinstance(device, str): 112 | return torch.device(device) 113 | elif isinstance(device, torch.device): 114 | return device 115 | elif isinstance(device, int): 116 | return torch.device(f"cuda:{device}") 117 | else: 118 | return torch.device("cuda:0") 119 | 120 | 121 | def load_config( 122 | name: ModelVersion = ModelVersion.flux_dev, 123 | flux_path: str | None = None, 124 | ae_path: str | None = None, 125 | text_enc_path: str | None = None, 126 | text_enc_device: str | torch.device | None = None, 127 | ae_device: str | torch.device | None = None, 128 | flux_device: str | torch.device | None = None, 129 | flow_dtype: str = "float16", 130 | ae_dtype: str = "bfloat16", 131 | text_enc_dtype: str = "bfloat16", 132 | num_to_quant: Optional[int] = 20, 133 | compile_extras: bool = False, 134 | compile_blocks: bool = False, 135 | offload_text_enc: bool = False, 136 | offload_ae: bool = False, 137 | offload_flow: bool = False, 138 | quant_text_enc: Optional[Literal["float8", "qint2", "qint4", "qint8"]] = None, 139 | quant_ae: bool = False, 140 | prequantized_flow: bool = False, 141 | quantize_modulation: bool = True, 142 | quantize_flow_embedder_layers: bool = False, 143 | ) -> ModelSpec: 144 | """ 145 | Load a model configuration using the passed arguments. 146 | """ 147 | text_enc_device = str(parse_device(text_enc_device)) 148 | ae_device = str(parse_device(ae_device)) 149 | flux_device = str(parse_device(flux_device)) 150 | return ModelSpec( 151 | version=name, 152 | repo_id=( 153 | "black-forest-labs/FLUX.1-dev" 154 | if name == ModelVersion.flux_dev 155 | else "black-forest-labs/FLUX.1-schnell" 156 | ), 157 | repo_flow=( 158 | "flux1-dev.sft" if name == ModelVersion.flux_dev else "flux1-schnell.sft" 159 | ), 160 | repo_ae="ae.sft", 161 | ckpt_path=flux_path, 162 | params=FluxParams( 163 | in_channels=64, 164 | vec_in_dim=768, 165 | context_in_dim=4096, 166 | hidden_size=3072, 167 | mlp_ratio=4.0, 168 | num_heads=24, 169 | depth=19, 170 | depth_single_blocks=38, 171 | axes_dim=[16, 56, 56], 172 | theta=10_000, 173 | qkv_bias=True, 174 | guidance_embed=name == ModelVersion.flux_dev, 175 | ), 176 | ae_path=ae_path, 177 | ae_params=AutoEncoderParams( 178 | resolution=256, 179 | in_channels=3, 180 | ch=128, 181 | out_ch=3, 182 | ch_mult=[1, 2, 4, 4], 183 | num_res_blocks=2, 184 | z_channels=16, 185 | scale_factor=0.3611, 186 | shift_factor=0.1159, 187 | ), 188 | text_enc_path=text_enc_path, 189 | text_enc_device=text_enc_device, 190 | ae_device=ae_device, 191 | flux_device=flux_device, 192 | flow_dtype=flow_dtype, 193 | ae_dtype=ae_dtype, 194 | text_enc_dtype=text_enc_dtype, 195 | text_enc_max_length=512 if name == ModelVersion.flux_dev else 256, 196 | num_to_quant=num_to_quant, 197 | compile_extras=compile_extras, 198 | compile_blocks=compile_blocks, 199 | offload_flow=offload_flow, 200 | offload_text_encoder=offload_text_enc, 201 | offload_vae=offload_ae, 202 | text_enc_quantization_dtype={ 203 | "float8": QuantizationDtype.qfloat8, 204 | "qint2": QuantizationDtype.qint2, 205 | "qint4": QuantizationDtype.qint4, 206 | "qint8": QuantizationDtype.qint8, 207 | }.get(quant_text_enc, None), 208 | ae_quantization_dtype=QuantizationDtype.qfloat8 if quant_ae else None, 209 | prequantized_flow=prequantized_flow, 210 | quantize_modulation=quantize_modulation, 211 | quantize_flow_embedder_layers=quantize_flow_embedder_layers, 212 | ) 213 | 214 | 215 | def load_config_from_path(path: str) -> ModelSpec: 216 | path_path = Path(path) 217 | if not path_path.exists(): 218 | raise ValueError(f"Path {path} does not exist") 219 | if not path_path.is_file(): 220 | raise ValueError(f"Path {path} is not a file") 221 | return ModelSpec(**json.loads(path_path.read_text())) 222 | 223 | 224 | def print_load_warning(missing: list[str], unexpected: list[str]) -> None: 225 | if len(missing) > 0 and len(unexpected) > 0: 226 | logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 227 | logger.warning("\n" + "-" * 79 + "\n") 228 | logger.warning( 229 | f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) 230 | ) 231 | elif len(missing) > 0: 232 | logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 233 | elif len(unexpected) > 0: 234 | logger.warning( 235 | f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) 236 | ) 237 | 238 | def download_weights(url: str, dest: Path): 239 | start = time.time() 240 | print("downloading url: ", url) 241 | print("downloading to: ", dest) 242 | if url.endswith("tar"): 243 | subprocess.check_call(["pget", "--log-level=WARNING", "-x", url, dest], close_fds=False) 244 | else: 245 | subprocess.check_call(["pget", "--log-level=WARNING", url, dest], close_fds=False) 246 | print("downloading took: ", time.time() - start) 247 | 248 | 249 | def load_flow_model(config: ModelSpec) -> Flux: 250 | ckpt_path = config.ckpt_path 251 | if not os.path.exists(ckpt_path): 252 | flux_url = config.flux_url 253 | download_weights(flux_url, ckpt_path) 254 | 255 | FluxClass = Flux 256 | 257 | with torch.device("meta"): 258 | model = FluxClass(config, dtype=into_dtype(config.flow_dtype)) 259 | if not config.prequantized_flow: 260 | model.type(into_dtype(config.flow_dtype)) 261 | 262 | if ckpt_path is not None: 263 | # load_sft doesn't support torch.device 264 | sd = load_sft(ckpt_path, device="cpu") 265 | missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) 266 | print_load_warning(missing, unexpected) 267 | if not config.prequantized_flow: 268 | model.type(into_dtype(config.flow_dtype)) 269 | return model 270 | 271 | 272 | def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]: 273 | clip = HFEmbedder( 274 | "openai/clip-vit-large-patch14", 275 | max_length=77, 276 | torch_dtype=into_dtype(config.text_enc_dtype), 277 | device=into_device(config.text_enc_device).index or 0, 278 | quantization_dtype=config.clip_quantization_dtype, 279 | ) 280 | t5 = HFEmbedder( 281 | config.text_enc_path, 282 | max_length=config.text_enc_max_length, 283 | torch_dtype=into_dtype(config.text_enc_dtype), 284 | device=into_device(config.text_enc_device).index or 0, 285 | quantization_dtype=config.text_enc_quantization_dtype, 286 | ) 287 | return clip, t5 288 | 289 | 290 | def load_autoencoder(config: ModelSpec) -> AutoEncoder: 291 | ckpt_path = config.ae_path 292 | with torch.device("meta" if ckpt_path is not None else config.ae_device): 293 | ae = AutoEncoder(config.ae_params).to(into_dtype(config.ae_dtype)) 294 | 295 | if ckpt_path is not None: 296 | sd = load_sft(ckpt_path, device=str(config.ae_device)) 297 | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) 298 | print_load_warning(missing, unexpected) 299 | ae.to(device=into_device(config.ae_device), dtype=into_dtype(config.ae_dtype)) 300 | if config.ae_quantization_dtype is not None: 301 | from fp8.float8_quantize import recursive_swap_linears 302 | 303 | recursive_swap_linears(ae) 304 | if config.offload_vae: 305 | ae.to("cpu") 306 | torch.cuda.empty_cache() 307 | return ae 308 | 309 | @dataclass 310 | class LoadedModels(): 311 | flow: Optional[Flux] 312 | ae: Any 313 | clip: Any 314 | t5: Any 315 | config: Optional[ModelSpec] 316 | 317 | 318 | def load_models_from_config_path( 319 | path: str, 320 | ) -> LoadedModels: 321 | config = load_config_from_path(path) 322 | clip, t5 = load_text_encoders(config) 323 | return LoadedModels( 324 | flow=load_flow_model(config), 325 | ae=load_autoencoder(config), 326 | clip=clip, 327 | t5=t5, 328 | config=config, 329 | ) 330 | 331 | 332 | def load_models_from_config(config: ModelSpec, shared_models: LoadedModels = None) -> LoadedModels: 333 | if shared_models: 334 | clip = shared_models.clip 335 | t5 = shared_models.t5 336 | ae = shared_models.ae 337 | else: 338 | clip, t5 = load_text_encoders(config) 339 | ae = load_autoencoder(config) 340 | 341 | flow = load_flow_model(config) 342 | 343 | return LoadedModels( 344 | flow=flow, 345 | ae=ae, 346 | clip=clip, 347 | t5=t5, 348 | config=config, 349 | ) 350 | -------------------------------------------------------------------------------- /lora_loading_patch.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa 2 | from diffusers.utils import ( 3 | convert_unet_state_dict_to_peft, 4 | get_peft_kwargs, 5 | is_peft_version, 6 | get_adapter_name, 7 | logging, 8 | ) 9 | 10 | logger = logging.get_logger(__name__) 11 | 12 | 13 | # TODO: 99% sure this patch functionality has been merged into diffusers and we don't need it anymore. 14 | # patching inject_adapter_in_model and load_peft_state_dict with low_cpu_mem_usage=True until it's merged into diffusers 15 | def load_lora_into_transformer( 16 | cls, 17 | state_dict, 18 | network_alphas, 19 | transformer, 20 | adapter_name=None, 21 | _pipeline=None, 22 | low_cpu_mem_usage=False, 23 | ): 24 | """ 25 | This will load the LoRA layers specified in `state_dict` into `transformer`. 26 | 27 | Parameters: 28 | state_dict (`dict`): 29 | A standard state dict containing the lora layer parameters. The keys can either be indexed directly 30 | into the unet or prefixed with an additional `unet` which can be used to distinguish between text 31 | encoder lora layers. 32 | network_alphas (`Dict[str, float]`): 33 | The value of the network alpha used for stable learning and preventing underflow. This value has the 34 | same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this 35 | link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). 36 | transformer (`SD3Transformer2DModel`): 37 | The Transformer model to load the LoRA layers into. 38 | adapter_name (`str`, *optional*): 39 | Adapter name to be used for referencing the loaded adapter model. If not specified, it will use 40 | `default_{i}` where i is the total number of adapters being loaded. 41 | """ 42 | from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict 43 | 44 | keys = list(state_dict.keys()) 45 | 46 | transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] 47 | state_dict = { 48 | k.replace(f"{cls.transformer_name}.", ""): v 49 | for k, v in state_dict.items() 50 | if k in transformer_keys 51 | } 52 | 53 | if len(state_dict.keys()) > 0: 54 | # check with first key if is not in peft format 55 | first_key = next(iter(state_dict.keys())) 56 | if "lora_A" not in first_key: 57 | state_dict = convert_unet_state_dict_to_peft(state_dict) 58 | 59 | if adapter_name in getattr(transformer, "peft_config", {}): 60 | raise ValueError( 61 | f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." 62 | ) 63 | 64 | rank = {} 65 | for key, val in state_dict.items(): 66 | if "lora_B" in key: 67 | rank[key] = val.shape[1] 68 | 69 | if network_alphas is not None and len(network_alphas) >= 1: 70 | prefix = cls.transformer_name 71 | alpha_keys = [ 72 | k 73 | for k in network_alphas.keys() 74 | if k.startswith(prefix) and k.split(".")[0] == prefix 75 | ] 76 | network_alphas = { 77 | k.replace(f"{prefix}.", ""): v 78 | for k, v in network_alphas.items() 79 | if k in alpha_keys 80 | } 81 | 82 | lora_config_kwargs = get_peft_kwargs( 83 | rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict 84 | ) 85 | if "use_dora" in lora_config_kwargs: 86 | if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): 87 | raise ValueError( 88 | "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." 89 | ) 90 | else: 91 | lora_config_kwargs.pop("use_dora") 92 | lora_config = LoraConfig(**lora_config_kwargs) 93 | 94 | # adapter_name 95 | if adapter_name is None: 96 | adapter_name = get_adapter_name(transformer) 97 | 98 | # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks 99 | # otherwise loading LoRA weights will lead to an error 100 | is_model_cpu_offload, is_sequential_cpu_offload = ( 101 | cls._optionally_disable_offloading(_pipeline) 102 | ) 103 | 104 | inject_adapter_in_model( 105 | lora_config, transformer, adapter_name=adapter_name, low_cpu_mem_usage=True 106 | ) 107 | incompatible_keys = set_peft_model_state_dict( 108 | transformer, state_dict, adapter_name, low_cpu_mem_usage=True 109 | ) 110 | 111 | if incompatible_keys is not None: 112 | # check only for unexpected keys 113 | unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) 114 | if unexpected_keys: 115 | logger.warning( 116 | f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " 117 | f" {unexpected_keys}. " 118 | ) 119 | 120 | # Offload back. 121 | if is_model_cpu_offload: 122 | _pipeline.enable_model_cpu_offload() 123 | elif is_sequential_cpu_offload: 124 | _pipeline.enable_sequential_cpu_offload() 125 | # Unsafe code /> 126 | -------------------------------------------------------------------------------- /model-cog-configs/canny-dev.yaml: -------------------------------------------------------------------------------- 1 | predict: "predict.py:CannyDevPredictor" 2 | -------------------------------------------------------------------------------- /model-cog-configs/depth-dev.yaml: -------------------------------------------------------------------------------- 1 | predict: "predict.py:DepthDevPredictor" 2 | -------------------------------------------------------------------------------- /model-cog-configs/dev-lora.yaml: -------------------------------------------------------------------------------- 1 | predict: "predict.py:DevLoraPredictor" 2 | -------------------------------------------------------------------------------- /model-cog-configs/dev.yaml: -------------------------------------------------------------------------------- 1 | predict: "predict.py:DevPredictor" 2 | -------------------------------------------------------------------------------- /model-cog-configs/fill-dev.yaml: -------------------------------------------------------------------------------- 1 | predict: "predict.py:FillDevPredictor" 2 | -------------------------------------------------------------------------------- /model-cog-configs/hotswap-lora.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | python_packages: 3 | - "accelerate==1.3.0" 4 | - "peft==0.14.0" 5 | 6 | predict: "predict.py:HotswapPredictor" 7 | -------------------------------------------------------------------------------- /model-cog-configs/redux-dev.yaml: -------------------------------------------------------------------------------- 1 | predict: "predict.py:DevReduxPredictor" 2 | -------------------------------------------------------------------------------- /model-cog-configs/redux-schnell.yaml: -------------------------------------------------------------------------------- 1 | predict: "predict.py:SchnellReduxPredictor" 2 | -------------------------------------------------------------------------------- /model-cog-configs/schnell-lora.yaml: -------------------------------------------------------------------------------- 1 | predict: "predict.py:SchnellLoraPredictor" 2 | -------------------------------------------------------------------------------- /model-cog-configs/schnell.yaml: -------------------------------------------------------------------------------- 1 | predict: "predict.py:SchnellPredictor" 2 | -------------------------------------------------------------------------------- /model-cog-configs/test.yaml: -------------------------------------------------------------------------------- 1 | 2 | predict: "predict.py:TestPredictor" -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | # Exclude a variety of commonly ignored directories. 2 | exclude = [ 3 | ".bzr", 4 | ".direnv", 5 | ".eggs", 6 | ".git", 7 | ".git-rewrite", 8 | ".hg", 9 | ".ipynb_checkpoints", 10 | ".mypy_cache", 11 | ".nox", 12 | ".pants.d", 13 | ".pyenv", 14 | ".pytest_cache", 15 | ".pytype", 16 | ".ruff_cache", 17 | ".svn", 18 | ".tox", 19 | ".venv", 20 | ".vscode", 21 | "__pypackages__", 22 | "_build", 23 | "buck-out", 24 | "build", 25 | "dist", 26 | "node_modules", 27 | "site-packages", 28 | "venv", 29 | "flux/", 30 | "fp8/", 31 | ] 32 | 33 | # Same as Black. 34 | line-length = 88 35 | indent-width = 4 36 | 37 | # Assume Python 3.8 38 | target-version = "py38" 39 | 40 | [lint] 41 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. 42 | # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or 43 | # McCabe complexity (`C901`) by default. 44 | select = ["E4", "E7", "E9", "F", "W", "I", "N", "UP", "A", "ICN", "PT", "Q", "RSE", "RET", "SLF", "SLOT", "SIM", "TID", "TCH", "ARG", "PTH", "ERA", "FLY"] 45 | ignore = ["E402", "F403", "F405", "PT011", "SIM117", "SIM102", "ERA001", "RSE102", "I001"] 46 | 47 | # Allow fix for all enabled rules (when `--fix`) is provided. 48 | fixable = ["ALL"] 49 | unfixable = [] 50 | 51 | # Allow unused variables when underscore-prefixed. 52 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 53 | 54 | [format] 55 | # Like Black, use double quotes for strings. 56 | quote-style = "double" 57 | 58 | # Like Black, indent with spaces, rather than tabs. 59 | indent-style = "space" 60 | 61 | # Like Black, respect magic trailing commas. 62 | skip-magic-trailing-comma = false 63 | 64 | # Like Black, automatically detect the appropriate line ending. 65 | line-ending = "auto" 66 | 67 | # Enable auto-formatting of code examples in docstrings. Markdown, 68 | # reStructuredText code/literal blocks and doctests are all supported. 69 | # 70 | # This is currently disabled by default, but it is planned for this 71 | # to be opt-out in the future. 72 | docstring-code-format = false 73 | 74 | # Set the line length limit used when formatting code snippets in 75 | # docstrings. 76 | # 77 | # This only has an effect when the `docstring-code-format` setting is 78 | # enabled. 79 | docstring-code-line-length = "dynamic" -------------------------------------------------------------------------------- /safe-push-configs/canny-dev.yaml: -------------------------------------------------------------------------------- 1 | model: replicate/flux-canny-dev-internal-model 2 | test_model: replicate/test-flux-canny-dev 3 | test_hardware: cpu 4 | predict: 5 | compare_outputs: false 6 | predict_timeout: 300 7 | test_cases: 8 | 9 | # basic 10 | - inputs: 11 | prompt: A bright red bird 12 | control_image: https://replicate.delivery/pbxt/IMPLYODUwdmHTsnLKi5YiFccIAK6g9l5KK1FNyCtpGS1g0UN/1200.jpeg 13 | seed: 2 14 | #match_prompt: A red bird # flaky for some reason 15 | -------------------------------------------------------------------------------- /safe-push-configs/depth-dev.yaml: -------------------------------------------------------------------------------- 1 | model: replicate/flux-depth-dev-internal-model 2 | test_model: replicate/test-flux-depth-dev 3 | test_hardware: cpu 4 | predict: 5 | compare_outputs: false 6 | predict_timeout: 300 7 | test_cases: 8 | 9 | # basic 10 | - inputs: 11 | prompt: A stormtrooper giving a lecture at a university 12 | control_image: https://replicate.delivery/pbxt/IKFvJn5EpLuDDsFysOP4B1J9HvKDbMBCwZUK9n6p9mIPoQwG/sd.png 13 | seed: 36414 14 | match_prompt: An image of storm trooper 15 | -------------------------------------------------------------------------------- /safe-push-configs/dev-lora.yaml: -------------------------------------------------------------------------------- 1 | model: replicate/flux-dev-lora-internal-model 2 | test_model: replicate/test-flux-dev-lora 3 | predict: 4 | compare_outputs: false # TODO(andreas): why doesn't this work? 5 | predict_timeout: 600 6 | test_cases: 7 | 8 | # monalisa ~= a person 9 | - inputs: 10 | prompt: a photo of MNALSA woman with pink hair at a rave 11 | num_outputs: 1 12 | num_inference_steps: 28 13 | seed: 8888 14 | output_format: jpg 15 | go_fast: true 16 | lora_weights: fofr/flux-mona-lisa 17 | lora_scale: 0.9 18 | #match_url: https://storage.googleapis.com/replicate-models-public-test/flux-lora-imgs/dev_flux-mona-lisa_go_fast.jpg 19 | 20 | # monalisa ~= 4 persons 21 | - inputs: 22 | prompt: a photo of MNALSA woman with pink hair at a rave 23 | num_outputs: 4 24 | num_inference_steps: 28 25 | seed: 8888 26 | output_format: jpg 27 | go_fast: true 28 | lora_weights: fofr/flux-mona-lisa 29 | lora_scale: 0.9 30 | match_prompt: Four images of a woman at a rave with pink hair who looks like the Mona Lisa 31 | 32 | 33 | # same but slower 34 | - inputs: 35 | prompt: a photo of MNALSA woman with pink hair at a rave 36 | num_outputs: 1 37 | num_inference_steps: 28 38 | seed: 8888 39 | output_format: jpg 40 | go_fast: false 41 | lora_weights: fofr/flux-mona-lisa 42 | lora_scale: 0.9 43 | # match_url: https://storage.googleapis.com/replicate-models-public-test/flux-lora-imgs/dev_flux-mona-lisa_go_slow.jpg 44 | 45 | # no lora! 46 | - inputs: 47 | prompt: a photo of MNALSA woman with pink hair at a rave 48 | num_outputs: 1 49 | num_inference_steps: 28 50 | seed: 8888 51 | output_format: jpg 52 | go_fast: true 53 | lora_scale: 0.9 54 | # match_url: https://storage.googleapis.com/replicate-models-public-test/flux-lora-imgs/dev_no-lora_go_fast.jpg 55 | 56 | # aesthetic lora 57 | - inputs: 58 | prompt: a smart person, sftsrv style 59 | lora_weights: aramintak/flux-softserve-anime 60 | num_outputs: 1 61 | num_inference_steps: 28 62 | seed: 8888 63 | output_format: jpg 64 | go_fast: true 65 | lora_scale: 0.9 66 | # match_url: https://storage.googleapis.com/replicate-models-public-test/flux-lora-imgs/dev_flux-softserve-anime_go_fast.jpg 67 | 68 | # non-replicate weights 69 | - inputs: 70 | prompt: a coca cola can "sacred elixir" arcana in the style of TOK a trtcrd, tarot style 71 | num_outputs: 1 72 | num_inference_steps: 28 73 | seed: 8888 74 | output_format: jpg 75 | go_fast: true 76 | lora_weights: huggingface.co/multimodalart/flux-tarot-v1 77 | lora_scale: 0.9 78 | # match_url: https://storage.googleapis.com/replicate-models-public-test/flux-lora-imgs/dev_flux-tarot-v1_go_fast.jpg 79 | 80 | # non-replicate weights kohya 81 | - inputs: 82 | prompt: cy04, a book titled "Did I Leave The Oven On?", an illustration of a man sitting at work, looking worried, thought bubble above his head with an oven in it 83 | num_outputs: 1 84 | num_inference_steps: 28 85 | output_format: jpg 86 | go_fast: true 87 | lora_weights: huggingface.co/Purz/choose-your-own-adventure 88 | lora_scale: 0.9 89 | match_prompt: A drawing of a man thinking about an oven 90 | 91 | # non-replicate weights no mlp trained 92 | - inputs: 93 | prompt: photo of a boy ANIMESTYLE 94 | num_outputs: 1 95 | num_inference_steps: 28 96 | output_format: jpg 97 | go_fast: true 98 | lora_weights: https://storage.googleapis.com/replicate-models-public-test/flux-loras/fixed_lora.safetensors 99 | lora_scale: 0.9 100 | match_prompt: An anime drawing of a boy 101 | 102 | - inputs: 103 | prompt: A portrait photo of MNALSA woman sitting at a party table with a selection of bad 70s food 104 | num_outputs: 1 105 | num_inference_steps: 28 106 | seed: 8888 107 | output_format: jpg 108 | go_fast: true 109 | lora_weights: fofr/flux-bad-70s-food 110 | lora_scale: 0.85 111 | extra_lora: fofr/flux-mona-lisa 112 | extra_lora_scale: 0.9 113 | match_prompt: An image of a woman at a dinner table who looks like the Mona Lisa 114 | 115 | fuzz: 116 | fixed_inputs: 117 | lora_weights: fofr/flux-90s-power-rangers 118 | extra_lora: fofr/flux-80s-cyberpunk 119 | iterations: 10 120 | prompt: | 121 | For the extra_lora input, here is a list of loras you can use: 122 | * fofr/flux-handwriting 123 | * fofr/flux-my-subconscious 124 | * aramintak/flux-softserve-anime 125 | * davisbrown/flux-half-illustration 126 | * andreasjansson/flux-shapes 127 | * https://civitai.com/api/download/models/735262 128 | * huggingface.co/multimodalart/flux-tarot-v1 129 | -------------------------------------------------------------------------------- /safe-push-configs/dev.yaml: -------------------------------------------------------------------------------- 1 | model: replicate/flux-dev-internal-model 2 | test_model: replicate/test-flux-dev 3 | predict: 4 | compare_outputs: false # TODO(andreas): why doesn't this work? 5 | predict_timeout: 600 6 | test_cases: 7 | 8 | # basic 9 | - inputs: 10 | prompt: A formula one car 11 | num_outputs: 1 12 | num_inference_steps: 28 13 | guidance: 3.5 14 | seed: 5259 15 | output_format: jpg 16 | match_prompt: A 1024x1024px jpg image of a formula one car 17 | 18 | # 4 outputs 19 | - inputs: 20 | prompt: A formula one car 21 | num_outputs: 4 22 | num_inference_steps: 20 23 | guidance: 1.0 24 | seed: 5259 25 | output_format: png 26 | match_prompt: Four png images 27 | 28 | # disable safety checker 29 | - inputs: 30 | prompt: A formula one car 31 | num_outputs: 1 32 | num_inference_steps: 20 33 | guidance: 1.0 34 | seed: 5259 35 | output_format: png 36 | disable_safety_checker: true 37 | match_prompt: A 1024x1024px png image of a formula one car 38 | 39 | # img2img 40 | - inputs: 41 | prompt: A formula one car 42 | num_outputs: 1 43 | num_inference_steps: 20 44 | guidance: 1.0 45 | seed: 5259 46 | output_format: png 47 | image: "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg" 48 | prompt_strength: 0.9 49 | match_prompt: A 1024x640px png image of a formula one car 50 | 51 | # aspect ratio 52 | - inputs: 53 | prompt: A formula one car 54 | num_outputs: 1 55 | num_inference_steps: 20 56 | guidance: 1.0 57 | seed: 5259 58 | output_format: png 59 | aspect_ratio: "3:2" 60 | match_prompt: A 1216x832px png image of a formula one car 61 | 62 | # go slow 63 | - inputs: 64 | prompt: A formula one car 65 | num_outputs: 1 66 | num_inference_steps: 28 67 | guidance: 3.5 68 | seed: 5259 69 | output_format: jpg 70 | go_fast: false 71 | match_prompt: A 1024x1024px jpg image of a formula one car 72 | 73 | # slow img2img, 4 outputs 74 | - inputs: 75 | prompt: A journey to the middle of the earth 76 | num_outputs: 4 77 | num_inference_steps: 20 78 | guidance: 3.5 79 | seed: 5259 80 | output_format: jpg 81 | go_fast: false 82 | image: "https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg" 83 | match_prompt: 4 jpg images 84 | 85 | fuzz: 86 | iterations: 10 87 | -------------------------------------------------------------------------------- /safe-push-configs/fill-dev.yaml: -------------------------------------------------------------------------------- 1 | model: replicate/flux-dev-fill-internal-model 2 | test_model: replicate/test-flux-dev-fill 3 | test_hardware: cpu 4 | predict: 5 | compare_outputs: false 6 | predict_timeout: 300 7 | test_cases: 8 | 9 | # basic 10 | - inputs: 11 | image: https://replicate.delivery/mgxm/f8c9cb3a-8ee8-41a7-9ef6-c65b37acc8af/desktop.png 12 | mask: https://replicate.delivery/mgxm/188d0097-6a6f-4488-a058-b0b7a66e5677/desktop-mask.png 13 | prompt: A herd of sheep grazing on a hill 14 | seed: 2 15 | match_prompt: An image of sheep grazing on a hill 16 | 17 | - inputs: 18 | prompt: a photo of MNALSA woman in front of a hill 19 | image: https://replicate.delivery/mgxm/f8c9cb3a-8ee8-41a7-9ef6-c65b37acc8af/desktop.png 20 | mask: https://replicate.delivery/mgxm/188d0097-6a6f-4488-a058-b0b7a66e5677/desktop-mask.png 21 | num_outputs: 1 22 | num_inference_steps: 28 23 | guidance: 2.5 24 | seed: 8888 25 | output_format: jpg 26 | lora_weights: fofr/flux-mona-lisa 27 | lora_scale: 2.0 28 | 29 | - inputs: 30 | image: https://replicate.delivery/mgxm/f8c9cb3a-8ee8-41a7-9ef6-c65b37acc8af/desktop.png 31 | mask: https://replicate.delivery/mgxm/188d0097-6a6f-4488-a058-b0b7a66e5677/desktop-mask.png 32 | prompt: A herd of sheep grazing on a hill 33 | seed: 2 34 | match_prompt: An image of sheep grazing on a hill 35 | -------------------------------------------------------------------------------- /safe-push-configs/hotswap-lora.yaml: -------------------------------------------------------------------------------- 1 | model: replicate/flux-hotswap-lora-internal-model 2 | test_model: replicate/test-flux-hotswap-lora 3 | predict: 4 | compare_outputs: false 5 | predict_timeout: 600 6 | test_cases: 7 | # same but slower 8 | - inputs: 9 | prompt: a photo of MNALSA woman with pink hair at a rave 10 | num_outputs: 1 11 | num_inference_steps: 28 12 | guidance_scale: 2.5 13 | seed: 8888 14 | model: dev 15 | output_format: jpg 16 | go_fast: false 17 | replicate_weights: fofr/flux-mona-lisa 18 | lora_scale: 0.9 19 | 20 | # monalisa ~= a person 21 | - inputs: 22 | prompt: a photo of MNALSA woman with pink hair at a rave 23 | num_outputs: 1 24 | num_inference_steps: 28 25 | model: dev 26 | guidance_scale: 2.5 27 | seed: 8888 28 | output_format: jpg 29 | go_fast: true 30 | replicate_weights: fofr/flux-mona-lisa 31 | lora_scale: 0.9 32 | match_prompt: An image of a woman at a rave with pink hair who looks like the Mona Lisa 33 | 34 | - inputs: 35 | prompt: a photo of MNALSA woman with pink hair at a rave 36 | num_outputs: 1 37 | num_inference_steps: 4 38 | model: schnell 39 | guidance_scale: 2.5 40 | seed: 8888 41 | output_format: jpg 42 | go_fast: true 43 | replicate_weights: fofr/flux-mona-lisa 44 | lora_scale: 0.9 45 | match_prompt: An image of a woman at a rave with pink hair who looks like the Mona Lisa 46 | 47 | # monalisa ~= 4 persons 48 | - inputs: 49 | prompt: a photo of MNALSA woman with pink hair at a rave 50 | num_outputs: 4 51 | num_inference_steps: 4 52 | model: schnell 53 | guidance_scale: 2.5 54 | seed: 8888 55 | output_format: jpg 56 | go_fast: true 57 | replicate_weights: fofr/flux-mona-lisa 58 | lora_scale: 0.9 59 | match_prompt: Four images of a woman at a rave with pink hair who looks like the Mona Lisa 60 | 61 | # monalisa ~= 4 persons 62 | - inputs: 63 | prompt: a photo of MNALSA woman with pink hair at a rave 64 | num_outputs: 4 65 | num_inference_steps: 28 66 | model: dev 67 | guidance_scale: 2.5 68 | seed: 8888 69 | output_format: jpg 70 | go_fast: true 71 | replicate_weights: fofr/flux-mona-lisa 72 | lora_scale: 0.9 73 | match_prompt: Four images of a woman at a rave with pink hair who looks like the Mona Lisa 74 | 75 | # same but slower 76 | - inputs: 77 | prompt: a photo of MNALSA woman with pink hair at a rave 78 | num_outputs: 1 79 | num_inference_steps: 28 80 | guidance_scale: 2.5 81 | seed: 8888 82 | model: dev 83 | output_format: jpg 84 | go_fast: false 85 | replicate_weights: fofr/flux-mona-lisa 86 | lora_scale: 0.9 87 | 88 | # same but slower 89 | - inputs: 90 | prompt: a photo of MNALSA woman with pink hair at a rave 91 | num_outputs: 1 92 | num_inference_steps: 4 93 | guidance_scale: 2.5 94 | seed: 8888 95 | model: schnell 96 | output_format: jpg 97 | go_fast: false 98 | replicate_weights: fofr/flux-mona-lisa 99 | lora_scale: 0.9 100 | 101 | # no lora! 102 | - inputs: 103 | prompt: a photo of MNALSA woman with pink hair at a rave 104 | num_outputs: 1 105 | num_inference_steps: 28 106 | guidance_scale: 2.5 107 | seed: 8888 108 | output_format: jpg 109 | go_fast: true 110 | model: dev 111 | lora_scale: 0.9 112 | 113 | # no lora! 114 | - inputs: 115 | prompt: a photo of MNALSA woman with pink hair at a rave 116 | num_outputs: 1 117 | num_inference_steps: 4 118 | guidance_scale: 2.5 119 | seed: 8888 120 | output_format: jpg 121 | go_fast: true 122 | model: schnell 123 | lora_scale: 0.9 124 | 125 | # non-replicate weights 126 | - inputs: 127 | prompt: a coca cola can "sacred elixir" arcana in the style of TOK a trtcrd, tarot style 128 | num_outputs: 1 129 | num_inference_steps: 28 130 | guidance_scale: 2.5 131 | seed: 8888 132 | model: dev 133 | output_format: jpg 134 | go_fast: true 135 | replicate_weights: huggingface.co/multimodalart/flux-tarot-v1 136 | lora_scale: 0.9 137 | 138 | # inpainting 139 | - inputs: 140 | prompt: a green cat sitting on a park bench 141 | image: https://storage.googleapis.com/replicate-models-public-test/flux-lora-imgs/inpainting-img.png 142 | mask: https://storage.googleapis.com/replicate-models-public-test/flux-lora-imgs/inpainting-mask.png 143 | aspect_ratio: "1:1" 144 | prompt_strength: 1.0 145 | model: dev 146 | num_outputs: 1 147 | num_inference_steps: 28 148 | go_fast: True 149 | megapixels: "1" 150 | replicate_weights: fofr/flux-80s-cyberpunk 151 | lora_scale: 1.1 152 | 153 | # custom height / width 154 | - inputs: 155 | prompt: a photo of MNALSA woman with pink hair at a rave 156 | height: 768 157 | width: 768 158 | aspect_ratio: custom 159 | num_outputs: 1 160 | num_inference_steps: 28 161 | model: dev 162 | guidance_scale: 2.5 163 | seed: 8888 164 | output_format: jpg 165 | go_fast: true 166 | replicate_weights: fofr/flux-mona-lisa 167 | lora_scale: 0.9 168 | match_prompt: A 768x768 image of a woman at a rave with pink hair who looks like the Mona Lisa 169 | 170 | # multi-lora 171 | - inputs: 172 | prompt: A portrait photo of MNALSA woman sitting at a party table with a selection of bad 70s food 173 | num_outputs: 1 174 | num_inference_steps: 28 175 | model: dev 176 | guidance_scale: 2.5 177 | seed: 8888 178 | output_format: jpg 179 | go_fast: true 180 | replicate_weights: fofr/flux-bad-70s-food 181 | lora_scale: 0.85 182 | extra_lora: fofr/flux-mona-lisa 183 | extra_lora_scale: 0.9 184 | match_prompt: An image of a woman at a dinner table who looks like the Mona Lisa 185 | -------------------------------------------------------------------------------- /safe-push-configs/redux-dev.yaml: -------------------------------------------------------------------------------- 1 | model: replicate/flux-redux-dev-internal-model 2 | test_model: replicate/test-flux-redux-dev 3 | test_hardware: cpu 4 | predict: 5 | compare_outputs: false 6 | predict_timeout: 300 7 | test_cases: 8 | 9 | # basic 10 | - inputs: 11 | redux_image: https://replicate.delivery/yhqm/eGnzS3AVsry6VCfRTZEvI7aIe7Vdp2WzCgzfeWLHhGbz2P0bC/out-0.webp 12 | seed: 2 13 | match_prompt: An astronaut hatching from an egg on the moon -------------------------------------------------------------------------------- /safe-push-configs/redux-schnell.yaml: -------------------------------------------------------------------------------- 1 | model: replicate/flux-redux-schnell-internal-model 2 | test_model: replicate/test-flux-redux-schnell 3 | test_hardware: cpu 4 | predict: 5 | compare_outputs: false 6 | predict_timeout: 300 7 | test_cases: 8 | 9 | # basic 10 | - inputs: 11 | redux_image: https://replicate.delivery/yhqm/eGnzS3AVsry6VCfRTZEvI7aIe7Vdp2WzCgzfeWLHhGbz2P0bC/out-0.webp 12 | seed: 2 13 | match_prompt: An astronaut hatching from an egg on the moon -------------------------------------------------------------------------------- /safe-push-configs/schnell-lora.yaml: -------------------------------------------------------------------------------- 1 | model: replicate/flux-schnell-lora-internal-model 2 | test_model: replicate/test-flux-schnell-lora 3 | predict: 4 | compare_outputs: false # TODO(andreas): why doesn't this work? 5 | predict_timeout: 600 6 | test_cases: 7 | # monalisa ~= a person 8 | - inputs: 9 | prompt: a photo of MNALSA woman with pink hair at a rave 10 | num_outputs: 1 11 | guidance_scale: 2.5 12 | seed: 8888 13 | output_format: jpg 14 | go_fast: true 15 | lora_weights: fofr/flux-mona-lisa 16 | lora_scale: 0.9 17 | # match_url: https://storage.googleapis.com/replicate-models-public-test/flux-lora-imgs/schnell_flux-mona-lisa_go_fast.jpg 18 | 19 | # monalisa ~= 4 persons 20 | - inputs: 21 | prompt: a photo of MNALSA woman with pink hair at a rave 22 | num_outputs: 4 23 | guidance_scale: 2.5 24 | seed: 8888 25 | output_format: jpg 26 | go_fast: true 27 | lora_weights: fofr/flux-mona-lisa 28 | lora_scale: 0.9 29 | match_prompt: Four images of a woman at a rave with pink hair who looks like the Mona Lisa 30 | 31 | 32 | # same but slower 33 | - inputs: 34 | prompt: a photo of MNALSA woman with pink hair at a rave 35 | num_outputs: 1 36 | guidance_scale: 2.5 37 | seed: 8888 38 | output_format: jpg 39 | go_fast: false 40 | lora_weights: fofr/flux-mona-lisa 41 | lora_scale: 0.9 42 | # match_url: https://storage.googleapis.com/replicate-models-public-test/flux-lora-imgs/schnell_flux-mona-lisa_go_slow.jpg 43 | 44 | # no lora! 45 | - inputs: 46 | prompt: a photo of MNALSA woman with pink hair at a rave 47 | num_outputs: 1 48 | guidance_scale: 2.5 49 | seed: 8888 50 | output_format: jpg 51 | go_fast: true 52 | lora_scale: 0.9 53 | # match_prompt: https://storage.googleapis.com/replicate-models-public-test/flux-lora-imgs/schnell_no-lora_go_fast.jpg 54 | 55 | # aesthetic lora 56 | - inputs: 57 | prompt: a smart person, sftsrv style 58 | lora_weights: aramintak/flux-softserve-anime 59 | num_outputs: 1 60 | guidance_scale: 2.5 61 | seed: 8888 62 | output_format: jpg 63 | go_fast: true 64 | lora_scale: 0.9 65 | # match_url: https://storage.googleapis.com/replicate-models-public-test/flux-lora-imgs/schnell_flux-softserve-anime_go_fast.jpg 66 | 67 | # non-replicate weights 68 | - inputs: 69 | prompt: a coca cola can "sacred elixir" arcana in the style of TOK a trtcrd, tarot style 70 | num_outputs: 1 71 | guidance_scale: 2.5 72 | seed: 8888 73 | output_format: jpg 74 | go_fast: true 75 | lora_weights: huggingface.co/multimodalart/flux-tarot-v1 76 | lora_scale: 0.9 77 | # match_url: https://storage.googleapis.com/replicate-models-public-test/flux-lora-imgs/schnell_flux-tarot-v1_go_fast.jpg 78 | 79 | # non-replicate weights kohya 80 | - inputs: 81 | prompt: cy04, a book titled "Did I Leave The Oven On?", an illustration of a man sitting at work, looking worried, thought bubble above his head with an oven in it 82 | num_outputs: 1 83 | guidance_scale: 2.5 84 | output_format: jpg 85 | go_fast: true 86 | lora_weights: huggingface.co/Purz/choose-your-own-adventure 87 | lora_scale: 0.9 88 | match_prompt: A drawing of a man thinking about an oven 89 | 90 | # non-replicate weights no mlp trained 91 | - inputs: 92 | prompt: photo of a boy ANIMESTYLE 93 | num_outputs: 1 94 | num_inference_steps: 4 95 | output_format: jpg 96 | go_fast: true 97 | lora_weights: https://storage.googleapis.com/replicate-models-public-test/flux-loras/fixed_lora.safetensors 98 | lora_scale: 0.9 99 | match_prompt: An anime drawing of a boy 100 | 101 | fuzz: 102 | fixed_inputs: 103 | lora_weights: huggingface.co/multimodalart/flux-tarot-v1 104 | iterations: 10 105 | -------------------------------------------------------------------------------- /safe-push-configs/schnell.yaml: -------------------------------------------------------------------------------- 1 | model: replicate/flux-schnell-internal-model 2 | test_model: replicate/test-flux-schnell 3 | predict: 4 | compare_outputs: false # TODO(andreas): why doesn't this work? 5 | predict_timeout: 600 6 | test_cases: 7 | 8 | # basic 9 | - inputs: 10 | prompt: A formula one car 11 | num_outputs: 1 12 | num_inference_steps: 4 13 | seed: 5259 14 | output_format: jpg 15 | match_prompt: A 1024x1024px jpg image of a formula one car 16 | 17 | # 4 outputs 18 | - inputs: 19 | prompt: A formula one car 20 | num_outputs: 4 21 | num_inference_steps: 4 22 | seed: 5259 23 | output_format: png 24 | match_prompt: Four png images 25 | 26 | # disable safety checker 27 | - inputs: 28 | prompt: A formula one car 29 | num_outputs: 1 30 | num_inference_steps: 4 31 | seed: 5259 32 | output_format: png 33 | disable_safety_checker: true 34 | match_prompt: A 1024x1024px png image of a formula one car 35 | 36 | # aspect ratio 37 | - inputs: 38 | prompt: A formula one car 39 | num_outputs: 1 40 | num_inference_steps: 1 41 | seed: 5259 42 | output_format: png 43 | aspect_ratio: "3:2" 44 | match_prompt: A 1216x832px png image of a formula one car 45 | 46 | # go slow 47 | - inputs: 48 | prompt: A formula one car 49 | num_outputs: 1 50 | num_inference_steps: 4 51 | seed: 5259 52 | output_format: jpg 53 | go_fast: false 54 | match_prompt: A 1024x1024px jpg image of a formula one car 55 | 56 | # slow 4 outputs 57 | - inputs: 58 | prompt: A formula one car 59 | num_outputs: 4 60 | num_inference_steps: 1 61 | seed: 5259 62 | output_format: jpg 63 | go_fast: false 64 | match_prompt: 4 jpg images 65 | 66 | fuzz: 67 | iterations: 10 68 | -------------------------------------------------------------------------------- /samples.py: -------------------------------------------------------------------------------- 1 | """ 2 | A handy utility for verifying image generation locally. 3 | To set up, first run a local cog server using: 4 | cog run -p 5000 python -m cog.server.http 5 | Then, in a separate terminal, generate samples 6 | python samples.py 7 | """ 8 | 9 | import base64 10 | import sys 11 | import time 12 | from pathlib import Path 13 | import requests 14 | 15 | 16 | def gen(output_fn, **kwargs): 17 | st = time.time() 18 | print("Generating", output_fn) 19 | url = "http://localhost:5000/predictions" 20 | response = requests.post(url, json={"input": kwargs}) 21 | data = response.json() 22 | print("Generated in: ", time.time() - st) 23 | 24 | try: 25 | datauri = data["output"][0] 26 | base64_encoded_data = datauri.split(",")[1] 27 | data = base64.b64decode(base64_encoded_data) 28 | except Exception: 29 | print("Error!") 30 | print("input:", kwargs) 31 | print(data["logs"]) 32 | sys.exit(1) 33 | 34 | Path(output_fn).write_bytes(data) 35 | 36 | 37 | def test_fp8_and_bf16(): 38 | """ 39 | runs generations in fp8 and bf16 on the same node! wow! 40 | """ 41 | gen( 42 | "float8_dog.png", 43 | prompt="a cool dog", 44 | aspect_ratio="1:1", 45 | num_outputs=1, 46 | output_format="png", 47 | disable_safety_checker=True, 48 | seed=123, 49 | float_8=True, 50 | ) 51 | 52 | gen( 53 | "bf16_dog.png", 54 | prompt="a cool dog", 55 | aspect_ratio="1:1", 56 | num_outputs=1, 57 | output_format="png", 58 | disable_safety_checker=True, 59 | seed=123, 60 | float_8=False, 61 | ) 62 | 63 | gen( 64 | "float8_dog_2.png", 65 | prompt="a cool dog", 66 | aspect_ratio="2:3", 67 | num_outputs=1, 68 | output_format="png", 69 | disable_safety_checker=True, 70 | seed=1231, 71 | float_8=True, 72 | ) 73 | 74 | gen( 75 | "bf16_dog_2.png", 76 | prompt="a cool dog", 77 | aspect_ratio="2:3", 78 | num_outputs=1, 79 | output_format="png", 80 | disable_safety_checker=True, 81 | seed=1231, 82 | float_8=False, 83 | ) 84 | 85 | 86 | if __name__ == "__main__": 87 | test_fp8_and_bf16() 88 | -------------------------------------------------------------------------------- /save_fp8_quantized.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from predict import DevPredictor, SchnellPredictor 4 | from safetensors.torch import save_file 5 | 6 | """ 7 | Code to prequantize and save fp8 weights for Dev or Schnell. Pattern should work for other models. 8 | Note - for this code to work, you'll need to tweak the config of the fp8 flux models in `predict.py` s.t. they load and quantize models. 9 | in practice, this just means eliminating the '-fp8' suffix on the model names. 10 | """ 11 | 12 | 13 | def generate_dev_img(p, img_name="cool_dog_1234.png"): 14 | p.predict("a cool dog", "1:1", None, 0, 1, 28, 3, 1234, "png", 100, True, True, "1") 15 | os.system(f"mv out-0.png {img_name}") 16 | 17 | 18 | def save_dev_fp8(): 19 | p = DevPredictor() 20 | p.setup() 21 | 22 | fp8_weights_path = "model-cache/dev-fp8" 23 | if not os.path.exists(fp8_weights_path): # noqa: PTH110 24 | os.makedirs(fp8_weights_path) # noqa: PTH103 25 | 26 | generate_dev_img(p) 27 | print( 28 | "scale initialized: ", 29 | p.fp8_model.fp8_pipe.model.double_blocks[0].img_mod.lin.input_scale_initialized, 30 | ) 31 | sd = p.fp8_model.fp8_pipe.model.state_dict() 32 | to_trim = "_orig_mod." 33 | sd_to_save = {k[len(to_trim) :]: v for k, v in sd.items()} 34 | save_file(sd_to_save, fp8_weights_path + "/" + "dev-fp8.safetensors") 35 | 36 | 37 | def test_dev_fp8(): 38 | p = DevPredictor() 39 | p.setup() 40 | generate_dev_img(p, "cool_dog_1234_loaded_from_compiled.png") 41 | 42 | 43 | def generate_schnell_img(p, img_name="fast_dog_1234.png"): 44 | p.predict("a cool dog", "1:1", 1, 4, 1234, "png", 100, True, True, "1") 45 | os.system(f"mv out-0.png {img_name}") 46 | 47 | 48 | def save_schnell_fp8(): 49 | p = SchnellPredictor() 50 | p.setup() 51 | 52 | fp8_weights_path = "model-cache/schnell-fp8" 53 | if not os.path.exists(fp8_weights_path): # noqa: PTH110 54 | os.makedirs(fp8_weights_path) # noqa: PTH103 55 | 56 | generate_schnell_img(p) 57 | print( 58 | "scale initialized: ", 59 | p.fp8_model.fp8_pipe.model.double_blocks[0].img_mod.lin.input_scale_initialized, 60 | ) 61 | sd = p.fp8_model.fp8_pipe.model.state_dict() 62 | to_trim = "_orig_mod." 63 | sd_to_save = {k[len(to_trim) :]: v for k, v in sd.items()} 64 | save_file(sd_to_save, fp8_weights_path + "/" + "schnell-fp8.safetensors") 65 | 66 | 67 | def test_schnell_fp8(): 68 | p = SchnellPredictor() 69 | p.setup() 70 | generate_schnell_img(p, "fast_dog_1234_loaded_from_compiled.png") 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser( 75 | description="Run image generation tests from YAML file" 76 | ) 77 | parser.add_argument("flux_model", help="schnell, dev, or all") 78 | args = parser.parse_args() 79 | if args.flux_model == "dev" or args.flux_model == "all": 80 | save_dev_fp8() 81 | if args.flux_model == "schnell" or args.flux_model == "all": 82 | save_schnell_fp8() 83 | else: 84 | print("testing I guess") 85 | # test_dev_fp8() 86 | test_schnell_fp8() 87 | -------------------------------------------------------------------------------- /script/prod-deploy-all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ./script/push.sh dev prod 4 | ./script/push.sh schnell prod -------------------------------------------------------------------------------- /script/push.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if both arguments are provided 4 | if [ $# -ne 2 ]; then 5 | echo "Usage: $0 " 6 | echo "Environment should be either 'test' or 'prod'" 7 | exit 1 8 | fi 9 | 10 | MODEL_NAME="$1" 11 | ENVIRONMENT="$2" 12 | 13 | # Validate environment argument 14 | if [ "$ENVIRONMENT" != "test" ] && [ "$ENVIRONMENT" != "prod" ]; then 15 | echo "Invalid environment. Please use 'test' or 'prod'." 16 | exit 1 17 | fi 18 | 19 | ./script/select.sh "$MODEL_NAME" 20 | 21 | if [ $? -ne 0 ]; then 22 | echo "Couldn't select a model, double check you're passing a valid name." 23 | exit 1 24 | fi 25 | 26 | v=$(cog --version) 27 | 28 | # async cog? 29 | if [[ $v == *"0.9."* ]]; then 30 | echo "Sync cog found, pushing model" 31 | else 32 | echo "Nope! switch to sync cog and rebuild" 33 | exit -1 34 | fi 35 | 36 | # Conditional cog push based on environment 37 | if [ "$ENVIRONMENT" == "test" ]; then 38 | echo "Pushing to test environment" 39 | cog push r8.im/replicate-internal/flux-$MODEL_NAME 40 | elif [ "$ENVIRONMENT" == "prod" ]; then 41 | echo "Pushing to production environment" 42 | cog push r8.im/replicate/flux-$MODEL_NAME-internal-model 43 | fi -------------------------------------------------------------------------------- /script/select.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == --list ]]; then 4 | ls model-cog-configs | sed 's/.yaml//' 5 | exit 0 6 | fi 7 | 8 | if [ -z $1 ]; then 9 | echo "Usage:" 10 | echo " ./script/select.sh " 11 | echo 12 | echo "To see all models: ./script/select.sh --list" 13 | exit 1 14 | fi 15 | 16 | yq eval-all 'select(fileIndex == 0) *+ select(fileIndex == 1)' cog.yaml.template "model-cog-configs/$1.yaml" > cog.yaml 17 | 18 | cp safe-push-configs/$1.yaml cog-safe-push.yaml 19 | -------------------------------------------------------------------------------- /script/update-schema.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_NAME="$1" 4 | 5 | ./script/select.sh "$MODEL_NAME" 6 | 7 | if [ $? -ne 0 ]; then 8 | echo "Couldn't select a model, double check you're passing a valid name." 9 | exit 1 10 | fi 11 | 12 | v=$(cog --version) 13 | 14 | # async cog? 15 | if [[ $v == *"0.9."* ]]; then 16 | echo "Sync cog found, pushing model" 17 | else 18 | echo "Nope! switch to sync cog and rebuild" 19 | exit -1 20 | fi 21 | 22 | # Conditional cog push based on environment 23 | echo "Pushing image to prod to update schema" 24 | date +%s > the_time.txt 25 | cog push r8.im/black-forest-labs/flux-$MODEL_NAME -------------------------------------------------------------------------------- /weights.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import hashlib 3 | import os 4 | import re 5 | import shutil 6 | import subprocess 7 | import tarfile 8 | import tempfile 9 | import time 10 | from collections import deque 11 | from io import BytesIO 12 | from pathlib import Path 13 | from contextlib import contextmanager 14 | 15 | from cog import Secret 16 | from huggingface_hub import HfApi, hf_hub_download, login, logout 17 | 18 | DEFAULT_CACHE_BASE_DIR = Path("/src/weights-cache") 19 | 20 | from dotenv import load_dotenv 21 | 22 | load_dotenv() 23 | 24 | 25 | class WeightsDownloadCache: 26 | def __init__( 27 | self, min_disk_free: int = 10 * (2**30), base_dir: Path = DEFAULT_CACHE_BASE_DIR 28 | ): 29 | self.min_disk_free = min_disk_free 30 | self.base_dir = base_dir 31 | self.hits = 0 32 | self.misses = 0 33 | 34 | # Least Recently Used (LRU) cache for paths 35 | self.lru_paths = deque() 36 | base_dir.mkdir(parents=True, exist_ok=True) 37 | 38 | def ensure( 39 | self, 40 | url: str, 41 | hf_api_token: Secret | None = None, 42 | civitai_api_token: Secret | None = None, 43 | ) -> Path: 44 | path = self._weights_path(url) 45 | 46 | if path in self.lru_paths: 47 | # here we remove to re-add to the end of the LRU (marking it as recently used) 48 | self.hits += 1 49 | self.lru_paths.remove(path) 50 | elif not Path.exists( 51 | path 52 | ): # local dev; sometimes we'll have a lora already downloaded 53 | self.misses += 1 54 | 55 | while not self._has_enough_space() and len(self.lru_paths) > 0: 56 | self._remove_least_recent() 57 | 58 | download_weights( 59 | url, 60 | path, 61 | hf_api_token=hf_api_token, 62 | civitai_api_token=civitai_api_token, 63 | ) 64 | 65 | self.lru_paths.append(path) # Add file to end of cache 66 | return path 67 | 68 | def cache_info(self) -> str: 69 | return f"CacheInfo(hits={self.hits}, misses={self.misses}, base_dir='{self.base_dir}', currsize={len(self.lru_paths)})" 70 | 71 | def _remove_least_recent(self) -> None: 72 | oldest = self.lru_paths.popleft() 73 | print("removing oldest", oldest) 74 | oldest.unlink() 75 | 76 | def _has_enough_space(self) -> bool: 77 | disk_usage = shutil.disk_usage(self.base_dir) 78 | 79 | free = disk_usage.free 80 | print(f"{free=}") # TODO(andreas): remove debug 81 | 82 | return free >= self.min_disk_free 83 | 84 | def _weights_path(self, url: str) -> Path: 85 | hashed_url = hashlib.sha256(url.encode()).hexdigest() 86 | short_hash = hashed_url[:16] # Use the first 16 characters of the hash 87 | return self.base_dir / short_hash 88 | 89 | 90 | def download_weights( 91 | url: str, 92 | path: Path, 93 | hf_api_token: str | None = None, 94 | civitai_api_token: str | None = None, 95 | ): 96 | download_url = make_download_url(url, civitai_api_token=civitai_api_token) 97 | download_weights_url(download_url, path, hf_api_token=hf_api_token) 98 | 99 | 100 | @contextmanager 101 | def logged_in_to_huggingface( 102 | token: Secret | None = None, add_to_git_credential: bool = False 103 | ): 104 | """Context manager for temporary Hugging Face login.""" 105 | try: 106 | if token is not None: 107 | print("Attemptig to login to HuggingFace using provided token...") 108 | # Log in at the start of the context 109 | login( 110 | token=token.get_secret_value(), 111 | add_to_git_credential=add_to_git_credential, 112 | ) 113 | print("Login to HuggingFace successful!") 114 | yield 115 | finally: 116 | # Always log out at the end, even if an exception occurs 117 | logout() 118 | print("Logged out of HuggingFace.") 119 | 120 | 121 | def download_weights_url(url: str, path: Path, hf_api_token: str | None = None): 122 | path = Path(path) 123 | 124 | print("Downloading weights") 125 | start_time = time.time() 126 | 127 | if m := re.match( 128 | r"^(?:https?://)?huggingface\.co/([^/]+)/([^/]+)(?:/([^/]+\.safetensors))?/?$", 129 | url, 130 | ): 131 | if len(m.groups()) == 2: 132 | owner, model_name = m.groups() 133 | lora_weights = None 134 | else: 135 | owner, model_name, lora_weights = m.groups() 136 | 137 | # Use HuggingFace Hub download directly 138 | try: 139 | with logged_in_to_huggingface(hf_api_token): 140 | if lora_weights is None: 141 | repo_id = f"{owner}/{model_name}" 142 | files = HfApi().list_repo_files(repo_id) 143 | sft_files = [file for file in files if ".safetensors" in file] 144 | if len(sft_files) == 1: 145 | lora_weights = sft_files[0] 146 | else: 147 | raise ValueError( 148 | f"No *.safetensors file was explicitly specified from the HuggingFace repo {repo_id} and more than one *.safetensors file was found. Found: {[sft_file for sft_file in sft_files]}" 149 | ) 150 | 151 | safetensors_path = hf_hub_download( 152 | repo_id=f"{owner}/{model_name}", 153 | filename=lora_weights, 154 | ) 155 | # Copy the downloaded file to the desired path 156 | shutil.copy(Path(safetensors_path), path) 157 | print(f"Downloaded {lora_weights} from HuggingFace to {path}") 158 | except Exception as e: 159 | raise ValueError(f"Failed to download from HuggingFace: {e}") 160 | elif url.startswith("data:"): 161 | download_data_url(url, path) 162 | elif url.endswith(".tar"): 163 | download_safetensors_tarball(url, path) 164 | elif ( 165 | url.endswith(".safetensors") 166 | or "://civitai.com/api/download" in url 167 | or ".safetensors?" in url 168 | ): 169 | download_safetensors(url, path) 170 | elif url.endswith("/_weights"): 171 | download_safetensors_tarball(url, path) 172 | else: 173 | raise ValueError("URL must end with either .tar or .safetensors") 174 | 175 | print(f"Downloaded weights in {time.time() - start_time:.2f}s") 176 | 177 | 178 | def find_safetensors(directory: Path) -> list[Path]: 179 | safetensors_paths = [] 180 | for root, _, files in os.walk(directory): 181 | root = Path(root) 182 | for filename in files: 183 | path = root / filename 184 | if path.suffix == ".safetensors": 185 | safetensors_paths.append(path) 186 | return safetensors_paths 187 | 188 | 189 | def download_safetensors_tarball(url: str, path: Path): 190 | with tempfile.TemporaryDirectory() as temp_dir: 191 | temp_dir = Path(temp_dir) 192 | extract_dir = temp_dir / "weights" 193 | 194 | try: 195 | subprocess.run( 196 | ["pget", "--log-level=WARNING", "-x", url, extract_dir], check=True 197 | ) 198 | except subprocess.CalledProcessError as e: 199 | raise RuntimeError(f"Failed to download tarball: {e}") 200 | 201 | safetensors_paths = find_safetensors(extract_dir) 202 | if not safetensors_paths: 203 | raise ValueError("No .safetensors file found in tarball") 204 | if len(safetensors_paths) > 1: 205 | raise ValueError("Multiple .safetensors files found in tarball") 206 | safetensors_path = safetensors_paths[0] 207 | 208 | shutil.move(safetensors_path, path) 209 | 210 | 211 | def download_data_url(url: str, path: Path): 212 | _, encoded = url.split(",", 1) 213 | data = base64.b64decode(encoded) 214 | 215 | with tempfile.TemporaryDirectory() as temp_dir: 216 | with tarfile.open(fileobj=BytesIO(data), mode="r:*") as tar: 217 | tar.extractall(path=temp_dir) 218 | 219 | safetensors_paths = find_safetensors(Path(temp_dir)) 220 | if not safetensors_paths: 221 | raise ValueError("No .safetensors file found in data URI") 222 | if len(safetensors_paths) > 1: 223 | raise ValueError("Multiple .safetensors files found in data URI") 224 | safetensors_path = safetensors_paths[0] 225 | 226 | shutil.move(safetensors_path, path) 227 | 228 | 229 | def download_safetensors(url: str, path: Path): 230 | try: 231 | # don't want to leak civitai api key 232 | output_redirect = subprocess.PIPE 233 | if "token=" in url: 234 | # print url without token 235 | print(f"downloading weights from {url.split('token=')[0]}token=***") 236 | else: 237 | print(f"downloading weights from {url}") 238 | 239 | result = subprocess.run( 240 | ["pget", url, str(path)], 241 | check=False, 242 | stdout=output_redirect, 243 | stderr=output_redirect, 244 | text=True, 245 | ) 246 | 247 | if result.returncode != 0: 248 | error_output = result.stderr or "" 249 | if "401" in error_output: 250 | raise RuntimeError( 251 | "Authorization to download weights failed. Please check to see if an API key is needed and if so pass in with the URL." 252 | ) 253 | if "404" in error_output: 254 | if "civitai" in url: 255 | raise RuntimeError( 256 | "Model not found on CivitAI at that URL. Double check the CivitAI model ID; the id on the download link can be different than the id to browse to the model page." 257 | ) 258 | raise RuntimeError( 259 | "Weights not found at the supplied URL. Please check the URL." 260 | ) 261 | raise RuntimeError(f"Failed to download safetensors file: {error_output}") 262 | 263 | except subprocess.CalledProcessError as e: 264 | raise RuntimeError(f"Failed to download safetensors file: {e}") 265 | 266 | 267 | def make_download_url(url: str, civitai_api_token: Secret | None = None) -> str: 268 | if url.startswith("data:"): 269 | return url 270 | if m := re.match( 271 | r"^(?:https?://)?huggingface\.co/([^/]+)/([^/]+)(?:/([^/]+\.safetensors))?/?$", 272 | url, 273 | ): 274 | if len(m.groups()) not in [2, 3]: 275 | raise ValueError( 276 | "Invalid HuggingFace URL. Expected format: huggingface.co//[/]" 277 | ) 278 | return url 279 | if m := re.match(r"^(?:https?://)?civitai\.com/models/(\d+)(?:/[^/?]+)?/?$", url): 280 | model_id = m.groups()[0] 281 | return make_civitai_download_url(model_id, civitai_api_token) 282 | if m := re.match(r"^((?:https?://)?civitai\.com/api/download/models/.*)$", url): 283 | return url 284 | if m := re.match(r"^(https?://.*\.safetensors(\?.*)?)$", url): 285 | return url # URL with query parameters, keep the whole url 286 | if m := re.match(r"^(?:https?://replicate.com/)?([^/]+)/([^/]+)/?$", url): 287 | owner, model_name = m.groups() 288 | return make_replicate_model_download_url(owner, model_name) 289 | if m := re.match( 290 | r"^(?:https?://replicate.com/)?([^/]+)/([^/]+)/(?:versions/)?([^/]+)/?$", url 291 | ): 292 | owner, model_name, version_id = m.groups() 293 | return make_replicate_version_download_url(owner, model_name, version_id) 294 | if m := re.match(r"^(https?://replicate.delivery/.*\.tar)$", url): 295 | return m.groups()[0] 296 | 297 | if "huggingface.co" in url: 298 | raise ValueError( 299 | "Failed to parse HuggingFace URL. Expected huggingface.co//[/]" 300 | ) 301 | if "civitai.com" in url: 302 | raise ValueError( 303 | "Failed to parse CivitAI URL. Expected civitai.com/models/[/]" 304 | ) 305 | raise ValueError( 306 | """Failed to parse URL. Expected either: 307 | * Replicate model in the format / or // 308 | * HuggingFace URL in the format huggingface.co//[/] 309 | * CivitAI URL in the format civitai.com/models/[/] 310 | * Arbitrary .safetensors URLs from the Internet""" 311 | ) 312 | 313 | 314 | def make_replicate_model_download_url(owner: str, model_name: str) -> str: 315 | return f"https://replicate.com/{owner}/{model_name}/_weights" 316 | 317 | 318 | def make_replicate_version_download_url( 319 | owner: str, model_name: str, version_id: str 320 | ) -> str: 321 | return f"https://replicate.com/{owner}/{model_name}/versions/{version_id}/_weights" 322 | 323 | 324 | def make_civitai_download_url( 325 | model_id: str, civitai_api_token: str | None = None 326 | ) -> str: 327 | if civitai_api_token is None: 328 | return f"https://civitai.com/api/download/models/{model_id}?type=Model&format=SafeTensor" 329 | return f"https://civitai.com/api/download/models/{model_id}?type=Model&format=SafeTensor&token={civitai_api_token.get_secret_value()}" 330 | --------------------------------------------------------------------------------