├── src ├── __init__.py ├── prompt_helper.py ├── jsonl_datasets.py └── pipeline_pe_clone.py ├── .gitignore ├── assets ├── teaser.png ├── result_show.png ├── close-eye-src1.jpg ├── close-on-cond1.jpg ├── close-on-cond2.jpg ├── close-on-src1.jpg ├── close-eye-cond1.jpg ├── close-eye-cond2.jpg ├── close-eye-target.jpg └── close-on-target.jpg ├── requirements.txt ├── README.md ├── attention_processor.py ├── infer_single.py ├── LICENSE ├── transformer_flux.py └── pipeline_flux_ipa.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | *.pyo 4 | *.pyd -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gy8888/RelationAdapter/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /assets/result_show.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gy8888/RelationAdapter/HEAD/assets/result_show.png -------------------------------------------------------------------------------- /assets/close-eye-src1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gy8888/RelationAdapter/HEAD/assets/close-eye-src1.jpg -------------------------------------------------------------------------------- /assets/close-on-cond1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gy8888/RelationAdapter/HEAD/assets/close-on-cond1.jpg -------------------------------------------------------------------------------- /assets/close-on-cond2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gy8888/RelationAdapter/HEAD/assets/close-on-cond2.jpg -------------------------------------------------------------------------------- /assets/close-on-src1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gy8888/RelationAdapter/HEAD/assets/close-on-src1.jpg -------------------------------------------------------------------------------- /assets/close-eye-cond1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gy8888/RelationAdapter/HEAD/assets/close-eye-cond1.jpg -------------------------------------------------------------------------------- /assets/close-eye-cond2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gy8888/RelationAdapter/HEAD/assets/close-eye-cond2.jpg -------------------------------------------------------------------------------- /assets/close-eye-target.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gy8888/RelationAdapter/HEAD/assets/close-eye-target.jpg -------------------------------------------------------------------------------- /assets/close-on-target.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gy8888/RelationAdapter/HEAD/assets/close-on-target.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==3.5.0 2 | diffusers==0.32.2 3 | numpy==1.26.4 4 | Pillow==11.2.1 5 | tqdm==4.67.1 6 | transformers==4.44.0 7 | accelerate==0.33.0 8 | protobuf==5.29.4 9 | sentencepiece==0.2.0 10 | peft==0.15.1 11 | 12 | # torch version 13 | # torch==2.5.1+cu124 14 | # torchvision==0.20.1+cu124 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RelationAdapter 2 | 3 | > **RelationAdapter: Learning and Transferring Visual Relation with Diffusion Transformers** 4 | >
5 | > Yan Gong, 6 | > Yiren Song, 7 | > Yicheng Li, 8 | > Chenglin Li, 9 | > and 10 | > Yin Zhang 11 | >
12 | > Zhejiang University, National University of Singapore 13 | >
14 | 15 | arXiv 16 | HuggingFace 17 | HuggingFace 18 | 19 |
20 | 21 | 22 | 23 | ## Quick Start 24 | ### Configuration 25 | #### 1. **Environment setup** 26 | ```bash 27 | git clone git@github.com:gy8888/RelationAdapter.git 28 | cd RelationAdapter 29 | 30 | conda create -n RelationAdapter python=3.11.10 31 | conda activate RelationAdapter 32 | ``` 33 | #### 2. **Requirements installation** 34 | ```bash 35 | pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124 36 | pip install --upgrade -r requirements.txt 37 | ``` 38 | 39 | 40 | ### 2. Inference 41 | We provided the integration of FluxPipeline pipeline with our model and uploaded the model weights to huggingface, it's easy to use the our model as example below: 42 | 43 | simply run the inference script: 44 | ``` 45 | python infer_single.py 46 | ``` 47 | 48 | 49 | ### 3. Weights 50 | You can download the trained checkpoints of RelationAdapter and LoRA for inference. Below are the details of available models. 51 | 52 | You would need to load the `RelationAdapter` checkpoints model in order to fuse the `LoRA` checkpoints. 53 | 54 | | **Model** | **Description** | 55 | | :----------------------------------------------------------: | :---------------------------------------------------------: | 56 | | [RelationAdapter](https://huggingface.co/handsomeWilliam/RelationAdapter/blob/main/ip_adapter-100000.bin) | Additional parameters from the RelationAdapter module are trained on the `Relation252K` dataset | 57 | | [LoRA](https://huggingface.co/handsomeWilliam/RelationAdapter/blob/main/pytorch_lora_weights.safetensors) | LoRA parameters are trained on the `Relation252K` dataset | 58 | 59 | 60 | ### 4. Dataset 61 | 62 | #### 4.1 Paired Dataset Format 63 | The paired dataset is stored in a .jsonl file, where each entry contains image file paths and corresponding text descriptions. Each entry includes source caption, target caption, and edit instruction describing the transformation from source image to target image. 64 | 65 | Example format: 66 | 67 | ```json 68 | { 69 | "left_image_description": "Description of the left image", 70 | "right_image_description": "Description of the right image", 71 | "edit_instruction": "Instructions for the desired modifications", 72 | "img_name": "path/to/image_pair.jpg" 73 | }, 74 | { 75 | "left_image_description": "Description of the left image2", 76 | "right_image_description": "Description of the right image2", 77 | "edit_instruction": "Another instruction", 78 | "img_name": "path/to/image_pair2.jpg" 79 | } 80 | ``` 81 | We have uploaded our datasets to [Hugging Face](https://huggingface.co/datasets/handsomeWilliam/Relation252K). 82 | 83 | #### 4.2 Run-Ready Dataset Generation 84 | To prepare the dataset for relational learning tasks such as analogy-based instruction scenarios, use the provided script 85 | ``` 86 | python dataset-All-2000-turn-5test.py 87 | ``` 88 | 89 | This script takes the original paired image dataset and converts it into a structured format where each entry includes: 90 | Example format: 91 | 92 | ```json 93 | { 94 | "cond1": "path/to/prompt_image.jpg", 95 | "cond2": "path/to/reference_image.jpg", 96 | "source": "path/to/source_image.jpg", 97 | "target": "path/to/target_image.jpg", 98 | "text": "Instruction for the intended modifications" 99 | }, 100 | { 101 | "cond1": "path/to/prompt_image2.jpg", 102 | "cond2": "path/to/reference_image2.jpg", 103 | "source": "path/to/source_image2.jpg", 104 | "target": "path/to/target_image2.jpg", 105 | "text": "Instruction for the second modification" 106 | } 107 | ``` 108 | 109 | ### 5. Results 110 | 111 | ![S-U](./assets/result_show.png) 112 | 113 | 114 | ## Citation 115 | ``` 116 | @misc{gong2025relationadapterlearningtransferringvisual, 117 | title={RelationAdapter: Learning and Transferring Visual Relation with Diffusion Transformers}, 118 | author={Yan Gong and Yiren Song and Yicheng Li and Chenglin Li and Yin Zhang}, 119 | year={2025}, 120 | eprint={2506.02528}, 121 | archivePrefix={arXiv}, 122 | primaryClass={cs.CV}, 123 | url={https://arxiv.org/abs/2506.02528}, 124 | } 125 | ``` -------------------------------------------------------------------------------- /attention_processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from diffusers.models.normalization import RMSNorm 5 | from typing import Callable, List, Optional, Tuple, Union 6 | 7 | import numpy as np 8 | from PIL import Image 9 | 10 | class IPAFluxAttnProcessor2_0(nn.Module): 11 | """Attention processor used typically in processing the SD3-like self-attention projections.""" 12 | 13 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): 14 | super().__init__() 15 | 16 | self.hidden_size = hidden_size # 3072 17 | self.cross_attention_dim = cross_attention_dim # 4096 18 | self.scale = scale 19 | self.num_tokens = num_tokens 20 | 21 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 22 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 23 | 24 | self.norm_added_k = RMSNorm(128, eps=1e-5, elementwise_affine=False) 25 | 26 | def __call__( 27 | self, 28 | attn, 29 | hidden_states: torch.FloatTensor, 30 | image_emb: torch.FloatTensor, 31 | encoder_hidden_states: torch.FloatTensor = None, 32 | attention_mask: Optional[torch.FloatTensor] = None, 33 | image_rotary_emb: Optional[torch.Tensor] = None, 34 | ) -> torch.FloatTensor: 35 | batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 36 | 37 | # `sample` projections. 38 | query = attn.to_q(hidden_states) 39 | key = attn.to_k(hidden_states) 40 | value = attn.to_v(hidden_states) 41 | 42 | inner_dim = key.shape[-1] 43 | head_dim = inner_dim // attn.heads 44 | 45 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # torch.Size([1, 24, 4800, 128]) 46 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 47 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 48 | 49 | if attn.norm_q is not None: 50 | query = attn.norm_q(query) 51 | if attn.norm_k is not None: 52 | key = attn.norm_k(key) 53 | 54 | if image_emb is not None: 55 | # `ip-adapter` projections 56 | ip_hidden_states = image_emb 57 | ip_hidden_states_key_proj = self.to_k_ip(ip_hidden_states) 58 | ip_hidden_states_value_proj = self.to_v_ip(ip_hidden_states) 59 | 60 | ip_hidden_states_key_proj = ip_hidden_states_key_proj.view( 61 | batch_size, -1, attn.heads, head_dim 62 | ).transpose(1, 2) 63 | ip_hidden_states_value_proj = ip_hidden_states_value_proj.view( 64 | batch_size, -1, attn.heads, head_dim 65 | ).transpose(1, 2) 66 | 67 | ip_hidden_states_key_proj = self.norm_added_k(ip_hidden_states_key_proj) 68 | 69 | ip_hidden_states = F.scaled_dot_product_attention(query, 70 | ip_hidden_states_key_proj, 71 | ip_hidden_states_value_proj, 72 | dropout_p=0.0, is_causal=False) 73 | 74 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 75 | ip_hidden_states = ip_hidden_states.to(query.dtype) 76 | 77 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` 78 | if encoder_hidden_states is not None: 79 | 80 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 81 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 82 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 83 | 84 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( 85 | batch_size, -1, attn.heads, head_dim 86 | ).transpose(1, 2) 87 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( 88 | batch_size, -1, attn.heads, head_dim 89 | ).transpose(1, 2) 90 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( 91 | batch_size, -1, attn.heads, head_dim 92 | ).transpose(1, 2) 93 | 94 | if attn.norm_added_q is not None: 95 | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) 96 | if attn.norm_added_k is not None: 97 | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) 98 | 99 | # attention 100 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) 101 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) 102 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) # (512+3840,128) 103 | 104 | if image_rotary_emb is not None: 105 | from diffusers.models.embeddings import apply_rotary_emb 106 | 107 | query = apply_rotary_emb(query, image_rotary_emb) 108 | key = apply_rotary_emb(key, image_rotary_emb) 109 | 110 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) 111 | 112 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 113 | hidden_states = hidden_states.to(query.dtype) 114 | 115 | if encoder_hidden_states is not None: 116 | 117 | encoder_hidden_states, hidden_states = ( 118 | hidden_states[:, : encoder_hidden_states.shape[1]], 119 | hidden_states[:, encoder_hidden_states.shape[1] :], 120 | ) 121 | if image_emb is not None: 122 | hidden_states = hidden_states + self.scale * ip_hidden_states 123 | 124 | # linear proj 125 | hidden_states = attn.to_out[0](hidden_states) 126 | # dropout 127 | hidden_states = attn.to_out[1](hidden_states) 128 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 129 | 130 | return hidden_states, encoder_hidden_states 131 | else: 132 | if image_emb is not None: 133 | hidden_states = hidden_states + self.scale * ip_hidden_states 134 | 135 | return hidden_states -------------------------------------------------------------------------------- /src/prompt_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def load_text_encoders(args, class_one, class_two): 5 | text_encoder_one = class_one.from_pretrained( 6 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant 7 | ) 8 | text_encoder_two = class_two.from_pretrained( 9 | args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant 10 | ) 11 | return text_encoder_one, text_encoder_two 12 | 13 | 14 | def tokenize_prompt(tokenizer, prompt, max_sequence_length): 15 | text_inputs = tokenizer( 16 | prompt, 17 | padding="max_length", 18 | max_length=max_sequence_length, 19 | truncation=True, 20 | return_length=False, 21 | return_overflowing_tokens=False, 22 | return_tensors="pt", 23 | ) 24 | text_input_ids = text_inputs.input_ids 25 | return text_input_ids 26 | 27 | 28 | def tokenize_prompt_clip(tokenizer, prompt): 29 | text_inputs = tokenizer( 30 | prompt, 31 | padding="max_length", 32 | max_length=77, 33 | truncation=True, 34 | return_length=False, 35 | return_overflowing_tokens=False, 36 | return_tensors="pt", 37 | ) 38 | text_input_ids = text_inputs.input_ids 39 | return text_input_ids 40 | 41 | 42 | def tokenize_prompt_t5(tokenizer, prompt): 43 | text_inputs = tokenizer( 44 | prompt, 45 | padding="max_length", 46 | max_length=512, 47 | truncation=True, 48 | return_length=False, 49 | return_overflowing_tokens=False, 50 | return_tensors="pt", 51 | ) 52 | text_input_ids = text_inputs.input_ids 53 | return text_input_ids 54 | 55 | 56 | def _encode_prompt_with_t5( 57 | text_encoder, 58 | tokenizer, 59 | max_sequence_length=512, 60 | prompt=None, 61 | num_images_per_prompt=1, 62 | device=None, 63 | text_input_ids=None, 64 | ): 65 | prompt = [prompt] if isinstance(prompt, str) else prompt 66 | batch_size = len(prompt) 67 | 68 | if tokenizer is not None: 69 | text_inputs = tokenizer( 70 | prompt, 71 | padding="max_length", 72 | max_length=max_sequence_length, 73 | truncation=True, 74 | return_length=False, 75 | return_overflowing_tokens=False, 76 | return_tensors="pt", 77 | ) 78 | text_input_ids = text_inputs.input_ids 79 | else: 80 | if text_input_ids is None: 81 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified") 82 | 83 | prompt_embeds = text_encoder(text_input_ids.to(device))[0] 84 | 85 | dtype = text_encoder.dtype 86 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 87 | 88 | _, seq_len, _ = prompt_embeds.shape 89 | 90 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 91 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 92 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 93 | 94 | return prompt_embeds 95 | 96 | 97 | def _encode_prompt_with_clip( 98 | text_encoder, 99 | tokenizer, 100 | prompt: str, 101 | device=None, 102 | text_input_ids=None, 103 | num_images_per_prompt: int = 1, 104 | ): 105 | prompt = [prompt] if isinstance(prompt, str) else prompt 106 | batch_size = len(prompt) 107 | 108 | if tokenizer is not None: 109 | text_inputs = tokenizer( 110 | prompt, 111 | padding="max_length", 112 | max_length=77, 113 | truncation=True, 114 | return_overflowing_tokens=False, 115 | return_length=False, 116 | return_tensors="pt", 117 | ) 118 | 119 | text_input_ids = text_inputs.input_ids 120 | else: 121 | if text_input_ids is None: 122 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified") 123 | 124 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) 125 | 126 | # Use pooled output of CLIPTextModel 127 | prompt_embeds = prompt_embeds.pooler_output 128 | prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) 129 | 130 | # duplicate text embeddings for each generation per prompt, using mps friendly method 131 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 132 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 133 | 134 | return prompt_embeds 135 | 136 | 137 | def encode_prompt( 138 | text_encoders, 139 | tokenizers, 140 | prompt: str, 141 | max_sequence_length, 142 | device=None, 143 | num_images_per_prompt: int = 1, 144 | text_input_ids_list=None, 145 | ): 146 | prompt = [prompt] if isinstance(prompt, str) else prompt 147 | dtype = text_encoders[0].dtype 148 | 149 | pooled_prompt_embeds = _encode_prompt_with_clip( 150 | text_encoder=text_encoders[0], 151 | tokenizer=tokenizers[0], 152 | prompt=prompt, 153 | device=device if device is not None else text_encoders[0].device, 154 | num_images_per_prompt=num_images_per_prompt, 155 | text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, 156 | ) 157 | 158 | prompt_embeds = _encode_prompt_with_t5( 159 | text_encoder=text_encoders[1], 160 | tokenizer=tokenizers[1], 161 | max_sequence_length=max_sequence_length, 162 | prompt=prompt, 163 | num_images_per_prompt=num_images_per_prompt, 164 | device=device if device is not None else text_encoders[1].device, 165 | text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, 166 | ) 167 | 168 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) 169 | 170 | return prompt_embeds, pooled_prompt_embeds, text_ids 171 | 172 | 173 | def encode_token_ids(text_encoders, tokens, accelerator, num_images_per_prompt=1, device=None): 174 | text_encoder_clip = text_encoders[0] 175 | text_encoder_t5 = text_encoders[1] 176 | tokens_clip, tokens_t5 = tokens[0], tokens[1] 177 | batch_size = tokens_clip.shape[0] 178 | 179 | if device == "cpu": 180 | device = "cpu" 181 | else: 182 | device = accelerator.device 183 | 184 | # clip 185 | prompt_embeds = text_encoder_clip(tokens_clip.to(device), output_hidden_states=False) 186 | # Use pooled output of CLIPTextModelpreprocess_train 187 | prompt_embeds = prompt_embeds.pooler_output 188 | prompt_embeds = prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device) 189 | # duplicate text embeddings for each generation per prompt, using mps friendly method 190 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 191 | pooled_prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 192 | pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device) 193 | 194 | # t5 195 | prompt_embeds = text_encoder_t5(tokens_t5.to(device))[0] 196 | dtype = text_encoder_t5.dtype 197 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=accelerator.device) 198 | _, seq_len, _ = prompt_embeds.shape 199 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 200 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 201 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 202 | 203 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=dtype) 204 | 205 | return prompt_embeds, pooled_prompt_embeds, text_ids -------------------------------------------------------------------------------- /src/jsonl_datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from datasets import load_dataset 4 | from torchvision import transforms 5 | import random 6 | import os 7 | import numpy as np 8 | from transformers import AutoProcessor 9 | import torchvision.transforms.functional as F 10 | 11 | Image.MAX_IMAGE_PIXELS = None 12 | 13 | def make_train_dataset(args, tokenizer, accelerator=None): 14 | if args.train_data_dir is not None: 15 | print("loading dataset ... ") 16 | dataset = load_dataset('json', data_files=args.train_data_dir) 17 | base_path = os.path.dirname(os.path.abspath(args.train_data_dir)) 18 | 19 | column_names = dataset["train"].column_names 20 | 21 | # 6. Get the column names for input/target. 22 | if args.caption_column is None: 23 | caption_column = column_names[0] 24 | print(f"caption column defaulting to {caption_column}") 25 | else: 26 | caption_column = args.caption_column 27 | if caption_column not in column_names: 28 | raise ValueError( 29 | f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" 30 | ) 31 | if args.source_column is None: 32 | source_column = column_names[1] 33 | print(f"source column defaulting to {source_column}") 34 | else: 35 | source_column = args.source_column 36 | if source_column not in column_names: 37 | raise ValueError( 38 | f"`--source_column` value '{args.source_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" 39 | ) 40 | if args.target_column is None: 41 | target_column = column_names[2] 42 | print(f"target column defaulting to {target_column}") 43 | else: 44 | target_column = args.target_column 45 | if target_column not in column_names: 46 | raise ValueError( 47 | f"`--target_column` value '{args.target_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" 48 | ) 49 | #Add two images as columnname 50 | if args.ipa_source_column is None: 51 | ipa_source_column = column_names[3] 52 | print(f"ipa-source column defaulting to {ipa_source_column}") 53 | else: 54 | ipa_source_column = args.ipa_source_column 55 | if ipa_source_column not in column_names: 56 | raise ValueError( 57 | f"`--ipa_source_column` value '{args.ipa_source_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" 58 | ) 59 | if args.ipa_target_column is None: 60 | ipa_target_column = column_names[4] 61 | print(f"ipa-target column defaulting to {ipa_target_column}") 62 | else: 63 | ipa_target_column = args.ipa_target_column 64 | if ipa_target_column not in column_names: 65 | raise ValueError( 66 | f"`--ipa_target_column` value '{args.ipa_target_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" 67 | ) 68 | 69 | def resize_long_side(img, target_long_side, interpolation=transforms.InterpolationMode.BILINEAR): 70 | w, h = img.size 71 | if w >= h: 72 | new_w = target_long_side 73 | new_h = int(target_long_side * h / w) 74 | else: 75 | new_h = target_long_side 76 | new_w = int(target_long_side * w / h) 77 | return F.resize(img, (new_h, new_w), interpolation=interpolation) 78 | 79 | train_transforms = transforms.Compose( 80 | [ 81 | transforms.Lambda(lambda img: resize_long_side(img, args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)), 82 | transforms.ToTensor(), 83 | transforms.Normalize([0.5], [0.5]), 84 | ] 85 | ) 86 | 87 | tokenizer_clip = tokenizer[0] 88 | tokenizer_t5 = tokenizer[1] 89 | 90 | def tokenize_prompt_clip_t5(examples): 91 | captions = [] 92 | for caption in examples[caption_column]: 93 | if isinstance(caption, str): 94 | captions.append(caption) 95 | elif isinstance(caption, list): 96 | captions.append(random.choice(caption)) 97 | else: 98 | raise ValueError( 99 | f"Caption column `{caption_column}` should contain either strings or lists of strings." 100 | ) 101 | text_inputs = tokenizer_clip( 102 | captions, 103 | padding="max_length", 104 | max_length=77, 105 | truncation=True, 106 | return_length=False, 107 | return_overflowing_tokens=False, 108 | return_tensors="pt", 109 | ) 110 | text_input_ids_1 = text_inputs.input_ids 111 | 112 | text_inputs = tokenizer_t5( 113 | captions, 114 | padding="max_length", 115 | max_length=512, 116 | truncation=True, 117 | return_length=False, 118 | return_overflowing_tokens=False, 119 | return_tensors="pt", 120 | ) 121 | text_input_ids_2 = text_inputs.input_ids 122 | return text_input_ids_1, text_input_ids_2 123 | 124 | def preprocess_train(examples): 125 | _examples = {} 126 | 127 | source_images = [Image.open(os.path.join(base_path, image)).convert("RGB") 128 | for image in examples[source_column]] 129 | target_images = [Image.open(os.path.join(base_path, image)).convert("RGB") 130 | for image in examples[target_column]] 131 | 132 | _examples["cond_pixel_values"] = [train_transforms(source) for source in source_images] 133 | _examples["pixel_values"] = [train_transforms(image) for image in target_images] 134 | _examples["token_ids_clip"], _examples["token_ids_t5"] = tokenize_prompt_clip_t5(examples) 135 | 136 | # clip pre-processor 137 | # ———————————————————————————————————————————————————————— 138 | clip_image_processor = AutoProcessor.from_pretrained('../models/siglip-so400m-patch14-384') 139 | ipa_source_images = [Image.open(os.path.join(base_path, image)).convert("RGB") 140 | for image in examples[ipa_source_column]] 141 | ipa_target_images = [Image.open(os.path.join(base_path, image)).convert("RGB") 142 | for image in examples[ipa_target_column]] 143 | _examples["ipa_source_images"] = [clip_image_processor(images=source, return_tensors="pt").pixel_values for source in ipa_source_images] 144 | _examples["ipa_target_images"] = [clip_image_processor(images=image, return_tensors="pt").pixel_values for image in ipa_target_images] 145 | 146 | drop_image_embeds = [1 if random.random() < 0.05 else 0 for _ in examples[ipa_target_column]] 147 | _examples["drop_image_embeds"] = drop_image_embeds 148 | # print(f"ipa_source_images[0] shape: {_examples['ipa_source_images'][0].shape}") 149 | # print(f"ipa_target_images[0] shape: {_examples['ipa_target_images'][0].shape}") 150 | 151 | # ————————————————————————————————————————————————————————— 152 | return _examples 153 | 154 | if accelerator is not None: 155 | with accelerator.main_process_first(): 156 | train_dataset = dataset["train"].with_transform(preprocess_train) 157 | else: 158 | train_dataset = dataset["train"].with_transform(preprocess_train) 159 | 160 | return train_dataset 161 | 162 | 163 | def collate_fn(examples): 164 | cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples]) 165 | cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float() 166 | target_pixel_values = torch.stack([example["pixel_values"] for example in examples]) 167 | target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float() 168 | # token_ids_clip = torch.stack([torch.tensor(example["token_ids_clip"]) for example in examples]) 169 | # token_ids_t5 = torch.stack([torch.tensor(example["token_ids_t5"]) for example in examples]) 170 | 171 | token_ids_clip = torch.stack([ 172 | example["token_ids_clip"].clone().detach() if isinstance(example["token_ids_clip"], torch.Tensor) 173 | else torch.tensor(example["token_ids_clip"]) 174 | for example in examples 175 | ]) 176 | token_ids_t5 = torch.stack([ 177 | example["token_ids_t5"].clone().detach() if isinstance(example["token_ids_t5"], torch.Tensor) 178 | else torch.tensor(example["token_ids_t5"]) 179 | for example in examples 180 | ]) 181 | 182 | # ———————————————————————————————————————————————————————————————————————————— 183 | ipa_source_pixel_values = torch.cat([example["ipa_source_images"] for example in examples]) 184 | ipa_source_pixel_values = ipa_source_pixel_values.to(memory_format=torch.contiguous_format).float() 185 | ipa_target_pixel_values = torch.cat([example["ipa_target_images"] for example in examples]) 186 | ipa_target_pixel_values = ipa_target_pixel_values.to(memory_format=torch.contiguous_format).float() 187 | drop_image_embeds = [example["drop_image_embeds"] for example in examples] 188 | return { 189 | "cond_pixel_values": cond_pixel_values, 190 | "pixel_values": target_pixel_values, 191 | "text_ids_1": token_ids_clip, 192 | "text_ids_2": token_ids_t5, 193 | "ipa_source_pixel_values": ipa_source_pixel_values, 194 | "ipa_target_pixel_values": ipa_target_pixel_values, 195 | "drop_image_embeds": drop_image_embeds 196 | } 197 | # ———————————————————————————————————————————————————————————————————————————— -------------------------------------------------------------------------------- /infer_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | from PIL import Image 5 | import json 6 | from tqdm import tqdm # for displaying progress bar 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | # from pipeline_flux_ipa import FluxPipeline 12 | from src.pipeline_pe_clone import FluxPipeline 13 | from transformer_flux import FluxTransformer2DModel 14 | from attention_processor import IPAFluxAttnProcessor2_0 15 | from transformers import AutoProcessor, SiglipVisionModel 16 | 17 | # ======================== Image Resize Function =========================== 18 | def resize_img(input_image, pad_to_regular=False, target_long_side=512, mode=Image.BILINEAR): 19 | w, h = input_image.size 20 | aspect_ratios = [(3, 4), (4, 3), (1, 1), (16, 9), (9, 16)] 21 | 22 | if pad_to_regular: 23 | img_ratio = w / h 24 | 25 | # Find the aspect ratio closest to the original image 26 | best_ratio = min( 27 | aspect_ratios, 28 | key=lambda r: abs((r[0] / r[1]) - img_ratio) 29 | ) 30 | 31 | target_w_ratio, target_h_ratio = best_ratio 32 | if w / h >= target_w_ratio / target_h_ratio: 33 | target_w = w 34 | target_h = int(w * target_h_ratio / target_w_ratio) 35 | else: 36 | target_h = h 37 | target_w = int(h * target_w_ratio / target_h_ratio) 38 | 39 | # Create white background and paste the image centered 40 | padded_img = Image.new("RGB", (target_w, target_h), (255, 255, 255)) 41 | offset_x = (target_w - w) // 2 42 | offset_y = (target_h - h) // 2 43 | padded_img.paste(input_image, (offset_x, offset_y)) 44 | input_image = padded_img 45 | w, h = input_image.size 46 | 47 | # Resize while keeping aspect ratio 48 | scale_ratio = target_long_side / max(w, h) 49 | new_w = round(w * scale_ratio) 50 | new_h = round(h * scale_ratio) 51 | input_image = input_image.resize((new_w, new_h), mode) 52 | 53 | return input_image 54 | 55 | # ======================== MLP Projection Module =========================== 56 | class MLPProjModel(torch.nn.Module): 57 | def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): 58 | super().__init__() 59 | 60 | self.cross_attention_dim = cross_attention_dim 61 | self.num_tokens = num_tokens 62 | 63 | self.proj = torch.nn.Sequential( 64 | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), 65 | torch.nn.GELU(), 66 | torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), 67 | ) 68 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 69 | 70 | def forward(self, id_embeds): 71 | x = self.proj(id_embeds) 72 | x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) 73 | x = self.norm(x) 74 | return x 75 | 76 | # ======================== IPAdapter Wrapper =========================== 77 | class IPAdapter: 78 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): 79 | self.device = device 80 | self.image_encoder_path = image_encoder_path 81 | self.ip_ckpt = ip_ckpt 82 | self.num_tokens = num_tokens 83 | 84 | self.pipe = sd_pipe.to(self.device) 85 | self.set_ip_adapter() 86 | 87 | # Load image encoder 88 | self.image_encoder = SiglipVisionModel.from_pretrained(image_encoder_path).to(self.device, dtype=torch.bfloat16) 89 | self.clip_image_processor = AutoProcessor.from_pretrained(self.image_encoder_path) 90 | 91 | # Initialize image projection model 92 | self.image_proj_model = self.init_proj() 93 | 94 | self.load_ip_adapter() 95 | 96 | def init_proj(self): 97 | image_proj_model = MLPProjModel( 98 | cross_attention_dim=self.pipe.transformer.config.joint_attention_dim, # 4096 99 | id_embeddings_dim=1152, 100 | num_tokens=self.num_tokens, 101 | ).to(self.device, dtype=torch.bfloat16) 102 | 103 | return image_proj_model 104 | 105 | def set_ip_adapter(self): 106 | transformer = self.pipe.transformer 107 | ip_attn_procs = {} # total 57 layers: 19 + 38 108 | for name in transformer.attn_processors.keys(): 109 | if name.startswith("transformer_blocks.") or name.startswith("single_transformer_blocks"): 110 | ip_attn_procs[name] = IPAFluxAttnProcessor2_0( 111 | hidden_size=transformer.config.num_attention_heads * transformer.config.attention_head_dim, 112 | cross_attention_dim=transformer.config.joint_attention_dim, 113 | num_tokens=self.num_tokens, 114 | ).to(self.device, dtype=torch.bfloat16) 115 | else: 116 | ip_attn_procs[name] = transformer.attn_processors[name] 117 | 118 | transformer.set_attn_processor(ip_attn_procs) 119 | 120 | def load_ip_adapter(self): 121 | state_dict = torch.load(self.ip_ckpt, map_location="cpu") 122 | 123 | # ---- 1. Load weights for image_proj_model ---- 124 | image_proj_state_dict = state_dict["image_proj"] 125 | if list(image_proj_state_dict.keys())[0].startswith("module."): 126 | # Remove DataParallel/DDP prefix 127 | image_proj_state_dict = { 128 | k.replace("module.", ""): v for k, v in image_proj_state_dict.items() 129 | } 130 | self.image_proj_model.load_state_dict(image_proj_state_dict, strict=True) 131 | 132 | # ---- 2. Load weights for ip_adapter (attn_processors) ---- 133 | ip_adapter_state_dict = state_dict["ip_adapter"] 134 | if list(ip_adapter_state_dict.keys())[0].startswith("module."): 135 | ip_adapter_state_dict = { 136 | k.replace("module.", ""): v for k, v in ip_adapter_state_dict.items() 137 | } 138 | 139 | ip_layers = torch.nn.ModuleList(self.pipe.transformer.attn_processors.values()) 140 | ip_layers.load_state_dict(ip_adapter_state_dict, strict=False) 141 | 142 | @torch.inference_mode() 143 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None): 144 | if pil_image is not None: 145 | if isinstance(pil_image, Image.Image): 146 | pil_image = [pil_image] 147 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 148 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=self.image_encoder.dtype)).pooler_output 149 | clip_image_embeds = clip_image_embeds.to(dtype=torch.bfloat16) 150 | else: 151 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.bfloat16) 152 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 153 | return image_prompt_embeds 154 | 155 | def set_scale(self, scale): 156 | for attn_processor in self.pipe.transformer.attn_processors.values(): 157 | if isinstance(attn_processor, IPAFluxAttnProcessor2_0): 158 | attn_processor.scale = scale 159 | 160 | def generate( 161 | self, 162 | condition_image=None, 163 | pil_image=None, # supports list or tuple of two PIL images 164 | clip_image_embeds=None, 165 | prompt=None, 166 | scale=1.0, 167 | num_samples=1, 168 | seed=None, 169 | guidance_scale=3.5, 170 | num_inference_steps=24, 171 | **kwargs, 172 | ): 173 | self.set_scale(scale) 174 | 175 | # Support case with two input images 176 | if isinstance(pil_image, (list, tuple)) and len(pil_image) == 2: 177 | image_prompt_embeds1 = self.get_image_embeds(pil_image=pil_image[0]) 178 | image_prompt_embeds2 = self.get_image_embeds(pil_image=pil_image[1]) 179 | image_prompt_embeds = torch.cat([image_prompt_embeds1, image_prompt_embeds2], dim=0) 180 | else: 181 | image_prompt_embeds = self.get_image_embeds( 182 | pil_image=pil_image, clip_image_embeds=clip_image_embeds 183 | ) 184 | 185 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 186 | 187 | images = self.pipe( 188 | prompt=prompt, 189 | condition_image=condition_image, 190 | image_emb=image_prompt_embeds, 191 | guidance_scale=guidance_scale, 192 | num_inference_steps=num_inference_steps, 193 | generator=generator, 194 | **kwargs, 195 | ).images 196 | 197 | return images 198 | 199 | # ======================== Parameter Setup =========================== 200 | BASE_MODEL_PATH = '../models/FLUX.1-dev' 201 | IMAGE_ENCODER_PATH = '../models/siglip-so400m-patch14-384' 202 | IPADAPTER_PATH = '../models/ip_adapter-100000.bin' 203 | LORA_WEIGHTS_PATH = "../models/checkpoint-100000-lora" 204 | LORA_WEIGHTS_FILE = "pytorch_lora_weights.safetensors" 205 | DEVICE = "cuda:0" 206 | 207 | # ========== Sample Image Paths ========== 208 | # cond1_path = "assets/close-eye-cond1.jpg" 209 | # cond2_path = "assets/close-eye-cond2.jpg" 210 | # source_path = "assets/close-eye-src1.jpg" 211 | # prompt = "Apply a closed-eyes expression to the person in the image." 212 | 213 | cond1_path = "assets/close-on-cond1.jpg" 214 | cond2_path = "assets/close-on-cond2.jpg" 215 | source_path = "assets/close-on-src1.jpg" 216 | prompt = "Add a model wearing the black pants along with a taupe polo shirt and white sneakers, standing in a neutral pose." 217 | 218 | # ======================== Model Setup =========================== 219 | transformer = FluxTransformer2DModel.from_pretrained( 220 | BASE_MODEL_PATH, subfolder="transformer", torch_dtype=torch.bfloat16 221 | ) 222 | pipe = FluxPipeline.from_pretrained( 223 | BASE_MODEL_PATH, transformer=transformer, torch_dtype=torch.bfloat16 224 | ) 225 | 226 | pipe.load_lora_weights(LORA_WEIGHTS_PATH, weight_name=LORA_WEIGHTS_FILE) 227 | pipe.fuse_lora() 228 | pipe.unload_lora_weights() 229 | print("LoRA weights loaded ✔️") 230 | 231 | ip_model = IPAdapter(pipe, IMAGE_ENCODER_PATH, IPADAPTER_PATH, device=DEVICE, num_tokens=128) 232 | print("IP-Adapter initialized ✔️") 233 | 234 | # ======================== Image Preparation =========================== 235 | image1 = Image.open(cond1_path).convert("RGB") 236 | image2 = Image.open(cond2_path).convert("RGB") 237 | image1 = resize_img(image1, pad_to_regular=True, target_long_side=512) 238 | image2 = resize_img(image2, pad_to_regular=True, target_long_side=512) 239 | 240 | condition_image = Image.open(source_path).convert("RGB") 241 | width, height = condition_image.size 242 | 243 | # ======================== Inference =========================== 244 | generated_images = ip_model.generate( 245 | prompt=prompt, 246 | condition_image=condition_image, 247 | height=height, 248 | width=width, 249 | pil_image=[image1, image2], 250 | scale=1.0, 251 | seed=1000, 252 | ) 253 | 254 | # ======================== Show Result =========================== 255 | generated_images[0].save("generated.jpg") 256 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /transformer_flux.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Dict, Optional, Tuple, Union 16 | 17 | import numpy as np 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | from diffusers.configuration_utils import ConfigMixin, register_to_config 23 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin 24 | from diffusers.models.attention import FeedForward 25 | from diffusers.models.attention_processor import ( 26 | Attention, 27 | AttentionProcessor, 28 | FluxAttnProcessor2_0, 29 | FusedFluxAttnProcessor2_0, 30 | ) 31 | from diffusers.models.modeling_utils import ModelMixin 32 | from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle 33 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 34 | from diffusers.utils.torch_utils import maybe_allow_in_graph 35 | from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed 36 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 37 | 38 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 39 | 40 | 41 | @maybe_allow_in_graph 42 | class FluxSingleTransformerBlock(nn.Module): 43 | r""" 44 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. 45 | 46 | Reference: https://arxiv.org/abs/2403.03206 47 | 48 | Parameters: 49 | dim (`int`): The number of channels in the input and output. 50 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 51 | attention_head_dim (`int`): The number of channels in each head. 52 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the 53 | processing of `context` conditions. 54 | """ 55 | 56 | def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): 57 | super().__init__() 58 | self.mlp_hidden_dim = int(dim * mlp_ratio) 59 | 60 | self.norm = AdaLayerNormZeroSingle(dim) 61 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) 62 | self.act_mlp = nn.GELU(approximate="tanh") 63 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) 64 | 65 | processor = FluxAttnProcessor2_0() 66 | self.attn = Attention( 67 | query_dim=dim, 68 | cross_attention_dim=None, 69 | dim_head=attention_head_dim, 70 | heads=num_attention_heads, 71 | out_dim=dim, 72 | bias=True, 73 | processor=processor, 74 | qk_norm="rms_norm", 75 | eps=1e-6, 76 | pre_only=True, 77 | ) 78 | 79 | def forward( 80 | self, 81 | hidden_states: torch.FloatTensor, 82 | temb: torch.FloatTensor, 83 | image_emb=None, 84 | image_rotary_emb=None, 85 | ): 86 | residual = hidden_states 87 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb) 88 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) 89 | 90 | attn_output = self.attn( 91 | hidden_states=norm_hidden_states, 92 | image_rotary_emb=image_rotary_emb, 93 | image_emb=image_emb, 94 | ) 95 | 96 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) 97 | gate = gate.unsqueeze(1) # torch.Size([1, 1, 3072]) 98 | hidden_states = gate * self.proj_out(hidden_states) # torch.Size([1, 4352, 3072]) 99 | 100 | hidden_states = residual + hidden_states 101 | if hidden_states.dtype == torch.float16: 102 | hidden_states = hidden_states.clip(-65504, 65504) 103 | 104 | return hidden_states 105 | 106 | 107 | @maybe_allow_in_graph 108 | class FluxTransformerBlock(nn.Module): 109 | r""" 110 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. 111 | 112 | Reference: https://arxiv.org/abs/2403.03206 113 | 114 | Parameters: 115 | dim (`int`): The number of channels in the input and output. 116 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 117 | attention_head_dim (`int`): The number of channels in each head. 118 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the 119 | processing of `context` conditions. 120 | """ 121 | 122 | def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): 123 | super().__init__() 124 | 125 | self.norm1 = AdaLayerNormZero(dim) 126 | 127 | self.norm1_context = AdaLayerNormZero(dim) 128 | 129 | if hasattr(F, "scaled_dot_product_attention"): 130 | processor = FluxAttnProcessor2_0() 131 | else: 132 | raise ValueError( 133 | "The current PyTorch version does not support the `scaled_dot_product_attention` function." 134 | ) 135 | self.attn = Attention( 136 | query_dim=dim, 137 | cross_attention_dim=None, 138 | added_kv_proj_dim=dim, 139 | dim_head=attention_head_dim, 140 | heads=num_attention_heads, 141 | out_dim=dim, 142 | context_pre_only=False, 143 | bias=True, 144 | processor=processor, 145 | qk_norm=qk_norm, 146 | eps=eps, 147 | ) 148 | 149 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 150 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 151 | 152 | self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 153 | self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 154 | 155 | # let chunk size default to None 156 | self._chunk_size = None 157 | self._chunk_dim = 0 158 | 159 | def forward( 160 | self, 161 | hidden_states: torch.FloatTensor, 162 | encoder_hidden_states: torch.FloatTensor, 163 | temb: torch.FloatTensor, 164 | image_emb=None, 165 | image_rotary_emb=None, 166 | ): 167 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) 168 | 169 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( 170 | encoder_hidden_states, emb=temb 171 | ) 172 | 173 | # Attention. 174 | attn_output, context_attn_output = self.attn( 175 | hidden_states=norm_hidden_states, 176 | encoder_hidden_states=norm_encoder_hidden_states, 177 | image_rotary_emb=image_rotary_emb, 178 | image_emb=image_emb, 179 | ) 180 | 181 | # Process attention outputs for the `hidden_states`. 182 | attn_output = gate_msa.unsqueeze(1) * attn_output 183 | hidden_states = hidden_states + attn_output 184 | 185 | norm_hidden_states = self.norm2(hidden_states) 186 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 187 | 188 | ff_output = self.ff(norm_hidden_states) 189 | ff_output = gate_mlp.unsqueeze(1) * ff_output 190 | hidden_states = hidden_states + ff_output 191 | 192 | # Process attention outputs for the `encoder_hidden_states`. 193 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output 194 | encoder_hidden_states = encoder_hidden_states + context_attn_output 195 | 196 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) 197 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] 198 | 199 | context_ff_output = self.ff_context(norm_encoder_hidden_states) 200 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output 201 | if encoder_hidden_states.dtype == torch.float16: 202 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) 203 | 204 | return encoder_hidden_states, hidden_states 205 | 206 | 207 | class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): 208 | """ 209 | The Transformer model introduced in Flux. 210 | 211 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 212 | 213 | Parameters: 214 | patch_size (`int`): Patch size to turn the input data into small patches. 215 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. 216 | num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. 217 | num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. 218 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. 219 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. 220 | joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 221 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. 222 | guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. 223 | """ 224 | 225 | _supports_gradient_checkpointing = True 226 | _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] 227 | 228 | @register_to_config 229 | def __init__( 230 | self, 231 | patch_size: int = 1, 232 | in_channels: int = 64, 233 | num_layers: int = 19, 234 | num_single_layers: int = 38, 235 | attention_head_dim: int = 128, 236 | num_attention_heads: int = 24, 237 | joint_attention_dim: int = 4096, 238 | pooled_projection_dim: int = 768, 239 | guidance_embeds: bool = False, 240 | axes_dims_rope: Tuple[int] = (16, 56, 56), 241 | ): 242 | super().__init__() 243 | self.out_channels = in_channels 244 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim 245 | 246 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) 247 | 248 | text_time_guidance_cls = ( 249 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings 250 | ) 251 | self.time_text_embed = text_time_guidance_cls( 252 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim 253 | ) 254 | 255 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) 256 | self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) 257 | 258 | self.transformer_blocks = nn.ModuleList( 259 | [ 260 | FluxTransformerBlock( 261 | dim=self.inner_dim, 262 | num_attention_heads=self.config.num_attention_heads, 263 | attention_head_dim=self.config.attention_head_dim, 264 | ) 265 | for i in range(self.config.num_layers) 266 | ] 267 | ) 268 | 269 | self.single_transformer_blocks = nn.ModuleList( 270 | [ 271 | FluxSingleTransformerBlock( 272 | dim=self.inner_dim, 273 | num_attention_heads=self.config.num_attention_heads, 274 | attention_head_dim=self.config.attention_head_dim, 275 | ) 276 | for i in range(self.config.num_single_layers) 277 | ] 278 | ) 279 | 280 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) 281 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) 282 | 283 | self.gradient_checkpointing = False 284 | 285 | @property 286 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 287 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 288 | r""" 289 | Returns: 290 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 291 | indexed by its weight name. 292 | """ 293 | # set recursively 294 | processors = {} 295 | 296 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 297 | if hasattr(module, "get_processor"): 298 | processors[f"{name}.processor"] = module.get_processor() 299 | 300 | for sub_name, child in module.named_children(): 301 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 302 | 303 | return processors 304 | 305 | for name, module in self.named_children(): 306 | fn_recursive_add_processors(name, module, processors) 307 | 308 | return processors 309 | 310 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 311 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 312 | r""" 313 | Sets the attention processor to use to compute attention. 314 | 315 | Parameters: 316 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 317 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 318 | for **all** `Attention` layers. 319 | 320 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 321 | processor. This is strongly recommended when setting trainable attention processors. 322 | 323 | """ 324 | count = len(self.attn_processors.keys()) 325 | 326 | if isinstance(processor, dict) and len(processor) != count: 327 | raise ValueError( 328 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 329 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 330 | ) 331 | 332 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 333 | if hasattr(module, "set_processor"): 334 | if not isinstance(processor, dict): 335 | module.set_processor(processor) 336 | else: 337 | module.set_processor(processor.pop(f"{name}.processor")) 338 | 339 | for sub_name, child in module.named_children(): 340 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 341 | 342 | for name, module in self.named_children(): 343 | fn_recursive_attn_processor(name, module, processor) 344 | 345 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 346 | def fuse_qkv_projections(self): 347 | """ 348 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) 349 | are fused. For cross-attention modules, key and value projection matrices are fused. 350 | 351 | 352 | 353 | This API is 🧪 experimental. 354 | 355 | 356 | """ 357 | self.original_attn_processors = None 358 | 359 | for _, attn_processor in self.attn_processors.items(): 360 | if "Added" in str(attn_processor.__class__.__name__): 361 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") 362 | 363 | self.original_attn_processors = self.attn_processors 364 | 365 | for module in self.modules(): 366 | if isinstance(module, Attention): 367 | module.fuse_projections(fuse=True) 368 | 369 | self.set_attn_processor(FusedFluxAttnProcessor2_0()) 370 | 371 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 372 | def unfuse_qkv_projections(self): 373 | """Disables the fused QKV projection if enabled. 374 | 375 | 376 | 377 | This API is 🧪 experimental. 378 | 379 | 380 | 381 | """ 382 | if self.original_attn_processors is not None: 383 | self.set_attn_processor(self.original_attn_processors) 384 | 385 | def _set_gradient_checkpointing(self, module, value=False): 386 | if hasattr(module, "gradient_checkpointing"): 387 | module.gradient_checkpointing = value 388 | 389 | def forward( 390 | self, 391 | hidden_states: torch.Tensor, 392 | encoder_hidden_states: torch.Tensor = None, 393 | image_emb: torch.FloatTensor = None, 394 | pooled_projections: torch.Tensor = None, 395 | timestep: torch.LongTensor = None, 396 | img_ids: torch.Tensor = None, 397 | txt_ids: torch.Tensor = None, 398 | guidance: torch.Tensor = None, 399 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 400 | controlnet_block_samples=None, 401 | controlnet_single_block_samples=None, 402 | return_dict: bool = True, 403 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: 404 | """ 405 | The [`FluxTransformer2DModel`] forward method. 406 | 407 | Args: 408 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): 409 | Input `hidden_states`. 410 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): 411 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. 412 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected 413 | from the embeddings of input conditions. 414 | timestep ( `torch.LongTensor`): 415 | Used to indicate denoising step. 416 | block_controlnet_hidden_states: (`list` of `torch.Tensor`): 417 | A list of tensors that if specified are added to the residuals of transformer blocks. 418 | joint_attention_kwargs (`dict`, *optional*): 419 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 420 | `self.processor` in 421 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 422 | return_dict (`bool`, *optional*, defaults to `True`): 423 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 424 | tuple. 425 | 426 | Returns: 427 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 428 | `tuple` where the first element is the sample tensor. 429 | """ 430 | if joint_attention_kwargs is not None: 431 | joint_attention_kwargs = joint_attention_kwargs.copy() 432 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 433 | else: 434 | lora_scale = 1.0 435 | 436 | if USE_PEFT_BACKEND: 437 | # weight the lora layers by setting `lora_scale` for each PEFT layer 438 | scale_lora_layers(self, lora_scale) 439 | else: 440 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 441 | logger.warning( 442 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 443 | ) 444 | hidden_states = self.x_embedder(hidden_states) 445 | 446 | timestep = timestep.to(hidden_states.dtype) * 1000 447 | if guidance is not None: 448 | guidance = guidance.to(hidden_states.dtype) * 1000 449 | else: 450 | guidance = None 451 | temb = ( 452 | self.time_text_embed(timestep, pooled_projections) 453 | if guidance is None 454 | else self.time_text_embed(timestep, guidance, pooled_projections) 455 | ) 456 | # torch.Size([1, 512*num_prompt, 4096]) -> torch.Size([1, 512*num_prompt, 3072]) 457 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 458 | 459 | if txt_ids.ndim == 3: 460 | logger.warning( 461 | "Passing `txt_ids` 3d torch.Tensor is deprecated." 462 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 463 | ) 464 | txt_ids = txt_ids[0] 465 | if img_ids.ndim == 3: 466 | logger.warning( 467 | "Passing `img_ids` 3d torch.Tensor is deprecated." 468 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 469 | ) 470 | img_ids = img_ids[0] 471 | 472 | ids = torch.cat((txt_ids, img_ids), dim=0) 473 | image_rotary_emb = self.pos_embed(ids) 474 | 475 | for index_block, block in enumerate(self.transformer_blocks): 476 | if self.training and self.gradient_checkpointing: 477 | 478 | def create_custom_forward(module, return_dict=None): 479 | def custom_forward(*inputs): 480 | if return_dict is not None: 481 | return module(*inputs, return_dict=return_dict) 482 | else: 483 | return module(*inputs) 484 | 485 | return custom_forward 486 | 487 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 488 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 489 | create_custom_forward(block), 490 | hidden_states, 491 | encoder_hidden_states, 492 | temb, 493 | image_emb, 494 | image_rotary_emb, 495 | **ckpt_kwargs, 496 | ) 497 | 498 | else: 499 | encoder_hidden_states, hidden_states = block( 500 | hidden_states=hidden_states, 501 | encoder_hidden_states=encoder_hidden_states, 502 | temb=temb, 503 | image_emb=image_emb, 504 | image_rotary_emb=image_rotary_emb, 505 | ) 506 | 507 | # controlnet residual 508 | if controlnet_block_samples is not None: 509 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) 510 | interval_control = int(np.ceil(interval_control)) 511 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] 512 | 513 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 514 | 515 | for index_block, block in enumerate(self.single_transformer_blocks): 516 | if self.training and self.gradient_checkpointing: 517 | 518 | def create_custom_forward(module, return_dict=None): 519 | def custom_forward(*inputs): 520 | if return_dict is not None: 521 | return module(*inputs, return_dict=return_dict) 522 | else: 523 | return module(*inputs) 524 | 525 | return custom_forward 526 | 527 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 528 | hidden_states = torch.utils.checkpoint.checkpoint( 529 | create_custom_forward(block), 530 | hidden_states, 531 | temb, 532 | image_emb, 533 | image_rotary_emb, 534 | **ckpt_kwargs, 535 | ) 536 | 537 | else: 538 | hidden_states = block( 539 | hidden_states=hidden_states, 540 | temb=temb, 541 | image_emb=image_emb, 542 | image_rotary_emb=image_rotary_emb, 543 | ) 544 | 545 | # controlnet residual 546 | if controlnet_single_block_samples is not None: 547 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) 548 | interval_control = int(np.ceil(interval_control)) 549 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 550 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] 551 | + controlnet_single_block_samples[index_block // interval_control] 552 | ) 553 | 554 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 555 | 556 | hidden_states = self.norm_out(hidden_states, temb) 557 | output = self.proj_out(hidden_states) 558 | 559 | if USE_PEFT_BACKEND: 560 | # remove `lora_scale` from each PEFT layer 561 | unscale_lora_layers(self, lora_scale) 562 | 563 | if not return_dict: 564 | return (output,) 565 | 566 | return Transformer2DModelOutput(sample=output) -------------------------------------------------------------------------------- /src/pipeline_pe_clone.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Callable, Dict, List, Optional, Union 3 | 4 | import numpy as np 5 | import torch 6 | from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast 7 | 8 | from diffusers.image_processor import (VaeImageProcessor) 9 | from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin 10 | from diffusers.models.autoencoders import AutoencoderKL 11 | from diffusers.models.transformers import FluxTransformer2DModel 12 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 13 | from diffusers.utils import ( 14 | USE_PEFT_BACKEND, 15 | is_torch_xla_available, 16 | logging, 17 | scale_lora_layers, 18 | unscale_lora_layers, 19 | ) 20 | from diffusers.utils.torch_utils import randn_tensor 21 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 22 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput 23 | 24 | if is_torch_xla_available(): 25 | import torch_xla.core.xla_model as xm 26 | 27 | XLA_AVAILABLE = True 28 | else: 29 | XLA_AVAILABLE = False 30 | 31 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 32 | 33 | def calculate_shift( 34 | image_seq_len, 35 | base_seq_len: int = 256, 36 | max_seq_len: int = 4096, 37 | base_shift: float = 0.5, 38 | max_shift: float = 1.16, 39 | ): 40 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) 41 | b = base_shift - m * base_seq_len 42 | mu = image_seq_len * m + b 43 | return mu 44 | 45 | def prepare_latent_image_ids_2(height, width, device, dtype): 46 | latent_image_ids = torch.zeros(height//2, width//2, 3, device=device, dtype=dtype) 47 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height//2, device=device)[:, None] # y坐标 48 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width//2, device=device)[None, :] # x坐标 49 | return latent_image_ids 50 | 51 | def position_encoding_clone(batch_size, original_height, original_width, device, dtype): 52 | latent_image_ids = prepare_latent_image_ids_2(original_height, original_width, device, dtype) 53 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape 54 | latent_image_ids = latent_image_ids.reshape( 55 | latent_image_id_height * latent_image_id_width, latent_image_id_channels 56 | ) 57 | cond_latent_image_ids = latent_image_ids 58 | latent_image_ids = torch.concat([latent_image_ids, cond_latent_image_ids], dim=-2) 59 | return latent_image_ids 60 | 61 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 62 | def retrieve_latents( 63 | encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" 64 | ): 65 | if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": 66 | return encoder_output.latent_dist.sample(generator) 67 | elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": 68 | return encoder_output.latent_dist.mode() 69 | elif hasattr(encoder_output, "latents"): 70 | return encoder_output.latents 71 | else: 72 | raise AttributeError("Could not access latents of provided encoder_output") 73 | 74 | 75 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 76 | def retrieve_timesteps( 77 | scheduler, 78 | num_inference_steps: Optional[int] = None, 79 | device: Optional[Union[str, torch.device]] = None, 80 | timesteps: Optional[List[int]] = None, 81 | sigmas: Optional[List[float]] = None, 82 | **kwargs, 83 | ): 84 | """ 85 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 86 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 87 | 88 | Args: 89 | scheduler (`SchedulerMixin`): 90 | The scheduler to get timesteps from. 91 | num_inference_steps (`int`): 92 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 93 | must be `None`. 94 | device (`str` or `torch.device`, *optional*): 95 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 96 | timesteps (`List[int]`, *optional*): 97 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 98 | `num_inference_steps` and `sigmas` must be `None`. 99 | sigmas (`List[float]`, *optional*): 100 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 101 | `num_inference_steps` and `timesteps` must be `None`. 102 | 103 | Returns: 104 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 105 | second element is the number of inference steps. 106 | """ 107 | if timesteps is not None and sigmas is not None: 108 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 109 | if timesteps is not None: 110 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 111 | if not accepts_timesteps: 112 | raise ValueError( 113 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 114 | f" timestep schedules. Please check whether you are using the correct scheduler." 115 | ) 116 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 117 | timesteps = scheduler.timesteps 118 | num_inference_steps = len(timesteps) 119 | elif sigmas is not None: 120 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 121 | if not accept_sigmas: 122 | raise ValueError( 123 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 124 | f" sigmas schedules. Please check whether you are using the correct scheduler." 125 | ) 126 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 127 | timesteps = scheduler.timesteps 128 | num_inference_steps = len(timesteps) 129 | else: 130 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 131 | timesteps = scheduler.timesteps 132 | return timesteps, num_inference_steps 133 | 134 | 135 | class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): 136 | r""" 137 | The Flux pipeline for text-to-image generation. 138 | 139 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 140 | 141 | Args: 142 | transformer ([`FluxTransformer2DModel`]): 143 | Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. 144 | scheduler ([`FlowMatchEulerDiscreteScheduler`]): 145 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 146 | vae ([`AutoencoderKL`]): 147 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 148 | text_encoder ([`CLIPTextModel`]): 149 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 150 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 151 | text_encoder_2 ([`T5EncoderModel`]): 152 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically 153 | the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. 154 | tokenizer (`CLIPTokenizer`): 155 | Tokenizer of class 156 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). 157 | tokenizer_2 (`T5TokenizerFast`): 158 | Second Tokenizer of class 159 | [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). 160 | """ 161 | 162 | model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" 163 | _optional_components = [] 164 | _callback_tensor_inputs = ["latents", "prompt_embeds"] 165 | 166 | def __init__( 167 | self, 168 | scheduler: FlowMatchEulerDiscreteScheduler, 169 | vae: AutoencoderKL, 170 | text_encoder: CLIPTextModel, 171 | tokenizer: CLIPTokenizer, 172 | text_encoder_2: T5EncoderModel, 173 | tokenizer_2: T5TokenizerFast, 174 | transformer: FluxTransformer2DModel, 175 | ): 176 | super().__init__() 177 | 178 | self.register_modules( 179 | vae=vae, 180 | text_encoder=text_encoder, 181 | text_encoder_2=text_encoder_2, 182 | tokenizer=tokenizer, 183 | tokenizer_2=tokenizer_2, 184 | transformer=transformer, 185 | scheduler=scheduler, 186 | ) 187 | self.vae_scale_factor = ( 188 | 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 189 | ) 190 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 191 | self.tokenizer_max_length = ( 192 | self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 193 | ) 194 | self.default_sample_size = 64 195 | 196 | def _get_t5_prompt_embeds( 197 | self, 198 | prompt: Union[str, List[str]] = None, 199 | num_images_per_prompt: int = 1, 200 | max_sequence_length: int = 512, 201 | device: Optional[torch.device] = None, 202 | dtype: Optional[torch.dtype] = None, 203 | ): 204 | device = device or self._execution_device 205 | dtype = dtype or self.text_encoder.dtype 206 | 207 | prompt = [prompt] if isinstance(prompt, str) else prompt 208 | batch_size = len(prompt) 209 | 210 | text_inputs = self.tokenizer_2( 211 | prompt, 212 | padding="max_length", 213 | max_length=max_sequence_length, 214 | truncation=True, 215 | return_length=False, 216 | return_overflowing_tokens=False, 217 | return_tensors="pt", 218 | ) 219 | text_input_ids = text_inputs.input_ids 220 | untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids 221 | 222 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 223 | removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1]) 224 | logger.warning( 225 | "The following part of your input was truncated because `max_sequence_length` is set to " 226 | f" {max_sequence_length} tokens: {removed_text}" 227 | ) 228 | 229 | prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] 230 | 231 | dtype = self.text_encoder_2.dtype 232 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 233 | 234 | _, seq_len, _ = prompt_embeds.shape 235 | 236 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 237 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 238 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 239 | 240 | return prompt_embeds 241 | 242 | def _get_clip_prompt_embeds( 243 | self, 244 | prompt: Union[str, List[str]], 245 | num_images_per_prompt: int = 1, 246 | device: Optional[torch.device] = None, 247 | ): 248 | device = device or self._execution_device 249 | 250 | prompt = [prompt] if isinstance(prompt, str) else prompt 251 | batch_size = len(prompt) 252 | 253 | text_inputs = self.tokenizer( 254 | prompt, 255 | padding="max_length", 256 | max_length=self.tokenizer_max_length, 257 | truncation=True, 258 | return_overflowing_tokens=False, 259 | return_length=False, 260 | return_tensors="pt", 261 | ) 262 | 263 | text_input_ids = text_inputs.input_ids 264 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 265 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 266 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1]) 267 | logger.warning( 268 | "The following part of your input was truncated because CLIP can only handle sequences up to" 269 | f" {self.tokenizer_max_length} tokens: {removed_text}" 270 | ) 271 | prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) 272 | 273 | # Use pooled output of CLIPTextModel 274 | prompt_embeds = prompt_embeds.pooler_output 275 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 276 | 277 | # duplicate text embeddings for each generation per prompt, using mps friendly method 278 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) 279 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 280 | 281 | return prompt_embeds 282 | 283 | def encode_prompt( 284 | self, 285 | prompt: Union[str, List[str]], 286 | prompt_2: Union[str, List[str]], 287 | device: Optional[torch.device] = None, 288 | num_images_per_prompt: int = 1, 289 | prompt_embeds: Optional[torch.FloatTensor] = None, 290 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 291 | max_sequence_length: int = 512, 292 | lora_scale: Optional[float] = None, 293 | ): 294 | r""" 295 | 296 | Args: 297 | prompt (`str` or `List[str]`, *optional*): 298 | prompt to be encoded 299 | prompt_2 (`str` or `List[str]`, *optional*): 300 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 301 | used in all text-encoders 302 | device: (`torch.device`): 303 | torch device 304 | num_images_per_prompt (`int`): 305 | number of images that should be generated per prompt 306 | prompt_embeds (`torch.FloatTensor`, *optional*): 307 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 308 | provided, text embeddings will be generated from `prompt` input argument. 309 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 310 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 311 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 312 | lora_scale (`float`, *optional*): 313 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 314 | """ 315 | device = device or self._execution_device 316 | 317 | # set lora scale so that monkey patched LoRA 318 | # function of text encoder can correctly access it 319 | if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): 320 | self._lora_scale = lora_scale 321 | 322 | # dynamically adjust the LoRA scale 323 | if self.text_encoder is not None and USE_PEFT_BACKEND: 324 | scale_lora_layers(self.text_encoder, lora_scale) 325 | if self.text_encoder_2 is not None and USE_PEFT_BACKEND: 326 | scale_lora_layers(self.text_encoder_2, lora_scale) 327 | 328 | prompt = [prompt] if isinstance(prompt, str) else prompt 329 | 330 | if prompt_embeds is None: 331 | prompt_2 = prompt_2 or prompt 332 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 333 | 334 | # We only use the pooled prompt output from the CLIPTextModel 335 | pooled_prompt_embeds = self._get_clip_prompt_embeds( 336 | prompt=prompt, 337 | device=device, 338 | num_images_per_prompt=num_images_per_prompt, 339 | ) 340 | prompt_embeds = self._get_t5_prompt_embeds( 341 | prompt=prompt_2, 342 | num_images_per_prompt=num_images_per_prompt, 343 | max_sequence_length=max_sequence_length, 344 | device=device, 345 | ) 346 | 347 | if self.text_encoder is not None: 348 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 349 | # Retrieve the original scale by scaling back the LoRA layers 350 | unscale_lora_layers(self.text_encoder, lora_scale) 351 | 352 | if self.text_encoder_2 is not None: 353 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 354 | # Retrieve the original scale by scaling back the LoRA layers 355 | unscale_lora_layers(self.text_encoder_2, lora_scale) 356 | 357 | dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype 358 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) 359 | 360 | return prompt_embeds, pooled_prompt_embeds, text_ids 361 | 362 | # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image 363 | def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): 364 | if isinstance(generator, list): 365 | image_latents = [ 366 | retrieve_latents(self.vae.encode(image[i: i + 1]), generator=generator[i]) 367 | for i in range(image.shape[0]) 368 | ] 369 | image_latents = torch.cat(image_latents, dim=0) 370 | else: 371 | image_latents = retrieve_latents(self.vae.encode(image), generator=generator) 372 | 373 | image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor 374 | 375 | return image_latents 376 | 377 | def check_inputs( 378 | self, 379 | prompt, 380 | prompt_2, 381 | height, 382 | width, 383 | prompt_embeds=None, 384 | pooled_prompt_embeds=None, 385 | callback_on_step_end_tensor_inputs=None, 386 | max_sequence_length=None, 387 | ): 388 | if height % 8 != 0 or width % 8 != 0: 389 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 390 | 391 | if callback_on_step_end_tensor_inputs is not None and not all( 392 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 393 | ): 394 | raise ValueError( 395 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 396 | ) 397 | 398 | if prompt is not None and prompt_embeds is not None: 399 | raise ValueError( 400 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 401 | " only forward one of the two." 402 | ) 403 | elif prompt_2 is not None and prompt_embeds is not None: 404 | raise ValueError( 405 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 406 | " only forward one of the two." 407 | ) 408 | elif prompt is None and prompt_embeds is None: 409 | raise ValueError( 410 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 411 | ) 412 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 413 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 414 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): 415 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") 416 | 417 | if prompt_embeds is not None and pooled_prompt_embeds is None: 418 | raise ValueError( 419 | "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." 420 | ) 421 | 422 | if max_sequence_length is not None and max_sequence_length > 512: 423 | raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") 424 | 425 | @staticmethod 426 | def _prepare_latent_image_ids(batch_size, height, width, device, dtype): 427 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) 428 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] 429 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] 430 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape 431 | latent_image_ids = latent_image_ids.reshape( 432 | latent_image_id_height * latent_image_id_width, latent_image_id_channels 433 | ) 434 | return latent_image_ids.to(device=device, dtype=dtype) 435 | 436 | @staticmethod 437 | def _pack_latents(latents, batch_size, num_channels_latents, height, width): 438 | latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) 439 | latents = latents.permute(0, 2, 4, 1, 3, 5) 440 | latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) 441 | 442 | return latents 443 | 444 | @staticmethod 445 | def _unpack_latents(latents, height, width, vae_scale_factor): 446 | batch_size, num_patches, channels = latents.shape 447 | 448 | height = height // vae_scale_factor 449 | width = width // vae_scale_factor 450 | 451 | latents = latents.view(batch_size, height, width, channels // 4, 2, 2) 452 | latents = latents.permute(0, 3, 1, 4, 2, 5) 453 | 454 | latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) 455 | 456 | return latents 457 | 458 | def enable_vae_slicing(self): 459 | r""" 460 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 461 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 462 | """ 463 | self.vae.enable_slicing() 464 | 465 | def disable_vae_slicing(self): 466 | r""" 467 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to 468 | computing decoding in one step. 469 | """ 470 | self.vae.disable_slicing() 471 | 472 | def enable_vae_tiling(self): 473 | r""" 474 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 475 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow 476 | processing larger images. 477 | """ 478 | self.vae.enable_tiling() 479 | 480 | def disable_vae_tiling(self): 481 | r""" 482 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to 483 | computing decoding in one step. 484 | """ 485 | self.vae.disable_tiling() 486 | 487 | def prepare_latents( 488 | self, 489 | batch_size, 490 | num_channels_latents, 491 | height, 492 | width, 493 | dtype, 494 | device, 495 | generator, 496 | latents=None, 497 | condition_image=None, 498 | ): 499 | height = 2 * (int(height) // self.vae_scale_factor) 500 | width = 2 * (int(width) // self.vae_scale_factor) 501 | 502 | shape = (batch_size, num_channels_latents, height, width) # 1 16 106 80 503 | 504 | if latents is not None: 505 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 506 | return latents.to(device=device, dtype=dtype), latent_image_ids 507 | 508 | if isinstance(generator, list) and len(generator) != batch_size: 509 | raise ValueError( 510 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 511 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 512 | ) 513 | if condition_image is not None: 514 | condition_image = condition_image.to(device=device, dtype=dtype) 515 | image_latents = self._encode_vae_image(image=condition_image, generator=generator) 516 | if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: 517 | # expand init_latents for batch_size 518 | additional_image_per_prompt = batch_size // image_latents.shape[0] 519 | image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) 520 | elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: 521 | raise ValueError( 522 | f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." 523 | ) 524 | else: 525 | image_latents = torch.cat([image_latents], dim=0) 526 | 527 | # import pdb; pdb.set_trace() 528 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 529 | latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) 530 | cond_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) 531 | latents = torch.concat([latents, cond_latents], dim=-2) 532 | 533 | latent_image_ids = position_encoding_clone(batch_size, height, width, device, dtype) # add position 534 | 535 | mask1 = torch.ones(shape, device=device, dtype=dtype) 536 | mask2 = torch.zeros(shape, device=device, dtype=dtype) 537 | mask1 = self._pack_latents(mask1, batch_size, num_channels_latents, height, width) # 1 4096 64 538 | mask2 = self._pack_latents(mask2, batch_size, num_channels_latents, height, width) # 1 4096 64 539 | mask = torch.concat([mask1, mask2], dim=-2) 540 | return latents, latent_image_ids, mask, cond_latents 541 | 542 | @property 543 | def guidance_scale(self): 544 | return self._guidance_scale 545 | 546 | @property 547 | def joint_attention_kwargs(self): 548 | return self._joint_attention_kwargs 549 | 550 | @property 551 | def num_timesteps(self): 552 | return self._num_timesteps 553 | 554 | @property 555 | def interrupt(self): 556 | return self._interrupt 557 | 558 | @torch.no_grad() 559 | def __call__( 560 | self, 561 | prompt: Union[str, List[str]] = None, 562 | prompt_2: Optional[Union[str, List[str]]] = None, 563 | height: Optional[int] = None, 564 | width: Optional[int] = None, 565 | num_inference_steps: int = 28, 566 | timesteps: List[int] = None, 567 | guidance_scale: float = 3.5, 568 | num_images_per_prompt: Optional[int] = 1, 569 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 570 | latents: Optional[torch.FloatTensor] = None, 571 | prompt_embeds: Optional[torch.FloatTensor] = None, 572 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 573 | image_emb: Optional[torch.FloatTensor] = None, 574 | output_type: Optional[str] = "pil", 575 | return_dict: bool = True, 576 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 577 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 578 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 579 | max_sequence_length: int = 512, 580 | condition_image=None, 581 | ): 582 | height = height or self.default_sample_size * self.vae_scale_factor 583 | width = width or self.default_sample_size * self.vae_scale_factor 584 | 585 | # 1. Check inputs. Raise error if not correct 586 | self.check_inputs( 587 | prompt, 588 | prompt_2, 589 | height, 590 | width, 591 | prompt_embeds=prompt_embeds, 592 | pooled_prompt_embeds=pooled_prompt_embeds, 593 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 594 | max_sequence_length=max_sequence_length, 595 | ) 596 | 597 | self._guidance_scale = guidance_scale 598 | self._joint_attention_kwargs = joint_attention_kwargs 599 | self._interrupt = False 600 | 601 | condition_image = self.image_processor.preprocess(condition_image, height=height, width=width) 602 | condition_image = condition_image.to(dtype=torch.float32) 603 | 604 | # 2. Define call parameters 605 | if prompt is not None and isinstance(prompt, str): 606 | batch_size = 1 607 | elif prompt is not None and isinstance(prompt, list): 608 | batch_size = len(prompt) 609 | else: 610 | batch_size = prompt_embeds.shape[0] 611 | 612 | device = self._execution_device 613 | 614 | lora_scale = ( 615 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 616 | ) 617 | ( 618 | prompt_embeds, 619 | pooled_prompt_embeds, 620 | text_ids, 621 | ) = self.encode_prompt( 622 | prompt=prompt, 623 | prompt_2=prompt_2, 624 | prompt_embeds=prompt_embeds, 625 | pooled_prompt_embeds=pooled_prompt_embeds, 626 | device=device, 627 | num_images_per_prompt=num_images_per_prompt, 628 | max_sequence_length=max_sequence_length, 629 | lora_scale=lora_scale, 630 | ) 631 | 632 | # 4. Prepare latent variables 633 | num_channels_latents = self.transformer.config.in_channels // 4 # 16 634 | latents, latent_image_ids, mask, cond_latents = self.prepare_latents( 635 | batch_size * num_images_per_prompt, 636 | num_channels_latents, 637 | height, 638 | width, 639 | prompt_embeds.dtype, 640 | device, 641 | generator, 642 | latents, 643 | condition_image 644 | ) 645 | clean_latents = latents.clone() 646 | 647 | # 5. Prepare timesteps 648 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 649 | image_seq_len = latents.shape[1] 650 | mu = calculate_shift( 651 | image_seq_len, 652 | self.scheduler.config.base_image_seq_len, 653 | self.scheduler.config.max_image_seq_len, 654 | self.scheduler.config.base_shift, 655 | self.scheduler.config.max_shift, 656 | ) 657 | timesteps, num_inference_steps = retrieve_timesteps( 658 | self.scheduler, 659 | num_inference_steps, 660 | device, 661 | timesteps, 662 | sigmas, 663 | mu=mu, 664 | ) 665 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 666 | self._num_timesteps = len(timesteps) 667 | 668 | # handle guidance 669 | if self.transformer.config.guidance_embeds: 670 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) 671 | guidance = guidance.expand(latents.shape[0]) 672 | else: 673 | guidance = None 674 | 675 | # 6. Denoising loop 676 | with self.progress_bar(total=num_inference_steps) as progress_bar: 677 | for i, t in enumerate(timesteps): 678 | if self.interrupt: 679 | continue 680 | 681 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 682 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 683 | 684 | noise_pred = self.transformer( 685 | hidden_states=latents, # 1 4096 64 686 | timestep=timestep / 1000, 687 | guidance=guidance, 688 | pooled_projections=pooled_prompt_embeds, 689 | encoder_hidden_states=prompt_embeds, 690 | image_emb=image_emb, 691 | txt_ids=text_ids, 692 | img_ids=latent_image_ids, 693 | joint_attention_kwargs=self.joint_attention_kwargs, 694 | return_dict=False, 695 | )[0] 696 | 697 | # compute the previous noisy sample x_t -> x_t-1 698 | latents_dtype = latents.dtype 699 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 700 | latents = latents * mask + clean_latents * (1 - mask) 701 | 702 | if latents.dtype != latents_dtype: 703 | if torch.backends.mps.is_available(): 704 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 705 | latents = latents.to(latents_dtype) 706 | 707 | if callback_on_step_end is not None: 708 | callback_kwargs = {} 709 | for k in callback_on_step_end_tensor_inputs: 710 | callback_kwargs[k] = locals()[k] 711 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 712 | 713 | latents = callback_outputs.pop("latents", latents) 714 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 715 | 716 | # call the callback, if provided 717 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 718 | progress_bar.update() 719 | 720 | if XLA_AVAILABLE: 721 | xm.mark_step() 722 | 723 | if output_type == "latent": 724 | image = latents 725 | 726 | else: 727 | latents = self._unpack_latents(latents[:,:latents.shape[-2]-cond_latents.shape[-2],:], height, width, self.vae_scale_factor) 728 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 729 | image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] 730 | image = self.image_processor.postprocess(image, output_type=output_type) 731 | 732 | # Offload all models 733 | self.maybe_free_model_hooks() 734 | 735 | if not return_dict: 736 | return (image,) 737 | 738 | return FluxPipelineOutput(images=image) -------------------------------------------------------------------------------- /pipeline_flux_ipa.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import inspect 17 | from typing import Any, Callable, Dict, List, Optional, Union 18 | 19 | import numpy as np 20 | import torch 21 | from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast 22 | 23 | from diffusers.image_processor import VaeImageProcessor 24 | from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin 25 | from diffusers.models.autoencoders import AutoencoderKL 26 | from diffusers.models.transformers import FluxTransformer2DModel 27 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 28 | from diffusers.utils import ( 29 | USE_PEFT_BACKEND, 30 | is_torch_xla_available, 31 | logging, 32 | replace_example_docstring, 33 | scale_lora_layers, 34 | unscale_lora_layers, 35 | ) 36 | from diffusers.utils.torch_utils import randn_tensor 37 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 38 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput 39 | 40 | 41 | if is_torch_xla_available(): 42 | import torch_xla.core.xla_model as xm 43 | 44 | XLA_AVAILABLE = True 45 | else: 46 | XLA_AVAILABLE = False 47 | 48 | from PIL import Image 49 | import numpy as np 50 | import torch 51 | import torch.nn.functional as F 52 | 53 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 54 | 55 | EXAMPLE_DOC_STRING = """ 56 | Examples: 57 | ```py 58 | >>> import torch 59 | >>> from diffusers import FluxPipeline 60 | 61 | >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) 62 | >>> pipe.to("cuda") 63 | >>> prompt = "A cat holding a sign that says hello world" 64 | >>> # Depending on the variant being used, the pipeline call will slightly vary. 65 | >>> # Refer to the pipeline documentation for more details. 66 | >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] 67 | >>> image.save("flux.png") 68 | ``` 69 | """ 70 | 71 | 72 | def calculate_shift( 73 | image_seq_len, 74 | base_seq_len: int = 256, 75 | max_seq_len: int = 4096, 76 | base_shift: float = 0.5, 77 | max_shift: float = 1.16, 78 | ): 79 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) 80 | b = base_shift - m * base_seq_len 81 | mu = image_seq_len * m + b 82 | return mu 83 | 84 | 85 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 86 | def retrieve_timesteps( 87 | scheduler, 88 | num_inference_steps: Optional[int] = None, 89 | device: Optional[Union[str, torch.device]] = None, 90 | timesteps: Optional[List[int]] = None, 91 | sigmas: Optional[List[float]] = None, 92 | **kwargs, 93 | ): 94 | """ 95 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 96 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 97 | 98 | Args: 99 | scheduler (`SchedulerMixin`): 100 | The scheduler to get timesteps from. 101 | num_inference_steps (`int`): 102 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 103 | must be `None`. 104 | device (`str` or `torch.device`, *optional*): 105 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 106 | timesteps (`List[int]`, *optional*): 107 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 108 | `num_inference_steps` and `sigmas` must be `None`. 109 | sigmas (`List[float]`, *optional*): 110 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 111 | `num_inference_steps` and `timesteps` must be `None`. 112 | 113 | Returns: 114 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 115 | second element is the number of inference steps. 116 | """ 117 | if timesteps is not None and sigmas is not None: 118 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 119 | if timesteps is not None: 120 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 121 | if not accepts_timesteps: 122 | raise ValueError( 123 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 124 | f" timestep schedules. Please check whether you are using the correct scheduler." 125 | ) 126 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 127 | timesteps = scheduler.timesteps 128 | num_inference_steps = len(timesteps) 129 | elif sigmas is not None: 130 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 131 | if not accept_sigmas: 132 | raise ValueError( 133 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 134 | f" sigmas schedules. Please check whether you are using the correct scheduler." 135 | ) 136 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 137 | timesteps = scheduler.timesteps 138 | num_inference_steps = len(timesteps) 139 | else: 140 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 141 | timesteps = scheduler.timesteps 142 | return timesteps, num_inference_steps 143 | 144 | 145 | class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): 146 | r""" 147 | The Flux pipeline for text-to-image generation. 148 | 149 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 150 | 151 | Args: 152 | transformer ([`FluxTransformer2DModel`]): 153 | Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. 154 | scheduler ([`FlowMatchEulerDiscreteScheduler`]): 155 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 156 | vae ([`AutoencoderKL`]): 157 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 158 | text_encoder ([`CLIPTextModel`]): 159 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 160 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 161 | text_encoder_2 ([`T5EncoderModel`]): 162 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically 163 | the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. 164 | tokenizer (`CLIPTokenizer`): 165 | Tokenizer of class 166 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). 167 | tokenizer_2 (`T5TokenizerFast`): 168 | Second Tokenizer of class 169 | [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). 170 | """ 171 | 172 | model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" 173 | _optional_components = [] 174 | _callback_tensor_inputs = ["latents", "prompt_embeds"] 175 | 176 | def __init__( 177 | self, 178 | scheduler: FlowMatchEulerDiscreteScheduler, 179 | vae: AutoencoderKL, 180 | text_encoder: CLIPTextModel, 181 | tokenizer: CLIPTokenizer, 182 | text_encoder_2: T5EncoderModel, 183 | tokenizer_2: T5TokenizerFast, 184 | transformer: FluxTransformer2DModel, 185 | ): 186 | super().__init__() 187 | 188 | self.register_modules( 189 | vae=vae, 190 | text_encoder=text_encoder, 191 | text_encoder_2=text_encoder_2, 192 | tokenizer=tokenizer, 193 | tokenizer_2=tokenizer_2, 194 | transformer=transformer, 195 | scheduler=scheduler, 196 | ) 197 | self.vae_scale_factor = ( 198 | 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 199 | ) 200 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 201 | self.tokenizer_max_length = ( 202 | self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 203 | ) 204 | self.default_sample_size = 64 205 | 206 | def _get_t5_prompt_embeds( 207 | self, 208 | prompt: Union[str, List[str]] = None, 209 | num_images_per_prompt: int = 1, 210 | max_sequence_length: int = 512, 211 | device: Optional[torch.device] = None, 212 | dtype: Optional[torch.dtype] = None, 213 | ): 214 | device = device or self._execution_device 215 | dtype = dtype or self.text_encoder.dtype 216 | 217 | prompt = [prompt] if isinstance(prompt, str) else prompt 218 | batch_size = len(prompt) 219 | 220 | text_inputs = self.tokenizer_2( 221 | prompt, 222 | padding="max_length", 223 | max_length=max_sequence_length, 224 | truncation=True, 225 | return_length=False, 226 | return_overflowing_tokens=False, 227 | return_tensors="pt", 228 | ) 229 | text_input_ids = text_inputs.input_ids 230 | untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids 231 | 232 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 233 | removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) 234 | logger.warning( 235 | "The following part of your input was truncated because `max_sequence_length` is set to " 236 | f" {max_sequence_length} tokens: {removed_text}" 237 | ) 238 | 239 | prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] 240 | 241 | dtype = self.text_encoder_2.dtype 242 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 243 | 244 | _, seq_len, _ = prompt_embeds.shape 245 | 246 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 247 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 248 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 249 | 250 | return prompt_embeds 251 | 252 | def _get_clip_prompt_embeds( 253 | self, 254 | prompt: Union[str, List[str]], 255 | num_images_per_prompt: int = 1, 256 | device: Optional[torch.device] = None, 257 | ): 258 | device = device or self._execution_device 259 | 260 | prompt = [prompt] if isinstance(prompt, str) else prompt 261 | batch_size = len(prompt) 262 | 263 | text_inputs = self.tokenizer( 264 | prompt, 265 | padding="max_length", 266 | max_length=self.tokenizer_max_length, 267 | truncation=True, 268 | return_overflowing_tokens=False, 269 | return_length=False, 270 | return_tensors="pt", 271 | ) 272 | 273 | text_input_ids = text_inputs.input_ids 274 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 275 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 276 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) 277 | logger.warning( 278 | "The following part of your input was truncated because CLIP can only handle sequences up to" 279 | f" {self.tokenizer_max_length} tokens: {removed_text}" 280 | ) 281 | prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) 282 | 283 | # Use pooled output of CLIPTextModel 284 | prompt_embeds = prompt_embeds.pooler_output 285 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 286 | 287 | # duplicate text embeddings for each generation per prompt, using mps friendly method 288 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) 289 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 290 | 291 | return prompt_embeds 292 | 293 | def encode_prompt( 294 | self, 295 | prompt: Union[str, List[str]], 296 | prompt_2: Union[str, List[str]], 297 | device: Optional[torch.device] = None, 298 | num_images_per_prompt: int = 1, 299 | prompt_embeds: Optional[torch.FloatTensor] = None, 300 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 301 | max_sequence_length: int = 512, 302 | lora_scale: Optional[float] = None, 303 | ): 304 | r""" 305 | 306 | Args: 307 | prompt (`str` or `List[str]`, *optional*): 308 | prompt to be encoded 309 | prompt_2 (`str` or `List[str]`, *optional*): 310 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 311 | used in all text-encoders 312 | device: (`torch.device`): 313 | torch device 314 | num_images_per_prompt (`int`): 315 | number of images that should be generated per prompt 316 | prompt_embeds (`torch.FloatTensor`, *optional*): 317 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 318 | provided, text embeddings will be generated from `prompt` input argument. 319 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 320 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 321 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 322 | lora_scale (`float`, *optional*): 323 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 324 | """ 325 | device = device or self._execution_device 326 | 327 | # set lora scale so that monkey patched LoRA 328 | # function of text encoder can correctly access it 329 | if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): 330 | self._lora_scale = lora_scale 331 | 332 | # dynamically adjust the LoRA scale 333 | if self.text_encoder is not None and USE_PEFT_BACKEND: 334 | scale_lora_layers(self.text_encoder, lora_scale) 335 | if self.text_encoder_2 is not None and USE_PEFT_BACKEND: 336 | scale_lora_layers(self.text_encoder_2, lora_scale) 337 | 338 | prompt = [prompt] if isinstance(prompt, str) else prompt 339 | 340 | if prompt_embeds is None: 341 | prompt_2 = prompt_2 or prompt 342 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 343 | 344 | # We only use the pooled prompt output from the CLIPTextModel 345 | pooled_prompt_embeds = self._get_clip_prompt_embeds( 346 | prompt=prompt, 347 | device=device, 348 | num_images_per_prompt=num_images_per_prompt, 349 | ) 350 | prompt_embeds = self._get_t5_prompt_embeds( 351 | prompt=prompt_2, 352 | num_images_per_prompt=num_images_per_prompt, 353 | max_sequence_length=max_sequence_length, 354 | device=device, 355 | ) 356 | 357 | if self.text_encoder is not None: 358 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 359 | # Retrieve the original scale by scaling back the LoRA layers 360 | unscale_lora_layers(self.text_encoder, lora_scale) 361 | 362 | if self.text_encoder_2 is not None: 363 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 364 | # Retrieve the original scale by scaling back the LoRA layers 365 | unscale_lora_layers(self.text_encoder_2, lora_scale) 366 | 367 | dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype 368 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) 369 | 370 | return prompt_embeds, pooled_prompt_embeds, text_ids 371 | 372 | def encode_regional_prompt( 373 | self, 374 | prompt: Union[str, List[str]], 375 | prompt_2: Union[str, List[str]], 376 | device: Optional[torch.device] = None, 377 | num_images_per_prompt: int = 1, 378 | prompt_embeds: Optional[torch.FloatTensor] = None, 379 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 380 | max_sequence_length: int = 512, 381 | lora_scale: Optional[float] = None, 382 | ): 383 | r""" 384 | 385 | Args: 386 | prompt (`str` or `List[str]`, *optional*): 387 | prompt to be encoded 388 | prompt_2 (`str` or `List[str]`, *optional*): 389 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 390 | used in all text-encoders 391 | device: (`torch.device`): 392 | torch device 393 | num_images_per_prompt (`int`): 394 | number of images that should be generated per prompt 395 | prompt_embeds (`torch.FloatTensor`, *optional*): 396 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 397 | provided, text embeddings will be generated from `prompt` input argument. 398 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 399 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 400 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 401 | lora_scale (`float`, *optional*): 402 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 403 | """ 404 | device = device or self._execution_device 405 | 406 | # set lora scale so that monkey patched LoRA 407 | # function of text encoder can correctly access it 408 | if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): 409 | self._lora_scale = lora_scale 410 | 411 | # dynamically adjust the LoRA scale 412 | if self.text_encoder is not None and USE_PEFT_BACKEND: 413 | scale_lora_layers(self.text_encoder, lora_scale) 414 | if self.text_encoder_2 is not None and USE_PEFT_BACKEND: 415 | scale_lora_layers(self.text_encoder_2, lora_scale) 416 | 417 | prompt = [prompt] if isinstance(prompt, str) else prompt 418 | 419 | if prompt_embeds is None: 420 | prompt_2 = prompt_2 or prompt 421 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 422 | 423 | # We only use the pooled prompt output from the CLIPTextModel 424 | pooled_prompt_embeds = self._get_clip_prompt_embeds( 425 | prompt=prompt, 426 | device=device, 427 | num_images_per_prompt=num_images_per_prompt, 428 | ) 429 | prompt_embeds = self._get_t5_prompt_embeds( 430 | prompt=prompt_2, 431 | num_images_per_prompt=num_images_per_prompt, 432 | max_sequence_length=max_sequence_length, 433 | device=device, 434 | ) 435 | 436 | if self.text_encoder is not None: 437 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 438 | # Retrieve the original scale by scaling back the LoRA layers 439 | unscale_lora_layers(self.text_encoder, lora_scale) 440 | 441 | if self.text_encoder_2 is not None: 442 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 443 | # Retrieve the original scale by scaling back the LoRA layers 444 | unscale_lora_layers(self.text_encoder_2, lora_scale) 445 | 446 | dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype 447 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) 448 | 449 | # hard code here! 450 | regional_prompts = prompt[0].split(";") 451 | prompt_embeds_list = [] 452 | for regional_prompt in regional_prompts: 453 | prompt_embeds = self._get_t5_prompt_embeds( 454 | prompt=regional_prompt, 455 | num_images_per_prompt=num_images_per_prompt, 456 | max_sequence_length=max_sequence_length, 457 | device=device, 458 | ) 459 | prompt_embeds_list.append(prompt_embeds) 460 | prompt_embeds = torch.concat(prompt_embeds_list, dim=1) 461 | 462 | #print(prompt_embeds.shape, pooled_prompt_embeds.shape, text_ids.shape) 463 | # torch.Size([1, 512*num_prompt, 4096]) torch.Size([1, 768]) torch.Size([512, 3]) 464 | 465 | return prompt_embeds, pooled_prompt_embeds, text_ids 466 | 467 | def check_inputs( 468 | self, 469 | prompt, 470 | prompt_2, 471 | height, 472 | width, 473 | prompt_embeds=None, 474 | pooled_prompt_embeds=None, 475 | callback_on_step_end_tensor_inputs=None, 476 | max_sequence_length=None, 477 | ): 478 | if height % 8 != 0 or width % 8 != 0: 479 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 480 | 481 | if callback_on_step_end_tensor_inputs is not None and not all( 482 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 483 | ): 484 | raise ValueError( 485 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 486 | ) 487 | 488 | if prompt is not None and prompt_embeds is not None: 489 | raise ValueError( 490 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 491 | " only forward one of the two." 492 | ) 493 | elif prompt_2 is not None and prompt_embeds is not None: 494 | raise ValueError( 495 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 496 | " only forward one of the two." 497 | ) 498 | elif prompt is None and prompt_embeds is None: 499 | raise ValueError( 500 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 501 | ) 502 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 503 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 504 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): 505 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") 506 | 507 | if prompt_embeds is not None and pooled_prompt_embeds is None: 508 | raise ValueError( 509 | "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." 510 | ) 511 | 512 | if max_sequence_length is not None and max_sequence_length > 512: 513 | raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") 514 | 515 | @staticmethod 516 | def _prepare_latent_image_ids(batch_size, height, width, device, dtype): 517 | # print(batch_size, height, width) 518 | # 1 96 160 519 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) 520 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] 521 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] 522 | 523 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape 524 | 525 | latent_image_ids = latent_image_ids.reshape( 526 | latent_image_id_height * latent_image_id_width, latent_image_id_channels 527 | ) 528 | 529 | return latent_image_ids.to(device=device, dtype=dtype) 530 | 531 | @staticmethod 532 | def _pack_latents(latents, batch_size, num_channels_latents, height, width): 533 | latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) 534 | latents = latents.permute(0, 2, 4, 1, 3, 5) 535 | latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) 536 | 537 | return latents 538 | 539 | @staticmethod 540 | def _unpack_latents(latents, height, width, vae_scale_factor): 541 | batch_size, num_patches, channels = latents.shape 542 | 543 | height = height // vae_scale_factor 544 | width = width // vae_scale_factor 545 | 546 | latents = latents.view(batch_size, height, width, channels // 4, 2, 2) 547 | latents = latents.permute(0, 3, 1, 4, 2, 5) 548 | 549 | latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) 550 | 551 | return latents 552 | 553 | def enable_vae_slicing(self): 554 | r""" 555 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 556 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 557 | """ 558 | self.vae.enable_slicing() 559 | 560 | def disable_vae_slicing(self): 561 | r""" 562 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to 563 | computing decoding in one step. 564 | """ 565 | self.vae.disable_slicing() 566 | 567 | def enable_vae_tiling(self): 568 | r""" 569 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 570 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow 571 | processing larger images. 572 | """ 573 | self.vae.enable_tiling() 574 | 575 | def disable_vae_tiling(self): 576 | r""" 577 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to 578 | computing decoding in one step. 579 | """ 580 | self.vae.disable_tiling() 581 | 582 | def prepare_latents( 583 | self, 584 | batch_size, 585 | num_channels_latents, 586 | height, 587 | width, 588 | dtype, 589 | device, 590 | generator, 591 | latents=None, 592 | ): 593 | height = 2 * (int(height) // self.vae_scale_factor) 594 | width = 2 * (int(width) // self.vae_scale_factor) 595 | 596 | shape = (batch_size, num_channels_latents, height, width) 597 | 598 | if latents is not None: 599 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 600 | return latents.to(device=device, dtype=dtype), latent_image_ids 601 | 602 | if isinstance(generator, list) and len(generator) != batch_size: 603 | raise ValueError( 604 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 605 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 606 | ) 607 | 608 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # torch.Size([1, 16, 96, 160]) 609 | latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) # torch.Size([1, 3840, 64]) 610 | 611 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) # torch.Size([3840, 3]) 612 | 613 | return latents, latent_image_ids 614 | 615 | @property 616 | def guidance_scale(self): 617 | return self._guidance_scale 618 | 619 | @property 620 | def joint_attention_kwargs(self): 621 | return self._joint_attention_kwargs 622 | 623 | @property 624 | def num_timesteps(self): 625 | return self._num_timesteps 626 | 627 | @property 628 | def interrupt(self): 629 | return self._interrupt 630 | 631 | @torch.no_grad() 632 | @replace_example_docstring(EXAMPLE_DOC_STRING) 633 | def __call__( 634 | self, 635 | prompt: Union[str, List[str]] = None, 636 | prompt_2: Optional[Union[str, List[str]]] = None, 637 | height: Optional[int] = None, 638 | width: Optional[int] = None, 639 | num_inference_steps: int = 28, 640 | timesteps: List[int] = None, 641 | guidance_scale: float = 3.5, 642 | num_images_per_prompt: Optional[int] = 1, 643 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 644 | latents: Optional[torch.FloatTensor] = None, 645 | prompt_embeds: Optional[torch.FloatTensor] = None, 646 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 647 | image_emb: Optional[torch.FloatTensor] = None, 648 | output_type: Optional[str] = "pil", 649 | return_dict: bool = True, 650 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 651 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 652 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 653 | max_sequence_length: int = 512, 654 | ): 655 | r""" 656 | Function invoked when calling the pipeline for generation. 657 | 658 | Args: 659 | prompt (`str` or `List[str]`, *optional*): 660 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 661 | instead. 662 | prompt_2 (`str` or `List[str]`, *optional*): 663 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 664 | will be used instead 665 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 666 | The height in pixels of the generated image. This is set to 1024 by default for the best results. 667 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 668 | The width in pixels of the generated image. This is set to 1024 by default for the best results. 669 | num_inference_steps (`int`, *optional*, defaults to 50): 670 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 671 | expense of slower inference. 672 | timesteps (`List[int]`, *optional*): 673 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 674 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 675 | passed will be used. Must be in descending order. 676 | guidance_scale (`float`, *optional*, defaults to 7.0): 677 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 678 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 679 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 680 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 681 | usually at the expense of lower image quality. 682 | num_images_per_prompt (`int`, *optional*, defaults to 1): 683 | The number of images to generate per prompt. 684 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 685 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 686 | to make generation deterministic. 687 | latents (`torch.FloatTensor`, *optional*): 688 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 689 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 690 | tensor will ge generated by sampling using the supplied random `generator`. 691 | prompt_embeds (`torch.FloatTensor`, *optional*): 692 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 693 | provided, text embeddings will be generated from `prompt` input argument. 694 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 695 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 696 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 697 | output_type (`str`, *optional*, defaults to `"pil"`): 698 | The output format of the generate image. Choose between 699 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 700 | return_dict (`bool`, *optional*, defaults to `True`): 701 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. 702 | joint_attention_kwargs (`dict`, *optional*): 703 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 704 | `self.processor` in 705 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 706 | callback_on_step_end (`Callable`, *optional*): 707 | A function that calls at the end of each denoising steps during the inference. The function is called 708 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 709 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 710 | `callback_on_step_end_tensor_inputs`. 711 | callback_on_step_end_tensor_inputs (`List`, *optional*): 712 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 713 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 714 | `._callback_tensor_inputs` attribute of your pipeline class. 715 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. 716 | 717 | Examples: 718 | 719 | Returns: 720 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` 721 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated 722 | images. 723 | """ 724 | 725 | height = height or self.default_sample_size * self.vae_scale_factor 726 | width = width or self.default_sample_size * self.vae_scale_factor 727 | 728 | # 1. Check inputs. Raise error if not correct 729 | self.check_inputs( 730 | prompt, 731 | prompt_2, 732 | height, 733 | width, 734 | prompt_embeds=prompt_embeds, 735 | pooled_prompt_embeds=pooled_prompt_embeds, 736 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 737 | max_sequence_length=max_sequence_length, 738 | ) 739 | 740 | self._guidance_scale = guidance_scale 741 | self._joint_attention_kwargs = joint_attention_kwargs 742 | self._interrupt = False 743 | 744 | # 2. Define call parameters 745 | if prompt is not None and isinstance(prompt, str): 746 | batch_size = 1 747 | elif prompt is not None and isinstance(prompt, list): 748 | batch_size = len(prompt) 749 | else: 750 | batch_size = prompt_embeds.shape[0] 751 | 752 | device = self._execution_device 753 | 754 | lora_scale = ( 755 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 756 | ) 757 | ( 758 | prompt_embeds, 759 | pooled_prompt_embeds, 760 | text_ids, 761 | ) = self.encode_prompt( 762 | prompt=prompt, 763 | prompt_2=prompt_2, 764 | prompt_embeds=prompt_embeds, 765 | pooled_prompt_embeds=pooled_prompt_embeds, 766 | device=device, 767 | num_images_per_prompt=num_images_per_prompt, 768 | max_sequence_length=max_sequence_length, 769 | lora_scale=lora_scale, 770 | ) 771 | 772 | # 4. Prepare latent variables 773 | num_channels_latents = self.transformer.config.in_channels // 4 774 | latents, latent_image_ids = self.prepare_latents( 775 | batch_size * num_images_per_prompt, 776 | num_channels_latents, 777 | height, 778 | width, 779 | prompt_embeds.dtype, 780 | device, 781 | generator, 782 | latents, 783 | ) 784 | 785 | # 5. Prepare timesteps 786 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 787 | image_seq_len = latents.shape[1] 788 | mu = calculate_shift( 789 | image_seq_len, 790 | self.scheduler.config.base_image_seq_len, 791 | self.scheduler.config.max_image_seq_len, 792 | self.scheduler.config.base_shift, 793 | self.scheduler.config.max_shift, 794 | ) 795 | timesteps, num_inference_steps = retrieve_timesteps( 796 | self.scheduler, 797 | num_inference_steps, 798 | device, 799 | timesteps, 800 | sigmas, 801 | mu=mu, 802 | ) 803 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 804 | self._num_timesteps = len(timesteps) 805 | 806 | # handle guidance 807 | if self.transformer.config.guidance_embeds: 808 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) 809 | guidance = guidance.expand(latents.shape[0]) 810 | else: 811 | guidance = None 812 | 813 | # 6. Denoising loop 814 | with self.progress_bar(total=num_inference_steps) as progress_bar: 815 | for i, t in enumerate(timesteps): 816 | if self.interrupt: 817 | continue 818 | 819 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 820 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 821 | 822 | noise_pred = self.transformer( 823 | hidden_states=latents, 824 | timestep=timestep / 1000, 825 | guidance=guidance, 826 | pooled_projections=pooled_prompt_embeds, 827 | encoder_hidden_states=prompt_embeds, 828 | image_emb=image_emb, 829 | txt_ids=text_ids, 830 | img_ids=latent_image_ids, 831 | joint_attention_kwargs=self.joint_attention_kwargs, 832 | return_dict=False, 833 | )[0] 834 | 835 | # compute the previous noisy sample x_t -> x_t-1 836 | latents_dtype = latents.dtype 837 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 838 | 839 | if latents.dtype != latents_dtype: 840 | if torch.backends.mps.is_available(): 841 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 842 | latents = latents.to(latents_dtype) 843 | 844 | if callback_on_step_end is not None: 845 | callback_kwargs = {} 846 | for k in callback_on_step_end_tensor_inputs: 847 | callback_kwargs[k] = locals()[k] 848 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 849 | 850 | latents = callback_outputs.pop("latents", latents) 851 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 852 | 853 | # call the callback, if provided 854 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 855 | progress_bar.update() 856 | 857 | if XLA_AVAILABLE: 858 | xm.mark_step() 859 | 860 | if output_type == "latent": 861 | image = latents 862 | 863 | else: 864 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 865 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 866 | image = self.vae.decode(latents, return_dict=False)[0] 867 | image = self.image_processor.postprocess(image, output_type=output_type) 868 | 869 | # Offload all models 870 | self.maybe_free_model_hooks() 871 | 872 | if not return_dict: 873 | return (image,) 874 | 875 | return FluxPipelineOutput(images=image) 876 | --------------------------------------------------------------------------------