├── LICENSE
├── README.md
├── img
├── CustomAnyID_img.png
├── CustomID_img.png
├── controlnet_img.png
├── logo.png
└── man1.jpg
├── infer_customID.ipynb
├── pretrained_ckpt
└── src
├── customID
├── __init__.py
├── attention_processor.py
├── attention_processor_ori.py
├── model.py
├── pipeline_flux.py
├── resampler.py
├── transformer_flux.py
├── transformer_flux_ori.py
└── utils.py
└── utils
└── insightface_package.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 DamoCV
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ##
FLUX-customID: Realistically Customize Your Personal ID to Perfection
4 |
5 | This repository is the official implementation of FLUX-customID. It is capable of generating images based on your face image at a level equivalent to real photographic quality. Our base model is FLUX.dev, which ensures the generation of high-quality images.
6 |
7 | ## News
8 | - 🌟**2024-11-13**: Released the code and weights for FLUX-customID.
9 |
10 | ## Gallery
11 | Here are some example samples generated by our method.
12 |
13 |
14 |
15 | ## Quick Start
16 |
17 | ### 1. Setup Repository and Environment
18 |
19 | ```
20 | conda create -n customID python=3.10 -y
21 | conda activate customID
22 | conda install pytorch==2.4.0 torchvision==0.19.0 pytorch-cuda=11.8 -c pytorch -c nvidia -y
23 | pip install -i https://mirrors.cloud.tencent.com/pypi/simple diffusers==0.31.0 transformers onnxruntime-gpu insightface sentencepiece matplotlib imageio tqdm numpy einops accelerate peft
24 | ```
25 |
26 | ### 2. Prepare Pretrained Checkpoints
27 |
28 | ```
29 | git clone https://github.com/damo-cv/FLUX-customID.git
30 | cd FLUX-customID
31 |
32 | mkdir pretrained_ckpt
33 | cd pretrained_ckpt
34 |
35 | #Download CLIP
36 | export HF_ENDPOINT=https://hf-mirror.com
37 | pip install -U "huggingface_hub[cli]"
38 |
39 | huggingface-cli download \
40 | --resume-download "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" \
41 | --cache-dir your_dir/
42 |
43 | ln -s your_dir/models--laion--CLIP-ViT-H-14-laion2B-s32B-b79K/snapshots/de081ac0a0ca8dc9d1533eed1ae884bb8ae1404b pretrained_ckpt/openclip-vit-h-14
44 |
45 | #Download FLUX.1-dev
46 | huggingface-cli download \
47 | --resume-download "black-forest-labs/FLUX.1-dev" \
48 | --cache-dir your_dir/
49 |
50 | ln -s your_dir/models--black-forest-labs--FLUX.1-dev/snapshots/303875135fff3f05b6aa893d544f28833a237d58 pretrained_ckpt/flux.1-dev
51 |
52 | #Download FLUX-customID
53 | Download our trained checkpoint from https://huggingface.co/Damo-vision/FLUX-customID and place FLUX-customID.pt in the floder pretrained_ckpt/
54 | ```
55 |
56 | ### 3. Quick Inference
57 | ```
58 | run infer_customID.ipynb
59 | ```
60 |
61 | ## Preview for CustomAnyID
62 | We would like to announce that we are currently working on a related project, **CustomAnyID**. Below are some preliminary experimental results:
63 |
64 |
65 |
66 | ## Preview for Controlnet
67 | We would like to announce our Controlnet model. Below are some preliminary experimental results:
68 |
69 |
70 |
71 | ## Contact Us
72 | Dongyang Li: [yingtian.ldy@alibaba-inc.com](yingtian.ldy@alibaba-inc.com)
73 |
74 | ## Acknowledgements
75 | The partial code is implemented based on [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) and [PhotoMaker](https://github.com/TencentARC/PhotoMaker).
76 |
--------------------------------------------------------------------------------
/img/CustomAnyID_img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/FLUX-customID/591f170db746be1c7da740996c1af5f4dcd729a1/img/CustomAnyID_img.png
--------------------------------------------------------------------------------
/img/CustomID_img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/FLUX-customID/591f170db746be1c7da740996c1af5f4dcd729a1/img/CustomID_img.png
--------------------------------------------------------------------------------
/img/controlnet_img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/FLUX-customID/591f170db746be1c7da740996c1af5f4dcd729a1/img/controlnet_img.png
--------------------------------------------------------------------------------
/img/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/FLUX-customID/591f170db746be1c7da740996c1af5f4dcd729a1/img/logo.png
--------------------------------------------------------------------------------
/img/man1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/FLUX-customID/591f170db746be1c7da740996c1af5f4dcd729a1/img/man1.jpg
--------------------------------------------------------------------------------
/infer_customID.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "from PIL import Image\n",
11 | "import torch\n",
12 | "from src.customID.pipeline_flux import FluxPipeline\n",
13 | "from src.customID.transformer_flux import FluxTransformer2DModel\n",
14 | "from src.customID.model import CustomIDModel\n",
15 | "\n",
16 | "def image_grid(imgs, rows, cols):\n",
17 | " assert len(imgs) == rows*cols\n",
18 | " w, h = imgs[0].size\n",
19 | " grid = Image.new('RGB', size=(cols*w, rows*h))\n",
20 | " grid_w, grid_h = grid.size\n",
21 | " \n",
22 | " for i, img in enumerate(imgs):\n",
23 | " grid.paste(img, box=(i%cols*w, i//cols*h))\n",
24 | " return grid\n",
25 | "\n",
26 | "_DEVICE = \"cuda:0\"\n",
27 | "_DTYPE=torch.bfloat16\n",
28 | "model_path = \"pretrained_ckpt/flux.1-dev\" #you can also use `black-forest-labs/FLUX.1-dev`\n",
29 | "transformer = FluxTransformer2DModel.from_pretrained(model_path, subfolder=\"transformer\", torch_dtype=_DTYPE).to(_DEVICE)\n",
30 | "pipe = FluxPipeline.from_pretrained(model_path, transformer=transformer, torch_dtype=_DTYPE).to(_DEVICE)"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "num_token=64\n",
40 | "trained_ckpt = \"pretrained_ckpt/FLUX-customID.pt\"\n",
41 | "customID_model = CustomIDModel(pipe, trained_ckpt, _DEVICE, _DTYPE, num_token)"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "num_samples=3\n",
51 | "gs= 3.5\n",
52 | "_seed=2024\n",
53 | "h=1024\n",
54 | "w=1024\n",
55 | "img_path = \"img/man1.jpg\"\n",
56 | "p=\"A man wearing a classic leather jacket leans against a vintage motorcycle, surrounded by autumn leaves swirling in the breeze.\"\n",
57 | "images = customID_model.generate(pil_image=img_path,\n",
58 | " prompt=p,\n",
59 | " num_samples=num_samples,\n",
60 | " height=h,\n",
61 | " width=w,\n",
62 | " seed=_seed,\n",
63 | " num_inference_steps=28,\n",
64 | " guidance_scale=gs)\n",
65 | "grid = image_grid(images, 1, num_samples)\n",
66 | "grid"
67 | ]
68 | }
69 | ],
70 | "metadata": {
71 | "kernelspec": {
72 | "display_name": "pt20",
73 | "language": "python",
74 | "name": "python3"
75 | },
76 | "language_info": {
77 | "codemirror_mode": {
78 | "name": "ipython",
79 | "version": 3
80 | },
81 | "file_extension": ".py",
82 | "mimetype": "text/x-python",
83 | "name": "python",
84 | "nbconvert_exporter": "python",
85 | "pygments_lexer": "ipython3",
86 | "version": "3.10.15"
87 | }
88 | },
89 | "nbformat": 4,
90 | "nbformat_minor": 2
91 | }
92 |
--------------------------------------------------------------------------------
/pretrained_ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/FLUX-customID/591f170db746be1c7da740996c1af5f4dcd729a1/pretrained_ckpt
--------------------------------------------------------------------------------
/src/customID/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/damo-cv/FLUX-customID/591f170db746be1c7da740996c1af5f4dcd729a1/src/customID/__init__.py
--------------------------------------------------------------------------------
/src/customID/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: LiDongyang(yingtian.ldy@alibaba-inc.com | amo5lee@aliyun.com)
3 | Date: 2024-10
4 | Description: Customized Image Generation Model Based on Facial ID.
5 | """
6 | import os
7 | from typing import List
8 | # import math
9 | import torch
10 | from PIL import Image
11 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
12 | import pdb
13 | from .utils import is_torch2_available, get_generator
14 | from .attention_processor import CATRefFluxAttnProcessor2_0
15 | from .resampler import PerceiverAttention, FeedForward
16 | from src.utils.insightface_package import FaceAnalysis2, analyze_faces
17 | import cv2
18 | USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
19 |
20 | class FacePerceiverResampler(torch.nn.Module):
21 | def __init__(
22 | self,
23 | *,
24 | dim=768,
25 | depth=4,
26 | dim_head=64,
27 | heads=16,
28 | embedding_dim=1280,
29 | output_dim=768,
30 | ff_mult=4,
31 | ):
32 | super().__init__()
33 |
34 | self.proj_in = torch.nn.Linear(embedding_dim, dim)
35 | self.proj_out = torch.nn.Linear(dim, output_dim)
36 | self.norm_out = torch.nn.LayerNorm(output_dim)
37 | self.layers = torch.nn.ModuleList([])
38 | for _ in range(depth):
39 | self.layers.append(
40 | torch.nn.ModuleList(
41 | [
42 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
43 | FeedForward(dim=dim, mult=ff_mult),
44 | ]
45 | )
46 | )
47 |
48 | def forward(self, latents, x):
49 | x = self.proj_in(x)
50 | for attn, ff in self.layers:
51 | latents = attn(x, latents) + latents
52 | latents = ff(latents) + latents
53 | latents = self.proj_out(latents)
54 | return self.norm_out(latents)
55 |
56 | class ProjPlusModel(torch.nn.Module):
57 | def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4, output_dim=3072):
58 | super().__init__()
59 |
60 | self.cross_attention_dim = cross_attention_dim
61 | self.num_tokens = num_tokens
62 |
63 | self.proj = torch.nn.Sequential(
64 | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
65 | torch.nn.GELU(),
66 | torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
67 | )
68 | self.norm = torch.nn.LayerNorm(cross_attention_dim)
69 |
70 | self.perceiver_resampler = FacePerceiverResampler(
71 | dim=cross_attention_dim,
72 | depth=4,
73 | dim_head=64,
74 | heads=cross_attention_dim // 64,
75 | embedding_dim=clip_embeddings_dim,
76 | output_dim=output_dim,
77 | ff_mult=4,
78 | )
79 | self.prj_out_clip = torch.nn.Sequential(
80 | torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim*2),
81 | torch.nn.GELU(),
82 | torch.nn.Linear(clip_embeddings_dim*2, output_dim),
83 | )
84 |
85 | def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
86 |
87 | x = self.proj(id_embeds)
88 | x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
89 | x = self.norm(x)
90 | out = self.perceiver_resampler(x, clip_embeds)
91 | if shortcut:
92 | out = x + scale * out
93 | return torch.cat([out, self.prj_out_clip(clip_embeds)], dim=1)
94 |
95 | class CustomIDModel:
96 | def __init__(self, sd_pipe, trained_ckpt, device, dtype, num_tokens=4, image_encoder_path="pretrained_ckpt/openclip-vit-h-14"):
97 | self.device = device
98 | self.dtype = dtype
99 | self.trained_ckpt = trained_ckpt
100 | self.num_tokens = num_tokens
101 | self.pipe = sd_pipe
102 | self.image_encoder_path = image_encoder_path
103 |
104 | # load image encoder
105 | self.clip_image_processor = CLIPImageProcessor()
106 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
107 | self.device)
108 |
109 | self.set_id_adapter()
110 |
111 | # image proj model
112 | self.image_proj_model = self.init_proj()
113 | self.image_proj_model.to(self.device)
114 | if self.trained_ckpt != None:
115 | self.load_id_adapter()
116 |
117 | self.face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition'])
118 | self.face_detector.prepare(ctx_id=0, det_size=(640, 640))
119 |
120 | def init_proj(self):
121 | image_proj_model = ProjPlusModel(
122 | cross_attention_dim=self.image_encoder.config.hidden_size,
123 | id_embeddings_dim=512,
124 | clip_embeddings_dim=self.image_encoder.config.hidden_size,
125 | num_tokens=self.num_tokens,
126 | output_dim=self.pipe.transformer.config.num_attention_heads * self.pipe.transformer.config.attention_head_dim,
127 | ).to(self.device)
128 | return image_proj_model
129 |
130 | def set_id_adapter(self):
131 | # init adapter modules
132 | attn_procs = {}
133 | transformer_sd = self.pipe.transformer.state_dict()
134 | for name in self.pipe.transformer.attn_processors.keys():
135 | if name.startswith("transformer_blocks"):
136 | attn_procs[name] = CATRefFluxAttnProcessor2_0(self.pipe.transformer.config.num_attention_heads * self.pipe.transformer.config.attention_head_dim,
137 | self.pipe.transformer.config.num_attention_heads * self.pipe.transformer.config.attention_head_dim,
138 | self.pipe.transformer.config.attention_head_dim,
139 | self.num_tokens+256,#!
140 | ).to(self.device, dtype=self.dtype)
141 | elif name.startswith("single_transformer_blocks"):
142 | attn_procs[name] = CATRefFluxAttnProcessor2_0(self.pipe.transformer.config.num_attention_heads * self.pipe.transformer.config.attention_head_dim,
143 | self.pipe.transformer.config.num_attention_heads * self.pipe.transformer.config.attention_head_dim,
144 | self.pipe.transformer.config.attention_head_dim,
145 | self.num_tokens+256,
146 | ).to(self.device, dtype=self.dtype)
147 | self.pipe.transformer.set_attn_processor(attn_procs)
148 |
149 | def load_id_adapter(self):
150 | state_dict = torch.load(self.trained_ckpt, map_location=torch.device('cpu'))
151 | self.image_proj_model.load_state_dict(state_dict["img_prj_state"], strict=True)
152 | m,u = self.pipe.transformer.load_state_dict(state_dict["attn_processor_state"], strict=False)
153 | assert len(u)==0
154 |
155 | @torch.inference_mode()
156 | def get_image_embeds(self, pil_image):
157 | image_ = cv2.imread(pil_image)
158 | faces = analyze_faces(self.face_detector, image_)
159 | faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0).to(self.device)
160 | faceid_embeds = faceid_embeds.unsqueeze(0)
161 |
162 | #clip
163 | face_image = Image.open(pil_image)
164 | if isinstance(face_image, Image.Image):
165 | pil_image = [face_image]
166 | clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
167 | clip_image = clip_image.to(self.device, dtype=self.dtype)
168 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
169 |
170 | ip_tokens = self.image_proj_model(faceid_embeds, clip_image_embeds[:,1:,:])
171 | assert ip_tokens.shape[1] == self.num_tokens+256
172 |
173 | return ip_tokens.to(self.device, dtype=self.dtype)
174 |
175 | def generate(
176 | self,
177 | pil_image=None,
178 | prompt=None,
179 | num_samples=4,
180 | height=1024,
181 | width=1024,
182 | seed=None,
183 | num_inference_steps=30,
184 | guidance_scale=3.5,
185 | ):
186 |
187 | ip_tokens = self.get_image_embeds(pil_image=pil_image)
188 |
189 | bs_embed, seq_len, _ = ip_tokens.shape
190 | ip_tokens = ip_tokens.repeat(1, num_samples, 1)
191 | ip_tokens = ip_tokens.view(bs_embed * num_samples, seq_len, -1)
192 | ip_tokens = ip_tokens.to(self.device).to(self.dtype)
193 |
194 | ip_token_ids = self.pipe._prepare_latent_image_ids(
195 | 1,
196 | 1*2,
197 | (self.num_tokens+256)*2,
198 | self.device,
199 | self.dtype,
200 | )
201 | images = self.pipe(
202 | prompt,
203 | ip_token=ip_tokens,
204 | ip_token_ids=ip_token_ids,
205 | num_images_per_prompt=num_samples,
206 | height=height,
207 | width=width,
208 | output_type="pil",
209 | num_inference_steps=num_inference_steps,
210 | generator=torch.Generator(self.device).manual_seed(seed),
211 | guidance_scale=guidance_scale,
212 | ).images
213 |
214 | return images
--------------------------------------------------------------------------------
/src/customID/pipeline_flux.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import inspect
16 | from typing import Any, Callable, Dict, List, Optional, Union
17 |
18 | import numpy as np
19 | import torch
20 | from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
21 |
22 | from diffusers.image_processor import VaeImageProcessor
23 | from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin
24 | from diffusers.models.autoencoders import AutoencoderKL
25 | from diffusers.models.transformers import FluxTransformer2DModel
26 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27 | from diffusers.utils import (
28 | USE_PEFT_BACKEND,
29 | is_torch_xla_available,
30 | logging,
31 | replace_example_docstring,
32 | scale_lora_layers,
33 | unscale_lora_layers,
34 | )
35 | from diffusers.utils.torch_utils import randn_tensor
36 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
37 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
38 |
39 |
40 | if is_torch_xla_available():
41 | import torch_xla.core.xla_model as xm
42 |
43 | XLA_AVAILABLE = True
44 | else:
45 | XLA_AVAILABLE = False
46 |
47 |
48 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49 |
50 | EXAMPLE_DOC_STRING = """
51 | Examples:
52 | ```py
53 | >>> import torch
54 | >>> from diffusers import FluxPipeline
55 |
56 | >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
57 | >>> pipe.to("cuda")
58 | >>> prompt = "A cat holding a sign that says hello world"
59 | >>> # Depending on the variant being used, the pipeline call will slightly vary.
60 | >>> # Refer to the pipeline documentation for more details.
61 | >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
62 | >>> image.save("flux.png")
63 | ```
64 | """
65 |
66 |
67 | def calculate_shift(
68 | image_seq_len,
69 | base_seq_len: int = 256,
70 | max_seq_len: int = 4096,
71 | base_shift: float = 0.5,
72 | max_shift: float = 1.16,
73 | ):
74 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
75 | b = base_shift - m * base_seq_len
76 | mu = image_seq_len * m + b
77 | return mu
78 |
79 |
80 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
81 | def retrieve_timesteps(
82 | scheduler,
83 | num_inference_steps: Optional[int] = None,
84 | device: Optional[Union[str, torch.device]] = None,
85 | timesteps: Optional[List[int]] = None,
86 | sigmas: Optional[List[float]] = None,
87 | **kwargs,
88 | ):
89 | """
90 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
91 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
92 |
93 | Args:
94 | scheduler (`SchedulerMixin`):
95 | The scheduler to get timesteps from.
96 | num_inference_steps (`int`):
97 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
98 | must be `None`.
99 | device (`str` or `torch.device`, *optional*):
100 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
101 | timesteps (`List[int]`, *optional*):
102 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
103 | `num_inference_steps` and `sigmas` must be `None`.
104 | sigmas (`List[float]`, *optional*):
105 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
106 | `num_inference_steps` and `timesteps` must be `None`.
107 |
108 | Returns:
109 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
110 | second element is the number of inference steps.
111 | """
112 | if timesteps is not None and sigmas is not None:
113 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
114 | if timesteps is not None:
115 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
116 | if not accepts_timesteps:
117 | raise ValueError(
118 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
119 | f" timestep schedules. Please check whether you are using the correct scheduler."
120 | )
121 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
122 | timesteps = scheduler.timesteps
123 | num_inference_steps = len(timesteps)
124 | elif sigmas is not None:
125 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126 | if not accept_sigmas:
127 | raise ValueError(
128 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129 | f" sigmas schedules. Please check whether you are using the correct scheduler."
130 | )
131 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
132 | timesteps = scheduler.timesteps
133 | num_inference_steps = len(timesteps)
134 | else:
135 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
136 | timesteps = scheduler.timesteps
137 | return timesteps, num_inference_steps
138 |
139 |
140 | class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
141 | r"""
142 | The Flux pipeline for text-to-image generation.
143 |
144 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
145 |
146 | Args:
147 | transformer ([`FluxTransformer2DModel`]):
148 | Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
149 | scheduler ([`FlowMatchEulerDiscreteScheduler`]):
150 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
151 | vae ([`AutoencoderKL`]):
152 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
153 | text_encoder ([`CLIPTextModel`]):
154 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
155 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
156 | text_encoder_2 ([`T5EncoderModel`]):
157 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
158 | the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
159 | tokenizer (`CLIPTokenizer`):
160 | Tokenizer of class
161 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
162 | tokenizer_2 (`T5TokenizerFast`):
163 | Second Tokenizer of class
164 | [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
165 | """
166 |
167 | model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
168 | _optional_components = []
169 | _callback_tensor_inputs = ["latents", "prompt_embeds"]
170 |
171 | def __init__(
172 | self,
173 | scheduler: FlowMatchEulerDiscreteScheduler,
174 | vae: AutoencoderKL,
175 | text_encoder: CLIPTextModel,
176 | tokenizer: CLIPTokenizer,
177 | text_encoder_2: T5EncoderModel,
178 | tokenizer_2: T5TokenizerFast,
179 | transformer: FluxTransformer2DModel,
180 | ):
181 | super().__init__()
182 |
183 | self.register_modules(
184 | vae=vae,
185 | text_encoder=text_encoder,
186 | text_encoder_2=text_encoder_2,
187 | tokenizer=tokenizer,
188 | tokenizer_2=tokenizer_2,
189 | transformer=transformer,
190 | scheduler=scheduler,
191 | )
192 | self.vae_scale_factor = (
193 | 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
194 | )
195 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
196 | self.tokenizer_max_length = (
197 | self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
198 | )
199 | self.default_sample_size = 64
200 |
201 | def _get_t5_prompt_embeds(
202 | self,
203 | prompt: Union[str, List[str]] = None,
204 | num_images_per_prompt: int = 1,
205 | max_sequence_length: int = 512,
206 | device: Optional[torch.device] = None,
207 | dtype: Optional[torch.dtype] = None,
208 | ):
209 | device = device or self._execution_device
210 | dtype = dtype or self.text_encoder.dtype
211 |
212 | prompt = [prompt] if isinstance(prompt, str) else prompt
213 | batch_size = len(prompt)
214 |
215 | text_inputs = self.tokenizer_2(
216 | prompt,
217 | padding="max_length",
218 | max_length=max_sequence_length,
219 | truncation=True,
220 | return_length=False,
221 | return_overflowing_tokens=False,
222 | return_tensors="pt",
223 | )
224 | text_input_ids = text_inputs.input_ids
225 | untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
226 |
227 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
228 | removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
229 | logger.warning(
230 | "The following part of your input was truncated because `max_sequence_length` is set to "
231 | f" {max_sequence_length} tokens: {removed_text}"
232 | )
233 |
234 | prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
235 |
236 | dtype = self.text_encoder_2.dtype
237 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
238 |
239 | _, seq_len, _ = prompt_embeds.shape
240 |
241 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
242 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
243 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
244 |
245 | return prompt_embeds
246 |
247 | def _get_clip_prompt_embeds(
248 | self,
249 | prompt: Union[str, List[str]],
250 | num_images_per_prompt: int = 1,
251 | device: Optional[torch.device] = None,
252 | ):
253 | device = device or self._execution_device
254 |
255 | prompt = [prompt] if isinstance(prompt, str) else prompt
256 | batch_size = len(prompt)
257 |
258 | text_inputs = self.tokenizer(
259 | prompt,
260 | padding="max_length",
261 | max_length=self.tokenizer_max_length,
262 | truncation=True,
263 | return_overflowing_tokens=False,
264 | return_length=False,
265 | return_tensors="pt",
266 | )
267 |
268 | text_input_ids = text_inputs.input_ids
269 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
270 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
271 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
272 | logger.warning(
273 | "The following part of your input was truncated because CLIP can only handle sequences up to"
274 | f" {self.tokenizer_max_length} tokens: {removed_text}"
275 | )
276 | prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
277 |
278 | # Use pooled output of CLIPTextModel
279 | prompt_embeds = prompt_embeds.pooler_output
280 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
281 |
282 | # duplicate text embeddings for each generation per prompt, using mps friendly method
283 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
284 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
285 |
286 | return prompt_embeds
287 |
288 | def encode_prompt(
289 | self,
290 | prompt: Union[str, List[str]],
291 | prompt_2: Union[str, List[str]],
292 | device: Optional[torch.device] = None,
293 | num_images_per_prompt: int = 1,
294 | prompt_embeds: Optional[torch.FloatTensor] = None,
295 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
296 | max_sequence_length: int = 512,
297 | lora_scale: Optional[float] = None,
298 | ):
299 | r"""
300 |
301 | Args:
302 | prompt (`str` or `List[str]`, *optional*):
303 | prompt to be encoded
304 | prompt_2 (`str` or `List[str]`, *optional*):
305 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
306 | used in all text-encoders
307 | device: (`torch.device`):
308 | torch device
309 | num_images_per_prompt (`int`):
310 | number of images that should be generated per prompt
311 | prompt_embeds (`torch.FloatTensor`, *optional*):
312 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
313 | provided, text embeddings will be generated from `prompt` input argument.
314 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
315 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
316 | If not provided, pooled text embeddings will be generated from `prompt` input argument.
317 | lora_scale (`float`, *optional*):
318 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
319 | """
320 | device = device or self._execution_device
321 |
322 | # set lora scale so that monkey patched LoRA
323 | # function of text encoder can correctly access it
324 | if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
325 | self._lora_scale = lora_scale
326 |
327 | # dynamically adjust the LoRA scale
328 | if self.text_encoder is not None and USE_PEFT_BACKEND:
329 | scale_lora_layers(self.text_encoder, lora_scale)
330 | if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
331 | scale_lora_layers(self.text_encoder_2, lora_scale)
332 |
333 | prompt = [prompt] if isinstance(prompt, str) else prompt
334 |
335 | if prompt_embeds is None:
336 | prompt_2 = prompt_2 or prompt
337 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
338 |
339 | # We only use the pooled prompt output from the CLIPTextModel
340 | pooled_prompt_embeds = self._get_clip_prompt_embeds(
341 | prompt=prompt,
342 | device=device,
343 | num_images_per_prompt=num_images_per_prompt,
344 | )
345 | prompt_embeds = self._get_t5_prompt_embeds(
346 | prompt=prompt_2,
347 | num_images_per_prompt=num_images_per_prompt,
348 | max_sequence_length=max_sequence_length,
349 | device=device,
350 | )
351 |
352 | if self.text_encoder is not None:
353 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
354 | # Retrieve the original scale by scaling back the LoRA layers
355 | unscale_lora_layers(self.text_encoder, lora_scale)
356 |
357 | if self.text_encoder_2 is not None:
358 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
359 | # Retrieve the original scale by scaling back the LoRA layers
360 | unscale_lora_layers(self.text_encoder_2, lora_scale)
361 |
362 | dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
363 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
364 |
365 | return prompt_embeds, pooled_prompt_embeds, text_ids
366 |
367 | def check_inputs(
368 | self,
369 | prompt,
370 | prompt_2,
371 | height,
372 | width,
373 | prompt_embeds=None,
374 | pooled_prompt_embeds=None,
375 | callback_on_step_end_tensor_inputs=None,
376 | max_sequence_length=None,
377 | ):
378 | if height % 8 != 0 or width % 8 != 0:
379 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
380 |
381 | if callback_on_step_end_tensor_inputs is not None and not all(
382 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
383 | ):
384 | raise ValueError(
385 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
386 | )
387 |
388 | if prompt is not None and prompt_embeds is not None:
389 | raise ValueError(
390 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
391 | " only forward one of the two."
392 | )
393 | elif prompt_2 is not None and prompt_embeds is not None:
394 | raise ValueError(
395 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
396 | " only forward one of the two."
397 | )
398 | elif prompt is None and prompt_embeds is None:
399 | raise ValueError(
400 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
401 | )
402 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
403 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
404 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
405 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
406 |
407 | if prompt_embeds is not None and pooled_prompt_embeds is None:
408 | raise ValueError(
409 | "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
410 | )
411 |
412 | if max_sequence_length is not None and max_sequence_length > 512:
413 | raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
414 |
415 | @staticmethod
416 | def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
417 | latent_image_ids = torch.zeros(height // 2, width // 2, 3)
418 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
419 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
420 |
421 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
422 |
423 | latent_image_ids = latent_image_ids.reshape(
424 | latent_image_id_height * latent_image_id_width, latent_image_id_channels
425 | )
426 |
427 | return latent_image_ids.to(device=device, dtype=dtype)
428 |
429 | @staticmethod
430 | def _pack_latents(latents, batch_size, num_channels_latents, height, width):
431 | latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
432 | latents = latents.permute(0, 2, 4, 1, 3, 5)
433 | latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
434 |
435 | return latents
436 |
437 | @staticmethod
438 | def _unpack_latents(latents, height, width, vae_scale_factor):
439 | batch_size, num_patches, channels = latents.shape
440 |
441 | height = height // vae_scale_factor
442 | width = width // vae_scale_factor
443 |
444 | latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
445 | latents = latents.permute(0, 3, 1, 4, 2, 5)
446 |
447 | latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
448 |
449 | return latents
450 |
451 | def enable_vae_slicing(self):
452 | r"""
453 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
454 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
455 | """
456 | self.vae.enable_slicing()
457 |
458 | def disable_vae_slicing(self):
459 | r"""
460 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
461 | computing decoding in one step.
462 | """
463 | self.vae.disable_slicing()
464 |
465 | def enable_vae_tiling(self):
466 | r"""
467 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
468 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
469 | processing larger images.
470 | """
471 | self.vae.enable_tiling()
472 |
473 | def disable_vae_tiling(self):
474 | r"""
475 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
476 | computing decoding in one step.
477 | """
478 | self.vae.disable_tiling()
479 |
480 | def prepare_latents(
481 | self,
482 | batch_size,
483 | num_channels_latents,
484 | height,
485 | width,
486 | dtype,
487 | device,
488 | generator,
489 | latents=None,
490 | ):
491 | height = 2 * (int(height) // self.vae_scale_factor)
492 | width = 2 * (int(width) // self.vae_scale_factor)
493 |
494 | shape = (batch_size, num_channels_latents, height, width)
495 |
496 | if latents is not None:
497 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
498 | return latents.to(device=device, dtype=dtype), latent_image_ids
499 |
500 | if isinstance(generator, list) and len(generator) != batch_size:
501 | raise ValueError(
502 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
503 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
504 | )
505 |
506 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
507 | latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
508 |
509 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
510 |
511 | return latents, latent_image_ids
512 |
513 | @property
514 | def guidance_scale(self):
515 | return self._guidance_scale
516 |
517 | @property
518 | def joint_attention_kwargs(self):
519 | return self._joint_attention_kwargs
520 |
521 | @property
522 | def num_timesteps(self):
523 | return self._num_timesteps
524 |
525 | @property
526 | def interrupt(self):
527 | return self._interrupt
528 |
529 | @torch.no_grad()
530 | @replace_example_docstring(EXAMPLE_DOC_STRING)
531 | def __call__(
532 | self,
533 | prompt: Union[str, List[str]] = None,
534 | prompt_2: Optional[Union[str, List[str]]] = None,
535 | height: Optional[int] = None,
536 | width: Optional[int] = None,
537 | num_inference_steps: int = 28,
538 | timesteps: List[int] = None,
539 | guidance_scale: float = 7.0,
540 | num_images_per_prompt: Optional[int] = 1,
541 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
542 | latents: Optional[torch.FloatTensor] = None,
543 | prompt_embeds: Optional[torch.FloatTensor] = None,
544 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
545 | output_type: Optional[str] = "pil",
546 | return_dict: bool = True,
547 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
548 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
549 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
550 | max_sequence_length: int = 512,
551 | ip_token = None,
552 | ip_token_ids=None,
553 | ):
554 | r"""
555 | Function invoked when calling the pipeline for generation.
556 |
557 | Args:
558 | prompt (`str` or `List[str]`, *optional*):
559 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
560 | instead.
561 | prompt_2 (`str` or `List[str]`, *optional*):
562 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
563 | will be used instead
564 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
565 | The height in pixels of the generated image. This is set to 1024 by default for the best results.
566 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
567 | The width in pixels of the generated image. This is set to 1024 by default for the best results.
568 | num_inference_steps (`int`, *optional*, defaults to 50):
569 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
570 | expense of slower inference.
571 | timesteps (`List[int]`, *optional*):
572 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
573 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
574 | passed will be used. Must be in descending order.
575 | guidance_scale (`float`, *optional*, defaults to 7.0):
576 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
577 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
578 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
579 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
580 | usually at the expense of lower image quality.
581 | num_images_per_prompt (`int`, *optional*, defaults to 1):
582 | The number of images to generate per prompt.
583 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
584 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
585 | to make generation deterministic.
586 | latents (`torch.FloatTensor`, *optional*):
587 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
588 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
589 | tensor will ge generated by sampling using the supplied random `generator`.
590 | prompt_embeds (`torch.FloatTensor`, *optional*):
591 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
592 | provided, text embeddings will be generated from `prompt` input argument.
593 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
594 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
595 | If not provided, pooled text embeddings will be generated from `prompt` input argument.
596 | output_type (`str`, *optional*, defaults to `"pil"`):
597 | The output format of the generate image. Choose between
598 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
599 | return_dict (`bool`, *optional*, defaults to `True`):
600 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
601 | joint_attention_kwargs (`dict`, *optional*):
602 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
603 | `self.processor` in
604 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
605 | callback_on_step_end (`Callable`, *optional*):
606 | A function that calls at the end of each denoising steps during the inference. The function is called
607 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
608 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
609 | `callback_on_step_end_tensor_inputs`.
610 | callback_on_step_end_tensor_inputs (`List`, *optional*):
611 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
612 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
613 | `._callback_tensor_inputs` attribute of your pipeline class.
614 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
615 |
616 | Examples:
617 |
618 | Returns:
619 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
620 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
621 | images.
622 | """
623 |
624 | height = height or self.default_sample_size * self.vae_scale_factor
625 | width = width or self.default_sample_size * self.vae_scale_factor
626 |
627 | # 1. Check inputs. Raise error if not correct
628 | self.check_inputs(
629 | prompt,
630 | prompt_2,
631 | height,
632 | width,
633 | prompt_embeds=prompt_embeds,
634 | pooled_prompt_embeds=pooled_prompt_embeds,
635 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
636 | max_sequence_length=max_sequence_length,
637 | )
638 |
639 | self._guidance_scale = guidance_scale
640 | self._joint_attention_kwargs = joint_attention_kwargs
641 | self._interrupt = False
642 |
643 | # 2. Define call parameters
644 | if prompt is not None and isinstance(prompt, str):
645 | batch_size = 1
646 | elif prompt is not None and isinstance(prompt, list):
647 | batch_size = len(prompt)
648 | else:
649 | batch_size = prompt_embeds.shape[0]
650 |
651 | device = self._execution_device
652 |
653 | lora_scale = (
654 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
655 | )
656 | (
657 | prompt_embeds,
658 | pooled_prompt_embeds,
659 | text_ids,
660 | ) = self.encode_prompt(
661 | prompt=prompt,
662 | prompt_2=prompt_2,
663 | prompt_embeds=prompt_embeds,
664 | pooled_prompt_embeds=pooled_prompt_embeds,
665 | device=device,
666 | num_images_per_prompt=num_images_per_prompt,
667 | max_sequence_length=max_sequence_length,
668 | lora_scale=lora_scale,
669 | )
670 |
671 | # 4. Prepare latent variables
672 | num_channels_latents = self.transformer.config.in_channels // 4
673 | latents, latent_image_ids = self.prepare_latents(
674 | batch_size * num_images_per_prompt,
675 | num_channels_latents,
676 | height,
677 | width,
678 | prompt_embeds.dtype,
679 | device,
680 | generator,
681 | latents,
682 | )
683 |
684 | # 5. Prepare timesteps
685 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
686 | image_seq_len = latents.shape[1]
687 | mu = calculate_shift(
688 | image_seq_len,
689 | self.scheduler.config.base_image_seq_len,
690 | self.scheduler.config.max_image_seq_len,
691 | self.scheduler.config.base_shift,
692 | self.scheduler.config.max_shift,
693 | )
694 | timesteps, num_inference_steps = retrieve_timesteps(
695 | self.scheduler,
696 | num_inference_steps,
697 | device,
698 | timesteps,
699 | sigmas,
700 | mu=mu,
701 | )
702 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
703 | self._num_timesteps = len(timesteps)
704 |
705 | # handle guidance
706 | if self.transformer.config.guidance_embeds:
707 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
708 | guidance = guidance.expand(latents.shape[0])
709 | else:
710 | guidance = None
711 |
712 | # 6. Denoising loop
713 | with self.progress_bar(total=num_inference_steps) as progress_bar:
714 | for i, t in enumerate(timesteps):
715 | if self.interrupt:
716 | continue
717 |
718 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
719 | timestep = t.expand(latents.shape[0]).to(latents.dtype)
720 |
721 | noise_pred = self.transformer(
722 | hidden_states=latents,
723 | timestep=timestep / 1000,
724 | guidance=guidance,
725 | pooled_projections=pooled_prompt_embeds,
726 | encoder_hidden_states=prompt_embeds,
727 | txt_ids=text_ids,
728 | img_ids=latent_image_ids,
729 | ip_token_ids=ip_token_ids,
730 | joint_attention_kwargs=self.joint_attention_kwargs,
731 | return_dict=False,
732 | ip_token=ip_token,
733 | )[0]
734 |
735 | # compute the previous noisy sample x_t -> x_t-1
736 | latents_dtype = latents.dtype
737 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
738 |
739 | if latents.dtype != latents_dtype:
740 | if torch.backends.mps.is_available():
741 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
742 | latents = latents.to(latents_dtype)
743 |
744 | if callback_on_step_end is not None:
745 | callback_kwargs = {}
746 | for k in callback_on_step_end_tensor_inputs:
747 | callback_kwargs[k] = locals()[k]
748 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
749 |
750 | latents = callback_outputs.pop("latents", latents)
751 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
752 |
753 | # call the callback, if provided
754 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
755 | progress_bar.update()
756 |
757 | if XLA_AVAILABLE:
758 | xm.mark_step()
759 |
760 | if output_type == "latent":
761 | image = latents
762 |
763 | else:
764 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
765 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
766 | image = self.vae.decode(latents, return_dict=False)[0]
767 | image = self.image_processor.postprocess(image, output_type=output_type)
768 |
769 | # Offload all models
770 | self.maybe_free_model_hooks()
771 |
772 | if not return_dict:
773 | return (image,)
774 |
775 | return FluxPipelineOutput(images=image)
776 |
--------------------------------------------------------------------------------
/src/customID/resampler.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 | from einops import rearrange
9 | from einops.layers.torch import Rearrange
10 |
11 |
12 | # FFN
13 | def FeedForward(dim, mult=4):
14 | inner_dim = int(dim * mult)
15 | return nn.Sequential(
16 | nn.LayerNorm(dim),
17 | nn.Linear(dim, inner_dim, bias=False),
18 | nn.GELU(),
19 | nn.Linear(inner_dim, dim, bias=False),
20 | )
21 |
22 |
23 | def reshape_tensor(x, heads):
24 | bs, length, width = x.shape
25 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26 | x = x.view(bs, length, heads, -1)
27 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28 | x = x.transpose(1, 2)
29 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30 | x = x.reshape(bs, heads, length, -1)
31 | return x
32 |
33 |
34 | class PerceiverAttention(nn.Module):
35 | def __init__(self, *, dim, dim_head=64, heads=8):
36 | super().__init__()
37 | self.scale = dim_head**-0.5
38 | self.dim_head = dim_head
39 | self.heads = heads
40 | inner_dim = dim_head * heads
41 |
42 | self.norm1 = nn.LayerNorm(dim)
43 | self.norm2 = nn.LayerNorm(dim)
44 |
45 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
46 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
48 |
49 | def forward(self, x, latents):
50 | """
51 | Args:
52 | x (torch.Tensor): image features
53 | shape (b, n1, D)
54 | latent (torch.Tensor): latent features
55 | shape (b, n2, D)
56 | """
57 | x = self.norm1(x)
58 | latents = self.norm2(latents)
59 |
60 | b, l, _ = latents.shape
61 |
62 | q = self.to_q(latents)
63 | kv_input = torch.cat((x, latents), dim=-2)
64 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65 |
66 | q = reshape_tensor(q, self.heads)
67 | k = reshape_tensor(k, self.heads)
68 | v = reshape_tensor(v, self.heads)
69 |
70 | # attention
71 | scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74 | out = weight @ v
75 |
76 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77 |
78 | return self.to_out(out)
79 |
80 |
81 | class Resampler(nn.Module):
82 | def __init__(
83 | self,
84 | dim=1024,
85 | depth=8,
86 | dim_head=64,
87 | heads=16,
88 | num_queries=8,
89 | embedding_dim=768,
90 | output_dim=1024,
91 | ff_mult=4,
92 | max_seq_len: int = 257, # CLIP tokens + CLS token
93 | apply_pos_emb: bool = False,
94 | num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95 | ):
96 | super().__init__()
97 | self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98 |
99 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100 |
101 | self.proj_in = nn.Linear(embedding_dim, dim)
102 |
103 | self.proj_out = nn.Linear(dim, output_dim)
104 | self.norm_out = nn.LayerNorm(output_dim)
105 |
106 | self.to_latents_from_mean_pooled_seq = (
107 | nn.Sequential(
108 | nn.LayerNorm(dim),
109 | nn.Linear(dim, dim * num_latents_mean_pooled),
110 | Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111 | )
112 | if num_latents_mean_pooled > 0
113 | else None
114 | )
115 |
116 | self.layers = nn.ModuleList([])
117 | for _ in range(depth):
118 | self.layers.append(
119 | nn.ModuleList(
120 | [
121 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122 | FeedForward(dim=dim, mult=ff_mult),
123 | ]
124 | )
125 | )
126 |
127 | def forward(self, x):
128 | if self.pos_emb is not None:
129 | n, device = x.shape[1], x.device
130 | pos_emb = self.pos_emb(torch.arange(n, device=device))
131 | x = x + pos_emb
132 |
133 | latents = self.latents.repeat(x.size(0), 1, 1)
134 |
135 | x = self.proj_in(x)
136 |
137 | if self.to_latents_from_mean_pooled_seq:
138 | meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140 | latents = torch.cat((meanpooled_latents, latents), dim=-2)
141 |
142 | for attn, ff in self.layers:
143 | latents = attn(x, latents) + latents
144 | latents = ff(latents) + latents
145 |
146 | latents = self.proj_out(latents)
147 | return self.norm_out(latents)
148 |
149 |
150 | def masked_mean(t, *, dim, mask=None):
151 | if mask is None:
152 | return t.mean(dim=dim)
153 |
154 | denom = mask.sum(dim=dim, keepdim=True)
155 | mask = rearrange(mask, "b n -> b n 1")
156 | masked_t = t.masked_fill(~mask, 0.0)
157 |
158 | return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
159 |
--------------------------------------------------------------------------------
/src/customID/transformer_flux.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import Any, Dict, Optional, Tuple, Union
17 |
18 | import numpy as np
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 |
23 | from diffusers.configuration_utils import ConfigMixin, register_to_config
24 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25 | from diffusers.models.attention import FeedForward
26 | from diffusers.models.attention_processor import (
27 | Attention,
28 | AttentionProcessor,
29 | FluxAttnProcessor2_0,
30 | FusedFluxAttnProcessor2_0,
31 | )
32 | from diffusers.models.modeling_utils import ModelMixin
33 | from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
34 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
35 | from diffusers.utils.torch_utils import maybe_allow_in_graph
36 | from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
37 | from diffusers.models.modeling_outputs import Transformer2DModelOutput
38 |
39 |
40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41 |
42 |
43 | @maybe_allow_in_graph
44 | class FluxSingleTransformerBlock(nn.Module):
45 | r"""
46 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
47 |
48 | Reference: https://arxiv.org/abs/2403.03206
49 |
50 | Parameters:
51 | dim (`int`): The number of channels in the input and output.
52 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
53 | attention_head_dim (`int`): The number of channels in each head.
54 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
55 | processing of `context` conditions.
56 | """
57 |
58 | def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
59 | super().__init__()
60 | self.mlp_hidden_dim = int(dim * mlp_ratio)
61 |
62 | self.norm = AdaLayerNormZeroSingle(dim)
63 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
64 | self.act_mlp = nn.GELU(approximate="tanh")
65 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
66 |
67 | processor = FluxAttnProcessor2_0()
68 | self.attn = Attention(
69 | query_dim=dim,
70 | cross_attention_dim=None,
71 | dim_head=attention_head_dim,
72 | heads=num_attention_heads,
73 | out_dim=dim,
74 | bias=True,
75 | processor=processor,
76 | qk_norm="rms_norm",
77 | eps=1e-6,
78 | pre_only=True,
79 | )
80 |
81 | def forward(
82 | self,
83 | hidden_states: torch.FloatTensor,
84 | temb: torch.FloatTensor,
85 | image_rotary_emb=None,
86 | ip_token=None,
87 | ):
88 | residual = hidden_states
89 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
90 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
91 |
92 | attn_output = self.attn(
93 | hidden_states=norm_hidden_states,
94 | image_rotary_emb=image_rotary_emb,
95 | ip_token=ip_token,
96 | )
97 |
98 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
99 | gate = gate.unsqueeze(1)
100 | hidden_states = gate * self.proj_out(hidden_states)
101 | hidden_states = residual + hidden_states
102 | if hidden_states.dtype == torch.float16:
103 | hidden_states = hidden_states.clip(-65504, 65504)
104 |
105 | return hidden_states
106 |
107 |
108 | @maybe_allow_in_graph
109 | class FluxTransformerBlock(nn.Module):
110 | r"""
111 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
112 |
113 | Reference: https://arxiv.org/abs/2403.03206
114 |
115 | Parameters:
116 | dim (`int`): The number of channels in the input and output.
117 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
118 | attention_head_dim (`int`): The number of channels in each head.
119 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
120 | processing of `context` conditions.
121 | """
122 |
123 | def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
124 | super().__init__()
125 |
126 | self.norm1 = AdaLayerNormZero(dim)
127 |
128 | self.norm1_context = AdaLayerNormZero(dim)
129 |
130 | if hasattr(F, "scaled_dot_product_attention"):
131 | processor = FluxAttnProcessor2_0()
132 | else:
133 | raise ValueError(
134 | "The current PyTorch version does not support the `scaled_dot_product_attention` function."
135 | )
136 | self.attn = Attention(
137 | query_dim=dim,
138 | cross_attention_dim=None,
139 | added_kv_proj_dim=dim,
140 | dim_head=attention_head_dim,
141 | heads=num_attention_heads,
142 | out_dim=dim,
143 | context_pre_only=False,
144 | bias=True,
145 | processor=processor,
146 | qk_norm=qk_norm,
147 | eps=eps,
148 | )
149 |
150 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
151 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
152 |
153 | self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
154 | self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
155 |
156 | # let chunk size default to None
157 | self._chunk_size = None
158 | self._chunk_dim = 0
159 |
160 | def forward(
161 | self,
162 | hidden_states: torch.FloatTensor,
163 | encoder_hidden_states: torch.FloatTensor,
164 | temb: torch.FloatTensor,
165 | image_rotary_emb=None,
166 | ip_token=None,
167 | ):
168 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
169 |
170 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
171 | encoder_hidden_states, emb=temb
172 | )
173 |
174 | # Attention.
175 | attn_output, context_attn_output = self.attn(
176 | hidden_states=norm_hidden_states,
177 | encoder_hidden_states=norm_encoder_hidden_states,
178 | image_rotary_emb=image_rotary_emb,
179 | ip_token=ip_token,
180 | )
181 |
182 | # Process attention outputs for the `hidden_states`.
183 | attn_output = gate_msa.unsqueeze(1) * attn_output
184 | hidden_states = hidden_states + attn_output
185 |
186 | norm_hidden_states = self.norm2(hidden_states)
187 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
188 |
189 | ff_output = self.ff(norm_hidden_states)
190 | ff_output = gate_mlp.unsqueeze(1) * ff_output
191 |
192 | hidden_states = hidden_states + ff_output
193 |
194 | # Process attention outputs for the `encoder_hidden_states`.
195 |
196 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
197 | encoder_hidden_states = encoder_hidden_states + context_attn_output
198 |
199 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
200 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
201 |
202 | context_ff_output = self.ff_context(norm_encoder_hidden_states)
203 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
204 | if encoder_hidden_states.dtype == torch.float16:
205 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
206 |
207 | return encoder_hidden_states, hidden_states
208 |
209 |
210 | class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
211 | """
212 | The Transformer model introduced in Flux.
213 |
214 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
215 |
216 | Parameters:
217 | patch_size (`int`): Patch size to turn the input data into small patches.
218 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
219 | num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
220 | num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
221 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
222 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
223 | joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
224 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
225 | guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
226 | """
227 |
228 | _supports_gradient_checkpointing = True
229 | _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
230 |
231 | @register_to_config
232 | def __init__(
233 | self,
234 | patch_size: int = 1,
235 | in_channels: int = 64,
236 | num_layers: int = 19,
237 | num_single_layers: int = 38,
238 | attention_head_dim: int = 128,
239 | num_attention_heads: int = 24,
240 | joint_attention_dim: int = 4096,
241 | pooled_projection_dim: int = 768,
242 | guidance_embeds: bool = False,
243 | axes_dims_rope: Tuple[int] = (16, 56, 56),
244 | ):
245 | super().__init__()
246 | self.out_channels = in_channels
247 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
248 |
249 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
250 |
251 | text_time_guidance_cls = (
252 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
253 | )
254 | self.time_text_embed = text_time_guidance_cls(
255 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
256 | )
257 |
258 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
259 | self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
260 |
261 | self.transformer_blocks = nn.ModuleList(
262 | [
263 | FluxTransformerBlock(
264 | dim=self.inner_dim,
265 | num_attention_heads=self.config.num_attention_heads,
266 | attention_head_dim=self.config.attention_head_dim,
267 | )
268 | for i in range(self.config.num_layers)
269 | ]
270 | )
271 |
272 | self.single_transformer_blocks = nn.ModuleList(
273 | [
274 | FluxSingleTransformerBlock(
275 | dim=self.inner_dim,
276 | num_attention_heads=self.config.num_attention_heads,
277 | attention_head_dim=self.config.attention_head_dim,
278 | )
279 | for i in range(self.config.num_single_layers)
280 | ]
281 | )
282 |
283 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
284 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
285 |
286 | self.gradient_checkpointing = False
287 |
288 | @property
289 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
290 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
291 | r"""
292 | Returns:
293 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
294 | indexed by its weight name.
295 | """
296 | # set recursively
297 | processors = {}
298 |
299 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
300 | if hasattr(module, "get_processor"):
301 | processors[f"{name}.processor"] = module.get_processor()
302 |
303 | for sub_name, child in module.named_children():
304 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
305 |
306 | return processors
307 |
308 | for name, module in self.named_children():
309 | fn_recursive_add_processors(name, module, processors)
310 |
311 | return processors
312 |
313 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
314 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
315 | r"""
316 | Sets the attention processor to use to compute attention.
317 |
318 | Parameters:
319 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
320 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
321 | for **all** `Attention` layers.
322 |
323 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
324 | processor. This is strongly recommended when setting trainable attention processors.
325 |
326 | """
327 | count = len(self.attn_processors.keys())
328 |
329 | if isinstance(processor, dict) and len(processor) != count:
330 | raise ValueError(
331 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
332 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
333 | )
334 |
335 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
336 | if hasattr(module, "set_processor"):
337 | if not isinstance(processor, dict):
338 | module.set_processor(processor)
339 | else:
340 | module.set_processor(processor.pop(f"{name}.processor"))
341 |
342 | for sub_name, child in module.named_children():
343 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
344 |
345 | for name, module in self.named_children():
346 | fn_recursive_attn_processor(name, module, processor)
347 |
348 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
349 | def fuse_qkv_projections(self):
350 | """
351 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
352 | are fused. For cross-attention modules, key and value projection matrices are fused.
353 |
354 |
355 |
356 | This API is 🧪 experimental.
357 |
358 |
359 | """
360 | self.original_attn_processors = None
361 |
362 | for _, attn_processor in self.attn_processors.items():
363 | if "Added" in str(attn_processor.__class__.__name__):
364 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
365 |
366 | self.original_attn_processors = self.attn_processors
367 |
368 | for module in self.modules():
369 | if isinstance(module, Attention):
370 | module.fuse_projections(fuse=True)
371 |
372 | self.set_attn_processor(FusedFluxAttnProcessor2_0())
373 |
374 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
375 | def unfuse_qkv_projections(self):
376 | """Disables the fused QKV projection if enabled.
377 |
378 |
379 |
380 | This API is 🧪 experimental.
381 |
382 |
383 |
384 | """
385 | if self.original_attn_processors is not None:
386 | self.set_attn_processor(self.original_attn_processors)
387 |
388 | def _set_gradient_checkpointing(self, module, value=False):
389 | if hasattr(module, "gradient_checkpointing"):
390 | module.gradient_checkpointing = value
391 |
392 | def forward(
393 | self,
394 | hidden_states: torch.Tensor,
395 | encoder_hidden_states: torch.Tensor = None,
396 | pooled_projections: torch.Tensor = None,
397 | timestep: torch.LongTensor = None,
398 | img_ids: torch.Tensor = None,
399 | txt_ids: torch.Tensor = None,
400 | guidance: torch.Tensor = None,
401 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
402 | controlnet_block_samples=None,
403 | controlnet_single_block_samples=None,
404 | return_dict: bool = True,
405 | ip_token: torch.Tensor = None, #add param
406 | ip_token_ids: torch.Tensor = None,
407 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
408 | """
409 | The [`FluxTransformer2DModel`] forward method.
410 |
411 | Args:
412 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
413 | Input `hidden_states`.
414 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
415 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
416 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
417 | from the embeddings of input conditions.
418 | timestep ( `torch.LongTensor`):
419 | Used to indicate denoising step.
420 | block_controlnet_hidden_states: (`list` of `torch.Tensor`):
421 | A list of tensors that if specified are added to the residuals of transformer blocks.
422 | joint_attention_kwargs (`dict`, *optional*):
423 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
424 | `self.processor` in
425 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
426 | return_dict (`bool`, *optional*, defaults to `True`):
427 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
428 | tuple.
429 |
430 | Returns:
431 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
432 | `tuple` where the first element is the sample tensor.
433 | """
434 | # with open(f"mid_info.txt", 'a') as f:
435 | # f.write(f"{ip_token.abs().mean().item() ,{hidden_states.abs().mean().item() ,{encoder_hidden_states.abs().mean().item()}}}\n")
436 |
437 | if joint_attention_kwargs is not None:
438 | joint_attention_kwargs = joint_attention_kwargs.copy()
439 | lora_scale = joint_attention_kwargs.pop("scale", 1.0)
440 | else:
441 | lora_scale = 1.0
442 |
443 | if USE_PEFT_BACKEND:
444 | # weight the lora layers by setting `lora_scale` for each PEFT layer
445 | scale_lora_layers(self, lora_scale)
446 | else:
447 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
448 | logger.warning(
449 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
450 | )
451 | hidden_states = self.x_embedder(hidden_states)
452 |
453 | timestep = timestep.to(hidden_states.dtype) * 1000
454 | if guidance is not None:
455 | guidance = guidance.to(hidden_states.dtype) * 1000
456 | else:
457 | guidance = None
458 | temb = (
459 | self.time_text_embed(timestep, pooled_projections)
460 | if guidance is None
461 | else self.time_text_embed(timestep, guidance, pooled_projections)
462 | )
463 | encoder_hidden_states = self.context_embedder(encoder_hidden_states)
464 |
465 | if txt_ids.ndim == 3:
466 | # logger.warning(
467 | # "Passing `txt_ids` 3d torch.Tensor is deprecated."
468 | # "Please remove the batch dimension and pass it as a 2d torch Tensor"
469 | # )
470 | txt_ids = txt_ids[0]
471 | if img_ids.ndim == 3:
472 | # logger.warning(
473 | # "Passing `img_ids` 3d torch.Tensor is deprecated."
474 | # "Please remove the batch dimension and pass it as a 2d torch Tensor"
475 | # )
476 | img_ids = img_ids[0]
477 | # print(f"transformers!!!") #dylee
478 | # print(f"txt_ids, img_ids, ip_token_ids: {txt_ids.shape}, {img_ids.shape}, {ip_token_ids.shape}")
479 | ids = torch.cat((txt_ids, img_ids, ip_token_ids), dim=0)
480 | image_rotary_emb = self.pos_embed(ids)
481 |
482 | # print(f"image_rotary_emb shape {image_rotary_emb[0].shape}")
483 | # print(f"ip_token shape is {ip_token.shape}")
484 | # print(f"hidden_states shape is {hidden_states.shape}")
485 | # print(f"encoder_hidden_states shape is {encoder_hidden_states.shape}")
486 |
487 | for index_block, block in enumerate(self.transformer_blocks):
488 | if self.training and self.gradient_checkpointing:
489 |
490 | def create_custom_forward(module, return_dict=None):
491 | def custom_forward(*inputs):
492 | if return_dict is not None:
493 | return module(*inputs, return_dict=return_dict)
494 | else:
495 | return module(*inputs)
496 |
497 | return custom_forward
498 |
499 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
500 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
501 | create_custom_forward(block),
502 | hidden_states,
503 | encoder_hidden_states,
504 | temb,
505 | image_rotary_emb,
506 | ip_token,
507 | **ckpt_kwargs,
508 | )
509 |
510 | else:
511 | encoder_hidden_states, hidden_states = block(
512 | hidden_states=hidden_states,
513 | encoder_hidden_states=encoder_hidden_states,
514 | temb=temb,
515 | image_rotary_emb=image_rotary_emb,
516 | ip_token=ip_token,
517 | )
518 |
519 | # controlnet residual
520 | if controlnet_block_samples is not None:
521 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
522 | interval_control = int(np.ceil(interval_control))
523 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
524 |
525 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
526 |
527 |
528 | # ids_single = torch.cat((txt_ids, img_ids), dim=0)
529 | # image_rotary_emb_single = self.pos_embed(ids_single) #dylee
530 | for index_block, block in enumerate(self.single_transformer_blocks):
531 | if self.training and self.gradient_checkpointing:
532 |
533 | def create_custom_forward(module, return_dict=None):
534 | def custom_forward(*inputs):
535 | if return_dict is not None:
536 | return module(*inputs, return_dict=return_dict)
537 | else:
538 | return module(*inputs)
539 |
540 | return custom_forward
541 |
542 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
543 | hidden_states = torch.utils.checkpoint.checkpoint(
544 | create_custom_forward(block),
545 | hidden_states,
546 | temb,
547 | image_rotary_emb, #dylee
548 | ip_token,
549 | **ckpt_kwargs,
550 | )
551 |
552 | else:
553 | hidden_states = block(
554 | hidden_states=hidden_states,
555 | temb=temb,
556 | image_rotary_emb=image_rotary_emb, #dylee
557 | ip_token=ip_token,
558 | )
559 |
560 | # controlnet residual
561 | if controlnet_single_block_samples is not None:
562 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
563 | interval_control = int(np.ceil(interval_control))
564 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
565 | hidden_states[:, encoder_hidden_states.shape[1] :, ...]
566 | + controlnet_single_block_samples[index_block // interval_control]
567 | )
568 |
569 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
570 |
571 | hidden_states = self.norm_out(hidden_states, temb)
572 | output = self.proj_out(hidden_states)
573 |
574 | if USE_PEFT_BACKEND:
575 | # remove `lora_scale` from each PEFT layer
576 | unscale_lora_layers(self, lora_scale)
577 |
578 | if not return_dict:
579 | return (output,)
580 |
581 | return Transformer2DModelOutput(sample=output)
582 |
--------------------------------------------------------------------------------
/src/customID/transformer_flux_ori.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import Any, Dict, Optional, Tuple, Union
17 |
18 | import numpy as np
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 |
23 | from ...configuration_utils import ConfigMixin, register_to_config
24 | from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
25 | from ...models.attention import FeedForward
26 | from ...models.attention_processor import (
27 | Attention,
28 | AttentionProcessor,
29 | FluxAttnProcessor2_0,
30 | FusedFluxAttnProcessor2_0,
31 | )
32 | from ...models.modeling_utils import ModelMixin
33 | from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
34 | from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
35 | from ...utils.torch_utils import maybe_allow_in_graph
36 | from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
37 | from ..modeling_outputs import Transformer2DModelOutput
38 |
39 |
40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41 |
42 |
43 | @maybe_allow_in_graph
44 | class FluxSingleTransformerBlock(nn.Module):
45 | r"""
46 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
47 |
48 | Reference: https://arxiv.org/abs/2403.03206
49 |
50 | Parameters:
51 | dim (`int`): The number of channels in the input and output.
52 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
53 | attention_head_dim (`int`): The number of channels in each head.
54 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
55 | processing of `context` conditions.
56 | """
57 |
58 | def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
59 | super().__init__()
60 | self.mlp_hidden_dim = int(dim * mlp_ratio)
61 |
62 | self.norm = AdaLayerNormZeroSingle(dim)
63 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
64 | self.act_mlp = nn.GELU(approximate="tanh")
65 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
66 |
67 | processor = FluxAttnProcessor2_0()
68 | self.attn = Attention(
69 | query_dim=dim,
70 | cross_attention_dim=None,
71 | dim_head=attention_head_dim,
72 | heads=num_attention_heads,
73 | out_dim=dim,
74 | bias=True,
75 | processor=processor,
76 | qk_norm="rms_norm",
77 | eps=1e-6,
78 | pre_only=True,
79 | )
80 |
81 | def forward(
82 | self,
83 | hidden_states: torch.FloatTensor,
84 | temb: torch.FloatTensor,
85 | image_rotary_emb=None,
86 | ):
87 | residual = hidden_states
88 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
89 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
90 |
91 | attn_output = self.attn(
92 | hidden_states=norm_hidden_states,
93 | image_rotary_emb=image_rotary_emb,
94 | )
95 |
96 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
97 | gate = gate.unsqueeze(1)
98 | hidden_states = gate * self.proj_out(hidden_states)
99 | hidden_states = residual + hidden_states
100 | if hidden_states.dtype == torch.float16:
101 | hidden_states = hidden_states.clip(-65504, 65504)
102 |
103 | return hidden_states
104 |
105 |
106 | @maybe_allow_in_graph
107 | class FluxTransformerBlock(nn.Module):
108 | r"""
109 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
110 |
111 | Reference: https://arxiv.org/abs/2403.03206
112 |
113 | Parameters:
114 | dim (`int`): The number of channels in the input and output.
115 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
116 | attention_head_dim (`int`): The number of channels in each head.
117 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
118 | processing of `context` conditions.
119 | """
120 |
121 | def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
122 | super().__init__()
123 |
124 | self.norm1 = AdaLayerNormZero(dim)
125 |
126 | self.norm1_context = AdaLayerNormZero(dim)
127 |
128 | if hasattr(F, "scaled_dot_product_attention"):
129 | processor = FluxAttnProcessor2_0()
130 | else:
131 | raise ValueError(
132 | "The current PyTorch version does not support the `scaled_dot_product_attention` function."
133 | )
134 | self.attn = Attention(
135 | query_dim=dim,
136 | cross_attention_dim=None,
137 | added_kv_proj_dim=dim,
138 | dim_head=attention_head_dim,
139 | heads=num_attention_heads,
140 | out_dim=dim,
141 | context_pre_only=False,
142 | bias=True,
143 | processor=processor,
144 | qk_norm=qk_norm,
145 | eps=eps,
146 | )
147 |
148 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
149 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
150 |
151 | self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
152 | self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
153 |
154 | # let chunk size default to None
155 | self._chunk_size = None
156 | self._chunk_dim = 0
157 |
158 | def forward(
159 | self,
160 | hidden_states: torch.FloatTensor,
161 | encoder_hidden_states: torch.FloatTensor,
162 | temb: torch.FloatTensor,
163 | image_rotary_emb=None,
164 | ):
165 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
166 |
167 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
168 | encoder_hidden_states, emb=temb
169 | )
170 |
171 | # Attention.
172 | attn_output, context_attn_output = self.attn(
173 | hidden_states=norm_hidden_states,
174 | encoder_hidden_states=norm_encoder_hidden_states,
175 | image_rotary_emb=image_rotary_emb,
176 | )
177 |
178 | # Process attention outputs for the `hidden_states`.
179 | attn_output = gate_msa.unsqueeze(1) * attn_output
180 | hidden_states = hidden_states + attn_output
181 |
182 | norm_hidden_states = self.norm2(hidden_states)
183 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
184 |
185 | ff_output = self.ff(norm_hidden_states)
186 | ff_output = gate_mlp.unsqueeze(1) * ff_output
187 |
188 | hidden_states = hidden_states + ff_output
189 |
190 | # Process attention outputs for the `encoder_hidden_states`.
191 |
192 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
193 | encoder_hidden_states = encoder_hidden_states + context_attn_output
194 |
195 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
196 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
197 |
198 | context_ff_output = self.ff_context(norm_encoder_hidden_states)
199 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
200 | if encoder_hidden_states.dtype == torch.float16:
201 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
202 |
203 | return encoder_hidden_states, hidden_states
204 |
205 |
206 | class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
207 | """
208 | The Transformer model introduced in Flux.
209 |
210 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
211 |
212 | Parameters:
213 | patch_size (`int`): Patch size to turn the input data into small patches.
214 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
215 | num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
216 | num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
217 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
218 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
219 | joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
220 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
221 | guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
222 | """
223 |
224 | _supports_gradient_checkpointing = True
225 | _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
226 |
227 | @register_to_config
228 | def __init__(
229 | self,
230 | patch_size: int = 1,
231 | in_channels: int = 64,
232 | num_layers: int = 19,
233 | num_single_layers: int = 38,
234 | attention_head_dim: int = 128,
235 | num_attention_heads: int = 24,
236 | joint_attention_dim: int = 4096,
237 | pooled_projection_dim: int = 768,
238 | guidance_embeds: bool = False,
239 | axes_dims_rope: Tuple[int] = (16, 56, 56),
240 | ):
241 | super().__init__()
242 | self.out_channels = in_channels
243 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
244 |
245 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
246 |
247 | text_time_guidance_cls = (
248 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
249 | )
250 | self.time_text_embed = text_time_guidance_cls(
251 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
252 | )
253 |
254 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
255 | self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
256 |
257 | self.transformer_blocks = nn.ModuleList(
258 | [
259 | FluxTransformerBlock(
260 | dim=self.inner_dim,
261 | num_attention_heads=self.config.num_attention_heads,
262 | attention_head_dim=self.config.attention_head_dim,
263 | )
264 | for i in range(self.config.num_layers)
265 | ]
266 | )
267 |
268 | self.single_transformer_blocks = nn.ModuleList(
269 | [
270 | FluxSingleTransformerBlock(
271 | dim=self.inner_dim,
272 | num_attention_heads=self.config.num_attention_heads,
273 | attention_head_dim=self.config.attention_head_dim,
274 | )
275 | for i in range(self.config.num_single_layers)
276 | ]
277 | )
278 |
279 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
280 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
281 |
282 | self.gradient_checkpointing = False
283 |
284 | @property
285 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
286 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
287 | r"""
288 | Returns:
289 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
290 | indexed by its weight name.
291 | """
292 | # set recursively
293 | processors = {}
294 |
295 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
296 | if hasattr(module, "get_processor"):
297 | processors[f"{name}.processor"] = module.get_processor()
298 |
299 | for sub_name, child in module.named_children():
300 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
301 |
302 | return processors
303 |
304 | for name, module in self.named_children():
305 | fn_recursive_add_processors(name, module, processors)
306 |
307 | return processors
308 |
309 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
310 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
311 | r"""
312 | Sets the attention processor to use to compute attention.
313 |
314 | Parameters:
315 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
316 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
317 | for **all** `Attention` layers.
318 |
319 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
320 | processor. This is strongly recommended when setting trainable attention processors.
321 |
322 | """
323 | count = len(self.attn_processors.keys())
324 |
325 | if isinstance(processor, dict) and len(processor) != count:
326 | raise ValueError(
327 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
328 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
329 | )
330 |
331 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
332 | if hasattr(module, "set_processor"):
333 | if not isinstance(processor, dict):
334 | module.set_processor(processor)
335 | else:
336 | module.set_processor(processor.pop(f"{name}.processor"))
337 |
338 | for sub_name, child in module.named_children():
339 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
340 |
341 | for name, module in self.named_children():
342 | fn_recursive_attn_processor(name, module, processor)
343 |
344 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
345 | def fuse_qkv_projections(self):
346 | """
347 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
348 | are fused. For cross-attention modules, key and value projection matrices are fused.
349 |
350 |
351 |
352 | This API is 🧪 experimental.
353 |
354 |
355 | """
356 | self.original_attn_processors = None
357 |
358 | for _, attn_processor in self.attn_processors.items():
359 | if "Added" in str(attn_processor.__class__.__name__):
360 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
361 |
362 | self.original_attn_processors = self.attn_processors
363 |
364 | for module in self.modules():
365 | if isinstance(module, Attention):
366 | module.fuse_projections(fuse=True)
367 |
368 | self.set_attn_processor(FusedFluxAttnProcessor2_0())
369 |
370 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
371 | def unfuse_qkv_projections(self):
372 | """Disables the fused QKV projection if enabled.
373 |
374 |
375 |
376 | This API is 🧪 experimental.
377 |
378 |
379 |
380 | """
381 | if self.original_attn_processors is not None:
382 | self.set_attn_processor(self.original_attn_processors)
383 |
384 | def _set_gradient_checkpointing(self, module, value=False):
385 | if hasattr(module, "gradient_checkpointing"):
386 | module.gradient_checkpointing = value
387 |
388 | def forward(
389 | self,
390 | hidden_states: torch.Tensor,
391 | encoder_hidden_states: torch.Tensor = None,
392 | pooled_projections: torch.Tensor = None,
393 | timestep: torch.LongTensor = None,
394 | img_ids: torch.Tensor = None,
395 | txt_ids: torch.Tensor = None,
396 | guidance: torch.Tensor = None,
397 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
398 | controlnet_block_samples=None,
399 | controlnet_single_block_samples=None,
400 | return_dict: bool = True,
401 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
402 | """
403 | The [`FluxTransformer2DModel`] forward method.
404 |
405 | Args:
406 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
407 | Input `hidden_states`.
408 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
409 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
410 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
411 | from the embeddings of input conditions.
412 | timestep ( `torch.LongTensor`):
413 | Used to indicate denoising step.
414 | block_controlnet_hidden_states: (`list` of `torch.Tensor`):
415 | A list of tensors that if specified are added to the residuals of transformer blocks.
416 | joint_attention_kwargs (`dict`, *optional*):
417 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
418 | `self.processor` in
419 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
420 | return_dict (`bool`, *optional*, defaults to `True`):
421 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
422 | tuple.
423 |
424 | Returns:
425 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
426 | `tuple` where the first element is the sample tensor.
427 | """
428 | if joint_attention_kwargs is not None:
429 | joint_attention_kwargs = joint_attention_kwargs.copy()
430 | lora_scale = joint_attention_kwargs.pop("scale", 1.0)
431 | else:
432 | lora_scale = 1.0
433 |
434 | if USE_PEFT_BACKEND:
435 | # weight the lora layers by setting `lora_scale` for each PEFT layer
436 | scale_lora_layers(self, lora_scale)
437 | else:
438 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
439 | logger.warning(
440 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
441 | )
442 | hidden_states = self.x_embedder(hidden_states)
443 |
444 | timestep = timestep.to(hidden_states.dtype) * 1000
445 | if guidance is not None:
446 | guidance = guidance.to(hidden_states.dtype) * 1000
447 | else:
448 | guidance = None
449 | temb = (
450 | self.time_text_embed(timestep, pooled_projections)
451 | if guidance is None
452 | else self.time_text_embed(timestep, guidance, pooled_projections)
453 | )
454 | encoder_hidden_states = self.context_embedder(encoder_hidden_states)
455 |
456 | if txt_ids.ndim == 3:
457 | logger.warning(
458 | "Passing `txt_ids` 3d torch.Tensor is deprecated."
459 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
460 | )
461 | txt_ids = txt_ids[0]
462 | if img_ids.ndim == 3:
463 | logger.warning(
464 | "Passing `img_ids` 3d torch.Tensor is deprecated."
465 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
466 | )
467 | img_ids = img_ids[0]
468 | ids = torch.cat((txt_ids, img_ids), dim=0)
469 | image_rotary_emb = self.pos_embed(ids)
470 |
471 | for index_block, block in enumerate(self.transformer_blocks):
472 | if self.training and self.gradient_checkpointing:
473 |
474 | def create_custom_forward(module, return_dict=None):
475 | def custom_forward(*inputs):
476 | if return_dict is not None:
477 | return module(*inputs, return_dict=return_dict)
478 | else:
479 | return module(*inputs)
480 |
481 | return custom_forward
482 |
483 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
484 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
485 | create_custom_forward(block),
486 | hidden_states,
487 | encoder_hidden_states,
488 | temb,
489 | image_rotary_emb,
490 | **ckpt_kwargs,
491 | )
492 |
493 | else:
494 | encoder_hidden_states, hidden_states = block(
495 | hidden_states=hidden_states,
496 | encoder_hidden_states=encoder_hidden_states,
497 | temb=temb,
498 | image_rotary_emb=image_rotary_emb,
499 | )
500 |
501 | # controlnet residual
502 | if controlnet_block_samples is not None:
503 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
504 | interval_control = int(np.ceil(interval_control))
505 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
506 |
507 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
508 |
509 | for index_block, block in enumerate(self.single_transformer_blocks):
510 | if self.training and self.gradient_checkpointing:
511 |
512 | def create_custom_forward(module, return_dict=None):
513 | def custom_forward(*inputs):
514 | if return_dict is not None:
515 | return module(*inputs, return_dict=return_dict)
516 | else:
517 | return module(*inputs)
518 |
519 | return custom_forward
520 |
521 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
522 | hidden_states = torch.utils.checkpoint.checkpoint(
523 | create_custom_forward(block),
524 | hidden_states,
525 | temb,
526 | image_rotary_emb,
527 | **ckpt_kwargs,
528 | )
529 |
530 | else:
531 | hidden_states = block(
532 | hidden_states=hidden_states,
533 | temb=temb,
534 | image_rotary_emb=image_rotary_emb,
535 | )
536 |
537 | # controlnet residual
538 | if controlnet_single_block_samples is not None:
539 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
540 | interval_control = int(np.ceil(interval_control))
541 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
542 | hidden_states[:, encoder_hidden_states.shape[1] :, ...]
543 | + controlnet_single_block_samples[index_block // interval_control]
544 | )
545 |
546 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
547 |
548 | hidden_states = self.norm_out(hidden_states, temb)
549 | output = self.proj_out(hidden_states)
550 |
551 | if USE_PEFT_BACKEND:
552 | # remove `lora_scale` from each PEFT layer
553 | unscale_lora_layers(self, lora_scale)
554 |
555 | if not return_dict:
556 | return (output,)
557 |
558 | return Transformer2DModelOutput(sample=output)
559 |
--------------------------------------------------------------------------------
/src/customID/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from PIL import Image
5 |
6 | attn_maps = {}
7 | def hook_fn(name):
8 | def forward_hook(module, input, output):
9 | if hasattr(module.processor, "attn_map"):
10 | attn_maps[name] = module.processor.attn_map
11 | del module.processor.attn_map
12 |
13 | return forward_hook
14 |
15 | def register_cross_attention_hook(unet):
16 | for name, module in unet.named_modules():
17 | if name.split('.')[-1].startswith('attn2'):
18 | module.register_forward_hook(hook_fn(name))
19 |
20 | return unet
21 |
22 | def upscale(attn_map, target_size):
23 | attn_map = torch.mean(attn_map, dim=0)
24 | attn_map = attn_map.permute(1,0)
25 | temp_size = None
26 |
27 | for i in range(0,5):
28 | scale = 2 ** i
29 | if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
30 | temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
31 | break
32 |
33 | assert temp_size is not None, "temp_size cannot is None"
34 |
35 | attn_map = attn_map.view(attn_map.shape[0], *temp_size)
36 |
37 | attn_map = F.interpolate(
38 | attn_map.unsqueeze(0).to(dtype=torch.float32),
39 | size=target_size,
40 | mode='bilinear',
41 | align_corners=False
42 | )[0]
43 |
44 | attn_map = torch.softmax(attn_map, dim=0)
45 | return attn_map
46 | def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
47 |
48 | idx = 0 if instance_or_negative else 1
49 | net_attn_maps = []
50 |
51 | for name, attn_map in attn_maps.items():
52 | attn_map = attn_map.cpu() if detach else attn_map
53 | attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
54 | attn_map = upscale(attn_map, image_size)
55 | net_attn_maps.append(attn_map)
56 |
57 | net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
58 |
59 | return net_attn_maps
60 |
61 | def attnmaps2images(net_attn_maps):
62 |
63 | #total_attn_scores = 0
64 | images = []
65 |
66 | for attn_map in net_attn_maps:
67 | attn_map = attn_map.cpu().numpy()
68 | #total_attn_scores += attn_map.mean().item()
69 |
70 | normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
71 | normalized_attn_map = normalized_attn_map.astype(np.uint8)
72 | #print("norm: ", normalized_attn_map.shape)
73 | image = Image.fromarray(normalized_attn_map)
74 |
75 | #image = fix_save_attn_map(attn_map)
76 | images.append(image)
77 |
78 | #print(total_attn_scores)
79 | return images
80 | def is_torch2_available():
81 | return hasattr(F, "scaled_dot_product_attention")
82 |
83 | def get_generator(seed, device):
84 |
85 | if seed is not None:
86 | if isinstance(seed, list):
87 | generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
88 | else:
89 | generator = torch.Generator(device).manual_seed(seed)
90 | else:
91 | generator = None
92 |
93 | return generator
--------------------------------------------------------------------------------
/src/utils/insightface_package.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | # pip install insightface==0.7.3
3 | from insightface.app import FaceAnalysis
4 | from insightface.data import get_image as ins_get_image
5 |
6 | ###
7 | # https://github.com/cubiq/ComfyUI_IPAdapter_plus/issues/165#issue-2055829543
8 | ###
9 | class FaceAnalysis2(FaceAnalysis):
10 | # NOTE: allows setting det_size for each detection call.
11 | # the model allows it but the wrapping code from insightface
12 | # doesn't show it, and people end up loading duplicate models
13 | # for different sizes where there is absolutely no need to
14 | def get(self, img, max_num=0, det_size=(640, 640)):
15 | if det_size is not None:
16 | self.det_model.input_size = det_size
17 |
18 | return super().get(img, max_num)
19 |
20 | def analyze_faces(face_analysis: FaceAnalysis, img_data: np.ndarray, det_size=(640, 640)):
21 | # NOTE: try detect faces, if no faces detected, lower det_size until it does
22 | detection_sizes = [None] + [(size, size) for size in range(640, 256, -64)] + [(256, 256)]
23 |
24 | for size in detection_sizes:
25 | faces = face_analysis.get(img_data, det_size=size)
26 | if len(faces) > 0:
27 | return faces
28 |
29 | return []
30 |
--------------------------------------------------------------------------------