├── LICENSE ├── README.md ├── img ├── CustomAnyID_img.png ├── CustomID_img.png ├── controlnet_img.png ├── logo.png └── man1.jpg ├── infer_customID.ipynb ├── pretrained_ckpt └── src ├── customID ├── __init__.py ├── attention_processor.py ├── attention_processor_ori.py ├── model.py ├── pipeline_flux.py ├── resampler.py ├── transformer_flux.py ├── transformer_flux_ori.py └── utils.py └── utils └── insightface_package.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 DamoCV 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## LogoFLUX-customID: Realistically Customize Your Personal ID to Perfection 4 | 5 | This repository is the official implementation of FLUX-customID. It is capable of generating images based on your face image at a level equivalent to real photographic quality. Our base model is FLUX.dev, which ensures the generation of high-quality images. 6 | 7 | ## News 8 | - 🌟**2024-11-13**: Released the code and weights for FLUX-customID. 9 | 10 | ## Gallery 11 | Here are some example samples generated by our method. 12 | 13 |
14 | 15 | ## Quick Start 16 | 17 | ### 1. Setup Repository and Environment 18 | 19 | ``` 20 | conda create -n customID python=3.10 -y 21 | conda activate customID 22 | conda install pytorch==2.4.0 torchvision==0.19.0 pytorch-cuda=11.8 -c pytorch -c nvidia -y 23 | pip install -i https://mirrors.cloud.tencent.com/pypi/simple diffusers==0.31.0 transformers onnxruntime-gpu insightface sentencepiece matplotlib imageio tqdm numpy einops accelerate peft 24 | ``` 25 | 26 | ### 2. Prepare Pretrained Checkpoints 27 | 28 | ``` 29 | git clone https://github.com/damo-cv/FLUX-customID.git 30 | cd FLUX-customID 31 | 32 | mkdir pretrained_ckpt 33 | cd pretrained_ckpt 34 | 35 | #Download CLIP 36 | export HF_ENDPOINT=https://hf-mirror.com 37 | pip install -U "huggingface_hub[cli]" 38 | 39 | huggingface-cli download \ 40 | --resume-download "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" \ 41 | --cache-dir your_dir/ 42 | 43 | ln -s your_dir/models--laion--CLIP-ViT-H-14-laion2B-s32B-b79K/snapshots/de081ac0a0ca8dc9d1533eed1ae884bb8ae1404b pretrained_ckpt/openclip-vit-h-14 44 | 45 | #Download FLUX.1-dev 46 | huggingface-cli download \ 47 | --resume-download "black-forest-labs/FLUX.1-dev" \ 48 | --cache-dir your_dir/ 49 | 50 | ln -s your_dir/models--black-forest-labs--FLUX.1-dev/snapshots/303875135fff3f05b6aa893d544f28833a237d58 pretrained_ckpt/flux.1-dev 51 | 52 | #Download FLUX-customID 53 | Download our trained checkpoint from https://huggingface.co/Damo-vision/FLUX-customID and place FLUX-customID.pt in the floder pretrained_ckpt/ 54 | ``` 55 | 56 | ### 3. Quick Inference 57 | ``` 58 | run infer_customID.ipynb 59 | ``` 60 | 61 | ## Preview for CustomAnyID 62 | We would like to announce that we are currently working on a related project, **CustomAnyID**. Below are some preliminary experimental results: 63 | 64 |
65 | 66 | ## Preview for Controlnet 67 | We would like to announce our Controlnet model. Below are some preliminary experimental results: 68 |
69 | 70 | 71 | ## Contact Us 72 | Dongyang Li: [yingtian.ldy@alibaba-inc.com](yingtian.ldy@alibaba-inc.com) 73 | 74 | ## Acknowledgements 75 | The partial code is implemented based on [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) and [PhotoMaker](https://github.com/TencentARC/PhotoMaker). 76 | -------------------------------------------------------------------------------- /img/CustomAnyID_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/FLUX-customID/591f170db746be1c7da740996c1af5f4dcd729a1/img/CustomAnyID_img.png -------------------------------------------------------------------------------- /img/CustomID_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/FLUX-customID/591f170db746be1c7da740996c1af5f4dcd729a1/img/CustomID_img.png -------------------------------------------------------------------------------- /img/controlnet_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/FLUX-customID/591f170db746be1c7da740996c1af5f4dcd729a1/img/controlnet_img.png -------------------------------------------------------------------------------- /img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/FLUX-customID/591f170db746be1c7da740996c1af5f4dcd729a1/img/logo.png -------------------------------------------------------------------------------- /img/man1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/FLUX-customID/591f170db746be1c7da740996c1af5f4dcd729a1/img/man1.jpg -------------------------------------------------------------------------------- /infer_customID.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "from PIL import Image\n", 11 | "import torch\n", 12 | "from src.customID.pipeline_flux import FluxPipeline\n", 13 | "from src.customID.transformer_flux import FluxTransformer2DModel\n", 14 | "from src.customID.model import CustomIDModel\n", 15 | "\n", 16 | "def image_grid(imgs, rows, cols):\n", 17 | " assert len(imgs) == rows*cols\n", 18 | " w, h = imgs[0].size\n", 19 | " grid = Image.new('RGB', size=(cols*w, rows*h))\n", 20 | " grid_w, grid_h = grid.size\n", 21 | " \n", 22 | " for i, img in enumerate(imgs):\n", 23 | " grid.paste(img, box=(i%cols*w, i//cols*h))\n", 24 | " return grid\n", 25 | "\n", 26 | "_DEVICE = \"cuda:0\"\n", 27 | "_DTYPE=torch.bfloat16\n", 28 | "model_path = \"pretrained_ckpt/flux.1-dev\" #you can also use `black-forest-labs/FLUX.1-dev`\n", 29 | "transformer = FluxTransformer2DModel.from_pretrained(model_path, subfolder=\"transformer\", torch_dtype=_DTYPE).to(_DEVICE)\n", 30 | "pipe = FluxPipeline.from_pretrained(model_path, transformer=transformer, torch_dtype=_DTYPE).to(_DEVICE)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "num_token=64\n", 40 | "trained_ckpt = \"pretrained_ckpt/FLUX-customID.pt\"\n", 41 | "customID_model = CustomIDModel(pipe, trained_ckpt, _DEVICE, _DTYPE, num_token)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "num_samples=3\n", 51 | "gs= 3.5\n", 52 | "_seed=2024\n", 53 | "h=1024\n", 54 | "w=1024\n", 55 | "img_path = \"img/man1.jpg\"\n", 56 | "p=\"A man wearing a classic leather jacket leans against a vintage motorcycle, surrounded by autumn leaves swirling in the breeze.\"\n", 57 | "images = customID_model.generate(pil_image=img_path,\n", 58 | " prompt=p,\n", 59 | " num_samples=num_samples,\n", 60 | " height=h,\n", 61 | " width=w,\n", 62 | " seed=_seed,\n", 63 | " num_inference_steps=28,\n", 64 | " guidance_scale=gs)\n", 65 | "grid = image_grid(images, 1, num_samples)\n", 66 | "grid" 67 | ] 68 | } 69 | ], 70 | "metadata": { 71 | "kernelspec": { 72 | "display_name": "pt20", 73 | "language": "python", 74 | "name": "python3" 75 | }, 76 | "language_info": { 77 | "codemirror_mode": { 78 | "name": "ipython", 79 | "version": 3 80 | }, 81 | "file_extension": ".py", 82 | "mimetype": "text/x-python", 83 | "name": "python", 84 | "nbconvert_exporter": "python", 85 | "pygments_lexer": "ipython3", 86 | "version": "3.10.15" 87 | } 88 | }, 89 | "nbformat": 4, 90 | "nbformat_minor": 2 91 | } 92 | -------------------------------------------------------------------------------- /pretrained_ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/FLUX-customID/591f170db746be1c7da740996c1af5f4dcd729a1/pretrained_ckpt -------------------------------------------------------------------------------- /src/customID/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/damo-cv/FLUX-customID/591f170db746be1c7da740996c1af5f4dcd729a1/src/customID/__init__.py -------------------------------------------------------------------------------- /src/customID/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: LiDongyang(yingtian.ldy@alibaba-inc.com | amo5lee@aliyun.com) 3 | Date: 2024-10 4 | Description: Customized Image Generation Model Based on Facial ID. 5 | """ 6 | import os 7 | from typing import List 8 | # import math 9 | import torch 10 | from PIL import Image 11 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 12 | import pdb 13 | from .utils import is_torch2_available, get_generator 14 | from .attention_processor import CATRefFluxAttnProcessor2_0 15 | from .resampler import PerceiverAttention, FeedForward 16 | from src.utils.insightface_package import FaceAnalysis2, analyze_faces 17 | import cv2 18 | USE_DAFAULT_ATTN = False # should be True for visualization_attnmap 19 | 20 | class FacePerceiverResampler(torch.nn.Module): 21 | def __init__( 22 | self, 23 | *, 24 | dim=768, 25 | depth=4, 26 | dim_head=64, 27 | heads=16, 28 | embedding_dim=1280, 29 | output_dim=768, 30 | ff_mult=4, 31 | ): 32 | super().__init__() 33 | 34 | self.proj_in = torch.nn.Linear(embedding_dim, dim) 35 | self.proj_out = torch.nn.Linear(dim, output_dim) 36 | self.norm_out = torch.nn.LayerNorm(output_dim) 37 | self.layers = torch.nn.ModuleList([]) 38 | for _ in range(depth): 39 | self.layers.append( 40 | torch.nn.ModuleList( 41 | [ 42 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 43 | FeedForward(dim=dim, mult=ff_mult), 44 | ] 45 | ) 46 | ) 47 | 48 | def forward(self, latents, x): 49 | x = self.proj_in(x) 50 | for attn, ff in self.layers: 51 | latents = attn(x, latents) + latents 52 | latents = ff(latents) + latents 53 | latents = self.proj_out(latents) 54 | return self.norm_out(latents) 55 | 56 | class ProjPlusModel(torch.nn.Module): 57 | def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4, output_dim=3072): 58 | super().__init__() 59 | 60 | self.cross_attention_dim = cross_attention_dim 61 | self.num_tokens = num_tokens 62 | 63 | self.proj = torch.nn.Sequential( 64 | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), 65 | torch.nn.GELU(), 66 | torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), 67 | ) 68 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 69 | 70 | self.perceiver_resampler = FacePerceiverResampler( 71 | dim=cross_attention_dim, 72 | depth=4, 73 | dim_head=64, 74 | heads=cross_attention_dim // 64, 75 | embedding_dim=clip_embeddings_dim, 76 | output_dim=output_dim, 77 | ff_mult=4, 78 | ) 79 | self.prj_out_clip = torch.nn.Sequential( 80 | torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim*2), 81 | torch.nn.GELU(), 82 | torch.nn.Linear(clip_embeddings_dim*2, output_dim), 83 | ) 84 | 85 | def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0): 86 | 87 | x = self.proj(id_embeds) 88 | x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) 89 | x = self.norm(x) 90 | out = self.perceiver_resampler(x, clip_embeds) 91 | if shortcut: 92 | out = x + scale * out 93 | return torch.cat([out, self.prj_out_clip(clip_embeds)], dim=1) 94 | 95 | class CustomIDModel: 96 | def __init__(self, sd_pipe, trained_ckpt, device, dtype, num_tokens=4, image_encoder_path="pretrained_ckpt/openclip-vit-h-14"): 97 | self.device = device 98 | self.dtype = dtype 99 | self.trained_ckpt = trained_ckpt 100 | self.num_tokens = num_tokens 101 | self.pipe = sd_pipe 102 | self.image_encoder_path = image_encoder_path 103 | 104 | # load image encoder 105 | self.clip_image_processor = CLIPImageProcessor() 106 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( 107 | self.device) 108 | 109 | self.set_id_adapter() 110 | 111 | # image proj model 112 | self.image_proj_model = self.init_proj() 113 | self.image_proj_model.to(self.device) 114 | if self.trained_ckpt != None: 115 | self.load_id_adapter() 116 | 117 | self.face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition']) 118 | self.face_detector.prepare(ctx_id=0, det_size=(640, 640)) 119 | 120 | def init_proj(self): 121 | image_proj_model = ProjPlusModel( 122 | cross_attention_dim=self.image_encoder.config.hidden_size, 123 | id_embeddings_dim=512, 124 | clip_embeddings_dim=self.image_encoder.config.hidden_size, 125 | num_tokens=self.num_tokens, 126 | output_dim=self.pipe.transformer.config.num_attention_heads * self.pipe.transformer.config.attention_head_dim, 127 | ).to(self.device) 128 | return image_proj_model 129 | 130 | def set_id_adapter(self): 131 | # init adapter modules 132 | attn_procs = {} 133 | transformer_sd = self.pipe.transformer.state_dict() 134 | for name in self.pipe.transformer.attn_processors.keys(): 135 | if name.startswith("transformer_blocks"): 136 | attn_procs[name] = CATRefFluxAttnProcessor2_0(self.pipe.transformer.config.num_attention_heads * self.pipe.transformer.config.attention_head_dim, 137 | self.pipe.transformer.config.num_attention_heads * self.pipe.transformer.config.attention_head_dim, 138 | self.pipe.transformer.config.attention_head_dim, 139 | self.num_tokens+256,#! 140 | ).to(self.device, dtype=self.dtype) 141 | elif name.startswith("single_transformer_blocks"): 142 | attn_procs[name] = CATRefFluxAttnProcessor2_0(self.pipe.transformer.config.num_attention_heads * self.pipe.transformer.config.attention_head_dim, 143 | self.pipe.transformer.config.num_attention_heads * self.pipe.transformer.config.attention_head_dim, 144 | self.pipe.transformer.config.attention_head_dim, 145 | self.num_tokens+256, 146 | ).to(self.device, dtype=self.dtype) 147 | self.pipe.transformer.set_attn_processor(attn_procs) 148 | 149 | def load_id_adapter(self): 150 | state_dict = torch.load(self.trained_ckpt, map_location=torch.device('cpu')) 151 | self.image_proj_model.load_state_dict(state_dict["img_prj_state"], strict=True) 152 | m,u = self.pipe.transformer.load_state_dict(state_dict["attn_processor_state"], strict=False) 153 | assert len(u)==0 154 | 155 | @torch.inference_mode() 156 | def get_image_embeds(self, pil_image): 157 | image_ = cv2.imread(pil_image) 158 | faces = analyze_faces(self.face_detector, image_) 159 | faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0).to(self.device) 160 | faceid_embeds = faceid_embeds.unsqueeze(0) 161 | 162 | #clip 163 | face_image = Image.open(pil_image) 164 | if isinstance(face_image, Image.Image): 165 | pil_image = [face_image] 166 | clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values 167 | clip_image = clip_image.to(self.device, dtype=self.dtype) 168 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 169 | 170 | ip_tokens = self.image_proj_model(faceid_embeds, clip_image_embeds[:,1:,:]) 171 | assert ip_tokens.shape[1] == self.num_tokens+256 172 | 173 | return ip_tokens.to(self.device, dtype=self.dtype) 174 | 175 | def generate( 176 | self, 177 | pil_image=None, 178 | prompt=None, 179 | num_samples=4, 180 | height=1024, 181 | width=1024, 182 | seed=None, 183 | num_inference_steps=30, 184 | guidance_scale=3.5, 185 | ): 186 | 187 | ip_tokens = self.get_image_embeds(pil_image=pil_image) 188 | 189 | bs_embed, seq_len, _ = ip_tokens.shape 190 | ip_tokens = ip_tokens.repeat(1, num_samples, 1) 191 | ip_tokens = ip_tokens.view(bs_embed * num_samples, seq_len, -1) 192 | ip_tokens = ip_tokens.to(self.device).to(self.dtype) 193 | 194 | ip_token_ids = self.pipe._prepare_latent_image_ids( 195 | 1, 196 | 1*2, 197 | (self.num_tokens+256)*2, 198 | self.device, 199 | self.dtype, 200 | ) 201 | images = self.pipe( 202 | prompt, 203 | ip_token=ip_tokens, 204 | ip_token_ids=ip_token_ids, 205 | num_images_per_prompt=num_samples, 206 | height=height, 207 | width=width, 208 | output_type="pil", 209 | num_inference_steps=num_inference_steps, 210 | generator=torch.Generator(self.device).manual_seed(seed), 211 | guidance_scale=guidance_scale, 212 | ).images 213 | 214 | return images -------------------------------------------------------------------------------- /src/customID/pipeline_flux.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import Any, Callable, Dict, List, Optional, Union 17 | 18 | import numpy as np 19 | import torch 20 | from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast 21 | 22 | from diffusers.image_processor import VaeImageProcessor 23 | from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin 24 | from diffusers.models.autoencoders import AutoencoderKL 25 | from diffusers.models.transformers import FluxTransformer2DModel 26 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 27 | from diffusers.utils import ( 28 | USE_PEFT_BACKEND, 29 | is_torch_xla_available, 30 | logging, 31 | replace_example_docstring, 32 | scale_lora_layers, 33 | unscale_lora_layers, 34 | ) 35 | from diffusers.utils.torch_utils import randn_tensor 36 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 37 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput 38 | 39 | 40 | if is_torch_xla_available(): 41 | import torch_xla.core.xla_model as xm 42 | 43 | XLA_AVAILABLE = True 44 | else: 45 | XLA_AVAILABLE = False 46 | 47 | 48 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 49 | 50 | EXAMPLE_DOC_STRING = """ 51 | Examples: 52 | ```py 53 | >>> import torch 54 | >>> from diffusers import FluxPipeline 55 | 56 | >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) 57 | >>> pipe.to("cuda") 58 | >>> prompt = "A cat holding a sign that says hello world" 59 | >>> # Depending on the variant being used, the pipeline call will slightly vary. 60 | >>> # Refer to the pipeline documentation for more details. 61 | >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] 62 | >>> image.save("flux.png") 63 | ``` 64 | """ 65 | 66 | 67 | def calculate_shift( 68 | image_seq_len, 69 | base_seq_len: int = 256, 70 | max_seq_len: int = 4096, 71 | base_shift: float = 0.5, 72 | max_shift: float = 1.16, 73 | ): 74 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) 75 | b = base_shift - m * base_seq_len 76 | mu = image_seq_len * m + b 77 | return mu 78 | 79 | 80 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 81 | def retrieve_timesteps( 82 | scheduler, 83 | num_inference_steps: Optional[int] = None, 84 | device: Optional[Union[str, torch.device]] = None, 85 | timesteps: Optional[List[int]] = None, 86 | sigmas: Optional[List[float]] = None, 87 | **kwargs, 88 | ): 89 | """ 90 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 91 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 92 | 93 | Args: 94 | scheduler (`SchedulerMixin`): 95 | The scheduler to get timesteps from. 96 | num_inference_steps (`int`): 97 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 98 | must be `None`. 99 | device (`str` or `torch.device`, *optional*): 100 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 101 | timesteps (`List[int]`, *optional*): 102 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 103 | `num_inference_steps` and `sigmas` must be `None`. 104 | sigmas (`List[float]`, *optional*): 105 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 106 | `num_inference_steps` and `timesteps` must be `None`. 107 | 108 | Returns: 109 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 110 | second element is the number of inference steps. 111 | """ 112 | if timesteps is not None and sigmas is not None: 113 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 114 | if timesteps is not None: 115 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 116 | if not accepts_timesteps: 117 | raise ValueError( 118 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 119 | f" timestep schedules. Please check whether you are using the correct scheduler." 120 | ) 121 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 122 | timesteps = scheduler.timesteps 123 | num_inference_steps = len(timesteps) 124 | elif sigmas is not None: 125 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 126 | if not accept_sigmas: 127 | raise ValueError( 128 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 129 | f" sigmas schedules. Please check whether you are using the correct scheduler." 130 | ) 131 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 132 | timesteps = scheduler.timesteps 133 | num_inference_steps = len(timesteps) 134 | else: 135 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 136 | timesteps = scheduler.timesteps 137 | return timesteps, num_inference_steps 138 | 139 | 140 | class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): 141 | r""" 142 | The Flux pipeline for text-to-image generation. 143 | 144 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 145 | 146 | Args: 147 | transformer ([`FluxTransformer2DModel`]): 148 | Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. 149 | scheduler ([`FlowMatchEulerDiscreteScheduler`]): 150 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 151 | vae ([`AutoencoderKL`]): 152 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 153 | text_encoder ([`CLIPTextModel`]): 154 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 155 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 156 | text_encoder_2 ([`T5EncoderModel`]): 157 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically 158 | the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. 159 | tokenizer (`CLIPTokenizer`): 160 | Tokenizer of class 161 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). 162 | tokenizer_2 (`T5TokenizerFast`): 163 | Second Tokenizer of class 164 | [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). 165 | """ 166 | 167 | model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" 168 | _optional_components = [] 169 | _callback_tensor_inputs = ["latents", "prompt_embeds"] 170 | 171 | def __init__( 172 | self, 173 | scheduler: FlowMatchEulerDiscreteScheduler, 174 | vae: AutoencoderKL, 175 | text_encoder: CLIPTextModel, 176 | tokenizer: CLIPTokenizer, 177 | text_encoder_2: T5EncoderModel, 178 | tokenizer_2: T5TokenizerFast, 179 | transformer: FluxTransformer2DModel, 180 | ): 181 | super().__init__() 182 | 183 | self.register_modules( 184 | vae=vae, 185 | text_encoder=text_encoder, 186 | text_encoder_2=text_encoder_2, 187 | tokenizer=tokenizer, 188 | tokenizer_2=tokenizer_2, 189 | transformer=transformer, 190 | scheduler=scheduler, 191 | ) 192 | self.vae_scale_factor = ( 193 | 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 194 | ) 195 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 196 | self.tokenizer_max_length = ( 197 | self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 198 | ) 199 | self.default_sample_size = 64 200 | 201 | def _get_t5_prompt_embeds( 202 | self, 203 | prompt: Union[str, List[str]] = None, 204 | num_images_per_prompt: int = 1, 205 | max_sequence_length: int = 512, 206 | device: Optional[torch.device] = None, 207 | dtype: Optional[torch.dtype] = None, 208 | ): 209 | device = device or self._execution_device 210 | dtype = dtype or self.text_encoder.dtype 211 | 212 | prompt = [prompt] if isinstance(prompt, str) else prompt 213 | batch_size = len(prompt) 214 | 215 | text_inputs = self.tokenizer_2( 216 | prompt, 217 | padding="max_length", 218 | max_length=max_sequence_length, 219 | truncation=True, 220 | return_length=False, 221 | return_overflowing_tokens=False, 222 | return_tensors="pt", 223 | ) 224 | text_input_ids = text_inputs.input_ids 225 | untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids 226 | 227 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 228 | removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) 229 | logger.warning( 230 | "The following part of your input was truncated because `max_sequence_length` is set to " 231 | f" {max_sequence_length} tokens: {removed_text}" 232 | ) 233 | 234 | prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] 235 | 236 | dtype = self.text_encoder_2.dtype 237 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 238 | 239 | _, seq_len, _ = prompt_embeds.shape 240 | 241 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 242 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 243 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 244 | 245 | return prompt_embeds 246 | 247 | def _get_clip_prompt_embeds( 248 | self, 249 | prompt: Union[str, List[str]], 250 | num_images_per_prompt: int = 1, 251 | device: Optional[torch.device] = None, 252 | ): 253 | device = device or self._execution_device 254 | 255 | prompt = [prompt] if isinstance(prompt, str) else prompt 256 | batch_size = len(prompt) 257 | 258 | text_inputs = self.tokenizer( 259 | prompt, 260 | padding="max_length", 261 | max_length=self.tokenizer_max_length, 262 | truncation=True, 263 | return_overflowing_tokens=False, 264 | return_length=False, 265 | return_tensors="pt", 266 | ) 267 | 268 | text_input_ids = text_inputs.input_ids 269 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 270 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 271 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) 272 | logger.warning( 273 | "The following part of your input was truncated because CLIP can only handle sequences up to" 274 | f" {self.tokenizer_max_length} tokens: {removed_text}" 275 | ) 276 | prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) 277 | 278 | # Use pooled output of CLIPTextModel 279 | prompt_embeds = prompt_embeds.pooler_output 280 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 281 | 282 | # duplicate text embeddings for each generation per prompt, using mps friendly method 283 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 284 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 285 | 286 | return prompt_embeds 287 | 288 | def encode_prompt( 289 | self, 290 | prompt: Union[str, List[str]], 291 | prompt_2: Union[str, List[str]], 292 | device: Optional[torch.device] = None, 293 | num_images_per_prompt: int = 1, 294 | prompt_embeds: Optional[torch.FloatTensor] = None, 295 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 296 | max_sequence_length: int = 512, 297 | lora_scale: Optional[float] = None, 298 | ): 299 | r""" 300 | 301 | Args: 302 | prompt (`str` or `List[str]`, *optional*): 303 | prompt to be encoded 304 | prompt_2 (`str` or `List[str]`, *optional*): 305 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 306 | used in all text-encoders 307 | device: (`torch.device`): 308 | torch device 309 | num_images_per_prompt (`int`): 310 | number of images that should be generated per prompt 311 | prompt_embeds (`torch.FloatTensor`, *optional*): 312 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 313 | provided, text embeddings will be generated from `prompt` input argument. 314 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 315 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 316 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 317 | lora_scale (`float`, *optional*): 318 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 319 | """ 320 | device = device or self._execution_device 321 | 322 | # set lora scale so that monkey patched LoRA 323 | # function of text encoder can correctly access it 324 | if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): 325 | self._lora_scale = lora_scale 326 | 327 | # dynamically adjust the LoRA scale 328 | if self.text_encoder is not None and USE_PEFT_BACKEND: 329 | scale_lora_layers(self.text_encoder, lora_scale) 330 | if self.text_encoder_2 is not None and USE_PEFT_BACKEND: 331 | scale_lora_layers(self.text_encoder_2, lora_scale) 332 | 333 | prompt = [prompt] if isinstance(prompt, str) else prompt 334 | 335 | if prompt_embeds is None: 336 | prompt_2 = prompt_2 or prompt 337 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 338 | 339 | # We only use the pooled prompt output from the CLIPTextModel 340 | pooled_prompt_embeds = self._get_clip_prompt_embeds( 341 | prompt=prompt, 342 | device=device, 343 | num_images_per_prompt=num_images_per_prompt, 344 | ) 345 | prompt_embeds = self._get_t5_prompt_embeds( 346 | prompt=prompt_2, 347 | num_images_per_prompt=num_images_per_prompt, 348 | max_sequence_length=max_sequence_length, 349 | device=device, 350 | ) 351 | 352 | if self.text_encoder is not None: 353 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 354 | # Retrieve the original scale by scaling back the LoRA layers 355 | unscale_lora_layers(self.text_encoder, lora_scale) 356 | 357 | if self.text_encoder_2 is not None: 358 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 359 | # Retrieve the original scale by scaling back the LoRA layers 360 | unscale_lora_layers(self.text_encoder_2, lora_scale) 361 | 362 | dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype 363 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) 364 | 365 | return prompt_embeds, pooled_prompt_embeds, text_ids 366 | 367 | def check_inputs( 368 | self, 369 | prompt, 370 | prompt_2, 371 | height, 372 | width, 373 | prompt_embeds=None, 374 | pooled_prompt_embeds=None, 375 | callback_on_step_end_tensor_inputs=None, 376 | max_sequence_length=None, 377 | ): 378 | if height % 8 != 0 or width % 8 != 0: 379 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 380 | 381 | if callback_on_step_end_tensor_inputs is not None and not all( 382 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 383 | ): 384 | raise ValueError( 385 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 386 | ) 387 | 388 | if prompt is not None and prompt_embeds is not None: 389 | raise ValueError( 390 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 391 | " only forward one of the two." 392 | ) 393 | elif prompt_2 is not None and prompt_embeds is not None: 394 | raise ValueError( 395 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 396 | " only forward one of the two." 397 | ) 398 | elif prompt is None and prompt_embeds is None: 399 | raise ValueError( 400 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 401 | ) 402 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 403 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 404 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): 405 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") 406 | 407 | if prompt_embeds is not None and pooled_prompt_embeds is None: 408 | raise ValueError( 409 | "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." 410 | ) 411 | 412 | if max_sequence_length is not None and max_sequence_length > 512: 413 | raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") 414 | 415 | @staticmethod 416 | def _prepare_latent_image_ids(batch_size, height, width, device, dtype): 417 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) 418 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] 419 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] 420 | 421 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape 422 | 423 | latent_image_ids = latent_image_ids.reshape( 424 | latent_image_id_height * latent_image_id_width, latent_image_id_channels 425 | ) 426 | 427 | return latent_image_ids.to(device=device, dtype=dtype) 428 | 429 | @staticmethod 430 | def _pack_latents(latents, batch_size, num_channels_latents, height, width): 431 | latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) 432 | latents = latents.permute(0, 2, 4, 1, 3, 5) 433 | latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) 434 | 435 | return latents 436 | 437 | @staticmethod 438 | def _unpack_latents(latents, height, width, vae_scale_factor): 439 | batch_size, num_patches, channels = latents.shape 440 | 441 | height = height // vae_scale_factor 442 | width = width // vae_scale_factor 443 | 444 | latents = latents.view(batch_size, height, width, channels // 4, 2, 2) 445 | latents = latents.permute(0, 3, 1, 4, 2, 5) 446 | 447 | latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) 448 | 449 | return latents 450 | 451 | def enable_vae_slicing(self): 452 | r""" 453 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 454 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 455 | """ 456 | self.vae.enable_slicing() 457 | 458 | def disable_vae_slicing(self): 459 | r""" 460 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to 461 | computing decoding in one step. 462 | """ 463 | self.vae.disable_slicing() 464 | 465 | def enable_vae_tiling(self): 466 | r""" 467 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 468 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow 469 | processing larger images. 470 | """ 471 | self.vae.enable_tiling() 472 | 473 | def disable_vae_tiling(self): 474 | r""" 475 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to 476 | computing decoding in one step. 477 | """ 478 | self.vae.disable_tiling() 479 | 480 | def prepare_latents( 481 | self, 482 | batch_size, 483 | num_channels_latents, 484 | height, 485 | width, 486 | dtype, 487 | device, 488 | generator, 489 | latents=None, 490 | ): 491 | height = 2 * (int(height) // self.vae_scale_factor) 492 | width = 2 * (int(width) // self.vae_scale_factor) 493 | 494 | shape = (batch_size, num_channels_latents, height, width) 495 | 496 | if latents is not None: 497 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 498 | return latents.to(device=device, dtype=dtype), latent_image_ids 499 | 500 | if isinstance(generator, list) and len(generator) != batch_size: 501 | raise ValueError( 502 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 503 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 504 | ) 505 | 506 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 507 | latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) 508 | 509 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 510 | 511 | return latents, latent_image_ids 512 | 513 | @property 514 | def guidance_scale(self): 515 | return self._guidance_scale 516 | 517 | @property 518 | def joint_attention_kwargs(self): 519 | return self._joint_attention_kwargs 520 | 521 | @property 522 | def num_timesteps(self): 523 | return self._num_timesteps 524 | 525 | @property 526 | def interrupt(self): 527 | return self._interrupt 528 | 529 | @torch.no_grad() 530 | @replace_example_docstring(EXAMPLE_DOC_STRING) 531 | def __call__( 532 | self, 533 | prompt: Union[str, List[str]] = None, 534 | prompt_2: Optional[Union[str, List[str]]] = None, 535 | height: Optional[int] = None, 536 | width: Optional[int] = None, 537 | num_inference_steps: int = 28, 538 | timesteps: List[int] = None, 539 | guidance_scale: float = 7.0, 540 | num_images_per_prompt: Optional[int] = 1, 541 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 542 | latents: Optional[torch.FloatTensor] = None, 543 | prompt_embeds: Optional[torch.FloatTensor] = None, 544 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 545 | output_type: Optional[str] = "pil", 546 | return_dict: bool = True, 547 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 548 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 549 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 550 | max_sequence_length: int = 512, 551 | ip_token = None, 552 | ip_token_ids=None, 553 | ): 554 | r""" 555 | Function invoked when calling the pipeline for generation. 556 | 557 | Args: 558 | prompt (`str` or `List[str]`, *optional*): 559 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 560 | instead. 561 | prompt_2 (`str` or `List[str]`, *optional*): 562 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 563 | will be used instead 564 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 565 | The height in pixels of the generated image. This is set to 1024 by default for the best results. 566 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 567 | The width in pixels of the generated image. This is set to 1024 by default for the best results. 568 | num_inference_steps (`int`, *optional*, defaults to 50): 569 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 570 | expense of slower inference. 571 | timesteps (`List[int]`, *optional*): 572 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 573 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 574 | passed will be used. Must be in descending order. 575 | guidance_scale (`float`, *optional*, defaults to 7.0): 576 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 577 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 578 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 579 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 580 | usually at the expense of lower image quality. 581 | num_images_per_prompt (`int`, *optional*, defaults to 1): 582 | The number of images to generate per prompt. 583 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 584 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 585 | to make generation deterministic. 586 | latents (`torch.FloatTensor`, *optional*): 587 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 588 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 589 | tensor will ge generated by sampling using the supplied random `generator`. 590 | prompt_embeds (`torch.FloatTensor`, *optional*): 591 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 592 | provided, text embeddings will be generated from `prompt` input argument. 593 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 594 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 595 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 596 | output_type (`str`, *optional*, defaults to `"pil"`): 597 | The output format of the generate image. Choose between 598 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 599 | return_dict (`bool`, *optional*, defaults to `True`): 600 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. 601 | joint_attention_kwargs (`dict`, *optional*): 602 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 603 | `self.processor` in 604 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 605 | callback_on_step_end (`Callable`, *optional*): 606 | A function that calls at the end of each denoising steps during the inference. The function is called 607 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 608 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 609 | `callback_on_step_end_tensor_inputs`. 610 | callback_on_step_end_tensor_inputs (`List`, *optional*): 611 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 612 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 613 | `._callback_tensor_inputs` attribute of your pipeline class. 614 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. 615 | 616 | Examples: 617 | 618 | Returns: 619 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` 620 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated 621 | images. 622 | """ 623 | 624 | height = height or self.default_sample_size * self.vae_scale_factor 625 | width = width or self.default_sample_size * self.vae_scale_factor 626 | 627 | # 1. Check inputs. Raise error if not correct 628 | self.check_inputs( 629 | prompt, 630 | prompt_2, 631 | height, 632 | width, 633 | prompt_embeds=prompt_embeds, 634 | pooled_prompt_embeds=pooled_prompt_embeds, 635 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 636 | max_sequence_length=max_sequence_length, 637 | ) 638 | 639 | self._guidance_scale = guidance_scale 640 | self._joint_attention_kwargs = joint_attention_kwargs 641 | self._interrupt = False 642 | 643 | # 2. Define call parameters 644 | if prompt is not None and isinstance(prompt, str): 645 | batch_size = 1 646 | elif prompt is not None and isinstance(prompt, list): 647 | batch_size = len(prompt) 648 | else: 649 | batch_size = prompt_embeds.shape[0] 650 | 651 | device = self._execution_device 652 | 653 | lora_scale = ( 654 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 655 | ) 656 | ( 657 | prompt_embeds, 658 | pooled_prompt_embeds, 659 | text_ids, 660 | ) = self.encode_prompt( 661 | prompt=prompt, 662 | prompt_2=prompt_2, 663 | prompt_embeds=prompt_embeds, 664 | pooled_prompt_embeds=pooled_prompt_embeds, 665 | device=device, 666 | num_images_per_prompt=num_images_per_prompt, 667 | max_sequence_length=max_sequence_length, 668 | lora_scale=lora_scale, 669 | ) 670 | 671 | # 4. Prepare latent variables 672 | num_channels_latents = self.transformer.config.in_channels // 4 673 | latents, latent_image_ids = self.prepare_latents( 674 | batch_size * num_images_per_prompt, 675 | num_channels_latents, 676 | height, 677 | width, 678 | prompt_embeds.dtype, 679 | device, 680 | generator, 681 | latents, 682 | ) 683 | 684 | # 5. Prepare timesteps 685 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 686 | image_seq_len = latents.shape[1] 687 | mu = calculate_shift( 688 | image_seq_len, 689 | self.scheduler.config.base_image_seq_len, 690 | self.scheduler.config.max_image_seq_len, 691 | self.scheduler.config.base_shift, 692 | self.scheduler.config.max_shift, 693 | ) 694 | timesteps, num_inference_steps = retrieve_timesteps( 695 | self.scheduler, 696 | num_inference_steps, 697 | device, 698 | timesteps, 699 | sigmas, 700 | mu=mu, 701 | ) 702 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 703 | self._num_timesteps = len(timesteps) 704 | 705 | # handle guidance 706 | if self.transformer.config.guidance_embeds: 707 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) 708 | guidance = guidance.expand(latents.shape[0]) 709 | else: 710 | guidance = None 711 | 712 | # 6. Denoising loop 713 | with self.progress_bar(total=num_inference_steps) as progress_bar: 714 | for i, t in enumerate(timesteps): 715 | if self.interrupt: 716 | continue 717 | 718 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 719 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 720 | 721 | noise_pred = self.transformer( 722 | hidden_states=latents, 723 | timestep=timestep / 1000, 724 | guidance=guidance, 725 | pooled_projections=pooled_prompt_embeds, 726 | encoder_hidden_states=prompt_embeds, 727 | txt_ids=text_ids, 728 | img_ids=latent_image_ids, 729 | ip_token_ids=ip_token_ids, 730 | joint_attention_kwargs=self.joint_attention_kwargs, 731 | return_dict=False, 732 | ip_token=ip_token, 733 | )[0] 734 | 735 | # compute the previous noisy sample x_t -> x_t-1 736 | latents_dtype = latents.dtype 737 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 738 | 739 | if latents.dtype != latents_dtype: 740 | if torch.backends.mps.is_available(): 741 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 742 | latents = latents.to(latents_dtype) 743 | 744 | if callback_on_step_end is not None: 745 | callback_kwargs = {} 746 | for k in callback_on_step_end_tensor_inputs: 747 | callback_kwargs[k] = locals()[k] 748 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 749 | 750 | latents = callback_outputs.pop("latents", latents) 751 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 752 | 753 | # call the callback, if provided 754 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 755 | progress_bar.update() 756 | 757 | if XLA_AVAILABLE: 758 | xm.mark_step() 759 | 760 | if output_type == "latent": 761 | image = latents 762 | 763 | else: 764 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 765 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 766 | image = self.vae.decode(latents, return_dict=False)[0] 767 | image = self.image_processor.postprocess(image, output_type=output_type) 768 | 769 | # Offload all models 770 | self.maybe_free_model_hooks() 771 | 772 | if not return_dict: 773 | return (image,) 774 | 775 | return FluxPipelineOutput(images=image) 776 | -------------------------------------------------------------------------------- /src/customID/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | from einops import rearrange 9 | from einops.layers.torch import Rearrange 10 | 11 | 12 | # FFN 13 | def FeedForward(dim, mult=4): 14 | inner_dim = int(dim * mult) 15 | return nn.Sequential( 16 | nn.LayerNorm(dim), 17 | nn.Linear(dim, inner_dim, bias=False), 18 | nn.GELU(), 19 | nn.Linear(inner_dim, dim, bias=False), 20 | ) 21 | 22 | 23 | def reshape_tensor(x, heads): 24 | bs, length, width = x.shape 25 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head) 26 | x = x.view(bs, length, heads, -1) 27 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 28 | x = x.transpose(1, 2) 29 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 30 | x = x.reshape(bs, heads, length, -1) 31 | return x 32 | 33 | 34 | class PerceiverAttention(nn.Module): 35 | def __init__(self, *, dim, dim_head=64, heads=8): 36 | super().__init__() 37 | self.scale = dim_head**-0.5 38 | self.dim_head = dim_head 39 | self.heads = heads 40 | inner_dim = dim_head * heads 41 | 42 | self.norm1 = nn.LayerNorm(dim) 43 | self.norm2 = nn.LayerNorm(dim) 44 | 45 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 46 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 47 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 48 | 49 | def forward(self, x, latents): 50 | """ 51 | Args: 52 | x (torch.Tensor): image features 53 | shape (b, n1, D) 54 | latent (torch.Tensor): latent features 55 | shape (b, n2, D) 56 | """ 57 | x = self.norm1(x) 58 | latents = self.norm2(latents) 59 | 60 | b, l, _ = latents.shape 61 | 62 | q = self.to_q(latents) 63 | kv_input = torch.cat((x, latents), dim=-2) 64 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 65 | 66 | q = reshape_tensor(q, self.heads) 67 | k = reshape_tensor(k, self.heads) 68 | v = reshape_tensor(v, self.heads) 69 | 70 | # attention 71 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 72 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 73 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 74 | out = weight @ v 75 | 76 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 77 | 78 | return self.to_out(out) 79 | 80 | 81 | class Resampler(nn.Module): 82 | def __init__( 83 | self, 84 | dim=1024, 85 | depth=8, 86 | dim_head=64, 87 | heads=16, 88 | num_queries=8, 89 | embedding_dim=768, 90 | output_dim=1024, 91 | ff_mult=4, 92 | max_seq_len: int = 257, # CLIP tokens + CLS token 93 | apply_pos_emb: bool = False, 94 | num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence 95 | ): 96 | super().__init__() 97 | self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None 98 | 99 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 100 | 101 | self.proj_in = nn.Linear(embedding_dim, dim) 102 | 103 | self.proj_out = nn.Linear(dim, output_dim) 104 | self.norm_out = nn.LayerNorm(output_dim) 105 | 106 | self.to_latents_from_mean_pooled_seq = ( 107 | nn.Sequential( 108 | nn.LayerNorm(dim), 109 | nn.Linear(dim, dim * num_latents_mean_pooled), 110 | Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), 111 | ) 112 | if num_latents_mean_pooled > 0 113 | else None 114 | ) 115 | 116 | self.layers = nn.ModuleList([]) 117 | for _ in range(depth): 118 | self.layers.append( 119 | nn.ModuleList( 120 | [ 121 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 122 | FeedForward(dim=dim, mult=ff_mult), 123 | ] 124 | ) 125 | ) 126 | 127 | def forward(self, x): 128 | if self.pos_emb is not None: 129 | n, device = x.shape[1], x.device 130 | pos_emb = self.pos_emb(torch.arange(n, device=device)) 131 | x = x + pos_emb 132 | 133 | latents = self.latents.repeat(x.size(0), 1, 1) 134 | 135 | x = self.proj_in(x) 136 | 137 | if self.to_latents_from_mean_pooled_seq: 138 | meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) 139 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) 140 | latents = torch.cat((meanpooled_latents, latents), dim=-2) 141 | 142 | for attn, ff in self.layers: 143 | latents = attn(x, latents) + latents 144 | latents = ff(latents) + latents 145 | 146 | latents = self.proj_out(latents) 147 | return self.norm_out(latents) 148 | 149 | 150 | def masked_mean(t, *, dim, mask=None): 151 | if mask is None: 152 | return t.mean(dim=dim) 153 | 154 | denom = mask.sum(dim=dim, keepdim=True) 155 | mask = rearrange(mask, "b n -> b n 1") 156 | masked_t = t.masked_fill(~mask, 0.0) 157 | 158 | return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) 159 | -------------------------------------------------------------------------------- /src/customID/transformer_flux.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Any, Dict, Optional, Tuple, Union 17 | 18 | import numpy as np 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | from diffusers.configuration_utils import ConfigMixin, register_to_config 24 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin 25 | from diffusers.models.attention import FeedForward 26 | from diffusers.models.attention_processor import ( 27 | Attention, 28 | AttentionProcessor, 29 | FluxAttnProcessor2_0, 30 | FusedFluxAttnProcessor2_0, 31 | ) 32 | from diffusers.models.modeling_utils import ModelMixin 33 | from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle 34 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 35 | from diffusers.utils.torch_utils import maybe_allow_in_graph 36 | from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed 37 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 38 | 39 | 40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 41 | 42 | 43 | @maybe_allow_in_graph 44 | class FluxSingleTransformerBlock(nn.Module): 45 | r""" 46 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. 47 | 48 | Reference: https://arxiv.org/abs/2403.03206 49 | 50 | Parameters: 51 | dim (`int`): The number of channels in the input and output. 52 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 53 | attention_head_dim (`int`): The number of channels in each head. 54 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the 55 | processing of `context` conditions. 56 | """ 57 | 58 | def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): 59 | super().__init__() 60 | self.mlp_hidden_dim = int(dim * mlp_ratio) 61 | 62 | self.norm = AdaLayerNormZeroSingle(dim) 63 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) 64 | self.act_mlp = nn.GELU(approximate="tanh") 65 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) 66 | 67 | processor = FluxAttnProcessor2_0() 68 | self.attn = Attention( 69 | query_dim=dim, 70 | cross_attention_dim=None, 71 | dim_head=attention_head_dim, 72 | heads=num_attention_heads, 73 | out_dim=dim, 74 | bias=True, 75 | processor=processor, 76 | qk_norm="rms_norm", 77 | eps=1e-6, 78 | pre_only=True, 79 | ) 80 | 81 | def forward( 82 | self, 83 | hidden_states: torch.FloatTensor, 84 | temb: torch.FloatTensor, 85 | image_rotary_emb=None, 86 | ip_token=None, 87 | ): 88 | residual = hidden_states 89 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb) 90 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) 91 | 92 | attn_output = self.attn( 93 | hidden_states=norm_hidden_states, 94 | image_rotary_emb=image_rotary_emb, 95 | ip_token=ip_token, 96 | ) 97 | 98 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) 99 | gate = gate.unsqueeze(1) 100 | hidden_states = gate * self.proj_out(hidden_states) 101 | hidden_states = residual + hidden_states 102 | if hidden_states.dtype == torch.float16: 103 | hidden_states = hidden_states.clip(-65504, 65504) 104 | 105 | return hidden_states 106 | 107 | 108 | @maybe_allow_in_graph 109 | class FluxTransformerBlock(nn.Module): 110 | r""" 111 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. 112 | 113 | Reference: https://arxiv.org/abs/2403.03206 114 | 115 | Parameters: 116 | dim (`int`): The number of channels in the input and output. 117 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 118 | attention_head_dim (`int`): The number of channels in each head. 119 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the 120 | processing of `context` conditions. 121 | """ 122 | 123 | def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): 124 | super().__init__() 125 | 126 | self.norm1 = AdaLayerNormZero(dim) 127 | 128 | self.norm1_context = AdaLayerNormZero(dim) 129 | 130 | if hasattr(F, "scaled_dot_product_attention"): 131 | processor = FluxAttnProcessor2_0() 132 | else: 133 | raise ValueError( 134 | "The current PyTorch version does not support the `scaled_dot_product_attention` function." 135 | ) 136 | self.attn = Attention( 137 | query_dim=dim, 138 | cross_attention_dim=None, 139 | added_kv_proj_dim=dim, 140 | dim_head=attention_head_dim, 141 | heads=num_attention_heads, 142 | out_dim=dim, 143 | context_pre_only=False, 144 | bias=True, 145 | processor=processor, 146 | qk_norm=qk_norm, 147 | eps=eps, 148 | ) 149 | 150 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 151 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 152 | 153 | self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 154 | self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 155 | 156 | # let chunk size default to None 157 | self._chunk_size = None 158 | self._chunk_dim = 0 159 | 160 | def forward( 161 | self, 162 | hidden_states: torch.FloatTensor, 163 | encoder_hidden_states: torch.FloatTensor, 164 | temb: torch.FloatTensor, 165 | image_rotary_emb=None, 166 | ip_token=None, 167 | ): 168 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) 169 | 170 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( 171 | encoder_hidden_states, emb=temb 172 | ) 173 | 174 | # Attention. 175 | attn_output, context_attn_output = self.attn( 176 | hidden_states=norm_hidden_states, 177 | encoder_hidden_states=norm_encoder_hidden_states, 178 | image_rotary_emb=image_rotary_emb, 179 | ip_token=ip_token, 180 | ) 181 | 182 | # Process attention outputs for the `hidden_states`. 183 | attn_output = gate_msa.unsqueeze(1) * attn_output 184 | hidden_states = hidden_states + attn_output 185 | 186 | norm_hidden_states = self.norm2(hidden_states) 187 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 188 | 189 | ff_output = self.ff(norm_hidden_states) 190 | ff_output = gate_mlp.unsqueeze(1) * ff_output 191 | 192 | hidden_states = hidden_states + ff_output 193 | 194 | # Process attention outputs for the `encoder_hidden_states`. 195 | 196 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output 197 | encoder_hidden_states = encoder_hidden_states + context_attn_output 198 | 199 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) 200 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] 201 | 202 | context_ff_output = self.ff_context(norm_encoder_hidden_states) 203 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output 204 | if encoder_hidden_states.dtype == torch.float16: 205 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) 206 | 207 | return encoder_hidden_states, hidden_states 208 | 209 | 210 | class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): 211 | """ 212 | The Transformer model introduced in Flux. 213 | 214 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 215 | 216 | Parameters: 217 | patch_size (`int`): Patch size to turn the input data into small patches. 218 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. 219 | num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. 220 | num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. 221 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. 222 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. 223 | joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 224 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. 225 | guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. 226 | """ 227 | 228 | _supports_gradient_checkpointing = True 229 | _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] 230 | 231 | @register_to_config 232 | def __init__( 233 | self, 234 | patch_size: int = 1, 235 | in_channels: int = 64, 236 | num_layers: int = 19, 237 | num_single_layers: int = 38, 238 | attention_head_dim: int = 128, 239 | num_attention_heads: int = 24, 240 | joint_attention_dim: int = 4096, 241 | pooled_projection_dim: int = 768, 242 | guidance_embeds: bool = False, 243 | axes_dims_rope: Tuple[int] = (16, 56, 56), 244 | ): 245 | super().__init__() 246 | self.out_channels = in_channels 247 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim 248 | 249 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) 250 | 251 | text_time_guidance_cls = ( 252 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings 253 | ) 254 | self.time_text_embed = text_time_guidance_cls( 255 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim 256 | ) 257 | 258 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) 259 | self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) 260 | 261 | self.transformer_blocks = nn.ModuleList( 262 | [ 263 | FluxTransformerBlock( 264 | dim=self.inner_dim, 265 | num_attention_heads=self.config.num_attention_heads, 266 | attention_head_dim=self.config.attention_head_dim, 267 | ) 268 | for i in range(self.config.num_layers) 269 | ] 270 | ) 271 | 272 | self.single_transformer_blocks = nn.ModuleList( 273 | [ 274 | FluxSingleTransformerBlock( 275 | dim=self.inner_dim, 276 | num_attention_heads=self.config.num_attention_heads, 277 | attention_head_dim=self.config.attention_head_dim, 278 | ) 279 | for i in range(self.config.num_single_layers) 280 | ] 281 | ) 282 | 283 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) 284 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) 285 | 286 | self.gradient_checkpointing = False 287 | 288 | @property 289 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 290 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 291 | r""" 292 | Returns: 293 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 294 | indexed by its weight name. 295 | """ 296 | # set recursively 297 | processors = {} 298 | 299 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 300 | if hasattr(module, "get_processor"): 301 | processors[f"{name}.processor"] = module.get_processor() 302 | 303 | for sub_name, child in module.named_children(): 304 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 305 | 306 | return processors 307 | 308 | for name, module in self.named_children(): 309 | fn_recursive_add_processors(name, module, processors) 310 | 311 | return processors 312 | 313 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 314 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 315 | r""" 316 | Sets the attention processor to use to compute attention. 317 | 318 | Parameters: 319 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 320 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 321 | for **all** `Attention` layers. 322 | 323 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 324 | processor. This is strongly recommended when setting trainable attention processors. 325 | 326 | """ 327 | count = len(self.attn_processors.keys()) 328 | 329 | if isinstance(processor, dict) and len(processor) != count: 330 | raise ValueError( 331 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 332 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 333 | ) 334 | 335 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 336 | if hasattr(module, "set_processor"): 337 | if not isinstance(processor, dict): 338 | module.set_processor(processor) 339 | else: 340 | module.set_processor(processor.pop(f"{name}.processor")) 341 | 342 | for sub_name, child in module.named_children(): 343 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 344 | 345 | for name, module in self.named_children(): 346 | fn_recursive_attn_processor(name, module, processor) 347 | 348 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 349 | def fuse_qkv_projections(self): 350 | """ 351 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) 352 | are fused. For cross-attention modules, key and value projection matrices are fused. 353 | 354 | 355 | 356 | This API is 🧪 experimental. 357 | 358 | 359 | """ 360 | self.original_attn_processors = None 361 | 362 | for _, attn_processor in self.attn_processors.items(): 363 | if "Added" in str(attn_processor.__class__.__name__): 364 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") 365 | 366 | self.original_attn_processors = self.attn_processors 367 | 368 | for module in self.modules(): 369 | if isinstance(module, Attention): 370 | module.fuse_projections(fuse=True) 371 | 372 | self.set_attn_processor(FusedFluxAttnProcessor2_0()) 373 | 374 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 375 | def unfuse_qkv_projections(self): 376 | """Disables the fused QKV projection if enabled. 377 | 378 | 379 | 380 | This API is 🧪 experimental. 381 | 382 | 383 | 384 | """ 385 | if self.original_attn_processors is not None: 386 | self.set_attn_processor(self.original_attn_processors) 387 | 388 | def _set_gradient_checkpointing(self, module, value=False): 389 | if hasattr(module, "gradient_checkpointing"): 390 | module.gradient_checkpointing = value 391 | 392 | def forward( 393 | self, 394 | hidden_states: torch.Tensor, 395 | encoder_hidden_states: torch.Tensor = None, 396 | pooled_projections: torch.Tensor = None, 397 | timestep: torch.LongTensor = None, 398 | img_ids: torch.Tensor = None, 399 | txt_ids: torch.Tensor = None, 400 | guidance: torch.Tensor = None, 401 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 402 | controlnet_block_samples=None, 403 | controlnet_single_block_samples=None, 404 | return_dict: bool = True, 405 | ip_token: torch.Tensor = None, #add param 406 | ip_token_ids: torch.Tensor = None, 407 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: 408 | """ 409 | The [`FluxTransformer2DModel`] forward method. 410 | 411 | Args: 412 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): 413 | Input `hidden_states`. 414 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): 415 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. 416 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected 417 | from the embeddings of input conditions. 418 | timestep ( `torch.LongTensor`): 419 | Used to indicate denoising step. 420 | block_controlnet_hidden_states: (`list` of `torch.Tensor`): 421 | A list of tensors that if specified are added to the residuals of transformer blocks. 422 | joint_attention_kwargs (`dict`, *optional*): 423 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 424 | `self.processor` in 425 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 426 | return_dict (`bool`, *optional*, defaults to `True`): 427 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 428 | tuple. 429 | 430 | Returns: 431 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 432 | `tuple` where the first element is the sample tensor. 433 | """ 434 | # with open(f"mid_info.txt", 'a') as f: 435 | # f.write(f"{ip_token.abs().mean().item() ,{hidden_states.abs().mean().item() ,{encoder_hidden_states.abs().mean().item()}}}\n") 436 | 437 | if joint_attention_kwargs is not None: 438 | joint_attention_kwargs = joint_attention_kwargs.copy() 439 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 440 | else: 441 | lora_scale = 1.0 442 | 443 | if USE_PEFT_BACKEND: 444 | # weight the lora layers by setting `lora_scale` for each PEFT layer 445 | scale_lora_layers(self, lora_scale) 446 | else: 447 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 448 | logger.warning( 449 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 450 | ) 451 | hidden_states = self.x_embedder(hidden_states) 452 | 453 | timestep = timestep.to(hidden_states.dtype) * 1000 454 | if guidance is not None: 455 | guidance = guidance.to(hidden_states.dtype) * 1000 456 | else: 457 | guidance = None 458 | temb = ( 459 | self.time_text_embed(timestep, pooled_projections) 460 | if guidance is None 461 | else self.time_text_embed(timestep, guidance, pooled_projections) 462 | ) 463 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 464 | 465 | if txt_ids.ndim == 3: 466 | # logger.warning( 467 | # "Passing `txt_ids` 3d torch.Tensor is deprecated." 468 | # "Please remove the batch dimension and pass it as a 2d torch Tensor" 469 | # ) 470 | txt_ids = txt_ids[0] 471 | if img_ids.ndim == 3: 472 | # logger.warning( 473 | # "Passing `img_ids` 3d torch.Tensor is deprecated." 474 | # "Please remove the batch dimension and pass it as a 2d torch Tensor" 475 | # ) 476 | img_ids = img_ids[0] 477 | # print(f"transformers!!!") #dylee 478 | # print(f"txt_ids, img_ids, ip_token_ids: {txt_ids.shape}, {img_ids.shape}, {ip_token_ids.shape}") 479 | ids = torch.cat((txt_ids, img_ids, ip_token_ids), dim=0) 480 | image_rotary_emb = self.pos_embed(ids) 481 | 482 | # print(f"image_rotary_emb shape {image_rotary_emb[0].shape}") 483 | # print(f"ip_token shape is {ip_token.shape}") 484 | # print(f"hidden_states shape is {hidden_states.shape}") 485 | # print(f"encoder_hidden_states shape is {encoder_hidden_states.shape}") 486 | 487 | for index_block, block in enumerate(self.transformer_blocks): 488 | if self.training and self.gradient_checkpointing: 489 | 490 | def create_custom_forward(module, return_dict=None): 491 | def custom_forward(*inputs): 492 | if return_dict is not None: 493 | return module(*inputs, return_dict=return_dict) 494 | else: 495 | return module(*inputs) 496 | 497 | return custom_forward 498 | 499 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 500 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 501 | create_custom_forward(block), 502 | hidden_states, 503 | encoder_hidden_states, 504 | temb, 505 | image_rotary_emb, 506 | ip_token, 507 | **ckpt_kwargs, 508 | ) 509 | 510 | else: 511 | encoder_hidden_states, hidden_states = block( 512 | hidden_states=hidden_states, 513 | encoder_hidden_states=encoder_hidden_states, 514 | temb=temb, 515 | image_rotary_emb=image_rotary_emb, 516 | ip_token=ip_token, 517 | ) 518 | 519 | # controlnet residual 520 | if controlnet_block_samples is not None: 521 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) 522 | interval_control = int(np.ceil(interval_control)) 523 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] 524 | 525 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 526 | 527 | 528 | # ids_single = torch.cat((txt_ids, img_ids), dim=0) 529 | # image_rotary_emb_single = self.pos_embed(ids_single) #dylee 530 | for index_block, block in enumerate(self.single_transformer_blocks): 531 | if self.training and self.gradient_checkpointing: 532 | 533 | def create_custom_forward(module, return_dict=None): 534 | def custom_forward(*inputs): 535 | if return_dict is not None: 536 | return module(*inputs, return_dict=return_dict) 537 | else: 538 | return module(*inputs) 539 | 540 | return custom_forward 541 | 542 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 543 | hidden_states = torch.utils.checkpoint.checkpoint( 544 | create_custom_forward(block), 545 | hidden_states, 546 | temb, 547 | image_rotary_emb, #dylee 548 | ip_token, 549 | **ckpt_kwargs, 550 | ) 551 | 552 | else: 553 | hidden_states = block( 554 | hidden_states=hidden_states, 555 | temb=temb, 556 | image_rotary_emb=image_rotary_emb, #dylee 557 | ip_token=ip_token, 558 | ) 559 | 560 | # controlnet residual 561 | if controlnet_single_block_samples is not None: 562 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) 563 | interval_control = int(np.ceil(interval_control)) 564 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 565 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] 566 | + controlnet_single_block_samples[index_block // interval_control] 567 | ) 568 | 569 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 570 | 571 | hidden_states = self.norm_out(hidden_states, temb) 572 | output = self.proj_out(hidden_states) 573 | 574 | if USE_PEFT_BACKEND: 575 | # remove `lora_scale` from each PEFT layer 576 | unscale_lora_layers(self, lora_scale) 577 | 578 | if not return_dict: 579 | return (output,) 580 | 581 | return Transformer2DModelOutput(sample=output) 582 | -------------------------------------------------------------------------------- /src/customID/transformer_flux_ori.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Any, Dict, Optional, Tuple, Union 17 | 18 | import numpy as np 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | from ...configuration_utils import ConfigMixin, register_to_config 24 | from ...loaders import FromOriginalModelMixin, PeftAdapterMixin 25 | from ...models.attention import FeedForward 26 | from ...models.attention_processor import ( 27 | Attention, 28 | AttentionProcessor, 29 | FluxAttnProcessor2_0, 30 | FusedFluxAttnProcessor2_0, 31 | ) 32 | from ...models.modeling_utils import ModelMixin 33 | from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle 34 | from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 35 | from ...utils.torch_utils import maybe_allow_in_graph 36 | from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed 37 | from ..modeling_outputs import Transformer2DModelOutput 38 | 39 | 40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 41 | 42 | 43 | @maybe_allow_in_graph 44 | class FluxSingleTransformerBlock(nn.Module): 45 | r""" 46 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. 47 | 48 | Reference: https://arxiv.org/abs/2403.03206 49 | 50 | Parameters: 51 | dim (`int`): The number of channels in the input and output. 52 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 53 | attention_head_dim (`int`): The number of channels in each head. 54 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the 55 | processing of `context` conditions. 56 | """ 57 | 58 | def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): 59 | super().__init__() 60 | self.mlp_hidden_dim = int(dim * mlp_ratio) 61 | 62 | self.norm = AdaLayerNormZeroSingle(dim) 63 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) 64 | self.act_mlp = nn.GELU(approximate="tanh") 65 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) 66 | 67 | processor = FluxAttnProcessor2_0() 68 | self.attn = Attention( 69 | query_dim=dim, 70 | cross_attention_dim=None, 71 | dim_head=attention_head_dim, 72 | heads=num_attention_heads, 73 | out_dim=dim, 74 | bias=True, 75 | processor=processor, 76 | qk_norm="rms_norm", 77 | eps=1e-6, 78 | pre_only=True, 79 | ) 80 | 81 | def forward( 82 | self, 83 | hidden_states: torch.FloatTensor, 84 | temb: torch.FloatTensor, 85 | image_rotary_emb=None, 86 | ): 87 | residual = hidden_states 88 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb) 89 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) 90 | 91 | attn_output = self.attn( 92 | hidden_states=norm_hidden_states, 93 | image_rotary_emb=image_rotary_emb, 94 | ) 95 | 96 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) 97 | gate = gate.unsqueeze(1) 98 | hidden_states = gate * self.proj_out(hidden_states) 99 | hidden_states = residual + hidden_states 100 | if hidden_states.dtype == torch.float16: 101 | hidden_states = hidden_states.clip(-65504, 65504) 102 | 103 | return hidden_states 104 | 105 | 106 | @maybe_allow_in_graph 107 | class FluxTransformerBlock(nn.Module): 108 | r""" 109 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. 110 | 111 | Reference: https://arxiv.org/abs/2403.03206 112 | 113 | Parameters: 114 | dim (`int`): The number of channels in the input and output. 115 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 116 | attention_head_dim (`int`): The number of channels in each head. 117 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the 118 | processing of `context` conditions. 119 | """ 120 | 121 | def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): 122 | super().__init__() 123 | 124 | self.norm1 = AdaLayerNormZero(dim) 125 | 126 | self.norm1_context = AdaLayerNormZero(dim) 127 | 128 | if hasattr(F, "scaled_dot_product_attention"): 129 | processor = FluxAttnProcessor2_0() 130 | else: 131 | raise ValueError( 132 | "The current PyTorch version does not support the `scaled_dot_product_attention` function." 133 | ) 134 | self.attn = Attention( 135 | query_dim=dim, 136 | cross_attention_dim=None, 137 | added_kv_proj_dim=dim, 138 | dim_head=attention_head_dim, 139 | heads=num_attention_heads, 140 | out_dim=dim, 141 | context_pre_only=False, 142 | bias=True, 143 | processor=processor, 144 | qk_norm=qk_norm, 145 | eps=eps, 146 | ) 147 | 148 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 149 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 150 | 151 | self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 152 | self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 153 | 154 | # let chunk size default to None 155 | self._chunk_size = None 156 | self._chunk_dim = 0 157 | 158 | def forward( 159 | self, 160 | hidden_states: torch.FloatTensor, 161 | encoder_hidden_states: torch.FloatTensor, 162 | temb: torch.FloatTensor, 163 | image_rotary_emb=None, 164 | ): 165 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) 166 | 167 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( 168 | encoder_hidden_states, emb=temb 169 | ) 170 | 171 | # Attention. 172 | attn_output, context_attn_output = self.attn( 173 | hidden_states=norm_hidden_states, 174 | encoder_hidden_states=norm_encoder_hidden_states, 175 | image_rotary_emb=image_rotary_emb, 176 | ) 177 | 178 | # Process attention outputs for the `hidden_states`. 179 | attn_output = gate_msa.unsqueeze(1) * attn_output 180 | hidden_states = hidden_states + attn_output 181 | 182 | norm_hidden_states = self.norm2(hidden_states) 183 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 184 | 185 | ff_output = self.ff(norm_hidden_states) 186 | ff_output = gate_mlp.unsqueeze(1) * ff_output 187 | 188 | hidden_states = hidden_states + ff_output 189 | 190 | # Process attention outputs for the `encoder_hidden_states`. 191 | 192 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output 193 | encoder_hidden_states = encoder_hidden_states + context_attn_output 194 | 195 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) 196 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] 197 | 198 | context_ff_output = self.ff_context(norm_encoder_hidden_states) 199 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output 200 | if encoder_hidden_states.dtype == torch.float16: 201 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) 202 | 203 | return encoder_hidden_states, hidden_states 204 | 205 | 206 | class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): 207 | """ 208 | The Transformer model introduced in Flux. 209 | 210 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 211 | 212 | Parameters: 213 | patch_size (`int`): Patch size to turn the input data into small patches. 214 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. 215 | num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. 216 | num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. 217 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. 218 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. 219 | joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 220 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. 221 | guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. 222 | """ 223 | 224 | _supports_gradient_checkpointing = True 225 | _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] 226 | 227 | @register_to_config 228 | def __init__( 229 | self, 230 | patch_size: int = 1, 231 | in_channels: int = 64, 232 | num_layers: int = 19, 233 | num_single_layers: int = 38, 234 | attention_head_dim: int = 128, 235 | num_attention_heads: int = 24, 236 | joint_attention_dim: int = 4096, 237 | pooled_projection_dim: int = 768, 238 | guidance_embeds: bool = False, 239 | axes_dims_rope: Tuple[int] = (16, 56, 56), 240 | ): 241 | super().__init__() 242 | self.out_channels = in_channels 243 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim 244 | 245 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) 246 | 247 | text_time_guidance_cls = ( 248 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings 249 | ) 250 | self.time_text_embed = text_time_guidance_cls( 251 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim 252 | ) 253 | 254 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) 255 | self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) 256 | 257 | self.transformer_blocks = nn.ModuleList( 258 | [ 259 | FluxTransformerBlock( 260 | dim=self.inner_dim, 261 | num_attention_heads=self.config.num_attention_heads, 262 | attention_head_dim=self.config.attention_head_dim, 263 | ) 264 | for i in range(self.config.num_layers) 265 | ] 266 | ) 267 | 268 | self.single_transformer_blocks = nn.ModuleList( 269 | [ 270 | FluxSingleTransformerBlock( 271 | dim=self.inner_dim, 272 | num_attention_heads=self.config.num_attention_heads, 273 | attention_head_dim=self.config.attention_head_dim, 274 | ) 275 | for i in range(self.config.num_single_layers) 276 | ] 277 | ) 278 | 279 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) 280 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) 281 | 282 | self.gradient_checkpointing = False 283 | 284 | @property 285 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 286 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 287 | r""" 288 | Returns: 289 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 290 | indexed by its weight name. 291 | """ 292 | # set recursively 293 | processors = {} 294 | 295 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 296 | if hasattr(module, "get_processor"): 297 | processors[f"{name}.processor"] = module.get_processor() 298 | 299 | for sub_name, child in module.named_children(): 300 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 301 | 302 | return processors 303 | 304 | for name, module in self.named_children(): 305 | fn_recursive_add_processors(name, module, processors) 306 | 307 | return processors 308 | 309 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 310 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 311 | r""" 312 | Sets the attention processor to use to compute attention. 313 | 314 | Parameters: 315 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 316 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 317 | for **all** `Attention` layers. 318 | 319 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 320 | processor. This is strongly recommended when setting trainable attention processors. 321 | 322 | """ 323 | count = len(self.attn_processors.keys()) 324 | 325 | if isinstance(processor, dict) and len(processor) != count: 326 | raise ValueError( 327 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 328 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 329 | ) 330 | 331 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 332 | if hasattr(module, "set_processor"): 333 | if not isinstance(processor, dict): 334 | module.set_processor(processor) 335 | else: 336 | module.set_processor(processor.pop(f"{name}.processor")) 337 | 338 | for sub_name, child in module.named_children(): 339 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 340 | 341 | for name, module in self.named_children(): 342 | fn_recursive_attn_processor(name, module, processor) 343 | 344 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 345 | def fuse_qkv_projections(self): 346 | """ 347 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) 348 | are fused. For cross-attention modules, key and value projection matrices are fused. 349 | 350 | 351 | 352 | This API is 🧪 experimental. 353 | 354 | 355 | """ 356 | self.original_attn_processors = None 357 | 358 | for _, attn_processor in self.attn_processors.items(): 359 | if "Added" in str(attn_processor.__class__.__name__): 360 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") 361 | 362 | self.original_attn_processors = self.attn_processors 363 | 364 | for module in self.modules(): 365 | if isinstance(module, Attention): 366 | module.fuse_projections(fuse=True) 367 | 368 | self.set_attn_processor(FusedFluxAttnProcessor2_0()) 369 | 370 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 371 | def unfuse_qkv_projections(self): 372 | """Disables the fused QKV projection if enabled. 373 | 374 | 375 | 376 | This API is 🧪 experimental. 377 | 378 | 379 | 380 | """ 381 | if self.original_attn_processors is not None: 382 | self.set_attn_processor(self.original_attn_processors) 383 | 384 | def _set_gradient_checkpointing(self, module, value=False): 385 | if hasattr(module, "gradient_checkpointing"): 386 | module.gradient_checkpointing = value 387 | 388 | def forward( 389 | self, 390 | hidden_states: torch.Tensor, 391 | encoder_hidden_states: torch.Tensor = None, 392 | pooled_projections: torch.Tensor = None, 393 | timestep: torch.LongTensor = None, 394 | img_ids: torch.Tensor = None, 395 | txt_ids: torch.Tensor = None, 396 | guidance: torch.Tensor = None, 397 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 398 | controlnet_block_samples=None, 399 | controlnet_single_block_samples=None, 400 | return_dict: bool = True, 401 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: 402 | """ 403 | The [`FluxTransformer2DModel`] forward method. 404 | 405 | Args: 406 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): 407 | Input `hidden_states`. 408 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): 409 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. 410 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected 411 | from the embeddings of input conditions. 412 | timestep ( `torch.LongTensor`): 413 | Used to indicate denoising step. 414 | block_controlnet_hidden_states: (`list` of `torch.Tensor`): 415 | A list of tensors that if specified are added to the residuals of transformer blocks. 416 | joint_attention_kwargs (`dict`, *optional*): 417 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 418 | `self.processor` in 419 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 420 | return_dict (`bool`, *optional*, defaults to `True`): 421 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 422 | tuple. 423 | 424 | Returns: 425 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 426 | `tuple` where the first element is the sample tensor. 427 | """ 428 | if joint_attention_kwargs is not None: 429 | joint_attention_kwargs = joint_attention_kwargs.copy() 430 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 431 | else: 432 | lora_scale = 1.0 433 | 434 | if USE_PEFT_BACKEND: 435 | # weight the lora layers by setting `lora_scale` for each PEFT layer 436 | scale_lora_layers(self, lora_scale) 437 | else: 438 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 439 | logger.warning( 440 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 441 | ) 442 | hidden_states = self.x_embedder(hidden_states) 443 | 444 | timestep = timestep.to(hidden_states.dtype) * 1000 445 | if guidance is not None: 446 | guidance = guidance.to(hidden_states.dtype) * 1000 447 | else: 448 | guidance = None 449 | temb = ( 450 | self.time_text_embed(timestep, pooled_projections) 451 | if guidance is None 452 | else self.time_text_embed(timestep, guidance, pooled_projections) 453 | ) 454 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 455 | 456 | if txt_ids.ndim == 3: 457 | logger.warning( 458 | "Passing `txt_ids` 3d torch.Tensor is deprecated." 459 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 460 | ) 461 | txt_ids = txt_ids[0] 462 | if img_ids.ndim == 3: 463 | logger.warning( 464 | "Passing `img_ids` 3d torch.Tensor is deprecated." 465 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 466 | ) 467 | img_ids = img_ids[0] 468 | ids = torch.cat((txt_ids, img_ids), dim=0) 469 | image_rotary_emb = self.pos_embed(ids) 470 | 471 | for index_block, block in enumerate(self.transformer_blocks): 472 | if self.training and self.gradient_checkpointing: 473 | 474 | def create_custom_forward(module, return_dict=None): 475 | def custom_forward(*inputs): 476 | if return_dict is not None: 477 | return module(*inputs, return_dict=return_dict) 478 | else: 479 | return module(*inputs) 480 | 481 | return custom_forward 482 | 483 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 484 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 485 | create_custom_forward(block), 486 | hidden_states, 487 | encoder_hidden_states, 488 | temb, 489 | image_rotary_emb, 490 | **ckpt_kwargs, 491 | ) 492 | 493 | else: 494 | encoder_hidden_states, hidden_states = block( 495 | hidden_states=hidden_states, 496 | encoder_hidden_states=encoder_hidden_states, 497 | temb=temb, 498 | image_rotary_emb=image_rotary_emb, 499 | ) 500 | 501 | # controlnet residual 502 | if controlnet_block_samples is not None: 503 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) 504 | interval_control = int(np.ceil(interval_control)) 505 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] 506 | 507 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 508 | 509 | for index_block, block in enumerate(self.single_transformer_blocks): 510 | if self.training and self.gradient_checkpointing: 511 | 512 | def create_custom_forward(module, return_dict=None): 513 | def custom_forward(*inputs): 514 | if return_dict is not None: 515 | return module(*inputs, return_dict=return_dict) 516 | else: 517 | return module(*inputs) 518 | 519 | return custom_forward 520 | 521 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 522 | hidden_states = torch.utils.checkpoint.checkpoint( 523 | create_custom_forward(block), 524 | hidden_states, 525 | temb, 526 | image_rotary_emb, 527 | **ckpt_kwargs, 528 | ) 529 | 530 | else: 531 | hidden_states = block( 532 | hidden_states=hidden_states, 533 | temb=temb, 534 | image_rotary_emb=image_rotary_emb, 535 | ) 536 | 537 | # controlnet residual 538 | if controlnet_single_block_samples is not None: 539 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) 540 | interval_control = int(np.ceil(interval_control)) 541 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 542 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] 543 | + controlnet_single_block_samples[index_block // interval_control] 544 | ) 545 | 546 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 547 | 548 | hidden_states = self.norm_out(hidden_states, temb) 549 | output = self.proj_out(hidden_states) 550 | 551 | if USE_PEFT_BACKEND: 552 | # remove `lora_scale` from each PEFT layer 553 | unscale_lora_layers(self, lora_scale) 554 | 555 | if not return_dict: 556 | return (output,) 557 | 558 | return Transformer2DModelOutput(sample=output) 559 | -------------------------------------------------------------------------------- /src/customID/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from PIL import Image 5 | 6 | attn_maps = {} 7 | def hook_fn(name): 8 | def forward_hook(module, input, output): 9 | if hasattr(module.processor, "attn_map"): 10 | attn_maps[name] = module.processor.attn_map 11 | del module.processor.attn_map 12 | 13 | return forward_hook 14 | 15 | def register_cross_attention_hook(unet): 16 | for name, module in unet.named_modules(): 17 | if name.split('.')[-1].startswith('attn2'): 18 | module.register_forward_hook(hook_fn(name)) 19 | 20 | return unet 21 | 22 | def upscale(attn_map, target_size): 23 | attn_map = torch.mean(attn_map, dim=0) 24 | attn_map = attn_map.permute(1,0) 25 | temp_size = None 26 | 27 | for i in range(0,5): 28 | scale = 2 ** i 29 | if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64: 30 | temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8)) 31 | break 32 | 33 | assert temp_size is not None, "temp_size cannot is None" 34 | 35 | attn_map = attn_map.view(attn_map.shape[0], *temp_size) 36 | 37 | attn_map = F.interpolate( 38 | attn_map.unsqueeze(0).to(dtype=torch.float32), 39 | size=target_size, 40 | mode='bilinear', 41 | align_corners=False 42 | )[0] 43 | 44 | attn_map = torch.softmax(attn_map, dim=0) 45 | return attn_map 46 | def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True): 47 | 48 | idx = 0 if instance_or_negative else 1 49 | net_attn_maps = [] 50 | 51 | for name, attn_map in attn_maps.items(): 52 | attn_map = attn_map.cpu() if detach else attn_map 53 | attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze() 54 | attn_map = upscale(attn_map, image_size) 55 | net_attn_maps.append(attn_map) 56 | 57 | net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) 58 | 59 | return net_attn_maps 60 | 61 | def attnmaps2images(net_attn_maps): 62 | 63 | #total_attn_scores = 0 64 | images = [] 65 | 66 | for attn_map in net_attn_maps: 67 | attn_map = attn_map.cpu().numpy() 68 | #total_attn_scores += attn_map.mean().item() 69 | 70 | normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255 71 | normalized_attn_map = normalized_attn_map.astype(np.uint8) 72 | #print("norm: ", normalized_attn_map.shape) 73 | image = Image.fromarray(normalized_attn_map) 74 | 75 | #image = fix_save_attn_map(attn_map) 76 | images.append(image) 77 | 78 | #print(total_attn_scores) 79 | return images 80 | def is_torch2_available(): 81 | return hasattr(F, "scaled_dot_product_attention") 82 | 83 | def get_generator(seed, device): 84 | 85 | if seed is not None: 86 | if isinstance(seed, list): 87 | generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] 88 | else: 89 | generator = torch.Generator(device).manual_seed(seed) 90 | else: 91 | generator = None 92 | 93 | return generator -------------------------------------------------------------------------------- /src/utils/insightface_package.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # pip install insightface==0.7.3 3 | from insightface.app import FaceAnalysis 4 | from insightface.data import get_image as ins_get_image 5 | 6 | ### 7 | # https://github.com/cubiq/ComfyUI_IPAdapter_plus/issues/165#issue-2055829543 8 | ### 9 | class FaceAnalysis2(FaceAnalysis): 10 | # NOTE: allows setting det_size for each detection call. 11 | # the model allows it but the wrapping code from insightface 12 | # doesn't show it, and people end up loading duplicate models 13 | # for different sizes where there is absolutely no need to 14 | def get(self, img, max_num=0, det_size=(640, 640)): 15 | if det_size is not None: 16 | self.det_model.input_size = det_size 17 | 18 | return super().get(img, max_num) 19 | 20 | def analyze_faces(face_analysis: FaceAnalysis, img_data: np.ndarray, det_size=(640, 640)): 21 | # NOTE: try detect faces, if no faces detected, lower det_size until it does 22 | detection_sizes = [None] + [(size, size) for size in range(640, 256, -64)] + [(256, 256)] 23 | 24 | for size in detection_sizes: 25 | faces = face_analysis.get(img_data, det_size=size) 26 | if len(faces) > 0: 27 | return faces 28 | 29 | return [] 30 | --------------------------------------------------------------------------------