├── README.md ├── assets ├── Arial_Unicode.ttf ├── adam-jang-8pOTAtyd_Mc-unsplash.jpg ├── comparison.pdf ├── example1.png ├── example2.png ├── example3.png ├── example4.png ├── example5.png ├── infer.png ├── inpaint.png ├── ipa.png ├── train.png └── union.png ├── controlnet_flux.py ├── infer.py ├── infer_inpaint.py ├── pipeline_flux_controlnet.py ├── pipeline_flux_controlnet_inpaint.py └── results ├── result.jpg └── result_inpaint.jpg /README.md: -------------------------------------------------------------------------------- 1 |
2 |

RepText: Rendering Visual Text via Replicating

3 | 4 |
5 | Haofan Wang, 6 | Yujia Xu, 7 | Yimeng Li, 8 | Junchen Li, 9 | Chaowei Zhang, 10 | Jing Wang, 11 | Kejia Yang, 12 | Zhibo Chen 13 |
14 |
15 | Shakker Labs, Liblib AI
16 |

Corresponding author

17 |
18 | 19 | 20 | 21 | [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue)](https://huggingface.co/Shakker-Labs/RepText) 22 | 23 |
24 | 25 | We present RepText, which aims to empower pre-trained monolingual text-to-image generation models with the ability to accurately render, or more precisely, replicate, multilingual visual text in user-specified fonts, without the need to really understand them. Specifically, we adopt the setting from ControlNet and additionally integrate language agnostic glyph and position of rendered text to enable generating harmonized visual text, allowing users to customize text content, font and position on their needs. To improve accuracy, a text perceptual loss is employed along with the diffusion loss. Furthermore, to stabilize rendering process, at the inference phase, we directly initialize with noisy glyph latent instead of random initialization, and adopt region masks to restrict the feature injection to only the text region to avoid distortion of the background. We conducted extensive experiments to verify the effectiveness of our RepText relative to existing works, our approach outperforms existing open-source methods and achieves comparable results to native multi-language closed-source models. 26 | 27 |
28 | 29 |
30 | 31 | ## ⭐ Update 32 | - [2025/06/07] [Model Weights](https://huggingface.co/Shakker-Labs/RepText) and inference code released! 33 | - [2025/04/28] [Technical Report](https://arxiv.org/abs/2504.19724) released! 34 | 35 | ## Method 36 | 37 |
38 | 39 |
40 | 41 |
42 | 43 |
44 | 45 | ## Usage 46 | ```python 47 | import torch 48 | from controlnet_flux import FluxControlNetModel 49 | from pipeline_flux_controlnet import FluxControlNetPipeline 50 | 51 | from PIL import Image, ImageDraw, ImageFont 52 | import numpy as np 53 | import cv2 54 | import re 55 | import os 56 | 57 | def contains_chinese(text): 58 | if re.search(r'[\u4e00-\u9fff]', text): 59 | return True 60 | return False 61 | 62 | def canny(img): 63 | low_threshold = 50 64 | high_threshold = 100 65 | img = cv2.Canny(img, low_threshold, high_threshold) 66 | img = img[:, :, None] 67 | img = 255 - np.concatenate([img, img, img], axis=2) 68 | return img 69 | 70 | base_model = "black-forest-labs/FLUX.1-dev" 71 | controlnet_model = "Shakker-Labs/RepText" 72 | 73 | controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) 74 | pipe = FluxControlNetPipeline.from_pretrained( 75 | base_model, controlnet=controlnet, torch_dtype=torch.bfloat16 76 | ).to("cuda") 77 | 78 | ## set resolution 79 | width, height = 1024, 1024 80 | 81 | ## set font 82 | font_path = "./assets/Arial_Unicode.ttf" # use your own font 83 | font_size = 80 # it is recommended to use a font size >= 60 84 | font = ImageFont.truetype(font_path, font_size) 85 | 86 | ## set text content, position, color 87 | text_list = ["哩布哩布"] 88 | text_position_list = [(370, 200)] 89 | text_color_list = [(255, 255, 255)] 90 | 91 | ## set controlnet conditions 92 | control_image_list = [] # canny list 93 | control_position_list = [] # position list 94 | control_mask_list = [] # regional mask list 95 | control_glyph_all = np.zeros([height, width, 3], dtype=np.uint8) # all glyphs 96 | 97 | ## handle each line of text 98 | for text, text_position, text_color in zip(text_list, text_position_list, text_color_list): 99 | 100 | ### glyph image, render text to black background 101 | control_image_glyph = Image.new("RGB", (width, height), (0, 0, 0)) 102 | draw = ImageDraw.Draw(control_image_glyph) 103 | draw.text(text_position, text, font=font, fill=text_color) 104 | 105 | ### get bbox 106 | bbox = draw.textbbox(text_position, text, font=font) 107 | 108 | ### position condition 109 | control_position = np.zeros([height, width], dtype=np.uint8) 110 | control_position[bbox[1]:bbox[3], bbox[0]:bbox[2]] = 255 111 | control_position = Image.fromarray(control_position.astype(np.uint8)) 112 | control_position_list.append(control_position) 113 | 114 | ### regional mask 115 | control_mask_np = np.zeros([height, width], dtype=np.uint8) 116 | control_mask_np[bbox[1]-5:bbox[3]+5, bbox[0]-5:bbox[2]+5] = 255 117 | control_mask = Image.fromarray(control_mask_np.astype(np.uint8)) 118 | control_mask_list.append(control_mask) 119 | 120 | ### accumulate glyph 121 | control_glyph = np.array(control_image_glyph) 122 | control_glyph_all += control_glyph 123 | 124 | ### canny condition 125 | control_image = canny(cv2.cvtColor(np.array(control_image_glyph), cv2.COLOR_RGB2BGR)) 126 | control_image = Image.fromarray(cv2.cvtColor(control_image, cv2.COLOR_BGR2RGB)) 127 | control_image_list.append(control_image) 128 | 129 | control_glyph_all = Image.fromarray(control_glyph_all.astype(np.uint8)) 130 | control_glyph_all = control_glyph_all.convert("RGB") 131 | # control_glyph_all.save("./results/control_glyph.jpg") 132 | 133 | # it is recommended to use words such 'sign', 'billboard', 'banner' in your prompt 134 | # for Englith text, it helps if you add the text to the prompt 135 | prompt = "a street sign in city" 136 | for text in text_list: 137 | if not contains_chinese(text): 138 | prompt += f", '{text}'" 139 | prompt += ", filmfotos, film grain, reversal film photography" # optional 140 | print(prompt) 141 | 142 | generator = torch.Generator(device="cuda").manual_seed(42) 143 | 144 | image = pipe( 145 | prompt, 146 | control_image=control_image_list, # canny 147 | control_position=control_position_list, # position 148 | control_mask=control_mask_list, # regional mask 149 | control_glyph=control_glyph_all, # as init latent, optional, set to None if not used 150 | controlnet_conditioning_scale=1.0, 151 | controlnet_conditioning_step=30, 152 | width=width, 153 | height=height, 154 | num_inference_steps=30, 155 | guidance_scale=3.5, 156 | generator=generator, 157 | ).images[0] 158 | 159 | if not os.path.exists("./results"): 160 | os.makedirs("./results") 161 | image.save(f"./results/result.jpg") 162 | ``` 163 | 164 | For inpainting demo, 165 | 166 | ```python 167 | python infer_inpaint.py 168 | ``` 169 | 170 | 171 | 172 | ## Compatibility to Other Works 173 | - [FLUX.1-dev-ControlNet-Union-Pro-2.0](https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0). 174 |
175 | 176 |
177 | 178 | - [FLUX.1-dev-Controlnet-Inpainting-Beta](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta). 179 |
180 | 181 |
182 | 183 | - [FLUX.1-dev-IP-Adapter](https://huggingface.co/InstantX/FLUX.1-dev-IP-Adapter). 184 |
185 | 186 |
187 | 188 | ## Generated Samples 189 | 190 |
191 | 192 | 193 | 194 | 195 |
196 | 197 | ## 📑 Citation 198 | If you find RepText useful for your research and applications, please cite us using this BibTeX: 199 | ```bibtex 200 | @article{wang2025reptext, 201 | title={RepText: Rendering Visual Text via Replicating}, 202 | author={Wang, Haofan and Xu, Yujia and Li, Yimeng and Li, Junchen and Zhang, Chaowei and Wang, Jing and Yang, Kejia and Chen, Zhibo}, 203 | journal={arXiv preprint arXiv:2504.19724}, 204 | year={2025} 205 | } 206 | ``` 207 | 208 | ## 📧 Contact 209 | If you have any questions, please feel free to reach us at `haofanwang.ai@gmail.com`. 210 | -------------------------------------------------------------------------------- /assets/Arial_Unicode.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shakker-Labs/RepText/2080e4c0965cb00d7a1506da5469692b2d854d50/assets/Arial_Unicode.ttf -------------------------------------------------------------------------------- /assets/adam-jang-8pOTAtyd_Mc-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shakker-Labs/RepText/2080e4c0965cb00d7a1506da5469692b2d854d50/assets/adam-jang-8pOTAtyd_Mc-unsplash.jpg -------------------------------------------------------------------------------- /assets/comparison.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shakker-Labs/RepText/2080e4c0965cb00d7a1506da5469692b2d854d50/assets/comparison.pdf -------------------------------------------------------------------------------- /assets/example1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shakker-Labs/RepText/2080e4c0965cb00d7a1506da5469692b2d854d50/assets/example1.png -------------------------------------------------------------------------------- /assets/example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shakker-Labs/RepText/2080e4c0965cb00d7a1506da5469692b2d854d50/assets/example2.png -------------------------------------------------------------------------------- /assets/example3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shakker-Labs/RepText/2080e4c0965cb00d7a1506da5469692b2d854d50/assets/example3.png -------------------------------------------------------------------------------- /assets/example4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shakker-Labs/RepText/2080e4c0965cb00d7a1506da5469692b2d854d50/assets/example4.png -------------------------------------------------------------------------------- /assets/example5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shakker-Labs/RepText/2080e4c0965cb00d7a1506da5469692b2d854d50/assets/example5.png -------------------------------------------------------------------------------- /assets/infer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shakker-Labs/RepText/2080e4c0965cb00d7a1506da5469692b2d854d50/assets/infer.png -------------------------------------------------------------------------------- /assets/inpaint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shakker-Labs/RepText/2080e4c0965cb00d7a1506da5469692b2d854d50/assets/inpaint.png -------------------------------------------------------------------------------- /assets/ipa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shakker-Labs/RepText/2080e4c0965cb00d7a1506da5469692b2d854d50/assets/ipa.png -------------------------------------------------------------------------------- /assets/train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shakker-Labs/RepText/2080e4c0965cb00d7a1506da5469692b2d854d50/assets/train.png -------------------------------------------------------------------------------- /assets/union.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shakker-Labs/RepText/2080e4c0965cb00d7a1506da5469692b2d854d50/assets/union.png -------------------------------------------------------------------------------- /controlnet_flux.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import Any, Dict, List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.loaders import PeftAdapterMixin 23 | from diffusers.models.attention_processor import AttentionProcessor 24 | from diffusers.models.modeling_utils import ModelMixin 25 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 26 | from diffusers.models.controlnets.controlnet import zero_module 27 | from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed 28 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 29 | from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock 30 | 31 | 32 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 33 | 34 | 35 | @dataclass 36 | class FluxControlNetOutput(BaseOutput): 37 | controlnet_block_samples: Tuple[torch.Tensor] 38 | controlnet_single_block_samples: Tuple[torch.Tensor] 39 | 40 | 41 | class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): 42 | _supports_gradient_checkpointing = True 43 | 44 | @register_to_config 45 | def __init__( 46 | self, 47 | patch_size: int = 1, 48 | in_channels: int = 64, 49 | num_layers: int = 19, 50 | num_single_layers: int = 38, 51 | attention_head_dim: int = 128, 52 | num_attention_heads: int = 24, 53 | joint_attention_dim: int = 4096, 54 | pooled_projection_dim: int = 768, 55 | guidance_embeds: bool = False, 56 | axes_dims_rope: List[int] = [16, 56, 56], 57 | num_mode: int = None, 58 | extra_conditioning_channels: int = 0, 59 | extra_condition_channels: int = 0, 60 | ): 61 | super().__init__() 62 | self.out_channels = in_channels 63 | self.inner_dim = num_attention_heads * attention_head_dim 64 | 65 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) 66 | text_time_guidance_cls = ( 67 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings 68 | ) 69 | self.time_text_embed = text_time_guidance_cls( 70 | embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim 71 | ) 72 | 73 | self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) 74 | self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) 75 | 76 | self.transformer_blocks = nn.ModuleList( 77 | [ 78 | FluxTransformerBlock( 79 | dim=self.inner_dim, 80 | num_attention_heads=num_attention_heads, 81 | attention_head_dim=attention_head_dim, 82 | ) 83 | for i in range(num_layers) 84 | ] 85 | ) 86 | 87 | self.single_transformer_blocks = nn.ModuleList( 88 | [ 89 | FluxSingleTransformerBlock( 90 | dim=self.inner_dim, 91 | num_attention_heads=num_attention_heads, 92 | attention_head_dim=attention_head_dim, 93 | ) 94 | for i in range(num_single_layers) 95 | ] 96 | ) 97 | 98 | # controlnet_blocks 99 | self.controlnet_blocks = nn.ModuleList([]) 100 | for _ in range(len(self.transformer_blocks)): 101 | self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) 102 | 103 | self.controlnet_single_blocks = nn.ModuleList([]) 104 | for _ in range(len(self.single_transformer_blocks)): 105 | self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) 106 | 107 | self.union = num_mode is not None 108 | if self.union: 109 | self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) 110 | 111 | # New Added! 112 | self.controlnet_x_embedder = zero_module( 113 | torch.nn.Linear(in_channels + extra_condition_channels, self.inner_dim) 114 | ) 115 | 116 | self.gradient_checkpointing = False 117 | 118 | @property 119 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 120 | def attn_processors(self): 121 | r""" 122 | Returns: 123 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 124 | indexed by its weight name. 125 | """ 126 | # set recursively 127 | processors = {} 128 | 129 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 130 | if hasattr(module, "get_processor"): 131 | processors[f"{name}.processor"] = module.get_processor() 132 | 133 | for sub_name, child in module.named_children(): 134 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 135 | 136 | return processors 137 | 138 | for name, module in self.named_children(): 139 | fn_recursive_add_processors(name, module, processors) 140 | 141 | return processors 142 | 143 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 144 | def set_attn_processor(self, processor): 145 | r""" 146 | Sets the attention processor to use to compute attention. 147 | 148 | Parameters: 149 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 150 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 151 | for **all** `Attention` layers. 152 | 153 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 154 | processor. This is strongly recommended when setting trainable attention processors. 155 | 156 | """ 157 | count = len(self.attn_processors.keys()) 158 | 159 | if isinstance(processor, dict) and len(processor) != count: 160 | raise ValueError( 161 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 162 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 163 | ) 164 | 165 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 166 | if hasattr(module, "set_processor"): 167 | if not isinstance(processor, dict): 168 | module.set_processor(processor) 169 | else: 170 | module.set_processor(processor.pop(f"{name}.processor")) 171 | 172 | for sub_name, child in module.named_children(): 173 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 174 | 175 | for name, module in self.named_children(): 176 | fn_recursive_attn_processor(name, module, processor) 177 | 178 | def _set_gradient_checkpointing(self, module, value=False): 179 | if hasattr(module, "gradient_checkpointing"): 180 | module.gradient_checkpointing = value 181 | 182 | @classmethod 183 | def from_transformer( 184 | cls, 185 | transformer, 186 | num_layers: int = 4, 187 | num_single_layers: int = 10, 188 | attention_head_dim: int = 128, 189 | num_attention_heads: int = 24, 190 | extra_condition_channels: int = 0, 191 | load_weights_from_transformer=True, 192 | ): 193 | config = transformer.config 194 | config["num_layers"] = num_layers 195 | config["num_single_layers"] = num_single_layers 196 | config["attention_head_dim"] = attention_head_dim 197 | config["num_attention_heads"] = num_attention_heads 198 | config["extra_condition_channels"] = extra_condition_channels 199 | 200 | controlnet = cls(**config) 201 | 202 | if load_weights_from_transformer: 203 | controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) 204 | controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) 205 | controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) 206 | controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict()) 207 | controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) 208 | controlnet.single_transformer_blocks.load_state_dict( 209 | transformer.single_transformer_blocks.state_dict(), strict=False 210 | ) 211 | 212 | controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder) 213 | 214 | return controlnet 215 | 216 | def forward( 217 | self, 218 | hidden_states: torch.Tensor, 219 | controlnet_cond: torch.Tensor, 220 | controlnet_mode: torch.Tensor = None, 221 | conditioning_scale: float = 1.0, 222 | encoder_hidden_states: torch.Tensor = None, 223 | pooled_projections: torch.Tensor = None, 224 | timestep: torch.LongTensor = None, 225 | img_ids: torch.Tensor = None, 226 | txt_ids: torch.Tensor = None, 227 | guidance: torch.Tensor = None, 228 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 229 | return_dict: bool = True, 230 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: 231 | """ 232 | The [`FluxTransformer2DModel`] forward method. 233 | 234 | Args: 235 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): 236 | Input `hidden_states`. 237 | controlnet_cond (`torch.Tensor`): 238 | The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. 239 | controlnet_mode (`torch.Tensor`): 240 | The mode tensor of shape `(batch_size, 1)`. 241 | conditioning_scale (`float`, defaults to `1.0`): 242 | The scale factor for ControlNet outputs. 243 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): 244 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. 245 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected 246 | from the embeddings of input conditions. 247 | timestep ( `torch.LongTensor`): 248 | Used to indicate denoising step. 249 | block_controlnet_hidden_states: (`list` of `torch.Tensor`): 250 | A list of tensors that if specified are added to the residuals of transformer blocks. 251 | joint_attention_kwargs (`dict`, *optional*): 252 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 253 | `self.processor` in 254 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 255 | return_dict (`bool`, *optional*, defaults to `True`): 256 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 257 | tuple. 258 | 259 | Returns: 260 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 261 | `tuple` where the first element is the sample tensor. 262 | """ 263 | if joint_attention_kwargs is not None: 264 | joint_attention_kwargs = joint_attention_kwargs.copy() 265 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 266 | else: 267 | lora_scale = 1.0 268 | 269 | if USE_PEFT_BACKEND: 270 | # weight the lora layers by setting `lora_scale` for each PEFT layer 271 | scale_lora_layers(self, lora_scale) 272 | else: 273 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 274 | logger.warning( 275 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 276 | ) 277 | hidden_states = self.x_embedder(hidden_states) 278 | 279 | # add 280 | hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) 281 | 282 | timestep = timestep.to(hidden_states.dtype) * 1000 283 | if guidance is not None: 284 | guidance = guidance.to(hidden_states.dtype) * 1000 285 | else: 286 | guidance = None 287 | temb = ( 288 | self.time_text_embed(timestep, pooled_projections) 289 | if guidance is None 290 | else self.time_text_embed(timestep, guidance, pooled_projections) 291 | ) 292 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 293 | 294 | if self.union: 295 | # union mode 296 | if controlnet_mode is None: 297 | raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union") 298 | # union mode emb 299 | controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) 300 | encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1) 301 | txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0) 302 | 303 | if txt_ids.ndim == 3: 304 | logger.warning( 305 | "Passing `txt_ids` 3d torch.Tensor is deprecated." 306 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 307 | ) 308 | txt_ids = txt_ids[0] 309 | if img_ids.ndim == 3: 310 | logger.warning( 311 | "Passing `img_ids` 3d torch.Tensor is deprecated." 312 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 313 | ) 314 | img_ids = img_ids[0] 315 | 316 | ids = torch.cat((txt_ids, img_ids), dim=0) 317 | image_rotary_emb = self.pos_embed(ids) 318 | 319 | block_samples = () 320 | for index_block, block in enumerate(self.transformer_blocks): 321 | if self.training and self.gradient_checkpointing: 322 | 323 | def create_custom_forward(module, return_dict=None): 324 | def custom_forward(*inputs): 325 | if return_dict is not None: 326 | return module(*inputs, return_dict=return_dict) 327 | else: 328 | return module(*inputs) 329 | 330 | return custom_forward 331 | 332 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 333 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 334 | create_custom_forward(block), 335 | hidden_states, 336 | encoder_hidden_states, 337 | temb, 338 | image_rotary_emb, 339 | **ckpt_kwargs, 340 | ) 341 | 342 | else: 343 | encoder_hidden_states, hidden_states = block( 344 | hidden_states=hidden_states, 345 | encoder_hidden_states=encoder_hidden_states, 346 | temb=temb, 347 | image_rotary_emb=image_rotary_emb, 348 | ) 349 | block_samples = block_samples + (hidden_states,) 350 | 351 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 352 | 353 | single_block_samples = () 354 | for index_block, block in enumerate(self.single_transformer_blocks): 355 | if self.training and self.gradient_checkpointing: 356 | 357 | def create_custom_forward(module, return_dict=None): 358 | def custom_forward(*inputs): 359 | if return_dict is not None: 360 | return module(*inputs, return_dict=return_dict) 361 | else: 362 | return module(*inputs) 363 | 364 | return custom_forward 365 | 366 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 367 | hidden_states = torch.utils.checkpoint.checkpoint( 368 | create_custom_forward(block), 369 | hidden_states, 370 | temb, 371 | image_rotary_emb, 372 | **ckpt_kwargs, 373 | ) 374 | 375 | else: 376 | hidden_states = block( 377 | hidden_states=hidden_states, 378 | temb=temb, 379 | image_rotary_emb=image_rotary_emb, 380 | ) 381 | single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) 382 | 383 | # controlnet block 384 | controlnet_block_samples = () 385 | for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): 386 | block_sample = controlnet_block(block_sample) 387 | controlnet_block_samples = controlnet_block_samples + (block_sample,) 388 | 389 | controlnet_single_block_samples = () 390 | for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks): 391 | single_block_sample = controlnet_block(single_block_sample) 392 | controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,) 393 | 394 | # scaling 395 | controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] 396 | controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples] 397 | 398 | controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples 399 | controlnet_single_block_samples = ( 400 | None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples 401 | ) 402 | 403 | if USE_PEFT_BACKEND: 404 | # remove `lora_scale` from each PEFT layer 405 | unscale_lora_layers(self, lora_scale) 406 | 407 | if not return_dict: 408 | return (controlnet_block_samples, controlnet_single_block_samples) 409 | 410 | return FluxControlNetOutput( 411 | controlnet_block_samples=controlnet_block_samples, 412 | controlnet_single_block_samples=controlnet_single_block_samples, 413 | ) 414 | 415 | 416 | class FluxMultiControlNetModel(ModelMixin): 417 | r""" 418 | `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel 419 | 420 | This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be 421 | compatible with `FluxControlNetModel`. 422 | 423 | Args: 424 | controlnets (`List[FluxControlNetModel]`): 425 | Provides additional conditioning to the unet during the denoising process. You must set multiple 426 | `FluxControlNetModel` as a list. 427 | """ 428 | 429 | def __init__(self, controlnets, union=False): 430 | super().__init__() 431 | self.nets = nn.ModuleList(controlnets) 432 | self.union = union 433 | 434 | def forward( 435 | self, 436 | hidden_states: torch.FloatTensor, 437 | controlnet_cond: List[torch.tensor], 438 | controlnet_mode: List[torch.tensor], 439 | conditioning_scale: List[float], 440 | encoder_hidden_states: torch.Tensor = None, 441 | pooled_projections: torch.Tensor = None, 442 | timestep: torch.LongTensor = None, 443 | img_ids: torch.Tensor = None, 444 | txt_ids: torch.Tensor = None, 445 | guidance: torch.Tensor = None, 446 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 447 | return_dict: bool = True, 448 | ) -> Union[FluxControlNetOutput, Tuple]: 449 | # ControlNet-Union with multiple conditions 450 | # only load one ControlNet for saving memories 451 | if len(self.nets) == 1: 452 | controlnet = self.nets[0] 453 | for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)): 454 | 455 | block_samples, single_block_samples = controlnet( 456 | hidden_states=hidden_states, 457 | controlnet_cond=image, 458 | controlnet_mode=mode[:, None], 459 | conditioning_scale=scale, 460 | timestep=timestep, 461 | guidance=guidance, 462 | pooled_projections=pooled_projections, 463 | encoder_hidden_states=encoder_hidden_states, 464 | txt_ids=txt_ids, 465 | img_ids=img_ids, 466 | joint_attention_kwargs=joint_attention_kwargs, 467 | return_dict=return_dict, 468 | ) 469 | 470 | # merge samples 471 | if i == 0: 472 | control_block_samples = block_samples 473 | control_single_block_samples = single_block_samples 474 | else: 475 | if block_samples is not None: 476 | control_block_samples = [ 477 | control_block_sample + block_sample 478 | for control_block_sample, block_sample in zip(control_block_samples, block_samples) 479 | ] 480 | 481 | if single_block_samples is not None: 482 | control_single_block_samples = [ 483 | control_single_block_sample + block_sample 484 | for control_single_block_sample, block_sample in zip( 485 | control_single_block_samples, single_block_samples 486 | ) 487 | ] 488 | 489 | # Regular Multi-ControlNets 490 | # load all ControlNets into memories 491 | else: 492 | for i, (image, mode, scale, controlnet) in enumerate( 493 | zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets) 494 | ): 495 | block_samples, single_block_samples = controlnet( 496 | hidden_states=hidden_states, 497 | controlnet_cond=image, 498 | controlnet_mode=mode[:, None], 499 | conditioning_scale=scale, 500 | timestep=timestep, 501 | guidance=guidance, 502 | pooled_projections=pooled_projections, 503 | encoder_hidden_states=encoder_hidden_states, 504 | txt_ids=txt_ids, 505 | img_ids=img_ids, 506 | joint_attention_kwargs=joint_attention_kwargs, 507 | return_dict=return_dict, 508 | ) 509 | 510 | # merge samples 511 | if i == 0: 512 | control_block_samples = block_samples 513 | control_single_block_samples = single_block_samples 514 | else: 515 | if block_samples is not None: 516 | control_block_samples = [ 517 | control_block_sample + block_sample 518 | for control_block_sample, block_sample in zip(control_block_samples, block_samples) 519 | ] 520 | 521 | if single_block_samples is not None: 522 | control_single_block_samples = [ 523 | control_single_block_sample + block_sample 524 | for control_single_block_sample, block_sample in zip( 525 | control_single_block_samples, single_block_samples 526 | ) 527 | ] 528 | 529 | return control_block_samples, control_single_block_samples 530 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from controlnet_flux import FluxControlNetModel 3 | from pipeline_flux_controlnet import FluxControlNetPipeline 4 | 5 | from PIL import Image, ImageDraw, ImageFont 6 | import numpy as np 7 | import cv2 8 | import re 9 | import os 10 | 11 | def contains_chinese(text): 12 | if re.search(r'[\u4e00-\u9fff]', text): 13 | return True 14 | return False 15 | 16 | def canny(img): 17 | low_threshold = 50 18 | high_threshold = 100 19 | img = cv2.Canny(img, low_threshold, high_threshold) 20 | img = img[:, :, None] 21 | img = 255 - np.concatenate([img, img, img], axis=2) 22 | return img 23 | 24 | 25 | if __name__ == "__main__": 26 | 27 | base_model = "black-forest-labs/FLUX.1-dev" 28 | controlnet_model = "Shakker-Labs/RepText" 29 | 30 | controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) 31 | pipe = FluxControlNetPipeline.from_pretrained( 32 | base_model, controlnet=controlnet, torch_dtype=torch.bfloat16 33 | ).to("cuda") 34 | 35 | ## set resolution 36 | width, height = 1024, 1024 37 | 38 | ## set font 39 | font_path = "./assets/Arial_Unicode.ttf" # use your own font 40 | font_size = 80 # it is recommended to use a font size >= 60 41 | font = ImageFont.truetype(font_path, font_size) 42 | 43 | ## set text content, position, color 44 | text_list = ["哩布哩布"] 45 | text_position_list = [(370, 200)] 46 | text_color_list = [(255, 255, 255)] 47 | 48 | # text_list = ["Shakker Labs"] 49 | # text_position_list = [(270, 300)] 50 | # text_color_list = [(255, 255, 255)] 51 | 52 | # text_list = ["Lovart AI", "Always Day 1"] 53 | # text_position_list = [(470, 300), (470, 400)] 54 | # text_color_list = [(255, 255, 255), (255, 255, 255)] 55 | 56 | # text_list = ["以往不谏", "来者可追"] 57 | # text_position_list = [(200, 200), (200, 300)] 58 | # text_color_list = [(255, 255, 255), (255, 255, 255)] 59 | 60 | # text_list = ["Shakker Labs", "RepText"] 61 | # text_position_list = [(200, 200), (200, 300)] 62 | # text_color_list = [(255, 255, 255), (255, 255, 255)] 63 | 64 | ## set controlnet conditions 65 | control_image_list = [] # canny list 66 | control_position_list = [] # position list 67 | control_mask_list = [] # regional mask list 68 | control_glyph_all = np.zeros([height, width, 3], dtype=np.uint8) # all glyphs 69 | 70 | ## handle each line of text 71 | for text, text_position, text_color in zip(text_list, text_position_list, text_color_list): 72 | 73 | ### glyph image, render text to black background 74 | control_image_glyph = Image.new("RGB", (width, height), (0, 0, 0)) 75 | draw = ImageDraw.Draw(control_image_glyph) 76 | draw.text(text_position, text, font=font, fill=text_color) 77 | 78 | ### get bbox 79 | bbox = draw.textbbox(text_position, text, font=font) 80 | 81 | ### position condition 82 | control_position = np.zeros([height, width], dtype=np.uint8) 83 | control_position[bbox[1]:bbox[3], bbox[0]:bbox[2]] = 255 84 | control_position = Image.fromarray(control_position.astype(np.uint8)) 85 | control_position_list.append(control_position) 86 | 87 | ### regional mask 88 | control_mask_np = np.zeros([height, width], dtype=np.uint8) 89 | control_mask_np[bbox[1]-5:bbox[3]+5, bbox[0]-5:bbox[2]+5] = 255 90 | control_mask = Image.fromarray(control_mask_np.astype(np.uint8)) 91 | control_mask_list.append(control_mask) 92 | 93 | ### accumulate glyph 94 | control_glyph = np.array(control_image_glyph) 95 | control_glyph_all += control_glyph 96 | 97 | ### canny condition 98 | control_image = canny(cv2.cvtColor(np.array(control_image_glyph), cv2.COLOR_RGB2BGR)) 99 | control_image = Image.fromarray(cv2.cvtColor(control_image, cv2.COLOR_BGR2RGB)) 100 | control_image_list.append(control_image) 101 | 102 | control_glyph_all = Image.fromarray(control_glyph_all.astype(np.uint8)) 103 | control_glyph_all = control_glyph_all.convert("RGB") 104 | # control_glyph_all.save("./results/control_glyph.jpg") 105 | 106 | # it is recommended to use words such 'sign', 'billboard', 'banner' in your prompt 107 | # for Englith text, it helps if you add the text to the prompt 108 | prompt = "a street sign in city" 109 | for text in text_list: 110 | if not contains_chinese(text): 111 | prompt += f", '{text}'" 112 | prompt += ", filmfotos, film grain, reversal film photography" # optional 113 | print(prompt) 114 | 115 | generator = torch.Generator(device="cuda").manual_seed(42) 116 | 117 | image = pipe( 118 | prompt, 119 | control_image=control_image_list, # canny 120 | control_position=control_position_list, # position 121 | control_mask=control_mask_list, # regional mask 122 | control_glyph=control_glyph_all, # as init latent, optional, set to None if not used 123 | controlnet_conditioning_scale=1.0, 124 | controlnet_conditioning_step=30, 125 | width=width, 126 | height=height, 127 | num_inference_steps=30, 128 | guidance_scale=3.5, 129 | generator=generator, 130 | ).images[0] 131 | 132 | if not os.path.exists("./results"): 133 | os.makedirs("./results") 134 | image.save(f"./results/result.jpg") -------------------------------------------------------------------------------- /infer_inpaint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.utils import load_image 3 | from controlnet_flux import FluxControlNetModel 4 | from pipeline_flux_controlnet_inpaint import FluxControlNetPipeline 5 | 6 | from PIL import Image, ImageDraw, ImageFont 7 | import numpy as np 8 | import cv2 9 | import re 10 | import os 11 | 12 | def contains_chinese(text): 13 | if re.search(r'[\u4e00-\u9fff]', text): 14 | return True 15 | return False 16 | 17 | def canny(img): 18 | low_threshold = 50 19 | high_threshold = 100 20 | img = cv2.Canny(img, low_threshold, high_threshold) 21 | img = img[:, :, None] 22 | img = 255 - np.concatenate([img, img, img], axis=2) 23 | return img 24 | 25 | def resize_img(input_image, max_side=1280, min_side=1024, size=None, 26 | pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): 27 | 28 | w, h = input_image.size 29 | if size is not None: 30 | w_resize_new, h_resize_new = size 31 | else: 32 | ratio = min_side / min(h, w) 33 | w, h = round(ratio*w), round(ratio*h) 34 | ratio = max_side / max(h, w) 35 | input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) 36 | w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number 37 | h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number 38 | input_image = input_image.resize([w_resize_new, h_resize_new], mode) 39 | 40 | if pad_to_max_side: 41 | res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 42 | offset_x = (max_side - w_resize_new) // 2 43 | offset_y = (max_side - h_resize_new) // 2 44 | res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) 45 | input_image = Image.fromarray(res) 46 | return input_image 47 | 48 | def extract_dwpose(img, include_body=True, include_hand=True, include_face=True): 49 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 50 | detected_map = dwprocessor(img, include_body=include_body, include_hand=include_hand, include_face=include_face) 51 | detected_map = Image.fromarray(detected_map) 52 | return detected_map 53 | 54 | if __name__ == "__main__": 55 | 56 | base_model = "black-forest-labs/FLUX.1-dev" 57 | controlnet_model = "Shakker-Labs/RepText" 58 | 59 | controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) 60 | controlnet_inpaint = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", torch_dtype=torch.bfloat16) 61 | 62 | pipe = FluxControlNetPipeline.from_pretrained( 63 | base_model, controlnet=controlnet, controlnet_inpaint=controlnet_inpaint, torch_dtype=torch.bfloat16 64 | ).to("cuda") 65 | 66 | control_image_inpaint = load_image("assets/adam-jang-8pOTAtyd_Mc-unsplash.jpg") 67 | control_image_inpaint = resize_img(control_image_inpaint) 68 | width, height = control_image_inpaint.size 69 | 70 | ## set font 71 | font_path = "./assets/Arial_Unicode.ttf" # use your own font 72 | font_size = 70 # it is recommended to use a font size >= 60 73 | font = ImageFont.truetype(font_path, font_size) 74 | 75 | ## set text content, position, color 76 | text_list = ["哩布哩布"] 77 | text_position_list = [(585, 375)] 78 | text_color_list = [(0, 255, 0)] 79 | 80 | ## set controlnet conditions 81 | control_image_list = [] # canny list 82 | control_position_list = [] # position list 83 | control_mask_list = [] # regional mask list 84 | control_glyph_all = np.zeros([height, width, 3], dtype=np.uint8) # all glyphs 85 | 86 | ## handle each line of text 87 | for text, text_position, text_color in zip(text_list, text_position_list, text_color_list): 88 | 89 | ### glyph image, render text to black background 90 | control_image_glyph = Image.new("RGB", (width, height), (0, 0, 0)) 91 | draw = ImageDraw.Draw(control_image_glyph) 92 | draw.text(text_position, text, font=font, fill=text_color) 93 | 94 | ### get bbox 95 | bbox = draw.textbbox(text_position, text, font=font) 96 | 97 | ### position condition 98 | control_position = np.zeros([height, width], dtype=np.uint8) 99 | control_position[bbox[1]-5:bbox[3]+5, bbox[0]-5:bbox[2]+5] = 255 100 | control_position = Image.fromarray(control_position.astype(np.uint8)) 101 | control_position_list.append(control_position) 102 | 103 | ### regional mask 104 | control_mask_np = np.zeros([height, width], dtype=np.uint8) 105 | control_mask_np[bbox[1]-5:bbox[3]+5, bbox[0]-5:bbox[2]+5] = 255 106 | control_mask = Image.fromarray(control_mask_np.astype(np.uint8)) 107 | control_mask_list.append(control_mask) 108 | 109 | ### accumulate glyph 110 | control_glyph = np.array(control_image_glyph) 111 | control_glyph_all += control_glyph 112 | 113 | ### canny condition 114 | control_image = canny(cv2.cvtColor(np.array(control_image_glyph), cv2.COLOR_RGB2BGR)) 115 | control_image = Image.fromarray(cv2.cvtColor(control_image, cv2.COLOR_BGR2RGB)) 116 | control_image_list.append(control_image) 117 | 118 | control_glyph_all = Image.fromarray(control_glyph_all.astype(np.uint8)) 119 | control_glyph_all = control_glyph_all.convert("RGB") 120 | 121 | # it is recommended to use words such 'sign', 'billboard', 'banner' in your prompt 122 | # for Englith text, it helps if you add the text to the prompt 123 | prompt = "a street photo, wall" 124 | for text in text_list: 125 | if not contains_chinese(text): 126 | prompt += f", '{text}'" 127 | prompt += ", filmfotos, film grain, reversal film photography" # optional 128 | print(prompt) 129 | 130 | generator = torch.Generator(device="cuda").manual_seed(42) 131 | 132 | image = pipe( 133 | prompt, 134 | true_guidance_scale=3.5, # set 1.0 to disable negative guidance 135 | # for text rendering 136 | control_image=control_image_list, # canny 137 | control_position=control_position_list, # position 138 | control_mask=control_mask_list, # regional mask 139 | control_glyph=control_glyph_all, # as init latent, optional, set to None if not used 140 | controlnet_conditioning_scale=1.0, 141 | controlnet_conditioning_step=30, 142 | # for inpainting 143 | control_image_inpaint=control_image_inpaint, 144 | control_mask_inpaint=control_mask, 145 | controlnet_conditioning_scale_inpaint=1.0, 146 | width=width, 147 | height=height, 148 | num_inference_steps=30, 149 | guidance_scale=3.5, 150 | generator=generator, 151 | ).images[0] 152 | 153 | if not os.path.exists("./results"): 154 | os.makedirs("./results") 155 | image.save(f"./results/result_inpaint.jpg") -------------------------------------------------------------------------------- /pipeline_flux_controlnet.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 3 | 4 | import re 5 | import numpy as np 6 | import torch 7 | from transformers import ( 8 | CLIPTextModel, 9 | CLIPTokenizer, 10 | T5EncoderModel, 11 | T5TokenizerFast, 12 | ) 13 | 14 | from diffusers.image_processor import PipelineImageInput, VaeImageProcessor 15 | from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin 16 | from diffusers.models.autoencoders import AutoencoderKL 17 | 18 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 19 | from diffusers.utils import ( 20 | USE_PEFT_BACKEND, 21 | is_torch_xla_available, 22 | logging, 23 | replace_example_docstring, 24 | scale_lora_layers, 25 | unscale_lora_layers, 26 | ) 27 | from diffusers.utils.torch_utils import randn_tensor 28 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 29 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput 30 | 31 | from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel 32 | 33 | # New Added! 34 | import torch.nn.functional as F 35 | import torchvision.transforms as transforms 36 | from controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel 37 | from PIL import Image 38 | 39 | if is_torch_xla_available(): 40 | import torch_xla.core.xla_model as xm 41 | 42 | XLA_AVAILABLE = True 43 | else: 44 | XLA_AVAILABLE = False 45 | 46 | 47 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 48 | 49 | EXAMPLE_DOC_STRING = """ 50 | Examples: 51 | ```py 52 | >>> import torch 53 | >>> from diffusers.utils import load_image 54 | >>> from diffusers import FluxControlNetPipeline 55 | >>> from diffusers import FluxControlNetModel 56 | 57 | >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny-alpha" 58 | >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) 59 | >>> pipe = FluxControlNetPipeline.from_pretrained( 60 | ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16 61 | ... ) 62 | >>> pipe.to("cuda") 63 | >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") 64 | >>> control_mask = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") 65 | >>> prompt = "A girl in city, 25 years old, cool, futuristic" 66 | >>> image = pipe( 67 | ... prompt, 68 | ... control_image=control_image, 69 | ... controlnet_conditioning_scale=0.6, 70 | ... num_inference_steps=28, 71 | ... guidance_scale=3.5, 72 | ... ).images[0] 73 | >>> image.save("flux.png") 74 | ``` 75 | """ 76 | 77 | # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift 78 | def calculate_shift( 79 | image_seq_len, 80 | base_seq_len: int = 256, 81 | max_seq_len: int = 4096, 82 | base_shift: float = 0.5, 83 | max_shift: float = 1.16, 84 | ): 85 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) 86 | b = base_shift - m * base_seq_len 87 | mu = image_seq_len * m + b 88 | return mu 89 | 90 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 91 | def retrieve_latents( 92 | encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" 93 | ): 94 | if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": 95 | return encoder_output.latent_dist.sample(generator) 96 | elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": 97 | return encoder_output.latent_dist.mode() 98 | elif hasattr(encoder_output, "latents"): 99 | return encoder_output.latents 100 | else: 101 | raise AttributeError("Could not access latents of provided encoder_output") 102 | 103 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 104 | def retrieve_timesteps( 105 | scheduler, 106 | num_inference_steps: Optional[int] = None, 107 | device: Optional[Union[str, torch.device]] = None, 108 | timesteps: Optional[List[int]] = None, 109 | sigmas: Optional[List[float]] = None, 110 | **kwargs, 111 | ): 112 | """ 113 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 114 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 115 | 116 | Args: 117 | scheduler (`SchedulerMixin`): 118 | The scheduler to get timesteps from. 119 | num_inference_steps (`int`): 120 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 121 | must be `None`. 122 | device (`str` or `torch.device`, *optional*): 123 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 124 | timesteps (`List[int]`, *optional*): 125 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 126 | `num_inference_steps` and `sigmas` must be `None`. 127 | sigmas (`List[float]`, *optional*): 128 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 129 | `num_inference_steps` and `timesteps` must be `None`. 130 | 131 | Returns: 132 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 133 | second element is the number of inference steps. 134 | """ 135 | if timesteps is not None and sigmas is not None: 136 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 137 | if timesteps is not None: 138 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 139 | if not accepts_timesteps: 140 | raise ValueError( 141 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 142 | f" timestep schedules. Please check whether you are using the correct scheduler." 143 | ) 144 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 145 | timesteps = scheduler.timesteps 146 | num_inference_steps = len(timesteps) 147 | elif sigmas is not None: 148 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 149 | if not accept_sigmas: 150 | raise ValueError( 151 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 152 | f" sigmas schedules. Please check whether you are using the correct scheduler." 153 | ) 154 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 155 | timesteps = scheduler.timesteps 156 | num_inference_steps = len(timesteps) 157 | else: 158 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 159 | timesteps = scheduler.timesteps 160 | return timesteps, num_inference_steps 161 | 162 | 163 | class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): 164 | r""" 165 | The Flux pipeline for text-to-image generation. 166 | 167 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 168 | 169 | Args: 170 | transformer ([`FluxTransformer2DModel`]): 171 | Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. 172 | scheduler ([`FlowMatchEulerDiscreteScheduler`]): 173 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 174 | vae ([`AutoencoderKL`]): 175 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 176 | text_encoder ([`CLIPTextModel`]): 177 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 178 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 179 | text_encoder_2 ([`T5EncoderModel`]): 180 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically 181 | the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. 182 | tokenizer (`CLIPTokenizer`): 183 | Tokenizer of class 184 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). 185 | tokenizer_2 (`T5TokenizerFast`): 186 | Second Tokenizer of class 187 | [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). 188 | """ 189 | 190 | model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" 191 | _optional_components = [] 192 | _callback_tensor_inputs = ["latents", "prompt_embeds"] 193 | 194 | def __init__( 195 | self, 196 | scheduler: FlowMatchEulerDiscreteScheduler, 197 | vae: AutoencoderKL, 198 | text_encoder: CLIPTextModel, 199 | tokenizer: CLIPTokenizer, 200 | text_encoder_2: T5EncoderModel, 201 | tokenizer_2: T5TokenizerFast, 202 | transformer: FluxTransformer2DModel, 203 | controlnet: Union[ 204 | FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel 205 | ], 206 | ): 207 | super().__init__() 208 | 209 | self.register_modules( 210 | vae=vae, 211 | text_encoder=text_encoder, 212 | text_encoder_2=text_encoder_2, 213 | tokenizer=tokenizer, 214 | tokenizer_2=tokenizer_2, 215 | transformer=transformer, 216 | scheduler=scheduler, 217 | controlnet=controlnet, 218 | ) 219 | self.vae_scale_factor = ( 220 | 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 221 | ) 222 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 223 | self.tokenizer_max_length = ( 224 | self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 225 | ) 226 | self.default_sample_size = 64 227 | 228 | @property 229 | def do_classifier_free_guidance(self): 230 | return self._guidance_scale > 1 231 | 232 | def _get_t5_prompt_embeds( 233 | self, 234 | prompt: Union[str, List[str]] = None, 235 | num_images_per_prompt: int = 1, 236 | max_sequence_length: int = 512, 237 | device: Optional[torch.device] = None, 238 | dtype: Optional[torch.dtype] = None, 239 | get_text_to_render: bool = False, 240 | ): 241 | device = device or self._execution_device 242 | dtype = dtype or self.text_encoder.dtype 243 | 244 | prompt = [prompt] if isinstance(prompt, str) else prompt 245 | batch_size = len(prompt) 246 | 247 | text_inputs = self.tokenizer_2( 248 | prompt, 249 | padding="max_length", 250 | max_length=max_sequence_length, 251 | truncation=True, 252 | return_length=False, 253 | return_overflowing_tokens=False, 254 | return_tensors="pt", 255 | ) 256 | 257 | if get_text_to_render: 258 | matches = re.findall(r"'[^']*'", prompt[0]) 259 | if matches == []: 260 | matches = re.findall(r'"[^"]*"', prompt[0]) 261 | 262 | text_to_render = self.tokenizer_2(matches[0], padding="max_length", max_length=self.tokenizer_max_length, truncation=True, return_tensors="pt",)['input_ids'] 263 | text_to_render_start = 0 264 | text_to_render_end = torch.where(text_to_render == 0)[1][0].item() 265 | text_to_render_ids = text_to_render[:, text_to_render_start+1:text_to_render_end-1] 266 | flat_input_ids = text_inputs['input_ids'].flatten() 267 | flat_to_render_ids = text_to_render_ids.flatten() 268 | 269 | windows = flat_input_ids.unfold(0, flat_to_render_ids.size(0), 1) 270 | matches = (windows == flat_to_render_ids).all(dim=1) 271 | 272 | # Find the starting index where the match occurs 273 | if torch.any(matches): 274 | start_index = torch.nonzero(matches).item() 275 | end_index = start_index + flat_to_render_ids.size(0) 276 | else: 277 | raise ValueError("No match found in the input IDs.") 278 | 279 | text_input_ids = text_inputs.input_ids 280 | untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids 281 | 282 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 283 | removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) 284 | logger.warning( 285 | "The following part of your input was truncated because `max_sequence_length` is set to " 286 | f" {max_sequence_length} tokens: {removed_text}" 287 | ) 288 | 289 | prompt_embeds = self.text_encoder_2( 290 | text_input_ids.to(device), output_hidden_states=False, 291 | )[0] 292 | 293 | dtype = self.text_encoder_2.dtype 294 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 295 | 296 | _, seq_len, _ = prompt_embeds.shape 297 | 298 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 299 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 300 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 301 | 302 | if get_text_to_render: 303 | return prompt_embeds, start_index, end_index 304 | else: 305 | return prompt_embeds 306 | 307 | 308 | def _get_clip_prompt_embeds( 309 | self, 310 | prompt: Union[str, List[str]], 311 | num_images_per_prompt: int = 1, 312 | device: Optional[torch.device] = None, 313 | ): 314 | device = device or self._execution_device 315 | 316 | prompt = [prompt] if isinstance(prompt, str) else prompt 317 | batch_size = len(prompt) 318 | 319 | text_inputs = self.tokenizer( 320 | prompt, 321 | padding="max_length", 322 | max_length=self.tokenizer_max_length, 323 | truncation=True, 324 | return_overflowing_tokens=False, 325 | return_length=False, 326 | return_tensors="pt", 327 | ) 328 | 329 | text_input_ids = text_inputs.input_ids 330 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 331 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 332 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) 333 | logger.warning( 334 | "The following part of your input was truncated because CLIP can only handle sequences up to" 335 | f" {self.tokenizer_max_length} tokens: {removed_text}" 336 | ) 337 | prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) 338 | 339 | # Use pooled output of CLIPTextModel 340 | prompt_embeds = prompt_embeds.pooler_output 341 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 342 | 343 | # duplicate text embeddings for each generation per prompt, using mps friendly method 344 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) 345 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 346 | 347 | return prompt_embeds 348 | 349 | def encode_prompt( 350 | self, 351 | prompt: Union[str, List[str]], 352 | prompt_2: Union[str, List[str]], 353 | device: Optional[torch.device] = None, 354 | num_images_per_prompt: int = 1, 355 | prompt_embeds: Optional[torch.FloatTensor] = None, 356 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 357 | max_sequence_length: int = 512, 358 | lora_scale: Optional[float] = None, 359 | get_text_to_render: Optional[bool] = False, 360 | ): 361 | r""" 362 | 363 | Args: 364 | prompt (`str` or `List[str]`, *optional*): 365 | prompt to be encoded 366 | prompt_2 (`str` or `List[str]`, *optional*): 367 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 368 | used in all text-encoders 369 | device: (`torch.device`): 370 | torch device 371 | num_images_per_prompt (`int`): 372 | number of images that should be generated per prompt 373 | do_classifier_free_guidance (`bool`): 374 | whether to use classifier-free guidance or not 375 | negative_prompt (`str` or `List[str]`, *optional*): 376 | negative prompt to be encoded 377 | negative_prompt_2 (`str` or `List[str]`, *optional*): 378 | negative prompt to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is 379 | used in all text-encoders 380 | prompt_embeds (`torch.FloatTensor`, *optional*): 381 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 382 | provided, text embeddings will be generated from `prompt` input argument. 383 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 384 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 385 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 386 | clip_skip (`int`, *optional*): 387 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 388 | the output of the pre-final layer will be used for computing the prompt embeddings. 389 | lora_scale (`float`, *optional*): 390 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 391 | """ 392 | device = device or self._execution_device 393 | 394 | # set lora scale so that monkey patched LoRA 395 | # function of text encoder can correctly access it 396 | if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): 397 | self._lora_scale = lora_scale 398 | 399 | # dynamically adjust the LoRA scale 400 | if self.text_encoder is not None and USE_PEFT_BACKEND: 401 | scale_lora_layers(self.text_encoder, lora_scale) 402 | if self.text_encoder_2 is not None and USE_PEFT_BACKEND: 403 | scale_lora_layers(self.text_encoder_2, lora_scale) 404 | 405 | prompt = [prompt] if isinstance(prompt, str) else prompt 406 | if prompt is not None: 407 | batch_size = len(prompt) 408 | else: 409 | batch_size = prompt_embeds.shape[0] 410 | 411 | if prompt_embeds is None: 412 | prompt_2 = prompt_2 or prompt 413 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 414 | 415 | # We only use the pooled prompt output from the CLIPTextModel 416 | pooled_prompt_embeds = self._get_clip_prompt_embeds( 417 | prompt=prompt, 418 | device=device, 419 | num_images_per_prompt=num_images_per_prompt, 420 | ) 421 | 422 | # We only use the pooled prompt output from the CLIPTextModel 423 | if get_text_to_render: 424 | prompt_embeds, t5_start_index, t5_end_index = self._get_t5_prompt_embeds( 425 | prompt=prompt_2, 426 | num_images_per_prompt=num_images_per_prompt, 427 | max_sequence_length=max_sequence_length, 428 | device=device, 429 | get_text_to_render=True, 430 | ) 431 | else: 432 | prompt_embeds = self._get_t5_prompt_embeds( 433 | prompt=prompt_2, 434 | num_images_per_prompt=num_images_per_prompt, 435 | max_sequence_length=max_sequence_length, 436 | device=device, 437 | ) 438 | 439 | if self.text_encoder is not None: 440 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 441 | # Retrieve the original scale by scaling back the LoRA layers 442 | unscale_lora_layers(self.text_encoder, lora_scale) 443 | 444 | if self.text_encoder_2 is not None: 445 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 446 | # Retrieve the original scale by scaling back the LoRA layers 447 | unscale_lora_layers(self.text_encoder_2, lora_scale) 448 | 449 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to( 450 | device=device, dtype=self.text_encoder.dtype 451 | ) 452 | 453 | if get_text_to_render: 454 | return prompt_embeds, pooled_prompt_embeds, text_ids, t5_start_index, t5_end_index 455 | else: 456 | return prompt_embeds, pooled_prompt_embeds, text_ids 457 | 458 | # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image 459 | def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): 460 | if isinstance(generator, list): 461 | image_latents = [ 462 | retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) 463 | for i in range(image.shape[0]) 464 | ] 465 | image_latents = torch.cat(image_latents, dim=0) 466 | else: 467 | image_latents = retrieve_latents(self.vae.encode(image), generator=generator) 468 | 469 | image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor 470 | 471 | return image_latents 472 | 473 | # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps 474 | def get_timesteps(self, num_inference_steps, strength, device): 475 | # get the original timestep using init_timestep 476 | init_timestep = min(num_inference_steps * strength, num_inference_steps) 477 | 478 | t_start = int(max(num_inference_steps - init_timestep, 0)) 479 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] 480 | if hasattr(self.scheduler, "set_begin_index"): 481 | self.scheduler.set_begin_index(t_start * self.scheduler.order) 482 | 483 | return timesteps, num_inference_steps - t_start 484 | 485 | def check_inputs( 486 | self, 487 | prompt, 488 | prompt_2, 489 | height, 490 | width, 491 | prompt_embeds=None, 492 | pooled_prompt_embeds=None, 493 | callback_on_step_end_tensor_inputs=None, 494 | max_sequence_length=None, 495 | ): 496 | if height % 8 != 0 or width % 8 != 0: 497 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 498 | 499 | if callback_on_step_end_tensor_inputs is not None and not all( 500 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 501 | ): 502 | raise ValueError( 503 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 504 | ) 505 | 506 | if prompt is not None and prompt_embeds is not None: 507 | raise ValueError( 508 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 509 | " only forward one of the two." 510 | ) 511 | elif prompt_2 is not None and prompt_embeds is not None: 512 | raise ValueError( 513 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 514 | " only forward one of the two." 515 | ) 516 | elif prompt is None and prompt_embeds is None: 517 | raise ValueError( 518 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 519 | ) 520 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 521 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 522 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): 523 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") 524 | 525 | if prompt_embeds is not None and pooled_prompt_embeds is None: 526 | raise ValueError( 527 | "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." 528 | ) 529 | 530 | if max_sequence_length is not None and max_sequence_length > 512: 531 | raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") 532 | 533 | @staticmethod 534 | # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids 535 | def _prepare_latent_image_ids(batch_size, height, width, device, dtype): 536 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) 537 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] 538 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] 539 | 540 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape 541 | 542 | latent_image_ids = latent_image_ids.reshape( 543 | latent_image_id_height * latent_image_id_width, latent_image_id_channels 544 | ) 545 | 546 | return latent_image_ids.to(device=device, dtype=dtype) 547 | 548 | @staticmethod 549 | # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents 550 | def _pack_latents(latents, batch_size, num_channels_latents, height, width): 551 | latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) 552 | latents = latents.permute(0, 2, 4, 1, 3, 5) 553 | latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) 554 | 555 | return latents 556 | 557 | @staticmethod 558 | # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents 559 | def _unpack_latents(latents, height, width, vae_scale_factor): 560 | batch_size, num_patches, channels = latents.shape 561 | 562 | height = height // vae_scale_factor 563 | width = width // vae_scale_factor 564 | 565 | latents = latents.view(batch_size, height, width, channels // 4, 2, 2) 566 | latents = latents.permute(0, 3, 1, 4, 2, 5) 567 | 568 | latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) 569 | 570 | return latents 571 | 572 | # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents 573 | def prepare_latents( 574 | self, 575 | batch_size, 576 | num_channels_latents, 577 | height, 578 | width, 579 | dtype, 580 | device, 581 | generator, 582 | latents=None, 583 | ): 584 | height = 2 * (int(height) // self.vae_scale_factor) 585 | width = 2 * (int(width) // self.vae_scale_factor) 586 | 587 | shape = (batch_size, num_channels_latents, height, width) 588 | 589 | if latents is not None: 590 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 591 | return latents.to(device=device, dtype=dtype), latent_image_ids 592 | 593 | if isinstance(generator, list) and len(generator) != batch_size: 594 | raise ValueError( 595 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 596 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 597 | ) 598 | 599 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 600 | 601 | latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) 602 | 603 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 604 | 605 | return latents, latent_image_ids 606 | 607 | # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents 608 | def prepare_latents_reptext( 609 | self, 610 | image, 611 | batch_size, 612 | num_channels_latents, 613 | height, 614 | width, 615 | dtype, 616 | device, 617 | generator, 618 | latents=None, 619 | ): 620 | height = 2 * (int(height) // self.vae_scale_factor) 621 | width = 2 * (int(width) // self.vae_scale_factor) 622 | 623 | image = image.to(device=device, dtype=dtype) # torch.Size([1, 3, height, width]) 624 | image_latents = self._encode_vae_image(image=image, generator=generator) 625 | 626 | if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: 627 | # expand init_latents for batch_size 628 | additional_image_per_prompt = batch_size // image_latents.shape[0] 629 | image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) 630 | elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: 631 | raise ValueError( 632 | f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." 633 | ) 634 | else: 635 | image_latents = torch.cat([image_latents], dim=0) 636 | 637 | shape = (batch_size, num_channels_latents, height, width) 638 | 639 | if latents is not None: 640 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 641 | return latents.to(device=device, dtype=dtype), latent_image_ids 642 | 643 | noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # torch.Size([1, 16, 160, 120]) 644 | 645 | glyph_mask = (image > 0).any(dim=1, keepdim=True) 646 | glyph_mask = glyph_mask.repeat(1, 16, 1, 1).float() 647 | glyph_mask = F.interpolate(glyph_mask, size=(noise.shape[-2], noise.shape[-1]), mode='bilinear', align_corners=False) # torch.Size([1, 16, 160, 120]) 648 | glyph_mask[glyph_mask>0] = 1 649 | glyph_mask[glyph_mask<0] = 0 650 | 651 | glyph_mask = glyph_mask > 0 652 | result = torch.zeros_like(noise) 653 | result[glyph_mask] = 0.10 * image_latents[glyph_mask] + 1.0 * noise[glyph_mask] 654 | result[~glyph_mask] = noise[~glyph_mask] 655 | 656 | latents = self._pack_latents(noise, batch_size, num_channels_latents, height, width) 657 | 658 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 659 | 660 | return latents, latent_image_ids 661 | 662 | # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image 663 | def prepare_image( 664 | self, 665 | image, 666 | width, 667 | height, 668 | batch_size, 669 | num_images_per_prompt, 670 | device, 671 | dtype, 672 | image_position=None, 673 | do_classifier_free_guidance=False, 674 | guess_mode=False, 675 | ): 676 | # prepare image 677 | if isinstance(image, torch.Tensor): 678 | pass 679 | else: 680 | image = self.image_processor.preprocess(image, height=height, width=width) 681 | image_batch_size = image.shape[0] 682 | if image_batch_size == 1: 683 | repeat_by = batch_size 684 | else: 685 | # image batch size is the same as prompt batch size 686 | repeat_by = num_images_per_prompt 687 | image = image.repeat_interleave(repeat_by, dim=0) 688 | image = image.to(device=device, dtype=dtype) 689 | 690 | # prepare position image 691 | if isinstance(image_position, torch.Tensor): 692 | pass 693 | else: 694 | image_position = self.image_processor.preprocess(image_position, height=height, width=width) 695 | image_batch_size = image_position.shape[0] 696 | if image_batch_size == 1: 697 | repeat_by = batch_size 698 | else: 699 | repeat_by = num_images_per_prompt 700 | image_position = image_position.repeat_interleave(repeat_by, dim=0) 701 | image_position = image_position.to(device=device, dtype=dtype) 702 | image_position = image_position.repeat(1,3,1,1) 703 | 704 | # Encode to latents 705 | image_latents = self.vae.encode(image.to(self.vae.dtype)).latent_dist.sample() 706 | image_latents = ( 707 | image_latents - self.vae.config.shift_factor 708 | ) * self.vae.config.scaling_factor 709 | image_latents = image_latents.to(dtype) 710 | 711 | position_image_latents = self.vae.encode(image_position.to(self.vae.dtype)).latent_dist.sample() 712 | position_image_latents = ( 713 | position_image_latents - self.vae.config.shift_factor 714 | ) * self.vae.config.scaling_factor 715 | position_image_latents = position_image_latents.to(dtype) 716 | 717 | control_image = torch.cat([image_latents, position_image_latents], dim=1) 718 | 719 | # Pack cond latents 720 | packed_control_image = self._pack_latents( 721 | control_image, 722 | batch_size * num_images_per_prompt, 723 | control_image.shape[1], 724 | control_image.shape[2], 725 | control_image.shape[3], 726 | ) 727 | 728 | if do_classifier_free_guidance: 729 | packed_control_image = torch.cat([packed_control_image] * 2) 730 | 731 | return packed_control_image, height, width 732 | 733 | @property 734 | def guidance_scale(self): 735 | return self._guidance_scale 736 | 737 | @property 738 | def joint_attention_kwargs(self): 739 | return self._joint_attention_kwargs 740 | 741 | @property 742 | def num_timesteps(self): 743 | return self._num_timesteps 744 | 745 | @property 746 | def interrupt(self): 747 | return self._interrupt 748 | 749 | @torch.no_grad() 750 | @replace_example_docstring(EXAMPLE_DOC_STRING) 751 | def __call__( 752 | self, 753 | prompt: Union[str, List[str]] = None, 754 | prompt_2: Optional[Union[str, List[str]]] = None, 755 | height: Optional[int] = None, 756 | width: Optional[int] = None, 757 | num_inference_steps: int = 28, 758 | timesteps: List[int] = None, 759 | guidance_scale: float = 7.0, 760 | control_guidance_start: Union[float, List[float]] = 0.0, 761 | control_guidance_end: Union[float, List[float]] = 1.0, 762 | control_image: PipelineImageInput = None, 763 | control_mode: Optional[Union[int, List[int]]] = None, 764 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0, 765 | controlnet_conditioning_step: int = 30, 766 | num_images_per_prompt: Optional[int] = 1, 767 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 768 | latents: Optional[torch.FloatTensor] = None, 769 | prompt_embeds: Optional[torch.FloatTensor] = None, 770 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 771 | output_type: Optional[str] = "pil", 772 | return_dict: bool = True, 773 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 774 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 775 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 776 | max_sequence_length: int = 512, 777 | 778 | control_mask: Optional[torch.FloatTensor] = None, 779 | control_position: Optional[torch.FloatTensor] = None, 780 | control_glyph: Optional[torch.FloatTensor] = None, 781 | ): 782 | r""" 783 | Function invoked when calling the pipeline for generation. 784 | 785 | Args: 786 | prompt (`str` or `List[str]`, *optional*): 787 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 788 | instead. 789 | prompt_2 (`str` or `List[str]`, *optional*): 790 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 791 | will be used instead 792 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 793 | The height in pixels of the generated image. This is set to 1024 by default for the best results. 794 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 795 | The width in pixels of the generated image. This is set to 1024 by default for the best results. 796 | num_inference_steps (`int`, *optional*, defaults to 50): 797 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 798 | expense of slower inference. 799 | timesteps (`List[int]`, *optional*): 800 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 801 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 802 | passed will be used. Must be in descending order. 803 | guidance_scale (`float`, *optional*, defaults to 7.0): 804 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 805 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 806 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 807 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 808 | usually at the expense of lower image quality. 809 | control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: 810 | `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): 811 | The ControlNet input condition to provide guidance to the `unet` for generation. If the type is 812 | specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted 813 | as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or 814 | width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, 815 | images must be passed as a list such that each element of the list can be correctly batched for input 816 | to a single ControlNet. 817 | controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): 818 | The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added 819 | to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set 820 | the corresponding scale as a list. 821 | control_mode (`int` or `List[int]`,, *optional*, defaults to None): 822 | The control mode when applying ControlNet-Union. 823 | num_images_per_prompt (`int`, *optional*, defaults to 1): 824 | The number of images to generate per prompt. 825 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 826 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 827 | to make generation deterministic. 828 | latents (`torch.FloatTensor`, *optional*): 829 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 830 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 831 | tensor will ge generated by sampling using the supplied random `generator`. 832 | prompt_embeds (`torch.FloatTensor`, *optional*): 833 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 834 | provided, text embeddings will be generated from `prompt` input argument. 835 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 836 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 837 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 838 | output_type (`str`, *optional*, defaults to `"pil"`): 839 | The output format of the generate image. Choose between 840 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 841 | return_dict (`bool`, *optional*, defaults to `True`): 842 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. 843 | joint_attention_kwargs (`dict`, *optional*): 844 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 845 | `self.processor` in 846 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 847 | callback_on_step_end (`Callable`, *optional*): 848 | A function that calls at the end of each denoising steps during the inference. The function is called 849 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 850 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 851 | `callback_on_step_end_tensor_inputs`. 852 | callback_on_step_end_tensor_inputs (`List`, *optional*): 853 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 854 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 855 | `._callback_tensor_inputs` attribute of your pipeline class. 856 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. 857 | 858 | Examples: 859 | 860 | Returns: 861 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` 862 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated 863 | images. 864 | """ 865 | 866 | height = height or self.default_sample_size * self.vae_scale_factor 867 | width = width or self.default_sample_size * self.vae_scale_factor 868 | 869 | if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): 870 | control_guidance_start = len(control_guidance_end) * [control_guidance_start] 871 | elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): 872 | control_guidance_end = len(control_guidance_start) * [control_guidance_end] 873 | elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): 874 | #mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 875 | mult = len(control_image) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 876 | control_guidance_start, control_guidance_end = ( 877 | mult * [control_guidance_start], 878 | mult * [control_guidance_end], 879 | ) 880 | 881 | # 1. Check inputs. Raise error if not correct 882 | self.check_inputs( 883 | prompt, 884 | prompt_2, 885 | height, 886 | width, 887 | prompt_embeds=prompt_embeds, 888 | pooled_prompt_embeds=pooled_prompt_embeds, 889 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 890 | max_sequence_length=max_sequence_length, 891 | ) 892 | 893 | self._guidance_scale = guidance_scale 894 | self._joint_attention_kwargs = joint_attention_kwargs 895 | self._interrupt = False 896 | 897 | # 2. Define call parameters 898 | if prompt is not None and isinstance(prompt, str): 899 | batch_size = 1 900 | elif prompt is not None and isinstance(prompt, list): 901 | batch_size = len(prompt) 902 | else: 903 | batch_size = prompt_embeds.shape[0] 904 | 905 | device = self._execution_device 906 | dtype = self.transformer.dtype 907 | 908 | lora_scale = ( 909 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 910 | ) 911 | 912 | ( 913 | prompt_embeds, 914 | pooled_prompt_embeds, 915 | text_ids 916 | ) = self.encode_prompt( 917 | prompt=prompt, 918 | prompt_2=prompt_2, 919 | prompt_embeds=prompt_embeds, 920 | pooled_prompt_embeds=pooled_prompt_embeds, 921 | device=device, 922 | num_images_per_prompt=num_images_per_prompt, 923 | max_sequence_length=max_sequence_length, 924 | lora_scale=lora_scale, 925 | ) 926 | 927 | # 3. Prepare control image 928 | if isinstance(self.controlnet, FluxControlNetModel): 929 | control_image_list = [] 930 | for control_image_, control_position_ in zip(control_image, control_position): 931 | control_image_, height, width = self.prepare_image( 932 | image=control_image_, 933 | image_position=control_position_, 934 | width=width, 935 | height=height, 936 | batch_size=batch_size * num_images_per_prompt, 937 | num_images_per_prompt=num_images_per_prompt, 938 | device=device, 939 | dtype=dtype, 940 | ) 941 | 942 | control_image_list.append(control_image_) 943 | 944 | # 4. Prepare latent variables 945 | num_channels_latents = self.transformer.config.in_channels // 4 946 | 947 | # 5. Prepare timesteps 948 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 949 | 950 | image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) 951 | 952 | mu = calculate_shift( 953 | image_seq_len, 954 | self.scheduler.config.base_image_seq_len, 955 | self.scheduler.config.max_image_seq_len, 956 | self.scheduler.config.base_shift, 957 | self.scheduler.config.max_shift, 958 | ) 959 | 960 | timesteps, num_inference_steps = retrieve_timesteps( 961 | self.scheduler, 962 | num_inference_steps, 963 | device, 964 | timesteps, 965 | sigmas, 966 | mu=mu, 967 | ) 968 | 969 | if control_glyph is not None: 970 | init_image = self.image_processor.preprocess(control_glyph, height=height, width=width) 971 | init_image = init_image.to(dtype=torch.float32) 972 | latents, latent_image_ids = self.prepare_latents_reptext( 973 | init_image, 974 | batch_size * num_images_per_prompt, 975 | num_channels_latents, 976 | height, 977 | width, 978 | prompt_embeds.dtype, 979 | device, 980 | generator, 981 | None, 982 | ) 983 | else: 984 | latents, latent_image_ids = self.prepare_latents( 985 | batch_size * num_images_per_prompt, 986 | num_channels_latents, 987 | height, 988 | width, 989 | prompt_embeds.dtype, 990 | device, 991 | generator, 992 | latents, 993 | ) 994 | 995 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 996 | self._num_timesteps = len(timesteps) 997 | 998 | # 6. Create tensor stating which controlnets to keep 999 | controlnet_keep = [] 1000 | for i in range(len(timesteps)): 1001 | keeps = [ 1002 | 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) 1003 | for s, e in zip(control_guidance_start, control_guidance_end) 1004 | ] 1005 | controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) 1006 | 1007 | control_mask_list = [] 1008 | if control_mask is not None: 1009 | for control_mask_ in control_mask: 1010 | region_mask = torch.from_numpy(np.array(control_mask_)) / 255. # (height, width) 1011 | mask = F.interpolate(region_mask[None, None], scale_factor=1/16, mode='bilinear').reshape([1, -1, 1]) 1012 | control_mask_ = mask.to(device=latents.device, dtype=latents.dtype) 1013 | control_mask_list.append(control_mask_) 1014 | 1015 | # 6. Denoising loop 1016 | with self.progress_bar(total=num_inference_steps) as progress_bar: 1017 | for i, t in enumerate(timesteps): 1018 | 1019 | if self.interrupt: 1020 | continue 1021 | 1022 | latent_model_input = latents 1023 | 1024 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 1025 | timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) 1026 | 1027 | # handle guidance 1028 | if self.transformer.config.guidance_embeds: 1029 | guidance = torch.tensor([guidance_scale], device=device) 1030 | guidance = guidance.expand(latent_model_input.shape[0]) 1031 | else: 1032 | guidance = None 1033 | 1034 | control_block_samples = None 1035 | control_single_block_samples = None 1036 | 1037 | for control_index in range(len(control_image_list)): 1038 | 1039 | control_image = control_image_list[control_index] 1040 | control_mask = control_mask_list[control_index] if len(control_mask_list)>0 else None 1041 | 1042 | if i < controlnet_conditioning_step: 1043 | controlnet_block_samples, controlnet_single_block_samples = self.controlnet( 1044 | hidden_states=latent_model_input, 1045 | controlnet_cond=control_image, 1046 | controlnet_mode=control_mode, 1047 | conditioning_scale=controlnet_conditioning_scale, 1048 | timestep=timestep / 1000, 1049 | guidance=guidance, 1050 | pooled_projections=pooled_prompt_embeds, 1051 | encoder_hidden_states=prompt_embeds, 1052 | txt_ids=text_ids, 1053 | img_ids=latent_image_ids, 1054 | joint_attention_kwargs=self.joint_attention_kwargs, 1055 | return_dict=False, 1056 | ) 1057 | else: 1058 | controlnet_block_samples, controlnet_single_block_samples = None, None 1059 | 1060 | if controlnet_block_samples is not None: 1061 | if control_mask is not None: 1062 | controlnet_block_samples = [control_mask*sample.to(dtype=latents.dtype, device=device) for sample in controlnet_block_samples] 1063 | else: 1064 | controlnet_block_samples = [sample.to(dtype=latents.dtype, device=device) for sample in controlnet_block_samples] 1065 | if controlnet_single_block_samples is not None: 1066 | if control_mask is not None: 1067 | controlnet_single_block_samples = [control_mask*sample.to(dtype=latents.dtype, device=device) for sample in controlnet_single_block_samples] 1068 | else: 1069 | controlnet_single_block_samples = [sample.to(dtype=latents.dtype, device=device) for sample in controlnet_single_block_samples] 1070 | 1071 | # merge samples 1072 | if control_index == 0: 1073 | control_block_samples = controlnet_block_samples 1074 | control_single_block_samples = controlnet_single_block_samples 1075 | else: 1076 | if controlnet_block_samples is not None and control_block_samples is not None: 1077 | control_block_samples = [ 1078 | control_block_sample + block_sample 1079 | for control_block_sample, block_sample in zip(control_block_samples, controlnet_block_samples) 1080 | ] 1081 | if controlnet_single_block_samples is not None and control_single_block_samples is not None: 1082 | control_single_block_samples = [ 1083 | control_single_block_sample + block_sample 1084 | for control_single_block_sample, block_sample in zip( 1085 | control_single_block_samples, controlnet_single_block_samples 1086 | ) 1087 | ] 1088 | 1089 | controlnet_block_samples = control_block_samples 1090 | controlnet_single_block_samples = control_single_block_samples 1091 | 1092 | noise_pred = self.transformer( 1093 | hidden_states=latent_model_input, 1094 | timestep=timestep / 1000, 1095 | guidance=guidance, 1096 | pooled_projections=pooled_prompt_embeds, 1097 | encoder_hidden_states=prompt_embeds, 1098 | controlnet_block_samples=controlnet_block_samples, 1099 | controlnet_single_block_samples=controlnet_single_block_samples, 1100 | txt_ids=text_ids, 1101 | img_ids=latent_image_ids, 1102 | joint_attention_kwargs=self.joint_attention_kwargs, 1103 | return_dict=False, 1104 | )[0] 1105 | 1106 | # compute the previous noisy sample x_t -> x_t-1 1107 | latents_dtype = latents.dtype 1108 | 1109 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 1110 | 1111 | if latents.dtype != latents_dtype: 1112 | if torch.backends.mps.is_available(): 1113 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 1114 | latents = latents.to(latents_dtype) 1115 | 1116 | if callback_on_step_end is not None: 1117 | callback_kwargs = {} 1118 | for k in callback_on_step_end_tensor_inputs: 1119 | callback_kwargs[k] = locals()[k] 1120 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 1121 | 1122 | latents = callback_outputs.pop("latents", latents) 1123 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 1124 | 1125 | # call the callback, if provided 1126 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 1127 | progress_bar.update() 1128 | 1129 | if XLA_AVAILABLE: 1130 | xm.mark_step() 1131 | 1132 | if output_type == "latent": 1133 | image = latents 1134 | 1135 | else: 1136 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 1137 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 1138 | 1139 | image = self.vae.decode(latents, return_dict=False)[0] 1140 | image = self.image_processor.postprocess(image, output_type=output_type) 1141 | 1142 | # Offload all models 1143 | self.maybe_free_model_hooks() 1144 | 1145 | if not return_dict: 1146 | return (image,) 1147 | 1148 | return FluxPipelineOutput(images=image) 1149 | -------------------------------------------------------------------------------- /pipeline_flux_controlnet_inpaint.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | from transformers import ( 7 | CLIPTextModel, 8 | CLIPTokenizer, 9 | T5EncoderModel, 10 | T5TokenizerFast, 11 | ) 12 | 13 | from diffusers.image_processor import PipelineImageInput, VaeImageProcessor 14 | from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin 15 | from diffusers.models.autoencoders import AutoencoderKL 16 | 17 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 18 | from diffusers.utils import ( 19 | USE_PEFT_BACKEND, 20 | is_torch_xla_available, 21 | logging, 22 | replace_example_docstring, 23 | scale_lora_layers, 24 | unscale_lora_layers, 25 | ) 26 | from diffusers.utils.torch_utils import randn_tensor 27 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 28 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput 29 | 30 | from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel 31 | 32 | # New Added! 33 | import torch.nn.functional as F 34 | import torchvision.transforms as transforms 35 | from controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel 36 | 37 | gaussian_blur = transforms.GaussianBlur(kernel_size=5, sigma=1) 38 | 39 | 40 | if is_torch_xla_available(): 41 | import torch_xla.core.xla_model as xm 42 | 43 | XLA_AVAILABLE = True 44 | else: 45 | XLA_AVAILABLE = False 46 | 47 | 48 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 49 | 50 | EXAMPLE_DOC_STRING = """ 51 | Examples: 52 | ```py 53 | >>> import torch 54 | >>> from diffusers.utils import load_image 55 | >>> from diffusers import FluxControlNetPipeline 56 | >>> from diffusers import FluxControlNetModel 57 | 58 | >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny-alpha" 59 | >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) 60 | >>> pipe = FluxControlNetPipeline.from_pretrained( 61 | ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16 62 | ... ) 63 | >>> pipe.to("cuda") 64 | >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") 65 | >>> control_mask = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") 66 | >>> prompt = "A girl in city, 25 years old, cool, futuristic" 67 | >>> image = pipe( 68 | ... prompt, 69 | ... control_image=control_image, 70 | ... controlnet_conditioning_scale=0.6, 71 | ... num_inference_steps=28, 72 | ... guidance_scale=3.5, 73 | ... ).images[0] 74 | >>> image.save("flux.png") 75 | ``` 76 | """ 77 | 78 | # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift 79 | def calculate_shift( 80 | image_seq_len, 81 | base_seq_len: int = 256, 82 | max_seq_len: int = 4096, 83 | base_shift: float = 0.5, 84 | max_shift: float = 1.16, 85 | ): 86 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) 87 | b = base_shift - m * base_seq_len 88 | mu = image_seq_len * m + b 89 | return mu 90 | 91 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 92 | def retrieve_latents( 93 | encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" 94 | ): 95 | if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": 96 | return encoder_output.latent_dist.sample(generator) 97 | elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": 98 | return encoder_output.latent_dist.mode() 99 | elif hasattr(encoder_output, "latents"): 100 | return encoder_output.latents 101 | else: 102 | raise AttributeError("Could not access latents of provided encoder_output") 103 | 104 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 105 | def retrieve_timesteps( 106 | scheduler, 107 | num_inference_steps: Optional[int] = None, 108 | device: Optional[Union[str, torch.device]] = None, 109 | timesteps: Optional[List[int]] = None, 110 | sigmas: Optional[List[float]] = None, 111 | **kwargs, 112 | ): 113 | """ 114 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 115 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 116 | 117 | Args: 118 | scheduler (`SchedulerMixin`): 119 | The scheduler to get timesteps from. 120 | num_inference_steps (`int`): 121 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 122 | must be `None`. 123 | device (`str` or `torch.device`, *optional*): 124 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 125 | timesteps (`List[int]`, *optional*): 126 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 127 | `num_inference_steps` and `sigmas` must be `None`. 128 | sigmas (`List[float]`, *optional*): 129 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 130 | `num_inference_steps` and `timesteps` must be `None`. 131 | 132 | Returns: 133 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 134 | second element is the number of inference steps. 135 | """ 136 | if timesteps is not None and sigmas is not None: 137 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 138 | if timesteps is not None: 139 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 140 | if not accepts_timesteps: 141 | raise ValueError( 142 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 143 | f" timestep schedules. Please check whether you are using the correct scheduler." 144 | ) 145 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 146 | timesteps = scheduler.timesteps 147 | num_inference_steps = len(timesteps) 148 | elif sigmas is not None: 149 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 150 | if not accept_sigmas: 151 | raise ValueError( 152 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 153 | f" sigmas schedules. Please check whether you are using the correct scheduler." 154 | ) 155 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 156 | timesteps = scheduler.timesteps 157 | num_inference_steps = len(timesteps) 158 | else: 159 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 160 | timesteps = scheduler.timesteps 161 | return timesteps, num_inference_steps 162 | 163 | 164 | class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): 165 | r""" 166 | The Flux pipeline for text-to-image generation. 167 | 168 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 169 | 170 | Args: 171 | transformer ([`FluxTransformer2DModel`]): 172 | Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. 173 | scheduler ([`FlowMatchEulerDiscreteScheduler`]): 174 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 175 | vae ([`AutoencoderKL`]): 176 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 177 | text_encoder ([`CLIPTextModel`]): 178 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 179 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 180 | text_encoder_2 ([`T5EncoderModel`]): 181 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically 182 | the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. 183 | tokenizer (`CLIPTokenizer`): 184 | Tokenizer of class 185 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). 186 | tokenizer_2 (`T5TokenizerFast`): 187 | Second Tokenizer of class 188 | [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). 189 | """ 190 | 191 | model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" 192 | _optional_components = [] 193 | _callback_tensor_inputs = ["latents", "prompt_embeds"] 194 | 195 | def __init__( 196 | self, 197 | scheduler: FlowMatchEulerDiscreteScheduler, 198 | vae: AutoencoderKL, 199 | text_encoder: CLIPTextModel, 200 | tokenizer: CLIPTokenizer, 201 | text_encoder_2: T5EncoderModel, 202 | tokenizer_2: T5TokenizerFast, 203 | transformer: FluxTransformer2DModel, 204 | controlnet: Union[ 205 | FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel 206 | ], 207 | controlnet_inpaint: Union[ 208 | FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel 209 | ], 210 | ): 211 | super().__init__() 212 | 213 | self.register_modules( 214 | vae=vae, 215 | text_encoder=text_encoder, 216 | text_encoder_2=text_encoder_2, 217 | tokenizer=tokenizer, 218 | tokenizer_2=tokenizer_2, 219 | transformer=transformer, 220 | scheduler=scheduler, 221 | controlnet=controlnet, 222 | controlnet_inpaint=controlnet_inpaint, 223 | ) 224 | self.vae_scale_factor = ( 225 | 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 226 | ) 227 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 228 | self.mask_processor = VaeImageProcessor( 229 | vae_scale_factor=self.vae_scale_factor, 230 | do_resize=True, 231 | do_convert_grayscale=True, 232 | do_normalize=False, 233 | do_binarize=True, 234 | ) 235 | self.tokenizer_max_length = ( 236 | self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 237 | ) 238 | self.default_sample_size = 64 239 | 240 | @property 241 | def do_classifier_free_guidance(self): 242 | return self._guidance_scale > 1 243 | 244 | def _get_t5_prompt_embeds( 245 | self, 246 | prompt: Union[str, List[str]] = None, 247 | num_images_per_prompt: int = 1, 248 | max_sequence_length: int = 512, 249 | device: Optional[torch.device] = None, 250 | dtype: Optional[torch.dtype] = None, 251 | ): 252 | device = device or self._execution_device 253 | dtype = dtype or self.text_encoder.dtype 254 | 255 | prompt = [prompt] if isinstance(prompt, str) else prompt 256 | batch_size = len(prompt) 257 | 258 | text_inputs = self.tokenizer_2( 259 | prompt, 260 | padding="max_length", 261 | max_length=max_sequence_length, 262 | truncation=True, 263 | return_length=False, 264 | return_overflowing_tokens=False, 265 | return_tensors="pt", 266 | ) 267 | text_input_ids = text_inputs.input_ids 268 | untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids 269 | 270 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 271 | removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) 272 | logger.warning( 273 | "The following part of your input was truncated because `max_sequence_length` is set to " 274 | f" {max_sequence_length} tokens: {removed_text}" 275 | ) 276 | 277 | prompt_embeds = self.text_encoder_2( 278 | text_input_ids.to(device), output_hidden_states=False, 279 | )[0] 280 | 281 | dtype = self.text_encoder_2.dtype 282 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 283 | 284 | _, seq_len, _ = prompt_embeds.shape 285 | 286 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 287 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 288 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 289 | 290 | return prompt_embeds 291 | 292 | def _get_clip_prompt_embeds( 293 | self, 294 | prompt: Union[str, List[str]], 295 | num_images_per_prompt: int = 1, 296 | device: Optional[torch.device] = None, 297 | ): 298 | device = device or self._execution_device 299 | 300 | prompt = [prompt] if isinstance(prompt, str) else prompt 301 | batch_size = len(prompt) 302 | 303 | text_inputs = self.tokenizer( 304 | prompt, 305 | padding="max_length", 306 | max_length=self.tokenizer_max_length, 307 | truncation=True, 308 | return_overflowing_tokens=False, 309 | return_length=False, 310 | return_tensors="pt", 311 | ) 312 | 313 | text_input_ids = text_inputs.input_ids 314 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 315 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 316 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) 317 | logger.warning( 318 | "The following part of your input was truncated because CLIP can only handle sequences up to" 319 | f" {self.tokenizer_max_length} tokens: {removed_text}" 320 | ) 321 | prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) 322 | 323 | # Use pooled output of CLIPTextModel 324 | prompt_embeds = prompt_embeds.pooler_output 325 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 326 | 327 | # duplicate text embeddings for each generation per prompt, using mps friendly method 328 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) 329 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 330 | 331 | return prompt_embeds 332 | 333 | def encode_prompt( 334 | self, 335 | prompt: Union[str, List[str]], 336 | prompt_2: Union[str, List[str]], 337 | device: Optional[torch.device] = None, 338 | num_images_per_prompt: int = 1, 339 | do_classifier_free_guidance: bool = True, 340 | negative_prompt: Optional[Union[str, List[str]]] = None, 341 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 342 | prompt_embeds: Optional[torch.FloatTensor] = None, 343 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 344 | max_sequence_length: int = 512, 345 | lora_scale: Optional[float] = None, 346 | ): 347 | r""" 348 | 349 | Args: 350 | prompt (`str` or `List[str]`, *optional*): 351 | prompt to be encoded 352 | prompt_2 (`str` or `List[str]`, *optional*): 353 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 354 | used in all text-encoders 355 | device: (`torch.device`): 356 | torch device 357 | num_images_per_prompt (`int`): 358 | number of images that should be generated per prompt 359 | do_classifier_free_guidance (`bool`): 360 | whether to use classifier-free guidance or not 361 | negative_prompt (`str` or `List[str]`, *optional*): 362 | negative prompt to be encoded 363 | negative_prompt_2 (`str` or `List[str]`, *optional*): 364 | negative prompt to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is 365 | used in all text-encoders 366 | prompt_embeds (`torch.FloatTensor`, *optional*): 367 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 368 | provided, text embeddings will be generated from `prompt` input argument. 369 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 370 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 371 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 372 | clip_skip (`int`, *optional*): 373 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 374 | the output of the pre-final layer will be used for computing the prompt embeddings. 375 | lora_scale (`float`, *optional*): 376 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 377 | """ 378 | device = device or self._execution_device 379 | 380 | # set lora scale so that monkey patched LoRA 381 | # function of text encoder can correctly access it 382 | if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): 383 | self._lora_scale = lora_scale 384 | 385 | # dynamically adjust the LoRA scale 386 | if self.text_encoder is not None and USE_PEFT_BACKEND: 387 | scale_lora_layers(self.text_encoder, lora_scale) 388 | if self.text_encoder_2 is not None and USE_PEFT_BACKEND: 389 | scale_lora_layers(self.text_encoder_2, lora_scale) 390 | 391 | prompt = [prompt] if isinstance(prompt, str) else prompt 392 | if prompt is not None: 393 | batch_size = len(prompt) 394 | else: 395 | batch_size = prompt_embeds.shape[0] 396 | 397 | if prompt_embeds is None: 398 | prompt_2 = prompt_2 or prompt 399 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 400 | 401 | # We only use the pooled prompt output from the CLIPTextModel 402 | pooled_prompt_embeds = self._get_clip_prompt_embeds( 403 | prompt=prompt, 404 | device=device, 405 | num_images_per_prompt=num_images_per_prompt, 406 | ) 407 | prompt_embeds = self._get_t5_prompt_embeds( 408 | prompt=prompt_2, 409 | num_images_per_prompt=num_images_per_prompt, 410 | max_sequence_length=max_sequence_length, 411 | device=device, 412 | ) 413 | 414 | if do_classifier_free_guidance: 415 | # 处理 negative prompt 416 | negative_prompt = negative_prompt or "bad quality, worst quality, text, signature, watermark, extra words" 417 | negative_prompt_2 = negative_prompt_2 or negative_prompt 418 | 419 | negative_pooled_prompt_embeds = self._get_clip_prompt_embeds( 420 | negative_prompt, 421 | device=device, 422 | num_images_per_prompt=num_images_per_prompt, 423 | ) 424 | negative_prompt_embeds = self._get_t5_prompt_embeds( 425 | negative_prompt_2, 426 | num_images_per_prompt=num_images_per_prompt, 427 | max_sequence_length=max_sequence_length, 428 | device=device, 429 | ) 430 | else: 431 | negative_pooled_prompt_embeds = None 432 | negative_prompt_embeds = None 433 | 434 | if self.text_encoder is not None: 435 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 436 | # Retrieve the original scale by scaling back the LoRA layers 437 | unscale_lora_layers(self.text_encoder, lora_scale) 438 | 439 | if self.text_encoder_2 is not None: 440 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 441 | # Retrieve the original scale by scaling back the LoRA layers 442 | unscale_lora_layers(self.text_encoder_2, lora_scale) 443 | 444 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to( 445 | device=device, dtype=self.text_encoder.dtype 446 | ) 447 | 448 | return prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds,text_ids 449 | 450 | # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image 451 | def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): 452 | if isinstance(generator, list): 453 | image_latents = [ 454 | retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) 455 | for i in range(image.shape[0]) 456 | ] 457 | image_latents = torch.cat(image_latents, dim=0) 458 | else: 459 | image_latents = retrieve_latents(self.vae.encode(image), generator=generator) 460 | 461 | image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor 462 | 463 | return image_latents 464 | 465 | # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps 466 | def get_timesteps(self, num_inference_steps, strength, device): 467 | # get the original timestep using init_timestep 468 | init_timestep = min(num_inference_steps * strength, num_inference_steps) 469 | 470 | t_start = int(max(num_inference_steps - init_timestep, 0)) 471 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] 472 | if hasattr(self.scheduler, "set_begin_index"): 473 | self.scheduler.set_begin_index(t_start * self.scheduler.order) 474 | 475 | return timesteps, num_inference_steps - t_start 476 | 477 | def check_inputs( 478 | self, 479 | prompt, 480 | prompt_2, 481 | height, 482 | width, 483 | prompt_embeds=None, 484 | pooled_prompt_embeds=None, 485 | callback_on_step_end_tensor_inputs=None, 486 | max_sequence_length=None, 487 | ): 488 | if height % 8 != 0 or width % 8 != 0: 489 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 490 | 491 | if callback_on_step_end_tensor_inputs is not None and not all( 492 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 493 | ): 494 | raise ValueError( 495 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 496 | ) 497 | 498 | if prompt is not None and prompt_embeds is not None: 499 | raise ValueError( 500 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 501 | " only forward one of the two." 502 | ) 503 | elif prompt_2 is not None and prompt_embeds is not None: 504 | raise ValueError( 505 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 506 | " only forward one of the two." 507 | ) 508 | elif prompt is None and prompt_embeds is None: 509 | raise ValueError( 510 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 511 | ) 512 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 513 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 514 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): 515 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") 516 | 517 | if prompt_embeds is not None and pooled_prompt_embeds is None: 518 | raise ValueError( 519 | "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." 520 | ) 521 | 522 | if max_sequence_length is not None and max_sequence_length > 512: 523 | raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") 524 | 525 | @staticmethod 526 | # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids 527 | def _prepare_latent_image_ids(batch_size, height, width, device, dtype): 528 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) 529 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] 530 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] 531 | 532 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape 533 | 534 | latent_image_ids = latent_image_ids.reshape( 535 | latent_image_id_height * latent_image_id_width, latent_image_id_channels 536 | ) 537 | 538 | return latent_image_ids.to(device=device, dtype=dtype) 539 | 540 | @staticmethod 541 | # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents 542 | def _pack_latents(latents, batch_size, num_channels_latents, height, width): 543 | latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) 544 | latents = latents.permute(0, 2, 4, 1, 3, 5) 545 | latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) 546 | 547 | return latents 548 | 549 | @staticmethod 550 | # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents 551 | def _unpack_latents(latents, height, width, vae_scale_factor): 552 | batch_size, num_patches, channels = latents.shape 553 | 554 | height = height // vae_scale_factor 555 | width = width // vae_scale_factor 556 | 557 | latents = latents.view(batch_size, height, width, channels // 4, 2, 2) 558 | latents = latents.permute(0, 3, 1, 4, 2, 5) 559 | 560 | latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) 561 | 562 | return latents 563 | 564 | # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents 565 | def prepare_latents( 566 | self, 567 | batch_size, 568 | num_channels_latents, 569 | height, 570 | width, 571 | dtype, 572 | device, 573 | generator, 574 | latents=None, 575 | ): 576 | height = 2 * (int(height) // self.vae_scale_factor) 577 | width = 2 * (int(width) // self.vae_scale_factor) 578 | 579 | shape = (batch_size, num_channels_latents, height, width) 580 | 581 | if latents is not None: 582 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 583 | return latents.to(device=device, dtype=dtype), latent_image_ids 584 | 585 | if isinstance(generator, list) and len(generator) != batch_size: 586 | raise ValueError( 587 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 588 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 589 | ) 590 | 591 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 592 | 593 | latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) 594 | 595 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 596 | 597 | return latents, latent_image_ids 598 | 599 | # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents 600 | def prepare_latents_reptext( 601 | self, 602 | image, 603 | batch_size, 604 | num_channels_latents, 605 | height, 606 | width, 607 | dtype, 608 | device, 609 | generator, 610 | latents=None, 611 | ): 612 | height = 2 * (int(height) // self.vae_scale_factor) 613 | width = 2 * (int(width) // self.vae_scale_factor) 614 | 615 | image = image.to(device=device, dtype=dtype) # torch.Size([1, 3, height, width]) 616 | image_latents = self._encode_vae_image(image=image, generator=generator) 617 | 618 | if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: 619 | # expand init_latents for batch_size 620 | additional_image_per_prompt = batch_size // image_latents.shape[0] 621 | image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) 622 | elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: 623 | raise ValueError( 624 | f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." 625 | ) 626 | else: 627 | image_latents = torch.cat([image_latents], dim=0) 628 | 629 | shape = (batch_size, num_channels_latents, height, width) 630 | 631 | if latents is not None: 632 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 633 | return latents.to(device=device, dtype=dtype), latent_image_ids 634 | 635 | noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 636 | 637 | glyph_mask = (image > 0).any(dim=1, keepdim=True) 638 | glyph_mask = glyph_mask.repeat(1, 16, 1, 1).float() 639 | glyph_mask = F.interpolate(glyph_mask, size=(noise.shape[-2], noise.shape[-1]), mode='bilinear', align_corners=False) 640 | glyph_mask[glyph_mask>0] = 1 641 | glyph_mask[glyph_mask<0] = 0 642 | 643 | glyph_mask = glyph_mask > 0 644 | result = torch.zeros_like(noise) 645 | result[glyph_mask] = 0.10 * image_latents[glyph_mask] + 1.00 * noise[glyph_mask] 646 | result[~glyph_mask] = noise[~glyph_mask] 647 | noise = result 648 | 649 | latents = self._pack_latents(noise, batch_size, num_channels_latents, height, width) 650 | 651 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 652 | 653 | return latents, latent_image_ids 654 | 655 | # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image 656 | def prepare_image( 657 | self, 658 | image, 659 | width, 660 | height, 661 | batch_size, 662 | num_images_per_prompt, 663 | device, 664 | dtype, 665 | image_position=None, 666 | do_classifier_free_guidance=False, 667 | guess_mode=False, 668 | ): 669 | # prepare image 670 | if isinstance(image, torch.Tensor): 671 | pass 672 | else: 673 | image = self.image_processor.preprocess(image, height=height, width=width) 674 | image_batch_size = image.shape[0] 675 | if image_batch_size == 1: 676 | repeat_by = batch_size 677 | else: 678 | # image batch size is the same as prompt batch size 679 | repeat_by = num_images_per_prompt 680 | image = image.repeat_interleave(repeat_by, dim=0) 681 | image = image.to(device=device, dtype=dtype) 682 | 683 | # prepare position image 684 | if isinstance(image_position, torch.Tensor): 685 | pass 686 | else: 687 | image_position = self.image_processor.preprocess(image_position, height=height, width=width) 688 | image_batch_size = image_position.shape[0] 689 | if image_batch_size == 1: 690 | repeat_by = batch_size 691 | else: 692 | repeat_by = num_images_per_prompt 693 | image_position = image_position.repeat_interleave(repeat_by, dim=0) 694 | image_position = image_position.to(device=device, dtype=dtype) 695 | image_position = image_position.repeat(1,3,1,1) 696 | 697 | # Encode to latents 698 | image_latents = self.vae.encode(image.to(self.vae.dtype)).latent_dist.sample() 699 | image_latents = ( 700 | image_latents - self.vae.config.shift_factor 701 | ) * self.vae.config.scaling_factor 702 | image_latents = image_latents.to(dtype) 703 | 704 | position_image_latents = self.vae.encode(image_position.to(self.vae.dtype)).latent_dist.sample() 705 | position_image_latents = ( 706 | position_image_latents - self.vae.config.shift_factor 707 | ) * self.vae.config.scaling_factor 708 | position_image_latents = position_image_latents.to(dtype) 709 | 710 | control_image = torch.cat([image_latents, position_image_latents], dim=1) 711 | 712 | # Pack cond latents 713 | packed_control_image = self._pack_latents( 714 | control_image, 715 | batch_size * num_images_per_prompt, 716 | control_image.shape[1], 717 | control_image.shape[2], 718 | control_image.shape[3], 719 | ) 720 | 721 | if do_classifier_free_guidance: 722 | packed_control_image = torch.cat([packed_control_image] * 2) 723 | 724 | return packed_control_image, height, width 725 | 726 | # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image 727 | def prepare_image_union( 728 | self, 729 | image, 730 | width, 731 | height, 732 | batch_size, 733 | num_images_per_prompt, 734 | device, 735 | dtype, 736 | do_classifier_free_guidance=False, 737 | guess_mode=False, 738 | ): 739 | if isinstance(image, torch.Tensor): 740 | pass 741 | else: 742 | image = self.image_processor.preprocess(image, height=height, width=width) 743 | 744 | image_batch_size = image.shape[0] 745 | 746 | if image_batch_size == 1: 747 | repeat_by = batch_size 748 | else: 749 | # image batch size is the same as prompt batch size 750 | repeat_by = num_images_per_prompt 751 | 752 | image = image.repeat_interleave(repeat_by, dim=0) 753 | 754 | image = image.to(device=device, dtype=dtype) 755 | 756 | if do_classifier_free_guidance and not guess_mode: 757 | image = torch.cat([image] * 2) 758 | 759 | return image 760 | 761 | def prepare_image_with_mask( 762 | self, 763 | image, 764 | mask, 765 | width, 766 | height, 767 | batch_size, 768 | num_images_per_prompt, 769 | device, 770 | dtype, 771 | do_classifier_free_guidance = False, 772 | ): 773 | # Prepare image 774 | if isinstance(image, torch.Tensor): 775 | pass 776 | else: 777 | image = self.image_processor.preprocess(image, height=height, width=width) 778 | 779 | image_batch_size = image.shape[0] 780 | if image_batch_size == 1: 781 | repeat_by = batch_size 782 | else: 783 | # image batch size is the same as prompt batch size 784 | repeat_by = num_images_per_prompt 785 | image = image.repeat_interleave(repeat_by, dim=0) 786 | image = image.to(device=device, dtype=dtype) 787 | 788 | # Prepare mask 789 | if isinstance(mask, torch.Tensor): 790 | pass 791 | else: 792 | mask = self.mask_processor.preprocess(mask, height=height, width=width) 793 | mask = mask.repeat_interleave(repeat_by, dim=0) 794 | mask = mask.to(device=device, dtype=dtype) 795 | 796 | # Get masked image 797 | masked_image = image.clone() 798 | masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1 799 | 800 | # Encode to latents 801 | image_latents = self.vae.encode(masked_image.to(self.vae.dtype)).latent_dist.sample() 802 | image_latents = ( 803 | image_latents - self.vae.config.shift_factor 804 | ) * self.vae.config.scaling_factor 805 | image_latents = image_latents.to(dtype) 806 | 807 | mask = torch.nn.functional.interpolate( 808 | mask, size=(height // self.vae_scale_factor * 2, width // self.vae_scale_factor * 2) 809 | ) 810 | mask = 1 - mask 811 | 812 | control_image = torch.cat([image_latents, mask], dim=1) 813 | 814 | # Pack cond latents 815 | packed_control_image = self._pack_latents( 816 | control_image, 817 | batch_size * num_images_per_prompt, 818 | control_image.shape[1], 819 | control_image.shape[2], 820 | control_image.shape[3], 821 | ) 822 | 823 | if do_classifier_free_guidance: 824 | packed_control_image = torch.cat([packed_control_image] * 2) 825 | 826 | return packed_control_image, height, width 827 | 828 | @property 829 | def guidance_scale(self): 830 | return self._guidance_scale 831 | 832 | @property 833 | def joint_attention_kwargs(self): 834 | return self._joint_attention_kwargs 835 | 836 | @property 837 | def num_timesteps(self): 838 | return self._num_timesteps 839 | 840 | @property 841 | def interrupt(self): 842 | return self._interrupt 843 | 844 | @torch.no_grad() 845 | @replace_example_docstring(EXAMPLE_DOC_STRING) 846 | def __call__( 847 | self, 848 | prompt: Union[str, List[str]] = None, 849 | prompt_2: Optional[Union[str, List[str]]] = None, 850 | true_guidance_scale: float = 3.5 , 851 | negative_prompt: Optional[Union[str, List[str]]] = None, 852 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 853 | height: Optional[int] = None, 854 | width: Optional[int] = None, 855 | num_inference_steps: int = 28, 856 | timesteps: List[int] = None, 857 | guidance_scale: float = 7.0, 858 | control_guidance_start: Union[float, List[float]] = 0.0, 859 | control_guidance_end: Union[float, List[float]] = 1.0, 860 | control_image: PipelineImageInput = None, 861 | control_mode: Optional[Union[int, List[int]]] = None, 862 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0, 863 | controlnet_conditioning_step: int = 30, 864 | num_images_per_prompt: Optional[int] = 1, 865 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 866 | latents: Optional[torch.FloatTensor] = None, 867 | prompt_embeds: Optional[torch.FloatTensor] = None, 868 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 869 | output_type: Optional[str] = "pil", 870 | return_dict: bool = True, 871 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 872 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 873 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 874 | max_sequence_length: int = 512, 875 | control_mask: Optional[torch.FloatTensor] = None, 876 | control_position: Optional[torch.FloatTensor] = None, 877 | control_glyph: Optional[torch.FloatTensor] = None, 878 | 879 | # for其它条件 880 | control_image_inpaint: PipelineImageInput = None, 881 | control_mask_inpaint: Optional[torch.FloatTensor] = None, 882 | controlnet_conditioning_scale_inpaint: Union[float, List[float]] = 1.0, 883 | ): 884 | r""" 885 | Function invoked when calling the pipeline for generation. 886 | 887 | Args: 888 | prompt (`str` or `List[str]`, *optional*): 889 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 890 | instead. 891 | prompt_2 (`str` or `List[str]`, *optional*): 892 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 893 | will be used instead 894 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 895 | The height in pixels of the generated image. This is set to 1024 by default for the best results. 896 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 897 | The width in pixels of the generated image. This is set to 1024 by default for the best results. 898 | num_inference_steps (`int`, *optional*, defaults to 50): 899 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 900 | expense of slower inference. 901 | timesteps (`List[int]`, *optional*): 902 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 903 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 904 | passed will be used. Must be in descending order. 905 | guidance_scale (`float`, *optional*, defaults to 7.0): 906 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 907 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 908 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 909 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 910 | usually at the expense of lower image quality. 911 | control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: 912 | `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): 913 | The ControlNet input condition to provide guidance to the `unet` for generation. If the type is 914 | specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted 915 | as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or 916 | width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, 917 | images must be passed as a list such that each element of the list can be correctly batched for input 918 | to a single ControlNet. 919 | controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): 920 | The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added 921 | to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set 922 | the corresponding scale as a list. 923 | control_mode (`int` or `List[int]`,, *optional*, defaults to None): 924 | The control mode when applying ControlNet-Union. 925 | num_images_per_prompt (`int`, *optional*, defaults to 1): 926 | The number of images to generate per prompt. 927 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 928 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 929 | to make generation deterministic. 930 | latents (`torch.FloatTensor`, *optional*): 931 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 932 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 933 | tensor will ge generated by sampling using the supplied random `generator`. 934 | prompt_embeds (`torch.FloatTensor`, *optional*): 935 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 936 | provided, text embeddings will be generated from `prompt` input argument. 937 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 938 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 939 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 940 | output_type (`str`, *optional*, defaults to `"pil"`): 941 | The output format of the generate image. Choose between 942 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 943 | return_dict (`bool`, *optional*, defaults to `True`): 944 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. 945 | joint_attention_kwargs (`dict`, *optional*): 946 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 947 | `self.processor` in 948 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 949 | callback_on_step_end (`Callable`, *optional*): 950 | A function that calls at the end of each denoising steps during the inference. The function is called 951 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 952 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 953 | `callback_on_step_end_tensor_inputs`. 954 | callback_on_step_end_tensor_inputs (`List`, *optional*): 955 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 956 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 957 | `._callback_tensor_inputs` attribute of your pipeline class. 958 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. 959 | 960 | Examples: 961 | 962 | Returns: 963 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` 964 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated 965 | images. 966 | """ 967 | 968 | height = height or self.default_sample_size * self.vae_scale_factor 969 | width = width or self.default_sample_size * self.vae_scale_factor 970 | 971 | if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): 972 | control_guidance_start = len(control_guidance_end) * [control_guidance_start] 973 | elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): 974 | control_guidance_end = len(control_guidance_start) * [control_guidance_end] 975 | elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): 976 | mult = len(control_image) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 977 | control_guidance_start, control_guidance_end = ( 978 | mult * [control_guidance_start], 979 | mult * [control_guidance_end], 980 | ) 981 | 982 | # 1. Check inputs. Raise error if not correct 983 | self.check_inputs( 984 | prompt, 985 | prompt_2, 986 | height, 987 | width, 988 | prompt_embeds=prompt_embeds, 989 | pooled_prompt_embeds=pooled_prompt_embeds, 990 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 991 | max_sequence_length=max_sequence_length, 992 | ) 993 | 994 | self._guidance_scale = guidance_scale 995 | self._joint_attention_kwargs = joint_attention_kwargs 996 | self._interrupt = False 997 | 998 | # 2. Define call parameters 999 | if prompt is not None and isinstance(prompt, str): 1000 | batch_size = 1 1001 | elif prompt is not None and isinstance(prompt, list): 1002 | batch_size = len(prompt) 1003 | else: 1004 | batch_size = prompt_embeds.shape[0] 1005 | 1006 | device = self._execution_device 1007 | dtype = self.transformer.dtype 1008 | 1009 | lora_scale = ( 1010 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 1011 | ) 1012 | 1013 | ( 1014 | prompt_embeds, 1015 | pooled_prompt_embeds, 1016 | negative_prompt_embeds, 1017 | negative_pooled_prompt_embeds, 1018 | text_ids 1019 | ) = self.encode_prompt( 1020 | prompt=prompt, 1021 | prompt_2=prompt_2, 1022 | prompt_embeds=prompt_embeds, 1023 | pooled_prompt_embeds=pooled_prompt_embeds, 1024 | do_classifier_free_guidance = self.do_classifier_free_guidance, 1025 | negative_prompt = negative_prompt, 1026 | negative_prompt_2 = negative_prompt_2, 1027 | device=device, 1028 | num_images_per_prompt=num_images_per_prompt, 1029 | max_sequence_length=max_sequence_length, 1030 | lora_scale=lora_scale, 1031 | ) 1032 | 1033 | if self.do_classifier_free_guidance: 1034 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim = 0) 1035 | pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim = 0) 1036 | 1037 | # handle text conditions(canny+position) 1038 | if isinstance(self.controlnet, FluxControlNetModel): 1039 | control_image_list = [] 1040 | for control_image_, control_position_ in zip(control_image, control_position): 1041 | control_image_, height, width = self.prepare_image( 1042 | image=control_image_, 1043 | image_position=control_position_, 1044 | width=width, 1045 | height=height, 1046 | batch_size=batch_size * num_images_per_prompt, 1047 | num_images_per_prompt=num_images_per_prompt, 1048 | device=device, 1049 | dtype=dtype, 1050 | do_classifier_free_guidance=self.do_classifier_free_guidance, 1051 | ) 1052 | control_image_list.append(control_image_) 1053 | 1054 | # handle inpaint condition 1055 | if isinstance(self.controlnet, FluxControlNetModel): 1056 | # inpaint only 1057 | control_image_inpaint, height, width = self.prepare_image_with_mask( 1058 | image=control_image_inpaint, 1059 | mask=control_mask_inpaint, 1060 | width=width, 1061 | height=height, 1062 | batch_size=batch_size * num_images_per_prompt, 1063 | num_images_per_prompt=num_images_per_prompt, 1064 | device=device, 1065 | dtype=dtype, 1066 | do_classifier_free_guidance=self.do_classifier_free_guidance, 1067 | ) 1068 | 1069 | # 4. Prepare latent variables 1070 | num_channels_latents = self.transformer.config.in_channels // 4 1071 | 1072 | # 5. Prepare timesteps 1073 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 1074 | image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) 1075 | 1076 | mu = calculate_shift( 1077 | image_seq_len, 1078 | self.scheduler.config.base_image_seq_len, 1079 | self.scheduler.config.max_image_seq_len, 1080 | self.scheduler.config.base_shift, 1081 | self.scheduler.config.max_shift, 1082 | ) 1083 | timesteps, num_inference_steps = retrieve_timesteps( 1084 | self.scheduler, 1085 | num_inference_steps, 1086 | device, 1087 | timesteps, 1088 | sigmas, 1089 | mu=mu, 1090 | ) 1091 | 1092 | if control_glyph is not None: 1093 | init_image = self.image_processor.preprocess(control_glyph, height=height, width=width) # torch.Size([1, 3, 1280, 960]) 1094 | init_image = init_image.to(dtype=torch.float32) 1095 | latents, latent_image_ids = self.prepare_latents_reptext( 1096 | init_image, 1097 | batch_size * num_images_per_prompt, 1098 | num_channels_latents, 1099 | height, 1100 | width, 1101 | prompt_embeds.dtype, 1102 | device, 1103 | generator, 1104 | None, 1105 | ) 1106 | else: 1107 | latents, latent_image_ids = self.prepare_latents( 1108 | batch_size * num_images_per_prompt, 1109 | num_channels_latents, 1110 | height, 1111 | width, 1112 | prompt_embeds.dtype, 1113 | device, 1114 | generator, 1115 | latents, 1116 | ) 1117 | 1118 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 1119 | self._num_timesteps = len(timesteps) 1120 | 1121 | # 6. Create tensor stating which controlnets to keep 1122 | controlnet_keep = [] 1123 | for i in range(len(timesteps)): 1124 | keeps = [ 1125 | 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) 1126 | for s, e in zip(control_guidance_start, control_guidance_end) 1127 | ] 1128 | controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) 1129 | 1130 | control_mask_list = [] 1131 | if control_mask is not None: 1132 | for control_mask_ in control_mask: 1133 | region_mask = torch.from_numpy(np.array(control_mask_)) / 255. # (height, width) 1134 | mask = F.interpolate(region_mask[None, None], scale_factor=1/16, mode='bilinear').reshape([1, -1, 1]) 1135 | control_mask_ = mask.to(device=latents.device, dtype=latents.dtype) 1136 | control_mask_list.append(control_mask_) 1137 | 1138 | # 6. Denoising loop 1139 | with self.progress_bar(total=num_inference_steps) as progress_bar: 1140 | for i, t in enumerate(timesteps): 1141 | 1142 | if self.interrupt: 1143 | continue 1144 | 1145 | latent_model_input = latents 1146 | 1147 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 1148 | timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) 1149 | 1150 | # handle guidance 1151 | if self.transformer.config.guidance_embeds: 1152 | guidance = torch.tensor([guidance_scale], device=device) 1153 | guidance = guidance.expand(latent_model_input.shape[0]) 1154 | else: 1155 | guidance = None 1156 | 1157 | control_block_samples = None 1158 | control_single_block_samples = None 1159 | 1160 | for control_index in range(len(control_image_list)): 1161 | 1162 | control_image = control_image_list[control_index] 1163 | control_mask = control_mask_list[control_index] if len(control_mask_list)>0 else None 1164 | 1165 | if i < controlnet_conditioning_step: 1166 | # controlnet 1167 | controlnet_block_samples, controlnet_single_block_samples = self.controlnet( 1168 | hidden_states=latent_model_input, 1169 | controlnet_cond=control_image, 1170 | controlnet_mode=control_mode, 1171 | conditioning_scale=controlnet_conditioning_scale, 1172 | timestep=timestep / 1000, 1173 | guidance=guidance, 1174 | pooled_projections=pooled_prompt_embeds, 1175 | encoder_hidden_states=prompt_embeds, 1176 | txt_ids=text_ids, 1177 | img_ids=latent_image_ids, 1178 | joint_attention_kwargs=self.joint_attention_kwargs, 1179 | return_dict=False, 1180 | ) 1181 | else: 1182 | controlnet_block_samples, controlnet_single_block_samples = None, None 1183 | 1184 | if controlnet_block_samples is not None: 1185 | if control_mask is not None: 1186 | controlnet_block_samples = [control_mask*sample.to(dtype=latents.dtype, device=device) for sample in controlnet_block_samples] 1187 | else: 1188 | controlnet_block_samples = [sample.to(dtype=latents.dtype, device=device) for sample in controlnet_block_samples] 1189 | if controlnet_single_block_samples is not None: 1190 | if control_mask is not None: 1191 | controlnet_single_block_samples = [control_mask*sample.to(dtype=latents.dtype, device=device) for sample in controlnet_single_block_samples] 1192 | else: 1193 | controlnet_single_block_samples = [sample.to(dtype=latents.dtype, device=device) for sample in controlnet_single_block_samples] 1194 | 1195 | # merge samples 1196 | if control_index == 0: 1197 | control_block_samples = controlnet_block_samples 1198 | control_single_block_samples = controlnet_single_block_samples 1199 | else: 1200 | if controlnet_block_samples is not None and control_block_samples is not None: 1201 | control_block_samples = [ 1202 | control_block_sample + block_sample 1203 | for control_block_sample, block_sample in zip(control_block_samples, controlnet_block_samples) 1204 | ] 1205 | if controlnet_single_block_samples is not None and control_single_block_samples is not None: 1206 | control_single_block_samples = [ 1207 | control_single_block_sample + block_sample 1208 | for control_single_block_sample, block_sample in zip( 1209 | control_single_block_samples, controlnet_single_block_samples 1210 | ) 1211 | ] 1212 | 1213 | # 处理其它条件 1214 | controlnet_block_samples, controlnet_single_block_samples = self.controlnet_inpaint( 1215 | hidden_states=latent_model_input, 1216 | controlnet_cond=control_image_inpaint, 1217 | controlnet_mode=control_mode, 1218 | conditioning_scale=controlnet_conditioning_scale_inpaint, 1219 | timestep=timestep / 1000, 1220 | guidance=guidance, 1221 | pooled_projections=pooled_prompt_embeds, 1222 | encoder_hidden_states=prompt_embeds, 1223 | txt_ids=text_ids, 1224 | img_ids=latent_image_ids, 1225 | joint_attention_kwargs=self.joint_attention_kwargs, 1226 | return_dict=False, 1227 | ) 1228 | 1229 | if controlnet_block_samples is not None: 1230 | controlnet_block_samples = [sample.to(dtype=latents.dtype, device=device) for sample in controlnet_block_samples] 1231 | if controlnet_single_block_samples is not None: 1232 | controlnet_single_block_samples = [sample.to(dtype=latents.dtype, device=device) for sample in controlnet_single_block_samples] 1233 | 1234 | if controlnet_block_samples is not None and control_block_samples is not None: 1235 | control_block_samples = [ 1236 | control_block_sample + block_sample 1237 | for control_block_sample, block_sample in zip(control_block_samples, controlnet_block_samples) 1238 | ] 1239 | if controlnet_single_block_samples is not None and control_single_block_samples is not None: 1240 | control_single_block_samples = [ 1241 | control_single_block_sample + block_sample 1242 | for control_single_block_sample, block_sample in zip( 1243 | control_single_block_samples, controlnet_single_block_samples 1244 | ) 1245 | ] 1246 | 1247 | controlnet_block_samples = control_block_samples 1248 | controlnet_single_block_samples = control_single_block_samples 1249 | 1250 | noise_pred = self.transformer( 1251 | hidden_states=latent_model_input, 1252 | timestep=timestep / 1000, 1253 | guidance=guidance, 1254 | pooled_projections=pooled_prompt_embeds, 1255 | encoder_hidden_states=prompt_embeds, 1256 | controlnet_block_samples=controlnet_block_samples, 1257 | controlnet_single_block_samples=controlnet_single_block_samples, 1258 | txt_ids=text_ids, 1259 | img_ids=latent_image_ids, 1260 | joint_attention_kwargs=self.joint_attention_kwargs, 1261 | return_dict=False, 1262 | )[0] 1263 | 1264 | if self.do_classifier_free_guidance: 1265 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 1266 | if i > 0: 1267 | noise_pred = noise_pred_uncond + true_guidance_scale * (noise_pred_text - noise_pred_uncond) 1268 | else: 1269 | #noise_pred = noise_pred_uncond + 1.5 * (noise_pred_text - noise_pred_uncond) 1270 | noise_pred = noise_pred_text*0. 1271 | 1272 | # compute the previous noisy sample x_t -> x_t-1 1273 | latents_dtype = latents.dtype 1274 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 1275 | 1276 | if latents.dtype != latents_dtype: 1277 | if torch.backends.mps.is_available(): 1278 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 1279 | latents = latents.to(latents_dtype) 1280 | 1281 | if callback_on_step_end is not None: 1282 | callback_kwargs = {} 1283 | for k in callback_on_step_end_tensor_inputs: 1284 | callback_kwargs[k] = locals()[k] 1285 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 1286 | 1287 | latents = callback_outputs.pop("latents", latents) 1288 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 1289 | 1290 | # call the callback, if provided 1291 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 1292 | progress_bar.update() 1293 | 1294 | if XLA_AVAILABLE: 1295 | xm.mark_step() 1296 | 1297 | if output_type == "latent": 1298 | image = latents 1299 | 1300 | else: 1301 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 1302 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 1303 | 1304 | image = self.vae.decode(latents, return_dict=False)[0] 1305 | image = self.image_processor.postprocess(image, output_type=output_type) 1306 | 1307 | # Offload all models 1308 | self.maybe_free_model_hooks() 1309 | 1310 | if not return_dict: 1311 | return (image,) 1312 | 1313 | return FluxPipelineOutput(images=image) 1314 | -------------------------------------------------------------------------------- /results/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shakker-Labs/RepText/2080e4c0965cb00d7a1506da5469692b2d854d50/results/result.jpg -------------------------------------------------------------------------------- /results/result_inpaint.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shakker-Labs/RepText/2080e4c0965cb00d7a1506da5469692b2d854d50/results/result_inpaint.jpg --------------------------------------------------------------------------------