├── images ├── 0.jpg ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── flux1.jpg ├── flux2.jpg ├── flux3.jpg ├── alibaba.png ├── alimama.png └── alibabaalimama.png ├── main.py ├── readme.md ├── controlnet_flux.py ├── transformer_flux.py └── pipeline_flux_controlnet_inpaint.py /images/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/friendmine/FLUX-Controlnet-Inpainting/main/images/0.jpg -------------------------------------------------------------------------------- /images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/friendmine/FLUX-Controlnet-Inpainting/main/images/1.jpg -------------------------------------------------------------------------------- /images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/friendmine/FLUX-Controlnet-Inpainting/main/images/2.jpg -------------------------------------------------------------------------------- /images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/friendmine/FLUX-Controlnet-Inpainting/main/images/3.jpg -------------------------------------------------------------------------------- /images/flux1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/friendmine/FLUX-Controlnet-Inpainting/main/images/flux1.jpg -------------------------------------------------------------------------------- /images/flux2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/friendmine/FLUX-Controlnet-Inpainting/main/images/flux2.jpg -------------------------------------------------------------------------------- /images/flux3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/friendmine/FLUX-Controlnet-Inpainting/main/images/flux3.jpg -------------------------------------------------------------------------------- /images/alibaba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/friendmine/FLUX-Controlnet-Inpainting/main/images/alibaba.png -------------------------------------------------------------------------------- /images/alimama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/friendmine/FLUX-Controlnet-Inpainting/main/images/alimama.png -------------------------------------------------------------------------------- /images/alibabaalimama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/friendmine/FLUX-Controlnet-Inpainting/main/images/alibabaalimama.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.utils import load_image, check_min_version 3 | from controlnet_flux import FluxControlNetModel 4 | from transformer_flux import FluxTransformer2DModel 5 | from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline 6 | 7 | check_min_version("0.30.2") 8 | 9 | # Set image path , mask path and prompt 10 | image_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket.png', 11 | mask_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket_mask.jpeg', 12 | prompt='a person wearing a white shoe, carrying a white bucket with text "FLUX" on it' 13 | 14 | # Build pipeline 15 | controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16) 16 | transformer = FluxTransformer2DModel.from_pretrained( 17 | "black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dytpe=torch.bfloat16 18 | ) 19 | pipe = FluxControlNetInpaintingPipeline.from_pretrained( 20 | "black-forest-labs/FLUX.1-dev", 21 | controlnet=controlnet, 22 | transformer=transformer, 23 | torch_dtype=torch.bfloat16 24 | ).to("cuda") 25 | pipe.transformer.to(torch.bfloat16) 26 | pipe.controlnet.to(torch.bfloat16) 27 | 28 | # Load image and mask 29 | size = (768, 768) 30 | image = load_image(image_path).convert("RGB").resize(size) 31 | mask = load_image(mask_path).convert("RGB").resize(size) 32 | generator = torch.Generator(device="cuda").manual_seed(24) 33 | 34 | # Inpaint 35 | result = pipe( 36 | prompt=prompt, 37 | height=size[1], 38 | width=size[0], 39 | control_image=image, 40 | control_mask=mask, 41 | num_inference_steps=28, 42 | generator=generator, 43 | controlnet_conditioning_scale=0.9, 44 | guidance_scale=3.5, 45 | negative_prompt="", 46 | true_guidance_scale=3.5 47 | ).images[0] 48 | 49 | result.save('flux_inpaint.png') 50 | print("Successfully inpaint image") 51 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |
2 | alibaba 3 |
4 | 5 | This repository provides a Inpainting ControlNet checkpoint for [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) model released by researchers from AlimamaCreative Team. 6 | 7 | ## News 8 | 9 | 🎉 Thanks to @comfyanonymous,ComfyUI now supports inference for Alimama inpainting ControlNet. Workflow can be downloaded from [here](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/alimama-flux-controlnet-inpaint.json). 10 | 11 | ComfyUI Usage Tips: 12 | 13 | * Using the `t5xxl-FP16` and `flux1-dev-fp8` models for 28-step inference, the GPU memory usage is 27GB. The inference time with `cfg=3.5` is 27 seconds, while without `cfg=1` it is 15 seconds. `Hyper-FLUX-lora` can be used to accelerate inference. 14 | * You can try adjusting(lower) the parameters `control-strength`, `control-end-percent`, and `cfg` to achieve better results. 15 | * The following example uses `control-strength` = 0.9 & `control-end-percent` = 1.0 & `cfg` = 3.5 16 | 17 | | Input | Output | Prompt | 18 | |------------------------------|------------------------------|-------------| 19 | | ![Image1](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_1.png) | ![Image2](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_1.png) | The image depicts a scene from the anime series Dragon Ball Z, with the characters Goku, Elon Musk, and a child version of Gohan sharing a meal of ramen noodles. They are all sitting around a dining table, with Goku and Gohan on one side and Naruto on the other. They are all holding chopsticks and eating the noodles. The table is set with bowls of ramen, cups, and bowls of drinks. The arrangement of the characters and the food creates a sense of camaraderie and shared enjoyment of the meal. | 20 | | ![Image3](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_2.png) | ![Image4](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_2.png) | The image is an illustration of a man standing in a cafe. He is wearing a white turtleneck, a camel-colored trench coat, and brown shoes. He is holding a cell phone and appears to be looking at it. There is a small table with a cat on it to his right. In the background, there is another man sitting at a table with a laptop. The man is wearing a black turtleneck and a tie. | 21 | | ![Image5](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_3.png) | ![Image6](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_3.png) | A woman with blonde hair is sitting on a table wearing a red and white long dress. She is holding a green phone in her hand and appears to be taking a photo. There is a bag next to her on the table and a handbag beside her on the chair. The woman is looking at the phone with a smile on her face. The background includes a TV on the left wall and a couch on the right. A chair is also present in the scene. | 22 | | ![Image7](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_4.png) | ![Image8](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_4.png) | The image depicts a beautiful young woman sitting at a desk, reading a book. She has long, wavy brown hair and is wearing a grey shirt with a black cardigan. She is holding a red pencil in her left hand and appears to be deep in thought. Surrounding her are numerous books, some stacked on the desk and others placed on a shelf behind her. A potted plant is also visible in the background, adding a touch of greenery to the scene. The image conveys a sense of serenity and intellectual pursuits. | 23 | 24 | 25 | ## Model Cards 26 | 27 | 28 | 29 | Hugging Face The model weights have been uploaded to Hugging Face. 30 | 31 | 32 | * The model was trained on 12M laion2B and internal source images at resolution 768x768. The inference performs best at this size, with other sizes yielding suboptimal results. 33 | 34 | * The recommended controlnet_conditioning_scale is 0.9 - 0.95. 35 | 36 | * **Please note: This is only the alpha version during the training process. We will release an updated version when we feel ready.** 37 | 38 | ## Showcase 39 | 40 | ![flux1](images/flux1.jpg) 41 | ![flux2](images/flux2.jpg) 42 | ![flux3](images/flux3.jpg) 43 | 44 | ## Comparison with SDXL-Inpainting 45 | 46 | Compared with [SDXL-Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1) 47 | 48 | From left to right: Input image | Masked image | SDXL inpainting | Ours 49 | 50 | ![0](images/0.jpg) 51 | *The image depicts a beautiful young woman sitting at a desk, reading a book. She has long, wavy brown hair and is wearing a grey shirt with a black cardigan. She is holding a pencil in her left hand and appears to be deep in thought. Surrounding her are numerous books, some stacked on the desk and others placed on a shelf behind her. A potted plant is also visible in the background, adding a touch of greenery to the scene. The image conveys a sense of serenity and intellectual pursuits.* 52 | 53 | ![0](images/1.jpg) 54 | A woman with blonde hair is sitting on a table wearing a blue and white long dress. She is holding a green phone in her hand and appears to be taking a photo. There is a bag next to her on the table and a handbag beside her on the chair. The woman is looking at the phone with a smile on her face. The background includes a TV on the left wall and a couch on the right. A chair is also present in the scene. 55 | 56 | ![0](images/2.jpg) 57 | The image is an illustration of a man standing in a cafe. He is wearing a white turtleneck, a camel-colored trench coat, and brown shoes. He is holding a cell phone and appears to be looking at it. There is a small table with a cup of coffee on it to his right. In the background, there is another man sitting at a table with a laptop. The man is wearing a black turtleneck and a tie. There are several cups and a cake on the table in the background. The man sitting at the table appears to be typing on the laptop. 58 | 59 | ![0](images/3.jpg) 60 | The image depicts a scene from the anime series Dragon Ball Z, with the characters Goku, Naruto, and a child version of Gohan sharing a meal of ramen noodles. They are all sitting around a dining table, with Goku and Gohan on one side and Naruto on the other. They are all holding chopsticks and eating the noodles. The table is set with bowls of ramen, cups, and bowls of drinks. The arrangement of the characters and the food creates a sense of camaraderie and shared enjoyment of the meal. 61 | 62 | ## Using with Diffusers 63 | Step1: install diffusers 64 | ``` Shell 65 | pip install diffusers==0.30.2 66 | ``` 67 | 68 | Step2: clone repo from github 69 | ``` Shell 70 | git clone https://github.com/alimama-creative/FLUX-Controlnet-Inpainting.git 71 | ``` 72 | 73 | Step3: modify the image_path, mask_path, prompt and run 74 | ``` Shell 75 | python main.py 76 | ``` 77 | ## LICENSE 78 | Our weights fall under the [FLUX.1 [dev]](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md) Non-Commercial License. 79 | -------------------------------------------------------------------------------- /controlnet_flux.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from diffusers.configuration_utils import ConfigMixin, register_to_config 8 | from diffusers.loaders import PeftAdapterMixin 9 | from diffusers.models.modeling_utils import ModelMixin 10 | from diffusers.models.attention_processor import AttentionProcessor 11 | from diffusers.utils import ( 12 | USE_PEFT_BACKEND, 13 | is_torch_version, 14 | logging, 15 | scale_lora_layers, 16 | unscale_lora_layers, 17 | ) 18 | from diffusers.models.controlnet import BaseOutput, zero_module 19 | from diffusers.models.embeddings import ( 20 | CombinedTimestepGuidanceTextProjEmbeddings, 21 | CombinedTimestepTextProjEmbeddings, 22 | ) 23 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 24 | from transformer_flux import ( 25 | EmbedND, 26 | FluxSingleTransformerBlock, 27 | FluxTransformerBlock, 28 | ) 29 | 30 | 31 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 32 | 33 | 34 | @dataclass 35 | class FluxControlNetOutput(BaseOutput): 36 | controlnet_block_samples: Tuple[torch.Tensor] 37 | controlnet_single_block_samples: Tuple[torch.Tensor] 38 | 39 | 40 | class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): 41 | _supports_gradient_checkpointing = True 42 | 43 | @register_to_config 44 | def __init__( 45 | self, 46 | patch_size: int = 1, 47 | in_channels: int = 64, 48 | num_layers: int = 19, 49 | num_single_layers: int = 38, 50 | attention_head_dim: int = 128, 51 | num_attention_heads: int = 24, 52 | joint_attention_dim: int = 4096, 53 | pooled_projection_dim: int = 768, 54 | guidance_embeds: bool = False, 55 | axes_dims_rope: List[int] = [16, 56, 56], 56 | extra_condition_channels: int = 1 * 4, 57 | ): 58 | super().__init__() 59 | self.out_channels = in_channels 60 | self.inner_dim = num_attention_heads * attention_head_dim 61 | 62 | self.pos_embed = EmbedND( 63 | dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope 64 | ) 65 | text_time_guidance_cls = ( 66 | CombinedTimestepGuidanceTextProjEmbeddings 67 | if guidance_embeds 68 | else CombinedTimestepTextProjEmbeddings 69 | ) 70 | self.time_text_embed = text_time_guidance_cls( 71 | embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim 72 | ) 73 | 74 | self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) 75 | self.x_embedder = nn.Linear(in_channels, self.inner_dim) 76 | 77 | self.transformer_blocks = nn.ModuleList( 78 | [ 79 | FluxTransformerBlock( 80 | dim=self.inner_dim, 81 | num_attention_heads=num_attention_heads, 82 | attention_head_dim=attention_head_dim, 83 | ) 84 | for _ in range(num_layers) 85 | ] 86 | ) 87 | 88 | self.single_transformer_blocks = nn.ModuleList( 89 | [ 90 | FluxSingleTransformerBlock( 91 | dim=self.inner_dim, 92 | num_attention_heads=num_attention_heads, 93 | attention_head_dim=attention_head_dim, 94 | ) 95 | for _ in range(num_single_layers) 96 | ] 97 | ) 98 | 99 | # controlnet_blocks 100 | self.controlnet_blocks = nn.ModuleList([]) 101 | for _ in range(len(self.transformer_blocks)): 102 | self.controlnet_blocks.append( 103 | zero_module(nn.Linear(self.inner_dim, self.inner_dim)) 104 | ) 105 | 106 | self.controlnet_single_blocks = nn.ModuleList([]) 107 | for _ in range(len(self.single_transformer_blocks)): 108 | self.controlnet_single_blocks.append( 109 | zero_module(nn.Linear(self.inner_dim, self.inner_dim)) 110 | ) 111 | 112 | self.controlnet_x_embedder = zero_module( 113 | torch.nn.Linear(in_channels + extra_condition_channels, self.inner_dim) 114 | ) 115 | 116 | self.gradient_checkpointing = False 117 | 118 | @property 119 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 120 | def attn_processors(self): 121 | r""" 122 | Returns: 123 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 124 | indexed by its weight name. 125 | """ 126 | # set recursively 127 | processors = {} 128 | 129 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 130 | if hasattr(module, "get_processor"): 131 | processors[f"{name}.processor"] = module.get_processor() 132 | 133 | for sub_name, child in module.named_children(): 134 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 135 | 136 | return processors 137 | 138 | for name, module in self.named_children(): 139 | fn_recursive_add_processors(name, module, processors) 140 | 141 | return processors 142 | 143 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 144 | def set_attn_processor(self, processor): 145 | r""" 146 | Sets the attention processor to use to compute attention. 147 | 148 | Parameters: 149 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 150 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 151 | for **all** `Attention` layers. 152 | 153 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 154 | processor. This is strongly recommended when setting trainable attention processors. 155 | 156 | """ 157 | count = len(self.attn_processors.keys()) 158 | 159 | if isinstance(processor, dict) and len(processor) != count: 160 | raise ValueError( 161 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 162 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 163 | ) 164 | 165 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 166 | if hasattr(module, "set_processor"): 167 | if not isinstance(processor, dict): 168 | module.set_processor(processor) 169 | else: 170 | module.set_processor(processor.pop(f"{name}.processor")) 171 | 172 | for sub_name, child in module.named_children(): 173 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 174 | 175 | for name, module in self.named_children(): 176 | fn_recursive_attn_processor(name, module, processor) 177 | 178 | def _set_gradient_checkpointing(self, module, value=False): 179 | if hasattr(module, "gradient_checkpointing"): 180 | module.gradient_checkpointing = value 181 | 182 | @classmethod 183 | def from_transformer( 184 | cls, 185 | transformer, 186 | num_layers: int = 4, 187 | num_single_layers: int = 10, 188 | attention_head_dim: int = 128, 189 | num_attention_heads: int = 24, 190 | load_weights_from_transformer=True, 191 | ): 192 | config = transformer.config 193 | config["num_layers"] = num_layers 194 | config["num_single_layers"] = num_single_layers 195 | config["attention_head_dim"] = attention_head_dim 196 | config["num_attention_heads"] = num_attention_heads 197 | 198 | controlnet = cls(**config) 199 | 200 | if load_weights_from_transformer: 201 | controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) 202 | controlnet.time_text_embed.load_state_dict( 203 | transformer.time_text_embed.state_dict() 204 | ) 205 | controlnet.context_embedder.load_state_dict( 206 | transformer.context_embedder.state_dict() 207 | ) 208 | controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict()) 209 | controlnet.transformer_blocks.load_state_dict( 210 | transformer.transformer_blocks.state_dict(), strict=False 211 | ) 212 | controlnet.single_transformer_blocks.load_state_dict( 213 | transformer.single_transformer_blocks.state_dict(), strict=False 214 | ) 215 | 216 | controlnet.controlnet_x_embedder = zero_module( 217 | controlnet.controlnet_x_embedder 218 | ) 219 | 220 | return controlnet 221 | 222 | def forward( 223 | self, 224 | hidden_states: torch.Tensor, 225 | controlnet_cond: torch.Tensor, 226 | conditioning_scale: float = 1.0, 227 | encoder_hidden_states: torch.Tensor = None, 228 | pooled_projections: torch.Tensor = None, 229 | timestep: torch.LongTensor = None, 230 | img_ids: torch.Tensor = None, 231 | txt_ids: torch.Tensor = None, 232 | guidance: torch.Tensor = None, 233 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 234 | return_dict: bool = True, 235 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: 236 | """ 237 | The [`FluxTransformer2DModel`] forward method. 238 | 239 | Args: 240 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): 241 | Input `hidden_states`. 242 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): 243 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. 244 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected 245 | from the embeddings of input conditions. 246 | timestep ( `torch.LongTensor`): 247 | Used to indicate denoising step. 248 | block_controlnet_hidden_states: (`list` of `torch.Tensor`): 249 | A list of tensors that if specified are added to the residuals of transformer blocks. 250 | joint_attention_kwargs (`dict`, *optional*): 251 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 252 | `self.processor` in 253 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 254 | return_dict (`bool`, *optional*, defaults to `True`): 255 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 256 | tuple. 257 | 258 | Returns: 259 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 260 | `tuple` where the first element is the sample tensor. 261 | """ 262 | if joint_attention_kwargs is not None: 263 | joint_attention_kwargs = joint_attention_kwargs.copy() 264 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 265 | else: 266 | lora_scale = 1.0 267 | 268 | if USE_PEFT_BACKEND: 269 | # weight the lora layers by setting `lora_scale` for each PEFT layer 270 | scale_lora_layers(self, lora_scale) 271 | else: 272 | if ( 273 | joint_attention_kwargs is not None 274 | and joint_attention_kwargs.get("scale", None) is not None 275 | ): 276 | logger.warning( 277 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 278 | ) 279 | hidden_states = self.x_embedder(hidden_states) 280 | 281 | # add condition 282 | hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) 283 | 284 | timestep = timestep.to(hidden_states.dtype) * 1000 285 | if guidance is not None: 286 | guidance = guidance.to(hidden_states.dtype) * 1000 287 | else: 288 | guidance = None 289 | temb = ( 290 | self.time_text_embed(timestep, pooled_projections) 291 | if guidance is None 292 | else self.time_text_embed(timestep, guidance, pooled_projections) 293 | ) 294 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 295 | 296 | txt_ids = txt_ids.expand(img_ids.size(0), -1, -1) 297 | ids = torch.cat((txt_ids, img_ids), dim=1) 298 | image_rotary_emb = self.pos_embed(ids) 299 | 300 | block_samples = () 301 | for _, block in enumerate(self.transformer_blocks): 302 | if self.training and self.gradient_checkpointing: 303 | 304 | def create_custom_forward(module, return_dict=None): 305 | def custom_forward(*inputs): 306 | if return_dict is not None: 307 | return module(*inputs, return_dict=return_dict) 308 | else: 309 | return module(*inputs) 310 | 311 | return custom_forward 312 | 313 | ckpt_kwargs: Dict[str, Any] = ( 314 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 315 | ) 316 | ( 317 | encoder_hidden_states, 318 | hidden_states, 319 | ) = torch.utils.checkpoint.checkpoint( 320 | create_custom_forward(block), 321 | hidden_states, 322 | encoder_hidden_states, 323 | temb, 324 | image_rotary_emb, 325 | **ckpt_kwargs, 326 | ) 327 | 328 | else: 329 | encoder_hidden_states, hidden_states = block( 330 | hidden_states=hidden_states, 331 | encoder_hidden_states=encoder_hidden_states, 332 | temb=temb, 333 | image_rotary_emb=image_rotary_emb, 334 | ) 335 | block_samples = block_samples + (hidden_states,) 336 | 337 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 338 | 339 | single_block_samples = () 340 | for _, block in enumerate(self.single_transformer_blocks): 341 | if self.training and self.gradient_checkpointing: 342 | 343 | def create_custom_forward(module, return_dict=None): 344 | def custom_forward(*inputs): 345 | if return_dict is not None: 346 | return module(*inputs, return_dict=return_dict) 347 | else: 348 | return module(*inputs) 349 | 350 | return custom_forward 351 | 352 | ckpt_kwargs: Dict[str, Any] = ( 353 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 354 | ) 355 | hidden_states = torch.utils.checkpoint.checkpoint( 356 | create_custom_forward(block), 357 | hidden_states, 358 | temb, 359 | image_rotary_emb, 360 | **ckpt_kwargs, 361 | ) 362 | 363 | else: 364 | hidden_states = block( 365 | hidden_states=hidden_states, 366 | temb=temb, 367 | image_rotary_emb=image_rotary_emb, 368 | ) 369 | single_block_samples = single_block_samples + ( 370 | hidden_states[:, encoder_hidden_states.shape[1] :], 371 | ) 372 | 373 | # controlnet block 374 | controlnet_block_samples = () 375 | for block_sample, controlnet_block in zip( 376 | block_samples, self.controlnet_blocks 377 | ): 378 | block_sample = controlnet_block(block_sample) 379 | controlnet_block_samples = controlnet_block_samples + (block_sample,) 380 | 381 | controlnet_single_block_samples = () 382 | for single_block_sample, controlnet_block in zip( 383 | single_block_samples, self.controlnet_single_blocks 384 | ): 385 | single_block_sample = controlnet_block(single_block_sample) 386 | controlnet_single_block_samples = controlnet_single_block_samples + ( 387 | single_block_sample, 388 | ) 389 | 390 | # scaling 391 | controlnet_block_samples = [ 392 | sample * conditioning_scale for sample in controlnet_block_samples 393 | ] 394 | controlnet_single_block_samples = [ 395 | sample * conditioning_scale for sample in controlnet_single_block_samples 396 | ] 397 | 398 | # 399 | controlnet_block_samples = ( 400 | None if len(controlnet_block_samples) == 0 else controlnet_block_samples 401 | ) 402 | controlnet_single_block_samples = ( 403 | None 404 | if len(controlnet_single_block_samples) == 0 405 | else controlnet_single_block_samples 406 | ) 407 | 408 | if USE_PEFT_BACKEND: 409 | # remove `lora_scale` from each PEFT layer 410 | unscale_lora_layers(self, lora_scale) 411 | 412 | if not return_dict: 413 | return (controlnet_block_samples, controlnet_single_block_samples) 414 | 415 | return FluxControlNetOutput( 416 | controlnet_block_samples=controlnet_block_samples, 417 | controlnet_single_block_samples=controlnet_single_block_samples, 418 | ) 419 | -------------------------------------------------------------------------------- /transformer_flux.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config 9 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin 10 | from diffusers.models.attention import FeedForward 11 | from diffusers.models.attention_processor import ( 12 | Attention, 13 | FluxAttnProcessor2_0, 14 | FluxSingleAttnProcessor2_0, 15 | ) 16 | from diffusers.models.modeling_utils import ModelMixin 17 | from diffusers.models.normalization import ( 18 | AdaLayerNormContinuous, 19 | AdaLayerNormZero, 20 | AdaLayerNormZeroSingle, 21 | ) 22 | from diffusers.utils import ( 23 | USE_PEFT_BACKEND, 24 | is_torch_version, 25 | logging, 26 | scale_lora_layers, 27 | unscale_lora_layers, 28 | ) 29 | from diffusers.utils.torch_utils import maybe_allow_in_graph 30 | from diffusers.models.embeddings import ( 31 | CombinedTimestepGuidanceTextProjEmbeddings, 32 | CombinedTimestepTextProjEmbeddings, 33 | ) 34 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 35 | 36 | 37 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 38 | 39 | 40 | # YiYi to-do: refactor rope related functions/classes 41 | def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: 42 | assert dim % 2 == 0, "The dimension must be even." 43 | 44 | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim 45 | omega = 1.0 / (theta**scale) 46 | 47 | batch_size, seq_length = pos.shape 48 | out = torch.einsum("...n,d->...nd", pos, omega) 49 | cos_out = torch.cos(out) 50 | sin_out = torch.sin(out) 51 | 52 | stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) 53 | out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) 54 | return out.float() 55 | 56 | 57 | # YiYi to-do: refactor rope related functions/classes 58 | class EmbedND(nn.Module): 59 | def __init__(self, dim: int, theta: int, axes_dim: List[int]): 60 | super().__init__() 61 | self.dim = dim 62 | self.theta = theta 63 | self.axes_dim = axes_dim 64 | 65 | def forward(self, ids: torch.Tensor) -> torch.Tensor: 66 | n_axes = ids.shape[-1] 67 | emb = torch.cat( 68 | [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], 69 | dim=-3, 70 | ) 71 | return emb.unsqueeze(1) 72 | 73 | 74 | @maybe_allow_in_graph 75 | class FluxSingleTransformerBlock(nn.Module): 76 | r""" 77 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. 78 | 79 | Reference: https://arxiv.org/abs/2403.03206 80 | 81 | Parameters: 82 | dim (`int`): The number of channels in the input and output. 83 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 84 | attention_head_dim (`int`): The number of channels in each head. 85 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the 86 | processing of `context` conditions. 87 | """ 88 | 89 | def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): 90 | super().__init__() 91 | self.mlp_hidden_dim = int(dim * mlp_ratio) 92 | 93 | self.norm = AdaLayerNormZeroSingle(dim) 94 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) 95 | self.act_mlp = nn.GELU(approximate="tanh") 96 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) 97 | 98 | processor = FluxSingleAttnProcessor2_0() 99 | self.attn = Attention( 100 | query_dim=dim, 101 | cross_attention_dim=None, 102 | dim_head=attention_head_dim, 103 | heads=num_attention_heads, 104 | out_dim=dim, 105 | bias=True, 106 | processor=processor, 107 | qk_norm="rms_norm", 108 | eps=1e-6, 109 | pre_only=True, 110 | ) 111 | 112 | def forward( 113 | self, 114 | hidden_states: torch.FloatTensor, 115 | temb: torch.FloatTensor, 116 | image_rotary_emb=None, 117 | ): 118 | residual = hidden_states 119 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb) 120 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) 121 | 122 | attn_output = self.attn( 123 | hidden_states=norm_hidden_states, 124 | image_rotary_emb=image_rotary_emb, 125 | ) 126 | 127 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) 128 | gate = gate.unsqueeze(1) 129 | hidden_states = gate * self.proj_out(hidden_states) 130 | hidden_states = residual + hidden_states 131 | if hidden_states.dtype == torch.float16: 132 | hidden_states = hidden_states.clip(-65504, 65504) 133 | 134 | return hidden_states 135 | 136 | 137 | @maybe_allow_in_graph 138 | class FluxTransformerBlock(nn.Module): 139 | r""" 140 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. 141 | 142 | Reference: https://arxiv.org/abs/2403.03206 143 | 144 | Parameters: 145 | dim (`int`): The number of channels in the input and output. 146 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 147 | attention_head_dim (`int`): The number of channels in each head. 148 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the 149 | processing of `context` conditions. 150 | """ 151 | 152 | def __init__( 153 | self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6 154 | ): 155 | super().__init__() 156 | 157 | self.norm1 = AdaLayerNormZero(dim) 158 | 159 | self.norm1_context = AdaLayerNormZero(dim) 160 | 161 | if hasattr(F, "scaled_dot_product_attention"): 162 | processor = FluxAttnProcessor2_0() 163 | else: 164 | raise ValueError( 165 | "The current PyTorch version does not support the `scaled_dot_product_attention` function." 166 | ) 167 | self.attn = Attention( 168 | query_dim=dim, 169 | cross_attention_dim=None, 170 | added_kv_proj_dim=dim, 171 | dim_head=attention_head_dim, 172 | heads=num_attention_heads, 173 | out_dim=dim, 174 | context_pre_only=False, 175 | bias=True, 176 | processor=processor, 177 | qk_norm=qk_norm, 178 | eps=eps, 179 | ) 180 | 181 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 182 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 183 | 184 | self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 185 | self.ff_context = FeedForward( 186 | dim=dim, dim_out=dim, activation_fn="gelu-approximate" 187 | ) 188 | 189 | # let chunk size default to None 190 | self._chunk_size = None 191 | self._chunk_dim = 0 192 | 193 | def forward( 194 | self, 195 | hidden_states: torch.FloatTensor, 196 | encoder_hidden_states: torch.FloatTensor, 197 | temb: torch.FloatTensor, 198 | image_rotary_emb=None, 199 | ): 200 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 201 | hidden_states, emb=temb 202 | ) 203 | 204 | ( 205 | norm_encoder_hidden_states, 206 | c_gate_msa, 207 | c_shift_mlp, 208 | c_scale_mlp, 209 | c_gate_mlp, 210 | ) = self.norm1_context(encoder_hidden_states, emb=temb) 211 | 212 | # Attention. 213 | attn_output, context_attn_output = self.attn( 214 | hidden_states=norm_hidden_states, 215 | encoder_hidden_states=norm_encoder_hidden_states, 216 | image_rotary_emb=image_rotary_emb, 217 | ) 218 | 219 | # Process attention outputs for the `hidden_states`. 220 | attn_output = gate_msa.unsqueeze(1) * attn_output 221 | hidden_states = hidden_states + attn_output 222 | 223 | norm_hidden_states = self.norm2(hidden_states) 224 | norm_hidden_states = ( 225 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 226 | ) 227 | 228 | ff_output = self.ff(norm_hidden_states) 229 | ff_output = gate_mlp.unsqueeze(1) * ff_output 230 | 231 | hidden_states = hidden_states + ff_output 232 | 233 | # Process attention outputs for the `encoder_hidden_states`. 234 | 235 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output 236 | encoder_hidden_states = encoder_hidden_states + context_attn_output 237 | 238 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) 239 | norm_encoder_hidden_states = ( 240 | norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) 241 | + c_shift_mlp[:, None] 242 | ) 243 | 244 | context_ff_output = self.ff_context(norm_encoder_hidden_states) 245 | encoder_hidden_states = ( 246 | encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output 247 | ) 248 | if encoder_hidden_states.dtype == torch.float16: 249 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) 250 | 251 | return encoder_hidden_states, hidden_states 252 | 253 | 254 | class FluxTransformer2DModel( 255 | ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin 256 | ): 257 | """ 258 | The Transformer model introduced in Flux. 259 | 260 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 261 | 262 | Parameters: 263 | patch_size (`int`): Patch size to turn the input data into small patches. 264 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. 265 | num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. 266 | num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. 267 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. 268 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. 269 | joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 270 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. 271 | guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. 272 | """ 273 | 274 | _supports_gradient_checkpointing = True 275 | 276 | @register_to_config 277 | def __init__( 278 | self, 279 | patch_size: int = 1, 280 | in_channels: int = 64, 281 | num_layers: int = 19, 282 | num_single_layers: int = 38, 283 | attention_head_dim: int = 128, 284 | num_attention_heads: int = 24, 285 | joint_attention_dim: int = 4096, 286 | pooled_projection_dim: int = 768, 287 | guidance_embeds: bool = False, 288 | axes_dims_rope: List[int] = [16, 56, 56], 289 | ): 290 | super().__init__() 291 | self.out_channels = in_channels 292 | self.inner_dim = ( 293 | self.config.num_attention_heads * self.config.attention_head_dim 294 | ) 295 | 296 | self.pos_embed = EmbedND( 297 | dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope 298 | ) 299 | text_time_guidance_cls = ( 300 | CombinedTimestepGuidanceTextProjEmbeddings 301 | if guidance_embeds 302 | else CombinedTimestepTextProjEmbeddings 303 | ) 304 | self.time_text_embed = text_time_guidance_cls( 305 | embedding_dim=self.inner_dim, 306 | pooled_projection_dim=self.config.pooled_projection_dim, 307 | ) 308 | 309 | self.context_embedder = nn.Linear( 310 | self.config.joint_attention_dim, self.inner_dim 311 | ) 312 | self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) 313 | 314 | self.transformer_blocks = nn.ModuleList( 315 | [ 316 | FluxTransformerBlock( 317 | dim=self.inner_dim, 318 | num_attention_heads=self.config.num_attention_heads, 319 | attention_head_dim=self.config.attention_head_dim, 320 | ) 321 | for i in range(self.config.num_layers) 322 | ] 323 | ) 324 | 325 | self.single_transformer_blocks = nn.ModuleList( 326 | [ 327 | FluxSingleTransformerBlock( 328 | dim=self.inner_dim, 329 | num_attention_heads=self.config.num_attention_heads, 330 | attention_head_dim=self.config.attention_head_dim, 331 | ) 332 | for i in range(self.config.num_single_layers) 333 | ] 334 | ) 335 | 336 | self.norm_out = AdaLayerNormContinuous( 337 | self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 338 | ) 339 | self.proj_out = nn.Linear( 340 | self.inner_dim, patch_size * patch_size * self.out_channels, bias=True 341 | ) 342 | 343 | self.gradient_checkpointing = False 344 | 345 | def _set_gradient_checkpointing(self, module, value=False): 346 | if hasattr(module, "gradient_checkpointing"): 347 | module.gradient_checkpointing = value 348 | 349 | def forward( 350 | self, 351 | hidden_states: torch.Tensor, 352 | encoder_hidden_states: torch.Tensor = None, 353 | pooled_projections: torch.Tensor = None, 354 | timestep: torch.LongTensor = None, 355 | img_ids: torch.Tensor = None, 356 | txt_ids: torch.Tensor = None, 357 | guidance: torch.Tensor = None, 358 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 359 | controlnet_block_samples=None, 360 | controlnet_single_block_samples=None, 361 | return_dict: bool = True, 362 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: 363 | """ 364 | The [`FluxTransformer2DModel`] forward method. 365 | 366 | Args: 367 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): 368 | Input `hidden_states`. 369 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): 370 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. 371 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected 372 | from the embeddings of input conditions. 373 | timestep ( `torch.LongTensor`): 374 | Used to indicate denoising step. 375 | block_controlnet_hidden_states: (`list` of `torch.Tensor`): 376 | A list of tensors that if specified are added to the residuals of transformer blocks. 377 | joint_attention_kwargs (`dict`, *optional*): 378 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 379 | `self.processor` in 380 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 381 | return_dict (`bool`, *optional*, defaults to `True`): 382 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 383 | tuple. 384 | 385 | Returns: 386 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 387 | `tuple` where the first element is the sample tensor. 388 | """ 389 | if joint_attention_kwargs is not None: 390 | joint_attention_kwargs = joint_attention_kwargs.copy() 391 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 392 | else: 393 | lora_scale = 1.0 394 | 395 | if USE_PEFT_BACKEND: 396 | # weight the lora layers by setting `lora_scale` for each PEFT layer 397 | scale_lora_layers(self, lora_scale) 398 | else: 399 | if ( 400 | joint_attention_kwargs is not None 401 | and joint_attention_kwargs.get("scale", None) is not None 402 | ): 403 | logger.warning( 404 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 405 | ) 406 | hidden_states = self.x_embedder(hidden_states) 407 | 408 | timestep = timestep.to(hidden_states.dtype) * 1000 409 | if guidance is not None: 410 | guidance = guidance.to(hidden_states.dtype) * 1000 411 | else: 412 | guidance = None 413 | temb = ( 414 | self.time_text_embed(timestep, pooled_projections) 415 | if guidance is None 416 | else self.time_text_embed(timestep, guidance, pooled_projections) 417 | ) 418 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 419 | 420 | txt_ids = txt_ids.expand(img_ids.size(0), -1, -1) 421 | ids = torch.cat((txt_ids, img_ids), dim=1) 422 | image_rotary_emb = self.pos_embed(ids) 423 | 424 | for index_block, block in enumerate(self.transformer_blocks): 425 | if self.training and self.gradient_checkpointing: 426 | 427 | def create_custom_forward(module, return_dict=None): 428 | def custom_forward(*inputs): 429 | if return_dict is not None: 430 | return module(*inputs, return_dict=return_dict) 431 | else: 432 | return module(*inputs) 433 | 434 | return custom_forward 435 | 436 | ckpt_kwargs: Dict[str, Any] = ( 437 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 438 | ) 439 | ( 440 | encoder_hidden_states, 441 | hidden_states, 442 | ) = torch.utils.checkpoint.checkpoint( 443 | create_custom_forward(block), 444 | hidden_states, 445 | encoder_hidden_states, 446 | temb, 447 | image_rotary_emb, 448 | **ckpt_kwargs, 449 | ) 450 | 451 | else: 452 | encoder_hidden_states, hidden_states = block( 453 | hidden_states=hidden_states, 454 | encoder_hidden_states=encoder_hidden_states, 455 | temb=temb, 456 | image_rotary_emb=image_rotary_emb, 457 | ) 458 | 459 | # controlnet residual 460 | if controlnet_block_samples is not None: 461 | interval_control = len(self.transformer_blocks) / len( 462 | controlnet_block_samples 463 | ) 464 | interval_control = int(np.ceil(interval_control)) 465 | hidden_states = ( 466 | hidden_states 467 | + controlnet_block_samples[index_block // interval_control] 468 | ) 469 | 470 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 471 | 472 | for index_block, block in enumerate(self.single_transformer_blocks): 473 | if self.training and self.gradient_checkpointing: 474 | 475 | def create_custom_forward(module, return_dict=None): 476 | def custom_forward(*inputs): 477 | if return_dict is not None: 478 | return module(*inputs, return_dict=return_dict) 479 | else: 480 | return module(*inputs) 481 | 482 | return custom_forward 483 | 484 | ckpt_kwargs: Dict[str, Any] = ( 485 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 486 | ) 487 | hidden_states = torch.utils.checkpoint.checkpoint( 488 | create_custom_forward(block), 489 | hidden_states, 490 | temb, 491 | image_rotary_emb, 492 | **ckpt_kwargs, 493 | ) 494 | 495 | else: 496 | hidden_states = block( 497 | hidden_states=hidden_states, 498 | temb=temb, 499 | image_rotary_emb=image_rotary_emb, 500 | ) 501 | 502 | # controlnet residual 503 | if controlnet_single_block_samples is not None: 504 | interval_control = len(self.single_transformer_blocks) / len( 505 | controlnet_single_block_samples 506 | ) 507 | interval_control = int(np.ceil(interval_control)) 508 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 509 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] 510 | + controlnet_single_block_samples[index_block // interval_control] 511 | ) 512 | 513 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 514 | 515 | hidden_states = self.norm_out(hidden_states, temb) 516 | output = self.proj_out(hidden_states) 517 | 518 | if USE_PEFT_BACKEND: 519 | # remove `lora_scale` from each PEFT layer 520 | unscale_lora_layers(self, lora_scale) 521 | 522 | if not return_dict: 523 | return (output,) 524 | 525 | return Transformer2DModelOutput(sample=output) 526 | -------------------------------------------------------------------------------- /pipeline_flux_controlnet_inpaint.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Callable, Dict, List, Optional, Union 3 | 4 | import numpy as np 5 | import torch 6 | from transformers import ( 7 | CLIPTextModel, 8 | CLIPTokenizer, 9 | T5EncoderModel, 10 | T5TokenizerFast, 11 | ) 12 | 13 | from diffusers.image_processor import PipelineImageInput, VaeImageProcessor 14 | from diffusers.loaders import FluxLoraLoaderMixin 15 | from diffusers.models.autoencoders import AutoencoderKL 16 | 17 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 18 | from diffusers.utils import ( 19 | USE_PEFT_BACKEND, 20 | is_torch_xla_available, 21 | logging, 22 | replace_example_docstring, 23 | scale_lora_layers, 24 | unscale_lora_layers, 25 | ) 26 | from diffusers.utils.torch_utils import randn_tensor 27 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 28 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput 29 | 30 | from transformer_flux import FluxTransformer2DModel 31 | from controlnet_flux import FluxControlNetModel 32 | 33 | if is_torch_xla_available(): 34 | import torch_xla.core.xla_model as xm 35 | 36 | XLA_AVAILABLE = True 37 | else: 38 | XLA_AVAILABLE = False 39 | 40 | 41 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 42 | 43 | EXAMPLE_DOC_STRING = """ 44 | Examples: 45 | ```py 46 | >>> import torch 47 | >>> from diffusers.utils import load_image 48 | >>> from diffusers import FluxControlNetPipeline 49 | >>> from diffusers import FluxControlNetModel 50 | 51 | >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny-alpha" 52 | >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) 53 | >>> pipe = FluxControlNetPipeline.from_pretrained( 54 | ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16 55 | ... ) 56 | >>> pipe.to("cuda") 57 | >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") 58 | >>> control_mask = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") 59 | >>> prompt = "A girl in city, 25 years old, cool, futuristic" 60 | >>> image = pipe( 61 | ... prompt, 62 | ... control_image=control_image, 63 | ... controlnet_conditioning_scale=0.6, 64 | ... num_inference_steps=28, 65 | ... guidance_scale=3.5, 66 | ... ).images[0] 67 | >>> image.save("flux.png") 68 | ``` 69 | """ 70 | 71 | 72 | # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift 73 | def calculate_shift( 74 | image_seq_len, 75 | base_seq_len: int = 256, 76 | max_seq_len: int = 4096, 77 | base_shift: float = 0.5, 78 | max_shift: float = 1.16, 79 | ): 80 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) 81 | b = base_shift - m * base_seq_len 82 | mu = image_seq_len * m + b 83 | return mu 84 | 85 | 86 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 87 | def retrieve_timesteps( 88 | scheduler, 89 | num_inference_steps: Optional[int] = None, 90 | device: Optional[Union[str, torch.device]] = None, 91 | timesteps: Optional[List[int]] = None, 92 | sigmas: Optional[List[float]] = None, 93 | **kwargs, 94 | ): 95 | """ 96 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 97 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 98 | 99 | Args: 100 | scheduler (`SchedulerMixin`): 101 | The scheduler to get timesteps from. 102 | num_inference_steps (`int`): 103 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 104 | must be `None`. 105 | device (`str` or `torch.device`, *optional*): 106 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 107 | timesteps (`List[int]`, *optional*): 108 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 109 | `num_inference_steps` and `sigmas` must be `None`. 110 | sigmas (`List[float]`, *optional*): 111 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 112 | `num_inference_steps` and `timesteps` must be `None`. 113 | 114 | Returns: 115 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 116 | second element is the number of inference steps. 117 | """ 118 | if timesteps is not None and sigmas is not None: 119 | raise ValueError( 120 | "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" 121 | ) 122 | if timesteps is not None: 123 | accepts_timesteps = "timesteps" in set( 124 | inspect.signature(scheduler.set_timesteps).parameters.keys() 125 | ) 126 | if not accepts_timesteps: 127 | raise ValueError( 128 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 129 | f" timestep schedules. Please check whether you are using the correct scheduler." 130 | ) 131 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 132 | timesteps = scheduler.timesteps 133 | num_inference_steps = len(timesteps) 134 | elif sigmas is not None: 135 | accept_sigmas = "sigmas" in set( 136 | inspect.signature(scheduler.set_timesteps).parameters.keys() 137 | ) 138 | if not accept_sigmas: 139 | raise ValueError( 140 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 141 | f" sigmas schedules. Please check whether you are using the correct scheduler." 142 | ) 143 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 144 | timesteps = scheduler.timesteps 145 | num_inference_steps = len(timesteps) 146 | else: 147 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 148 | timesteps = scheduler.timesteps 149 | return timesteps, num_inference_steps 150 | 151 | 152 | class FluxControlNetInpaintingPipeline(DiffusionPipeline, FluxLoraLoaderMixin): 153 | r""" 154 | The Flux pipeline for text-to-image generation. 155 | 156 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 157 | 158 | Args: 159 | transformer ([`FluxTransformer2DModel`]): 160 | Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. 161 | scheduler ([`FlowMatchEulerDiscreteScheduler`]): 162 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 163 | vae ([`AutoencoderKL`]): 164 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 165 | text_encoder ([`CLIPTextModel`]): 166 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 167 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 168 | text_encoder_2 ([`T5EncoderModel`]): 169 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically 170 | the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. 171 | tokenizer (`CLIPTokenizer`): 172 | Tokenizer of class 173 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). 174 | tokenizer_2 (`T5TokenizerFast`): 175 | Second Tokenizer of class 176 | [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). 177 | """ 178 | 179 | model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" 180 | _optional_components = [] 181 | _callback_tensor_inputs = ["latents", "prompt_embeds"] 182 | 183 | def __init__( 184 | self, 185 | scheduler: FlowMatchEulerDiscreteScheduler, 186 | vae: AutoencoderKL, 187 | text_encoder: CLIPTextModel, 188 | tokenizer: CLIPTokenizer, 189 | text_encoder_2: T5EncoderModel, 190 | tokenizer_2: T5TokenizerFast, 191 | transformer: FluxTransformer2DModel, 192 | controlnet: FluxControlNetModel, 193 | ): 194 | super().__init__() 195 | 196 | self.register_modules( 197 | vae=vae, 198 | text_encoder=text_encoder, 199 | text_encoder_2=text_encoder_2, 200 | tokenizer=tokenizer, 201 | tokenizer_2=tokenizer_2, 202 | transformer=transformer, 203 | scheduler=scheduler, 204 | controlnet=controlnet, 205 | ) 206 | self.vae_scale_factor = ( 207 | 2 ** (len(self.vae.config.block_out_channels)) 208 | if hasattr(self, "vae") and self.vae is not None 209 | else 16 210 | ) 211 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=True, do_convert_rgb=True, do_normalize=True) 212 | self.mask_processor = VaeImageProcessor( 213 | vae_scale_factor=self.vae_scale_factor, 214 | do_resize=True, 215 | do_convert_grayscale=True, 216 | do_normalize=False, 217 | do_binarize=True, 218 | ) 219 | self.tokenizer_max_length = ( 220 | self.tokenizer.model_max_length 221 | if hasattr(self, "tokenizer") and self.tokenizer is not None 222 | else 77 223 | ) 224 | self.default_sample_size = 64 225 | 226 | @property 227 | def do_classifier_free_guidance(self): 228 | return self._guidance_scale > 1 229 | 230 | def _get_t5_prompt_embeds( 231 | self, 232 | prompt: Union[str, List[str]] = None, 233 | num_images_per_prompt: int = 1, 234 | max_sequence_length: int = 512, 235 | device: Optional[torch.device] = None, 236 | dtype: Optional[torch.dtype] = None, 237 | ): 238 | device = device or self._execution_device 239 | dtype = dtype or self.text_encoder.dtype 240 | 241 | prompt = [prompt] if isinstance(prompt, str) else prompt 242 | batch_size = len(prompt) 243 | 244 | text_inputs = self.tokenizer_2( 245 | prompt, 246 | padding="max_length", 247 | max_length=max_sequence_length, 248 | truncation=True, 249 | return_length=False, 250 | return_overflowing_tokens=False, 251 | return_tensors="pt", 252 | ) 253 | text_input_ids = text_inputs.input_ids 254 | untruncated_ids = self.tokenizer_2( 255 | prompt, padding="longest", return_tensors="pt" 256 | ).input_ids 257 | 258 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 259 | text_input_ids, untruncated_ids 260 | ): 261 | removed_text = self.tokenizer_2.batch_decode( 262 | untruncated_ids[:, self.tokenizer_max_length - 1 : -1] 263 | ) 264 | logger.warning( 265 | "The following part of your input was truncated because `max_sequence_length` is set to " 266 | f" {max_sequence_length} tokens: {removed_text}" 267 | ) 268 | 269 | prompt_embeds = self.text_encoder_2( 270 | text_input_ids.to(device), output_hidden_states=False 271 | )[0] 272 | 273 | dtype = self.text_encoder_2.dtype 274 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 275 | 276 | _, seq_len, _ = prompt_embeds.shape 277 | 278 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 279 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 280 | prompt_embeds = prompt_embeds.view( 281 | batch_size * num_images_per_prompt, seq_len, -1 282 | ) 283 | 284 | return prompt_embeds 285 | 286 | def _get_clip_prompt_embeds( 287 | self, 288 | prompt: Union[str, List[str]], 289 | num_images_per_prompt: int = 1, 290 | device: Optional[torch.device] = None, 291 | ): 292 | device = device or self._execution_device 293 | 294 | prompt = [prompt] if isinstance(prompt, str) else prompt 295 | batch_size = len(prompt) 296 | 297 | text_inputs = self.tokenizer( 298 | prompt, 299 | padding="max_length", 300 | max_length=self.tokenizer_max_length, 301 | truncation=True, 302 | return_overflowing_tokens=False, 303 | return_length=False, 304 | return_tensors="pt", 305 | ) 306 | 307 | text_input_ids = text_inputs.input_ids 308 | untruncated_ids = self.tokenizer( 309 | prompt, padding="longest", return_tensors="pt" 310 | ).input_ids 311 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 312 | text_input_ids, untruncated_ids 313 | ): 314 | removed_text = self.tokenizer.batch_decode( 315 | untruncated_ids[:, self.tokenizer_max_length - 1 : -1] 316 | ) 317 | logger.warning( 318 | "The following part of your input was truncated because CLIP can only handle sequences up to" 319 | f" {self.tokenizer_max_length} tokens: {removed_text}" 320 | ) 321 | prompt_embeds = self.text_encoder( 322 | text_input_ids.to(device), output_hidden_states=False 323 | ) 324 | 325 | # Use pooled output of CLIPTextModel 326 | prompt_embeds = prompt_embeds.pooler_output 327 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 328 | 329 | # duplicate text embeddings for each generation per prompt, using mps friendly method 330 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 331 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 332 | 333 | return prompt_embeds 334 | 335 | def encode_prompt( 336 | self, 337 | prompt: Union[str, List[str]], 338 | prompt_2: Union[str, List[str]], 339 | device: Optional[torch.device] = None, 340 | num_images_per_prompt: int = 1, 341 | do_classifier_free_guidance: bool = True, 342 | negative_prompt: Optional[Union[str, List[str]]] = None, 343 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 344 | prompt_embeds: Optional[torch.FloatTensor] = None, 345 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 346 | max_sequence_length: int = 512, 347 | lora_scale: Optional[float] = None, 348 | ): 349 | r""" 350 | 351 | Args: 352 | prompt (`str` or `List[str]`, *optional*): 353 | prompt to be encoded 354 | prompt_2 (`str` or `List[str]`, *optional*): 355 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 356 | used in all text-encoders 357 | device: (`torch.device`): 358 | torch device 359 | num_images_per_prompt (`int`): 360 | number of images that should be generated per prompt 361 | do_classifier_free_guidance (`bool`): 362 | whether to use classifier-free guidance or not 363 | negative_prompt (`str` or `List[str]`, *optional*): 364 | negative prompt to be encoded 365 | negative_prompt_2 (`str` or `List[str]`, *optional*): 366 | negative prompt to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is 367 | used in all text-encoders 368 | prompt_embeds (`torch.FloatTensor`, *optional*): 369 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 370 | provided, text embeddings will be generated from `prompt` input argument. 371 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 372 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 373 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 374 | clip_skip (`int`, *optional*): 375 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 376 | the output of the pre-final layer will be used for computing the prompt embeddings. 377 | lora_scale (`float`, *optional*): 378 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 379 | """ 380 | device = device or self._execution_device 381 | 382 | # set lora scale so that monkey patched LoRA 383 | # function of text encoder can correctly access it 384 | if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): 385 | self._lora_scale = lora_scale 386 | 387 | # dynamically adjust the LoRA scale 388 | if self.text_encoder is not None and USE_PEFT_BACKEND: 389 | scale_lora_layers(self.text_encoder, lora_scale) 390 | if self.text_encoder_2 is not None and USE_PEFT_BACKEND: 391 | scale_lora_layers(self.text_encoder_2, lora_scale) 392 | 393 | prompt = [prompt] if isinstance(prompt, str) else prompt 394 | if prompt is not None: 395 | batch_size = len(prompt) 396 | else: 397 | batch_size = prompt_embeds.shape[0] 398 | 399 | if prompt_embeds is None: 400 | prompt_2 = prompt_2 or prompt 401 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 402 | 403 | # We only use the pooled prompt output from the CLIPTextModel 404 | pooled_prompt_embeds = self._get_clip_prompt_embeds( 405 | prompt=prompt, 406 | device=device, 407 | num_images_per_prompt=num_images_per_prompt, 408 | ) 409 | prompt_embeds = self._get_t5_prompt_embeds( 410 | prompt=prompt_2, 411 | num_images_per_prompt=num_images_per_prompt, 412 | max_sequence_length=max_sequence_length, 413 | device=device, 414 | ) 415 | 416 | if do_classifier_free_guidance: 417 | # 处理 negative prompt 418 | negative_prompt = negative_prompt or "" 419 | negative_prompt_2 = negative_prompt_2 or negative_prompt 420 | 421 | negative_pooled_prompt_embeds = self._get_clip_prompt_embeds( 422 | negative_prompt, 423 | device=device, 424 | num_images_per_prompt=num_images_per_prompt, 425 | ) 426 | negative_prompt_embeds = self._get_t5_prompt_embeds( 427 | negative_prompt_2, 428 | num_images_per_prompt=num_images_per_prompt, 429 | max_sequence_length=max_sequence_length, 430 | device=device, 431 | ) 432 | 433 | if self.text_encoder is not None: 434 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 435 | # Retrieve the original scale by scaling back the LoRA layers 436 | unscale_lora_layers(self.text_encoder, lora_scale) 437 | 438 | if self.text_encoder_2 is not None: 439 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 440 | # Retrieve the original scale by scaling back the LoRA layers 441 | unscale_lora_layers(self.text_encoder_2, lora_scale) 442 | 443 | text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to( 444 | device=device, dtype=self.text_encoder.dtype 445 | ) 446 | 447 | return prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds,text_ids 448 | 449 | def check_inputs( 450 | self, 451 | prompt, 452 | prompt_2, 453 | height, 454 | width, 455 | prompt_embeds=None, 456 | pooled_prompt_embeds=None, 457 | callback_on_step_end_tensor_inputs=None, 458 | max_sequence_length=None, 459 | ): 460 | if height % 8 != 0 or width % 8 != 0: 461 | raise ValueError( 462 | f"`height` and `width` have to be divisible by 8 but are {height} and {width}." 463 | ) 464 | 465 | if callback_on_step_end_tensor_inputs is not None and not all( 466 | k in self._callback_tensor_inputs 467 | for k in callback_on_step_end_tensor_inputs 468 | ): 469 | raise ValueError( 470 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 471 | ) 472 | 473 | if prompt is not None and prompt_embeds is not None: 474 | raise ValueError( 475 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 476 | " only forward one of the two." 477 | ) 478 | elif prompt_2 is not None and prompt_embeds is not None: 479 | raise ValueError( 480 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 481 | " only forward one of the two." 482 | ) 483 | elif prompt is None and prompt_embeds is None: 484 | raise ValueError( 485 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 486 | ) 487 | elif prompt is not None and ( 488 | not isinstance(prompt, str) and not isinstance(prompt, list) 489 | ): 490 | raise ValueError( 491 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" 492 | ) 493 | elif prompt_2 is not None and ( 494 | not isinstance(prompt_2, str) and not isinstance(prompt_2, list) 495 | ): 496 | raise ValueError( 497 | f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}" 498 | ) 499 | 500 | if prompt_embeds is not None and pooled_prompt_embeds is None: 501 | raise ValueError( 502 | "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." 503 | ) 504 | 505 | if max_sequence_length is not None and max_sequence_length > 512: 506 | raise ValueError( 507 | f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}" 508 | ) 509 | 510 | # Copied from diffusers.pipelines.flux.pipeline_flux._prepare_latent_image_ids 511 | @staticmethod 512 | def _prepare_latent_image_ids(batch_size, height, width, device, dtype): 513 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) 514 | latent_image_ids[..., 1] = ( 515 | latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] 516 | ) 517 | latent_image_ids[..., 2] = ( 518 | latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] 519 | ) 520 | 521 | ( 522 | latent_image_id_height, 523 | latent_image_id_width, 524 | latent_image_id_channels, 525 | ) = latent_image_ids.shape 526 | 527 | latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) 528 | latent_image_ids = latent_image_ids.reshape( 529 | batch_size, 530 | latent_image_id_height * latent_image_id_width, 531 | latent_image_id_channels, 532 | ) 533 | 534 | return latent_image_ids.to(device=device, dtype=dtype) 535 | 536 | # Copied from diffusers.pipelines.flux.pipeline_flux._pack_latents 537 | @staticmethod 538 | def _pack_latents(latents, batch_size, num_channels_latents, height, width): 539 | latents = latents.view( 540 | batch_size, num_channels_latents, height // 2, 2, width // 2, 2 541 | ) 542 | latents = latents.permute(0, 2, 4, 1, 3, 5) 543 | latents = latents.reshape( 544 | batch_size, (height // 2) * (width // 2), num_channels_latents * 4 545 | ) 546 | 547 | return latents 548 | 549 | # Copied from diffusers.pipelines.flux.pipeline_flux._unpack_latents 550 | @staticmethod 551 | def _unpack_latents(latents, height, width, vae_scale_factor): 552 | batch_size, num_patches, channels = latents.shape 553 | 554 | height = height // vae_scale_factor 555 | width = width // vae_scale_factor 556 | 557 | latents = latents.view(batch_size, height, width, channels // 4, 2, 2) 558 | latents = latents.permute(0, 3, 1, 4, 2, 5) 559 | 560 | latents = latents.reshape( 561 | batch_size, channels // (2 * 2), height * 2, width * 2 562 | ) 563 | 564 | return latents 565 | 566 | # Copied from diffusers.pipelines.flux.pipeline_flux.prepare_latents 567 | def prepare_latents( 568 | self, 569 | batch_size, 570 | num_channels_latents, 571 | height, 572 | width, 573 | dtype, 574 | device, 575 | generator, 576 | latents=None, 577 | ): 578 | height = 2 * (int(height) // self.vae_scale_factor) 579 | width = 2 * (int(width) // self.vae_scale_factor) 580 | 581 | shape = (batch_size, num_channels_latents, height, width) 582 | 583 | if latents is not None: 584 | latent_image_ids = self._prepare_latent_image_ids( 585 | batch_size, height, width, device, dtype 586 | ) 587 | return latents.to(device=device, dtype=dtype), latent_image_ids 588 | 589 | if isinstance(generator, list) and len(generator) != batch_size: 590 | raise ValueError( 591 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 592 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 593 | ) 594 | 595 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 596 | latents = self._pack_latents( 597 | latents, batch_size, num_channels_latents, height, width 598 | ) 599 | 600 | latent_image_ids = self._prepare_latent_image_ids( 601 | batch_size, height, width, device, dtype 602 | ) 603 | 604 | return latents, latent_image_ids 605 | 606 | # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image 607 | def prepare_image( 608 | self, 609 | image, 610 | width, 611 | height, 612 | batch_size, 613 | num_images_per_prompt, 614 | device, 615 | dtype, 616 | ): 617 | if isinstance(image, torch.Tensor): 618 | pass 619 | else: 620 | image = self.image_processor.preprocess(image, height=height, width=width) 621 | 622 | image_batch_size = image.shape[0] 623 | 624 | if image_batch_size == 1: 625 | repeat_by = batch_size 626 | else: 627 | # image batch size is the same as prompt batch size 628 | repeat_by = num_images_per_prompt 629 | 630 | image = image.repeat_interleave(repeat_by, dim=0) 631 | 632 | image = image.to(device=device, dtype=dtype) 633 | 634 | return image 635 | 636 | def prepare_image_with_mask( 637 | self, 638 | image, 639 | mask, 640 | width, 641 | height, 642 | batch_size, 643 | num_images_per_prompt, 644 | device, 645 | dtype, 646 | do_classifier_free_guidance = False, 647 | ): 648 | # Prepare image 649 | if isinstance(image, torch.Tensor): 650 | pass 651 | else: 652 | image = self.image_processor.preprocess(image, height=height, width=width) 653 | 654 | image_batch_size = image.shape[0] 655 | if image_batch_size == 1: 656 | repeat_by = batch_size 657 | else: 658 | # image batch size is the same as prompt batch size 659 | repeat_by = num_images_per_prompt 660 | image = image.repeat_interleave(repeat_by, dim=0) 661 | image = image.to(device=device, dtype=dtype) 662 | 663 | # Prepare mask 664 | if isinstance(mask, torch.Tensor): 665 | pass 666 | else: 667 | mask = self.mask_processor.preprocess(mask, height=height, width=width) 668 | mask = mask.repeat_interleave(repeat_by, dim=0) 669 | mask = mask.to(device=device, dtype=dtype) 670 | 671 | # Get masked image 672 | masked_image = image.clone() 673 | masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1 674 | 675 | # Encode to latents 676 | image_latents = self.vae.encode(masked_image.to(self.vae.dtype)).latent_dist.sample() 677 | image_latents = ( 678 | image_latents - self.vae.config.shift_factor 679 | ) * self.vae.config.scaling_factor 680 | image_latents = image_latents.to(dtype) 681 | 682 | mask = torch.nn.functional.interpolate( 683 | mask, size=(height // self.vae_scale_factor * 2, width // self.vae_scale_factor * 2) 684 | ) 685 | mask = 1 - mask 686 | 687 | control_image = torch.cat([image_latents, mask], dim=1) 688 | 689 | # Pack cond latents 690 | packed_control_image = self._pack_latents( 691 | control_image, 692 | batch_size * num_images_per_prompt, 693 | control_image.shape[1], 694 | control_image.shape[2], 695 | control_image.shape[3], 696 | ) 697 | 698 | if do_classifier_free_guidance: 699 | packed_control_image = torch.cat([packed_control_image] * 2) 700 | 701 | return packed_control_image, height, width 702 | 703 | @property 704 | def guidance_scale(self): 705 | return self._guidance_scale 706 | 707 | @property 708 | def joint_attention_kwargs(self): 709 | return self._joint_attention_kwargs 710 | 711 | @property 712 | def num_timesteps(self): 713 | return self._num_timesteps 714 | 715 | @property 716 | def interrupt(self): 717 | return self._interrupt 718 | 719 | @torch.no_grad() 720 | @replace_example_docstring(EXAMPLE_DOC_STRING) 721 | def __call__( 722 | self, 723 | prompt: Union[str, List[str]] = None, 724 | prompt_2: Optional[Union[str, List[str]]] = None, 725 | height: Optional[int] = None, 726 | width: Optional[int] = None, 727 | num_inference_steps: int = 28, 728 | timesteps: List[int] = None, 729 | guidance_scale: float = 7.0, 730 | true_guidance_scale: float = 3.5 , 731 | negative_prompt: Optional[Union[str, List[str]]] = None, 732 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 733 | control_image: PipelineImageInput = None, 734 | control_mask: PipelineImageInput = None, 735 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0, 736 | num_images_per_prompt: Optional[int] = 1, 737 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 738 | latents: Optional[torch.FloatTensor] = None, 739 | prompt_embeds: Optional[torch.FloatTensor] = None, 740 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 741 | output_type: Optional[str] = "pil", 742 | return_dict: bool = True, 743 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 744 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 745 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 746 | max_sequence_length: int = 512, 747 | ): 748 | r""" 749 | Function invoked when calling the pipeline for generation. 750 | 751 | Args: 752 | prompt (`str` or `List[str]`, *optional*): 753 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 754 | instead. 755 | prompt_2 (`str` or `List[str]`, *optional*): 756 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 757 | will be used instead 758 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 759 | The height in pixels of the generated image. This is set to 1024 by default for the best results. 760 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 761 | The width in pixels of the generated image. This is set to 1024 by default for the best results. 762 | num_inference_steps (`int`, *optional*, defaults to 50): 763 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 764 | expense of slower inference. 765 | timesteps (`List[int]`, *optional*): 766 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 767 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 768 | passed will be used. Must be in descending order. 769 | guidance_scale (`float`, *optional*, defaults to 7.0): 770 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 771 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 772 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 773 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 774 | usually at the expense of lower image quality. 775 | num_images_per_prompt (`int`, *optional*, defaults to 1): 776 | The number of images to generate per prompt. 777 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 778 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 779 | to make generation deterministic. 780 | latents (`torch.FloatTensor`, *optional*): 781 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 782 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 783 | tensor will ge generated by sampling using the supplied random `generator`. 784 | prompt_embeds (`torch.FloatTensor`, *optional*): 785 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 786 | provided, text embeddings will be generated from `prompt` input argument. 787 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 788 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 789 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 790 | output_type (`str`, *optional*, defaults to `"pil"`): 791 | The output format of the generate image. Choose between 792 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 793 | return_dict (`bool`, *optional*, defaults to `True`): 794 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. 795 | joint_attention_kwargs (`dict`, *optional*): 796 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 797 | `self.processor` in 798 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 799 | callback_on_step_end (`Callable`, *optional*): 800 | A function that calls at the end of each denoising steps during the inference. The function is called 801 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 802 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 803 | `callback_on_step_end_tensor_inputs`. 804 | callback_on_step_end_tensor_inputs (`List`, *optional*): 805 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 806 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 807 | `._callback_tensor_inputs` attribute of your pipeline class. 808 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. 809 | 810 | Examples: 811 | 812 | Returns: 813 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` 814 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated 815 | images. 816 | """ 817 | 818 | height = height or self.default_sample_size * self.vae_scale_factor 819 | width = width or self.default_sample_size * self.vae_scale_factor 820 | 821 | # 1. Check inputs. Raise error if not correct 822 | self.check_inputs( 823 | prompt, 824 | prompt_2, 825 | height, 826 | width, 827 | prompt_embeds=prompt_embeds, 828 | pooled_prompt_embeds=pooled_prompt_embeds, 829 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 830 | max_sequence_length=max_sequence_length, 831 | ) 832 | 833 | self._guidance_scale = true_guidance_scale 834 | self._joint_attention_kwargs = joint_attention_kwargs 835 | self._interrupt = False 836 | 837 | # 2. Define call parameters 838 | if prompt is not None and isinstance(prompt, str): 839 | batch_size = 1 840 | elif prompt is not None and isinstance(prompt, list): 841 | batch_size = len(prompt) 842 | else: 843 | batch_size = prompt_embeds.shape[0] 844 | 845 | device = self._execution_device 846 | dtype = self.transformer.dtype 847 | 848 | lora_scale = ( 849 | self.joint_attention_kwargs.get("scale", None) 850 | if self.joint_attention_kwargs is not None 851 | else None 852 | ) 853 | ( 854 | prompt_embeds, 855 | pooled_prompt_embeds, 856 | negative_prompt_embeds, 857 | negative_pooled_prompt_embeds, 858 | text_ids 859 | ) = self.encode_prompt( 860 | prompt=prompt, 861 | prompt_2=prompt_2, 862 | prompt_embeds=prompt_embeds, 863 | pooled_prompt_embeds=pooled_prompt_embeds, 864 | do_classifier_free_guidance = self.do_classifier_free_guidance, 865 | negative_prompt = negative_prompt, 866 | negative_prompt_2 = negative_prompt_2, 867 | device=device, 868 | num_images_per_prompt=num_images_per_prompt, 869 | max_sequence_length=max_sequence_length, 870 | lora_scale=lora_scale, 871 | ) 872 | 873 | # 在 encode_prompt 之后 874 | if self.do_classifier_free_guidance: 875 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim = 0) 876 | pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim = 0) 877 | text_ids = torch.cat([text_ids, text_ids], dim = 0) 878 | 879 | # 3. Prepare control image 880 | num_channels_latents = self.transformer.config.in_channels // 4 881 | if isinstance(self.controlnet, FluxControlNetModel): 882 | control_image, height, width = self.prepare_image_with_mask( 883 | image=control_image, 884 | mask=control_mask, 885 | width=width, 886 | height=height, 887 | batch_size=batch_size * num_images_per_prompt, 888 | num_images_per_prompt=num_images_per_prompt, 889 | device=device, 890 | dtype=dtype, 891 | do_classifier_free_guidance=self.do_classifier_free_guidance, 892 | ) 893 | 894 | # 4. Prepare latent variables 895 | num_channels_latents = self.transformer.config.in_channels // 4 896 | latents, latent_image_ids = self.prepare_latents( 897 | batch_size * num_images_per_prompt, 898 | num_channels_latents, 899 | height, 900 | width, 901 | prompt_embeds.dtype, 902 | device, 903 | generator, 904 | latents, 905 | ) 906 | 907 | if self.do_classifier_free_guidance: 908 | latent_image_ids = torch.cat([latent_image_ids] * 2) 909 | 910 | # 5. Prepare timesteps 911 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 912 | image_seq_len = latents.shape[1] 913 | mu = calculate_shift( 914 | image_seq_len, 915 | self.scheduler.config.base_image_seq_len, 916 | self.scheduler.config.max_image_seq_len, 917 | self.scheduler.config.base_shift, 918 | self.scheduler.config.max_shift, 919 | ) 920 | timesteps, num_inference_steps = retrieve_timesteps( 921 | self.scheduler, 922 | num_inference_steps, 923 | device, 924 | timesteps, 925 | sigmas, 926 | mu=mu, 927 | ) 928 | 929 | num_warmup_steps = max( 930 | len(timesteps) - num_inference_steps * self.scheduler.order, 0 931 | ) 932 | self._num_timesteps = len(timesteps) 933 | 934 | # 6. Denoising loop 935 | with self.progress_bar(total=num_inference_steps) as progress_bar: 936 | for i, t in enumerate(timesteps): 937 | if self.interrupt: 938 | continue 939 | 940 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 941 | 942 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 943 | timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) 944 | 945 | # handle guidance 946 | if self.transformer.config.guidance_embeds: 947 | guidance = torch.tensor([guidance_scale], device=device) 948 | guidance = guidance.expand(latent_model_input.shape[0]) 949 | else: 950 | guidance = None 951 | 952 | # controlnet 953 | ( 954 | controlnet_block_samples, 955 | controlnet_single_block_samples, 956 | ) = self.controlnet( 957 | hidden_states=latent_model_input, 958 | controlnet_cond=control_image, 959 | conditioning_scale=controlnet_conditioning_scale, 960 | timestep=timestep / 1000, 961 | guidance=guidance, 962 | pooled_projections=pooled_prompt_embeds, 963 | encoder_hidden_states=prompt_embeds, 964 | txt_ids=text_ids, 965 | img_ids=latent_image_ids, 966 | joint_attention_kwargs=self.joint_attention_kwargs, 967 | return_dict=False, 968 | ) 969 | 970 | noise_pred = self.transformer( 971 | hidden_states=latent_model_input, 972 | # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) 973 | timestep=timestep / 1000, 974 | guidance=guidance, 975 | pooled_projections=pooled_prompt_embeds, 976 | encoder_hidden_states=prompt_embeds, 977 | controlnet_block_samples=[ 978 | sample.to(dtype=self.transformer.dtype) 979 | for sample in controlnet_block_samples 980 | ], 981 | controlnet_single_block_samples=[ 982 | sample.to(dtype=self.transformer.dtype) 983 | for sample in controlnet_single_block_samples 984 | ] if controlnet_single_block_samples is not None else controlnet_single_block_samples, 985 | txt_ids=text_ids, 986 | img_ids=latent_image_ids, 987 | joint_attention_kwargs=self.joint_attention_kwargs, 988 | return_dict=False, 989 | )[0] 990 | 991 | # 在生成循环中 992 | if self.do_classifier_free_guidance: 993 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 994 | noise_pred = noise_pred_uncond + true_guidance_scale * (noise_pred_text - noise_pred_uncond) 995 | 996 | # compute the previous noisy sample x_t -> x_t-1 997 | latents_dtype = latents.dtype 998 | latents = self.scheduler.step( 999 | noise_pred, t, latents, return_dict=False 1000 | )[0] 1001 | 1002 | if latents.dtype != latents_dtype: 1003 | if torch.backends.mps.is_available(): 1004 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 1005 | latents = latents.to(latents_dtype) 1006 | 1007 | if callback_on_step_end is not None: 1008 | callback_kwargs = {} 1009 | for k in callback_on_step_end_tensor_inputs: 1010 | callback_kwargs[k] = locals()[k] 1011 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 1012 | 1013 | latents = callback_outputs.pop("latents", latents) 1014 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 1015 | 1016 | # call the callback, if provided 1017 | if i == len(timesteps) - 1 or ( 1018 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 1019 | ): 1020 | progress_bar.update() 1021 | 1022 | if XLA_AVAILABLE: 1023 | xm.mark_step() 1024 | 1025 | if output_type == "latent": 1026 | image = latents 1027 | 1028 | else: 1029 | latents = self._unpack_latents( 1030 | latents, height, width, self.vae_scale_factor 1031 | ) 1032 | latents = ( 1033 | latents / self.vae.config.scaling_factor 1034 | ) + self.vae.config.shift_factor 1035 | latents = latents.to(self.vae.dtype) 1036 | 1037 | image = self.vae.decode(latents, return_dict=False)[0] 1038 | image = self.image_processor.postprocess(image, output_type=output_type) 1039 | 1040 | # Offload all models 1041 | self.maybe_free_model_hooks() 1042 | 1043 | if not return_dict: 1044 | return (image,) 1045 | 1046 | return FluxPipelineOutput(images=image) 1047 | --------------------------------------------------------------------------------