├── teaser.jpg
├── requirements.txt
├── cache_latent_codes.sh
├── cache_prompt_embeds.sh
├── train_2k.sh
├── train_4k.sh
├── inference_2k.ipynb
├── inference_2k_schnell.ipynb
├── inference_4k_lora_conversion_schnell.ipynb
├── inference_4k_lora_conversion.ipynb
├── .gitignore
├── cache_latent_codes.py
├── attention_processor.py
├── cache_prompt_embeds.py
├── README.md
├── LICENSE
├── inference_4k_schnell.ipynb
├── inference_4k.ipynb
├── transformer_flux.py
├── train_2k.py
└── pipeline_flux.py
/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Huage001/URAE/HEAD/teaser.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.1.0
2 | transformers
3 | diffusers==0.31.0
4 | sentencepiece
5 | accelerate
6 | opencv-python
7 | peft
8 | wandb
9 | jupyter
10 | patch_conv
11 | prodigyopt
--------------------------------------------------------------------------------
/cache_latent_codes.sh:
--------------------------------------------------------------------------------
1 | export NUM_WORKERS=1
2 | export DATA_DIR="path/to/data"
3 | export MODEL_NAME="black-forest-labs/FLUX.1-dev"
4 | torchrun --nproc_per_node=$NUM_WORKERS cache_latent_codes.py \
5 | --data_root=$DATA_DIR \
6 | --num_worker=$NUM_WORKERS \
7 | --pretrained_model_name_or_path=$MODEL_NAME \
8 | --mixed_precision='bf16' \
9 | --output_dir=$DATA_DIR \
10 | --resolution=4096
--------------------------------------------------------------------------------
/cache_prompt_embeds.sh:
--------------------------------------------------------------------------------
1 | export NUM_WORKERS=1
2 | export DATA_DIR="/path/to/data"
3 | export MODEL_NAME="black-forest-labs/FLUX.1-dev"
4 | torchrun --nproc_per_node=$NUM_WORKERS cache_prompt_embeds.py \
5 | --data_root=$DATA_DIR \
6 | --batch_size=256 \
7 | --num_worker=$NUM_WORKERS \
8 | --pretrained_model_name_or_path=$MODEL_NAME \
9 | --mixed_precision='bf16' \
10 | --output_dir=$DATA_DIR \
11 | --column="prompt" \
12 | --max_sequence_length=512
--------------------------------------------------------------------------------
/train_2k.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="black-forest-labs/FLUX.1-dev"
2 | export DATA_DIR="/path/to/data"
3 | export OUTPUT_DIR="/path/to/ckpt"
4 | export PRECISION="bf16"
5 |
6 | accelerate launch --num_processes 8 --multi_gpu --mixed_precision $PRECISION train_2k.py \
7 | --pretrained_model_name_or_path=$MODEL_NAME \
8 | --dataset_root=$DATA_DIR \
9 | --output_dir=$OUTPUT_DIR \
10 | --mixed_precision=$PRECISION \
11 | --dataloader_num_workers=4 \
12 | --train_batch_size=1 \
13 | --gradient_accumulation_steps=1 \
14 | --optimizer="prodigy" \
15 | --learning_rate=1. \
16 | --report_to="wandb" \
17 | --lr_scheduler="constant" \
18 | --lr_warmup_steps=0 \
19 | --max_train_steps=2000 \
20 | --seed="0" \
21 | --real_prompt_ratio=0.2 \
22 | --checkpointing_steps=1000 \
23 | --gradient_checkpointing
24 |
--------------------------------------------------------------------------------
/train_4k.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="black-forest-labs/FLUX.1-dev"
2 | export DATA_DIR="/path/to/data"
3 | export OUTPUT_DIR="/path/to/ckpt"
4 | export PRECISION="bf16"
5 |
6 | accelerate launch --num_processes 8 --multi_gpu --mixed_precision $PRECISION train_4k.py \
7 | --pretrained_model_name_or_path=$MODEL_NAME \
8 | --dataset_root=$DATA_DIR \
9 | --output_dir=$OUTPUT_DIR \
10 | --mixed_precision=$PRECISION \
11 | --dataloader_num_workers=4 \
12 | --train_batch_size=1 \
13 | --gradient_accumulation_steps=1 \
14 | --optimizer="prodigy" \
15 | --learning_rate=1. \
16 | --report_to="wandb" \
17 | --lr_scheduler="constant" \
18 | --lr_warmup_steps=0 \
19 | --max_train_steps=10000 \
20 | --seed="0" \
21 | --real_prompt_ratio=0.2 \
22 | --checkpointing_steps=1000 \
23 | --gradient_checkpointing \
24 | --ntk_factor=10 \
25 | --proportional_attention \
26 | --pretrained_lora="ckpt/urae_2k_adapter.safetensors"
27 |
--------------------------------------------------------------------------------
/inference_2k.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import torch\n",
11 | "from huggingface_hub import hf_hub_download\n",
12 | "from pipeline_flux import FluxPipeline\n",
13 | "from transformer_flux import FluxTransformer2DModel"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "bfl_repo=\"black-forest-labs/FLUX.1-dev\"\n",
23 | "device = torch.device('cuda')\n",
24 | "dtype = torch.bfloat16\n",
25 | "transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\"transformer\", torch_dtype=dtype)\n",
26 | "pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, torch_dtype=dtype)\n",
27 | "pipe.scheduler.config.use_dynamic_shifting = False\n",
28 | "pipe.scheduler.config.time_shift = 10\n",
29 | "pipe = pipe.to(device)"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": null,
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "if not os.path.exists('ckpt/urae_2k_adapter.safetensors'):\n",
39 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_2k_adapter.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)\n",
40 | "pipe.load_lora_weights(\"ckpt/urae_2k_adapter.safetensors\")"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "prompt = \"A serene woman in a flowing azure dress, gracefully perched on a sunlit cliff overlooking a tranquil sea, her hair gently tousled by the breeze. The scene is infused with a sense of peace, evoking a dreamlike atmosphere, reminiscent of Impressionist paintings.\"\n",
50 | "height = 2048\n",
51 | "width = 2048\n",
52 | "image = pipe(\n",
53 | " prompt,\n",
54 | " height=height,\n",
55 | " width=width,\n",
56 | " guidance_scale=3.5,\n",
57 | " num_inference_steps=28,\n",
58 | " max_sequence_length=512,\n",
59 | " generator=torch.manual_seed(8888),\n",
60 | " ntk_factor=10,\n",
61 | " proportional_attention=True\n",
62 | ").images[0]\n",
63 | "image"
64 | ]
65 | }
66 | ],
67 | "metadata": {
68 | "kernelspec": {
69 | "display_name": "Python 3 (ipykernel)",
70 | "language": "python",
71 | "name": "python3"
72 | }
73 | },
74 | "nbformat": 4,
75 | "nbformat_minor": 2
76 | }
77 |
--------------------------------------------------------------------------------
/inference_2k_schnell.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import torch\n",
11 | "from huggingface_hub import hf_hub_download\n",
12 | "from pipeline_flux import FluxPipeline\n",
13 | "from transformer_flux import FluxTransformer2DModel"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "bfl_repo=\"black-forest-labs/FLUX.1-schnell\"\n",
23 | "device = torch.device('cuda')\n",
24 | "dtype = torch.bfloat16\n",
25 | "transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\"transformer\", torch_dtype=dtype)\n",
26 | "pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, torch_dtype=dtype)\n",
27 | "pipe.transformer = transformer\n",
28 | "pipe.scheduler.config.use_dynamic_shifting = False\n",
29 | "pipe.scheduler.config.time_shift = 10\n",
30 | "pipe = pipe.to(device)"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "if not os.path.exists('ckpt/urae_2k_adapter.safetensors'):\n",
40 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_2k_adapter.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)\n",
41 | "pipe.load_lora_weights(\"ckpt/urae_2k_adapter.safetensors\")"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "prompt = \"A serene woman in a flowing azure dress, gracefully perched on a sunlit cliff overlooking a tranquil sea, her hair gently tousled by the breeze. The scene is infused with a sense of peace, evoking a dreamlike atmosphere, reminiscent of Impressionist paintings.\"\n",
51 | "height = 2048\n",
52 | "width = 2048\n",
53 | "image = pipe(\n",
54 | " prompt,\n",
55 | " height=height,\n",
56 | " width=width,\n",
57 | " guidance_scale=0,\n",
58 | " num_inference_steps=4,\n",
59 | " max_sequence_length=256,\n",
60 | " generator=torch.manual_seed(8888),\n",
61 | " ntk_factor=10,\n",
62 | " proportional_attention=True\n",
63 | ").images[0]\n",
64 | "image"
65 | ]
66 | }
67 | ],
68 | "metadata": {
69 | "kernelspec": {
70 | "display_name": "Python 3 (ipykernel)",
71 | "language": "python",
72 | "name": "python3"
73 | }
74 | },
75 | "nbformat": 4,
76 | "nbformat_minor": 2
77 | }
78 |
--------------------------------------------------------------------------------
/inference_4k_lora_conversion_schnell.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import torch\n",
11 | "from huggingface_hub import hf_hub_download\n",
12 | "from pipeline_flux import FluxPipeline\n",
13 | "from transformer_flux import FluxTransformer2DModel\n",
14 | "from patch_conv import convert_model"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "bfl_repo=\"black-forest-labs/FLUX.1-schnell\"\n",
24 | "device = torch.device('cuda')\n",
25 | "dtype = torch.bfloat16\n",
26 | "transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\"transformer\", torch_dtype=dtype)\n",
27 | "pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, torch_dtype=dtype)\n",
28 | "pipe.scheduler.config.use_dynamic_shifting = False\n",
29 | "pipe.scheduler.config.time_shift = 10\n",
30 | "pipe.enable_model_cpu_offload()"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "if not os.path.exists('ckpt/urae_4k_adapter_lora_conversion_schnell.safetensors'):\n",
40 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_4k_adapter_lora_conversion_schnell.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)\n",
41 | "pipe.load_lora_weights(\"ckpt/urae_4k_adapter_lora_conversion_schnell.safetensors\")"
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {},
47 | "source": [
48 | "* Use patch-wise convolution for VAE to avoid OOM error when decoding\n",
49 | "* If still OOM, try replacing the following line with `pipe.vae.enable_tiling()`"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "pipe.vae = convert_model(pipe.vae, splits=4)"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "prompt = \"A serene woman in a flowing azure dress, gracefully perched on a sunlit cliff overlooking a tranquil sea, her hair gently tousled by the breeze. The scene is infused with a sense of peace, evoking a dreamlike atmosphere, reminiscent of Impressionist paintings.\"\n",
68 | "height = 4096\n",
69 | "width = 4096\n",
70 | "image = pipe(\n",
71 | " prompt,\n",
72 | " height=height,\n",
73 | " width=width,\n",
74 | " guidance_scale=0,\n",
75 | " num_inference_steps=4,\n",
76 | " max_sequence_length=256,\n",
77 | " generator=torch.manual_seed(8888),\n",
78 | " ntk_factor=10,\n",
79 | " proportional_attention=True\n",
80 | ").images[0]\n",
81 | "image"
82 | ]
83 | }
84 | ],
85 | "metadata": {
86 | "kernelspec": {
87 | "display_name": "Python 3 (ipykernel)",
88 | "language": "python",
89 | "name": "python3"
90 | }
91 | },
92 | "nbformat": 4,
93 | "nbformat_minor": 2
94 | }
95 |
--------------------------------------------------------------------------------
/inference_4k_lora_conversion.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import torch\n",
11 | "from huggingface_hub import hf_hub_download\n",
12 | "from pipeline_flux import FluxPipeline\n",
13 | "from transformer_flux import FluxTransformer2DModel\n",
14 | "from patch_conv import convert_model"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "bfl_repo=\"black-forest-labs/FLUX.1-dev\"\n",
24 | "device = torch.device('cuda')\n",
25 | "dtype = torch.bfloat16\n",
26 | "transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\"transformer\", torch_dtype=dtype)\n",
27 | "pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, torch_dtype=dtype)\n",
28 | "pipe.scheduler.config.use_dynamic_shifting = False\n",
29 | "pipe.scheduler.config.time_shift = 10\n",
30 | "pipe.enable_model_cpu_offload()"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "if not os.path.exists('ckpt/urae_4k_adapter_lora_conversion_dev.safetensors'):\n",
40 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_4k_adapter_lora_conversion_dev.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)\n",
41 | "pipe.load_lora_weights(\"ckpt/urae_4k_adapter_lora_conversion_dev.safetensors\")"
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {},
47 | "source": [
48 | "* Use patch-wise convolution for VAE to avoid OOM error when decoding\n",
49 | "* If still OOM, try replacing the following line with `pipe.vae.enable_tiling()`"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "pipe.vae = convert_model(pipe.vae, splits=4)"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "prompt = \"A serene woman in a flowing azure dress, gracefully perched on a sunlit cliff overlooking a tranquil sea, her hair gently tousled by the breeze. The scene is infused with a sense of peace, evoking a dreamlike atmosphere, reminiscent of Impressionist paintings.\"\n",
68 | "height = 4096\n",
69 | "width = 4096\n",
70 | "image = pipe(\n",
71 | " prompt,\n",
72 | " height=height,\n",
73 | " width=width,\n",
74 | " guidance_scale=3.5,\n",
75 | " num_inference_steps=28,\n",
76 | " max_sequence_length=512,\n",
77 | " generator=torch.manual_seed(0),\n",
78 | " ntk_factor=10,\n",
79 | " proportional_attention=True\n",
80 | ").images[0]\n",
81 | "image"
82 | ]
83 | }
84 | ],
85 | "metadata": {
86 | "kernelspec": {
87 | "display_name": "Python 3 (ipykernel)",
88 | "language": "python",
89 | "name": "python3"
90 | },
91 | "language_info": {
92 | "codemirror_mode": {
93 | "name": "ipython",
94 | "version": 3
95 | },
96 | "file_extension": ".py",
97 | "mimetype": "text/x-python",
98 | "name": "python",
99 | "nbconvert_exporter": "python",
100 | "pygments_lexer": "ipython3",
101 | "version": "3.12.9"
102 | }
103 | },
104 | "nbformat": 4,
105 | "nbformat_minor": 2
106 | }
107 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # Ruff stuff:
171 | .ruff_cache/
172 |
173 | # PyPI configuration file
174 | .pypirc
175 |
--------------------------------------------------------------------------------
/cache_latent_codes.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | import math
6 | from safetensors.torch import save_file
7 | import torch
8 | import torch.nn as nn
9 | import tqdm
10 | import cv2
11 | from diffusers import AutoencoderKL
12 | from patch_conv import PatchConv2d
13 |
14 |
15 | def parse_args(input_args=None):
16 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
17 | parser.add_argument(
18 | "--pretrained_model_name_or_path",
19 | type=str,
20 | default=None,
21 | required=True,
22 | help="Path to pretrained model or model identifier from huggingface.co/models.",
23 | )
24 | parser.add_argument(
25 | "--revision",
26 | type=str,
27 | default=None,
28 | required=False,
29 | help="Revision of pretrained model identifier from huggingface.co/models.",
30 | )
31 | parser.add_argument(
32 | "--variant",
33 | type=str,
34 | default=None,
35 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
36 | )
37 | parser.add_argument(
38 | "--data_root",
39 | type=str,
40 | default=None
41 | )
42 | parser.add_argument(
43 | "--cache_dir",
44 | type=str,
45 | default=None,
46 | help="The directory where the downloaded models and datasets will be stored.",
47 | )
48 | parser.add_argument(
49 | "--output_dir",
50 | type=str,
51 | default="flux-lora",
52 | help="The output directory where the model predictions and checkpoints will be written.",
53 | )
54 | parser.add_argument(
55 | "--mixed_precision",
56 | type=str,
57 | default=None,
58 | choices=["no", "fp16", "bf16"],
59 | help=(
60 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
61 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
62 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
63 | ),
64 | )
65 | parser.add_argument("--local_rank", type=int, default=0, help="For distributed training: local_rank")
66 | parser.add_argument(
67 | "--resolution",
68 | type=int,
69 | default=None,
70 | help="Image resolution for resizing. If None, the original resolution will be used.",
71 | )
72 | parser.add_argument(
73 | "--num_workers",
74 | type=int,
75 | default=1,
76 | help="Number of workers",
77 | )
78 |
79 | if input_args is not None:
80 | args = parser.parse_args(input_args)
81 | else:
82 | args = parser.parse_args()
83 |
84 | args.local_rank = int(os.environ.get("LOCAL_RANK", 0))
85 |
86 | return args
87 |
88 |
89 | def resize(img, base_size=4096):
90 | all_sizes = np.array([
91 | [base_size, base_size],
92 | [base_size, base_size * 3 // 4],
93 | [base_size, base_size // 2],
94 | [base_size, base_size // 4]
95 | ])
96 | width, height = img.size
97 | if width < height:
98 | size = all_sizes[np.argmin(np.abs(all_sizes[:, 0] / all_sizes[:, 1] - height / width))][::-1]
99 | else:
100 | size = all_sizes[np.argmin(np.abs(all_sizes[:, 0] / all_sizes[:, 1] - width / height))]
101 | return img.resize((size[0], size[1]))
102 |
103 |
104 | def convert_model(model: nn.Module, splits: int = 4, sequential: bool=True) -> nn.Module:
105 | """
106 | Convert the convolutions in the model to PatchConv2d.
107 | """
108 | if isinstance(model, PatchConv2d):
109 | return model
110 | elif isinstance(model, nn.Conv2d) and model.kernel_size[0] > 1 and model.kernel_size[1] > 1:
111 | return PatchConv2d(splits=splits, conv2d=model)
112 | else:
113 | for name, module in model.named_modules():
114 | if isinstance(module, (nn.Conv2d, PatchConv2d)):
115 | continue
116 | if 'downsamplers' in name:
117 | continue
118 | for subname, submodule in module.named_children():
119 | if isinstance(submodule, nn.Conv2d) and submodule.kernel_size[0] > 1 and submodule.kernel_size[1] > 1:
120 | setattr(module, subname, PatchConv2d(splits=splits, sequential=sequential, conv2d=submodule))
121 | return model
122 |
123 |
124 | def main(args):
125 | if torch.cuda.is_available():
126 | torch.cuda.set_device(args.local_rank)
127 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
128 | if args.mixed_precision == 'fp16':
129 | dtype = torch.float16
130 | elif args.mixed_precision == 'bf16':
131 | dtype = torch.bfloat16
132 | else:
133 | dtype = torch.float32
134 |
135 | vae = AutoencoderKL.from_pretrained(
136 | args.pretrained_model_name_or_path,
137 | subfolder="vae",
138 | revision=args.revision,
139 | variant=args.variant,
140 | ).to(device, dtype)
141 |
142 | if args.resolution is not None and args.resolution > 3072:
143 | vae = convert_model(vae, splits=4)
144 |
145 | all_info = sorted([item for item in os.listdir(args.data_root) if item.endswith('.jpg')])
146 |
147 | os.makedirs(args.output_dir, exist_ok=True)
148 |
149 | work_load = math.ceil(len(all_info) / args.num_workers)
150 | for idx in tqdm.tqdm(range(work_load * args.local_rank, min(work_load * (args.local_rank + 1), len(all_info)))):
151 | output_path = os.path.join(args.output_dir, f"{all_info[idx][:all_info[idx].rfind('.')]}_latent_code.safetensors")
152 | img = cv2.cvtColor(cv2.imread(os.path.join(args.data_root, all_info[idx])), cv2.COLOR_BGR2RGB)
153 | img = Image.fromarray(img)
154 | if args.resolution is not None:
155 | img = resize(img, args.resolution)
156 | img = torch.from_numpy((np.array(img) / 127.5) - 1)
157 | img = img.permute(2, 0, 1)
158 | with torch.no_grad():
159 | img = img.unsqueeze(0)
160 | data = vae.encode(img.to(device, vae.dtype)).latent_dist
161 | mean = data.mean[0].cpu().data
162 | std = data.std[0].cpu().data
163 | save_file(
164 | {'mean': mean, 'std': std},
165 | output_path
166 | )
167 |
168 |
169 | if __name__ == '__main__':
170 | main(parse_args())
--------------------------------------------------------------------------------
/attention_processor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from diffusers.models.attention_processor import Attention
6 | from typing import Optional
7 | from diffusers.models.embeddings import apply_rotary_emb
8 |
9 |
10 | class FluxAttnProcessor2_0:
11 | """Attention processor used typically in processing the SD3-like self-attention projections."""
12 |
13 | def __init__(self, train_seq_len=512 + 64 * 64):
14 | if not hasattr(F, "scaled_dot_product_attention"):
15 | raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
16 | self.train_seq_len = train_seq_len
17 |
18 | def __call__(
19 | self,
20 | attn: Attention,
21 | hidden_states: torch.FloatTensor,
22 | encoder_hidden_states: torch.FloatTensor = None,
23 | attention_mask: Optional[torch.FloatTensor] = None,
24 | image_rotary_emb: Optional[torch.Tensor] = None,
25 | proportional_attention=False
26 | ) -> torch.FloatTensor:
27 | batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
28 |
29 | # `sample` projections.
30 | query = attn.to_q(hidden_states)
31 | key = attn.to_k(hidden_states)
32 | value = attn.to_v(hidden_states)
33 |
34 | inner_dim = key.shape[-1]
35 | head_dim = inner_dim // attn.heads
36 |
37 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
38 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
39 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
40 |
41 | if attn.norm_q is not None:
42 | query = attn.norm_q(query)
43 | if attn.norm_k is not None:
44 | key = attn.norm_k(key)
45 |
46 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
47 | if encoder_hidden_states is not None:
48 | # `context` projections.
49 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
50 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
51 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
52 |
53 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
54 | batch_size, -1, attn.heads, head_dim
55 | ).transpose(1, 2)
56 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
57 | batch_size, -1, attn.heads, head_dim
58 | ).transpose(1, 2)
59 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
60 | batch_size, -1, attn.heads, head_dim
61 | ).transpose(1, 2)
62 |
63 | if attn.norm_added_q is not None:
64 | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
65 | if attn.norm_added_k is not None:
66 | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
67 |
68 | # attention
69 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
70 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
71 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
72 |
73 | if image_rotary_emb is not None:
74 | query = apply_rotary_emb(query, image_rotary_emb)
75 | key = apply_rotary_emb(key, image_rotary_emb)
76 |
77 | if proportional_attention:
78 | attention_scale = math.sqrt(math.log(key.size(2), self.train_seq_len) / head_dim)
79 | else:
80 | attention_scale = math.sqrt(1 / head_dim)
81 |
82 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, scale=attention_scale)
83 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
84 | hidden_states = hidden_states.to(query.dtype)
85 |
86 | if encoder_hidden_states is not None:
87 | encoder_hidden_states, hidden_states = (
88 | hidden_states[:, : encoder_hidden_states.shape[1]],
89 | hidden_states[:, encoder_hidden_states.shape[1] :],
90 | )
91 |
92 | # linear proj
93 | hidden_states = attn.to_out[0](hidden_states)
94 | # dropout
95 | hidden_states = attn.to_out[1](hidden_states)
96 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
97 |
98 | return hidden_states, encoder_hidden_states
99 | else:
100 | return hidden_states
101 |
102 |
103 | class FluxAttnAdaptationProcessor2_0(nn.Module):
104 | """Attention processor used typically in processing the SD3-like self-attention projections."""
105 |
106 | def __init__(self, rank=16, dim=3072, to_out=False, train_seq_len=512 + 64 * 64):
107 | super().__init__()
108 | if not hasattr(F, "scaled_dot_product_attention"):
109 | raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
110 | self.to_q_a = nn.Linear(dim, rank, bias=False)
111 | self.to_q_b = nn.Linear(rank, dim, bias=False)
112 | self.to_q_b.weight.data = torch.zeros_like(self.to_q_b.weight.data)
113 | self.to_k_a = nn.Linear(dim, rank, bias=False)
114 | self.to_k_b = nn.Linear(rank, dim, bias=False)
115 | self.to_k_b.weight.data = torch.zeros_like(self.to_k_b.weight.data)
116 | self.to_v_a = nn.Linear(dim, rank, bias=False)
117 | self.to_v_b = nn.Linear(rank, dim, bias=False)
118 | self.to_v_b.weight.data = torch.zeros_like(self.to_v_b.weight.data)
119 | if to_out:
120 | self.to_out_a = nn.Linear(dim, rank, bias=False)
121 | self.to_out_b = nn.Linear(rank, dim, bias=False)
122 | self.to_out_b.weight.data = torch.zeros_like(self.to_out_b.weight.data)
123 | self.train_seq_len = train_seq_len
124 |
125 | def __call__(
126 | self,
127 | attn: Attention,
128 | hidden_states: torch.FloatTensor,
129 | encoder_hidden_states: torch.FloatTensor = None,
130 | attention_mask: Optional[torch.FloatTensor] = None,
131 | image_rotary_emb: Optional[torch.Tensor] = None,
132 | proportional_attention=False
133 | ) -> torch.FloatTensor:
134 | batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
135 |
136 | # `sample` projections.
137 | query = attn.to_q(hidden_states) + self.to_q_b(self.to_q_a(hidden_states))
138 | key = attn.to_k(hidden_states) + self.to_k_b(self.to_k_a(hidden_states))
139 | value = attn.to_v(hidden_states) + self.to_v_b(self.to_v_a(hidden_states))
140 |
141 | inner_dim = key.shape[-1]
142 | head_dim = inner_dim // attn.heads
143 |
144 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
145 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
146 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
147 |
148 | if attn.norm_q is not None:
149 | query = attn.norm_q(query)
150 | if attn.norm_k is not None:
151 | key = attn.norm_k(key)
152 |
153 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
154 | if encoder_hidden_states is not None:
155 | # `context` projections.
156 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
157 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
158 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
159 |
160 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
161 | batch_size, -1, attn.heads, head_dim
162 | ).transpose(1, 2)
163 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
164 | batch_size, -1, attn.heads, head_dim
165 | ).transpose(1, 2)
166 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
167 | batch_size, -1, attn.heads, head_dim
168 | ).transpose(1, 2)
169 |
170 | if attn.norm_added_q is not None:
171 | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
172 | if attn.norm_added_k is not None:
173 | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
174 |
175 | # attention
176 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
177 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
178 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
179 |
180 | if image_rotary_emb is not None:
181 | query = apply_rotary_emb(query, image_rotary_emb)
182 | key = apply_rotary_emb(key, image_rotary_emb)
183 |
184 | if proportional_attention:
185 | attention_scale = math.sqrt(math.log(key.size(2), self.train_seq_len) / head_dim)
186 | else:
187 | attention_scale = math.sqrt(1 / head_dim)
188 |
189 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, scale=attention_scale)
190 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
191 | hidden_states = hidden_states.to(query.dtype)
192 |
193 | if encoder_hidden_states is not None:
194 | encoder_hidden_states, hidden_states = (
195 | hidden_states[:, : encoder_hidden_states.shape[1]],
196 | hidden_states[:, encoder_hidden_states.shape[1] :],
197 | )
198 |
199 | # linear proj
200 | hidden_states = attn.to_out[0](hidden_states) + self.to_out_b(self.to_out_a(hidden_states))
201 | # dropout
202 | hidden_states = attn.to_out[1](hidden_states)
203 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
204 |
205 | return hidden_states, encoder_hidden_states
206 | else:
207 | return hidden_states
208 |
--------------------------------------------------------------------------------
/cache_prompt_embeds.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import json
4 | import math
5 | from safetensors.torch import save_file
6 | import torch
7 | import tqdm
8 | from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, T5EncoderModel
9 |
10 |
11 | def parse_args(input_args=None):
12 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
13 | parser.add_argument(
14 | "--pretrained_model_name_or_path",
15 | type=str,
16 | default=None,
17 | required=True,
18 | help="Path to pretrained model or model identifier from huggingface.co/models.",
19 | )
20 | parser.add_argument(
21 | "--revision",
22 | type=str,
23 | default=None,
24 | required=False,
25 | help="Revision of pretrained model identifier from huggingface.co/models.",
26 | )
27 | parser.add_argument(
28 | "--variant",
29 | type=str,
30 | default=None,
31 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
32 | )
33 | parser.add_argument(
34 | "--data_root",
35 | type=str,
36 | default=None
37 | )
38 | parser.add_argument(
39 | "--cache_dir",
40 | type=str,
41 | default=None,
42 | help="The directory where the downloaded models and datasets will be stored.",
43 | )
44 | parser.add_argument(
45 | "--column",
46 | type=str,
47 | default="prompt"
48 | )
49 | parser.add_argument(
50 | "--max_sequence_length",
51 | type=int,
52 | default=512,
53 | help="Maximum sequence length to use with with the T5 text encoder",
54 | )
55 | parser.add_argument(
56 | "--output_dir",
57 | type=str,
58 | default="flux-lora",
59 | help="The output directory where the model predictions and checkpoints will be written.",
60 | )
61 | parser.add_argument(
62 | "--batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
63 | )
64 | parser.add_argument(
65 | "--mixed_precision",
66 | type=str,
67 | default=None,
68 | choices=["no", "fp16", "bf16"],
69 | help=(
70 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
71 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
72 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
73 | ),
74 | )
75 | parser.add_argument("--local_rank", type=int, default=0, help="For distributed training: local_rank")
76 | parser.add_argument(
77 | "--num_workers",
78 | type=int,
79 | default=1,
80 | help="Number of workers",
81 | )
82 |
83 | if input_args is not None:
84 | args = parser.parse_args(input_args)
85 | else:
86 | args = parser.parse_args()
87 |
88 | args.local_rank = int(os.environ.get("LOCAL_RANK", 0))
89 |
90 | return args
91 |
92 |
93 | def tokenize_prompt(tokenizer, prompt, max_sequence_length):
94 | text_inputs = tokenizer(
95 | prompt,
96 | padding="max_length",
97 | max_length=max_sequence_length,
98 | truncation=True,
99 | return_length=False,
100 | return_overflowing_tokens=False,
101 | return_tensors="pt",
102 | )
103 | text_input_ids = text_inputs.input_ids
104 | return text_input_ids
105 |
106 |
107 | def _encode_prompt_with_t5(
108 | text_encoder,
109 | tokenizer,
110 | max_sequence_length=512,
111 | prompt=None,
112 | num_images_per_prompt=1,
113 | device=None,
114 | text_input_ids=None,
115 | ):
116 | prompt = [prompt] if isinstance(prompt, str) else prompt
117 | batch_size = len(prompt)
118 |
119 | if tokenizer is not None:
120 | text_inputs = tokenizer(
121 | prompt,
122 | padding="max_length",
123 | max_length=max_sequence_length,
124 | truncation=True,
125 | return_length=False,
126 | return_overflowing_tokens=False,
127 | return_tensors="pt",
128 | )
129 | text_input_ids = text_inputs.input_ids
130 | else:
131 | if text_input_ids is None:
132 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
133 |
134 | prompt_embeds = text_encoder(text_input_ids.to(device))[0]
135 |
136 | dtype = text_encoder.dtype
137 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
138 |
139 | _, seq_len, _ = prompt_embeds.shape
140 |
141 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
142 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
143 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
144 |
145 | return prompt_embeds
146 |
147 |
148 | def _encode_prompt_with_clip(
149 | text_encoder,
150 | tokenizer,
151 | prompt: str,
152 | device=None,
153 | text_input_ids=None,
154 | num_images_per_prompt: int = 1,
155 | ):
156 | prompt = [prompt] if isinstance(prompt, str) else prompt
157 | batch_size = len(prompt)
158 |
159 | if tokenizer is not None:
160 | text_inputs = tokenizer(
161 | prompt,
162 | padding="max_length",
163 | max_length=77,
164 | truncation=True,
165 | return_overflowing_tokens=False,
166 | return_length=False,
167 | return_tensors="pt",
168 | )
169 |
170 | text_input_ids = text_inputs.input_ids
171 | else:
172 | if text_input_ids is None:
173 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
174 |
175 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
176 |
177 | # Use pooled output of CLIPTextModel
178 | prompt_embeds = prompt_embeds.pooler_output
179 | prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
180 |
181 | # duplicate text embeddings for each generation per prompt, using mps friendly method
182 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
183 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
184 |
185 | return prompt_embeds
186 |
187 |
188 | def encode_prompt(
189 | text_encoders,
190 | tokenizers,
191 | prompt: str,
192 | max_sequence_length,
193 | device=None,
194 | num_images_per_prompt: int = 1,
195 | text_input_ids_list=None,
196 | ):
197 | prompt = [prompt] if isinstance(prompt, str) else prompt
198 |
199 | pooled_prompt_embeds = _encode_prompt_with_clip(
200 | text_encoder=text_encoders[0],
201 | tokenizer=tokenizers[0],
202 | prompt=prompt,
203 | device=device if device is not None else text_encoders[0].device,
204 | num_images_per_prompt=num_images_per_prompt,
205 | text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
206 | )
207 |
208 | prompt_embeds = _encode_prompt_with_t5(
209 | text_encoder=text_encoders[1],
210 | tokenizer=tokenizers[1],
211 | max_sequence_length=max_sequence_length,
212 | prompt=prompt,
213 | num_images_per_prompt=num_images_per_prompt,
214 | device=device if device is not None else text_encoders[1].device,
215 | text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
216 | )
217 |
218 | return prompt_embeds, pooled_prompt_embeds
219 |
220 |
221 | def main(args):
222 | if torch.cuda.is_available():
223 | torch.cuda.set_device(args.local_rank)
224 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
225 | if args.mixed_precision == 'fp16':
226 | dtype = torch.float16
227 | elif args.mixed_precision == 'bf16':
228 | dtype = torch.bfloat16
229 | else:
230 | dtype = torch.float32
231 | tokenizer_one = CLIPTokenizer.from_pretrained(
232 | args.pretrained_model_name_or_path,
233 | subfolder="tokenizer",
234 | revision=args.revision,
235 | variant=args.variant,
236 | cache_dir=args.cache_dir
237 | )
238 | tokenizer_two = T5TokenizerFast.from_pretrained(
239 | args.pretrained_model_name_or_path,
240 | subfolder="tokenizer_2",
241 | revision=args.revision,
242 | variant=args.variant,
243 | cache_dir=args.cache_dir
244 | )
245 |
246 | text_encoder_one = CLIPTextModel.from_pretrained(
247 | args.pretrained_model_name_or_path,
248 | revision=args.revision,
249 | subfolder="text_encoder",
250 | variant=args.variant,
251 | cache_dir=args.cache_dir
252 | ).to(device, dtype)
253 | text_encoder_two = T5EncoderModel.from_pretrained(
254 | args.pretrained_model_name_or_path,
255 | revision=args.revision,
256 | subfolder="text_encoder_2",
257 | variant=args.variant,
258 | cache_dir=args.cache_dir
259 | ).to(device, dtype)
260 | tokenizers = [tokenizer_one, tokenizer_two]
261 | text_encoders = [text_encoder_one, text_encoder_two]
262 |
263 | all_info = [os.path.join(args.data_root, i) for i in sorted(os.listdir(args.data_root)) if '.json' in i]
264 |
265 | os.makedirs(args.output_dir, exist_ok=True)
266 |
267 | work_load = math.ceil(len(all_info) / args.num_workers)
268 | for idx in tqdm.tqdm(range(work_load * args.local_rank, min(work_load * (args.local_rank + 1), len(all_info)), args.batch_size)):
269 | texts = []
270 | for item in all_info[idx:idx + args.batch_size]:
271 | with open(item) as f:
272 | prompt = json.load(f)[args.column]
273 | texts.append(prompt)
274 | if not isinstance(prompt, str):
275 | print(prompt, item)
276 | paths = [item[:item.rfind('.')] + f'_{args.column}_embed.safetensors' for item in all_info[idx:idx + args.batch_size]]
277 | with torch.no_grad():
278 | prompt_embeds, pooled_prompt_embeds = encode_prompt(
279 | text_encoders, tokenizers, texts, args.max_sequence_length
280 | )
281 | prompt_embeds = prompt_embeds.cpu().data
282 | pooled_prompt_embeds = pooled_prompt_embeds.cpu().data
283 | for path, prompt_embed, pooled_prompt_embed in zip(paths, prompt_embeds.unbind(), pooled_prompt_embeds.unbind()):
284 | save_file(
285 | {'caption_feature_t5': prompt_embed, 'caption_feature_clip': pooled_prompt_embed},
286 | path
287 | )
288 |
289 |
290 | if __name__ == '__main__':
291 | main(parse_args())
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [ICML 2025] URAE: ~~Your Free FLUX Pro Ultra~~
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | > ***U*ltra-*R*esolution *A*daptation with *E*ase**
12 | >
13 | > [Ruonan Yu*](https://scholar.google.com/citations?user=UHP95egAAAAJ&hl=en),
14 | > [Songhua Liu*](http://121.37.94.87/),
15 | > [Zhenxiong Tan](https://scholar.google.com/citations?user=HP9Be6UAAAAJ&hl=en),
16 | > and
17 | > [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
18 | >
19 | > [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore
20 | >
21 |
22 | ## 🪶Features
23 |
24 | * **Easy-to-Use High-Quality and High-Resolution Generation😊**: Ultra-Resolution Adaptation with Ease, or URAE in short, generates high-resolution images with FLUX, with minimal code modifications.
25 | * **Easy Training🚀**: URAE tames light-weight adapters with a handful of synthetic data from [FLUX1.1 Pro Ultra](https://blackforestlabs.ai/ultra-home/).
26 |
27 | ## 🔥News
28 |
29 | **[2025/05/01]** URAE is accepted by ICML2025! 🎉
30 |
31 | **[2025/03/20]** We release models and codes for both training and inference of URAE.
32 |
33 | ## Introduction
34 |
35 | Text-to-image diffusion models have achieved remarkable progress in recent years. However, training models for high-resolution image generation remains challenging, particularly when training data and computational resources are limited. In this paper, we explore this practical problem from two key perspectives: data and parameter efficiency, and propose a set of key guidelines for ultra-resolution adaptation termed *URAE*. For data efficiency, we theoretically and empirically demonstrate that synthetic data generated by some teacher models can significantly promote training convergence. For parameter efficiency, we find that tuning minor components of the weight matrices outperforms widely-used low-rank adapters when synthetic data are unavailable, offering substantial performance gains while maintaining efficiency. Additionally, for models leveraging guidance distillation, such as FLUX, we show that disabling classifier-free guidance, *i.e.*, setting the guidance scale to 1 during adaptation, is crucial for satisfactory performance. Extensive experiments validate that URAE achieves comparable 2K-generation performance to state-of-the-art closed-source models like FLUX1.1 [Pro] Ultra with only 3K samples and 2K iterations, while setting new benchmarks for 4K-resolution generation.
36 |
37 | ## Quick Start
38 |
39 | * If you have not, install [PyTorch](https://pytorch.org/get-started/locally/), [diffusers](https://huggingface.co/docs/diffusers/index), [transformers](https://huggingface.co/docs/transformers/index), and [peft](https://huggingface.co/docs/peft/index).
40 |
41 | * Clone this repo to your project directory:
42 |
43 | ``` bash
44 | git clone https://github.com/Huage001/URAE.git
45 | cd URAE
46 | ```
47 |
48 | * **You only need minimal modifications!**
49 |
50 | ```diff
51 | import torch
52 | - from diffusers import FluxPipeline
53 | + from pipeline_flux import FluxPipeline
54 | + from transformer_flux import FluxTransformer2DModel
55 |
56 | bfl_repo = "black-forest-labs/FLUX.1-dev"
57 | + transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder="transformer", torch_dtype=torch.bfloat16)
58 | - pipe = FluxPipeline.from_pretrained(bfl_repo, torch_dtype=torch.bfloat16)
59 | + pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, torch_dtype=torch.bfloat16)
60 | + pipe.scheduler.config.use dynamic_shifting = False
61 | + pipe.scheduler.config.time shift = 10
62 | pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
63 |
64 | + pipe.load_lora_weights("Huage001/URAE", weight_name="urae_2k_adapter.safetensors")
65 |
66 | prompt = "An astronaut riding a green horse"
67 | image = pipe(
68 | prompt,
69 | - height=1024,
70 | - width=1024,
71 | + height=2048,
72 | + width=2048,
73 | guidance_scale=3.5,
74 | num_inference_steps=50,
75 | max_sequence_length=512,
76 | generator=torch.Generator("cpu").manual_seed(0)
77 | ).images[0]
78 | image.save("flux-urae.png")
79 | ```
80 | ⚠️ **FLUX requires at least 28GB of GPU memory to operate at a 2K resolution.** A 48GB GPU is recommended for the full functionalities of URAE, including both 2K and 4K. We are actively integrating model lightweighting strategies into URAE! If you have a good idea, feel free to submit a PR!
81 |
82 | * Do not want to run the codes? No worry! Try the model on Huggingface Space!
83 |
84 | * [URAE w. FLUX1.schnell](https://huggingface.co/spaces/Yuanshi/URAE) (Faster)
85 | * [URAE w. FLUX1.dev](https://huggingface.co/spaces/Yuanshi/URAE_dev) (Higher Quality)
86 |
87 | ## Installation
88 |
89 | * Clone this repo to your project directory:
90 |
91 | ``` bash
92 | git clone https://github.com/Huage001/URAE.git
93 | cd URAE
94 | ```
95 |
96 | * URAE has been tested on ``torch==2.5.1`` and ``diffusers==0.31.0``, but it should also be compatible to similar variants. You can set up a new environment if you wish and install packages listed in ``requirements.txt``:
97 |
98 | ```bash
99 | conda create -n URAE python=3.12
100 | conda activate URAE
101 | pip install -r requirements.txt
102 | ```
103 |
104 | ## Inference
105 |
106 | #### 2K Resolution, or Use Cases Similar to FLUX Pro Ultra
107 |
108 | * We use **LoRA adapters** to adapt FLUX 1.dev to 2K resolution. You can try the corresponding URAE model in `inference_2k.ipynb`.
109 |
110 | * We also support FLUX 1.schnell for a **faster inference process**. Please refer to `inference_2k_schnell.ipynb`.
111 |
112 | #### 4K Resolution or Higher
113 |
114 | * Instead of LoRA, we use **minor-component adapters** at the 4K stage. You can try the models for FLUX 1.dev and FLUX 1.schnell in `inference_4k.ipynb` and `inference_4k_schnell.ipynb`.
115 | * Alternatively, although these adapters are different from LoRA, it can also be converted to the LoRA format, so that we can apply interfaces provided by `peft` for a more convenient usage. **If you only want to try the models instead of understanding the principle,** `inference_4k_lora_conversion.ipynb` **and** `inference_4k_lora_conversion_schnell.ipynb` **are what you want!** Their outputs should be equivalent to the above counterparts.
116 |
117 | 🚧 The 4K model is still in beta and the performance may not be stable. A more recommended use case for the current 4K model is to integrate it with some training-free high-resolution generation pipelines based on coarse-to-fine strategies, such as [SDEdit](https://arxiv.org/abs/2108.01073) (refer to [this repo](https://github.com/Huage001/CLEAR/blob/main/inference_t2i_highres.ipynb) for a sample usage) and [I-Max](https://github.com/PRIS-CV/I-Max), and load the 4K adapter at high resolution stages.
118 |
119 |
120 | ## Training
121 |
122 | 1. Prepare training data.
123 |
124 | * To train the 2K model, we collect 3,000 images generated by [FLUX1.1 Pro Ultra](https://blackforestlabs.ai/ultra-home/). You can use your API and follow the instructions on acquiring images.
125 |
126 | * To train the 4K model, we use ~16,000 images with resolution greater than 4K from [LAION-High-Resolution](https://huggingface.co/datasets/laion/laion-high-resolution).
127 |
128 | * The training data folder should be organized to the following format:
129 |
130 | ```
131 | train_data
132 | ├── image_0.jpg
133 | ├── image_0.json
134 | ├── image_1.jpg
135 | ├── image_1.json
136 | ├── ...
137 | ```
138 |
139 | The json file should contain a dictionary with entries `prompt` and/or `generated_prompt`, to specify the original caption and generated caption with detailed descriptions by VLM such as GPT-4o or Llama3.
140 |
141 | * The T5 and VAE can take a large amount of GPU memory, which can trigger OOM when training at high resolutions. Therefore, we pre-cache the T5 and VAE features instead of computing them online:
142 |
143 | ```bash
144 | bash cache_prompt_embeds.sh
145 | bash cache_latent_codes.sh
146 | ```
147 |
148 | Make sure to modify these bash files beforehand and configure the number of processes in parallel (`$NUM_WORKERS`), training data folder (`$DATA_DIR`), the target resolution (`--resolution`), and the key to the prompt in json files (`--column`, can be `prompt` or `generated_prompt`).
149 |
150 | * The final format of the training data folder should be:
151 |
152 | ```
153 | train_data
154 | ├── image_0_generated_prompt_embed.safetensors
155 | ├── image_0.jpg
156 | ├── image_0.json
157 | ├── image_0_latent_code.safetensors
158 | ├── image_0_prompt_embed.safetensors
159 | ├── image_1_generated_prompt_embed.safetensors
160 | ├── image_1.jpg
161 | ├── image_1.json
162 | ├── image_1_latent_code.safetensors
163 | ├── image_1_prompt_embed.safetensors
164 | ├── ...
165 | ```
166 |
167 | 2. Let's start training! For 2K model, make sure to modify `train_2k.sh` and configure the training data folder (`$DATA_DIR`), output folder (`$OUTPUT_DIR`), and the number of GPUs (`--num_processes`).
168 |
169 | ```bash
170 | bash train_2k.sh
171 | ```
172 |
173 | 3. 4K model is based on the 2K LoRA trained previously. Make sure to modify `train_4k.sh` to configure the path (`--pretrained_lora`) in addition to the aforementioned items.
174 |
175 | ```bash
176 | bash train_4k.sh
177 | ```
178 |
179 | ⚠️ Currently, if the training images have different resolutions, the `--train_batch_size` can only be 1 because we did not customize data sampler to handle this case. Nevertheless, in practice, with such large resolutions, the maximal batch size per GPU (before gradient accumulation) can only be 1 even on 80G GPUs 😂.
180 |
181 | ## Acknowledgement
182 |
183 | * [FLUX](https://blackforestlabs.ai/announcing-black-forest-labs/) for the source models.
184 | * [diffusers](https://github.com/huggingface/diffusers), [CLEAR](https://github.com/Huage001/CLEAR), and [I-Max](https://github.com/PRIS-CV/I-Max) for the code base.
185 | * [patch_conv](https://github.com/mit-han-lab/patch_conv) for solution to VAE OOM error at high resolutions.
186 | * [@Xinyin Ma](https://github.com/horseee) for valuable discussions.
187 | * NUS IT’s Research Computing group using grant numbers NUSREC-HPC-00001.
188 |
189 | ## Citation
190 |
191 | If you finds this repo is helpful, please consider citing:
192 |
193 | ```bib
194 | @article{yu2025urae,
195 | title = {Ultra-Resolution Adaptation with Ease},
196 | author = {Yu, Ruonan and Liu, Songhua and Tan, Zhenxiong and Wang, Xinchao},
197 | journal = {International Conference on Machine Learning},
198 | year = {2025},
199 | }
200 | ```
201 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/inference_4k_schnell.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import torch\n",
11 | "from huggingface_hub import hf_hub_download\n",
12 | "from pipeline_flux import FluxPipeline\n",
13 | "from transformer_flux import FluxTransformer2DModel\n",
14 | "from attention_processor import FluxAttnAdaptationProcessor2_0\n",
15 | "from safetensors.torch import load_file, save_file\n",
16 | "from patch_conv import convert_model"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "bfl_repo=\"black-forest-labs/FLUX.1-schnell\"\n",
26 | "device = torch.device('cuda')\n",
27 | "dtype = torch.bfloat16\n",
28 | "transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\"transformer\")\n",
29 | "pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, torch_dtype=dtype)\n",
30 | "pipe.scheduler.config.use_dynamic_shifting = False\n",
31 | "pipe.scheduler.config.time_shift = 10\n",
32 | "pipe.enable_model_cpu_offload()"
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "metadata": {},
38 | "source": [
39 | "* 4K model is based on 2K LoRA"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "if not os.path.exists('ckpt/urae_2k_adapter.safetensors'):\n",
49 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_2k_adapter.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)\n",
50 | "pipe.load_lora_weights(\"ckpt/urae_2k_adapter.safetensors\")\n",
51 | "pipe.fuse_lora()"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "metadata": {},
57 | "source": [
58 | "* Substitute original attention processors"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "rank = 16\n",
68 | "attn_processors = {}\n",
69 | "for k in pipe.transformer.attn_processors.keys():\n",
70 | " attn_processors[k] = FluxAttnAdaptationProcessor2_0(rank=rank, to_out='single' not in k).to(device, dtype)\n",
71 | "pipe.transformer.set_attn_processor(attn_processors)"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {},
77 | "source": [
78 | "* If no cached major components, compute them via SVD and save them to cache_path\n",
79 | "* If you don't want to save cached major components, simply set `cache_path = None`"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "cache_path = 'ckpt/_urae_4k_adapter_schnell.safetensors'\n",
89 | "if cache_path is not None and os.path.exists(cache_path):\n",
90 | " pipe.transformer.to(dtype=dtype)\n",
91 | " pipe.transformer.load_state_dict(load_file(cache_path), strict=False)\n",
92 | "else:\n",
93 | " with torch.no_grad():\n",
94 | " for idx in range(len(pipe.transformer.transformer_blocks)):\n",
95 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_q.weight.data.to(device)\n",
96 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n",
97 | " pipe.transformer.transformer_blocks[idx].attn.to_q.weight.data = (\n",
98 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n",
99 | " ).to('cpu')\n",
100 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_q_b.weight.data = (\n",
101 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n",
102 | " ).to('cpu')\n",
103 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_q_a.weight.data = (\n",
104 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n",
105 | " ).to('cpu')\n",
106 | "\n",
107 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_k.weight.data.to(device)\n",
108 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n",
109 | " pipe.transformer.transformer_blocks[idx].attn.to_k.weight.data = (\n",
110 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n",
111 | " ).to('cpu')\n",
112 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_k_b.weight.data = (\n",
113 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n",
114 | " ).to('cpu')\n",
115 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_k_a.weight.data = (\n",
116 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n",
117 | " ).to('cpu')\n",
118 | "\n",
119 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_v.weight.data.to(device)\n",
120 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n",
121 | " pipe.transformer.transformer_blocks[idx].attn.to_v.weight.data = (\n",
122 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n",
123 | " ).to('cpu')\n",
124 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_v_b.weight.data = (\n",
125 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n",
126 | " ).to('cpu')\n",
127 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_v_a.weight.data = (\n",
128 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n",
129 | " ).to('cpu')\n",
130 | "\n",
131 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_out[0].weight.data.to(device)\n",
132 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n",
133 | " pipe.transformer.transformer_blocks[idx].attn.to_out[0].weight.data = (\n",
134 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n",
135 | " ).to('cpu')\n",
136 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_out_b.weight.data = (\n",
137 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n",
138 | " ).to('cpu')\n",
139 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_out_a.weight.data = (\n",
140 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n",
141 | " ).to('cpu')\n",
142 | " for idx in range(len(pipe.transformer.single_transformer_blocks)):\n",
143 | " matrix_w = pipe.transformer.single_transformer_blocks[idx].attn.to_q.weight.data.to(device)\n",
144 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n",
145 | " pipe.transformer.single_transformer_blocks[idx].attn.to_q.weight.data = (\n",
146 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n",
147 | " ).to('cpu')\n",
148 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_q_b.weight.data = (\n",
149 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n",
150 | " ).to('cpu')\n",
151 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_q_a.weight.data = (\n",
152 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n",
153 | " ).to('cpu')\n",
154 | "\n",
155 | " matrix_w = pipe.transformer.single_transformer_blocks[idx].attn.to_k.weight.data.to(device)\n",
156 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n",
157 | " pipe.transformer.single_transformer_blocks[idx].attn.to_k.weight.data = (\n",
158 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n",
159 | " ).to('cpu')\n",
160 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_k_b.weight.data = (\n",
161 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n",
162 | " ).to('cpu')\n",
163 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_k_a.weight.data = (\n",
164 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n",
165 | " ).to('cpu')\n",
166 | "\n",
167 | " matrix_w = pipe.transformer.single_transformer_blocks[idx].attn.to_v.weight.data.to(device)\n",
168 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n",
169 | " pipe.transformer.single_transformer_blocks[idx].attn.to_v.weight.data = (\n",
170 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n",
171 | " ).to('cpu')\n",
172 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_v_b.weight.data = (\n",
173 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n",
174 | " ).to('cpu')\n",
175 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_v_a.weight.data = (\n",
176 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n",
177 | " ).to('cpu')\n",
178 | " pipe.transformer.to(dtype=dtype)\n",
179 | " if cache_path is not None:\n",
180 | " state_dict = pipe.transformer.state_dict()\n",
181 | " attn_state_dict = {}\n",
182 | " for k in state_dict.keys():\n",
183 | " if 'base_layer' in k:\n",
184 | " attn_state_dict[k] = state_dict[k]\n",
185 | " save_file(attn_state_dict, cache_path)"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "metadata": {},
191 | "source": [
192 | "* Download pre-trained 4k adapter"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": null,
198 | "metadata": {},
199 | "outputs": [],
200 | "source": [
201 | "if not os.path.exists('ckpt/urae_4k_adapter.safetensors'):\n",
202 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_4k_adapter.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)"
203 | ]
204 | },
205 | {
206 | "cell_type": "markdown",
207 | "metadata": {},
208 | "source": [
209 | "* Optionally, you can convert the minor-component adapter into a LoRA for easier use"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": null,
215 | "metadata": {},
216 | "outputs": [],
217 | "source": [
218 | "lora_conversion = True\n",
219 | "if lora_conversion and not os.path.exists('ckpt/urae_4k_adapter_lora_conversion_schnell.safetensors'):\n",
220 | " cur = pipe.transformer.state_dict()\n",
221 | " tgt = load_file('ckpt/urae_4k_adapter.safetensors')\n",
222 | " ref = load_file('ckpt/urae_2k_adapter.safetensors')\n",
223 | " new_ckpt = {}\n",
224 | " for k in tgt.keys():\n",
225 | " if 'to_k_a' in k:\n",
226 | " k_ = 'transformer.' + k.replace('.processor.to_k_a', '.to_k.lora_A')\n",
227 | " elif 'to_k_b' in k:\n",
228 | " k_ = 'transformer.' + k.replace('.processor.to_k_b', '.to_k.lora_B')\n",
229 | " elif 'to_q_a' in k:\n",
230 | " k_ = 'transformer.' + k.replace('.processor.to_q_a', '.to_q.lora_A')\n",
231 | " elif 'to_q_b' in k:\n",
232 | " k_ = 'transformer.' + k.replace('.processor.to_q_b', '.to_q.lora_B')\n",
233 | " elif 'to_v_a' in k:\n",
234 | " k_ = 'transformer.' + k.replace('.processor.to_v_a', '.to_v.lora_A')\n",
235 | " elif 'to_v_b' in k:\n",
236 | " k_ = 'transformer.' + k.replace('.processor.to_v_b', '.to_v.lora_B')\n",
237 | " elif 'to_out_a' in k:\n",
238 | " k_ = 'transformer.' + k.replace('.processor.to_out_a', '.to_out.0.lora_A')\n",
239 | " elif 'to_out_b' in k:\n",
240 | " k_ = 'transformer.' + k.replace('.processor.to_out_b', '.to_out.0.lora_B')\n",
241 | " else:\n",
242 | " print(k)\n",
243 | " assert False\n",
244 | " if '_a.' in k and '_b.' not in k:\n",
245 | " new_ckpt[k_] = torch.cat([-cur[k], tgt[k], ref[k_]], dim=0)\n",
246 | " elif '_b.' in k and '_a.' not in k:\n",
247 | " new_ckpt[k_] = torch.cat([cur[k], tgt[k], ref[k_]], dim=1)\n",
248 | " else:\n",
249 | " print(k)\n",
250 | " assert False\n",
251 | " save_file(new_ckpt, 'ckpt/urae_4k_adapter_lora_conversion_schnell.safetensors')"
252 | ]
253 | },
254 | {
255 | "cell_type": "markdown",
256 | "metadata": {},
257 | "source": [
258 | "* Load state_dict of 4k adapter"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": null,
264 | "metadata": {},
265 | "outputs": [],
266 | "source": [
267 | "m, u = pipe.transformer.load_state_dict(load_file('ckpt/urae_4k_adapter.safetensors'), strict=False)\n",
268 | "assert len(u) == 0"
269 | ]
270 | },
271 | {
272 | "cell_type": "markdown",
273 | "metadata": {},
274 | "source": [
275 | "* Use patch-wise convolution for VAE to avoid OOM error when decoding\n",
276 | "* If still OOM, try replacing the following line with `pipe.vae.enable_tiling()`"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": null,
282 | "metadata": {},
283 | "outputs": [],
284 | "source": [
285 | "pipe.vae = convert_model(pipe.vae, splits=4)"
286 | ]
287 | },
288 | {
289 | "cell_type": "markdown",
290 | "metadata": {},
291 | "source": [
292 | "* Everything ready. Let's generate!"
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": null,
298 | "metadata": {},
299 | "outputs": [],
300 | "source": [
301 | "prompt = \"A serene woman in a flowing azure dress, gracefully perched on a sunlit cliff overlooking a tranquil sea, her hair gently tousled by the breeze. The scene is infused with a sense of peace, evoking a dreamlike atmosphere, reminiscent of Impressionist paintings.\"\n",
302 | "height = 4096\n",
303 | "width = 4096\n",
304 | "image = pipe(\n",
305 | " prompt,\n",
306 | " height=height,\n",
307 | " width=width,\n",
308 | " guidance_scale=0,\n",
309 | " num_inference_steps=4,\n",
310 | " max_sequence_length=256,\n",
311 | " generator=torch.manual_seed(8888),\n",
312 | " ntk_factor=10,\n",
313 | " proportional_attention=True\n",
314 | ").images[0]\n",
315 | "image"
316 | ]
317 | }
318 | ],
319 | "metadata": {
320 | "kernelspec": {
321 | "display_name": "Python 3 (ipykernel)",
322 | "language": "python",
323 | "name": "python3"
324 | },
325 | "language_info": {
326 | "codemirror_mode": {
327 | "name": "ipython",
328 | "version": 3
329 | },
330 | "file_extension": ".py",
331 | "mimetype": "text/x-python",
332 | "name": "python",
333 | "nbconvert_exporter": "python",
334 | "pygments_lexer": "ipython3",
335 | "version": "3.12.9"
336 | }
337 | },
338 | "nbformat": 4,
339 | "nbformat_minor": 2
340 | }
341 |
--------------------------------------------------------------------------------
/inference_4k.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import torch\n",
11 | "from huggingface_hub import hf_hub_download\n",
12 | "from pipeline_flux import FluxPipeline\n",
13 | "from transformer_flux import FluxTransformer2DModel\n",
14 | "from attention_processor import FluxAttnAdaptationProcessor2_0\n",
15 | "from safetensors.torch import load_file, save_file\n",
16 | "from patch_conv import convert_model"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "bfl_repo=\"black-forest-labs/FLUX.1-dev\"\n",
26 | "device = torch.device('cuda')\n",
27 | "dtype = torch.bfloat16\n",
28 | "transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\"transformer\")\n",
29 | "pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=transformer, torch_dtype=dtype)\n",
30 | "pipe.scheduler.config.use_dynamic_shifting = False\n",
31 | "pipe.scheduler.config.time_shift = 10\n",
32 | "pipe.enable_model_cpu_offload()"
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "metadata": {},
38 | "source": [
39 | "* 4K model is based on 2K LoRA"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "if not os.path.exists('ckpt/urae_2k_adapter.safetensors'):\n",
49 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_2k_adapter.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)\n",
50 | "pipe.load_lora_weights(\"ckpt/urae_2k_adapter.safetensors\")\n",
51 | "pipe.fuse_lora()"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "metadata": {},
57 | "source": [
58 | "* Substitute original attention processors"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "rank = 16\n",
68 | "attn_processors = {}\n",
69 | "for k in pipe.transformer.attn_processors.keys():\n",
70 | " attn_processors[k] = FluxAttnAdaptationProcessor2_0(rank=rank, to_out='single' not in k)\n",
71 | "pipe.transformer.set_attn_processor(attn_processors)"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {},
77 | "source": [
78 | "* If no cached major components, compute them via SVD and save them to cache_path\n",
79 | "* If you don't want to save cached major components, simply set `cache_path = None`"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "cache_path = 'ckpt/_urae_4k_adapter_dev.safetensors'\n",
89 | "if cache_path is not None and os.path.exists(cache_path):\n",
90 | " pipe.transformer.to(dtype=dtype)\n",
91 | " pipe.transformer.load_state_dict(load_file(cache_path), strict=False)\n",
92 | "else:\n",
93 | " with torch.no_grad():\n",
94 | " for idx in range(len(pipe.transformer.transformer_blocks)):\n",
95 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_q.weight.data.to(device)\n",
96 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n",
97 | " pipe.transformer.transformer_blocks[idx].attn.to_q.weight.data = (\n",
98 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n",
99 | " ).to('cpu')\n",
100 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_q_b.weight.data = (\n",
101 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n",
102 | " ).to('cpu')\n",
103 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_q_a.weight.data = (\n",
104 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n",
105 | " ).to('cpu')\n",
106 | "\n",
107 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_k.weight.data.to(device)\n",
108 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n",
109 | " pipe.transformer.transformer_blocks[idx].attn.to_k.weight.data = (\n",
110 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n",
111 | " ).to('cpu')\n",
112 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_k_b.weight.data = (\n",
113 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n",
114 | " ).to('cpu')\n",
115 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_k_a.weight.data = (\n",
116 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n",
117 | " ).to('cpu')\n",
118 | "\n",
119 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_v.weight.data.to(device)\n",
120 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n",
121 | " pipe.transformer.transformer_blocks[idx].attn.to_v.weight.data = (\n",
122 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n",
123 | " ).to('cpu')\n",
124 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_v_b.weight.data = (\n",
125 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n",
126 | " ).to('cpu')\n",
127 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_v_a.weight.data = (\n",
128 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n",
129 | " ).to('cpu')\n",
130 | "\n",
131 | " matrix_w = pipe.transformer.transformer_blocks[idx].attn.to_out[0].weight.data.to(device)\n",
132 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n",
133 | " pipe.transformer.transformer_blocks[idx].attn.to_out[0].weight.data = (\n",
134 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n",
135 | " ).to('cpu')\n",
136 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_out_b.weight.data = (\n",
137 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n",
138 | " ).to('cpu')\n",
139 | " pipe.transformer.transformer_blocks[idx].attn.processor.to_out_a.weight.data = (\n",
140 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n",
141 | " ).to('cpu')\n",
142 | " for idx in range(len(pipe.transformer.single_transformer_blocks)):\n",
143 | " matrix_w = pipe.transformer.single_transformer_blocks[idx].attn.to_q.weight.data.to(device)\n",
144 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n",
145 | " pipe.transformer.single_transformer_blocks[idx].attn.to_q.weight.data = (\n",
146 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n",
147 | " ).to('cpu')\n",
148 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_q_b.weight.data = (\n",
149 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n",
150 | " ).to('cpu')\n",
151 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_q_a.weight.data = (\n",
152 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n",
153 | " ).to('cpu')\n",
154 | "\n",
155 | " matrix_w = pipe.transformer.single_transformer_blocks[idx].attn.to_k.weight.data.to(device)\n",
156 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n",
157 | " pipe.transformer.single_transformer_blocks[idx].attn.to_k.weight.data = (\n",
158 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n",
159 | " ).to('cpu')\n",
160 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_k_b.weight.data = (\n",
161 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n",
162 | " ).to('cpu')\n",
163 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_k_a.weight.data = (\n",
164 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n",
165 | " ).to('cpu')\n",
166 | "\n",
167 | " matrix_w = pipe.transformer.single_transformer_blocks[idx].attn.to_v.weight.data.to(device)\n",
168 | " matrix_u, matrix_s, matrix_v = torch.linalg.svd(matrix_w)\n",
169 | " pipe.transformer.single_transformer_blocks[idx].attn.to_v.weight.data = (\n",
170 | " matrix_u[:, :-rank] @ torch.diag(matrix_s[:-rank]) @ matrix_v[:-rank, :]\n",
171 | " ).to('cpu')\n",
172 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_v_b.weight.data = (\n",
173 | " matrix_u[:, -rank:] @ torch.diag(torch.sqrt(matrix_s[-rank:]))\n",
174 | " ).to('cpu')\n",
175 | " pipe.transformer.single_transformer_blocks[idx].attn.processor.to_v_a.weight.data = (\n",
176 | " torch.diag(torch.sqrt(matrix_s[-rank:])) @ matrix_v[-rank:, :]\n",
177 | " ).to('cpu')\n",
178 | " pipe.transformer.to(dtype=dtype)\n",
179 | " if cache_path is not None:\n",
180 | " state_dict = pipe.transformer.state_dict()\n",
181 | " attn_state_dict = {}\n",
182 | " for k in state_dict.keys():\n",
183 | " if 'base_layer' in k:\n",
184 | " attn_state_dict[k] = state_dict[k]\n",
185 | " save_file(attn_state_dict, cache_path)"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "metadata": {},
191 | "source": [
192 | "* Download pre-trained 4k adapter"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": null,
198 | "metadata": {},
199 | "outputs": [],
200 | "source": [
201 | "if not os.path.exists('ckpt/urae_4k_adapter.safetensors'):\n",
202 | " hf_hub_download(repo_id=\"Huage001/URAE\", filename='urae_4k_adapter.safetensors', local_dir='ckpt', local_dir_use_symlinks=False)"
203 | ]
204 | },
205 | {
206 | "cell_type": "markdown",
207 | "metadata": {},
208 | "source": [
209 | "* Optionally, you can convert the minor-component adapter into a LoRA for easier use"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": null,
215 | "metadata": {},
216 | "outputs": [],
217 | "source": [
218 | "lora_conversion = True\n",
219 | "if lora_conversion and not os.path.exists('ckpt/urae_4k_adapter_lora_conversion_dev.safetensors'):\n",
220 | " cur = pipe.transformer.state_dict()\n",
221 | " tgt = load_file('ckpt/urae_4k_adapter.safetensors')\n",
222 | " ref = load_file('ckpt/urae_2k_adapter.safetensors')\n",
223 | " new_ckpt = {}\n",
224 | " for k in tgt.keys():\n",
225 | " if 'to_k_a' in k:\n",
226 | " k_ = 'transformer.' + k.replace('.processor.to_k_a', '.to_k.lora_A')\n",
227 | " elif 'to_k_b' in k:\n",
228 | " k_ = 'transformer.' + k.replace('.processor.to_k_b', '.to_k.lora_B')\n",
229 | " elif 'to_q_a' in k:\n",
230 | " k_ = 'transformer.' + k.replace('.processor.to_q_a', '.to_q.lora_A')\n",
231 | " elif 'to_q_b' in k:\n",
232 | " k_ = 'transformer.' + k.replace('.processor.to_q_b', '.to_q.lora_B')\n",
233 | " elif 'to_v_a' in k:\n",
234 | " k_ = 'transformer.' + k.replace('.processor.to_v_a', '.to_v.lora_A')\n",
235 | " elif 'to_v_b' in k:\n",
236 | " k_ = 'transformer.' + k.replace('.processor.to_v_b', '.to_v.lora_B')\n",
237 | " elif 'to_out_a' in k:\n",
238 | " k_ = 'transformer.' + k.replace('.processor.to_out_a', '.to_out.0.lora_A')\n",
239 | " elif 'to_out_b' in k:\n",
240 | " k_ = 'transformer.' + k.replace('.processor.to_out_b', '.to_out.0.lora_B')\n",
241 | " else:\n",
242 | " print(k)\n",
243 | " assert False\n",
244 | " if '_a.' in k and '_b.' not in k:\n",
245 | " new_ckpt[k_] = torch.cat([-cur[k], tgt[k], ref[k_]], dim=0)\n",
246 | " elif '_b.' in k and '_a.' not in k:\n",
247 | " new_ckpt[k_] = torch.cat([cur[k], tgt[k], ref[k_]], dim=1)\n",
248 | " else:\n",
249 | " print(k)\n",
250 | " assert False\n",
251 | " save_file(new_ckpt, 'ckpt/urae_4k_adapter_lora_conversion_dev.safetensors')"
252 | ]
253 | },
254 | {
255 | "cell_type": "markdown",
256 | "metadata": {},
257 | "source": [
258 | "* Load state_dict of 4k adapter"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": null,
264 | "metadata": {},
265 | "outputs": [],
266 | "source": [
267 | "m, u = pipe.transformer.load_state_dict(load_file('ckpt/urae_4k_adapter.safetensors'), strict=False)\n",
268 | "assert len(u) == 0"
269 | ]
270 | },
271 | {
272 | "cell_type": "markdown",
273 | "metadata": {},
274 | "source": [
275 | "* Use patch-wise convolution for VAE to avoid OOM error when decoding\n",
276 | "* If still OOM, try replacing the following line with `pipe.vae.enable_tiling()`"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": null,
282 | "metadata": {},
283 | "outputs": [],
284 | "source": [
285 | "pipe.vae = convert_model(pipe.vae, splits=4)"
286 | ]
287 | },
288 | {
289 | "cell_type": "markdown",
290 | "metadata": {},
291 | "source": [
292 | "* Everything ready. Let's generate!\n",
293 | "* 4K generation using FLUX-1.dev can take a while, e.g., 5min on H100."
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": null,
299 | "metadata": {},
300 | "outputs": [],
301 | "source": [
302 | "prompt = \"A serene woman in a flowing azure dress, gracefully perched on a sunlit cliff overlooking a tranquil sea, her hair gently tousled by the breeze. The scene is infused with a sense of peace, evoking a dreamlike atmosphere, reminiscent of Impressionist paintings.\"\n",
303 | "height = 4096\n",
304 | "width = 4096\n",
305 | "image = pipe(\n",
306 | " prompt,\n",
307 | " height=height,\n",
308 | " width=width,\n",
309 | " guidance_scale=3.5,\n",
310 | " num_inference_steps=28,\n",
311 | " max_sequence_length=512,\n",
312 | " generator=torch.manual_seed(8888),\n",
313 | " ntk_factor=10,\n",
314 | " proportional_attention=True\n",
315 | ").images[0]\n",
316 | "image"
317 | ]
318 | }
319 | ],
320 | "metadata": {
321 | "kernelspec": {
322 | "display_name": "Python 3 (ipykernel)",
323 | "language": "python",
324 | "name": "python3"
325 | },
326 | "language_info": {
327 | "codemirror_mode": {
328 | "name": "ipython",
329 | "version": 3
330 | },
331 | "file_extension": ".py",
332 | "mimetype": "text/x-python",
333 | "name": "python",
334 | "nbconvert_exporter": "python",
335 | "pygments_lexer": "ipython3",
336 | "version": "3.12.9"
337 | }
338 | },
339 | "nbformat": 4,
340 | "nbformat_minor": 2
341 | }
342 |
--------------------------------------------------------------------------------
/transformer_flux.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional, Tuple, Union, List
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from diffusers.configuration_utils import ConfigMixin, register_to_config
9 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
10 | from diffusers.models.attention import FeedForward
11 | from diffusers.models.attention_processor import (
12 | Attention,
13 | AttentionProcessor
14 | )
15 | from diffusers.models.modeling_utils import ModelMixin
16 | from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
17 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
18 | from diffusers.utils.torch_utils import maybe_allow_in_graph
19 | from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed
20 | from diffusers.models.modeling_outputs import Transformer2DModelOutput
21 | from attention_processor import FluxAttnProcessor2_0
22 |
23 |
24 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25 |
26 |
27 | @maybe_allow_in_graph
28 | class FluxSingleTransformerBlock(nn.Module):
29 | r"""
30 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
31 |
32 | Reference: https://arxiv.org/abs/2403.03206
33 |
34 | Parameters:
35 | dim (`int`): The number of channels in the input and output.
36 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
37 | attention_head_dim (`int`): The number of channels in each head.
38 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
39 | processing of `context` conditions.
40 | """
41 |
42 | def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
43 | super().__init__()
44 | self.mlp_hidden_dim = int(dim * mlp_ratio)
45 |
46 | self.norm = AdaLayerNormZeroSingle(dim)
47 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
48 | self.act_mlp = nn.GELU(approximate="tanh")
49 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
50 |
51 | processor = FluxAttnProcessor2_0()
52 | self.attn = Attention(
53 | query_dim=dim,
54 | cross_attention_dim=None,
55 | dim_head=attention_head_dim,
56 | heads=num_attention_heads,
57 | out_dim=dim,
58 | bias=True,
59 | processor=processor,
60 | qk_norm="rms_norm",
61 | eps=1e-6,
62 | pre_only=True,
63 | )
64 |
65 | def forward(
66 | self,
67 | hidden_states: torch.FloatTensor,
68 | temb: torch.FloatTensor,
69 | image_rotary_emb=None,
70 | joint_attention_kwargs=None
71 | ):
72 | residual = hidden_states
73 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
74 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
75 | joint_attention_kwargs = joint_attention_kwargs or {}
76 | attn_output = self.attn(
77 | hidden_states=norm_hidden_states,
78 | image_rotary_emb=image_rotary_emb,
79 | **joint_attention_kwargs,
80 | )
81 |
82 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
83 | gate = gate.unsqueeze(1)
84 | hidden_states = gate * self.proj_out(hidden_states)
85 | hidden_states = residual + hidden_states
86 | if hidden_states.dtype == torch.float16:
87 | hidden_states = hidden_states.clip(-65504, 65504)
88 |
89 | return hidden_states
90 |
91 |
92 | @maybe_allow_in_graph
93 | class FluxTransformerBlock(nn.Module):
94 | r"""
95 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
96 |
97 | Reference: https://arxiv.org/abs/2403.03206
98 |
99 | Parameters:
100 | dim (`int`): The number of channels in the input and output.
101 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
102 | attention_head_dim (`int`): The number of channels in each head.
103 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
104 | processing of `context` conditions.
105 | """
106 |
107 | def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
108 | super().__init__()
109 |
110 | self.norm1 = AdaLayerNormZero(dim)
111 |
112 | self.norm1_context = AdaLayerNormZero(dim)
113 |
114 | if hasattr(F, "scaled_dot_product_attention"):
115 | processor = FluxAttnProcessor2_0()
116 | else:
117 | raise ValueError(
118 | "The current PyTorch version does not support the `scaled_dot_product_attention` function."
119 | )
120 | self.attn = Attention(
121 | query_dim=dim,
122 | cross_attention_dim=None,
123 | added_kv_proj_dim=dim,
124 | dim_head=attention_head_dim,
125 | heads=num_attention_heads,
126 | out_dim=dim,
127 | context_pre_only=False,
128 | bias=True,
129 | processor=processor,
130 | qk_norm=qk_norm,
131 | eps=eps,
132 | )
133 |
134 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
135 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
136 |
137 | self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
138 | self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
139 |
140 | # let chunk size default to None
141 | self._chunk_size = None
142 | self._chunk_dim = 0
143 |
144 | def forward(
145 | self,
146 | hidden_states: torch.FloatTensor,
147 | encoder_hidden_states: torch.FloatTensor,
148 | temb: torch.FloatTensor,
149 | image_rotary_emb=None,
150 | joint_attention_kwargs=None,
151 | ):
152 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
153 |
154 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
155 | encoder_hidden_states, emb=temb
156 | )
157 | joint_attention_kwargs = joint_attention_kwargs or {}
158 | # Attention.
159 | attn_output, context_attn_output = self.attn(
160 | hidden_states=norm_hidden_states,
161 | encoder_hidden_states=norm_encoder_hidden_states,
162 | image_rotary_emb=image_rotary_emb,
163 | **joint_attention_kwargs,
164 | )
165 |
166 | # Process attention outputs for the `hidden_states`.
167 | attn_output = gate_msa.unsqueeze(1) * attn_output
168 | hidden_states = hidden_states + attn_output
169 |
170 | norm_hidden_states = self.norm2(hidden_states)
171 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
172 |
173 | ff_output = self.ff(norm_hidden_states)
174 | ff_output = gate_mlp.unsqueeze(1) * ff_output
175 |
176 | hidden_states = hidden_states + ff_output
177 |
178 | # Process attention outputs for the `encoder_hidden_states`.
179 |
180 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
181 | encoder_hidden_states = encoder_hidden_states + context_attn_output
182 |
183 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
184 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
185 |
186 | context_ff_output = self.ff_context(norm_encoder_hidden_states)
187 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
188 | if encoder_hidden_states.dtype == torch.float16:
189 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
190 |
191 | return encoder_hidden_states, hidden_states
192 |
193 |
194 | class FluxPosEmbed(nn.Module):
195 | # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
196 | def __init__(self, theta: int, axes_dim: List[int]):
197 | super().__init__()
198 | self.theta = theta
199 | self.axes_dim = axes_dim
200 |
201 | def forward(self, ids: torch.Tensor, ntk_factor=1) -> torch.Tensor:
202 | n_axes = ids.shape[-1]
203 | cos_out = []
204 | sin_out = []
205 | pos = ids.float()
206 | is_mps = ids.device.type == "mps"
207 | freqs_dtype = torch.float32 if is_mps else torch.float64
208 | for i in range(n_axes):
209 | cos, sin = get_1d_rotary_pos_embed(
210 | self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype,
211 | ntk_factor=ntk_factor
212 | )
213 | cos_out.append(cos)
214 | sin_out.append(sin)
215 | freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
216 | freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
217 | return freqs_cos, freqs_sin
218 |
219 |
220 | class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
221 | """
222 | The Transformer model introduced in Flux.
223 |
224 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
225 |
226 | Parameters:
227 | patch_size (`int`): Patch size to turn the input data into small patches.
228 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
229 | num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
230 | num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
231 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
232 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
233 | joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
234 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
235 | guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
236 | """
237 |
238 | _supports_gradient_checkpointing = True
239 | _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
240 |
241 | @register_to_config
242 | def __init__(
243 | self,
244 | patch_size: int = 1,
245 | in_channels: int = 64,
246 | num_layers: int = 19,
247 | num_single_layers: int = 38,
248 | attention_head_dim: int = 128,
249 | num_attention_heads: int = 24,
250 | joint_attention_dim: int = 4096,
251 | pooled_projection_dim: int = 768,
252 | guidance_embeds: bool = False,
253 | axes_dims_rope: Tuple[int] = (16, 56, 56),
254 | ):
255 | super().__init__()
256 | self.out_channels = in_channels
257 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
258 |
259 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
260 |
261 | text_time_guidance_cls = (
262 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
263 | )
264 | self.time_text_embed = text_time_guidance_cls(
265 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
266 | )
267 |
268 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
269 | self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
270 |
271 | self.transformer_blocks = nn.ModuleList(
272 | [
273 | FluxTransformerBlock(
274 | dim=self.inner_dim,
275 | num_attention_heads=self.config.num_attention_heads,
276 | attention_head_dim=self.config.attention_head_dim,
277 | )
278 | for i in range(self.config.num_layers)
279 | ]
280 | )
281 |
282 | self.single_transformer_blocks = nn.ModuleList(
283 | [
284 | FluxSingleTransformerBlock(
285 | dim=self.inner_dim,
286 | num_attention_heads=self.config.num_attention_heads,
287 | attention_head_dim=self.config.attention_head_dim,
288 | )
289 | for i in range(self.config.num_single_layers)
290 | ]
291 | )
292 |
293 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
294 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
295 |
296 | self.gradient_checkpointing = False
297 |
298 | @property
299 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
300 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
301 | r"""
302 | Returns:
303 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
304 | indexed by its weight name.
305 | """
306 | # set recursively
307 | processors = {}
308 |
309 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
310 | if hasattr(module, "get_processor"):
311 | processors[f"{name}.processor"] = module.get_processor()
312 |
313 | for sub_name, child in module.named_children():
314 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
315 |
316 | return processors
317 |
318 | for name, module in self.named_children():
319 | fn_recursive_add_processors(name, module, processors)
320 |
321 | return processors
322 |
323 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
324 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
325 | r"""
326 | Sets the attention processor to use to compute attention.
327 |
328 | Parameters:
329 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
330 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
331 | for **all** `Attention` layers.
332 |
333 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
334 | processor. This is strongly recommended when setting trainable attention processors.
335 |
336 | """
337 | count = len(self.attn_processors.keys())
338 |
339 | if isinstance(processor, dict) and len(processor) != count:
340 | raise ValueError(
341 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
342 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
343 | )
344 |
345 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
346 | if hasattr(module, "set_processor"):
347 | if not isinstance(processor, dict):
348 | module.set_processor(processor)
349 | else:
350 | module.set_processor(processor.pop(f"{name}.processor"))
351 |
352 | for sub_name, child in module.named_children():
353 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
354 |
355 | for name, module in self.named_children():
356 | fn_recursive_attn_processor(name, module, processor)
357 |
358 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
359 | def unfuse_qkv_projections(self):
360 | """Disables the fused QKV projection if enabled.
361 |
362 |
363 |
364 | This API is 🧪 experimental.
365 |
366 |
367 |
368 | """
369 | if self.original_attn_processors is not None:
370 | self.set_attn_processor(self.original_attn_processors)
371 |
372 | def _set_gradient_checkpointing(self, module, value=False):
373 | if hasattr(module, "gradient_checkpointing"):
374 | module.gradient_checkpointing = value
375 |
376 | def forward(
377 | self,
378 | hidden_states: torch.Tensor,
379 | encoder_hidden_states: torch.Tensor = None,
380 | pooled_projections: torch.Tensor = None,
381 | timestep: torch.LongTensor = None,
382 | img_ids: torch.Tensor = None,
383 | txt_ids: torch.Tensor = None,
384 | guidance: torch.Tensor = None,
385 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
386 | controlnet_block_samples=None,
387 | controlnet_single_block_samples=None,
388 | return_dict: bool = True,
389 | ntk_factor: float = 1,
390 | controlnet_blocks_repeat: bool = False,
391 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
392 | """
393 | The [`FluxTransformer2DModel`] forward method.
394 |
395 | Args:
396 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
397 | Input `hidden_states`.
398 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
399 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
400 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
401 | from the embeddings of input conditions.
402 | timestep ( `torch.LongTensor`):
403 | Used to indicate denoising step.
404 | block_controlnet_hidden_states: (`list` of `torch.Tensor`):
405 | A list of tensors that if specified are added to the residuals of transformer blocks.
406 | joint_attention_kwargs (`dict`, *optional*):
407 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
408 | `self.processor` in
409 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
410 | return_dict (`bool`, *optional*, defaults to `True`):
411 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
412 | tuple.
413 |
414 | Returns:
415 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
416 | `tuple` where the first element is the sample tensor.
417 | """
418 |
419 | if txt_ids.ndim == 3:
420 | logger.warning(
421 | "Passing `txt_ids` 3d torch.Tensor is deprecated."
422 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
423 | )
424 | txt_ids = txt_ids[0]
425 | if img_ids.ndim == 3:
426 | logger.warning(
427 | "Passing `img_ids` 3d torch.Tensor is deprecated."
428 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
429 | )
430 | img_ids = img_ids[0]
431 |
432 | if joint_attention_kwargs is not None:
433 | joint_attention_kwargs = joint_attention_kwargs.copy()
434 | lora_scale = joint_attention_kwargs.pop("scale", 1.0)
435 | else:
436 | lora_scale = 1.0
437 |
438 | if USE_PEFT_BACKEND:
439 | # weight the lora layers by setting `lora_scale` for each PEFT layer
440 | scale_lora_layers(self, lora_scale)
441 | else:
442 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
443 | logger.warning(
444 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
445 | )
446 | hidden_states = self.x_embedder(hidden_states)
447 |
448 | timestep = timestep.to(hidden_states.dtype) * 1000
449 | if guidance is not None:
450 | guidance = guidance.to(hidden_states.dtype) * 1000
451 | else:
452 | guidance = None
453 | temb = (
454 | self.time_text_embed(timestep, pooled_projections)
455 | if guidance is None
456 | else self.time_text_embed(timestep, guidance, pooled_projections)
457 | )
458 | encoder_hidden_states = self.context_embedder(encoder_hidden_states)
459 |
460 | ids = torch.cat((txt_ids, img_ids), dim=0)
461 | image_rotary_emb = self.pos_embed(ids, ntk_factor=ntk_factor)
462 |
463 | for index_block, block in enumerate(self.transformer_blocks):
464 | if self.training and self.gradient_checkpointing:
465 |
466 | def create_custom_forward(module, return_dict=None):
467 | def custom_forward(*inputs):
468 | if return_dict is not None:
469 | return module(*inputs, return_dict=return_dict)
470 | else:
471 | return module(*inputs)
472 |
473 | return custom_forward
474 |
475 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
476 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
477 | create_custom_forward(block),
478 | hidden_states,
479 | encoder_hidden_states,
480 | temb,
481 | image_rotary_emb,
482 | joint_attention_kwargs,
483 | **ckpt_kwargs,
484 | )
485 |
486 | else:
487 | encoder_hidden_states, hidden_states = block(
488 | hidden_states=hidden_states,
489 | encoder_hidden_states=encoder_hidden_states,
490 | temb=temb,
491 | image_rotary_emb=image_rotary_emb,
492 | joint_attention_kwargs=joint_attention_kwargs,
493 | )
494 |
495 | # controlnet residual
496 | if controlnet_block_samples is not None:
497 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
498 | interval_control = int(np.ceil(interval_control))
499 | # For Xlabs ControlNet.
500 | if controlnet_blocks_repeat:
501 | hidden_states = (
502 | hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
503 | )
504 | else:
505 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
506 |
507 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
508 |
509 | for index_block, block in enumerate(self.single_transformer_blocks):
510 | if self.training and self.gradient_checkpointing:
511 |
512 | def create_custom_forward(module, return_dict=None):
513 | def custom_forward(*inputs):
514 | if return_dict is not None:
515 | return module(*inputs, return_dict=return_dict)
516 | else:
517 | return module(*inputs)
518 |
519 | return custom_forward
520 |
521 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
522 | hidden_states = torch.utils.checkpoint.checkpoint(
523 | create_custom_forward(block),
524 | hidden_states,
525 | temb,
526 | image_rotary_emb,
527 | joint_attention_kwargs,
528 | **ckpt_kwargs,
529 | )
530 |
531 | else:
532 | hidden_states = block(
533 | hidden_states=hidden_states,
534 | temb=temb,
535 | image_rotary_emb=image_rotary_emb,
536 | joint_attention_kwargs=joint_attention_kwargs,
537 | )
538 |
539 | # controlnet residual
540 | if controlnet_single_block_samples is not None:
541 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
542 | interval_control = int(np.ceil(interval_control))
543 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
544 | hidden_states[:, encoder_hidden_states.shape[1] :, ...]
545 | + controlnet_single_block_samples[index_block // interval_control]
546 | )
547 |
548 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
549 |
550 | hidden_states = self.norm_out(hidden_states, temb)
551 | output = self.proj_out(hidden_states)
552 |
553 | if USE_PEFT_BACKEND:
554 | # remove `lora_scale` from each PEFT layer
555 | unscale_lora_layers(self, lora_scale)
556 |
557 | if not return_dict:
558 | return (output,)
559 |
560 | return Transformer2DModelOutput(sample=output)
561 |
--------------------------------------------------------------------------------
/train_2k.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import logging
4 | import os
5 | import random
6 | import shutil
7 | from pathlib import Path
8 | from safetensors.torch import load_file
9 | import torch
10 | import torch.utils.checkpoint
11 | import transformers
12 | from accelerate import Accelerator
13 | from accelerate.logging import get_logger
14 | from accelerate.utils import ProjectConfiguration, set_seed
15 | from peft import LoraConfig, set_peft_model_state_dict
16 | from peft.utils import get_peft_model_state_dict
17 | from torch.utils.data import Dataset
18 | from tqdm.auto import tqdm
19 |
20 | import diffusers
21 | from diffusers import (
22 | AutoencoderKL,
23 | FlowMatchEulerDiscreteScheduler
24 | )
25 | from diffusers.optimization import get_scheduler
26 | from diffusers.training_utils import (
27 | cast_training_params,
28 | compute_density_for_timestep_sampling,
29 | compute_loss_weighting_for_sd3,
30 | free_memory,
31 | )
32 | from diffusers.utils import (
33 | check_min_version,
34 | convert_unet_state_dict_to_peft,
35 | is_wandb_available,
36 | )
37 | from diffusers.utils.torch_utils import is_compiled_module
38 | from transformer_flux import FluxTransformer2DModel
39 | from pipeline_flux import FluxPipeline
40 | from attention_processor import FluxAttnProcessor2_0
41 |
42 |
43 | if is_wandb_available():
44 | import wandb
45 |
46 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
47 | check_min_version("0.31.0.dev0")
48 |
49 | logger = get_logger(__name__)
50 |
51 |
52 | def parse_args(input_args=None):
53 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
54 | parser.add_argument(
55 | "--pretrained_model_name_or_path",
56 | type=str,
57 | default=None,
58 | required=True,
59 | help="Path to pretrained model or model identifier from huggingface.co/models.",
60 | )
61 | parser.add_argument(
62 | "--revision",
63 | type=str,
64 | default=None,
65 | required=False,
66 | help="Revision of pretrained model identifier from huggingface.co/models.",
67 | )
68 | parser.add_argument(
69 | "--variant",
70 | type=str,
71 | default=None,
72 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
73 | )
74 | parser.add_argument(
75 | "--dataset_root",
76 | type=str,
77 | default=None,
78 | help=(
79 | "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
80 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
81 | " or to a folder containing files that 🤗 Datasets can understand."
82 | ),
83 | )
84 |
85 | parser.add_argument(
86 | "--cache_dir",
87 | type=str,
88 | default=None,
89 | help="The directory where the downloaded models and datasets will be stored.",
90 | )
91 |
92 | parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
93 |
94 | parser.add_argument(
95 | "--real_prompt_ratio",
96 | type=float,
97 | default=0.5
98 | )
99 | parser.add_argument(
100 | "--max_sequence_length",
101 | type=int,
102 | default=512,
103 | help="Maximum sequence length to use with with the T5 text encoder",
104 | )
105 | parser.add_argument(
106 | "--rank",
107 | type=int,
108 | default=16,
109 | help=("The dimension of the LoRA update matrices."),
110 | )
111 | parser.add_argument(
112 | "--output_dir",
113 | type=str,
114 | default="urae_2k",
115 | help="The output directory where the model predictions and checkpoints will be written.",
116 | )
117 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
118 | parser.add_argument(
119 | "--proportional_attention",
120 | action='store_true',
121 | help="Dynamic attention scale with respect to the resolution",
122 | )
123 | parser.add_argument(
124 | "--ntk_factor",
125 | type=float,
126 | default=1.,
127 | help="NTK factor for RePE"
128 | )
129 | parser.add_argument(
130 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
131 | )
132 | parser.add_argument(
133 | "--max_train_steps",
134 | type=int,
135 | default=10000
136 | )
137 | parser.add_argument(
138 | "--checkpointing_steps",
139 | type=int,
140 | default=1000,
141 | help=(
142 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
143 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
144 | " training using `--resume_from_checkpoint`."
145 | ),
146 | )
147 | parser.add_argument(
148 | "--checkpoints_total_limit",
149 | type=int,
150 | default=None,
151 | help=("Max number of checkpoints to store."),
152 | )
153 | parser.add_argument(
154 | "--resume_from_checkpoint",
155 | type=str,
156 | default=None,
157 | help=(
158 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
159 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
160 | ),
161 | )
162 | parser.add_argument(
163 | "--gradient_accumulation_steps",
164 | type=int,
165 | default=1,
166 | help="Number of updates steps to accumulate before performing a backward/update pass.",
167 | )
168 | parser.add_argument(
169 | "--gradient_checkpointing",
170 | action="store_true",
171 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
172 | )
173 | parser.add_argument(
174 | "--learning_rate",
175 | type=float,
176 | default=1e-4,
177 | help="Initial learning rate (after the potential warmup period) to use.",
178 | )
179 |
180 | parser.add_argument(
181 | "--guidance_scale",
182 | type=float,
183 | default=1,
184 | help="the FLUX.1 dev variant is a guidance distilled model",
185 | )
186 |
187 | parser.add_argument(
188 | "--scale_lr",
189 | action="store_true",
190 | default=False,
191 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
192 | )
193 | parser.add_argument(
194 | "--lr_scheduler",
195 | type=str,
196 | default="constant",
197 | help=(
198 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
199 | ' "constant", "constant_with_warmup"]'
200 | ),
201 | )
202 | parser.add_argument(
203 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
204 | )
205 | parser.add_argument(
206 | "--lr_num_cycles",
207 | type=int,
208 | default=1,
209 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
210 | )
211 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
212 | parser.add_argument(
213 | "--dataloader_num_workers",
214 | type=int,
215 | default=0,
216 | help=(
217 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
218 | ),
219 | )
220 | parser.add_argument(
221 | "--weighting_scheme",
222 | type=str,
223 | default="none",
224 | choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
225 | help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
226 | )
227 | parser.add_argument(
228 | "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
229 | )
230 | parser.add_argument(
231 | "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
232 | )
233 | parser.add_argument(
234 | "--mode_scale",
235 | type=float,
236 | default=1.29,
237 | help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
238 | )
239 | parser.add_argument(
240 | "--optimizer",
241 | type=str,
242 | default="AdamW",
243 | help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
244 | )
245 |
246 | parser.add_argument(
247 | "--use_8bit_adam",
248 | action="store_true",
249 | help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
250 | )
251 |
252 | parser.add_argument(
253 | "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
254 | )
255 | parser.add_argument(
256 | "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
257 | )
258 | parser.add_argument(
259 | "--prodigy_beta3",
260 | type=float,
261 | default=None,
262 | help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
263 | "uses the value of square root of beta2. Ignored if optimizer is adamW",
264 | )
265 | parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
266 | parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
267 |
268 | parser.add_argument(
269 | "--adam_epsilon",
270 | type=float,
271 | default=1e-08,
272 | help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
273 | )
274 |
275 | parser.add_argument(
276 | "--prodigy_use_bias_correction",
277 | type=bool,
278 | default=True,
279 | help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
280 | )
281 | parser.add_argument(
282 | "--prodigy_safeguard_warmup",
283 | type=bool,
284 | default=True,
285 | help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
286 | "Ignored if optimizer is adamW",
287 | )
288 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
289 | parser.add_argument(
290 | "--logging_dir",
291 | type=str,
292 | default="logs",
293 | help=(
294 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
295 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
296 | ),
297 | )
298 | parser.add_argument(
299 | "--allow_tf32",
300 | action="store_true",
301 | help=(
302 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
303 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
304 | ),
305 | )
306 | parser.add_argument(
307 | "--report_to",
308 | type=str,
309 | default="tensorboard",
310 | help=(
311 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
312 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
313 | ),
314 | )
315 | parser.add_argument(
316 | "--mixed_precision",
317 | type=str,
318 | default=None,
319 | choices=["no", "fp16", "bf16"],
320 | help=(
321 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
322 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
323 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
324 | ),
325 | )
326 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
327 |
328 | if input_args is not None:
329 | args = parser.parse_args(input_args)
330 | else:
331 | args = parser.parse_args()
332 |
333 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
334 | if env_local_rank != -1 and env_local_rank != args.local_rank:
335 | args.local_rank = env_local_rank
336 |
337 | return args
338 |
339 |
340 | class CustomImageDataset(Dataset):
341 | def __init__(self, img_dir, real_prompt_ratio=0.5):
342 | self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
343 | self.real_prompt_ratio = real_prompt_ratio
344 |
345 | def __len__(self):
346 | return len(self.images)
347 |
348 | def __getitem__(self, idx):
349 | try:
350 | batch = {}
351 | f = load_file(self.images[idx][:self.images[idx].rfind('.')] + '_latent_code.safetensors')
352 | batch['latent_codes_mean'] = f['mean']
353 | batch['latent_codes_std'] = f['std']
354 | prompt_embeds_path = self.images[idx][:self.images[idx].rfind('.')] + '_prompt_embed.safetensors'
355 | generated_prompt_embeds_path = self.images[idx][:self.images[idx].rfind('.')] + '_generated_prompt_embed.safetensors'
356 | if (not os.path.exists(generated_prompt_embeds_path) or random.random() < self.real_prompt_ratio) and os.path.exists(prompt_embeds_path):
357 | f = load_file(prompt_embeds_path)
358 | else:
359 | f = load_file(generated_prompt_embeds_path)
360 | batch['prompt_embeds_t5'] = f['caption_feature_t5']
361 | batch['prompt_embeds_clip'] = f['caption_feature_clip']
362 | return batch
363 | except Exception as e:
364 | print(e)
365 | return self.__getitem__(random.randint(0, len(self.images) - 1))
366 |
367 |
368 | def main(args):
369 | if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
370 | # due to pytorch#99272, MPS does not yet support bfloat16.
371 | raise ValueError(
372 | "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
373 | )
374 |
375 | logging_dir = Path(args.output_dir, args.logging_dir)
376 |
377 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
378 | accelerator = Accelerator(
379 | gradient_accumulation_steps=args.gradient_accumulation_steps,
380 | mixed_precision=args.mixed_precision,
381 | log_with=args.report_to,
382 | project_config=accelerator_project_config
383 | )
384 |
385 | # Disable AMP for MPS.
386 | if torch.backends.mps.is_available():
387 | accelerator.native_amp = False
388 |
389 | if args.report_to == "wandb":
390 | if not is_wandb_available():
391 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
392 |
393 | # Make one log on every process with the configuration for debugging.
394 | logging.basicConfig(
395 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
396 | datefmt="%m/%d/%Y %H:%M:%S",
397 | level=logging.INFO,
398 | )
399 | logger.info(accelerator.state, main_process_only=False)
400 | if accelerator.is_local_main_process:
401 | transformers.utils.logging.set_verbosity_warning()
402 | diffusers.utils.logging.set_verbosity_info()
403 | else:
404 | transformers.utils.logging.set_verbosity_error()
405 | diffusers.utils.logging.set_verbosity_error()
406 |
407 | # If passed along, set the training seed now.
408 | if args.seed is not None:
409 | set_seed(args.seed)
410 |
411 | # Handle the repository creation
412 | if accelerator.is_main_process:
413 | if args.output_dir is not None:
414 | os.makedirs(args.output_dir, exist_ok=True)
415 |
416 | # Load scheduler and models
417 | noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
418 | args.pretrained_model_name_or_path, subfolder="scheduler"
419 | )
420 | noise_scheduler_copy = copy.deepcopy(noise_scheduler)
421 | vae = AutoencoderKL.from_pretrained(
422 | args.pretrained_model_name_or_path,
423 | subfolder="vae",
424 | revision=args.revision,
425 | variant=args.variant,
426 | )
427 | transformer = FluxTransformer2DModel.from_pretrained(
428 | args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
429 | )
430 |
431 | attn_processors = {}
432 | for k in transformer.attn_processors.keys():
433 | attn_processors[k] = FluxAttnProcessor2_0()
434 | transformer.set_attn_processor(attn_processors)
435 |
436 | # We only train the additional adapter LoRA layers
437 | transformer.requires_grad_(False)
438 | vae.requires_grad_(False)
439 |
440 | # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
441 | # as these weights are only used for inference, keeping weights in full precision is not required.
442 | weight_dtype = torch.float32
443 | if accelerator.mixed_precision == "fp16":
444 | weight_dtype = torch.float16
445 | elif accelerator.mixed_precision == "bf16":
446 | weight_dtype = torch.bfloat16
447 |
448 | if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
449 | # due to pytorch#99272, MPS does not yet support bfloat16.
450 | raise ValueError(
451 | "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
452 | )
453 |
454 | vae.to(accelerator.device, dtype=weight_dtype)
455 | transformer.to(accelerator.device, dtype=weight_dtype)
456 |
457 | if args.gradient_checkpointing:
458 | transformer.enable_gradient_checkpointing()
459 |
460 | # now we will add new LoRA weights to the attention layers
461 | transformer_lora_config = LoraConfig(
462 | r=args.rank,
463 | lora_alpha=args.rank,
464 | init_lora_weights="gaussian",
465 | target_modules=["to_k", "to_q", "to_v", "to_out.0"],
466 | )
467 | transformer.add_adapter(transformer_lora_config)
468 |
469 | def unwrap_model(model):
470 | model = accelerator.unwrap_model(model)
471 | model = model._orig_mod if is_compiled_module(model) else model
472 | return model
473 |
474 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
475 | def save_model_hook(models, weights, output_dir):
476 | if accelerator.is_main_process:
477 | transformer_lora_layers_to_save = None
478 | text_encoder_one_lora_layers_to_save = None
479 |
480 | for model in models:
481 | if isinstance(model, type(unwrap_model(transformer))):
482 | transformer_lora_layers_to_save = get_peft_model_state_dict(model)
483 | else:
484 | raise ValueError(f"unexpected save model: {model.__class__}")
485 |
486 | # make sure to pop weight so that corresponding model is not saved again
487 | weights.pop()
488 |
489 | FluxPipeline.save_lora_weights(
490 | output_dir,
491 | transformer_lora_layers=transformer_lora_layers_to_save,
492 | text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
493 | )
494 |
495 | def load_model_hook(models, input_dir):
496 | transformer_ = None
497 |
498 | while len(models) > 0:
499 | model = models.pop()
500 |
501 | if isinstance(model, type(unwrap_model(transformer))):
502 | transformer_ = model
503 | else:
504 | raise ValueError(f"unexpected save model: {model.__class__}")
505 |
506 | lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
507 |
508 | transformer_state_dict = {
509 | f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
510 | }
511 | transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
512 | incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
513 | if incompatible_keys is not None:
514 | # check only for unexpected keys
515 | unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
516 | if unexpected_keys:
517 | logger.warning(
518 | f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
519 | f" {unexpected_keys}. "
520 | )
521 | # Make sure the trainable params are in float32. This is again needed since the base models
522 | # are in `weight_dtype`. More details:
523 | # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
524 | if args.mixed_precision == "fp16":
525 | models = [transformer_]
526 | # only upcast trainable parameters (LoRA) into fp32
527 | cast_training_params(models)
528 |
529 | accelerator.register_save_state_pre_hook(save_model_hook)
530 | accelerator.register_load_state_pre_hook(load_model_hook)
531 |
532 | # Enable TF32 for faster training on Ampere GPUs,
533 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
534 | if args.allow_tf32 and torch.cuda.is_available():
535 | torch.backends.cuda.matmul.allow_tf32 = True
536 |
537 | if args.scale_lr:
538 | args.learning_rate = (
539 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
540 | )
541 |
542 | # Make sure the trainable params are in float32.
543 | if args.mixed_precision == "fp16":
544 | models = [transformer]
545 | # only upcast trainable parameters (LoRA) into fp32
546 | cast_training_params(models, dtype=torch.float32)
547 |
548 | transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
549 |
550 | # Optimization parameters
551 | transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
552 |
553 | params_to_optimize = [transformer_parameters_with_lr]
554 |
555 | # Optimizer creation
556 | if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
557 | logger.warning(
558 | f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
559 | "Defaulting to adamW"
560 | )
561 | args.optimizer = "adamw"
562 |
563 | if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
564 | logger.warning(
565 | f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
566 | f"set to {args.optimizer.lower()}"
567 | )
568 |
569 | if args.optimizer.lower() == "adamw":
570 | if args.use_8bit_adam:
571 | try:
572 | import bitsandbytes as bnb
573 | except ImportError:
574 | raise ImportError(
575 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
576 | )
577 |
578 | optimizer_class = bnb.optim.AdamW8bit
579 | else:
580 | optimizer_class = torch.optim.AdamW
581 |
582 | optimizer = optimizer_class(
583 | params_to_optimize,
584 | betas=(args.adam_beta1, args.adam_beta2),
585 | weight_decay=args.adam_weight_decay,
586 | eps=args.adam_epsilon,
587 | )
588 |
589 | if args.optimizer.lower() == "prodigy":
590 | try:
591 | import prodigyopt
592 | except ImportError:
593 | raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
594 |
595 | optimizer_class = prodigyopt.Prodigy
596 |
597 | if args.learning_rate <= 0.1:
598 | logger.warning(
599 | "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
600 | )
601 |
602 | optimizer = optimizer_class(
603 | params_to_optimize,
604 | lr=args.learning_rate,
605 | betas=(args.adam_beta1, args.adam_beta2),
606 | beta3=args.prodigy_beta3,
607 | weight_decay=args.adam_weight_decay,
608 | eps=args.adam_epsilon,
609 | decouple=args.prodigy_decouple,
610 | use_bias_correction=args.prodigy_use_bias_correction,
611 | safeguard_warmup=args.prodigy_safeguard_warmup,
612 | )
613 |
614 | # Dataset and DataLoaders creation:
615 | train_dataloader = torch.utils.data.DataLoader(
616 | CustomImageDataset(
617 | args.dataset_root,
618 | real_prompt_ratio=args.real_prompt_ratio
619 | ),
620 | batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, shuffle=True
621 | )
622 |
623 | vae_config_shift_factor = vae.config.shift_factor
624 | vae_config_scaling_factor = vae.config.scaling_factor
625 | vae_config_block_out_channels = vae.config.block_out_channels
626 |
627 | del vae
628 | free_memory()
629 |
630 | # Scheduler and math around the number of training steps.
631 | lr_scheduler = get_scheduler(
632 | args.lr_scheduler,
633 | optimizer=optimizer,
634 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
635 | num_training_steps=args.max_train_steps * accelerator.num_processes,
636 | num_cycles=args.lr_num_cycles,
637 | power=args.lr_power,
638 | )
639 |
640 | # Prepare everything with our `accelerator`.
641 | guidance_embeds = transformer.config.guidance_embeds
642 | transformer, train_dataloader, optimizer, lr_scheduler = accelerator.prepare(
643 | transformer, train_dataloader, optimizer, lr_scheduler
644 | )
645 |
646 | # We need to initialize the trackers we use, and also store our configuration.
647 | # The trackers initializes automatically on the main process.
648 | if accelerator.is_main_process:
649 | tracker_name = "flux-hr"
650 | accelerator.init_trackers(tracker_name, config=vars(args))
651 |
652 | # Train!
653 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
654 |
655 | logger.info("***** Running training *****")
656 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
657 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
658 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
659 | logger.info(f" Total optimization steps = {args.max_train_steps}")
660 | global_step = 0
661 |
662 | # Potentially load in the weights and states from a previous save
663 | if args.resume_from_checkpoint:
664 | if args.resume_from_checkpoint != "latest":
665 | path = os.path.basename(args.resume_from_checkpoint)
666 | else:
667 | # Get the mos recent checkpoint
668 | dirs = os.listdir(args.output_dir)
669 | dirs = [d for d in dirs if d.startswith("checkpoint")]
670 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
671 | path = dirs[-1] if len(dirs) > 0 else None
672 |
673 | if path is None:
674 | accelerator.print(
675 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
676 | )
677 | args.resume_from_checkpoint = None
678 | initial_global_step = 0
679 | else:
680 | accelerator.print(f"Resuming from checkpoint {path}")
681 | accelerator.load_state(os.path.join(args.output_dir, path))
682 | global_step = int(path.split("-")[1])
683 |
684 | initial_global_step = global_step
685 |
686 | else:
687 | initial_global_step = 0
688 |
689 | progress_bar = tqdm(
690 | range(0, args.max_train_steps),
691 | initial=initial_global_step,
692 | desc="Steps",
693 | # Only show the progress bar once on each machine.
694 | disable=not accelerator.is_local_main_process,
695 | )
696 |
697 | def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
698 | sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
699 | schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
700 | timesteps = timesteps.to(accelerator.device)
701 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
702 |
703 | sigma = sigmas[step_indices].flatten()
704 | while len(sigma.shape) < n_dim:
705 | sigma = sigma.unsqueeze(-1)
706 | return sigma
707 |
708 | transformer.train()
709 | loader_iter = iter(train_dataloader)
710 | while True:
711 | try:
712 | batch = loader_iter.__next__()
713 | except StopIteration:
714 | loader_iter = iter(train_dataloader)
715 | batch = loader_iter.__next__()
716 | models_to_accumulate = [transformer]
717 |
718 | with accelerator.accumulate(models_to_accumulate):
719 |
720 | prompt_embeds = batch['prompt_embeds_t5'].to(dtype=weight_dtype)
721 | pooled_prompt_embeds = batch['prompt_embeds_clip'].to(dtype=weight_dtype)
722 | text_ids = torch.zeros(prompt_embeds.shape[1], 3, device=accelerator.device, dtype=weight_dtype)
723 |
724 | with torch.no_grad():
725 | mean = batch['latent_codes_mean'].to(dtype=weight_dtype)
726 | std = batch['latent_codes_std'].to(dtype=weight_dtype)
727 | sample = torch.randn_like(mean)
728 | model_input = mean + std * sample
729 | model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
730 |
731 | vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
732 |
733 | latent_image_ids = FluxPipeline._prepare_latent_image_ids(
734 | model_input.shape[0],
735 | model_input.shape[2],
736 | model_input.shape[3],
737 | accelerator.device,
738 | weight_dtype,
739 | )
740 | # Sample noise that we'll add to the latents
741 | noise = torch.randn_like(model_input)
742 | bsz = model_input.shape[0]
743 |
744 | # Sample a random timestep for each image
745 | # for weighting schemes where we sample timesteps non-uniformly
746 | u = compute_density_for_timestep_sampling(
747 | weighting_scheme=args.weighting_scheme,
748 | batch_size=bsz,
749 | logit_mean=args.logit_mean,
750 | logit_std=args.logit_std,
751 | mode_scale=args.mode_scale,
752 | )
753 | indices = (u * noise_scheduler.config.num_train_timesteps).long()
754 | timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
755 |
756 | # Add noise according to flow matching.
757 | # zt = (1 - texp) * x + texp * z1
758 | sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
759 | noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
760 |
761 | packed_noisy_model_input = FluxPipeline._pack_latents(
762 | noisy_model_input,
763 | batch_size=model_input.shape[0],
764 | num_channels_latents=model_input.shape[1],
765 | height=model_input.shape[2],
766 | width=model_input.shape[3],
767 | )
768 |
769 | # handle guidance
770 | if guidance_embeds:
771 | guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
772 | guidance = guidance.expand(model_input.shape[0])
773 | else:
774 | guidance = None
775 |
776 | model_pred = transformer(
777 | hidden_states=packed_noisy_model_input,
778 | # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
779 | timestep=timesteps / 1000,
780 | guidance=guidance,
781 | pooled_projections=pooled_prompt_embeds,
782 | encoder_hidden_states=prompt_embeds,
783 | txt_ids=text_ids,
784 | img_ids=latent_image_ids,
785 | return_dict=False,
786 | ntk_factor=args.ntk_factor,
787 | joint_attention_kwargs={'proportional_attention': args.proportional_attention}
788 | )[0]
789 | model_pred = FluxPipeline._unpack_latents(
790 | model_pred,
791 | height=int(model_input.shape[2] * vae_scale_factor / 2),
792 | width=int(model_input.shape[3] * vae_scale_factor / 2),
793 | vae_scale_factor=vae_scale_factor,
794 | )
795 |
796 | # these weighting schemes use a uniform timestep sampling
797 | # and instead post-weight the loss
798 | weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
799 |
800 | # flow matching loss
801 | target = noise - model_input
802 |
803 | # Compute regular loss.
804 | loss = (weighting.float() * (model_pred.float() - target.float()) ** 2).mean()
805 |
806 | accelerator.backward(loss)
807 | if accelerator.sync_gradients:
808 | params_to_clip = (
809 | transformer.parameters()
810 | )
811 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
812 |
813 | optimizer.step()
814 | lr_scheduler.step()
815 | optimizer.zero_grad()
816 |
817 | # Checks if the accelerator has performed an optimization step behind the scenes
818 | if accelerator.sync_gradients:
819 | progress_bar.update(1)
820 | global_step += 1
821 |
822 | if global_step % args.checkpointing_steps == 0:
823 | if accelerator.is_main_process:
824 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
825 | if args.checkpoints_total_limit is not None:
826 | checkpoints = os.listdir(args.output_dir)
827 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
828 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
829 |
830 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
831 | if len(checkpoints) >= args.checkpoints_total_limit:
832 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
833 | removing_checkpoints = checkpoints[0:num_to_remove]
834 |
835 | logger.info(
836 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
837 | )
838 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
839 |
840 | for removing_checkpoint in removing_checkpoints:
841 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
842 | shutil.rmtree(removing_checkpoint)
843 |
844 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
845 | logger.info(f"Saving state to {save_path}...")
846 | accelerator.save_state(save_path)
847 |
848 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
849 | progress_bar.set_postfix(**logs)
850 | accelerator.log(logs, step=global_step)
851 |
852 | if global_step >= args.max_train_steps:
853 | break
854 |
855 | accelerator.wait_for_everyone()
856 | accelerator.end_training()
857 |
858 |
859 | if __name__ == "__main__":
860 | args = parse_args()
861 | main(args)
--------------------------------------------------------------------------------
/pipeline_flux.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import inspect
16 | from typing import Any, Callable, Dict, List, Optional, Union
17 |
18 | import numpy as np
19 | import torch
20 | from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
21 |
22 | from diffusers.image_processor import VaeImageProcessor
23 | from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
24 | from diffusers.models.autoencoders import AutoencoderKL
25 | from diffusers.models.transformers import FluxTransformer2DModel
26 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27 | from diffusers.utils import (
28 | USE_PEFT_BACKEND,
29 | is_torch_xla_available,
30 | logging,
31 | replace_example_docstring,
32 | scale_lora_layers,
33 | unscale_lora_layers,
34 | )
35 | from diffusers.utils.torch_utils import randn_tensor
36 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
37 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
38 |
39 |
40 | if is_torch_xla_available():
41 | import torch_xla.core.xla_model as xm
42 |
43 | XLA_AVAILABLE = True
44 | else:
45 | XLA_AVAILABLE = False
46 |
47 |
48 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49 |
50 | EXAMPLE_DOC_STRING = """
51 | Examples:
52 | ```py
53 | >>> import torch
54 | >>> from diffusers import FluxPipeline
55 |
56 | >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
57 | >>> pipe.to("cuda")
58 | >>> prompt = "A cat holding a sign that says hello world"
59 | >>> # Depending on the variant being used, the pipeline call will slightly vary.
60 | >>> # Refer to the pipeline documentation for more details.
61 | >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
62 | >>> image.save("flux.png")
63 | ```
64 | """
65 |
66 |
67 | def calculate_shift(
68 | image_seq_len,
69 | base_seq_len: int = 256,
70 | max_seq_len: int = 4096,
71 | base_shift: float = 0.5,
72 | max_shift: float = 1.16,
73 | ):
74 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
75 | b = base_shift - m * base_seq_len
76 | mu = image_seq_len * m + b
77 | return mu
78 |
79 |
80 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
81 | def retrieve_timesteps(
82 | scheduler,
83 | num_inference_steps: Optional[int] = None,
84 | device: Optional[Union[str, torch.device]] = None,
85 | timesteps: Optional[List[int]] = None,
86 | sigmas: Optional[List[float]] = None,
87 | **kwargs,
88 | ):
89 | r"""
90 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
91 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
92 |
93 | Args:
94 | scheduler (`SchedulerMixin`):
95 | The scheduler to get timesteps from.
96 | num_inference_steps (`int`):
97 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
98 | must be `None`.
99 | device (`str` or `torch.device`, *optional*):
100 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
101 | timesteps (`List[int]`, *optional*):
102 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
103 | `num_inference_steps` and `sigmas` must be `None`.
104 | sigmas (`List[float]`, *optional*):
105 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
106 | `num_inference_steps` and `timesteps` must be `None`.
107 |
108 | Returns:
109 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
110 | second element is the number of inference steps.
111 | """
112 | if timesteps is not None and sigmas is not None:
113 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
114 | if timesteps is not None:
115 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
116 | if not accepts_timesteps:
117 | raise ValueError(
118 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
119 | f" timestep schedules. Please check whether you are using the correct scheduler."
120 | )
121 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
122 | timesteps = scheduler.timesteps
123 | num_inference_steps = len(timesteps)
124 | elif sigmas is not None:
125 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126 | if not accept_sigmas:
127 | raise ValueError(
128 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129 | f" sigmas schedules. Please check whether you are using the correct scheduler."
130 | )
131 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
132 | timesteps = scheduler.timesteps
133 | num_inference_steps = len(timesteps)
134 | else:
135 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
136 | timesteps = scheduler.timesteps
137 | return timesteps, num_inference_steps
138 |
139 |
140 | class FluxPipeline(
141 | DiffusionPipeline,
142 | FluxLoraLoaderMixin,
143 | FromSingleFileMixin,
144 | TextualInversionLoaderMixin,
145 | ):
146 | r"""
147 | The Flux pipeline for text-to-image generation.
148 |
149 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
150 |
151 | Args:
152 | transformer ([`FluxTransformer2DModel`]):
153 | Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
154 | scheduler ([`FlowMatchEulerDiscreteScheduler`]):
155 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
156 | vae ([`AutoencoderKL`]):
157 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
158 | text_encoder ([`CLIPTextModel`]):
159 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
160 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
161 | text_encoder_2 ([`T5EncoderModel`]):
162 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
163 | the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
164 | tokenizer (`CLIPTokenizer`):
165 | Tokenizer of class
166 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
167 | tokenizer_2 (`T5TokenizerFast`):
168 | Second Tokenizer of class
169 | [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
170 | """
171 |
172 | model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
173 | _optional_components = []
174 | _callback_tensor_inputs = ["latents", "prompt_embeds"]
175 |
176 | def __init__(
177 | self,
178 | scheduler: FlowMatchEulerDiscreteScheduler,
179 | vae: AutoencoderKL,
180 | text_encoder: CLIPTextModel,
181 | tokenizer: CLIPTokenizer,
182 | text_encoder_2: T5EncoderModel,
183 | tokenizer_2: T5TokenizerFast,
184 | transformer: FluxTransformer2DModel,
185 | ):
186 | super().__init__()
187 |
188 | self.register_modules(
189 | vae=vae,
190 | text_encoder=text_encoder,
191 | text_encoder_2=text_encoder_2,
192 | tokenizer=tokenizer,
193 | tokenizer_2=tokenizer_2,
194 | transformer=transformer,
195 | scheduler=scheduler,
196 | )
197 | self.vae_scale_factor = (
198 | 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
199 | )
200 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
201 | self.tokenizer_max_length = (
202 | self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
203 | )
204 | self.default_sample_size = 64
205 |
206 | def _get_t5_prompt_embeds(
207 | self,
208 | prompt: Union[str, List[str]] = None,
209 | num_images_per_prompt: int = 1,
210 | max_sequence_length: int = 512,
211 | device: Optional[torch.device] = None,
212 | dtype: Optional[torch.dtype] = None,
213 | ):
214 | device = device or self._execution_device
215 | dtype = dtype or self.text_encoder.dtype
216 |
217 | prompt = [prompt] if isinstance(prompt, str) else prompt
218 | batch_size = len(prompt)
219 |
220 | if isinstance(self, TextualInversionLoaderMixin):
221 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
222 |
223 | text_inputs = self.tokenizer_2(
224 | prompt,
225 | padding="max_length",
226 | max_length=max_sequence_length,
227 | truncation=True,
228 | return_length=False,
229 | return_overflowing_tokens=False,
230 | return_tensors="pt",
231 | )
232 | text_input_ids = text_inputs.input_ids
233 | untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
234 |
235 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
236 | removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
237 | logger.warning(
238 | "The following part of your input was truncated because `max_sequence_length` is set to "
239 | f" {max_sequence_length} tokens: {removed_text}"
240 | )
241 |
242 | prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
243 |
244 | dtype = self.text_encoder_2.dtype
245 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
246 |
247 | _, seq_len, _ = prompt_embeds.shape
248 |
249 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
250 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
251 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
252 |
253 | return prompt_embeds
254 |
255 | def _get_clip_prompt_embeds(
256 | self,
257 | prompt: Union[str, List[str]],
258 | num_images_per_prompt: int = 1,
259 | device: Optional[torch.device] = None,
260 | ):
261 | device = device or self._execution_device
262 |
263 | prompt = [prompt] if isinstance(prompt, str) else prompt
264 | batch_size = len(prompt)
265 |
266 | if isinstance(self, TextualInversionLoaderMixin):
267 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
268 |
269 | text_inputs = self.tokenizer(
270 | prompt,
271 | padding="max_length",
272 | max_length=self.tokenizer_max_length,
273 | truncation=True,
274 | return_overflowing_tokens=False,
275 | return_length=False,
276 | return_tensors="pt",
277 | )
278 |
279 | text_input_ids = text_inputs.input_ids
280 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
281 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
282 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
283 | logger.warning(
284 | "The following part of your input was truncated because CLIP can only handle sequences up to"
285 | f" {self.tokenizer_max_length} tokens: {removed_text}"
286 | )
287 | prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
288 |
289 | # Use pooled output of CLIPTextModel
290 | prompt_embeds = prompt_embeds.pooler_output
291 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
292 |
293 | # duplicate text embeddings for each generation per prompt, using mps friendly method
294 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
295 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
296 |
297 | return prompt_embeds
298 |
299 | def encode_prompt(
300 | self,
301 | prompt: Union[str, List[str]],
302 | prompt_2: Union[str, List[str]],
303 | device: Optional[torch.device] = None,
304 | num_images_per_prompt: int = 1,
305 | prompt_embeds: Optional[torch.FloatTensor] = None,
306 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
307 | max_sequence_length: int = 512,
308 | lora_scale: Optional[float] = None,
309 | ):
310 | r"""
311 |
312 | Args:
313 | prompt (`str` or `List[str]`, *optional*):
314 | prompt to be encoded
315 | prompt_2 (`str` or `List[str]`, *optional*):
316 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
317 | used in all text-encoders
318 | device: (`torch.device`):
319 | torch device
320 | num_images_per_prompt (`int`):
321 | number of images that should be generated per prompt
322 | prompt_embeds (`torch.FloatTensor`, *optional*):
323 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
324 | provided, text embeddings will be generated from `prompt` input argument.
325 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
326 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
327 | If not provided, pooled text embeddings will be generated from `prompt` input argument.
328 | lora_scale (`float`, *optional*):
329 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
330 | """
331 | device = device or self._execution_device
332 |
333 | # set lora scale so that monkey patched LoRA
334 | # function of text encoder can correctly access it
335 | if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
336 | self._lora_scale = lora_scale
337 |
338 | # dynamically adjust the LoRA scale
339 | if self.text_encoder is not None and USE_PEFT_BACKEND:
340 | scale_lora_layers(self.text_encoder, lora_scale)
341 | if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
342 | scale_lora_layers(self.text_encoder_2, lora_scale)
343 |
344 | prompt = [prompt] if isinstance(prompt, str) else prompt
345 |
346 | if prompt_embeds is None:
347 | prompt_2 = prompt_2 or prompt
348 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
349 |
350 | # We only use the pooled prompt output from the CLIPTextModel
351 | pooled_prompt_embeds = self._get_clip_prompt_embeds(
352 | prompt=prompt,
353 | device=device,
354 | num_images_per_prompt=num_images_per_prompt,
355 | )
356 | prompt_embeds = self._get_t5_prompt_embeds(
357 | prompt=prompt_2,
358 | num_images_per_prompt=num_images_per_prompt,
359 | max_sequence_length=max_sequence_length,
360 | device=device,
361 | )
362 |
363 | if self.text_encoder is not None:
364 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
365 | # Retrieve the original scale by scaling back the LoRA layers
366 | unscale_lora_layers(self.text_encoder, lora_scale)
367 |
368 | if self.text_encoder_2 is not None:
369 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
370 | # Retrieve the original scale by scaling back the LoRA layers
371 | unscale_lora_layers(self.text_encoder_2, lora_scale)
372 |
373 | dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
374 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
375 |
376 | return prompt_embeds, pooled_prompt_embeds, text_ids
377 |
378 | def check_inputs(
379 | self,
380 | prompt,
381 | prompt_2,
382 | height,
383 | width,
384 | prompt_embeds=None,
385 | pooled_prompt_embeds=None,
386 | callback_on_step_end_tensor_inputs=None,
387 | max_sequence_length=None,
388 | ):
389 | if height % 8 != 0 or width % 8 != 0:
390 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
391 |
392 | if callback_on_step_end_tensor_inputs is not None and not all(
393 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
394 | ):
395 | raise ValueError(
396 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
397 | )
398 |
399 | if prompt is not None and prompt_embeds is not None:
400 | raise ValueError(
401 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
402 | " only forward one of the two."
403 | )
404 | elif prompt_2 is not None and prompt_embeds is not None:
405 | raise ValueError(
406 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
407 | " only forward one of the two."
408 | )
409 | elif prompt is None and prompt_embeds is None:
410 | raise ValueError(
411 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
412 | )
413 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
414 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
415 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
416 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
417 |
418 | if prompt_embeds is not None and pooled_prompt_embeds is None:
419 | raise ValueError(
420 | "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
421 | )
422 |
423 | if max_sequence_length is not None and max_sequence_length > 512:
424 | raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
425 |
426 | @staticmethod
427 | def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
428 | latent_image_ids = torch.zeros(height // 2, width // 2, 3)
429 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
430 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
431 |
432 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
433 |
434 | latent_image_ids = latent_image_ids.reshape(
435 | latent_image_id_height * latent_image_id_width, latent_image_id_channels
436 | )
437 |
438 | return latent_image_ids.to(device=device, dtype=dtype)
439 |
440 | @staticmethod
441 | def _pack_latents(latents, batch_size, num_channels_latents, height, width):
442 | latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
443 | latents = latents.permute(0, 2, 4, 1, 3, 5)
444 | latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
445 |
446 | return latents
447 |
448 | @staticmethod
449 | def _unpack_latents(latents, height, width, vae_scale_factor):
450 | batch_size, num_patches, channels = latents.shape
451 |
452 | height = height // vae_scale_factor
453 | width = width // vae_scale_factor
454 |
455 | latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
456 | latents = latents.permute(0, 3, 1, 4, 2, 5)
457 |
458 | latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
459 |
460 | return latents
461 |
462 | def enable_vae_slicing(self):
463 | r"""
464 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
465 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
466 | """
467 | self.vae.enable_slicing()
468 |
469 | def disable_vae_slicing(self):
470 | r"""
471 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
472 | computing decoding in one step.
473 | """
474 | self.vae.disable_slicing()
475 |
476 | def enable_vae_tiling(self):
477 | r"""
478 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
479 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
480 | processing larger images.
481 | """
482 | self.vae.enable_tiling()
483 |
484 | def disable_vae_tiling(self):
485 | r"""
486 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
487 | computing decoding in one step.
488 | """
489 | self.vae.disable_tiling()
490 |
491 | def prepare_latents(
492 | self,
493 | batch_size,
494 | num_channels_latents,
495 | height,
496 | width,
497 | dtype,
498 | device,
499 | generator,
500 | latents=None,
501 | ):
502 | height = 2 * (int(height) // self.vae_scale_factor)
503 | width = 2 * (int(width) // self.vae_scale_factor)
504 |
505 | shape = (batch_size, num_channels_latents, height, width)
506 |
507 | if latents is not None:
508 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
509 | return latents.to(device=device, dtype=dtype), latent_image_ids
510 |
511 | if isinstance(generator, list) and len(generator) != batch_size:
512 | raise ValueError(
513 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
514 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
515 | )
516 |
517 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
518 | latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
519 |
520 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
521 |
522 | return latents, latent_image_ids
523 |
524 | @property
525 | def guidance_scale(self):
526 | return self._guidance_scale
527 |
528 | @property
529 | def joint_attention_kwargs(self):
530 | return self._joint_attention_kwargs
531 |
532 | @property
533 | def num_timesteps(self):
534 | return self._num_timesteps
535 |
536 | @property
537 | def interrupt(self):
538 | return self._interrupt
539 |
540 | @torch.no_grad()
541 | @replace_example_docstring(EXAMPLE_DOC_STRING)
542 | def __call__(
543 | self,
544 | prompt: Union[str, List[str]] = None,
545 | prompt_2: Optional[Union[str, List[str]]] = None,
546 | height: Optional[int] = None,
547 | width: Optional[int] = None,
548 | num_inference_steps: int = 28,
549 | timesteps: List[int] = None,
550 | guidance_scale: float = 3.5,
551 | num_images_per_prompt: Optional[int] = 1,
552 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
553 | latents: Optional[torch.FloatTensor] = None,
554 | prompt_embeds: Optional[torch.FloatTensor] = None,
555 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
556 | output_type: Optional[str] = "pil",
557 | return_dict: bool = True,
558 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
559 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
560 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
561 | max_sequence_length: int = 512,
562 | ntk_factor: float = 10.0,
563 | proportional_attention: bool = True
564 | ):
565 | r"""
566 | Function invoked when calling the pipeline for generation.
567 |
568 | Args:
569 | prompt (`str` or `List[str]`, *optional*):
570 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
571 | instead.
572 | prompt_2 (`str` or `List[str]`, *optional*):
573 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
574 | will be used instead
575 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
576 | The height in pixels of the generated image. This is set to 1024 by default for the best results.
577 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
578 | The width in pixels of the generated image. This is set to 1024 by default for the best results.
579 | num_inference_steps (`int`, *optional*, defaults to 50):
580 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
581 | expense of slower inference.
582 | timesteps (`List[int]`, *optional*):
583 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
584 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
585 | passed will be used. Must be in descending order.
586 | guidance_scale (`float`, *optional*, defaults to 7.0):
587 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
588 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
589 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
590 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
591 | usually at the expense of lower image quality.
592 | num_images_per_prompt (`int`, *optional*, defaults to 1):
593 | The number of images to generate per prompt.
594 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
595 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
596 | to make generation deterministic.
597 | latents (`torch.FloatTensor`, *optional*):
598 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
599 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
600 | tensor will ge generated by sampling using the supplied random `generator`.
601 | prompt_embeds (`torch.FloatTensor`, *optional*):
602 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
603 | provided, text embeddings will be generated from `prompt` input argument.
604 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
605 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
606 | If not provided, pooled text embeddings will be generated from `prompt` input argument.
607 | output_type (`str`, *optional*, defaults to `"pil"`):
608 | The output format of the generate image. Choose between
609 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
610 | return_dict (`bool`, *optional*, defaults to `True`):
611 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
612 | joint_attention_kwargs (`dict`, *optional*):
613 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
614 | `self.processor` in
615 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
616 | callback_on_step_end (`Callable`, *optional*):
617 | A function that calls at the end of each denoising steps during the inference. The function is called
618 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
619 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
620 | `callback_on_step_end_tensor_inputs`.
621 | callback_on_step_end_tensor_inputs (`List`, *optional*):
622 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
623 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
624 | `._callback_tensor_inputs` attribute of your pipeline class.
625 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
626 |
627 | Examples:
628 |
629 | Returns:
630 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
631 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
632 | images.
633 | """
634 |
635 | height = height or self.default_sample_size * self.vae_scale_factor
636 | width = width or self.default_sample_size * self.vae_scale_factor
637 |
638 | # 1. Check inputs. Raise error if not correct
639 | self.check_inputs(
640 | prompt,
641 | prompt_2,
642 | height,
643 | width,
644 | prompt_embeds=prompt_embeds,
645 | pooled_prompt_embeds=pooled_prompt_embeds,
646 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
647 | max_sequence_length=max_sequence_length,
648 | )
649 |
650 | self._guidance_scale = guidance_scale
651 | if joint_attention_kwargs is None:
652 | joint_attention_kwargs = {'proportional_attention': proportional_attention}
653 | else:
654 | joint_attention_kwargs = {**joint_attention_kwargs, 'proportional_attention': proportional_attention}
655 | self._joint_attention_kwargs = joint_attention_kwargs
656 | self._interrupt = False
657 |
658 | # 2. Define call parameters
659 | if prompt is not None and isinstance(prompt, str):
660 | batch_size = 1
661 | elif prompt is not None and isinstance(prompt, list):
662 | batch_size = len(prompt)
663 | else:
664 | batch_size = prompt_embeds.shape[0]
665 |
666 | device = self._execution_device
667 |
668 | lora_scale = (
669 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
670 | )
671 | (
672 | prompt_embeds,
673 | pooled_prompt_embeds,
674 | text_ids,
675 | ) = self.encode_prompt(
676 | prompt=prompt,
677 | prompt_2=prompt_2,
678 | prompt_embeds=prompt_embeds,
679 | pooled_prompt_embeds=pooled_prompt_embeds,
680 | device=device,
681 | num_images_per_prompt=num_images_per_prompt,
682 | max_sequence_length=max_sequence_length,
683 | lora_scale=lora_scale,
684 | )
685 |
686 | # 4. Prepare latent variables
687 | num_channels_latents = self.transformer.config.in_channels // 4
688 | latents, latent_image_ids = self.prepare_latents(
689 | batch_size * num_images_per_prompt,
690 | num_channels_latents,
691 | height,
692 | width,
693 | prompt_embeds.dtype,
694 | device,
695 | generator,
696 | latents,
697 | )
698 |
699 | # 5. Prepare timesteps
700 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
701 | image_seq_len = latents.shape[1]
702 | mu = calculate_shift(
703 | image_seq_len,
704 | self.scheduler.config.base_image_seq_len,
705 | self.scheduler.config.max_image_seq_len,
706 | self.scheduler.config.base_shift,
707 | self.scheduler.config.max_shift,
708 | )
709 | timesteps, num_inference_steps = retrieve_timesteps(
710 | self.scheduler,
711 | num_inference_steps,
712 | device,
713 | timesteps,
714 | sigmas,
715 | mu=mu,
716 | )
717 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
718 | self._num_timesteps = len(timesteps)
719 |
720 | # handle guidance
721 | if self.transformer.config.guidance_embeds:
722 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
723 | guidance = guidance.expand(latents.shape[0])
724 | else:
725 | guidance = None
726 |
727 | # 6. Denoising loop
728 | with self.progress_bar(total=num_inference_steps) as progress_bar:
729 | for i, t in enumerate(timesteps):
730 | if self.interrupt:
731 | continue
732 |
733 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
734 | timestep = t.expand(latents.shape[0]).to(latents.dtype)
735 |
736 | noise_pred = self.transformer(
737 | hidden_states=latents,
738 | timestep=timestep / 1000,
739 | guidance=guidance,
740 | pooled_projections=pooled_prompt_embeds,
741 | encoder_hidden_states=prompt_embeds,
742 | txt_ids=text_ids,
743 | img_ids=latent_image_ids,
744 | joint_attention_kwargs=self.joint_attention_kwargs,
745 | return_dict=False,
746 | ntk_factor=ntk_factor
747 | )[0]
748 |
749 | # compute the previous noisy sample x_t -> x_t-1
750 | latents_dtype = latents.dtype
751 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
752 |
753 | if latents.dtype != latents_dtype:
754 | #if torch.backends.mps.is_available():
755 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
756 | latents = latents.to(latents_dtype)
757 |
758 | if callback_on_step_end is not None:
759 | callback_kwargs = {}
760 | for k in callback_on_step_end_tensor_inputs:
761 | callback_kwargs[k] = locals()[k]
762 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
763 |
764 | latents = callback_outputs.pop("latents", latents)
765 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
766 |
767 | # call the callback, if provided
768 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
769 | progress_bar.update()
770 |
771 | if XLA_AVAILABLE:
772 | xm.mark_step()
773 |
774 | if output_type == "latent":
775 | image = latents
776 |
777 | else:
778 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
779 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
780 | image = self.vae.decode(latents, return_dict=False)[0]
781 | image = self.image_processor.postprocess(image, output_type=output_type)
782 |
783 | # Offload all models
784 | self.maybe_free_model_hooks()
785 |
786 | if not return_dict:
787 | return (image,)
788 |
789 | return FluxPipelineOutput(images=image)
790 |
--------------------------------------------------------------------------------