├── .gitignore ├── .python-version ├── README.md ├── assets └── banner.jpg ├── inference.ipynb ├── inference.py ├── pyproject.toml ├── requirements.txt ├── src └── flux │ ├── __init__.py │ ├── math.py │ ├── model.py │ ├── modules │ ├── autoencoder.py │ ├── conditioner.py │ └── layers.py │ ├── pipeline.py │ ├── sampling.py │ └── util.py └── uv.lock /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | Makefile 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | weights/ 25 | 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache/ 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 113 | .pdm.toml 114 | .pdm-python 115 | .pdm-build/ 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | 167 | .DS_Store 168 | wandb 169 | trace 170 | tmp* 171 | logs -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FLUX-Krea 2 | 3 | ![banner](assets/banner.jpg) 4 | 5 | --- 6 | 7 | This is the official repository for `FLUX.1 Krea [dev]` (AKA `flux-krea`). 8 | 9 | The code in this repository and the weights hosted on Huggingface are the open version of [Krea 1](https://www.krea.ai/krea-1), our first image model trained in collaboration with [Black Forest Labs](https://bfl.ai/) to offer superior aesthetic control and image quality. 10 | 11 | The repository contains [inference code](https://github.com/krea-ai/flux-krea/blob/main/inference.py) and a [Jupyter Notebook](https://github.com/krea-ai/flux-krea/blob/main/inference.ipynb) to run the model; you can download the weights and inspect the model card [here](https://huggingface.co/black-forest-labs/FLUX.1-Krea-dev). 12 | 13 | 14 | ## Usage 15 | 16 | ### With `pip` 17 | 18 | ``` 19 | git clone https://github.com/krea-ai/flux-krea.git 20 | cd flux-krea 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ### With [`uv`](https://github.com/astral-sh/uv) 25 | 26 | ``` 27 | git clone https://github.com/krea-ai/flux-krea.git 28 | cd kflux 29 | uv sync 30 | ``` 31 | 32 | ### Live Demo 33 | 34 | Generate on [krea.ai](https://www.krea.ai/apps/image/flux-krea) 35 | 36 | ## Running the model 37 | 38 | ```bash 39 | python inference.py --prompt "a cute cat" --seed 42 40 | ``` 41 | 42 | Check `inference.ipynb` for a full example. It may take a few minutes to download the model weights on your first attempt. 43 | 44 | **Recommended inference settings** 45 | 46 | - **Resolution** - between `1024` and `1280` pixels. 47 | 48 | - **Number of inference steps** - between 28 - 32 steps 49 | 50 | - **CFG Guidance** - between 3.5 - 5.0 51 | 52 | ## How was it made? 53 | 54 | Krea 1 was created in as a research collaboration between [Krea](https://www.krea.ai) and [Black Forest Labs](https://bfl.ai). 55 | 56 | `FLUX.1 Krea [dev]` is a 12B param. rectified-flow model _distilled_ from Krea 1. This model is a CFG-distilled model and fully compatible with the [FLUX.1 [dev]](https://github.com/black-forest-labs/flux) architecture. 57 | 58 | In a nutshell, we ran a large-scale post-training of the pre-trained weights provided by Black Forest Labs. 59 | 60 | For more details on the development of this model, [read our technical blog post](https://krea.ai/blog/flux-krea-open-source-release). 61 | 62 | ## Acknowledgements 63 | 64 | We would like to thank the Black Forest Labs team for providing the base model weights. None of this would be possible without their contribution. The post-training work would not be possible without the hard work of our data, infrastructure, and product team who put together a solid foundation for our post-training pipelines. 65 | 66 | If you are interested in building large-scale image/video/3D/world models, or the engineering and data infrastructure around it... 67 | 68 | > 69 | > [We are hiring.](https://www.krea.ai/careers) 70 | > 71 | 72 | ### Citation 73 | 74 | ```bib 75 | @misc{flux1kreadev2025, 76 | author={Sangwu Lee, Titus Ebbecke, Erwann Millon, Will Beddow, Le Zhuo, Iker García-Ferrero, Liam Esparraguera, Mihai Petrescu, Gian Saß, Gabriel Menezes, Victor Perez}, 77 | title={FLUX.1 Krea [dev]}, 78 | year={2025}, 79 | howpublished={\url{https://github.com/krea-ai/flux-krea}}, 80 | } 81 | ``` 82 | -------------------------------------------------------------------------------- /assets/banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krea-ai/flux-krea/88238a00c24afa49a19b0d685f491753af8293f7/assets/banner.jpg -------------------------------------------------------------------------------- /inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "653b9ecd", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import torch\n", 12 | "from src.flux.util import load_ae, load_clip, load_t5, load_flow_model\n", 13 | "from src.flux.pipeline import Sampler\n", 14 | "\n", 15 | "device = \"cuda\"\n", 16 | "model = load_flow_model(\"flux-krea-dev\", device=\"cpu\")\n", 17 | "ae = load_ae(\"flux-krea-dev\")\n", 18 | "clip = load_clip()\n", 19 | "t5 = load_t5()\n", 20 | "\n", 21 | "ae = ae.to(device=device, dtype=torch.bfloat16)\n", 22 | "clip = clip.to(device=device, dtype=torch.bfloat16)\n", 23 | "t5 = t5.to(device=device, dtype=torch.bfloat16)\n", 24 | "model = model.to(device, dtype=torch.bfloat16)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "6905ce9b", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "sampler = Sampler(\n", 35 | " model=model,\n", 36 | " ae=ae,\n", 37 | " clip=clip,\n", 38 | " t5=t5,\n", 39 | " device=device,\n", 40 | " dtype=torch.bfloat16,\n", 41 | ")" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "7508f5d3", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "image = sampler(\n", 52 | " prompt=\"a cute cat\",\n", 53 | " width=1024,\n", 54 | " height=1024,\n", 55 | " guidance=4.5,\n", 56 | " num_steps=28,\n", 57 | " seed=42,\n", 58 | ")\n", 59 | "\n", 60 | "# enjoy a cute cat\n", 61 | "image" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "id": "58d17741", 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [] 71 | } 72 | ], 73 | "metadata": { 74 | "kernelspec": { 75 | "display_name": ".venv", 76 | "language": "python", 77 | "name": "python3" 78 | }, 79 | "language_info": { 80 | "codemirror_mode": { 81 | "name": "ipython", 82 | "version": 3 83 | }, 84 | "file_extension": ".py", 85 | "mimetype": "text/x-python", 86 | "name": "python", 87 | "nbconvert_exporter": "python", 88 | "pygments_lexer": "ipython3", 89 | "version": "3.12.3" 90 | } 91 | }, 92 | "nbformat": 4, 93 | "nbformat_minor": 5 94 | } 95 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import click 3 | import torch 4 | from pathlib import Path 5 | from src.flux.util import load_ae, load_clip, load_t5, load_flow_model 6 | from src.flux.pipeline import Sampler 7 | 8 | 9 | @click.command() 10 | @click.option('--prompt', '-p', required=True, help='Text prompt for image generation') 11 | @click.option('--width', '-w', default=1024, type=int, help='Image width (default: 1024)') 12 | @click.option('--height', '-h', default=1024, type=int, help='Image height (default: 1024)') 13 | @click.option('--guidance', '-g', default=4.5, type=float, help='Guidance scale (default: 4.5)') 14 | @click.option('--num-steps', '-s', default=28, type=int, help='Number of sampling steps (default: 28)') 15 | @click.option('--seed', default=42, type=int, help='Random seed (default: 42)') 16 | @click.option('--output', '-o', default='output.png', help='Output image path (default: output.png)') 17 | @click.option('--device', default='cuda', help='Device to use (default: cuda)') 18 | def generate(prompt, width, height, guidance, num_steps, seed, output, device): 19 | torch_dtype = torch.bfloat16 20 | click.echo("Loading models...") 21 | 22 | # Load models 23 | click.echo("Loading AE...") 24 | ae = load_ae("flux-krea-dev") 25 | 26 | click.echo("Loading CLIP...") 27 | clip = load_clip() 28 | 29 | click.echo("Loading T5...") 30 | t5 = load_t5() 31 | 32 | click.echo("Loading MMDiT...") 33 | model = load_flow_model("flux-krea-dev", device="cpu") 34 | model = model.to(device=device, dtype=torch_dtype) 35 | 36 | # Move models to device with specified dtype 37 | ae = ae.to(device=device, dtype=torch_dtype) 38 | clip = clip.to(device=device, dtype=torch_dtype) 39 | t5 = t5.to(device=device, dtype=torch_dtype) 40 | 41 | # Create sampler 42 | sampler = Sampler( 43 | model=model, 44 | ae=ae, 45 | clip=clip, 46 | t5=t5, 47 | device=device, 48 | dtype=torch_dtype, 49 | ) 50 | 51 | click.echo(f"Generating image with prompt: '{prompt}'") 52 | click.echo(f"Parameters: {width}x{height}, guidance={guidance}, steps={num_steps}, seed={seed}") 53 | 54 | # Generate image 55 | image = sampler( 56 | prompt=prompt, 57 | width=width, 58 | height=height, 59 | guidance=guidance, 60 | num_steps=num_steps, 61 | seed=seed, 62 | ) 63 | 64 | # Save image 65 | outpath = Path(output) 66 | image.save(outpath) 67 | 68 | click.echo(f"Image saved to: {outpath.absolute()}") 69 | 70 | 71 | if __name__ == '__main__': 72 | generate() 73 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "kflux" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "click>=8.2.1", 9 | "einops>=0.8.1", 10 | "huggingface-hub>=0.34.3", 11 | "jupyter>=1.1.1", 12 | "safetensors>=0.5.3", 13 | "sentencepiece>=0.2.0", 14 | "tokenizers>=0.21.4", 15 | "torch>=2.6.0", 16 | "torchvision>=0.22.1", 17 | "tqdm>=4.67.1", 18 | "transformers>=4.54.1", 19 | ] 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.6.0 2 | torchvision 3 | transformers 4 | safetensors 5 | einops 6 | tqdm 7 | tokenizers 8 | sentencepiece 9 | huggingface-hub 10 | click 11 | jupyter 12 | -------------------------------------------------------------------------------- /src/flux/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | version_tuple = (0, 1, 0, "final") 3 | 4 | from pathlib import Path 5 | 6 | PACKAGE = "kflux" 7 | PACKAGE_ROOT = Path(__file__).parent 8 | -------------------------------------------------------------------------------- /src/flux/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import Tensor 4 | 5 | def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask: Tensor | None = None, scale: float | None = None) -> Tensor: 6 | q, k = apply_rope(q, k, pe) 7 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, scale=scale) 8 | x = rearrange(x, "B H L D -> B L (H D)") 9 | 10 | return x 11 | 12 | def rope(pos: Tensor, dim: int, theta: int) -> Tensor: 13 | assert dim % 2 == 0 14 | 15 | b, l = pos.shape 16 | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim 17 | omega = 1.0 / ((theta) ** scale) 18 | 19 | out = torch.einsum("...n,d->...nd", pos, omega) 20 | out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) 21 | out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) 22 | return out.float() 23 | 24 | 25 | def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: 26 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 27 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 28 | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 29 | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 30 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) 31 | -------------------------------------------------------------------------------- /src/flux/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | from .modules.layers import DoubleStreamBlock, SingleStreamBlock, EmbedND, LastLayer, MLPEmbedder, TimestepEmbedding 6 | 7 | 8 | @dataclass 9 | class FluxParams: 10 | in_channels: int 11 | vec_in_dim: int 12 | context_in_dim: int 13 | hidden_size: int 14 | mlp_ratio: float 15 | num_heads: int 16 | depth: int 17 | depth_single_blocks: int 18 | axes_dim: list[int] 19 | theta: int 20 | qkv_bias: bool 21 | guidance_embed: bool 22 | 23 | 24 | class Flux(nn.Module): 25 | def __init__(self, params: FluxParams): 26 | super().__init__() 27 | 28 | self.params = params 29 | self.in_channels = params.in_channels 30 | self.out_channels = self.in_channels 31 | pe_dim = params.hidden_size // params.num_heads 32 | 33 | self.hidden_size = params.hidden_size 34 | self.num_heads = params.num_heads 35 | self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 36 | self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) 37 | self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 38 | self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) 39 | self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() 40 | self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) 41 | self.timestep_embedder = TimestepEmbedding(dim=256, max_period=10_000, time_factor=1_000.0) 42 | 43 | self.double_blocks = nn.ModuleList( 44 | [ 45 | DoubleStreamBlock( 46 | self.hidden_size, 47 | self.num_heads, 48 | mlp_ratio=params.mlp_ratio, 49 | qkv_bias=params.qkv_bias, 50 | ) 51 | for _ in range(params.depth) 52 | ] 53 | ) 54 | 55 | self.single_blocks = nn.ModuleList( 56 | [ 57 | SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) 58 | for _ in range(params.depth_single_blocks) 59 | ] 60 | ) 61 | 62 | self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) 63 | 64 | def forward( 65 | self, 66 | img: Tensor, 67 | img_ids: Tensor, 68 | txt: Tensor, 69 | txt_ids: Tensor, 70 | timesteps: Tensor, 71 | y: Tensor, 72 | guidance: Tensor | None = None, 73 | ) -> Tensor: 74 | # running on sequences img 75 | img = self.img_in(img) 76 | vec = self.time_in(self.timestep_embedder(timesteps)) 77 | if self.params.guidance_embed: 78 | vec = vec + self.guidance_in(self.timestep_embedder(guidance)) 79 | 80 | vec = vec + self.vector_in(y) 81 | txt = self.txt_in(txt) 82 | 83 | ids = torch.cat((txt_ids, img_ids), dim=1) 84 | pe = self.pe_embedder(ids) 85 | for block in self.double_blocks: 86 | img, txt = block( 87 | img=img, 88 | txt=txt, 89 | vec=vec, 90 | pe=pe, 91 | ) 92 | img = torch.cat((txt, img), 1) 93 | for block in self.single_blocks: 94 | img = block(img, vec=vec, pe=pe) 95 | 96 | b, s, c = txt.shape 97 | img = img[:, s:, ...] 98 | img = self.final_layer(img, vec) 99 | 100 | return img -------------------------------------------------------------------------------- /src/flux/modules/autoencoder.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from torch import Tensor, nn 7 | 8 | 9 | @dataclass 10 | class AutoEncoderParams: 11 | resolution: int 12 | in_channels: int 13 | ch: int 14 | out_ch: int 15 | ch_mult: list[int] 16 | num_res_blocks: int 17 | z_channels: int 18 | scale_factor: float 19 | shift_factor: float 20 | 21 | 22 | def swish(x: Tensor) -> Tensor: 23 | return x * torch.sigmoid(x) 24 | 25 | 26 | class AttnBlock(nn.Module): 27 | def __init__(self, in_channels: int): 28 | super().__init__() 29 | self.in_channels = in_channels 30 | 31 | self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 32 | 33 | self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 34 | self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 35 | self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 36 | self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) 37 | 38 | def attention(self, h_: Tensor) -> Tensor: 39 | h_ = self.norm(h_) 40 | q = self.q(h_) 41 | k = self.k(h_) 42 | v = self.v(h_) 43 | 44 | b, c, h, w = q.shape 45 | q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 46 | k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 47 | v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 48 | h_ = nn.functional.scaled_dot_product_attention(q, k, v) 49 | 50 | return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) 51 | 52 | def forward(self, x: Tensor) -> Tensor: 53 | return x + self.proj_out(self.attention(x)) 54 | 55 | 56 | class ResnetBlock(nn.Module): 57 | def __init__(self, in_channels: int, out_channels: int): 58 | super().__init__() 59 | self.in_channels = in_channels 60 | out_channels = in_channels if out_channels is None else out_channels 61 | self.out_channels = out_channels 62 | 63 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 64 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 65 | self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) 66 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 67 | if self.in_channels != self.out_channels: 68 | self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 69 | 70 | def forward(self, x): 71 | h = x 72 | h = self.norm1(h) 73 | h = swish(h) 74 | h = self.conv1(h) 75 | 76 | h = self.norm2(h) 77 | h = swish(h) 78 | h = self.conv2(h) 79 | 80 | if self.in_channels != self.out_channels: 81 | x = self.nin_shortcut(x) 82 | 83 | return x + h 84 | 85 | 86 | class Downsample(nn.Module): 87 | def __init__(self, in_channels: int): 88 | super().__init__() 89 | # no asymmetric padding in torch conv, must do it ourselves 90 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) 91 | 92 | def forward(self, x: Tensor): 93 | pad = (0, 1, 0, 1) 94 | x = nn.functional.pad(x, pad, mode="constant", value=0) 95 | x = self.conv(x) 96 | return x 97 | 98 | 99 | class Upsample(nn.Module): 100 | def __init__(self, in_channels: int): 101 | super().__init__() 102 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 103 | 104 | def forward(self, x: Tensor): 105 | x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 106 | x = self.conv(x) 107 | return x 108 | 109 | 110 | class Encoder(nn.Module): 111 | def __init__( 112 | self, 113 | resolution: int, 114 | in_channels: int, 115 | ch: int, 116 | ch_mult: list[int], 117 | num_res_blocks: int, 118 | z_channels: int, 119 | ): 120 | super().__init__() 121 | self.ch = ch 122 | self.num_resolutions = len(ch_mult) 123 | self.num_res_blocks = num_res_blocks 124 | self.resolution = resolution 125 | self.in_channels = in_channels 126 | # downsampling 127 | self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 128 | 129 | curr_res = resolution 130 | in_ch_mult = (1,) + tuple(ch_mult) 131 | self.in_ch_mult = in_ch_mult 132 | self.down = nn.ModuleList() 133 | block_in = self.ch 134 | for i_level in range(self.num_resolutions): 135 | block = nn.ModuleList() 136 | attn = nn.ModuleList() 137 | block_in = ch * in_ch_mult[i_level] 138 | block_out = ch * ch_mult[i_level] 139 | for _ in range(self.num_res_blocks): 140 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 141 | block_in = block_out 142 | down = nn.Module() 143 | down.block = block 144 | down.attn = attn 145 | if i_level != self.num_resolutions - 1: 146 | down.downsample = Downsample(block_in) 147 | curr_res = curr_res // 2 148 | self.down.append(down) 149 | 150 | # middle 151 | self.mid = nn.Module() 152 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 153 | self.mid.attn_1 = AttnBlock(block_in) 154 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 155 | 156 | # end 157 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 158 | self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) 159 | 160 | def forward(self, x: Tensor) -> Tensor: 161 | # downsampling 162 | hs = [self.conv_in(x)] 163 | for i_level in range(self.num_resolutions): 164 | for i_block in range(self.num_res_blocks): 165 | h = self.down[i_level].block[i_block](hs[-1]) 166 | if len(self.down[i_level].attn) > 0: 167 | h = self.down[i_level].attn[i_block](h) 168 | hs.append(h) 169 | if i_level != self.num_resolutions - 1: 170 | hs.append(self.down[i_level].downsample(hs[-1])) 171 | 172 | # middle 173 | h = hs[-1] 174 | h = self.mid.block_1(h) 175 | h = self.mid.attn_1(h) 176 | h = self.mid.block_2(h) 177 | # end 178 | h = self.norm_out(h) 179 | h = swish(h) 180 | h = self.conv_out(h) 181 | return h 182 | 183 | 184 | class Decoder(nn.Module): 185 | def __init__( 186 | self, 187 | ch: int, 188 | out_ch: int, 189 | ch_mult: list[int], 190 | num_res_blocks: int, 191 | in_channels: int, 192 | resolution: int, 193 | z_channels: int, 194 | ): 195 | super().__init__() 196 | self.ch = ch 197 | self.num_resolutions = len(ch_mult) 198 | self.num_res_blocks = num_res_blocks 199 | self.resolution = resolution 200 | self.in_channels = in_channels 201 | self.ffactor = 2 ** (self.num_resolutions - 1) 202 | 203 | # compute in_ch_mult, block_in and curr_res at lowest res 204 | block_in = ch * ch_mult[self.num_resolutions - 1] 205 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 206 | self.z_shape = (1, z_channels, curr_res, curr_res) 207 | 208 | # z to block_in 209 | self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 210 | 211 | # middle 212 | self.mid = nn.Module() 213 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 214 | self.mid.attn_1 = AttnBlock(block_in) 215 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 216 | 217 | # upsampling 218 | self.up = nn.ModuleList() 219 | for i_level in reversed(range(self.num_resolutions)): 220 | block = nn.ModuleList() 221 | attn = nn.ModuleList() 222 | block_out = ch * ch_mult[i_level] 223 | for _ in range(self.num_res_blocks + 1): 224 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 225 | block_in = block_out 226 | up = nn.Module() 227 | up.block = block 228 | up.attn = attn 229 | if i_level != 0: 230 | up.upsample = Upsample(block_in) 231 | curr_res = curr_res * 2 232 | self.up.insert(0, up) # prepend to get consistent order 233 | 234 | # end 235 | self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 236 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 237 | 238 | def forward(self, z: Tensor) -> Tensor: 239 | # z to block_in 240 | h = self.conv_in(z) 241 | 242 | # middle 243 | h = self.mid.block_1(h) 244 | h = self.mid.attn_1(h) 245 | h = self.mid.block_2(h) 246 | 247 | # upsampling 248 | for i_level in reversed(range(self.num_resolutions)): 249 | for i_block in range(self.num_res_blocks + 1): 250 | h = self.up[i_level].block[i_block](h) 251 | if len(self.up[i_level].attn) > 0: 252 | h = self.up[i_level].attn[i_block](h) 253 | if i_level != 0: 254 | h = self.up[i_level].upsample(h) 255 | 256 | # end 257 | h = self.norm_out(h) 258 | h = swish(h) 259 | h = self.conv_out(h) 260 | return h 261 | 262 | 263 | 264 | class DiagonalGaussian(nn.Module): 265 | def __init__(self, sample: bool = True, chunk_dim: int = 1): 266 | super().__init__() 267 | self.sample = sample 268 | self.chunk_dim = chunk_dim 269 | 270 | def forward(self, z: Tensor) -> Tensor: 271 | mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) 272 | if self.sample: 273 | std = torch.exp(0.5 * logvar) 274 | return mean + std * torch.randn_like(mean) 275 | else: 276 | return mean 277 | 278 | 279 | class AutoEncoder(nn.Module): 280 | def __init__(self, params: AutoEncoderParams): 281 | super().__init__() 282 | self.encoder = Encoder( 283 | resolution=params.resolution, 284 | in_channels=params.in_channels, 285 | ch=params.ch, 286 | ch_mult=params.ch_mult, 287 | num_res_blocks=params.num_res_blocks, 288 | z_channels=params.z_channels, 289 | ) 290 | self.decoder = Decoder( 291 | resolution=params.resolution, 292 | in_channels=params.in_channels, 293 | ch=params.ch, 294 | out_ch=params.out_ch, 295 | ch_mult=params.ch_mult, 296 | num_res_blocks=params.num_res_blocks, 297 | z_channels=params.z_channels, 298 | ) 299 | self.reg = DiagonalGaussian() 300 | 301 | self.scale_factor = params.scale_factor 302 | self.shift_factor = params.shift_factor 303 | 304 | def encode(self, x: Tensor) -> Tensor: 305 | z = self.reg(self.encoder(x)) 306 | z = self.scale_factor * (z - self.shift_factor) 307 | return z 308 | 309 | def decode(self, z: Tensor) -> Tensor: 310 | z = z / self.scale_factor + self.shift_factor 311 | return self.decoder(z) 312 | 313 | def forward(self, x: Tensor) -> Tensor: 314 | return self.decode(self.encode(x)) 315 | -------------------------------------------------------------------------------- /src/flux/modules/conditioner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer 4 | 5 | class HFEmbedder(nn.Module): 6 | def __init__(self, version: str, max_length: int, torch_dtype = torch.bfloat16, device: str | torch.device = "cuda", **hf_kwargs): 7 | super().__init__() 8 | self.is_clip = version.startswith("openai") 9 | self.max_length = max_length 10 | self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" 11 | 12 | if self.is_clip: 13 | self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) 14 | self.hf_module: CLIPTextModel = self.load_clip(version, device=device, torch_dtype=torch_dtype, **hf_kwargs) 15 | else: 16 | self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) 17 | self.hf_module: T5EncoderModel = self.load_t5(version, device=device, torch_dtype=torch_dtype, **hf_kwargs) 18 | 19 | self.hf_module = self.hf_module.eval().requires_grad_(False) 20 | self.device = self.hf_module.device 21 | self.hf_module.compile() 22 | 23 | 24 | def load_t5(self, version: str, device: str | torch.device = "cuda", torch_dtype = torch.bfloat16, **hf_kwargs): 25 | t5 = T5EncoderModel.from_pretrained(version, torch_dtype=torch_dtype, **hf_kwargs) 26 | return t5.to(device) 27 | 28 | def load_clip(self, version: str, device: str | torch.device = "cuda", torch_dtype = torch.bfloat16, **hf_kwargs): 29 | clip = CLIPTextModel.from_pretrained(version, torch_dtype=torch_dtype, **hf_kwargs) 30 | return clip.to(device) 31 | 32 | def forward(self, text: list[str]) -> Tensor: 33 | batch_encoding = self.tokenizer( 34 | text, 35 | truncation=True, 36 | max_length=self.max_length, 37 | return_length=False, 38 | return_overflowing_tokens=False, 39 | padding="max_length", 40 | return_tensors="pt", 41 | ) 42 | 43 | input_ids = batch_encoding["input_ids"].to(self.hf_module.device) 44 | 45 | outputs = self.hf_module( 46 | input_ids=input_ids, 47 | attention_mask=None, 48 | output_hidden_states=False, 49 | ) 50 | 51 | return outputs[self.output_key] -------------------------------------------------------------------------------- /src/flux/modules/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | from torch import Tensor, nn 8 | 9 | from ..math import attention, rope 10 | 11 | 12 | class EmbedND(nn.Module): 13 | def __init__(self, dim: int, theta: int, axes_dim: list[int]): 14 | super().__init__() 15 | self.dim = dim 16 | self.theta = theta 17 | self.axes_dim = axes_dim 18 | 19 | def forward(self, ids: Tensor) -> Tensor: 20 | n_axes = ids.shape[-1] 21 | emb = torch.cat( 22 | [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 23 | dim=-3, 24 | ) 25 | 26 | return emb.unsqueeze(1) 27 | 28 | 29 | class TimestepEmbedding(nn.Module): 30 | def __init__(self, dim, max_period: int = 10_000, 31 | time_factor: float = 1_000.0, device: str | torch.device = "cuda"): 32 | super().__init__() 33 | half = dim // 2 34 | freqs = torch.exp(-math.log(max_period) * torch.arange(half, dtype=torch.float32, device=device) / half) 35 | self.register_buffer("freqs", freqs, persistent=False) 36 | self.time_factor = time_factor 37 | self.dim = dim 38 | 39 | def forward(self, t: torch.Tensor): 40 | t = t * self.time_factor 41 | args = t[:, None] * self.freqs 42 | sin, cos = torch.sin(args), torch.cos(args) 43 | emb = torch.cat((cos, sin), dim=-1) 44 | return emb 45 | 46 | def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): 47 | t = time_factor * t 48 | half = dim // 2 49 | freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( 50 | t.device 51 | ) 52 | 53 | args = t[:, None].float() * freqs[None] 54 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 55 | if dim % 2: 56 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 57 | if torch.is_floating_point(t): 58 | embedding = embedding.to(t) 59 | return embedding 60 | 61 | 62 | class MLPEmbedder(nn.Module): 63 | def __init__(self, in_dim: int, hidden_dim: int): 64 | super().__init__() 65 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) 66 | self.silu = nn.SiLU() 67 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | return self.out_layer(self.silu(self.in_layer(x))) 71 | 72 | 73 | class RMSNorm(torch.nn.Module): 74 | def __init__(self, dim: int): 75 | super().__init__() 76 | self.scale = nn.Parameter(torch.ones(dim)) 77 | 78 | def forward(self, x: Tensor): 79 | x_dtype = x.dtype 80 | x = x.float() 81 | rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) 82 | return (x * rrms).to(dtype=x_dtype) * self.scale 83 | 84 | class QKNorm(torch.nn.Module): 85 | def __init__(self, dim: int): 86 | super().__init__() 87 | self.query_norm = RMSNorm(dim) 88 | self.key_norm = RMSNorm(dim) 89 | 90 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: 91 | q = self.query_norm(q) 92 | k = self.key_norm(k) 93 | return q.to(v), k.to(v) 94 | 95 | class SelfAttention(nn.Module): 96 | def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): 97 | super().__init__() 98 | self.num_heads = num_heads 99 | head_dim = dim // num_heads 100 | 101 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 102 | self.norm = QKNorm(head_dim) 103 | self.proj = nn.Linear(dim, dim) 104 | 105 | def forward(): 106 | pass 107 | 108 | 109 | @dataclass 110 | class ModulationOut: 111 | shift: Tensor 112 | scale: Tensor 113 | gate: Tensor 114 | 115 | 116 | class Modulation(nn.Module): 117 | def __init__(self, dim: int, double: bool): 118 | super().__init__() 119 | self.is_double = double 120 | self.multiplier = 6 if double else 3 121 | self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) 122 | 123 | def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: 124 | out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) 125 | 126 | return ( 127 | ModulationOut(*out[:3]), 128 | ModulationOut(*out[3:]) if self.is_double else None, 129 | ) 130 | 131 | 132 | 133 | 134 | class DoubleStreamBlock(nn.Module): 135 | def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): 136 | super().__init__() 137 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 138 | self.num_heads = num_heads 139 | self.hidden_size = hidden_size 140 | self.head_dim = hidden_size // num_heads 141 | 142 | self.img_mod = Modulation(hidden_size, double=True) 143 | self.img_norm1 = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False) 144 | self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 145 | 146 | self.img_norm2 = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False) 147 | self.img_mlp = nn.Sequential( 148 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 149 | nn.GELU(approximate="tanh"), 150 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 151 | ) 152 | 153 | self.txt_mod = Modulation(hidden_size, double=True) 154 | self.txt_norm1 = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False) 155 | self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) 156 | 157 | self.txt_norm2 = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False) 158 | self.txt_mlp = nn.Sequential( 159 | nn.Linear(hidden_size, mlp_hidden_dim, bias=True), 160 | nn.GELU(approximate="tanh"), 161 | nn.Linear(mlp_hidden_dim, hidden_size, bias=True), 162 | ) 163 | 164 | def forward( 165 | self, 166 | img: Tensor, 167 | txt: Tensor, 168 | vec: Tensor, 169 | pe: Tensor, 170 | ) -> tuple[Tensor, Tensor]: 171 | img_mod1, img_mod2 = self.img_mod(vec) 172 | txt_mod1, txt_mod2 = self.txt_mod(vec) 173 | 174 | # prepare image for attention 175 | img_modulated = self.img_norm1(img) 176 | img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift 177 | img_qkv = self.img_attn.qkv(img_modulated) 178 | img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads, D=self.head_dim) 179 | img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) 180 | 181 | # prepare txt for attention 182 | txt_modulated = self.txt_norm1(txt) 183 | txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift 184 | txt_qkv = self.txt_attn.qkv(txt_modulated) 185 | txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads, D=self.head_dim) 186 | txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) 187 | 188 | # run actual attention 189 | q = torch.cat((txt_q, img_q), dim=2) 190 | k = torch.cat((txt_k, img_k), dim=2) 191 | v = torch.cat((txt_v, img_v), dim=2) 192 | 193 | b, h, l, d = q.shape 194 | 195 | attn1 = attention(q, k, v, pe=pe) 196 | txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] 197 | 198 | # calculate the img bloks 199 | img = img + img_mod1.gate * self.img_attn.proj(img_attn) 200 | img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) 201 | 202 | # calculate the txt bloks 203 | txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) 204 | txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) 205 | return img, txt 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | class SingleStreamBlock(nn.Module): 216 | """ 217 | A DiT block with parallel linear layers as described in 218 | https://arxiv.org/abs/2302.05442 and adapted modulation interface. 219 | """ 220 | 221 | def __init__( 222 | self, 223 | hidden_size: int, 224 | num_heads: int, 225 | mlp_ratio: float = 4.0, 226 | qk_scale: float | None = None, 227 | ): 228 | super().__init__() 229 | self.hidden_dim = hidden_size 230 | self.num_heads = num_heads 231 | self.head_dim = hidden_size // num_heads 232 | self.scale = qk_scale or self.head_dim**-0.5 233 | 234 | self.mlp_hidden_dim = int(hidden_size * mlp_ratio) 235 | # qkv and mlp_in 236 | self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) 237 | # proj and mlp_out 238 | self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) 239 | 240 | self.norm = QKNorm(self.head_dim) 241 | 242 | self.hidden_size = hidden_size 243 | self.pre_norm = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False) 244 | 245 | self.mlp_act = nn.GELU(approximate="tanh") 246 | self.modulation = Modulation(hidden_size, double=False) 247 | 248 | def forward( 249 | self, 250 | x: Tensor, 251 | vec: Tensor, 252 | pe: Tensor, 253 | ) -> Tensor: 254 | mod, _ = self.modulation(vec) 255 | x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift 256 | qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) 257 | 258 | q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) 259 | q, k = self.norm(q, k, v) 260 | 261 | 262 | # compute attention 263 | attn_1 = attention(q, k, v, pe=pe) 264 | 265 | # compute activation in mlp stream, cat again and run second linear layer 266 | output = self.linear2(torch.cat((attn_1, self.mlp_act(mlp)), 2)) 267 | output = x + mod.gate * output 268 | return output 269 | 270 | 271 | class LastLayer(nn.Module): 272 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int): 273 | super().__init__() 274 | self.norm_final = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False) 275 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 276 | self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) 277 | 278 | def forward(self, x: Tensor, vec: Tensor) -> Tensor: 279 | shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) 280 | x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] 281 | x = self.linear(x) 282 | return x 283 | 284 | -------------------------------------------------------------------------------- /src/flux/pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from PIL import Image 4 | 5 | from src.flux.sampling import denoise, get_noise, get_schedule, prepare, unpack 6 | from src.flux.util import ( 7 | load_ae, 8 | load_clip, 9 | load_flow_model, 10 | load_t5, 11 | ) 12 | 13 | class Pipeline: 14 | def __init__( 15 | self, 16 | model_type, 17 | device, 18 | sigma: float = 1.0, 19 | y1: float = 0.5, 20 | y2: float = 1.15, 21 | dtype: torch.dtype = torch.bfloat16, 22 | ): 23 | self.device = torch.device(device) 24 | self.clip = load_clip(self.device) 25 | self.t5 = load_t5(self.device, max_length=512) 26 | self.ae = load_ae(model_type, device=self.device) 27 | self.model = load_flow_model(model_type, device=self.device) 28 | 29 | self.y1 = y1 30 | self.y2 = y2 31 | self.sigma = sigma 32 | self.dtype = dtype 33 | 34 | def __call__( 35 | self, 36 | prompt: str, 37 | width: int = 1024, 38 | height: int = 1024, 39 | guidance: float = 4.5, 40 | num_steps: int = 32, 41 | seed: int = 42, 42 | ): 43 | width = 16 * (width // 16) 44 | height = 16 * (height // 16) 45 | 46 | x = get_noise(1, height, width, device=self.device, dtype=self.dtype, seed=seed) 47 | b, c, h, w = x.shape 48 | timesteps = get_schedule( 49 | num_steps, 50 | (w // 2) * (h // 2), 51 | base_shift = self.y1, 52 | max_shift = self.y2, 53 | sigma = self.sigma, 54 | ) 55 | 56 | with torch.no_grad(): 57 | inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=[prompt]) 58 | x = denoise( 59 | self.model, 60 | **inp, 61 | timesteps=timesteps, 62 | guidance=guidance, 63 | ) 64 | 65 | x = unpack(x, height, width) 66 | x = self.ae.decode(x) 67 | 68 | x1 = x.clamp(-1, 1) 69 | x1 = rearrange(x1[-1], "c h w -> h w c") 70 | img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) 71 | 72 | return img 73 | 74 | 75 | class Sampler(Pipeline): 76 | def __init__(self, clip, t5, ae, model, device, dtype, y1: float = 0.5, y2: float = 1.15, sigma: float = 1.0): 77 | self.clip = clip 78 | self.t5 = t5 79 | self.ae = ae 80 | self.model = model 81 | self.model.eval() 82 | self.device = device 83 | self.y1 = y1 84 | self.y2 = y2 85 | self.sigma = sigma 86 | self.dtype = dtype -------------------------------------------------------------------------------- /src/flux/sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable 3 | 4 | import torch 5 | from einops import rearrange, repeat 6 | from torch import Tensor 7 | from tqdm import tqdm 8 | from .model import Flux 9 | from .modules.conditioner import HFEmbedder 10 | 11 | def get_noise( 12 | num_samples: int, 13 | height: int, 14 | width: int, 15 | device: torch.device, 16 | dtype: torch.dtype, 17 | seed: int = None, 18 | ): 19 | generator = torch.Generator(device=device).manual_seed(seed) if seed else None 20 | return torch.randn( 21 | num_samples, 22 | 16, 23 | # allow for packing 24 | 2 * math.ceil(height / 16), 25 | 2 * math.ceil(width / 16), 26 | device=device, 27 | dtype=dtype, 28 | generator=generator 29 | ) 30 | 31 | 32 | def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: list[str]) -> dict[str, Tensor]: 33 | bs, c, h, w = img.shape 34 | img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 35 | 36 | img_ids = torch.zeros(h // 2, w // 2, 3) 37 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] 38 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] 39 | img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 40 | 41 | txt = t5(prompt) 42 | txt_ids = torch.zeros(bs, t5.max_length, 3) 43 | vec = clip(prompt) 44 | 45 | outputs = { 46 | "img": img, 47 | "img_ids": img_ids.to(img.device), 48 | "txt": txt.to(img.device), 49 | "txt_ids": txt_ids.to(img.device), 50 | "vec": vec.to(img.device), 51 | } 52 | 53 | return outputs 54 | 55 | def time_shift(mu: float, sigma: float, t: Tensor): 56 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 57 | 58 | def get_lin_function( 59 | x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 60 | ) -> Callable[[float], float]: 61 | m = (y2 - y1) / (x2 - x1) 62 | b = y1 - m * x1 63 | return lambda x: m * x + b 64 | 65 | 66 | def get_schedule( 67 | num_steps: int, 68 | image_seq_len: int, 69 | base_shift: float = 0.5, 70 | max_shift: float = 1.15, 71 | sigma: float = 1.0, 72 | shift: bool = True, 73 | ) -> list[float]: 74 | # extra step for zerod 75 | timesteps = torch.linspace(1, 0, num_steps + 1) 76 | 77 | # shifting the schedule to favor high timesteps for higher signal images 78 | if shift: 79 | # eastimate mu based on linear estimation between two points 80 | mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) 81 | timesteps = time_shift(mu, sigma, timesteps) 82 | 83 | return timesteps.tolist() 84 | 85 | 86 | def denoise( 87 | model: Flux, 88 | # model input 89 | img: Tensor, 90 | img_ids: Tensor, 91 | txt: Tensor, 92 | txt_ids: Tensor, 93 | vec: Tensor, 94 | # sampling parameters 95 | timesteps: list[float], 96 | guidance: float = 4.0, 97 | ): 98 | b, *_ = img.shape 99 | guidance = torch.full((b,), guidance, device=img.device, dtype=img.dtype) 100 | for tcurr, tprev in tqdm(zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1): 101 | tvec = torch.full((b,), tcurr, dtype=img.dtype, device=img.device) 102 | pred = model( 103 | img=img, 104 | img_ids=img_ids, 105 | txt=txt, 106 | txt_ids=txt_ids, 107 | y=vec, 108 | timesteps=tvec, 109 | guidance=guidance, 110 | ) 111 | img = img + (tprev - tcurr) * pred 112 | 113 | return img 114 | 115 | def unpack(x: Tensor, height: int, width: int, highres: bool = False) -> Tensor: 116 | return rearrange( 117 | x, 118 | "b (h w) (c ph pw) -> b c (h ph) (w pw)", 119 | h=math.ceil(height / (32 if highres else 16)), 120 | w=math.ceil(width / (32 if highres else 16)), 121 | ph=2, 122 | pw=2, 123 | ) 124 | -------------------------------------------------------------------------------- /src/flux/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | from huggingface_hub import hf_hub_download 6 | from safetensors.torch import load_file 7 | 8 | from .model import Flux, FluxParams 9 | from .modules.autoencoder import AutoEncoder, AutoEncoderParams 10 | from .modules.conditioner import HFEmbedder 11 | 12 | @dataclass 13 | class ModelSpec: 14 | params: FluxParams 15 | ae_params: AutoEncoderParams 16 | ckpt_path: str | None = None 17 | ae_path: str | None = None 18 | repo_id: str | None = None 19 | repo_flow: str | None = None 20 | repo_ae: str | None = None 21 | repo_id_ae: str | None = None 22 | 23 | 24 | configs = { 25 | "flux-krea-dev": ModelSpec( 26 | params=FluxParams( 27 | in_channels=64, 28 | vec_in_dim=768, 29 | context_in_dim=4096, 30 | hidden_size=3072, 31 | mlp_ratio=4.0, 32 | num_heads=24, 33 | depth=19, 34 | depth_single_blocks=38, 35 | axes_dim=[16, 56, 56], 36 | theta=10_000, 37 | qkv_bias=True, 38 | guidance_embed=True, 39 | ), 40 | ae_params=AutoEncoderParams( 41 | resolution=256, 42 | in_channels=3, 43 | ch=128, 44 | out_ch=3, 45 | ch_mult=[1, 2, 4, 4], 46 | num_res_blocks=2, 47 | z_channels=16, 48 | scale_factor=0.3611, 49 | shift_factor=0.1159, 50 | ), 51 | ckpt_path=os.getenv("FLUX"), 52 | ae_path=os.getenv("AE"), 53 | repo_id="black-forest-labs/FLUX.1-Krea-dev", 54 | repo_id_ae="black-forest-labs/FLUX.1-Krea-dev", 55 | repo_ae="ae.safetensors", 56 | repo_flow="flux1-krea-dev.safetensors" 57 | ), 58 | } 59 | 60 | def load_from_repo_id(repo_id, checkpoint_name): 61 | ckpt_path = hf_hub_download(repo_id, checkpoint_name) 62 | sd = load_file(ckpt_path, device='cpu') 63 | return sd 64 | 65 | def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): 66 | ckpt_path = configs[name].ckpt_path 67 | if ( 68 | ckpt_path is None 69 | and configs[name].repo_id is not None 70 | and configs[name].repo_flow is not None 71 | and hf_download 72 | ): 73 | ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) 74 | 75 | if ckpt_path is not None: 76 | sd = load_file(ckpt_path, device=str(device)) 77 | config = configs[name].params 78 | config.in_channels = sd["img_in.weight"].shape[1] 79 | 80 | print("Initialising model") 81 | with torch.device("meta"): 82 | model = Flux(config) 83 | model = model.to(dtype=torch.bfloat16) 84 | 85 | if ckpt_path is not None: 86 | print(f"Loading flow checkpoint to model from {ckpt_path}") 87 | model.load_state_dict(sd, strict=False, assign=True) 88 | 89 | return model 90 | 91 | def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: 92 | embedder = HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16, device=device) 93 | return embedder 94 | 95 | def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: 96 | embedder = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16, device=device) 97 | return embedder 98 | 99 | def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: 100 | ckpt_path = configs[name].ae_path 101 | if ( 102 | ckpt_path is None 103 | and configs[name].repo_id is not None 104 | and configs[name].repo_ae is not None 105 | and hf_download 106 | ): 107 | ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) 108 | 109 | with torch.device("meta"): 110 | ae = AutoEncoder(configs[name].ae_params) 111 | ae = ae.to_empty(device=device) 112 | ae = ae.to(dtype=torch.bfloat16) 113 | 114 | if ckpt_path is not None: 115 | print(f"Loading AE checkpoint from path {ckpt_path}") 116 | sd = load_file(ckpt_path, device=str(device)) 117 | ae.load_state_dict(sd, strict=False, assign=True) 118 | 119 | return ae --------------------------------------------------------------------------------