├── teaser.jpg ├── requirements.txt ├── cache_latent_codes.sh ├── cache_prompt_embeds.sh ├── train_2k.sh ├── train_4k.sh ├── inference_2k.ipynb ├── inference_2k_schnell.ipynb ├── inference_4k_lora_conversion_schnell.ipynb ├── inference_4k_lora_conversion.ipynb ├── .gitignore ├── cache_latent_codes.py ├── attention_processor.py ├── cache_prompt_embeds.py ├── README.md ├── LICENSE ├── inference_4k_schnell.ipynb ├── inference_4k.ipynb ├── transformer_flux.py ├── train_2k.py └── pipeline_flux.py /teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Huage001/URAE/HEAD/teaser.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.1.0 2 | transformers 3 | diffusers==0.31.0 4 | sentencepiece 5 | accelerate 6 | opencv-python 7 | peft 8 | wandb 9 | jupyter 10 | patch_conv 11 | prodigyopt -------------------------------------------------------------------------------- /cache_latent_codes.sh: -------------------------------------------------------------------------------- 1 | export NUM_WORKERS=1 2 | export DATA_DIR="path/to/data" 3 | export MODEL_NAME="black-forest-labs/FLUX.1-dev" 4 | torchrun --nproc_per_node=$NUM_WORKERS cache_latent_codes.py \ 5 | --data_root=$DATA_DIR \ 6 | --num_worker=$NUM_WORKERS \ 7 | --pretrained_model_name_or_path=$MODEL_NAME \ 8 | --mixed_precision='bf16' \ 9 | --output_dir=$DATA_DIR \ 10 | --resolution=4096 -------------------------------------------------------------------------------- /cache_prompt_embeds.sh: -------------------------------------------------------------------------------- 1 | export NUM_WORKERS=1 2 | export DATA_DIR="/path/to/data" 3 | export MODEL_NAME="black-forest-labs/FLUX.1-dev" 4 | torchrun --nproc_per_node=$NUM_WORKERS cache_prompt_embeds.py \ 5 | --data_root=$DATA_DIR \ 6 | --batch_size=256 \ 7 | --num_worker=$NUM_WORKERS \ 8 | --pretrained_model_name_or_path=$MODEL_NAME \ 9 | --mixed_precision='bf16' \ 10 | --output_dir=$DATA_DIR \ 11 | --column="prompt" \ 12 | --max_sequence_length=512 -------------------------------------------------------------------------------- /train_2k.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="black-forest-labs/FLUX.1-dev" 2 | export DATA_DIR="/path/to/data" 3 | export OUTPUT_DIR="/path/to/ckpt" 4 | export PRECISION="bf16" 5 | 6 | accelerate launch --num_processes 8 --multi_gpu --mixed_precision $PRECISION train_2k.py \ 7 | --pretrained_model_name_or_path=$MODEL_NAME \ 8 | --dataset_root=$DATA_DIR \ 9 | --output_dir=$OUTPUT_DIR \ 10 | --mixed_precision=$PRECISION \ 11 | --dataloader_num_workers=4 \ 12 | --train_batch_size=1 \ 13 | --gradient_accumulation_steps=1 \ 14 | --optimizer="prodigy" \ 15 | --learning_rate=1. \ 16 | --report_to="wandb" \ 17 | --lr_scheduler="constant" \ 18 | --lr_warmup_steps=0 \ 19 | --max_train_steps=2000 \ 20 | --seed="0" \ 21 | --real_prompt_ratio=0.2 \ 22 | --checkpointing_steps=1000 \ 23 | --gradient_checkpointing 24 | -------------------------------------------------------------------------------- /train_4k.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="black-forest-labs/FLUX.1-dev" 2 | export DATA_DIR="/path/to/data" 3 | export OUTPUT_DIR="/path/to/ckpt" 4 | export PRECISION="bf16" 5 | 6 | accelerate launch --num_processes 8 --multi_gpu --mixed_precision $PRECISION train_4k.py \ 7 | --pretrained_model_name_or_path=$MODEL_NAME \ 8 | --dataset_root=$DATA_DIR \ 9 | --output_dir=$OUTPUT_DIR \ 10 | --mixed_precision=$PRECISION \ 11 | --dataloader_num_workers=4 \ 12 | --train_batch_size=1 \ 13 | --gradient_accumulation_steps=1 \ 14 | --optimizer="prodigy" \ 15 | --learning_rate=1. \ 16 | --report_to="wandb" \ 17 | --lr_scheduler="constant" \ 18 | --lr_warmup_steps=0 \ 19 | --max_train_steps=10000 \ 20 | --seed="0" \ 21 | --real_prompt_ratio=0.2 \ 22 | --checkpointing_steps=1000 \ 23 | --gradient_checkpointing \ 24 | --ntk_factor=10 \ 25 | --proportional_attention \ 26 | --pretrained_lora="ckpt/urae_2k_adapter.safetensors" 27 | -------------------------------------------------------------------------------- /inference_2k.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import torch\n", 11 | "from huggingface_hub import hf_hub_download\n", 12 | "from pipeline_flux import FluxPipeline\n", 13 | "from transformer_flux import FluxTransformer2DModel" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "bfl_repo=\"black-forest-labs/FLUX.1-dev\"\n", 23 | "device = torch.device('cuda')\n", 24 | "dtype = torch.bfloat16\n", 25 | "transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\"transformer\", torch_dtype=dtype)\n", 26 | "pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, torch_dtype=dtype)\n", 27 | "pipe.scheduler.config.use_dynamic_shifting = False\n", 28 | "pipe.scheduler.config.time_shift = 10\n", 29 | "pipe = pipe.to(device)" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "if not os.path.exists('ckpt/urae_2k_adapter.safetensors'):\n", 39 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_2k_adapter.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)\n", 40 | "pipe.load_lora_weights(\"ckpt/urae_2k_adapter.safetensors\")" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "prompt = \"A serene woman in a flowing azure dress, gracefully perched on a sunlit cliff overlooking a tranquil sea, her hair gently tousled by the breeze. The scene is infused with a sense of peace, evoking a dreamlike atmosphere, reminiscent of Impressionist paintings.\"\n", 50 | "height = 2048\n", 51 | "width = 2048\n", 52 | "image = pipe(\n", 53 | " prompt,\n", 54 | " height=height,\n", 55 | " width=width,\n", 56 | " guidance_scale=3.5,\n", 57 | " num_inference_steps=28,\n", 58 | " max_sequence_length=512,\n", 59 | " generator=torch.manual_seed(8888),\n", 60 | " ntk_factor=10,\n", 61 | " proportional_attention=True\n", 62 | ").images[0]\n", 63 | "image" 64 | ] 65 | } 66 | ], 67 | "metadata": { 68 | "kernelspec": { 69 | "display_name": "Python 3 (ipykernel)", 70 | "language": "python", 71 | "name": "python3" 72 | } 73 | }, 74 | "nbformat": 4, 75 | "nbformat_minor": 2 76 | } 77 | -------------------------------------------------------------------------------- /inference_2k_schnell.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import torch\n", 11 | "from huggingface_hub import hf_hub_download\n", 12 | "from pipeline_flux import FluxPipeline\n", 13 | "from transformer_flux import FluxTransformer2DModel" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "bfl_repo=\"black-forest-labs/FLUX.1-schnell\"\n", 23 | "device = torch.device('cuda')\n", 24 | "dtype = torch.bfloat16\n", 25 | "transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\"transformer\", torch_dtype=dtype)\n", 26 | "pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, torch_dtype=dtype)\n", 27 | "pipe.transformer = transformer\n", 28 | "pipe.scheduler.config.use_dynamic_shifting = False\n", 29 | "pipe.scheduler.config.time_shift = 10\n", 30 | "pipe = pipe.to(device)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "if not os.path.exists('ckpt/urae_2k_adapter.safetensors'):\n", 40 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_2k_adapter.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)\n", 41 | "pipe.load_lora_weights(\"ckpt/urae_2k_adapter.safetensors\")" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "prompt = \"A serene woman in a flowing azure dress, gracefully perched on a sunlit cliff overlooking a tranquil sea, her hair gently tousled by the breeze. The scene is infused with a sense of peace, evoking a dreamlike atmosphere, reminiscent of Impressionist paintings.\"\n", 51 | "height = 2048\n", 52 | "width = 2048\n", 53 | "image = pipe(\n", 54 | " prompt,\n", 55 | " height=height,\n", 56 | " width=width,\n", 57 | " guidance_scale=0,\n", 58 | " num_inference_steps=4,\n", 59 | " max_sequence_length=256,\n", 60 | " generator=torch.manual_seed(8888),\n", 61 | " ntk_factor=10,\n", 62 | " proportional_attention=True\n", 63 | ").images[0]\n", 64 | "image" 65 | ] 66 | } 67 | ], 68 | "metadata": { 69 | "kernelspec": { 70 | "display_name": "Python 3 (ipykernel)", 71 | "language": "python", 72 | "name": "python3" 73 | } 74 | }, 75 | "nbformat": 4, 76 | "nbformat_minor": 2 77 | } 78 | -------------------------------------------------------------------------------- /inference_4k_lora_conversion_schnell.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import torch\n", 11 | "from huggingface_hub import hf_hub_download\n", 12 | "from pipeline_flux import FluxPipeline\n", 13 | "from transformer_flux import FluxTransformer2DModel\n", 14 | "from patch_conv import convert_model" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "bfl_repo=\"black-forest-labs/FLUX.1-schnell\"\n", 24 | "device = torch.device('cuda')\n", 25 | "dtype = torch.bfloat16\n", 26 | "transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\"transformer\", torch_dtype=dtype)\n", 27 | "pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, torch_dtype=dtype)\n", 28 | "pipe.scheduler.config.use_dynamic_shifting = False\n", 29 | "pipe.scheduler.config.time_shift = 10\n", 30 | "pipe.enable_model_cpu_offload()" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "if not os.path.exists('ckpt/urae_4k_adapter_lora_conversion_schnell.safetensors'):\n", 40 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_4k_adapter_lora_conversion_schnell.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)\n", 41 | "pipe.load_lora_weights(\"ckpt/urae_4k_adapter_lora_conversion_schnell.safetensors\")" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "* Use patch-wise convolution for VAE to avoid OOM error when decoding\n", 49 | "* If still OOM, try replacing the following line with `pipe.vae.enable_tiling()`" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "pipe.vae = convert_model(pipe.vae, splits=4)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "prompt = \"A serene woman in a flowing azure dress, gracefully perched on a sunlit cliff overlooking a tranquil sea, her hair gently tousled by the breeze. The scene is infused with a sense of peace, evoking a dreamlike atmosphere, reminiscent of Impressionist paintings.\"\n", 68 | "height = 4096\n", 69 | "width = 4096\n", 70 | "image = pipe(\n", 71 | " prompt,\n", 72 | " height=height,\n", 73 | " width=width,\n", 74 | " guidance_scale=0,\n", 75 | " num_inference_steps=4,\n", 76 | " max_sequence_length=256,\n", 77 | " generator=torch.manual_seed(8888),\n", 78 | " ntk_factor=10,\n", 79 | " proportional_attention=True\n", 80 | ").images[0]\n", 81 | "image" 82 | ] 83 | } 84 | ], 85 | "metadata": { 86 | "kernelspec": { 87 | "display_name": "Python 3 (ipykernel)", 88 | "language": "python", 89 | "name": "python3" 90 | } 91 | }, 92 | "nbformat": 4, 93 | "nbformat_minor": 2 94 | } 95 | -------------------------------------------------------------------------------- /inference_4k_lora_conversion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import torch\n", 11 | "from huggingface_hub import hf_hub_download\n", 12 | "from pipeline_flux import FluxPipeline\n", 13 | "from transformer_flux import FluxTransformer2DModel\n", 14 | "from patch_conv import convert_model" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "bfl_repo=\"black-forest-labs/FLUX.1-dev\"\n", 24 | "device = torch.device('cuda')\n", 25 | "dtype = torch.bfloat16\n", 26 | "transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\"transformer\", torch_dtype=dtype)\n", 27 | "pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, torch_dtype=dtype)\n", 28 | "pipe.scheduler.config.use_dynamic_shifting = False\n", 29 | "pipe.scheduler.config.time_shift = 10\n", 30 | "pipe.enable_model_cpu_offload()" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "if not os.path.exists('ckpt/urae_4k_adapter_lora_conversion_dev.safetensors'):\n", 40 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_4k_adapter_lora_conversion_dev.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)\n", 41 | "pipe.load_lora_weights(\"ckpt/urae_4k_adapter_lora_conversion_dev.safetensors\")" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "* Use patch-wise convolution for VAE to avoid OOM error when decoding\n", 49 | "* If still OOM, try replacing the following line with `pipe.vae.enable_tiling()`" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "pipe.vae = convert_model(pipe.vae, splits=4)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "prompt = \"A serene woman in a flowing azure dress, gracefully perched on a sunlit cliff overlooking a tranquil sea, her hair gently tousled by the breeze. The scene is infused with a sense of peace, evoking a dreamlike atmosphere, reminiscent of Impressionist paintings.\"\n", 68 | "height = 4096\n", 69 | "width = 4096\n", 70 | "image = pipe(\n", 71 | " prompt,\n", 72 | " height=height,\n", 73 | " width=width,\n", 74 | " guidance_scale=3.5,\n", 75 | " num_inference_steps=28,\n", 76 | " max_sequence_length=512,\n", 77 | " generator=torch.manual_seed(0),\n", 78 | " ntk_factor=10,\n", 79 | " proportional_attention=True\n", 80 | ").images[0]\n", 81 | "image" 82 | ] 83 | } 84 | ], 85 | "metadata": { 86 | "kernelspec": { 87 | "display_name": "Python 3 (ipykernel)", 88 | "language": "python", 89 | "name": "python3" 90 | }, 91 | "language_info": { 92 | "codemirror_mode": { 93 | "name": "ipython", 94 | "version": 3 95 | }, 96 | "file_extension": ".py", 97 | "mimetype": "text/x-python", 98 | "name": "python", 99 | "nbconvert_exporter": "python", 100 | "pygments_lexer": "ipython3", 101 | "version": "3.12.9" 102 | } 103 | }, 104 | "nbformat": 4, 105 | "nbformat_minor": 2 106 | } 107 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /cache_latent_codes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import math 6 | from safetensors.torch import save_file 7 | import torch 8 | import torch.nn as nn 9 | import tqdm 10 | import cv2 11 | from diffusers import AutoencoderKL 12 | from patch_conv import PatchConv2d 13 | 14 | 15 | def parse_args(input_args=None): 16 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 17 | parser.add_argument( 18 | "--pretrained_model_name_or_path", 19 | type=str, 20 | default=None, 21 | required=True, 22 | help="Path to pretrained model or model identifier from huggingface.co/models.", 23 | ) 24 | parser.add_argument( 25 | "--revision", 26 | type=str, 27 | default=None, 28 | required=False, 29 | help="Revision of pretrained model identifier from huggingface.co/models.", 30 | ) 31 | parser.add_argument( 32 | "--variant", 33 | type=str, 34 | default=None, 35 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 36 | ) 37 | parser.add_argument( 38 | "--data_root", 39 | type=str, 40 | default=None 41 | ) 42 | parser.add_argument( 43 | "--cache_dir", 44 | type=str, 45 | default=None, 46 | help="The directory where the downloaded models and datasets will be stored.", 47 | ) 48 | parser.add_argument( 49 | "--output_dir", 50 | type=str, 51 | default="flux-lora", 52 | help="The output directory where the model predictions and checkpoints will be written.", 53 | ) 54 | parser.add_argument( 55 | "--mixed_precision", 56 | type=str, 57 | default=None, 58 | choices=["no", "fp16", "bf16"], 59 | help=( 60 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 61 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 62 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 63 | ), 64 | ) 65 | parser.add_argument("--local_rank", type=int, default=0, help="For distributed training: local_rank") 66 | parser.add_argument( 67 | "--resolution", 68 | type=int, 69 | default=None, 70 | help="Image resolution for resizing. If None, the original resolution will be used.", 71 | ) 72 | parser.add_argument( 73 | "--num_workers", 74 | type=int, 75 | default=1, 76 | help="Number of workers", 77 | ) 78 | 79 | if input_args is not None: 80 | args = parser.parse_args(input_args) 81 | else: 82 | args = parser.parse_args() 83 | 84 | args.local_rank = int(os.environ.get("LOCAL_RANK", 0)) 85 | 86 | return args 87 | 88 | 89 | def resize(img, base_size=4096): 90 | all_sizes = np.array([ 91 | [base_size, base_size], 92 | [base_size, base_size * 3 // 4], 93 | [base_size, base_size // 2], 94 | [base_size, base_size // 4] 95 | ]) 96 | width, height = img.size 97 | if width < height: 98 | size = all_sizes[np.argmin(np.abs(all_sizes[:, 0] / all_sizes[:, 1] - height / width))][::-1] 99 | else: 100 | size = all_sizes[np.argmin(np.abs(all_sizes[:, 0] / all_sizes[:, 1] - width / height))] 101 | return img.resize((size[0], size[1])) 102 | 103 | 104 | def convert_model(model: nn.Module, splits: int = 4, sequential: bool=True) -> nn.Module: 105 | """ 106 | Convert the convolutions in the model to PatchConv2d. 107 | """ 108 | if isinstance(model, PatchConv2d): 109 | return model 110 | elif isinstance(model, nn.Conv2d) and model.kernel_size[0] > 1 and model.kernel_size[1] > 1: 111 | return PatchConv2d(splits=splits, conv2d=model) 112 | else: 113 | for name, module in model.named_modules(): 114 | if isinstance(module, (nn.Conv2d, PatchConv2d)): 115 | continue 116 | if 'downsamplers' in name: 117 | continue 118 | for subname, submodule in module.named_children(): 119 | if isinstance(submodule, nn.Conv2d) and submodule.kernel_size[0] > 1 and submodule.kernel_size[1] > 1: 120 | setattr(module, subname, PatchConv2d(splits=splits, sequential=sequential, conv2d=submodule)) 121 | return model 122 | 123 | 124 | def main(args): 125 | if torch.cuda.is_available(): 126 | torch.cuda.set_device(args.local_rank) 127 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 128 | if args.mixed_precision == 'fp16': 129 | dtype = torch.float16 130 | elif args.mixed_precision == 'bf16': 131 | dtype = torch.bfloat16 132 | else: 133 | dtype = torch.float32 134 | 135 | vae = AutoencoderKL.from_pretrained( 136 | args.pretrained_model_name_or_path, 137 | subfolder="vae", 138 | revision=args.revision, 139 | variant=args.variant, 140 | ).to(device, dtype) 141 | 142 | if args.resolution is not None and args.resolution > 3072: 143 | vae = convert_model(vae, splits=4) 144 | 145 | all_info = sorted([item for item in os.listdir(args.data_root) if item.endswith('.jpg')]) 146 | 147 | os.makedirs(args.output_dir, exist_ok=True) 148 | 149 | work_load = math.ceil(len(all_info) / args.num_workers) 150 | for idx in tqdm.tqdm(range(work_load * args.local_rank, min(work_load * (args.local_rank + 1), len(all_info)))): 151 | output_path = os.path.join(args.output_dir, f"{all_info[idx][:all_info[idx].rfind('.')]}_latent_code.safetensors") 152 | img = cv2.cvtColor(cv2.imread(os.path.join(args.data_root, all_info[idx])), cv2.COLOR_BGR2RGB) 153 | img = Image.fromarray(img) 154 | if args.resolution is not None: 155 | img = resize(img, args.resolution) 156 | img = torch.from_numpy((np.array(img) / 127.5) - 1) 157 | img = img.permute(2, 0, 1) 158 | with torch.no_grad(): 159 | img = img.unsqueeze(0) 160 | data = vae.encode(img.to(device, vae.dtype)).latent_dist 161 | mean = data.mean[0].cpu().data 162 | std = data.std[0].cpu().data 163 | save_file( 164 | {'mean': mean, 'std': std}, 165 | output_path 166 | ) 167 | 168 | 169 | if __name__ == '__main__': 170 | main(parse_args()) -------------------------------------------------------------------------------- /attention_processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from diffusers.models.attention_processor import Attention 6 | from typing import Optional 7 | from diffusers.models.embeddings import apply_rotary_emb 8 | 9 | 10 | class FluxAttnProcessor2_0: 11 | """Attention processor used typically in processing the SD3-like self-attention projections.""" 12 | 13 | def __init__(self, train_seq_len=512 + 64 * 64): 14 | if not hasattr(F, "scaled_dot_product_attention"): 15 | raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 16 | self.train_seq_len = train_seq_len 17 | 18 | def __call__( 19 | self, 20 | attn: Attention, 21 | hidden_states: torch.FloatTensor, 22 | encoder_hidden_states: torch.FloatTensor = None, 23 | attention_mask: Optional[torch.FloatTensor] = None, 24 | image_rotary_emb: Optional[torch.Tensor] = None, 25 | proportional_attention=False 26 | ) -> torch.FloatTensor: 27 | batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 28 | 29 | # `sample` projections. 30 | query = attn.to_q(hidden_states) 31 | key = attn.to_k(hidden_states) 32 | value = attn.to_v(hidden_states) 33 | 34 | inner_dim = key.shape[-1] 35 | head_dim = inner_dim // attn.heads 36 | 37 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 38 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 39 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 40 | 41 | if attn.norm_q is not None: 42 | query = attn.norm_q(query) 43 | if attn.norm_k is not None: 44 | key = attn.norm_k(key) 45 | 46 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` 47 | if encoder_hidden_states is not None: 48 | # `context` projections. 49 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 50 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 51 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 52 | 53 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( 54 | batch_size, -1, attn.heads, head_dim 55 | ).transpose(1, 2) 56 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( 57 | batch_size, -1, attn.heads, head_dim 58 | ).transpose(1, 2) 59 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( 60 | batch_size, -1, attn.heads, head_dim 61 | ).transpose(1, 2) 62 | 63 | if attn.norm_added_q is not None: 64 | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) 65 | if attn.norm_added_k is not None: 66 | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) 67 | 68 | # attention 69 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) 70 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) 71 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) 72 | 73 | if image_rotary_emb is not None: 74 | query = apply_rotary_emb(query, image_rotary_emb) 75 | key = apply_rotary_emb(key, image_rotary_emb) 76 | 77 | if proportional_attention: 78 | attention_scale = math.sqrt(math.log(key.size(2), self.train_seq_len) / head_dim) 79 | else: 80 | attention_scale = math.sqrt(1 / head_dim) 81 | 82 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, scale=attention_scale) 83 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 84 | hidden_states = hidden_states.to(query.dtype) 85 | 86 | if encoder_hidden_states is not None: 87 | encoder_hidden_states, hidden_states = ( 88 | hidden_states[:, : encoder_hidden_states.shape[1]], 89 | hidden_states[:, encoder_hidden_states.shape[1] :], 90 | ) 91 | 92 | # linear proj 93 | hidden_states = attn.to_out[0](hidden_states) 94 | # dropout 95 | hidden_states = attn.to_out[1](hidden_states) 96 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 97 | 98 | return hidden_states, encoder_hidden_states 99 | else: 100 | return hidden_states 101 | 102 | 103 | class FluxAttnAdaptationProcessor2_0(nn.Module): 104 | """Attention processor used typically in processing the SD3-like self-attention projections.""" 105 | 106 | def __init__(self, rank=16, dim=3072, to_out=False, train_seq_len=512 + 64 * 64): 107 | super().__init__() 108 | if not hasattr(F, "scaled_dot_product_attention"): 109 | raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 110 | self.to_q_a = nn.Linear(dim, rank, bias=False) 111 | self.to_q_b = nn.Linear(rank, dim, bias=False) 112 | self.to_q_b.weight.data = torch.zeros_like(self.to_q_b.weight.data) 113 | self.to_k_a = nn.Linear(dim, rank, bias=False) 114 | self.to_k_b = nn.Linear(rank, dim, bias=False) 115 | self.to_k_b.weight.data = torch.zeros_like(self.to_k_b.weight.data) 116 | self.to_v_a = nn.Linear(dim, rank, bias=False) 117 | self.to_v_b = nn.Linear(rank, dim, bias=False) 118 | self.to_v_b.weight.data = torch.zeros_like(self.to_v_b.weight.data) 119 | if to_out: 120 | self.to_out_a = nn.Linear(dim, rank, bias=False) 121 | self.to_out_b = nn.Linear(rank, dim, bias=False) 122 | self.to_out_b.weight.data = torch.zeros_like(self.to_out_b.weight.data) 123 | self.train_seq_len = train_seq_len 124 | 125 | def __call__( 126 | self, 127 | attn: Attention, 128 | hidden_states: torch.FloatTensor, 129 | encoder_hidden_states: torch.FloatTensor = None, 130 | attention_mask: Optional[torch.FloatTensor] = None, 131 | image_rotary_emb: Optional[torch.Tensor] = None, 132 | proportional_attention=False 133 | ) -> torch.FloatTensor: 134 | batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 135 | 136 | # `sample` projections. 137 | query = attn.to_q(hidden_states) + self.to_q_b(self.to_q_a(hidden_states)) 138 | key = attn.to_k(hidden_states) + self.to_k_b(self.to_k_a(hidden_states)) 139 | value = attn.to_v(hidden_states) + self.to_v_b(self.to_v_a(hidden_states)) 140 | 141 | inner_dim = key.shape[-1] 142 | head_dim = inner_dim // attn.heads 143 | 144 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 145 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 146 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 147 | 148 | if attn.norm_q is not None: 149 | query = attn.norm_q(query) 150 | if attn.norm_k is not None: 151 | key = attn.norm_k(key) 152 | 153 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` 154 | if encoder_hidden_states is not None: 155 | # `context` projections. 156 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 157 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 158 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 159 | 160 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( 161 | batch_size, -1, attn.heads, head_dim 162 | ).transpose(1, 2) 163 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( 164 | batch_size, -1, attn.heads, head_dim 165 | ).transpose(1, 2) 166 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( 167 | batch_size, -1, attn.heads, head_dim 168 | ).transpose(1, 2) 169 | 170 | if attn.norm_added_q is not None: 171 | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) 172 | if attn.norm_added_k is not None: 173 | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) 174 | 175 | # attention 176 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) 177 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) 178 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) 179 | 180 | if image_rotary_emb is not None: 181 | query = apply_rotary_emb(query, image_rotary_emb) 182 | key = apply_rotary_emb(key, image_rotary_emb) 183 | 184 | if proportional_attention: 185 | attention_scale = math.sqrt(math.log(key.size(2), self.train_seq_len) / head_dim) 186 | else: 187 | attention_scale = math.sqrt(1 / head_dim) 188 | 189 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, scale=attention_scale) 190 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 191 | hidden_states = hidden_states.to(query.dtype) 192 | 193 | if encoder_hidden_states is not None: 194 | encoder_hidden_states, hidden_states = ( 195 | hidden_states[:, : encoder_hidden_states.shape[1]], 196 | hidden_states[:, encoder_hidden_states.shape[1] :], 197 | ) 198 | 199 | # linear proj 200 | hidden_states = attn.to_out[0](hidden_states) + self.to_out_b(self.to_out_a(hidden_states)) 201 | # dropout 202 | hidden_states = attn.to_out[1](hidden_states) 203 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 204 | 205 | return hidden_states, encoder_hidden_states 206 | else: 207 | return hidden_states 208 | -------------------------------------------------------------------------------- /cache_prompt_embeds.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import math 5 | from safetensors.torch import save_file 6 | import torch 7 | import tqdm 8 | from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, T5EncoderModel 9 | 10 | 11 | def parse_args(input_args=None): 12 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 13 | parser.add_argument( 14 | "--pretrained_model_name_or_path", 15 | type=str, 16 | default=None, 17 | required=True, 18 | help="Path to pretrained model or model identifier from huggingface.co/models.", 19 | ) 20 | parser.add_argument( 21 | "--revision", 22 | type=str, 23 | default=None, 24 | required=False, 25 | help="Revision of pretrained model identifier from huggingface.co/models.", 26 | ) 27 | parser.add_argument( 28 | "--variant", 29 | type=str, 30 | default=None, 31 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 32 | ) 33 | parser.add_argument( 34 | "--data_root", 35 | type=str, 36 | default=None 37 | ) 38 | parser.add_argument( 39 | "--cache_dir", 40 | type=str, 41 | default=None, 42 | help="The directory where the downloaded models and datasets will be stored.", 43 | ) 44 | parser.add_argument( 45 | "--column", 46 | type=str, 47 | default="prompt" 48 | ) 49 | parser.add_argument( 50 | "--max_sequence_length", 51 | type=int, 52 | default=512, 53 | help="Maximum sequence length to use with with the T5 text encoder", 54 | ) 55 | parser.add_argument( 56 | "--output_dir", 57 | type=str, 58 | default="flux-lora", 59 | help="The output directory where the model predictions and checkpoints will be written.", 60 | ) 61 | parser.add_argument( 62 | "--batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 63 | ) 64 | parser.add_argument( 65 | "--mixed_precision", 66 | type=str, 67 | default=None, 68 | choices=["no", "fp16", "bf16"], 69 | help=( 70 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 71 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 72 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 73 | ), 74 | ) 75 | parser.add_argument("--local_rank", type=int, default=0, help="For distributed training: local_rank") 76 | parser.add_argument( 77 | "--num_workers", 78 | type=int, 79 | default=1, 80 | help="Number of workers", 81 | ) 82 | 83 | if input_args is not None: 84 | args = parser.parse_args(input_args) 85 | else: 86 | args = parser.parse_args() 87 | 88 | args.local_rank = int(os.environ.get("LOCAL_RANK", 0)) 89 | 90 | return args 91 | 92 | 93 | def tokenize_prompt(tokenizer, prompt, max_sequence_length): 94 | text_inputs = tokenizer( 95 | prompt, 96 | padding="max_length", 97 | max_length=max_sequence_length, 98 | truncation=True, 99 | return_length=False, 100 | return_overflowing_tokens=False, 101 | return_tensors="pt", 102 | ) 103 | text_input_ids = text_inputs.input_ids 104 | return text_input_ids 105 | 106 | 107 | def _encode_prompt_with_t5( 108 | text_encoder, 109 | tokenizer, 110 | max_sequence_length=512, 111 | prompt=None, 112 | num_images_per_prompt=1, 113 | device=None, 114 | text_input_ids=None, 115 | ): 116 | prompt = [prompt] if isinstance(prompt, str) else prompt 117 | batch_size = len(prompt) 118 | 119 | if tokenizer is not None: 120 | text_inputs = tokenizer( 121 | prompt, 122 | padding="max_length", 123 | max_length=max_sequence_length, 124 | truncation=True, 125 | return_length=False, 126 | return_overflowing_tokens=False, 127 | return_tensors="pt", 128 | ) 129 | text_input_ids = text_inputs.input_ids 130 | else: 131 | if text_input_ids is None: 132 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified") 133 | 134 | prompt_embeds = text_encoder(text_input_ids.to(device))[0] 135 | 136 | dtype = text_encoder.dtype 137 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 138 | 139 | _, seq_len, _ = prompt_embeds.shape 140 | 141 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 142 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 143 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 144 | 145 | return prompt_embeds 146 | 147 | 148 | def _encode_prompt_with_clip( 149 | text_encoder, 150 | tokenizer, 151 | prompt: str, 152 | device=None, 153 | text_input_ids=None, 154 | num_images_per_prompt: int = 1, 155 | ): 156 | prompt = [prompt] if isinstance(prompt, str) else prompt 157 | batch_size = len(prompt) 158 | 159 | if tokenizer is not None: 160 | text_inputs = tokenizer( 161 | prompt, 162 | padding="max_length", 163 | max_length=77, 164 | truncation=True, 165 | return_overflowing_tokens=False, 166 | return_length=False, 167 | return_tensors="pt", 168 | ) 169 | 170 | text_input_ids = text_inputs.input_ids 171 | else: 172 | if text_input_ids is None: 173 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified") 174 | 175 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) 176 | 177 | # Use pooled output of CLIPTextModel 178 | prompt_embeds = prompt_embeds.pooler_output 179 | prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) 180 | 181 | # duplicate text embeddings for each generation per prompt, using mps friendly method 182 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 183 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 184 | 185 | return prompt_embeds 186 | 187 | 188 | def encode_prompt( 189 | text_encoders, 190 | tokenizers, 191 | prompt: str, 192 | max_sequence_length, 193 | device=None, 194 | num_images_per_prompt: int = 1, 195 | text_input_ids_list=None, 196 | ): 197 | prompt = [prompt] if isinstance(prompt, str) else prompt 198 | 199 | pooled_prompt_embeds = _encode_prompt_with_clip( 200 | text_encoder=text_encoders[0], 201 | tokenizer=tokenizers[0], 202 | prompt=prompt, 203 | device=device if device is not None else text_encoders[0].device, 204 | num_images_per_prompt=num_images_per_prompt, 205 | text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, 206 | ) 207 | 208 | prompt_embeds = _encode_prompt_with_t5( 209 | text_encoder=text_encoders[1], 210 | tokenizer=tokenizers[1], 211 | max_sequence_length=max_sequence_length, 212 | prompt=prompt, 213 | num_images_per_prompt=num_images_per_prompt, 214 | device=device if device is not None else text_encoders[1].device, 215 | text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, 216 | ) 217 | 218 | return prompt_embeds, pooled_prompt_embeds 219 | 220 | 221 | def main(args): 222 | if torch.cuda.is_available(): 223 | torch.cuda.set_device(args.local_rank) 224 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 225 | if args.mixed_precision == 'fp16': 226 | dtype = torch.float16 227 | elif args.mixed_precision == 'bf16': 228 | dtype = torch.bfloat16 229 | else: 230 | dtype = torch.float32 231 | tokenizer_one = CLIPTokenizer.from_pretrained( 232 | args.pretrained_model_name_or_path, 233 | subfolder="tokenizer", 234 | revision=args.revision, 235 | variant=args.variant, 236 | cache_dir=args.cache_dir 237 | ) 238 | tokenizer_two = T5TokenizerFast.from_pretrained( 239 | args.pretrained_model_name_or_path, 240 | subfolder="tokenizer_2", 241 | revision=args.revision, 242 | variant=args.variant, 243 | cache_dir=args.cache_dir 244 | ) 245 | 246 | text_encoder_one = CLIPTextModel.from_pretrained( 247 | args.pretrained_model_name_or_path, 248 | revision=args.revision, 249 | subfolder="text_encoder", 250 | variant=args.variant, 251 | cache_dir=args.cache_dir 252 | ).to(device, dtype) 253 | text_encoder_two = T5EncoderModel.from_pretrained( 254 | args.pretrained_model_name_or_path, 255 | revision=args.revision, 256 | subfolder="text_encoder_2", 257 | variant=args.variant, 258 | cache_dir=args.cache_dir 259 | ).to(device, dtype) 260 | tokenizers = [tokenizer_one, tokenizer_two] 261 | text_encoders = [text_encoder_one, text_encoder_two] 262 | 263 | all_info = [os.path.join(args.data_root, i) for i in sorted(os.listdir(args.data_root)) if '.json' in i] 264 | 265 | os.makedirs(args.output_dir, exist_ok=True) 266 | 267 | work_load = math.ceil(len(all_info) / args.num_workers) 268 | for idx in tqdm.tqdm(range(work_load * args.local_rank, min(work_load * (args.local_rank + 1), len(all_info)), args.batch_size)): 269 | texts = [] 270 | for item in all_info[idx:idx + args.batch_size]: 271 | with open(item) as f: 272 | prompt = json.load(f)[args.column] 273 | texts.append(prompt) 274 | if not isinstance(prompt, str): 275 | print(prompt, item) 276 | paths = [item[:item.rfind('.')] + f'_{args.column}_embed.safetensors' for item in all_info[idx:idx + args.batch_size]] 277 | with torch.no_grad(): 278 | prompt_embeds, pooled_prompt_embeds = encode_prompt( 279 | text_encoders, tokenizers, texts, args.max_sequence_length 280 | ) 281 | prompt_embeds = prompt_embeds.cpu().data 282 | pooled_prompt_embeds = pooled_prompt_embeds.cpu().data 283 | for path, prompt_embed, pooled_prompt_embed in zip(paths, prompt_embeds.unbind(), pooled_prompt_embeds.unbind()): 284 | save_file( 285 | {'caption_feature_t5': prompt_embed, 'caption_feature_clip': pooled_prompt_embed}, 286 | path 287 | ) 288 | 289 | 290 | if __name__ == '__main__': 291 | main(parse_args()) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ICML 2025] URAE: ~~Your Free FLUX Pro Ultra~~ 2 | 3 | 4 |
5 | 6 | arXiv 7 | HuggingFace 8 | HuggingFace 9 | HuggingFace 10 | 11 | > ***U*ltra-*R*esolution *A*daptation with *E*ase** 12 | >
13 | > [Ruonan Yu*](https://scholar.google.com/citations?user=UHP95egAAAAJ&hl=en), 14 | > [Songhua Liu*](http://121.37.94.87/), 15 | > [Zhenxiong Tan](https://scholar.google.com/citations?user=HP9Be6UAAAAJ&hl=en), 16 | > and 17 | > [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/) 18 | >
19 | > [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore 20 | >
21 | 22 | ## 🪶Features 23 | 24 | * **Easy-to-Use High-Quality and High-Resolution Generation😊**: Ultra-Resolution Adaptation with Ease, or URAE in short, generates high-resolution images with FLUX, with minimal code modifications. 25 | * **Easy Training🚀**: URAE tames light-weight adapters with a handful of synthetic data from [FLUX1.1 Pro Ultra](https://blackforestlabs.ai/ultra-home/). 26 | 27 | ## 🔥News 28 | 29 | **[2025/05/01]** URAE is accepted by ICML2025! 🎉 30 | 31 | **[2025/03/20]** We release models and codes for both training and inference of URAE. 32 | 33 | ## Introduction 34 | 35 | Text-to-image diffusion models have achieved remarkable progress in recent years. However, training models for high-resolution image generation remains challenging, particularly when training data and computational resources are limited. In this paper, we explore this practical problem from two key perspectives: data and parameter efficiency, and propose a set of key guidelines for ultra-resolution adaptation termed *URAE*. For data efficiency, we theoretically and empirically demonstrate that synthetic data generated by some teacher models can significantly promote training convergence. For parameter efficiency, we find that tuning minor components of the weight matrices outperforms widely-used low-rank adapters when synthetic data are unavailable, offering substantial performance gains while maintaining efficiency. Additionally, for models leveraging guidance distillation, such as FLUX, we show that disabling classifier-free guidance, *i.e.*, setting the guidance scale to 1 during adaptation, is crucial for satisfactory performance. Extensive experiments validate that URAE achieves comparable 2K-generation performance to state-of-the-art closed-source models like FLUX1.1 [Pro] Ultra with only 3K samples and 2K iterations, while setting new benchmarks for 4K-resolution generation. 36 | 37 | ## Quick Start 38 | 39 | * If you have not, install [PyTorch](https://pytorch.org/get-started/locally/), [diffusers](https://huggingface.co/docs/diffusers/index), [transformers](https://huggingface.co/docs/transformers/index), and [peft](https://huggingface.co/docs/peft/index). 40 | 41 | * Clone this repo to your project directory: 42 | 43 | ``` bash 44 | git clone https://github.com/Huage001/URAE.git 45 | cd URAE 46 | ``` 47 | 48 | * **You only need minimal modifications!** 49 | 50 | ```diff 51 | import torch 52 | - from diffusers import FluxPipeline 53 | + from pipeline_flux import FluxPipeline 54 | + from transformer_flux import FluxTransformer2DModel 55 | 56 | bfl_repo = "black-forest-labs/FLUX.1-dev" 57 | + transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder="transformer", torch_dtype=torch.bfloat16) 58 | - pipe = FluxPipeline.from_pretrained(bfl_repo, torch_dtype=torch.bfloat16) 59 | + pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, torch_dtype=torch.bfloat16) 60 | + pipe.scheduler.config.use dynamic_shifting = False 61 | + pipe.scheduler.config.time shift = 10 62 | pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power 63 | 64 | + pipe.load_lora_weights("Huage001/URAE", weight_name="urae_2k_adapter.safetensors") 65 | 66 | prompt = "An astronaut riding a green horse" 67 | image = pipe( 68 | prompt, 69 | - height=1024, 70 | - width=1024, 71 | + height=2048, 72 | + width=2048, 73 | guidance_scale=3.5, 74 | num_inference_steps=50, 75 | max_sequence_length=512, 76 | generator=torch.Generator("cpu").manual_seed(0) 77 | ).images[0] 78 | image.save("flux-urae.png") 79 | ``` 80 | ⚠️ **FLUX requires at least 28GB of GPU memory to operate at a 2K resolution.** A 48GB GPU is recommended for the full functionalities of URAE, including both 2K and 4K. We are actively integrating model lightweighting strategies into URAE! If you have a good idea, feel free to submit a PR! 81 | 82 | * Do not want to run the codes? No worry! Try the model on Huggingface Space! 83 | 84 | * [URAE w. FLUX1.schnell](https://huggingface.co/spaces/Yuanshi/URAE) (Faster) 85 | * [URAE w. FLUX1.dev](https://huggingface.co/spaces/Yuanshi/URAE_dev) (Higher Quality) 86 | 87 | ## Installation 88 | 89 | * Clone this repo to your project directory: 90 | 91 | ``` bash 92 | git clone https://github.com/Huage001/URAE.git 93 | cd URAE 94 | ``` 95 | 96 | * URAE has been tested on ``torch==2.5.1`` and ``diffusers==0.31.0``, but it should also be compatible to similar variants. You can set up a new environment if you wish and install packages listed in ``requirements.txt``: 97 | 98 | ```bash 99 | conda create -n URAE python=3.12 100 | conda activate URAE 101 | pip install -r requirements.txt 102 | ``` 103 | 104 | ## Inference 105 | 106 | #### 2K Resolution, or Use Cases Similar to FLUX Pro Ultra 107 | 108 | * We use **LoRA adapters** to adapt FLUX 1.dev to 2K resolution. You can try the corresponding URAE model in `inference_2k.ipynb`. 109 | 110 | * We also support FLUX 1.schnell for a **faster inference process**. Please refer to `inference_2k_schnell.ipynb`. 111 | 112 | #### 4K Resolution or Higher 113 | 114 | * Instead of LoRA, we use **minor-component adapters** at the 4K stage. You can try the models for FLUX 1.dev and FLUX 1.schnell in `inference_4k.ipynb` and `inference_4k_schnell.ipynb`. 115 | * Alternatively, although these adapters are different from LoRA, it can also be converted to the LoRA format, so that we can apply interfaces provided by `peft` for a more convenient usage. **If you only want to try the models instead of understanding the principle,** `inference_4k_lora_conversion.ipynb` **and** `inference_4k_lora_conversion_schnell.ipynb` **are what you want!** Their outputs should be equivalent to the above counterparts. 116 | 117 | 🚧 The 4K model is still in beta and the performance may not be stable. A more recommended use case for the current 4K model is to integrate it with some training-free high-resolution generation pipelines based on coarse-to-fine strategies, such as [SDEdit](https://arxiv.org/abs/2108.01073) (refer to [this repo](https://github.com/Huage001/CLEAR/blob/main/inference_t2i_highres.ipynb) for a sample usage) and [I-Max](https://github.com/PRIS-CV/I-Max), and load the 4K adapter at high resolution stages. 118 | 119 | 120 | ## Training 121 | 122 | 1. Prepare training data. 123 | 124 | * To train the 2K model, we collect 3,000 images generated by [FLUX1.1 Pro Ultra](https://blackforestlabs.ai/ultra-home/). You can use your API and follow the instructions on acquiring images. 125 | 126 | * To train the 4K model, we use ~16,000 images with resolution greater than 4K from [LAION-High-Resolution](https://huggingface.co/datasets/laion/laion-high-resolution). 127 | 128 | * The training data folder should be organized to the following format: 129 | 130 | ``` 131 | train_data 132 | ├── image_0.jpg 133 | ├── image_0.json 134 | ├── image_1.jpg 135 | ├── image_1.json 136 | ├── ... 137 | ``` 138 | 139 | The json file should contain a dictionary with entries `prompt` and/or `generated_prompt`, to specify the original caption and generated caption with detailed descriptions by VLM such as GPT-4o or Llama3. 140 | 141 | * The T5 and VAE can take a large amount of GPU memory, which can trigger OOM when training at high resolutions. Therefore, we pre-cache the T5 and VAE features instead of computing them online: 142 | 143 | ```bash 144 | bash cache_prompt_embeds.sh 145 | bash cache_latent_codes.sh 146 | ``` 147 | 148 | Make sure to modify these bash files beforehand and configure the number of processes in parallel (`$NUM_WORKERS`), training data folder (`$DATA_DIR`), the target resolution (`--resolution`), and the key to the prompt in json files (`--column`, can be `prompt` or `generated_prompt`). 149 | 150 | * The final format of the training data folder should be: 151 | 152 | ``` 153 | train_data 154 | ├── image_0_generated_prompt_embed.safetensors 155 | ├── image_0.jpg 156 | ├── image_0.json 157 | ├── image_0_latent_code.safetensors 158 | ├── image_0_prompt_embed.safetensors 159 | ├── image_1_generated_prompt_embed.safetensors 160 | ├── image_1.jpg 161 | ├── image_1.json 162 | ├── image_1_latent_code.safetensors 163 | ├── image_1_prompt_embed.safetensors 164 | ├── ... 165 | ``` 166 | 167 | 2. Let's start training! For 2K model, make sure to modify `train_2k.sh` and configure the training data folder (`$DATA_DIR`), output folder (`$OUTPUT_DIR`), and the number of GPUs (`--num_processes`). 168 | 169 | ```bash 170 | bash train_2k.sh 171 | ``` 172 | 173 | 3. 4K model is based on the 2K LoRA trained previously. Make sure to modify `train_4k.sh` to configure the path (`--pretrained_lora`) in addition to the aforementioned items. 174 | 175 | ```bash 176 | bash train_4k.sh 177 | ``` 178 | 179 | ⚠️ Currently, if the training images have different resolutions, the `--train_batch_size` can only be 1 because we did not customize data sampler to handle this case. Nevertheless, in practice, with such large resolutions, the maximal batch size per GPU (before gradient accumulation) can only be 1 even on 80G GPUs 😂. 180 | 181 | ## Acknowledgement 182 | 183 | * [FLUX](https://blackforestlabs.ai/announcing-black-forest-labs/) for the source models. 184 | * [diffusers](https://github.com/huggingface/diffusers), [CLEAR](https://github.com/Huage001/CLEAR), and [I-Max](https://github.com/PRIS-CV/I-Max) for the code base. 185 | * [patch_conv](https://github.com/mit-han-lab/patch_conv) for solution to VAE OOM error at high resolutions. 186 | * [@Xinyin Ma](https://github.com/horseee) for valuable discussions. 187 | * NUS IT’s Research Computing group using grant numbers NUSREC-HPC-00001. 188 | 189 | ## Citation 190 | 191 | If you finds this repo is helpful, please consider citing: 192 | 193 | ```bib 194 | @article{yu2025urae, 195 | title = {Ultra-Resolution Adaptation with Ease}, 196 | author = {Yu, Ruonan and Liu, Songhua and Tan, Zhenxiong and Wang, Xinchao}, 197 | journal = {International Conference on Machine Learning}, 198 | year = {2025}, 199 | } 200 | ``` 201 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /inference_4k_schnell.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import torch\n", 11 | "from huggingface_hub import hf_hub_download\n", 12 | "from pipeline_flux import FluxPipeline\n", 13 | "from transformer_flux import FluxTransformer2DModel\n", 14 | "from attention_processor import FluxAttnAdaptationProcessor2_0\n", 15 | "from safetensors.torch import load_file, save_file\n", 16 | "from patch_conv import convert_model" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "bfl_repo=\"black-forest-labs/FLUX.1-schnell\"\n", 26 | "device = torch.device('cuda')\n", 27 | "dtype = torch.bfloat16\n", 28 | "transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\"transformer\")\n", 29 | "pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, torch_dtype=dtype)\n", 30 | "pipe.scheduler.config.use_dynamic_shifting = False\n", 31 | "pipe.scheduler.config.time_shift = 10\n", 32 | "pipe.enable_model_cpu_offload()" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "* 4K model is based on 2K LoRA" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "if not os.path.exists('ckpt/urae_2k_adapter.safetensors'):\n", 49 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_2k_adapter.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)\n", 50 | "pipe.load_lora_weights(\"ckpt/urae_2k_adapter.safetensors\")\n", 51 | "pipe.fuse_lora()" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "* Substitute original attention processors" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "rank = 16\n", 68 | "attn_processors = {}\n", 69 | "for k in pipe.transformer.attn_processors.keys():\n", 70 | " attn_processors[k] = FluxAttnAdaptationProcessor2_0(rank=rank, to_out='single' not in k).to(device, dtype)\n", 71 | "pipe.transformer.set_attn_processor(attn_processors)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "* If no cached major components, compute them via SVD and save them to cache_path\n", 79 | "* If you don't want to save cached major components, simply set `cache_path = None`" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "cache_path = 'ckpt/_urae_4k_adapter_schnell.safetensors'\n", 89 | "if cache_path is not None and os.path.exists(cache_path):\n", 90 | " pipe.transformer.to(dtype=dtype)\n", 91 | " pipe.transformer.load_state_dict(load_file(cache_path), strict=False)\n", 92 | "else:\n", 93 | " with torch.no_grad():\n", 94 | " for idx in range(len(pipe.transformer.transformer_blocks)):\n", 95 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_q.weight.data.to(device)\n", 96 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n", 97 | " pipe.transformer.transformer_blocks[idx].attn.to_q.weight.data = (\n", 98 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n", 99 | " ).to('cpu')\n", 100 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_q_b.weight.data = (\n", 101 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n", 102 | " ).to('cpu')\n", 103 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_q_a.weight.data = (\n", 104 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n", 105 | " ).to('cpu')\n", 106 | "\n", 107 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_k.weight.data.to(device)\n", 108 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n", 109 | " pipe.transformer.transformer_blocks[idx].attn.to_k.weight.data = (\n", 110 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n", 111 | " ).to('cpu')\n", 112 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_k_b.weight.data = (\n", 113 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n", 114 | " ).to('cpu')\n", 115 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_k_a.weight.data = (\n", 116 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n", 117 | " ).to('cpu')\n", 118 | "\n", 119 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_v.weight.data.to(device)\n", 120 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n", 121 | " pipe.transformer.transformer_blocks[idx].attn.to_v.weight.data = (\n", 122 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n", 123 | " ).to('cpu')\n", 124 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_v_b.weight.data = (\n", 125 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n", 126 | " ).to('cpu')\n", 127 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_v_a.weight.data = (\n", 128 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n", 129 | " ).to('cpu')\n", 130 | "\n", 131 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_out[0].weight.data.to(device)\n", 132 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n", 133 | " pipe.transformer.transformer_blocks[idx].attn.to_out[0].weight.data = (\n", 134 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n", 135 | " ).to('cpu')\n", 136 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_out_b.weight.data = (\n", 137 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n", 138 | " ).to('cpu')\n", 139 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_out_a.weight.data = (\n", 140 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n", 141 | " ).to('cpu')\n", 142 | " for idx in range(len(pipe.transformer.single_transformer_blocks)):\n", 143 | " matrix_w = pipe.transformer.single_transformer_blocks[idx].attn.to_q.weight.data.to(device)\n", 144 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n", 145 | " pipe.transformer.single_transformer_blocks[idx].attn.to_q.weight.data = (\n", 146 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n", 147 | " ).to('cpu')\n", 148 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_q_b.weight.data = (\n", 149 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n", 150 | " ).to('cpu')\n", 151 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_q_a.weight.data = (\n", 152 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n", 153 | " ).to('cpu')\n", 154 | "\n", 155 | " matrix_w = pipe.transformer.single_transformer_blocks[idx].attn.to_k.weight.data.to(device)\n", 156 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n", 157 | " pipe.transformer.single_transformer_blocks[idx].attn.to_k.weight.data = (\n", 158 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n", 159 | " ).to('cpu')\n", 160 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_k_b.weight.data = (\n", 161 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n", 162 | " ).to('cpu')\n", 163 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_k_a.weight.data = (\n", 164 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n", 165 | " ).to('cpu')\n", 166 | "\n", 167 | " matrix_w = pipe.transformer.single_transformer_blocks[idx].attn.to_v.weight.data.to(device)\n", 168 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n", 169 | " pipe.transformer.single_transformer_blocks[idx].attn.to_v.weight.data = (\n", 170 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n", 171 | " ).to('cpu')\n", 172 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_v_b.weight.data = (\n", 173 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n", 174 | " ).to('cpu')\n", 175 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_v_a.weight.data = (\n", 176 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n", 177 | " ).to('cpu')\n", 178 | " pipe.transformer.to(dtype=dtype)\n", 179 | " if cache_path is not None:\n", 180 | " state_dict = pipe.transformer.state_dict()\n", 181 | " attn_state_dict = {}\n", 182 | " for k in state_dict.keys():\n", 183 | " if 'base_layer' in k:\n", 184 | " attn_state_dict[k] = state_dict[k]\n", 185 | " save_file(attn_state_dict, cache_path)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "* Download pre-trained 4k adapter" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "if not os.path.exists('ckpt/urae_4k_adapter.safetensors'):\n", 202 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_4k_adapter.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "* Optionally, you can convert the minor-component adapter into a LoRA for easier use" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "lora_conversion = True\n", 219 | "if lora_conversion and not os.path.exists('ckpt/urae_4k_adapter_lora_conversion_schnell.safetensors'):\n", 220 | " cur = pipe.transformer.state_dict()\n", 221 | " tgt = load_file('ckpt/urae_4k_adapter.safetensors')\n", 222 | " ref = load_file('ckpt/urae_2k_adapter.safetensors')\n", 223 | " new_ckpt = {}\n", 224 | " for k in tgt.keys():\n", 225 | " if 'to_k_a' in k:\n", 226 | " k_ = 'transformer.' + k.replace('.processor.to_k_a', '.to_k.lora_A')\n", 227 | " elif 'to_k_b' in k:\n", 228 | " k_ = 'transformer.' + k.replace('.processor.to_k_b', '.to_k.lora_B')\n", 229 | " elif 'to_q_a' in k:\n", 230 | " k_ = 'transformer.' + k.replace('.processor.to_q_a', '.to_q.lora_A')\n", 231 | " elif 'to_q_b' in k:\n", 232 | " k_ = 'transformer.' + k.replace('.processor.to_q_b', '.to_q.lora_B')\n", 233 | " elif 'to_v_a' in k:\n", 234 | " k_ = 'transformer.' + k.replace('.processor.to_v_a', '.to_v.lora_A')\n", 235 | " elif 'to_v_b' in k:\n", 236 | " k_ = 'transformer.' + k.replace('.processor.to_v_b', '.to_v.lora_B')\n", 237 | " elif 'to_out_a' in k:\n", 238 | " k_ = 'transformer.' + k.replace('.processor.to_out_a', '.to_out.0.lora_A')\n", 239 | " elif 'to_out_b' in k:\n", 240 | " k_ = 'transformer.' + k.replace('.processor.to_out_b', '.to_out.0.lora_B')\n", 241 | " else:\n", 242 | " print(k)\n", 243 | " assert False\n", 244 | " if '_a.' in k and '_b.' not in k:\n", 245 | " new_ckpt[k_] = torch.cat([-cur[k], tgt[k], ref[k_]], dim=0)\n", 246 | " elif '_b.' in k and '_a.' not in k:\n", 247 | " new_ckpt[k_] = torch.cat([cur[k], tgt[k], ref[k_]], dim=1)\n", 248 | " else:\n", 249 | " print(k)\n", 250 | " assert False\n", 251 | " save_file(new_ckpt, 'ckpt/urae_4k_adapter_lora_conversion_schnell.safetensors')" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "* Load state_dict of 4k adapter" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "m, u = pipe.transformer.load_state_dict(load_file('ckpt/urae_4k_adapter.safetensors'), strict=False)\n", 268 | "assert len(u) == 0" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": {}, 274 | "source": [ 275 | "* Use patch-wise convolution for VAE to avoid OOM error when decoding\n", 276 | "* If still OOM, try replacing the following line with `pipe.vae.enable_tiling()`" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "pipe.vae = convert_model(pipe.vae, splits=4)" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "* Everything ready. Let's generate!" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "prompt = \"A serene woman in a flowing azure dress, gracefully perched on a sunlit cliff overlooking a tranquil sea, her hair gently tousled by the breeze. The scene is infused with a sense of peace, evoking a dreamlike atmosphere, reminiscent of Impressionist paintings.\"\n", 302 | "height = 4096\n", 303 | "width = 4096\n", 304 | "image = pipe(\n", 305 | " prompt,\n", 306 | " height=height,\n", 307 | " width=width,\n", 308 | " guidance_scale=0,\n", 309 | " num_inference_steps=4,\n", 310 | " max_sequence_length=256,\n", 311 | " generator=torch.manual_seed(8888),\n", 312 | " ntk_factor=10,\n", 313 | " proportional_attention=True\n", 314 | ").images[0]\n", 315 | "image" 316 | ] 317 | } 318 | ], 319 | "metadata": { 320 | "kernelspec": { 321 | "display_name": "Python 3 (ipykernel)", 322 | "language": "python", 323 | "name": "python3" 324 | }, 325 | "language_info": { 326 | "codemirror_mode": { 327 | "name": "ipython", 328 | "version": 3 329 | }, 330 | "file_extension": ".py", 331 | "mimetype": "text/x-python", 332 | "name": "python", 333 | "nbconvert_exporter": "python", 334 | "pygments_lexer": "ipython3", 335 | "version": "3.12.9" 336 | } 337 | }, 338 | "nbformat": 4, 339 | "nbformat_minor": 2 340 | } 341 | -------------------------------------------------------------------------------- /inference_4k.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import torch\n", 11 | "from huggingface_hub import hf_hub_download\n", 12 | "from pipeline_flux import FluxPipeline\n", 13 | "from transformer_flux import FluxTransformer2DModel\n", 14 | "from attention_processor import FluxAttnAdaptationProcessor2_0\n", 15 | "from safetensors.torch import load_file, save_file\n", 16 | "from patch_conv import convert_model" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "bfl_repo=\"black-forest-labs/FLUX.1-dev\"\n", 26 | "device = torch.device('cuda')\n", 27 | "dtype = torch.bfloat16\n", 28 | "transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\"transformer\")\n", 29 | "pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, torch_dtype=dtype)\n", 30 | "pipe.scheduler.config.use_dynamic_shifting = False\n", 31 | "pipe.scheduler.config.time_shift = 10\n", 32 | "pipe.enable_model_cpu_offload()" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "* 4K model is based on 2K LoRA" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "if not os.path.exists('ckpt/urae_2k_adapter.safetensors'):\n", 49 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_2k_adapter.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)\n", 50 | "pipe.load_lora_weights(\"ckpt/urae_2k_adapter.safetensors\")\n", 51 | "pipe.fuse_lora()" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "* Substitute original attention processors" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "rank = 16\n", 68 | "attn_processors = {}\n", 69 | "for k in pipe.transformer.attn_processors.keys():\n", 70 | " attn_processors[k] = FluxAttnAdaptationProcessor2_0(rank=rank, to_out='single' not in k)\n", 71 | "pipe.transformer.set_attn_processor(attn_processors)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "* If no cached major components, compute them via SVD and save them to cache_path\n", 79 | "* If you don't want to save cached major components, simply set `cache_path = None`" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "cache_path = 'ckpt/_urae_4k_adapter_dev.safetensors'\n", 89 | "if cache_path is not None and os.path.exists(cache_path):\n", 90 | " pipe.transformer.to(dtype=dtype)\n", 91 | " pipe.transformer.load_state_dict(load_file(cache_path), strict=False)\n", 92 | "else:\n", 93 | " with torch.no_grad():\n", 94 | " for idx in range(len(pipe.transformer.transformer_blocks)):\n", 95 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_q.weight.data.to(device)\n", 96 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n", 97 | " pipe.transformer.transformer_blocks[idx].attn.to_q.weight.data = (\n", 98 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n", 99 | " ).to('cpu')\n", 100 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_q_b.weight.data = (\n", 101 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n", 102 | " ).to('cpu')\n", 103 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_q_a.weight.data = (\n", 104 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n", 105 | " ).to('cpu')\n", 106 | "\n", 107 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_k.weight.data.to(device)\n", 108 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n", 109 | " pipe.transformer.transformer_blocks[idx].attn.to_k.weight.data = (\n", 110 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n", 111 | " ).to('cpu')\n", 112 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_k_b.weight.data = (\n", 113 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n", 114 | " ).to('cpu')\n", 115 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_k_a.weight.data = (\n", 116 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n", 117 | " ).to('cpu')\n", 118 | "\n", 119 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_v.weight.data.to(device)\n", 120 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n", 121 | " pipe.transformer.transformer_blocks[idx].attn.to_v.weight.data = (\n", 122 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n", 123 | " ).to('cpu')\n", 124 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_v_b.weight.data = (\n", 125 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n", 126 | " ).to('cpu')\n", 127 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_v_a.weight.data = (\n", 128 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n", 129 | " ).to('cpu')\n", 130 | "\n", 131 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_out[0].weight.data.to(device)\n", 132 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n", 133 | " pipe.transformer.transformer_blocks[idx].attn.to_out[0].weight.data = (\n", 134 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n", 135 | " ).to('cpu')\n", 136 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_out_b.weight.data = (\n", 137 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n", 138 | " ).to('cpu')\n", 139 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_out_a.weight.data = (\n", 140 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n", 141 | " ).to('cpu')\n", 142 | " for idx in range(len(pipe.transformer.single_transformer_blocks)):\n", 143 | " matrix_w = pipe.transformer.single_transformer_blocks[idx].attn.to_q.weight.data.to(device)\n", 144 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n", 145 | " pipe.transformer.single_transformer_blocks[idx].attn.to_q.weight.data = (\n", 146 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n", 147 | " ).to('cpu')\n", 148 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_q_b.weight.data = (\n", 149 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n", 150 | " ).to('cpu')\n", 151 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_q_a.weight.data = (\n", 152 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n", 153 | " ).to('cpu')\n", 154 | "\n", 155 | " matrix_w = pipe.transformer.single_transformer_blocks[idx].attn.to_k.weight.data.to(device)\n", 156 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n", 157 | " pipe.transformer.single_transformer_blocks[idx].attn.to_k.weight.data = (\n", 158 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n", 159 | " ).to('cpu')\n", 160 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_k_b.weight.data = (\n", 161 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n", 162 | " ).to('cpu')\n", 163 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_k_a.weight.data = (\n", 164 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n", 165 | " ).to('cpu')\n", 166 | "\n", 167 | " matrix_w = pipe.transformer.single_transformer_blocks[idx].attn.to_v.weight.data.to(device)\n", 168 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n", 169 | " pipe.transformer.single_transformer_blocks[idx].attn.to_v.weight.data = (\n", 170 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n", 171 | " ).to('cpu')\n", 172 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_v_b.weight.data = (\n", 173 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n", 174 | " ).to('cpu')\n", 175 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_v_a.weight.data = (\n", 176 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n", 177 | " ).to('cpu')\n", 178 | " pipe.transformer.to(dtype=dtype)\n", 179 | " if cache_path is not None:\n", 180 | " state_dict = pipe.transformer.state_dict()\n", 181 | " attn_state_dict = {}\n", 182 | " for k in state_dict.keys():\n", 183 | " if 'base_layer' in k:\n", 184 | " attn_state_dict[k] = state_dict[k]\n", 185 | " save_file(attn_state_dict, cache_path)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "* Download pre-trained 4k adapter" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "if not os.path.exists('ckpt/urae_4k_adapter.safetensors'):\n", 202 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_4k_adapter.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "* Optionally, you can convert the minor-component adapter into a LoRA for easier use" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "lora_conversion = True\n", 219 | "if lora_conversion and not os.path.exists('ckpt/urae_4k_adapter_lora_conversion_dev.safetensors'):\n", 220 | " cur = pipe.transformer.state_dict()\n", 221 | " tgt = load_file('ckpt/urae_4k_adapter.safetensors')\n", 222 | " ref = load_file('ckpt/urae_2k_adapter.safetensors')\n", 223 | " new_ckpt = {}\n", 224 | " for k in tgt.keys():\n", 225 | " if 'to_k_a' in k:\n", 226 | " k_ = 'transformer.' + k.replace('.processor.to_k_a', '.to_k.lora_A')\n", 227 | " elif 'to_k_b' in k:\n", 228 | " k_ = 'transformer.' + k.replace('.processor.to_k_b', '.to_k.lora_B')\n", 229 | " elif 'to_q_a' in k:\n", 230 | " k_ = 'transformer.' + k.replace('.processor.to_q_a', '.to_q.lora_A')\n", 231 | " elif 'to_q_b' in k:\n", 232 | " k_ = 'transformer.' + k.replace('.processor.to_q_b', '.to_q.lora_B')\n", 233 | " elif 'to_v_a' in k:\n", 234 | " k_ = 'transformer.' + k.replace('.processor.to_v_a', '.to_v.lora_A')\n", 235 | " elif 'to_v_b' in k:\n", 236 | " k_ = 'transformer.' + k.replace('.processor.to_v_b', '.to_v.lora_B')\n", 237 | " elif 'to_out_a' in k:\n", 238 | " k_ = 'transformer.' + k.replace('.processor.to_out_a', '.to_out.0.lora_A')\n", 239 | " elif 'to_out_b' in k:\n", 240 | " k_ = 'transformer.' + k.replace('.processor.to_out_b', '.to_out.0.lora_B')\n", 241 | " else:\n", 242 | " print(k)\n", 243 | " assert False\n", 244 | " if '_a.' in k and '_b.' not in k:\n", 245 | " new_ckpt[k_] = torch.cat([-cur[k], tgt[k], ref[k_]], dim=0)\n", 246 | " elif '_b.' in k and '_a.' not in k:\n", 247 | " new_ckpt[k_] = torch.cat([cur[k], tgt[k], ref[k_]], dim=1)\n", 248 | " else:\n", 249 | " print(k)\n", 250 | " assert False\n", 251 | " save_file(new_ckpt, 'ckpt/urae_4k_adapter_lora_conversion_dev.safetensors')" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "* Load state_dict of 4k adapter" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "m, u = pipe.transformer.load_state_dict(load_file('ckpt/urae_4k_adapter.safetensors'), strict=False)\n", 268 | "assert len(u) == 0" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": {}, 274 | "source": [ 275 | "* Use patch-wise convolution for VAE to avoid OOM error when decoding\n", 276 | "* If still OOM, try replacing the following line with `pipe.vae.enable_tiling()`" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "pipe.vae = convert_model(pipe.vae, splits=4)" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "* Everything ready. Let's generate!\n", 293 | "* 4K generation using FLUX-1.dev can take a while, e.g., 5min on H100." 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "prompt = \"A serene woman in a flowing azure dress, gracefully perched on a sunlit cliff overlooking a tranquil sea, her hair gently tousled by the breeze. The scene is infused with a sense of peace, evoking a dreamlike atmosphere, reminiscent of Impressionist paintings.\"\n", 303 | "height = 4096\n", 304 | "width = 4096\n", 305 | "image = pipe(\n", 306 | " prompt,\n", 307 | " height=height,\n", 308 | " width=width,\n", 309 | " guidance_scale=3.5,\n", 310 | " num_inference_steps=28,\n", 311 | " max_sequence_length=512,\n", 312 | " generator=torch.manual_seed(8888),\n", 313 | " ntk_factor=10,\n", 314 | " proportional_attention=True\n", 315 | ").images[0]\n", 316 | "image" 317 | ] 318 | } 319 | ], 320 | "metadata": { 321 | "kernelspec": { 322 | "display_name": "Python 3 (ipykernel)", 323 | "language": "python", 324 | "name": "python3" 325 | }, 326 | "language_info": { 327 | "codemirror_mode": { 328 | "name": "ipython", 329 | "version": 3 330 | }, 331 | "file_extension": ".py", 332 | "mimetype": "text/x-python", 333 | "name": "python", 334 | "nbconvert_exporter": "python", 335 | "pygments_lexer": "ipython3", 336 | "version": "3.12.9" 337 | } 338 | }, 339 | "nbformat": 4, 340 | "nbformat_minor": 2 341 | } 342 | -------------------------------------------------------------------------------- /transformer_flux.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple, Union, List 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config 9 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin 10 | from diffusers.models.attention import FeedForward 11 | from diffusers.models.attention_processor import ( 12 | Attention, 13 | AttentionProcessor 14 | ) 15 | from diffusers.models.modeling_utils import ModelMixin 16 | from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle 17 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 18 | from diffusers.utils.torch_utils import maybe_allow_in_graph 19 | from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed 20 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 21 | from attention_processor import FluxAttnProcessor2_0 22 | 23 | 24 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 25 | 26 | 27 | @maybe_allow_in_graph 28 | class FluxSingleTransformerBlock(nn.Module): 29 | r""" 30 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. 31 | 32 | Reference: https://arxiv.org/abs/2403.03206 33 | 34 | Parameters: 35 | dim (`int`): The number of channels in the input and output. 36 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 37 | attention_head_dim (`int`): The number of channels in each head. 38 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the 39 | processing of `context` conditions. 40 | """ 41 | 42 | def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): 43 | super().__init__() 44 | self.mlp_hidden_dim = int(dim * mlp_ratio) 45 | 46 | self.norm = AdaLayerNormZeroSingle(dim) 47 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) 48 | self.act_mlp = nn.GELU(approximate="tanh") 49 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) 50 | 51 | processor = FluxAttnProcessor2_0() 52 | self.attn = Attention( 53 | query_dim=dim, 54 | cross_attention_dim=None, 55 | dim_head=attention_head_dim, 56 | heads=num_attention_heads, 57 | out_dim=dim, 58 | bias=True, 59 | processor=processor, 60 | qk_norm="rms_norm", 61 | eps=1e-6, 62 | pre_only=True, 63 | ) 64 | 65 | def forward( 66 | self, 67 | hidden_states: torch.FloatTensor, 68 | temb: torch.FloatTensor, 69 | image_rotary_emb=None, 70 | joint_attention_kwargs=None 71 | ): 72 | residual = hidden_states 73 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb) 74 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) 75 | joint_attention_kwargs = joint_attention_kwargs or {} 76 | attn_output = self.attn( 77 | hidden_states=norm_hidden_states, 78 | image_rotary_emb=image_rotary_emb, 79 | **joint_attention_kwargs, 80 | ) 81 | 82 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) 83 | gate = gate.unsqueeze(1) 84 | hidden_states = gate * self.proj_out(hidden_states) 85 | hidden_states = residual + hidden_states 86 | if hidden_states.dtype == torch.float16: 87 | hidden_states = hidden_states.clip(-65504, 65504) 88 | 89 | return hidden_states 90 | 91 | 92 | @maybe_allow_in_graph 93 | class FluxTransformerBlock(nn.Module): 94 | r""" 95 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. 96 | 97 | Reference: https://arxiv.org/abs/2403.03206 98 | 99 | Parameters: 100 | dim (`int`): The number of channels in the input and output. 101 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 102 | attention_head_dim (`int`): The number of channels in each head. 103 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the 104 | processing of `context` conditions. 105 | """ 106 | 107 | def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): 108 | super().__init__() 109 | 110 | self.norm1 = AdaLayerNormZero(dim) 111 | 112 | self.norm1_context = AdaLayerNormZero(dim) 113 | 114 | if hasattr(F, "scaled_dot_product_attention"): 115 | processor = FluxAttnProcessor2_0() 116 | else: 117 | raise ValueError( 118 | "The current PyTorch version does not support the `scaled_dot_product_attention` function." 119 | ) 120 | self.attn = Attention( 121 | query_dim=dim, 122 | cross_attention_dim=None, 123 | added_kv_proj_dim=dim, 124 | dim_head=attention_head_dim, 125 | heads=num_attention_heads, 126 | out_dim=dim, 127 | context_pre_only=False, 128 | bias=True, 129 | processor=processor, 130 | qk_norm=qk_norm, 131 | eps=eps, 132 | ) 133 | 134 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 135 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 136 | 137 | self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 138 | self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 139 | 140 | # let chunk size default to None 141 | self._chunk_size = None 142 | self._chunk_dim = 0 143 | 144 | def forward( 145 | self, 146 | hidden_states: torch.FloatTensor, 147 | encoder_hidden_states: torch.FloatTensor, 148 | temb: torch.FloatTensor, 149 | image_rotary_emb=None, 150 | joint_attention_kwargs=None, 151 | ): 152 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) 153 | 154 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( 155 | encoder_hidden_states, emb=temb 156 | ) 157 | joint_attention_kwargs = joint_attention_kwargs or {} 158 | # Attention. 159 | attn_output, context_attn_output = self.attn( 160 | hidden_states=norm_hidden_states, 161 | encoder_hidden_states=norm_encoder_hidden_states, 162 | image_rotary_emb=image_rotary_emb, 163 | **joint_attention_kwargs, 164 | ) 165 | 166 | # Process attention outputs for the `hidden_states`. 167 | attn_output = gate_msa.unsqueeze(1) * attn_output 168 | hidden_states = hidden_states + attn_output 169 | 170 | norm_hidden_states = self.norm2(hidden_states) 171 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 172 | 173 | ff_output = self.ff(norm_hidden_states) 174 | ff_output = gate_mlp.unsqueeze(1) * ff_output 175 | 176 | hidden_states = hidden_states + ff_output 177 | 178 | # Process attention outputs for the `encoder_hidden_states`. 179 | 180 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output 181 | encoder_hidden_states = encoder_hidden_states + context_attn_output 182 | 183 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) 184 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] 185 | 186 | context_ff_output = self.ff_context(norm_encoder_hidden_states) 187 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output 188 | if encoder_hidden_states.dtype == torch.float16: 189 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) 190 | 191 | return encoder_hidden_states, hidden_states 192 | 193 | 194 | class FluxPosEmbed(nn.Module): 195 | # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 196 | def __init__(self, theta: int, axes_dim: List[int]): 197 | super().__init__() 198 | self.theta = theta 199 | self.axes_dim = axes_dim 200 | 201 | def forward(self, ids: torch.Tensor, ntk_factor=1) -> torch.Tensor: 202 | n_axes = ids.shape[-1] 203 | cos_out = [] 204 | sin_out = [] 205 | pos = ids.float() 206 | is_mps = ids.device.type == "mps" 207 | freqs_dtype = torch.float32 if is_mps else torch.float64 208 | for i in range(n_axes): 209 | cos, sin = get_1d_rotary_pos_embed( 210 | self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype, 211 | ntk_factor=ntk_factor 212 | ) 213 | cos_out.append(cos) 214 | sin_out.append(sin) 215 | freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) 216 | freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) 217 | return freqs_cos, freqs_sin 218 | 219 | 220 | class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): 221 | """ 222 | The Transformer model introduced in Flux. 223 | 224 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 225 | 226 | Parameters: 227 | patch_size (`int`): Patch size to turn the input data into small patches. 228 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. 229 | num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. 230 | num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. 231 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. 232 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. 233 | joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 234 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. 235 | guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. 236 | """ 237 | 238 | _supports_gradient_checkpointing = True 239 | _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] 240 | 241 | @register_to_config 242 | def __init__( 243 | self, 244 | patch_size: int = 1, 245 | in_channels: int = 64, 246 | num_layers: int = 19, 247 | num_single_layers: int = 38, 248 | attention_head_dim: int = 128, 249 | num_attention_heads: int = 24, 250 | joint_attention_dim: int = 4096, 251 | pooled_projection_dim: int = 768, 252 | guidance_embeds: bool = False, 253 | axes_dims_rope: Tuple[int] = (16, 56, 56), 254 | ): 255 | super().__init__() 256 | self.out_channels = in_channels 257 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim 258 | 259 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) 260 | 261 | text_time_guidance_cls = ( 262 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings 263 | ) 264 | self.time_text_embed = text_time_guidance_cls( 265 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim 266 | ) 267 | 268 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) 269 | self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) 270 | 271 | self.transformer_blocks = nn.ModuleList( 272 | [ 273 | FluxTransformerBlock( 274 | dim=self.inner_dim, 275 | num_attention_heads=self.config.num_attention_heads, 276 | attention_head_dim=self.config.attention_head_dim, 277 | ) 278 | for i in range(self.config.num_layers) 279 | ] 280 | ) 281 | 282 | self.single_transformer_blocks = nn.ModuleList( 283 | [ 284 | FluxSingleTransformerBlock( 285 | dim=self.inner_dim, 286 | num_attention_heads=self.config.num_attention_heads, 287 | attention_head_dim=self.config.attention_head_dim, 288 | ) 289 | for i in range(self.config.num_single_layers) 290 | ] 291 | ) 292 | 293 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) 294 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) 295 | 296 | self.gradient_checkpointing = False 297 | 298 | @property 299 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 300 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 301 | r""" 302 | Returns: 303 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 304 | indexed by its weight name. 305 | """ 306 | # set recursively 307 | processors = {} 308 | 309 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 310 | if hasattr(module, "get_processor"): 311 | processors[f"{name}.processor"] = module.get_processor() 312 | 313 | for sub_name, child in module.named_children(): 314 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 315 | 316 | return processors 317 | 318 | for name, module in self.named_children(): 319 | fn_recursive_add_processors(name, module, processors) 320 | 321 | return processors 322 | 323 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 324 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 325 | r""" 326 | Sets the attention processor to use to compute attention. 327 | 328 | Parameters: 329 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 330 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 331 | for **all** `Attention` layers. 332 | 333 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 334 | processor. This is strongly recommended when setting trainable attention processors. 335 | 336 | """ 337 | count = len(self.attn_processors.keys()) 338 | 339 | if isinstance(processor, dict) and len(processor) != count: 340 | raise ValueError( 341 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 342 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 343 | ) 344 | 345 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 346 | if hasattr(module, "set_processor"): 347 | if not isinstance(processor, dict): 348 | module.set_processor(processor) 349 | else: 350 | module.set_processor(processor.pop(f"{name}.processor")) 351 | 352 | for sub_name, child in module.named_children(): 353 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 354 | 355 | for name, module in self.named_children(): 356 | fn_recursive_attn_processor(name, module, processor) 357 | 358 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 359 | def unfuse_qkv_projections(self): 360 | """Disables the fused QKV projection if enabled. 361 | 362 | 363 | 364 | This API is 🧪 experimental. 365 | 366 | 367 | 368 | """ 369 | if self.original_attn_processors is not None: 370 | self.set_attn_processor(self.original_attn_processors) 371 | 372 | def _set_gradient_checkpointing(self, module, value=False): 373 | if hasattr(module, "gradient_checkpointing"): 374 | module.gradient_checkpointing = value 375 | 376 | def forward( 377 | self, 378 | hidden_states: torch.Tensor, 379 | encoder_hidden_states: torch.Tensor = None, 380 | pooled_projections: torch.Tensor = None, 381 | timestep: torch.LongTensor = None, 382 | img_ids: torch.Tensor = None, 383 | txt_ids: torch.Tensor = None, 384 | guidance: torch.Tensor = None, 385 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 386 | controlnet_block_samples=None, 387 | controlnet_single_block_samples=None, 388 | return_dict: bool = True, 389 | ntk_factor: float = 1, 390 | controlnet_blocks_repeat: bool = False, 391 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: 392 | """ 393 | The [`FluxTransformer2DModel`] forward method. 394 | 395 | Args: 396 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): 397 | Input `hidden_states`. 398 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): 399 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. 400 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected 401 | from the embeddings of input conditions. 402 | timestep ( `torch.LongTensor`): 403 | Used to indicate denoising step. 404 | block_controlnet_hidden_states: (`list` of `torch.Tensor`): 405 | A list of tensors that if specified are added to the residuals of transformer blocks. 406 | joint_attention_kwargs (`dict`, *optional*): 407 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 408 | `self.processor` in 409 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 410 | return_dict (`bool`, *optional*, defaults to `True`): 411 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 412 | tuple. 413 | 414 | Returns: 415 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 416 | `tuple` where the first element is the sample tensor. 417 | """ 418 | 419 | if txt_ids.ndim == 3: 420 | logger.warning( 421 | "Passing `txt_ids` 3d torch.Tensor is deprecated." 422 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 423 | ) 424 | txt_ids = txt_ids[0] 425 | if img_ids.ndim == 3: 426 | logger.warning( 427 | "Passing `img_ids` 3d torch.Tensor is deprecated." 428 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 429 | ) 430 | img_ids = img_ids[0] 431 | 432 | if joint_attention_kwargs is not None: 433 | joint_attention_kwargs = joint_attention_kwargs.copy() 434 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 435 | else: 436 | lora_scale = 1.0 437 | 438 | if USE_PEFT_BACKEND: 439 | # weight the lora layers by setting `lora_scale` for each PEFT layer 440 | scale_lora_layers(self, lora_scale) 441 | else: 442 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 443 | logger.warning( 444 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 445 | ) 446 | hidden_states = self.x_embedder(hidden_states) 447 | 448 | timestep = timestep.to(hidden_states.dtype) * 1000 449 | if guidance is not None: 450 | guidance = guidance.to(hidden_states.dtype) * 1000 451 | else: 452 | guidance = None 453 | temb = ( 454 | self.time_text_embed(timestep, pooled_projections) 455 | if guidance is None 456 | else self.time_text_embed(timestep, guidance, pooled_projections) 457 | ) 458 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 459 | 460 | ids = torch.cat((txt_ids, img_ids), dim=0) 461 | image_rotary_emb = self.pos_embed(ids, ntk_factor=ntk_factor) 462 | 463 | for index_block, block in enumerate(self.transformer_blocks): 464 | if self.training and self.gradient_checkpointing: 465 | 466 | def create_custom_forward(module, return_dict=None): 467 | def custom_forward(*inputs): 468 | if return_dict is not None: 469 | return module(*inputs, return_dict=return_dict) 470 | else: 471 | return module(*inputs) 472 | 473 | return custom_forward 474 | 475 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 476 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 477 | create_custom_forward(block), 478 | hidden_states, 479 | encoder_hidden_states, 480 | temb, 481 | image_rotary_emb, 482 | joint_attention_kwargs, 483 | **ckpt_kwargs, 484 | ) 485 | 486 | else: 487 | encoder_hidden_states, hidden_states = block( 488 | hidden_states=hidden_states, 489 | encoder_hidden_states=encoder_hidden_states, 490 | temb=temb, 491 | image_rotary_emb=image_rotary_emb, 492 | joint_attention_kwargs=joint_attention_kwargs, 493 | ) 494 | 495 | # controlnet residual 496 | if controlnet_block_samples is not None: 497 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) 498 | interval_control = int(np.ceil(interval_control)) 499 | # For Xlabs ControlNet. 500 | if controlnet_blocks_repeat: 501 | hidden_states = ( 502 | hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] 503 | ) 504 | else: 505 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] 506 | 507 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 508 | 509 | for index_block, block in enumerate(self.single_transformer_blocks): 510 | if self.training and self.gradient_checkpointing: 511 | 512 | def create_custom_forward(module, return_dict=None): 513 | def custom_forward(*inputs): 514 | if return_dict is not None: 515 | return module(*inputs, return_dict=return_dict) 516 | else: 517 | return module(*inputs) 518 | 519 | return custom_forward 520 | 521 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 522 | hidden_states = torch.utils.checkpoint.checkpoint( 523 | create_custom_forward(block), 524 | hidden_states, 525 | temb, 526 | image_rotary_emb, 527 | joint_attention_kwargs, 528 | **ckpt_kwargs, 529 | ) 530 | 531 | else: 532 | hidden_states = block( 533 | hidden_states=hidden_states, 534 | temb=temb, 535 | image_rotary_emb=image_rotary_emb, 536 | joint_attention_kwargs=joint_attention_kwargs, 537 | ) 538 | 539 | # controlnet residual 540 | if controlnet_single_block_samples is not None: 541 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) 542 | interval_control = int(np.ceil(interval_control)) 543 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 544 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] 545 | + controlnet_single_block_samples[index_block // interval_control] 546 | ) 547 | 548 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 549 | 550 | hidden_states = self.norm_out(hidden_states, temb) 551 | output = self.proj_out(hidden_states) 552 | 553 | if USE_PEFT_BACKEND: 554 | # remove `lora_scale` from each PEFT layer 555 | unscale_lora_layers(self, lora_scale) 556 | 557 | if not return_dict: 558 | return (output,) 559 | 560 | return Transformer2DModelOutput(sample=output) 561 | -------------------------------------------------------------------------------- /train_2k.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import logging 4 | import os 5 | import random 6 | import shutil 7 | from pathlib import Path 8 | from safetensors.torch import load_file 9 | import torch 10 | import torch.utils.checkpoint 11 | import transformers 12 | from accelerate import Accelerator 13 | from accelerate.logging import get_logger 14 | from accelerate.utils import ProjectConfiguration, set_seed 15 | from peft import LoraConfig, set_peft_model_state_dict 16 | from peft.utils import get_peft_model_state_dict 17 | from torch.utils.data import Dataset 18 | from tqdm.auto import tqdm 19 | 20 | import diffusers 21 | from diffusers import ( 22 | AutoencoderKL, 23 | FlowMatchEulerDiscreteScheduler 24 | ) 25 | from diffusers.optimization import get_scheduler 26 | from diffusers.training_utils import ( 27 | cast_training_params, 28 | compute_density_for_timestep_sampling, 29 | compute_loss_weighting_for_sd3, 30 | free_memory, 31 | ) 32 | from diffusers.utils import ( 33 | check_min_version, 34 | convert_unet_state_dict_to_peft, 35 | is_wandb_available, 36 | ) 37 | from diffusers.utils.torch_utils import is_compiled_module 38 | from transformer_flux import FluxTransformer2DModel 39 | from pipeline_flux import FluxPipeline 40 | from attention_processor import FluxAttnProcessor2_0 41 | 42 | 43 | if is_wandb_available(): 44 | import wandb 45 | 46 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 47 | check_min_version("0.31.0.dev0") 48 | 49 | logger = get_logger(__name__) 50 | 51 | 52 | def parse_args(input_args=None): 53 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 54 | parser.add_argument( 55 | "--pretrained_model_name_or_path", 56 | type=str, 57 | default=None, 58 | required=True, 59 | help="Path to pretrained model or model identifier from huggingface.co/models.", 60 | ) 61 | parser.add_argument( 62 | "--revision", 63 | type=str, 64 | default=None, 65 | required=False, 66 | help="Revision of pretrained model identifier from huggingface.co/models.", 67 | ) 68 | parser.add_argument( 69 | "--variant", 70 | type=str, 71 | default=None, 72 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 73 | ) 74 | parser.add_argument( 75 | "--dataset_root", 76 | type=str, 77 | default=None, 78 | help=( 79 | "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," 80 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 81 | " or to a folder containing files that 🤗 Datasets can understand." 82 | ), 83 | ) 84 | 85 | parser.add_argument( 86 | "--cache_dir", 87 | type=str, 88 | default=None, 89 | help="The directory where the downloaded models and datasets will be stored.", 90 | ) 91 | 92 | parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") 93 | 94 | parser.add_argument( 95 | "--real_prompt_ratio", 96 | type=float, 97 | default=0.5 98 | ) 99 | parser.add_argument( 100 | "--max_sequence_length", 101 | type=int, 102 | default=512, 103 | help="Maximum sequence length to use with with the T5 text encoder", 104 | ) 105 | parser.add_argument( 106 | "--rank", 107 | type=int, 108 | default=16, 109 | help=("The dimension of the LoRA update matrices."), 110 | ) 111 | parser.add_argument( 112 | "--output_dir", 113 | type=str, 114 | default="urae_2k", 115 | help="The output directory where the model predictions and checkpoints will be written.", 116 | ) 117 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 118 | parser.add_argument( 119 | "--proportional_attention", 120 | action='store_true', 121 | help="Dynamic attention scale with respect to the resolution", 122 | ) 123 | parser.add_argument( 124 | "--ntk_factor", 125 | type=float, 126 | default=1., 127 | help="NTK factor for RePE" 128 | ) 129 | parser.add_argument( 130 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 131 | ) 132 | parser.add_argument( 133 | "--max_train_steps", 134 | type=int, 135 | default=10000 136 | ) 137 | parser.add_argument( 138 | "--checkpointing_steps", 139 | type=int, 140 | default=1000, 141 | help=( 142 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" 143 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" 144 | " training using `--resume_from_checkpoint`." 145 | ), 146 | ) 147 | parser.add_argument( 148 | "--checkpoints_total_limit", 149 | type=int, 150 | default=None, 151 | help=("Max number of checkpoints to store."), 152 | ) 153 | parser.add_argument( 154 | "--resume_from_checkpoint", 155 | type=str, 156 | default=None, 157 | help=( 158 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 159 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 160 | ), 161 | ) 162 | parser.add_argument( 163 | "--gradient_accumulation_steps", 164 | type=int, 165 | default=1, 166 | help="Number of updates steps to accumulate before performing a backward/update pass.", 167 | ) 168 | parser.add_argument( 169 | "--gradient_checkpointing", 170 | action="store_true", 171 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 172 | ) 173 | parser.add_argument( 174 | "--learning_rate", 175 | type=float, 176 | default=1e-4, 177 | help="Initial learning rate (after the potential warmup period) to use.", 178 | ) 179 | 180 | parser.add_argument( 181 | "--guidance_scale", 182 | type=float, 183 | default=1, 184 | help="the FLUX.1 dev variant is a guidance distilled model", 185 | ) 186 | 187 | parser.add_argument( 188 | "--scale_lr", 189 | action="store_true", 190 | default=False, 191 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 192 | ) 193 | parser.add_argument( 194 | "--lr_scheduler", 195 | type=str, 196 | default="constant", 197 | help=( 198 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 199 | ' "constant", "constant_with_warmup"]' 200 | ), 201 | ) 202 | parser.add_argument( 203 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 204 | ) 205 | parser.add_argument( 206 | "--lr_num_cycles", 207 | type=int, 208 | default=1, 209 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 210 | ) 211 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 212 | parser.add_argument( 213 | "--dataloader_num_workers", 214 | type=int, 215 | default=0, 216 | help=( 217 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 218 | ), 219 | ) 220 | parser.add_argument( 221 | "--weighting_scheme", 222 | type=str, 223 | default="none", 224 | choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], 225 | help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), 226 | ) 227 | parser.add_argument( 228 | "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." 229 | ) 230 | parser.add_argument( 231 | "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." 232 | ) 233 | parser.add_argument( 234 | "--mode_scale", 235 | type=float, 236 | default=1.29, 237 | help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", 238 | ) 239 | parser.add_argument( 240 | "--optimizer", 241 | type=str, 242 | default="AdamW", 243 | help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), 244 | ) 245 | 246 | parser.add_argument( 247 | "--use_8bit_adam", 248 | action="store_true", 249 | help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", 250 | ) 251 | 252 | parser.add_argument( 253 | "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." 254 | ) 255 | parser.add_argument( 256 | "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." 257 | ) 258 | parser.add_argument( 259 | "--prodigy_beta3", 260 | type=float, 261 | default=None, 262 | help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " 263 | "uses the value of square root of beta2. Ignored if optimizer is adamW", 264 | ) 265 | parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") 266 | parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") 267 | 268 | parser.add_argument( 269 | "--adam_epsilon", 270 | type=float, 271 | default=1e-08, 272 | help="Epsilon value for the Adam optimizer and Prodigy optimizers.", 273 | ) 274 | 275 | parser.add_argument( 276 | "--prodigy_use_bias_correction", 277 | type=bool, 278 | default=True, 279 | help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", 280 | ) 281 | parser.add_argument( 282 | "--prodigy_safeguard_warmup", 283 | type=bool, 284 | default=True, 285 | help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " 286 | "Ignored if optimizer is adamW", 287 | ) 288 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 289 | parser.add_argument( 290 | "--logging_dir", 291 | type=str, 292 | default="logs", 293 | help=( 294 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 295 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 296 | ), 297 | ) 298 | parser.add_argument( 299 | "--allow_tf32", 300 | action="store_true", 301 | help=( 302 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 303 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 304 | ), 305 | ) 306 | parser.add_argument( 307 | "--report_to", 308 | type=str, 309 | default="tensorboard", 310 | help=( 311 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 312 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 313 | ), 314 | ) 315 | parser.add_argument( 316 | "--mixed_precision", 317 | type=str, 318 | default=None, 319 | choices=["no", "fp16", "bf16"], 320 | help=( 321 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 322 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 323 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 324 | ), 325 | ) 326 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 327 | 328 | if input_args is not None: 329 | args = parser.parse_args(input_args) 330 | else: 331 | args = parser.parse_args() 332 | 333 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 334 | if env_local_rank != -1 and env_local_rank != args.local_rank: 335 | args.local_rank = env_local_rank 336 | 337 | return args 338 | 339 | 340 | class CustomImageDataset(Dataset): 341 | def __init__(self, img_dir, real_prompt_ratio=0.5): 342 | self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] 343 | self.real_prompt_ratio = real_prompt_ratio 344 | 345 | def __len__(self): 346 | return len(self.images) 347 | 348 | def __getitem__(self, idx): 349 | try: 350 | batch = {} 351 | f = load_file(self.images[idx][:self.images[idx].rfind('.')] + '_latent_code.safetensors') 352 | batch['latent_codes_mean'] = f['mean'] 353 | batch['latent_codes_std'] = f['std'] 354 | prompt_embeds_path = self.images[idx][:self.images[idx].rfind('.')] + '_prompt_embed.safetensors' 355 | generated_prompt_embeds_path = self.images[idx][:self.images[idx].rfind('.')] + '_generated_prompt_embed.safetensors' 356 | if (not os.path.exists(generated_prompt_embeds_path) or random.random() < self.real_prompt_ratio) and os.path.exists(prompt_embeds_path): 357 | f = load_file(prompt_embeds_path) 358 | else: 359 | f = load_file(generated_prompt_embeds_path) 360 | batch['prompt_embeds_t5'] = f['caption_feature_t5'] 361 | batch['prompt_embeds_clip'] = f['caption_feature_clip'] 362 | return batch 363 | except Exception as e: 364 | print(e) 365 | return self.__getitem__(random.randint(0, len(self.images) - 1)) 366 | 367 | 368 | def main(args): 369 | if torch.backends.mps.is_available() and args.mixed_precision == "bf16": 370 | # due to pytorch#99272, MPS does not yet support bfloat16. 371 | raise ValueError( 372 | "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." 373 | ) 374 | 375 | logging_dir = Path(args.output_dir, args.logging_dir) 376 | 377 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 378 | accelerator = Accelerator( 379 | gradient_accumulation_steps=args.gradient_accumulation_steps, 380 | mixed_precision=args.mixed_precision, 381 | log_with=args.report_to, 382 | project_config=accelerator_project_config 383 | ) 384 | 385 | # Disable AMP for MPS. 386 | if torch.backends.mps.is_available(): 387 | accelerator.native_amp = False 388 | 389 | if args.report_to == "wandb": 390 | if not is_wandb_available(): 391 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 392 | 393 | # Make one log on every process with the configuration for debugging. 394 | logging.basicConfig( 395 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 396 | datefmt="%m/%d/%Y %H:%M:%S", 397 | level=logging.INFO, 398 | ) 399 | logger.info(accelerator.state, main_process_only=False) 400 | if accelerator.is_local_main_process: 401 | transformers.utils.logging.set_verbosity_warning() 402 | diffusers.utils.logging.set_verbosity_info() 403 | else: 404 | transformers.utils.logging.set_verbosity_error() 405 | diffusers.utils.logging.set_verbosity_error() 406 | 407 | # If passed along, set the training seed now. 408 | if args.seed is not None: 409 | set_seed(args.seed) 410 | 411 | # Handle the repository creation 412 | if accelerator.is_main_process: 413 | if args.output_dir is not None: 414 | os.makedirs(args.output_dir, exist_ok=True) 415 | 416 | # Load scheduler and models 417 | noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 418 | args.pretrained_model_name_or_path, subfolder="scheduler" 419 | ) 420 | noise_scheduler_copy = copy.deepcopy(noise_scheduler) 421 | vae = AutoencoderKL.from_pretrained( 422 | args.pretrained_model_name_or_path, 423 | subfolder="vae", 424 | revision=args.revision, 425 | variant=args.variant, 426 | ) 427 | transformer = FluxTransformer2DModel.from_pretrained( 428 | args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant 429 | ) 430 | 431 | attn_processors = {} 432 | for k in transformer.attn_processors.keys(): 433 | attn_processors[k] = FluxAttnProcessor2_0() 434 | transformer.set_attn_processor(attn_processors) 435 | 436 | # We only train the additional adapter LoRA layers 437 | transformer.requires_grad_(False) 438 | vae.requires_grad_(False) 439 | 440 | # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision 441 | # as these weights are only used for inference, keeping weights in full precision is not required. 442 | weight_dtype = torch.float32 443 | if accelerator.mixed_precision == "fp16": 444 | weight_dtype = torch.float16 445 | elif accelerator.mixed_precision == "bf16": 446 | weight_dtype = torch.bfloat16 447 | 448 | if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: 449 | # due to pytorch#99272, MPS does not yet support bfloat16. 450 | raise ValueError( 451 | "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." 452 | ) 453 | 454 | vae.to(accelerator.device, dtype=weight_dtype) 455 | transformer.to(accelerator.device, dtype=weight_dtype) 456 | 457 | if args.gradient_checkpointing: 458 | transformer.enable_gradient_checkpointing() 459 | 460 | # now we will add new LoRA weights to the attention layers 461 | transformer_lora_config = LoraConfig( 462 | r=args.rank, 463 | lora_alpha=args.rank, 464 | init_lora_weights="gaussian", 465 | target_modules=["to_k", "to_q", "to_v", "to_out.0"], 466 | ) 467 | transformer.add_adapter(transformer_lora_config) 468 | 469 | def unwrap_model(model): 470 | model = accelerator.unwrap_model(model) 471 | model = model._orig_mod if is_compiled_module(model) else model 472 | return model 473 | 474 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 475 | def save_model_hook(models, weights, output_dir): 476 | if accelerator.is_main_process: 477 | transformer_lora_layers_to_save = None 478 | text_encoder_one_lora_layers_to_save = None 479 | 480 | for model in models: 481 | if isinstance(model, type(unwrap_model(transformer))): 482 | transformer_lora_layers_to_save = get_peft_model_state_dict(model) 483 | else: 484 | raise ValueError(f"unexpected save model: {model.__class__}") 485 | 486 | # make sure to pop weight so that corresponding model is not saved again 487 | weights.pop() 488 | 489 | FluxPipeline.save_lora_weights( 490 | output_dir, 491 | transformer_lora_layers=transformer_lora_layers_to_save, 492 | text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, 493 | ) 494 | 495 | def load_model_hook(models, input_dir): 496 | transformer_ = None 497 | 498 | while len(models) > 0: 499 | model = models.pop() 500 | 501 | if isinstance(model, type(unwrap_model(transformer))): 502 | transformer_ = model 503 | else: 504 | raise ValueError(f"unexpected save model: {model.__class__}") 505 | 506 | lora_state_dict = FluxPipeline.lora_state_dict(input_dir) 507 | 508 | transformer_state_dict = { 509 | f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") 510 | } 511 | transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) 512 | incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") 513 | if incompatible_keys is not None: 514 | # check only for unexpected keys 515 | unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) 516 | if unexpected_keys: 517 | logger.warning( 518 | f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " 519 | f" {unexpected_keys}. " 520 | ) 521 | # Make sure the trainable params are in float32. This is again needed since the base models 522 | # are in `weight_dtype`. More details: 523 | # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 524 | if args.mixed_precision == "fp16": 525 | models = [transformer_] 526 | # only upcast trainable parameters (LoRA) into fp32 527 | cast_training_params(models) 528 | 529 | accelerator.register_save_state_pre_hook(save_model_hook) 530 | accelerator.register_load_state_pre_hook(load_model_hook) 531 | 532 | # Enable TF32 for faster training on Ampere GPUs, 533 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 534 | if args.allow_tf32 and torch.cuda.is_available(): 535 | torch.backends.cuda.matmul.allow_tf32 = True 536 | 537 | if args.scale_lr: 538 | args.learning_rate = ( 539 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 540 | ) 541 | 542 | # Make sure the trainable params are in float32. 543 | if args.mixed_precision == "fp16": 544 | models = [transformer] 545 | # only upcast trainable parameters (LoRA) into fp32 546 | cast_training_params(models, dtype=torch.float32) 547 | 548 | transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) 549 | 550 | # Optimization parameters 551 | transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} 552 | 553 | params_to_optimize = [transformer_parameters_with_lr] 554 | 555 | # Optimizer creation 556 | if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): 557 | logger.warning( 558 | f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." 559 | "Defaulting to adamW" 560 | ) 561 | args.optimizer = "adamw" 562 | 563 | if args.use_8bit_adam and not args.optimizer.lower() == "adamw": 564 | logger.warning( 565 | f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " 566 | f"set to {args.optimizer.lower()}" 567 | ) 568 | 569 | if args.optimizer.lower() == "adamw": 570 | if args.use_8bit_adam: 571 | try: 572 | import bitsandbytes as bnb 573 | except ImportError: 574 | raise ImportError( 575 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 576 | ) 577 | 578 | optimizer_class = bnb.optim.AdamW8bit 579 | else: 580 | optimizer_class = torch.optim.AdamW 581 | 582 | optimizer = optimizer_class( 583 | params_to_optimize, 584 | betas=(args.adam_beta1, args.adam_beta2), 585 | weight_decay=args.adam_weight_decay, 586 | eps=args.adam_epsilon, 587 | ) 588 | 589 | if args.optimizer.lower() == "prodigy": 590 | try: 591 | import prodigyopt 592 | except ImportError: 593 | raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") 594 | 595 | optimizer_class = prodigyopt.Prodigy 596 | 597 | if args.learning_rate <= 0.1: 598 | logger.warning( 599 | "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" 600 | ) 601 | 602 | optimizer = optimizer_class( 603 | params_to_optimize, 604 | lr=args.learning_rate, 605 | betas=(args.adam_beta1, args.adam_beta2), 606 | beta3=args.prodigy_beta3, 607 | weight_decay=args.adam_weight_decay, 608 | eps=args.adam_epsilon, 609 | decouple=args.prodigy_decouple, 610 | use_bias_correction=args.prodigy_use_bias_correction, 611 | safeguard_warmup=args.prodigy_safeguard_warmup, 612 | ) 613 | 614 | # Dataset and DataLoaders creation: 615 | train_dataloader = torch.utils.data.DataLoader( 616 | CustomImageDataset( 617 | args.dataset_root, 618 | real_prompt_ratio=args.real_prompt_ratio 619 | ), 620 | batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, shuffle=True 621 | ) 622 | 623 | vae_config_shift_factor = vae.config.shift_factor 624 | vae_config_scaling_factor = vae.config.scaling_factor 625 | vae_config_block_out_channels = vae.config.block_out_channels 626 | 627 | del vae 628 | free_memory() 629 | 630 | # Scheduler and math around the number of training steps. 631 | lr_scheduler = get_scheduler( 632 | args.lr_scheduler, 633 | optimizer=optimizer, 634 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 635 | num_training_steps=args.max_train_steps * accelerator.num_processes, 636 | num_cycles=args.lr_num_cycles, 637 | power=args.lr_power, 638 | ) 639 | 640 | # Prepare everything with our `accelerator`. 641 | guidance_embeds = transformer.config.guidance_embeds 642 | transformer, train_dataloader, optimizer, lr_scheduler = accelerator.prepare( 643 | transformer, train_dataloader, optimizer, lr_scheduler 644 | ) 645 | 646 | # We need to initialize the trackers we use, and also store our configuration. 647 | # The trackers initializes automatically on the main process. 648 | if accelerator.is_main_process: 649 | tracker_name = "flux-hr" 650 | accelerator.init_trackers(tracker_name, config=vars(args)) 651 | 652 | # Train! 653 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 654 | 655 | logger.info("***** Running training *****") 656 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 657 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 658 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 659 | logger.info(f" Total optimization steps = {args.max_train_steps}") 660 | global_step = 0 661 | 662 | # Potentially load in the weights and states from a previous save 663 | if args.resume_from_checkpoint: 664 | if args.resume_from_checkpoint != "latest": 665 | path = os.path.basename(args.resume_from_checkpoint) 666 | else: 667 | # Get the mos recent checkpoint 668 | dirs = os.listdir(args.output_dir) 669 | dirs = [d for d in dirs if d.startswith("checkpoint")] 670 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 671 | path = dirs[-1] if len(dirs) > 0 else None 672 | 673 | if path is None: 674 | accelerator.print( 675 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 676 | ) 677 | args.resume_from_checkpoint = None 678 | initial_global_step = 0 679 | else: 680 | accelerator.print(f"Resuming from checkpoint {path}") 681 | accelerator.load_state(os.path.join(args.output_dir, path)) 682 | global_step = int(path.split("-")[1]) 683 | 684 | initial_global_step = global_step 685 | 686 | else: 687 | initial_global_step = 0 688 | 689 | progress_bar = tqdm( 690 | range(0, args.max_train_steps), 691 | initial=initial_global_step, 692 | desc="Steps", 693 | # Only show the progress bar once on each machine. 694 | disable=not accelerator.is_local_main_process, 695 | ) 696 | 697 | def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): 698 | sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) 699 | schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) 700 | timesteps = timesteps.to(accelerator.device) 701 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] 702 | 703 | sigma = sigmas[step_indices].flatten() 704 | while len(sigma.shape) < n_dim: 705 | sigma = sigma.unsqueeze(-1) 706 | return sigma 707 | 708 | transformer.train() 709 | loader_iter = iter(train_dataloader) 710 | while True: 711 | try: 712 | batch = loader_iter.__next__() 713 | except StopIteration: 714 | loader_iter = iter(train_dataloader) 715 | batch = loader_iter.__next__() 716 | models_to_accumulate = [transformer] 717 | 718 | with accelerator.accumulate(models_to_accumulate): 719 | 720 | prompt_embeds = batch['prompt_embeds_t5'].to(dtype=weight_dtype) 721 | pooled_prompt_embeds = batch['prompt_embeds_clip'].to(dtype=weight_dtype) 722 | text_ids = torch.zeros(prompt_embeds.shape[1], 3, device=accelerator.device, dtype=weight_dtype) 723 | 724 | with torch.no_grad(): 725 | mean = batch['latent_codes_mean'].to(dtype=weight_dtype) 726 | std = batch['latent_codes_std'].to(dtype=weight_dtype) 727 | sample = torch.randn_like(mean) 728 | model_input = mean + std * sample 729 | model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor 730 | 731 | vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) 732 | 733 | latent_image_ids = FluxPipeline._prepare_latent_image_ids( 734 | model_input.shape[0], 735 | model_input.shape[2], 736 | model_input.shape[3], 737 | accelerator.device, 738 | weight_dtype, 739 | ) 740 | # Sample noise that we'll add to the latents 741 | noise = torch.randn_like(model_input) 742 | bsz = model_input.shape[0] 743 | 744 | # Sample a random timestep for each image 745 | # for weighting schemes where we sample timesteps non-uniformly 746 | u = compute_density_for_timestep_sampling( 747 | weighting_scheme=args.weighting_scheme, 748 | batch_size=bsz, 749 | logit_mean=args.logit_mean, 750 | logit_std=args.logit_std, 751 | mode_scale=args.mode_scale, 752 | ) 753 | indices = (u * noise_scheduler.config.num_train_timesteps).long() 754 | timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device) 755 | 756 | # Add noise according to flow matching. 757 | # zt = (1 - texp) * x + texp * z1 758 | sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) 759 | noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise 760 | 761 | packed_noisy_model_input = FluxPipeline._pack_latents( 762 | noisy_model_input, 763 | batch_size=model_input.shape[0], 764 | num_channels_latents=model_input.shape[1], 765 | height=model_input.shape[2], 766 | width=model_input.shape[3], 767 | ) 768 | 769 | # handle guidance 770 | if guidance_embeds: 771 | guidance = torch.tensor([args.guidance_scale], device=accelerator.device) 772 | guidance = guidance.expand(model_input.shape[0]) 773 | else: 774 | guidance = None 775 | 776 | model_pred = transformer( 777 | hidden_states=packed_noisy_model_input, 778 | # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) 779 | timestep=timesteps / 1000, 780 | guidance=guidance, 781 | pooled_projections=pooled_prompt_embeds, 782 | encoder_hidden_states=prompt_embeds, 783 | txt_ids=text_ids, 784 | img_ids=latent_image_ids, 785 | return_dict=False, 786 | ntk_factor=args.ntk_factor, 787 | joint_attention_kwargs={'proportional_attention': args.proportional_attention} 788 | )[0] 789 | model_pred = FluxPipeline._unpack_latents( 790 | model_pred, 791 | height=int(model_input.shape[2] * vae_scale_factor / 2), 792 | width=int(model_input.shape[3] * vae_scale_factor / 2), 793 | vae_scale_factor=vae_scale_factor, 794 | ) 795 | 796 | # these weighting schemes use a uniform timestep sampling 797 | # and instead post-weight the loss 798 | weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) 799 | 800 | # flow matching loss 801 | target = noise - model_input 802 | 803 | # Compute regular loss. 804 | loss = (weighting.float() * (model_pred.float() - target.float()) ** 2).mean() 805 | 806 | accelerator.backward(loss) 807 | if accelerator.sync_gradients: 808 | params_to_clip = ( 809 | transformer.parameters() 810 | ) 811 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 812 | 813 | optimizer.step() 814 | lr_scheduler.step() 815 | optimizer.zero_grad() 816 | 817 | # Checks if the accelerator has performed an optimization step behind the scenes 818 | if accelerator.sync_gradients: 819 | progress_bar.update(1) 820 | global_step += 1 821 | 822 | if global_step % args.checkpointing_steps == 0: 823 | if accelerator.is_main_process: 824 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 825 | if args.checkpoints_total_limit is not None: 826 | checkpoints = os.listdir(args.output_dir) 827 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 828 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 829 | 830 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 831 | if len(checkpoints) >= args.checkpoints_total_limit: 832 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 833 | removing_checkpoints = checkpoints[0:num_to_remove] 834 | 835 | logger.info( 836 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 837 | ) 838 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 839 | 840 | for removing_checkpoint in removing_checkpoints: 841 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 842 | shutil.rmtree(removing_checkpoint) 843 | 844 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 845 | logger.info(f"Saving state to {save_path}...") 846 | accelerator.save_state(save_path) 847 | 848 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 849 | progress_bar.set_postfix(**logs) 850 | accelerator.log(logs, step=global_step) 851 | 852 | if global_step >= args.max_train_steps: 853 | break 854 | 855 | accelerator.wait_for_everyone() 856 | accelerator.end_training() 857 | 858 | 859 | if __name__ == "__main__": 860 | args = parse_args() 861 | main(args) -------------------------------------------------------------------------------- /pipeline_flux.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import Any, Callable, Dict, List, Optional, Union 17 | 18 | import numpy as np 19 | import torch 20 | from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast 21 | 22 | from diffusers.image_processor import VaeImageProcessor 23 | from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin 24 | from diffusers.models.autoencoders import AutoencoderKL 25 | from diffusers.models.transformers import FluxTransformer2DModel 26 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 27 | from diffusers.utils import ( 28 | USE_PEFT_BACKEND, 29 | is_torch_xla_available, 30 | logging, 31 | replace_example_docstring, 32 | scale_lora_layers, 33 | unscale_lora_layers, 34 | ) 35 | from diffusers.utils.torch_utils import randn_tensor 36 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 37 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput 38 | 39 | 40 | if is_torch_xla_available(): 41 | import torch_xla.core.xla_model as xm 42 | 43 | XLA_AVAILABLE = True 44 | else: 45 | XLA_AVAILABLE = False 46 | 47 | 48 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 49 | 50 | EXAMPLE_DOC_STRING = """ 51 | Examples: 52 | ```py 53 | >>> import torch 54 | >>> from diffusers import FluxPipeline 55 | 56 | >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) 57 | >>> pipe.to("cuda") 58 | >>> prompt = "A cat holding a sign that says hello world" 59 | >>> # Depending on the variant being used, the pipeline call will slightly vary. 60 | >>> # Refer to the pipeline documentation for more details. 61 | >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] 62 | >>> image.save("flux.png") 63 | ``` 64 | """ 65 | 66 | 67 | def calculate_shift( 68 | image_seq_len, 69 | base_seq_len: int = 256, 70 | max_seq_len: int = 4096, 71 | base_shift: float = 0.5, 72 | max_shift: float = 1.16, 73 | ): 74 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) 75 | b = base_shift - m * base_seq_len 76 | mu = image_seq_len * m + b 77 | return mu 78 | 79 | 80 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 81 | def retrieve_timesteps( 82 | scheduler, 83 | num_inference_steps: Optional[int] = None, 84 | device: Optional[Union[str, torch.device]] = None, 85 | timesteps: Optional[List[int]] = None, 86 | sigmas: Optional[List[float]] = None, 87 | **kwargs, 88 | ): 89 | r""" 90 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 91 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 92 | 93 | Args: 94 | scheduler (`SchedulerMixin`): 95 | The scheduler to get timesteps from. 96 | num_inference_steps (`int`): 97 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 98 | must be `None`. 99 | device (`str` or `torch.device`, *optional*): 100 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 101 | timesteps (`List[int]`, *optional*): 102 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 103 | `num_inference_steps` and `sigmas` must be `None`. 104 | sigmas (`List[float]`, *optional*): 105 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 106 | `num_inference_steps` and `timesteps` must be `None`. 107 | 108 | Returns: 109 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 110 | second element is the number of inference steps. 111 | """ 112 | if timesteps is not None and sigmas is not None: 113 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 114 | if timesteps is not None: 115 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 116 | if not accepts_timesteps: 117 | raise ValueError( 118 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 119 | f" timestep schedules. Please check whether you are using the correct scheduler." 120 | ) 121 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 122 | timesteps = scheduler.timesteps 123 | num_inference_steps = len(timesteps) 124 | elif sigmas is not None: 125 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 126 | if not accept_sigmas: 127 | raise ValueError( 128 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 129 | f" sigmas schedules. Please check whether you are using the correct scheduler." 130 | ) 131 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 132 | timesteps = scheduler.timesteps 133 | num_inference_steps = len(timesteps) 134 | else: 135 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 136 | timesteps = scheduler.timesteps 137 | return timesteps, num_inference_steps 138 | 139 | 140 | class FluxPipeline( 141 | DiffusionPipeline, 142 | FluxLoraLoaderMixin, 143 | FromSingleFileMixin, 144 | TextualInversionLoaderMixin, 145 | ): 146 | r""" 147 | The Flux pipeline for text-to-image generation. 148 | 149 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 150 | 151 | Args: 152 | transformer ([`FluxTransformer2DModel`]): 153 | Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. 154 | scheduler ([`FlowMatchEulerDiscreteScheduler`]): 155 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 156 | vae ([`AutoencoderKL`]): 157 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 158 | text_encoder ([`CLIPTextModel`]): 159 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 160 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 161 | text_encoder_2 ([`T5EncoderModel`]): 162 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically 163 | the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. 164 | tokenizer (`CLIPTokenizer`): 165 | Tokenizer of class 166 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). 167 | tokenizer_2 (`T5TokenizerFast`): 168 | Second Tokenizer of class 169 | [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). 170 | """ 171 | 172 | model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" 173 | _optional_components = [] 174 | _callback_tensor_inputs = ["latents", "prompt_embeds"] 175 | 176 | def __init__( 177 | self, 178 | scheduler: FlowMatchEulerDiscreteScheduler, 179 | vae: AutoencoderKL, 180 | text_encoder: CLIPTextModel, 181 | tokenizer: CLIPTokenizer, 182 | text_encoder_2: T5EncoderModel, 183 | tokenizer_2: T5TokenizerFast, 184 | transformer: FluxTransformer2DModel, 185 | ): 186 | super().__init__() 187 | 188 | self.register_modules( 189 | vae=vae, 190 | text_encoder=text_encoder, 191 | text_encoder_2=text_encoder_2, 192 | tokenizer=tokenizer, 193 | tokenizer_2=tokenizer_2, 194 | transformer=transformer, 195 | scheduler=scheduler, 196 | ) 197 | self.vae_scale_factor = ( 198 | 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 199 | ) 200 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 201 | self.tokenizer_max_length = ( 202 | self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 203 | ) 204 | self.default_sample_size = 64 205 | 206 | def _get_t5_prompt_embeds( 207 | self, 208 | prompt: Union[str, List[str]] = None, 209 | num_images_per_prompt: int = 1, 210 | max_sequence_length: int = 512, 211 | device: Optional[torch.device] = None, 212 | dtype: Optional[torch.dtype] = None, 213 | ): 214 | device = device or self._execution_device 215 | dtype = dtype or self.text_encoder.dtype 216 | 217 | prompt = [prompt] if isinstance(prompt, str) else prompt 218 | batch_size = len(prompt) 219 | 220 | if isinstance(self, TextualInversionLoaderMixin): 221 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) 222 | 223 | text_inputs = self.tokenizer_2( 224 | prompt, 225 | padding="max_length", 226 | max_length=max_sequence_length, 227 | truncation=True, 228 | return_length=False, 229 | return_overflowing_tokens=False, 230 | return_tensors="pt", 231 | ) 232 | text_input_ids = text_inputs.input_ids 233 | untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids 234 | 235 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 236 | removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) 237 | logger.warning( 238 | "The following part of your input was truncated because `max_sequence_length` is set to " 239 | f" {max_sequence_length} tokens: {removed_text}" 240 | ) 241 | 242 | prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] 243 | 244 | dtype = self.text_encoder_2.dtype 245 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 246 | 247 | _, seq_len, _ = prompt_embeds.shape 248 | 249 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 250 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 251 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 252 | 253 | return prompt_embeds 254 | 255 | def _get_clip_prompt_embeds( 256 | self, 257 | prompt: Union[str, List[str]], 258 | num_images_per_prompt: int = 1, 259 | device: Optional[torch.device] = None, 260 | ): 261 | device = device or self._execution_device 262 | 263 | prompt = [prompt] if isinstance(prompt, str) else prompt 264 | batch_size = len(prompt) 265 | 266 | if isinstance(self, TextualInversionLoaderMixin): 267 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer) 268 | 269 | text_inputs = self.tokenizer( 270 | prompt, 271 | padding="max_length", 272 | max_length=self.tokenizer_max_length, 273 | truncation=True, 274 | return_overflowing_tokens=False, 275 | return_length=False, 276 | return_tensors="pt", 277 | ) 278 | 279 | text_input_ids = text_inputs.input_ids 280 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 281 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 282 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) 283 | logger.warning( 284 | "The following part of your input was truncated because CLIP can only handle sequences up to" 285 | f" {self.tokenizer_max_length} tokens: {removed_text}" 286 | ) 287 | prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) 288 | 289 | # Use pooled output of CLIPTextModel 290 | prompt_embeds = prompt_embeds.pooler_output 291 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 292 | 293 | # duplicate text embeddings for each generation per prompt, using mps friendly method 294 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) 295 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 296 | 297 | return prompt_embeds 298 | 299 | def encode_prompt( 300 | self, 301 | prompt: Union[str, List[str]], 302 | prompt_2: Union[str, List[str]], 303 | device: Optional[torch.device] = None, 304 | num_images_per_prompt: int = 1, 305 | prompt_embeds: Optional[torch.FloatTensor] = None, 306 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 307 | max_sequence_length: int = 512, 308 | lora_scale: Optional[float] = None, 309 | ): 310 | r""" 311 | 312 | Args: 313 | prompt (`str` or `List[str]`, *optional*): 314 | prompt to be encoded 315 | prompt_2 (`str` or `List[str]`, *optional*): 316 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 317 | used in all text-encoders 318 | device: (`torch.device`): 319 | torch device 320 | num_images_per_prompt (`int`): 321 | number of images that should be generated per prompt 322 | prompt_embeds (`torch.FloatTensor`, *optional*): 323 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 324 | provided, text embeddings will be generated from `prompt` input argument. 325 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 326 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 327 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 328 | lora_scale (`float`, *optional*): 329 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 330 | """ 331 | device = device or self._execution_device 332 | 333 | # set lora scale so that monkey patched LoRA 334 | # function of text encoder can correctly access it 335 | if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): 336 | self._lora_scale = lora_scale 337 | 338 | # dynamically adjust the LoRA scale 339 | if self.text_encoder is not None and USE_PEFT_BACKEND: 340 | scale_lora_layers(self.text_encoder, lora_scale) 341 | if self.text_encoder_2 is not None and USE_PEFT_BACKEND: 342 | scale_lora_layers(self.text_encoder_2, lora_scale) 343 | 344 | prompt = [prompt] if isinstance(prompt, str) else prompt 345 | 346 | if prompt_embeds is None: 347 | prompt_2 = prompt_2 or prompt 348 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 349 | 350 | # We only use the pooled prompt output from the CLIPTextModel 351 | pooled_prompt_embeds = self._get_clip_prompt_embeds( 352 | prompt=prompt, 353 | device=device, 354 | num_images_per_prompt=num_images_per_prompt, 355 | ) 356 | prompt_embeds = self._get_t5_prompt_embeds( 357 | prompt=prompt_2, 358 | num_images_per_prompt=num_images_per_prompt, 359 | max_sequence_length=max_sequence_length, 360 | device=device, 361 | ) 362 | 363 | if self.text_encoder is not None: 364 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 365 | # Retrieve the original scale by scaling back the LoRA layers 366 | unscale_lora_layers(self.text_encoder, lora_scale) 367 | 368 | if self.text_encoder_2 is not None: 369 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 370 | # Retrieve the original scale by scaling back the LoRA layers 371 | unscale_lora_layers(self.text_encoder_2, lora_scale) 372 | 373 | dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype 374 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) 375 | 376 | return prompt_embeds, pooled_prompt_embeds, text_ids 377 | 378 | def check_inputs( 379 | self, 380 | prompt, 381 | prompt_2, 382 | height, 383 | width, 384 | prompt_embeds=None, 385 | pooled_prompt_embeds=None, 386 | callback_on_step_end_tensor_inputs=None, 387 | max_sequence_length=None, 388 | ): 389 | if height % 8 != 0 or width % 8 != 0: 390 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 391 | 392 | if callback_on_step_end_tensor_inputs is not None and not all( 393 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 394 | ): 395 | raise ValueError( 396 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 397 | ) 398 | 399 | if prompt is not None and prompt_embeds is not None: 400 | raise ValueError( 401 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 402 | " only forward one of the two." 403 | ) 404 | elif prompt_2 is not None and prompt_embeds is not None: 405 | raise ValueError( 406 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 407 | " only forward one of the two." 408 | ) 409 | elif prompt is None and prompt_embeds is None: 410 | raise ValueError( 411 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 412 | ) 413 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 414 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 415 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): 416 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") 417 | 418 | if prompt_embeds is not None and pooled_prompt_embeds is None: 419 | raise ValueError( 420 | "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." 421 | ) 422 | 423 | if max_sequence_length is not None and max_sequence_length > 512: 424 | raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") 425 | 426 | @staticmethod 427 | def _prepare_latent_image_ids(batch_size, height, width, device, dtype): 428 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) 429 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] 430 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] 431 | 432 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape 433 | 434 | latent_image_ids = latent_image_ids.reshape( 435 | latent_image_id_height * latent_image_id_width, latent_image_id_channels 436 | ) 437 | 438 | return latent_image_ids.to(device=device, dtype=dtype) 439 | 440 | @staticmethod 441 | def _pack_latents(latents, batch_size, num_channels_latents, height, width): 442 | latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) 443 | latents = latents.permute(0, 2, 4, 1, 3, 5) 444 | latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) 445 | 446 | return latents 447 | 448 | @staticmethod 449 | def _unpack_latents(latents, height, width, vae_scale_factor): 450 | batch_size, num_patches, channels = latents.shape 451 | 452 | height = height // vae_scale_factor 453 | width = width // vae_scale_factor 454 | 455 | latents = latents.view(batch_size, height, width, channels // 4, 2, 2) 456 | latents = latents.permute(0, 3, 1, 4, 2, 5) 457 | 458 | latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) 459 | 460 | return latents 461 | 462 | def enable_vae_slicing(self): 463 | r""" 464 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 465 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 466 | """ 467 | self.vae.enable_slicing() 468 | 469 | def disable_vae_slicing(self): 470 | r""" 471 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to 472 | computing decoding in one step. 473 | """ 474 | self.vae.disable_slicing() 475 | 476 | def enable_vae_tiling(self): 477 | r""" 478 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 479 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow 480 | processing larger images. 481 | """ 482 | self.vae.enable_tiling() 483 | 484 | def disable_vae_tiling(self): 485 | r""" 486 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to 487 | computing decoding in one step. 488 | """ 489 | self.vae.disable_tiling() 490 | 491 | def prepare_latents( 492 | self, 493 | batch_size, 494 | num_channels_latents, 495 | height, 496 | width, 497 | dtype, 498 | device, 499 | generator, 500 | latents=None, 501 | ): 502 | height = 2 * (int(height) // self.vae_scale_factor) 503 | width = 2 * (int(width) // self.vae_scale_factor) 504 | 505 | shape = (batch_size, num_channels_latents, height, width) 506 | 507 | if latents is not None: 508 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 509 | return latents.to(device=device, dtype=dtype), latent_image_ids 510 | 511 | if isinstance(generator, list) and len(generator) != batch_size: 512 | raise ValueError( 513 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 514 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 515 | ) 516 | 517 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 518 | latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) 519 | 520 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 521 | 522 | return latents, latent_image_ids 523 | 524 | @property 525 | def guidance_scale(self): 526 | return self._guidance_scale 527 | 528 | @property 529 | def joint_attention_kwargs(self): 530 | return self._joint_attention_kwargs 531 | 532 | @property 533 | def num_timesteps(self): 534 | return self._num_timesteps 535 | 536 | @property 537 | def interrupt(self): 538 | return self._interrupt 539 | 540 | @torch.no_grad() 541 | @replace_example_docstring(EXAMPLE_DOC_STRING) 542 | def __call__( 543 | self, 544 | prompt: Union[str, List[str]] = None, 545 | prompt_2: Optional[Union[str, List[str]]] = None, 546 | height: Optional[int] = None, 547 | width: Optional[int] = None, 548 | num_inference_steps: int = 28, 549 | timesteps: List[int] = None, 550 | guidance_scale: float = 3.5, 551 | num_images_per_prompt: Optional[int] = 1, 552 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 553 | latents: Optional[torch.FloatTensor] = None, 554 | prompt_embeds: Optional[torch.FloatTensor] = None, 555 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 556 | output_type: Optional[str] = "pil", 557 | return_dict: bool = True, 558 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 559 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 560 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 561 | max_sequence_length: int = 512, 562 | ntk_factor: float = 10.0, 563 | proportional_attention: bool = True 564 | ): 565 | r""" 566 | Function invoked when calling the pipeline for generation. 567 | 568 | Args: 569 | prompt (`str` or `List[str]`, *optional*): 570 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 571 | instead. 572 | prompt_2 (`str` or `List[str]`, *optional*): 573 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 574 | will be used instead 575 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 576 | The height in pixels of the generated image. This is set to 1024 by default for the best results. 577 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 578 | The width in pixels of the generated image. This is set to 1024 by default for the best results. 579 | num_inference_steps (`int`, *optional*, defaults to 50): 580 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 581 | expense of slower inference. 582 | timesteps (`List[int]`, *optional*): 583 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 584 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 585 | passed will be used. Must be in descending order. 586 | guidance_scale (`float`, *optional*, defaults to 7.0): 587 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 588 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 589 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 590 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 591 | usually at the expense of lower image quality. 592 | num_images_per_prompt (`int`, *optional*, defaults to 1): 593 | The number of images to generate per prompt. 594 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 595 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 596 | to make generation deterministic. 597 | latents (`torch.FloatTensor`, *optional*): 598 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 599 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 600 | tensor will ge generated by sampling using the supplied random `generator`. 601 | prompt_embeds (`torch.FloatTensor`, *optional*): 602 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 603 | provided, text embeddings will be generated from `prompt` input argument. 604 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 605 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 606 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 607 | output_type (`str`, *optional*, defaults to `"pil"`): 608 | The output format of the generate image. Choose between 609 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 610 | return_dict (`bool`, *optional*, defaults to `True`): 611 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. 612 | joint_attention_kwargs (`dict`, *optional*): 613 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 614 | `self.processor` in 615 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 616 | callback_on_step_end (`Callable`, *optional*): 617 | A function that calls at the end of each denoising steps during the inference. The function is called 618 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 619 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 620 | `callback_on_step_end_tensor_inputs`. 621 | callback_on_step_end_tensor_inputs (`List`, *optional*): 622 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 623 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 624 | `._callback_tensor_inputs` attribute of your pipeline class. 625 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. 626 | 627 | Examples: 628 | 629 | Returns: 630 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` 631 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated 632 | images. 633 | """ 634 | 635 | height = height or self.default_sample_size * self.vae_scale_factor 636 | width = width or self.default_sample_size * self.vae_scale_factor 637 | 638 | # 1. Check inputs. Raise error if not correct 639 | self.check_inputs( 640 | prompt, 641 | prompt_2, 642 | height, 643 | width, 644 | prompt_embeds=prompt_embeds, 645 | pooled_prompt_embeds=pooled_prompt_embeds, 646 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 647 | max_sequence_length=max_sequence_length, 648 | ) 649 | 650 | self._guidance_scale = guidance_scale 651 | if joint_attention_kwargs is None: 652 | joint_attention_kwargs = {'proportional_attention': proportional_attention} 653 | else: 654 | joint_attention_kwargs = {**joint_attention_kwargs, 'proportional_attention': proportional_attention} 655 | self._joint_attention_kwargs = joint_attention_kwargs 656 | self._interrupt = False 657 | 658 | # 2. Define call parameters 659 | if prompt is not None and isinstance(prompt, str): 660 | batch_size = 1 661 | elif prompt is not None and isinstance(prompt, list): 662 | batch_size = len(prompt) 663 | else: 664 | batch_size = prompt_embeds.shape[0] 665 | 666 | device = self._execution_device 667 | 668 | lora_scale = ( 669 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 670 | ) 671 | ( 672 | prompt_embeds, 673 | pooled_prompt_embeds, 674 | text_ids, 675 | ) = self.encode_prompt( 676 | prompt=prompt, 677 | prompt_2=prompt_2, 678 | prompt_embeds=prompt_embeds, 679 | pooled_prompt_embeds=pooled_prompt_embeds, 680 | device=device, 681 | num_images_per_prompt=num_images_per_prompt, 682 | max_sequence_length=max_sequence_length, 683 | lora_scale=lora_scale, 684 | ) 685 | 686 | # 4. Prepare latent variables 687 | num_channels_latents = self.transformer.config.in_channels // 4 688 | latents, latent_image_ids = self.prepare_latents( 689 | batch_size * num_images_per_prompt, 690 | num_channels_latents, 691 | height, 692 | width, 693 | prompt_embeds.dtype, 694 | device, 695 | generator, 696 | latents, 697 | ) 698 | 699 | # 5. Prepare timesteps 700 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 701 | image_seq_len = latents.shape[1] 702 | mu = calculate_shift( 703 | image_seq_len, 704 | self.scheduler.config.base_image_seq_len, 705 | self.scheduler.config.max_image_seq_len, 706 | self.scheduler.config.base_shift, 707 | self.scheduler.config.max_shift, 708 | ) 709 | timesteps, num_inference_steps = retrieve_timesteps( 710 | self.scheduler, 711 | num_inference_steps, 712 | device, 713 | timesteps, 714 | sigmas, 715 | mu=mu, 716 | ) 717 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 718 | self._num_timesteps = len(timesteps) 719 | 720 | # handle guidance 721 | if self.transformer.config.guidance_embeds: 722 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) 723 | guidance = guidance.expand(latents.shape[0]) 724 | else: 725 | guidance = None 726 | 727 | # 6. Denoising loop 728 | with self.progress_bar(total=num_inference_steps) as progress_bar: 729 | for i, t in enumerate(timesteps): 730 | if self.interrupt: 731 | continue 732 | 733 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 734 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 735 | 736 | noise_pred = self.transformer( 737 | hidden_states=latents, 738 | timestep=timestep / 1000, 739 | guidance=guidance, 740 | pooled_projections=pooled_prompt_embeds, 741 | encoder_hidden_states=prompt_embeds, 742 | txt_ids=text_ids, 743 | img_ids=latent_image_ids, 744 | joint_attention_kwargs=self.joint_attention_kwargs, 745 | return_dict=False, 746 | ntk_factor=ntk_factor 747 | )[0] 748 | 749 | # compute the previous noisy sample x_t -> x_t-1 750 | latents_dtype = latents.dtype 751 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 752 | 753 | if latents.dtype != latents_dtype: 754 | #if torch.backends.mps.is_available(): 755 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 756 | latents = latents.to(latents_dtype) 757 | 758 | if callback_on_step_end is not None: 759 | callback_kwargs = {} 760 | for k in callback_on_step_end_tensor_inputs: 761 | callback_kwargs[k] = locals()[k] 762 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 763 | 764 | latents = callback_outputs.pop("latents", latents) 765 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 766 | 767 | # call the callback, if provided 768 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 769 | progress_bar.update() 770 | 771 | if XLA_AVAILABLE: 772 | xm.mark_step() 773 | 774 | if output_type == "latent": 775 | image = latents 776 | 777 | else: 778 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 779 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 780 | image = self.vae.decode(latents, return_dict=False)[0] 781 | image = self.image_processor.postprocess(image, output_type=output_type) 782 | 783 | # Offload all models 784 | self.maybe_free_model_hooks() 785 | 786 | if not return_dict: 787 | return (image,) 788 | 789 | return FluxPipelineOutput(images=image) 790 | --------------------------------------------------------------------------------