├── pic ├── introduction.png └── 0001_0007_merge_seed1234.png ├── T2IS_Gen ├── pic │ └── 0001_0007_merge_seed1234.png ├── README.md ├── init_attention.py ├── utils.py ├── inference_t2is.py ├── t2is_transformer_flux.py └── t2is_pipeline_flux.py ├── T2IS_Eval ├── run_prompt_alignment.sh ├── run_prompt_consistency.sh ├── README.md ├── vqa_alignment.py └── test_prompt_consistency.py ├── README.md └── T2IS_Bench └── README.md /pic/introduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengyou-jia/T2IS/HEAD/pic/introduction.png -------------------------------------------------------------------------------- /pic/0001_0007_merge_seed1234.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengyou-jia/T2IS/HEAD/pic/0001_0007_merge_seed1234.png -------------------------------------------------------------------------------- /T2IS_Gen/pic/0001_0007_merge_seed1234.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengyou-jia/T2IS/HEAD/T2IS_Gen/pic/0001_0007_merge_seed1234.png -------------------------------------------------------------------------------- /T2IS_Eval/run_prompt_alignment.sh: -------------------------------------------------------------------------------- 1 | IMAGE_DIR="/home/chengyou/AutoT2IS/output_images/layout_deepseek-reasoner" # Target directory 2 | OUTPUT_JSON="/home/chengyou/results/layout_deepseek-reasoner.json" # Output JSON file path 3 | 4 | cd /home/chengyou/clipscore/t2v_metrics 5 | # Process specified directory 6 | # see https://github.com/linzhiqiu/t2v_metrics for t2v_metrics 7 | 8 | echo "Evaluating images in ${IMAGE_DIR}..." 9 | CUDA_VISIBLE_DEVICES=0 python vqa_alignment.py \ 10 | --image_dir "${IMAGE_DIR}" \ 11 | --image_format png \ 12 | --prompt_json "./ChengyouJia/T2IS-Bench/prompt_alignment.json" \ 13 | --output_json "${OUTPUT_JSON}" \ 14 | --start_idx 0 15 | 16 | echo "Evaluation completed for ${IMAGE_DIR}!" -------------------------------------------------------------------------------- /T2IS_Eval/run_prompt_consistency.sh: -------------------------------------------------------------------------------- 1 | # Set basic parameters 2 | IMAGE_DIR="/home/chengyou/AutoT2IS/output_images/layout_deepseek-reasoner" # Target directory 3 | IMAGE_PATTERN=".png" # Image format 4 | OUTPUT_JSON="/home/chengyou/Qwen/results/layout_deepseek-reasoner.json" # Output JSON file path 5 | 6 | # Process specified directory 7 | echo "Evaluating images in ${IMAGE_DIR}..." 8 | CUDA_VISIBLE_DEVICES=0 python test_prompt_consistency.py \ 9 | --dataset_path "./ChengyouJia/T2IS-Bench/T2IS-Bench.json" \ 10 | --criteria_path "./ChengyouJia/T2IS-Bench/prompt_consistency.json" \ 11 | --image_base_path "${IMAGE_DIR}" \ 12 | --image_pattern "${IMAGE_PATTERN}" \ 13 | --output_path "${OUTPUT_JSON}" 14 | echo "Evaluation completed for ${IMAGE_DIR}!" -------------------------------------------------------------------------------- /T2IS_Gen/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Usage 4 | 5 | ### Basic Usage 6 | ```bash 7 | python inference_t2is.py 8 | ``` 9 | 10 | ### With Parameters 11 | ```bash 12 | python inference_t2is.py --idx 0 --json_file ./data/gallery.json --output_dir ./output_images/ 13 | ``` 14 | 15 | ## Parameters 16 | 17 | | Parameter | Type | Default | Description | 18 | |-----------|------|---------|-------------| 19 | | `--idx` | int | None | Parameter index from JSON file | 20 | | `--json_file` | str | `./data/gallery.json` | Input JSON file path | 21 | | `--output_dir` | str | `./output_images/` | Output directory path | 22 | 23 | ## Output Structure 24 | 25 | Generated images are saved in: 26 | ``` 27 | output_images/ 28 | └── seed_*/ 29 | └── {idx}_merge_seed*.png 30 | ``` 31 | 32 | ## Generated ImageSet 33 | Examples 34 | 35 | 36 | 37 | 38 | 41 |
39 |
python inference_t2is.py --idx=
40 |
-------------------------------------------------------------------------------- /T2IS_Eval/README.md: -------------------------------------------------------------------------------- 1 | # T2IS Evaluation Scripts 2 | 3 | This directory contains evaluation scripts for Text-to-ImageSet (T2IS) models, including prompt alignment and visual consistency evaluation. 4 | 5 | ## Prerequisites 6 | 7 | Before running the evaluation scripts, you need to download the T2IS-Bench dataset from Hugging Face: 8 | 9 | 1. Visit [ChengyouJia/T2IS-Bench](https://huggingface.co/datasets/ChengyouJia/T2IS-Bench) on Hugging Face 10 | 2. Download the following files: 11 | - `T2IS-Bench.json`: Main dataset file containing prompts and metadata 12 | - `prompt_alignment.json`: Alignment evaluation criteria 13 | - `prompt_consistency.json`: Consistency evaluation criteria 14 | 3. Place the downloaded files in a directory named `ChengyouJia/T2IS-Bench/` relative to the evaluation scripts 15 | 16 | 17 | ## Scripts Overview 18 | 19 | ### 1. Prompt Alignment Evaluation (`run_prompt_alignment.sh`) 20 | 21 | **Purpose**: Evaluates the alignment between generated images and their corresponding prompts using VQAScore metrics. 22 | 23 | **Features**: 24 | - Uses VQAScore to measure semantic similarity between images and text prompts 25 | - Processes PNG images from a specified directory 26 | - Outputs evaluation results in JSON format 27 | - Configurable start index for batch processing 28 | 29 | **Usage**: 30 | ```bash 31 | # Make the script executable 32 | chmod +x run_prompt_alignment.sh 33 | 34 | # Run the evaluation 35 | ./run_prompt_alignment.sh 36 | ``` 37 | 38 | **Configuration**: 39 | - **Image Directory**: Set `IMAGE_DIR` to the path containing your generated images 40 | - **Output Path**: Set `OUTPUT_JSON` to where you want to save the results 41 | - **Image Format**: Currently set to PNG format 42 | - **Start Index**: Set `start_idx` for batch processing (default: 0) 43 | 44 | **Requirements**: 45 | - See details in https://github.com/linzhiqiu/t2v_metrics. 46 | 47 | ### 2. Prompt Consistency Evaluation (`run_prompt_consistency.sh`) 48 | 49 | **Purpose**: Evaluates the consistency of generated images with their prompts using Qwen model analysis. 50 | 51 | **Features**: 52 | - Uses Qwen2.5VL model to analyze image-prompt consistency 53 | - Processes images with configurable pattern matching 54 | - Outputs detailed consistency analysis in JSON format 55 | - Supports various image formats through pattern matching 56 | 57 | **Usage**: 58 | ```bash 59 | # Make the script executable 60 | chmod +x run_prompt_consistency.sh 61 | 62 | # Run the evaluation 63 | ./run_prompt_consistency.sh 64 | ``` 65 | 66 | **Configuration**: 67 | - **Image Directory**: Set `IMAGE_DIR` to the path containing your generated images 68 | - **Image Pattern**: Set `IMAGE_PATTERN` to match your image files (e.g., ".png", ".jpg") 69 | - **Output Path**: Set `OUTPUT_JSON` to where you want to save the results 70 | 71 | **Requirements**: 72 | - See details in https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct. 73 | -------------------------------------------------------------------------------- /T2IS_Gen/init_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torchvision.transforms.functional as F 4 | TOKENS = 75 5 | 6 | def split_dims(x_t, height, width, self=None): 7 | """Split an attention layer dimension to height + width. 8 | The original estimate was latent_h = sqrt(hw_ratio*x_t), 9 | rounding to the nearest value. However, this proved inaccurate. 10 | The actual operation seems to be as follows: 11 | - Divide h,w by 8, rounding DOWN. 12 | - For every new layer (of 4), divide both by 2 and round UP (then back up). 13 | - Multiply h*w to yield x_t. 14 | There is no inverse function to this set of operations, 15 | so instead we mimic them without the multiplication part using the original h+w. 16 | It's worth noting that no known checkpoints follow a different system of layering, 17 | but it's theoretically possible. Please report if encountered. 18 | """ 19 | scale = math.ceil(math.log2(math.sqrt(height * width / x_t))) 20 | latent_h = repeat_div(height, scale) 21 | latent_w = repeat_div(width, scale) 22 | if x_t > latent_h * latent_w and hasattr(self, "nei_multi"): 23 | latent_h, latent_w = self.nei_multi[1], self.nei_multi[0] 24 | while latent_h * latent_w != x_t: 25 | latent_h, latent_w = latent_h // 2, latent_w // 2 26 | 27 | return latent_h, latent_w 28 | 29 | def repeat_div(x,y): 30 | """Imitates dimension halving common in convolution operations. 31 | 32 | This is a pretty big assumption of the model, 33 | but then if some model doesn't work like that it will be easy to spot. 34 | """ 35 | while y > 0: 36 | x = math.ceil(x / 2) 37 | y = y - 1 38 | return x 39 | 40 | def init_forwards(self, root_module: torch.nn.Module): 41 | for name, module in root_module.named_modules(): 42 | if "attn" in name and "transformer_blocks" in name and "single_transformer_blocks" not in name and module.__class__.__name__ == "Attention": 43 | module.forward = FluxTransformerBlock_init_forward(self, module) 44 | elif "attn" in name and "single_transformer_blocks" in name and module.__class__.__name__ == "Attention": 45 | module.forward = FluxSingleTransformerBlock_init_forward(self, module) 46 | 47 | def FluxSingleTransformerBlock_init_forward(self, module): 48 | def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None,RPG_encoder_hidden_states_list=None,RPG_norm_encoder_hidden_states_list=None,RPG_hidden_states_list=None,RPG_norm_hidden_states_list=None): 49 | return module.processor(module, hidden_states=hidden_states, image_rotary_emb=image_rotary_emb) 50 | return forward 51 | 52 | def FluxTransformerBlock_init_forward(self, module): 53 | def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None,RPG_encoder_hidden_states_list=None,RPG_norm_encoder_hidden_states_list=None,RPG_hidden_states_list=None,RPG_norm_hidden_states_list=None): 54 | return module.processor(module, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb) 55 | return forward -------------------------------------------------------------------------------- /T2IS_Gen/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for T2IS generation and image processing. 3 | """ 4 | 5 | def calculate_layout_dimensions(num_prompts, sub_height, sub_width): 6 | """Calculate total height and width based on number of prompts and sub-image dimensions. 7 | 8 | Args: 9 | num_prompts (int): Number of sub-images/prompts 10 | sub_height (int): Height of each sub-image 11 | sub_width (int): Width of each sub-image 12 | 13 | Returns: 14 | tuple: Total (height, width) dimensions 15 | """ 16 | if num_prompts == 4: 17 | # 2x2 layout 18 | height = sub_height * 2 19 | width = sub_width * 2 20 | elif num_prompts == 6: 21 | # 2x3 layout 22 | height = sub_height * 2 23 | width = sub_width * 3 24 | elif num_prompts == 8: 25 | # 2x4 layout 26 | height = sub_height * 2 27 | width = sub_width * 4 28 | elif num_prompts == 9: 29 | # 3x3 layout 30 | height = sub_height * 3 31 | width = sub_width * 3 32 | else: 33 | # Default: 1xN layout 34 | height = sub_height 35 | width = sub_width * num_prompts 36 | return height, width 37 | 38 | 39 | def calculate_cutting_layout(num_prompts): 40 | """Calculate rows and columns for cutting images based on number of prompts. 41 | 42 | Args: 43 | num_prompts (int): Number of sub-images/prompts 44 | 45 | Returns: 46 | tuple: (rows, cols) for cutting layout 47 | """ 48 | if num_prompts == 4: 49 | # 2*2 切割 50 | rows, cols = 2, 2 51 | elif num_prompts == 6: 52 | # 2*3 切割 53 | rows, cols = 2, 3 54 | elif num_prompts == 8: 55 | # 2*4 切割 56 | rows, cols = 2, 4 57 | elif num_prompts == 9: 58 | # 3*3 切割 59 | rows, cols = 3, 3 60 | else: 61 | # 默认沿着宽切割 62 | rows, cols = 1, num_prompts 63 | 64 | return rows, cols 65 | 66 | 67 | def get_grid_params(size): 68 | """Calculate grid parameters for image layout based on size. 69 | 70 | Args: 71 | size (int): Number of sub-images/prompts 72 | 73 | Returns: 74 | tuple: (m_offsets, n_offsets, m_scales, n_scales) for grid layout 75 | """ 76 | if size == 4: 77 | rows, cols = 2, 2 78 | elif size == 6: 79 | rows, cols = 2, 3 80 | elif size == 8: 81 | rows, cols = 2, 4 82 | elif size == 9: 83 | rows, cols = 3, 3 84 | else: 85 | rows, cols = 1, size 86 | 87 | m_scale = 1.0/cols 88 | n_scale = 1.0/rows 89 | 90 | m_offsets = [] 91 | n_offsets = [] 92 | m_scales = [] 93 | n_scales = [] 94 | 95 | for row in range(rows): 96 | for col in range(cols): 97 | m_offsets.append(col * m_scale) 98 | n_offsets.append(row * n_scale) 99 | m_scales.append(m_scale) 100 | n_scales.append(n_scale) 101 | 102 | return m_offsets, n_offsets, m_scales, n_scales 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /T2IS_Gen/inference_t2is.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import json 4 | import os 5 | from t2is_pipeline_flux import T2IS_FluxPipeline 6 | from PIL import Image 7 | from utils import calculate_layout_dimensions, calculate_cutting_layout 8 | 9 | def parse_arguments(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--idx', type=int, help="Loading parameters in json") 12 | parser.add_argument('--output_dir', type=str, default="./output_images/", help="Output directory") 13 | parser.add_argument('--json_file', type=str, default="./data/gallery.json", help="Input JSON file") 14 | return parser.parse_args() 15 | 16 | pipe = T2IS_FluxPipeline.from_pretrained("/home/chengyou/hugging/models/FLUX.1-dev", torch_dtype=torch.bfloat16) 17 | pipe = pipe.to("cuda") 18 | 19 | args = parse_arguments() 20 | 21 | if args.idx is not None: 22 | # Load parameters from JSON 23 | with open(args.json_file, 'r') as f: 24 | data = json.load(f) 25 | 26 | item = data[args.idx] 27 | 28 | idx = item["idx"] 29 | task_name_case_id = item["task_name_case_id"] 30 | prompt = item["prompt"] 31 | Divide_prompt_list = item["Divide_prompt_list"] 32 | 33 | # Calculate dimensions 34 | num_prompts = len(Divide_prompt_list) 35 | sub_height = 512 36 | sub_width = 512 37 | height, width = calculate_layout_dimensions(num_prompts, sub_height, sub_width) 38 | 39 | Divide_replace = 2 40 | seed = 1234 41 | 42 | else: 43 | # Default parameters 44 | idx = "0001_0007" 45 | prompt = "THREE-PANEL Images with a 1x3 grid layout a teenage boy with short spiky black hair, a slight build, and dark brown eyes in hyper-realistic style.All images maintain hyper-realistic digital painting style with consistent character design, emphasizing the boy's distinct features and naturalistic lighting across varied environments. [LEFT]:The boy stands at a science fair, surrounded by project displays and glowing holographic models. He holds a blueprint, his expression bright with curiosity. The background features blurred crowds and colorful experiment stations under bright indoor lighting. [MIDDLE]:The boy crouches in a sunlit garden, digging soil with a trowel. Dirt stains his hands and casual clothes, with scattered gardening tools nearby. His focused gaze and slightly parted lips suggest discovery, sunlight casting sharp shadows on the earthy textures. [RIGHT]:The boy wears a green knitted hat in a snowy urban park, breath visible in cold air. Frosted trees frame the scene as he clutches a steaming drink. The hat's yarn details contrast with his spiky hair, while distant ice-skating figures blur into the winter haze." 46 | Divide_prompt_list = [ 47 | "The boy stands at a science fair, surrounded by project displays and glowing holographic models. He holds a blueprint, his expression bright with curiosity. The background features blurred crowds and colorful experiment stations under bright indoor lighting.", 48 | "The boy crouches in a sunlit garden, digging soil with a trowel. Dirt stains his hands and casual clothes, with scattered gardening tools nearby. His focused gaze and slightly parted lips suggest discovery, sunlight casting sharp shadows on the earthy textures.", 49 | "The boy wears a green knitted hat in a snowy urban park, breath visible in cold air. Frosted trees frame the scene as he clutches a steaming drink. The hat's yarn details contrast with his spiky hair, while distant ice-skating figures blur into the winter haze." 50 | ] 51 | 52 | sub_height = 512 53 | sub_width = 512 54 | height, width = calculate_layout_dimensions(3, sub_height, sub_width) 55 | 56 | Divide_replace = 2 57 | seed = 1234 58 | 59 | # Create output directory 60 | base_output_path = args.output_dir 61 | if not os.path.exists(base_output_path): 62 | os.makedirs(base_output_path) 63 | 64 | seed_output_path = os.path.join(base_output_path, f"seed_{seed}") 65 | if not os.path.exists(seed_output_path): 66 | os.makedirs(seed_output_path) 67 | 68 | print(f"Generating with seed {seed}:") 69 | try: 70 | image = pipe( 71 | Divide_prompt_list=Divide_prompt_list, 72 | Divide_replace=Divide_replace, 73 | seed=seed, 74 | prompt=prompt, 75 | height=height, 76 | width=width, 77 | num_inference_steps=20, 78 | guidance_scale=3.5, 79 | ).images[0] 80 | 81 | # Save image 82 | output_filename = os.path.join(seed_output_path, f"{idx}_merge_seed{seed}.png") 83 | image.save(output_filename) 84 | print(f"Image saved as {output_filename}") 85 | 86 | except Exception as e: 87 | print(f"Error processing {idx} with seed {seed}: {str(e)}") 88 | -------------------------------------------------------------------------------- /T2IS_Eval/vqa_alignment.py: -------------------------------------------------------------------------------- 1 | import t2v_metrics 2 | import json 3 | import os 4 | from PIL import Image 5 | import torch 6 | from torchvision.transforms import ToTensor 7 | dimensions = ['Entity', 'Attribute', 'Relation'] 8 | # Initialize the scoring model 9 | clip_flant5_score = t2v_metrics.VQAScore(model='clip-flant5-xxl') 10 | 11 | 12 | 13 | # Function to get image path from task_id and sub_id 14 | def get_image_path(task_id, sub_id, image_dir, image_format): 15 | return os.path.join(image_dir, f"{task_id}_{sub_id}.{image_format}") 16 | 17 | # Function to evaluate a single image against its criteria 18 | def evaluate_image(image_path, criteria): 19 | if not os.path.exists(image_path): 20 | return None 21 | 22 | # Load the image 23 | image = Image.open(image_path) 24 | 25 | # Prepare texts for each dimension 26 | texts = [] 27 | for dimension in dimensions: 28 | texts.extend(criteria[dimension]) 29 | 30 | # Calculate scores 31 | scores = clip_flant5_score(images=[image_path], texts=texts) 32 | 33 | # Calculate average score for each dimension 34 | dimension_scores = {} 35 | start_idx = 0 36 | for dimension in dimensions: 37 | num_criteria = len(criteria[dimension]) 38 | dimension_scores[dimension] = sum(scores[0][start_idx:start_idx + num_criteria]) / num_criteria 39 | start_idx += num_criteria 40 | 41 | return dimension_scores 42 | 43 | # Main evaluation loop 44 | def main_evaluation(image_dir, image_format, prompt_json, output_json, start_idx=1): 45 | 46 | # Load the prompt alignment JSON file 47 | with open(prompt_json, 'r') as f: 48 | prompt_data = json.load(f) 49 | 50 | overall_dimension_scores = {dimension: 0 for dimension in dimensions} 51 | total_images = 0 52 | results = {} 53 | 54 | # Load existing results from the JSON file if it exists 55 | if os.path.exists(output_json): 56 | with open(output_json, 'r') as f: 57 | results = json.load(f) 58 | else: 59 | results = {} 60 | 61 | for task_id, task_data in prompt_data.items(): 62 | print(f"\nEvaluating task: {task_id}") 63 | 64 | for idx, (sub_id, criteria) in enumerate(task_data.items(), start=start_idx): 65 | formatted_sub_id = f"{idx:04d}" 66 | result_key = f"{task_id}_{formatted_sub_id}" 67 | 68 | # Skip processing if the result already exists 69 | if result_key in results: 70 | print(f"Skipping already processed image: {result_key}") 71 | continue 72 | 73 | image_path = get_image_path(task_id, formatted_sub_id, image_dir, image_format) 74 | print(f"\nEvaluating image: {image_path}") 75 | 76 | scores = evaluate_image(image_path, criteria) 77 | if scores: 78 | print("Dimension scores:") 79 | for dimension, score in scores.items(): 80 | print(f"{dimension}: {score:.4f}") 81 | overall_dimension_scores[dimension] += score 82 | total_images += 1 83 | results[result_key] = scores 84 | 85 | # Save the updated results to the JSON file after each image 86 | with open(output_json, 'w') as f: 87 | json.dump(results, f, indent=4, default=lambda o: o.tolist() if isinstance(o, torch.Tensor) else o) 88 | else: 89 | print("Image not found") 90 | 91 | # Calculate overall average scores for each dimension 92 | if total_images > 0: 93 | for dimension in overall_dimension_scores: 94 | overall_dimension_scores[dimension] /= total_images 95 | 96 | print("\nOverall average dimension scores:") 97 | for dimension, score in overall_dimension_scores.items(): 98 | print(f"{dimension}: {score:.4f}") 99 | 100 | # Example usage 101 | if __name__ == "__main__": 102 | import argparse 103 | 104 | parser = argparse.ArgumentParser(description="Evaluate image dimensions.") 105 | parser.add_argument("--image_dir", type=str, default="/home/chengyou/GroupGen/baseline/results/flux/", help="Directory containing the images.") 106 | parser.add_argument("--image_format", type=str, default="png", choices=["png", "jpg"], help="Format of the images.") 107 | parser.add_argument("--output_json", type=str, default="dimension_scores.json", help="Path to the output JSON file.") 108 | parser.add_argument("--start_idx", type=int, default=1, help="Starting index for image enumeration.") 109 | parser.add_argument("--prompt_json", type=str, default="/home/chengyou/GroupGen/evaluation/prompt_alignment.json", help="Path to the prompt alignment JSON file.") 110 | args = parser.parse_args() 111 | 112 | main_evaluation(args.image_dir, args.image_format, args.prompt_json, args.output_json, args.start_idx) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Why Settle for One? Text-to-ImageSet Generation and Evaluation 3 |

4 |

5 | [🌐 Website] • 6 | [📜 Paper] • 7 | [🤗 HF Dataset] • 8 |

9 | 10 | 11 |

12 | Official Repo for "Why Settle for One? 13 | Text-to-ImageSet Generation and Evaluation" 14 |

15 | 16 | 18 | 19 | ## T2IS 20 | ![T2IS](./pic/introduction.png) 21 | 22 | 23 | ## News 24 | 25 | - _2025.10_: We have added the latest **Seedream 4.0** results. Please refer to the [Seedream 4.0 Demo](https://www.volcengine.com/experience/ark?launch=seedream) and the attached file [**T2IS_Seedream.zip**](https://huggingface.co/datasets/ChengyouJia/T2IS-Bench). 26 | 27 | 28 | - _2025.09_: We release the [T2IS-Gen] simple version of set-aware generation code. 29 | 30 | 31 | - _2025.08_: We release the [T2IS-Eval] evaluation toolkit. 32 | - _2025.07_: We release the details of [T2IS-Bench]. 33 | 34 | 35 | ## 🛠️ Installation 36 | 37 | ### Text-to-ImageSet Generation 38 | 39 | ### 1. Set Environment 40 | ```bash 41 | conda create -n T2IS python==3.9 42 | conda activate T2IS 43 | pip install xformers==0.0.28.post1 diffusers peft torchvision==0.19.1 opencv-python==4.10.0.84 sentencepiece==0.2.0 protobuf==5.28.1 scipy==1.13.1 44 | ``` 45 | 46 | ### 2. Quick Start 47 | 48 | ```bash 49 | cd T2IS_Gen 50 | ``` 51 | 52 | ```python 53 | import torch 54 | import argparse 55 | import json 56 | import os 57 | from t2is_pipeline_flux import T2IS_FluxPipeline 58 | from PIL import Image 59 | from utils import calculate_layout_dimensions, calculate_cutting_layout 60 | pipe = T2IS_FluxPipeline.from_pretrained("/home/chengyou/hugging/models/FLUX.1-dev", torch_dtype=torch.bfloat16) 61 | pipe = pipe.to("cuda") 62 | 63 | # base_output_path = "../output_images/RAG_layout_deepseek-reasoner_3_30_seed_1234" 64 | base_output_path = "./output_images/" 65 | 66 | print(f"Processing file with task name case ID: 0001_0007") 67 | task_name_case_id = "dynamic_character_scenario_design_0007" 68 | Divide_prompt_list = [ 69 | "The boy stands at a science fair, surrounded by project displays and glowing holographic models. He holds a blueprint, his expression bright with curiosity. The background features blurred crowds and colorful experiment stations under bright indoor lighting.", 70 | "The boy crouches in a sunlit garden, digging soil with a trowel. Dirt stains his hands and casual clothes, with scattered gardening tools nearby. His focused gaze and slightly parted lips suggest discovery, sunlight casting sharp shadows on the earthy textures.", 71 | "The boy wears a green knitted hat in a snowy urban park, breath visible in cold air. Frosted trees frame the scene as he clutches a steaming drink. The hat's yarn details contrast with his spiky hair, while distant ice-skating figures blur into the winter haze." 72 | ] 73 | prompt = "THREE-PANEL Images with a 1x3 grid layout a teenage boy with short spiky black hair, a slight build, and dark brown eyes in hyper-realistic style.All images maintain hyper-realistic digital painting style with consistent character design, emphasizing the boy's distinct features and naturalistic lighting across varied environments. [LEFT]:The boy stands at a science fair, surrounded by project displays and glowing holographic models. He holds a blueprint, his expression bright with curiosity. The background features blurred crowds and colorful experiment stations under bright indoor lighting. [MIDDLE]:The boy crouches in a sunlit garden, digging soil with a trowel. Dirt stains his hands and casual clothes, with scattered gardening tools nearby. His focused gaze and slightly parted lips suggest discovery, sunlight casting sharp shadows on the earthy textures. [RIGHT]:The boy wears a green knitted hat in a snowy urban park, breath visible in cold air. Frosted trees frame the scene as he clutches a steaming drink. The hat's yarn details contrast with his spiky hair, while distant ice-skating figures blur into the winter haze." 74 | 75 | # Set default sub-image size to 512x512 76 | sub_height = 512 77 | sub_width = 512 78 | 79 | # Calculate total height and width based on layout 80 | num_prompts = len(Divide_prompt_list) 81 | height, width = calculate_layout_dimensions(num_prompts, sub_height, sub_width) 82 | 83 | 84 | 85 | Divide_replace = 2 86 | num_inference_steps = 20 87 | 88 | seeds = [1234] 89 | 90 | for seed_idx, seed in enumerate(seeds): 91 | seed_output_path = os.path.join(base_output_path, f"seed_{seed}") 92 | if not os.path.exists(seed_output_path): 93 | os.makedirs(seed_output_path) 94 | 95 | print(f"Generating with seed {seed}:") 96 | try: 97 | image = pipe( 98 | Divide_prompt_list=Divide_prompt_list, 99 | Divide_replace=Divide_replace, 100 | seed=seed, 101 | prompt=prompt, 102 | height=height, 103 | width=width, 104 | num_inference_steps=num_inference_steps, 105 | guidance_scale=3.5, 106 | ).images[0] 107 | except Exception as e: 108 | print(f"Error processing {idx} with seed {seed}: {str(e)}") 109 | continue 110 | image.save(os.path.join(seed_output_path, f"{idx}_merge_seed{seed}.png")) 111 | ``` 112 | ## Generated ImageSet 113 | Examples 114 | 115 | 116 | 117 | 118 |
119 | 120 | 121 | ## Citation 122 | If you find it helpful, please kindly cite the paper. 123 | ``` 124 | @article{jia2025settle, 125 | title={Why Settle for One? Text-to-ImageSet Generation and Evaluation}, 126 | author={Jia, Chengyou and Shen, Xin and Dang, Zhuohang and Xia, Changliang and Wu, Weijia and Zhang, Xinyu and Qian, Hangwei and Tsang, Ivor W and Luo, Minnan}, 127 | journal={arXiv preprint arXiv:2506.23275}, 128 | year={2025} 129 | } 130 | ``` 131 | 132 | ## 📬 Contact 133 | 134 | If you have any inquiries, suggestions, or wish to contact us for any reason, we warmly invite you to email us at cp3jia@stu.xjtu.edu.cn. 135 | -------------------------------------------------------------------------------- /T2IS_Bench/README.md: -------------------------------------------------------------------------------- 1 | See https://huggingface.co/datasets/ChengyouJia/T2IS-Bench for more details. 2 | 3 | 4 | ### Dataset Overview 5 | 6 | **T2IS-Bench** is a comprehensive benchmark designed to evaluate generative models' performance in text-to-image set generation tasks. It includes **596 carefully constructed tasks** across **five major categories** (26 sub-categories), each targeting different aspects of set-level consistency such as identity preservation, style uniformity, and logical coherence. These tasks span a wide range of real-world applications, including character creation, visual storytelling, product mockups, procedural illustrations, and instructional content. 7 | 8 | T2IS-Bench provides a scalable evaluation framework that assesses image sets across **three critical consistency dimensions**: identity, style, and logic. Each of the **596 tasks** is paired with structured natural language instructions and evaluated using **LLM-driven criteria generation**, enabling automatic, interpretable, and fine-grained assessment. This design supports benchmarking generative models' ability to produce coherent visual outputs beyond prompt-level alignment, and reflects real-world requirements for controllability and consistency in multi-image generation. 9 | 10 | 11 | 12 | ### Supported Tasks 13 | 14 | The dataset comprises five main categories, each with a set of associated tasks and unique task IDs as listed below: 15 | 16 | #### **Character Generation** 17 | 18 | - `0001` – Multi-Scenario 19 | - `0002` – Multi-Expression 20 | - `0003` – Portrait Design 21 | - `0004` – Multi-view 22 | - `0005` – Multi-pose 23 | 24 | #### **Design Style Generation** 25 | 26 | - `0006` – Creative Style 27 | - `0007` – Poster Design 28 | - `0008` – Font Design 29 | - `0009` – IP Product 30 | - `0010` – Home Decoration 31 | 32 | #### **Story Generation** 33 | 34 | - `0011` – Movie Shot 35 | - `0012` – Comic Story 36 | - `0013` – Children Book 37 | - `0014` – News Illustration 38 | - `0015` – Hist. Narrative 39 | 40 | #### **Process Generation** 41 | 42 | - `0016` – Growth Process 43 | - `0017` – Draw Process 44 | - `0018` – Cooking Process 45 | - `0019` – Physical Law 46 | - `0020` – Arch. Building 47 | - `0021` – Evolution Illustration 48 | 49 | #### **Instruction Generation** 50 | 51 | - `0022` – Education Illustration 52 | - `0023` – Historical Panel 53 | - `0024` – Product Instruction 54 | - `0025` – Travel Guide 55 | - `0026` – Activity Arrange 56 | 57 | 58 | 59 | ### Use Cases 60 | 61 | **T2IS-Bench** is designed for evaluating generative models on multi-image consistency tasks, testing capabilities such as aesthetics, prompt alignment (including entity, attribute, and relation understanding), and visual consistency (covering identity, style, and logic) across image sets. It is suitable for benchmarking text-to-image models, diffusion transformers, and multimodal generation systems in real-world applications like product design, storytelling, and instructional visualization. 62 | 63 | ## Dataset Format and Structure 64 | 65 | ### Data Organization 66 | 67 | 1. **`T2IS-Bench.json`** 68 | A JSON file providing all of the cases. The structure of `T2IS-Bench.json` is as follows: 69 | 70 | ```json 71 | { 72 | ...... 73 | "0018_0001": { 74 | "task_name": "Cooking Process", 75 | "num_of_cases": 27, 76 | "uid": "0018", 77 | "output_image_count": 4, 78 | "case_id": "0001", 79 | "task_name_case_id": "cooking_process_0001", 80 | "category": "Process Generation", 81 | "instruction": "Please provide a detailed guide on melting chocolate, including 4 steps. For each step, generate an image.", 82 | "sub_caption": [ 83 | "A glass bowl filled with chopped dark chocolate pieces sits on top of a pot of simmering water. Steam rises gently around the bowl, and a thermometer is visible in the chocolate. The kitchen counter shows other baking ingredients in the background.", 84 | "Hands holding a silicone spatula are gently stirring melting chocolate in a glass bowl. The chocolate is partially melted, with some pieces still visible. The bowl is positioned over a steaming pot on a stovetop.", 85 | "A close-up view of a digital thermometer inserted into fully melted, glossy chocolate. The thermometer display shows a temperature of 88°F (31°C). The melted chocolate has a rich, dark color and smooth texture.", 86 | "A hand is seen removing the bowl of melted chocolate from the double boiler setup. The chocolate appears smooth and shiny. Next to the stove, various dessert items like strawberries, cookies, and a cake are ready for dipping or coating." 87 | ] 88 | } 89 | ...... 90 | } 91 | ``` 92 | 93 | - task_name: Name of the task. 94 | - num_of_cases: The number of individual cases in the task. 95 | - uid: Unique identifier for the task. 96 | - output_image_count: Number of images expected as output. 97 | - case_id: Identifier for this case. 98 | - task_name_case_id: Unique identifier for each specific case within a task, combining the task name and case ID. 99 | - category: The classification of the task. 100 | - instruction: The task's description, specifying what needs to be generated. 101 | - sub_caption: Descriptions for each image in the task by feeding instruction into LLM. 102 | 103 | 104 | 105 | 2. **`prompt_alignment_criterion.json`** 106 | 107 | This file contains evaluation criteria for assessing prompt alignment in image generation tasks. Each entry corresponds to a specific task and is organized by steps, with each step evaluated based on three key aspects: **Entity**, **Attribute**, and **Relation**. 108 | 109 | - **Entity** defines the key objects or characters required in the scene. 110 | - **Attribute** describes the properties or conditions that these entities must possess. 111 | - **Relation** outlines how the entities interact or are positioned within the scene. 112 | 113 | This structured format helps evaluate the accuracy of the generated images in response to specific prompts. 114 | 115 | 3. **`prompt_consistency_criterion.json`** 116 | 117 | This file defines evaluation criteria for assessing *intra-sequence consistency* in image generation tasks. Each entry corresponds to a specific task and outlines standards across three core aspects: **Style**, **Identity**, and **Logic**. 118 | 119 | - **Style** evaluates the visual coherence across all generated images, including consistency in rendering style, color palette, lighting conditions, and background detail. It ensures that all images share a unified artistic and atmospheric aesthetic. 120 | 121 | - **Identity** focuses on maintaining character integrity across scenes. This includes preserving key facial features, body proportions, attire, and expressions so that the same individual or entity is clearly represented throughout the sequence. 122 | 123 | - **Logic** ensures semantic and physical plausibility across images. This includes spatial layout consistency, realistic actions, appropriate interactions with the environment, and coherent scene transitions. 124 | 125 | This structured format enables a systematic evaluation of how well generated images maintain consistency within a task. 126 | -------------------------------------------------------------------------------- /T2IS_Eval/test_prompt_consistency.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from tqdm import tqdm 5 | from PIL import Image 6 | import torch 7 | import argparse 8 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor 9 | from qwen_vl_utils import process_vision_info 10 | from transformers import GenerationConfig 11 | 12 | def load_model(model_path): 13 | # Load the model on the available device(s) 14 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained( 15 | model_path, torch_dtype="auto", device_map="auto" 16 | ) 17 | 18 | # default processer 19 | processor = AutoProcessor.from_pretrained(model_path) 20 | 21 | return model, processor 22 | 23 | def evaluate_prompt_consistency(dataset_path, criteria_path, image_base_path, image_pattern, output_path=None): 24 | # Load the model 25 | model, processor = load_model("./Qwen/Qwen2.5-VL-7B-Instruct") 26 | 27 | # Load the filtered dataset 28 | with open(dataset_path, 'r') as f: 29 | dataset = json.load(f) 30 | print(f"Total number of cases in dataset: {len(dataset)}") 31 | 32 | # Load the prompt consistency criteria 33 | with open(criteria_path, 'r') as f: 34 | criteria_data = json.load(f) 35 | print(f"Loaded prompt consistency criteria") 36 | 37 | # Get folder name for output file 38 | folder_name = os.path.basename(image_base_path) 39 | print(f"\n{'='*50}") 40 | print(f"Processing folder: {folder_name}") 41 | print(f"{'='*50}\n") 42 | 43 | # Set output path if not provided 44 | if output_path is None: 45 | output_path = f"evaluation_scores_{folder_name}.json" 46 | 47 | # Initialize results 48 | results = {} 49 | 50 | # Check if results file already exists and load it 51 | if os.path.exists(output_path): 52 | with open(output_path, 'r') as f: 53 | results = json.load(f) 54 | print(f"Loaded existing results from {output_path}") 55 | else: 56 | results = { 57 | "case_results": {} 58 | } 59 | 60 | # Process each case in the dataset 61 | for case_id, case_data in tqdm(dataset.items(), desc=f"Processing cases"): 62 | # Skip if this case has already been processed 63 | if case_id in results["case_results"]: 64 | print(f"\nSkipping already processed case: {case_id}") 65 | continue 66 | 67 | print(f"\nProcessing case: {case_id}") 68 | 69 | # Get criteria for this case 70 | case_criteria = criteria_data.get(case_id, {}) 71 | if not case_criteria: 72 | print(f"Warning: No criteria found for case {case_id}. Skipping.") 73 | continue 74 | 75 | # Get all image paths for this case based on the provided pattern 76 | all_files = os.listdir(image_base_path) 77 | image_paths = [] 78 | 79 | # Parse the image pattern to extract case_id and file extension 80 | ext = image_pattern 81 | 82 | for f in all_files: 83 | if f.startswith(case_id) and f.endswith(ext): 84 | try: 85 | # Extract the last number from filename (e.g., 0001 from 0001_0001_0001.png) 86 | num = int(f.split('_')[-1].replace(ext, '')) 87 | image_paths.append((num, os.path.join(image_base_path, f))) 88 | except ValueError: 89 | continue 90 | 91 | # Sort by the extracted number and get only the paths 92 | image_paths = [path for _, path in sorted(image_paths)] 93 | 94 | if len(image_paths) == 0: 95 | print(f"Warning: Expected {case_data['output_image_count']} images, but found {len(image_paths)}. Skipping this case.") 96 | continue 97 | 98 | # Initialize dimension scores 99 | dimension_scores = { 100 | "Style": {"scores": [], "criteria": []}, 101 | "Identity": {"scores": [], "criteria": []}, 102 | "Logic": {"scores": [], "criteria": []} 103 | } 104 | 105 | # Process each dimension 106 | for dimension in ["Style", "Identity", "Logic"]: 107 | if dimension not in case_criteria: 108 | continue 109 | 110 | # Get criteria for this dimension 111 | dimension_criteria = case_criteria[dimension][0] # Get the first (and only) dictionary in the list 112 | dimension_scores[dimension]["criteria"] = list(dimension_criteria.values()) 113 | 114 | # Process each criterion in this dimension 115 | for criterion_text in dimension_criteria.values(): 116 | softmax_values = [] 117 | 118 | # Compare each pair of images 119 | for i in range(len(image_paths)): 120 | for j in range(i + 1, len(image_paths)): 121 | messages = [] 122 | messages.append( 123 | { 124 | "role": "user", 125 | "content": [ 126 | {"type": "image", "image": image_paths[i], "resized_height": 512, "resized_width": 512}, 127 | {"type": "image", "image": image_paths[j], "resized_height": 512, "resized_width": 512}, 128 | {"type": "text", "text": f"Do images meet the following criteria? {criterion_text} Please answer Yes or No."}, 129 | ], 130 | } 131 | ) 132 | 133 | # Prepare for inference 134 | text = processor.apply_chat_template( 135 | messages, tokenize=False, add_generation_prompt=True 136 | ) 137 | image_inputs, video_inputs = process_vision_info(messages) 138 | inputs = processor( 139 | text=[text], 140 | images=image_inputs, 141 | videos=video_inputs, 142 | padding=True, 143 | return_tensors="pt", 144 | ) 145 | inputs = inputs.to("cuda") 146 | 147 | generated_ids = model.generate(**inputs, max_new_tokens=1, return_dict_in_generate=True, output_scores=True, output_logits=True) 148 | sequences = generated_ids.sequences 149 | scores = generated_ids.scores 150 | logits = generated_ids.logits 151 | 152 | no_logits = logits[0][0][2753] 153 | yes_logits = logits[0][0][9454] 154 | 155 | # Calculate softmax 156 | logits = torch.tensor([no_logits, yes_logits]) 157 | softmax = torch.nn.functional.softmax(logits, dim=0) 158 | yes_softmax = softmax[1].item() 159 | 160 | generated_ids_trimmed = [ 161 | out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, sequences) 162 | ] 163 | 164 | output_text = processor.batch_decode( 165 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 166 | ) 167 | 168 | # Append the current yes_softmax value to the list 169 | softmax_values.append(yes_softmax) 170 | 171 | # Calculate average score for this criterion 172 | if softmax_values: 173 | average_softmax = sum(softmax_values) / len(softmax_values) 174 | dimension_scores[dimension]["scores"].append(average_softmax) 175 | 176 | # Calculate overall scores for each dimension 177 | dimension_averages = {} 178 | for dimension, data in dimension_scores.items(): 179 | if data["scores"]: 180 | dimension_averages[dimension] = sum(data["scores"]) / len(data["scores"]) 181 | else: 182 | dimension_averages[dimension] = 0.0 183 | 184 | # Calculate overall average across all dimensions 185 | overall_average = sum(dimension_averages.values()) / len(dimension_averages) if dimension_averages else 0.0 186 | 187 | # Store the results for this case 188 | results["case_results"][case_id] = { 189 | "task_name_case_id": case_data["task_name_case_id"], 190 | "num_images": case_data["output_image_count"], 191 | "dimension_scores": dimension_scores, 192 | "dimension_averages": dimension_averages, 193 | "overall_average": overall_average 194 | } 195 | 196 | print(f"\nScores for case {case_id}:") 197 | for dimension, avg in dimension_averages.items(): 198 | print(f"{dimension} average: {avg:.4f}") 199 | print(f"Overall average: {overall_average:.4f}") 200 | 201 | # Save results after each case is processed 202 | with open(output_path, 'w') as f: 203 | json.dump(results, f, indent=4) 204 | print(f"Updated results saved to: {output_path}") 205 | 206 | # Final save of results 207 | with open(output_path, 'w') as f: 208 | json.dump(results, f, indent=4) 209 | print(f"\nFinal results saved to: {output_path}") 210 | 211 | # Calculate overall dimension scores across all cases 212 | print(f"\nCalculating overall dimension scores...") 213 | 214 | # Initialize dimension totals 215 | dimension_totals = {} 216 | dimension_counts = {} 217 | 218 | # Aggregate scores across all cases 219 | for case_id, case_result in results["case_results"].items(): 220 | for dimension, avg in case_result["dimension_averages"].items(): 221 | if dimension not in dimension_totals: 222 | dimension_totals[dimension] = 0.0 223 | dimension_counts[dimension] = 0 224 | if avg > 0: # Only add non-zero scores 225 | dimension_totals[dimension] += avg 226 | dimension_counts[dimension] += 1 227 | 228 | # Calculate final averages for each dimension 229 | final_dimension_averages = {} 230 | for dimension, total in dimension_totals.items(): 231 | if dimension_counts[dimension] > 0: 232 | final_dimension_averages[dimension] = total / dimension_counts[dimension] 233 | else: 234 | final_dimension_averages[dimension] = 0.0 235 | 236 | # Calculate overall average across all dimensions 237 | final_overall_average = sum(final_dimension_averages.values()) / len(final_dimension_averages) if final_dimension_averages else 0.0 238 | 239 | # Add final scores to results 240 | results["final_dimension_averages"] = final_dimension_averages 241 | results["final_overall_average"] = final_overall_average 242 | 243 | # Save updated results with final scores 244 | with open(output_path, 'w') as f: 245 | json.dump(results, f, indent=4) 246 | 247 | # Print final scores 248 | print(f"\nFinal dimension scores:") 249 | for dimension, avg in final_dimension_averages.items(): 250 | print(f"{dimension} final average: {avg:.4f}") 251 | print(f"Final overall average: {final_overall_average:.4f}") 252 | 253 | return results 254 | 255 | def main(): 256 | parser = argparse.ArgumentParser(description="Evaluate prompt consistency for generated images") 257 | parser.add_argument("--dataset_path", type=str, default="/home/chengyou/GroupGen/baseline/filtered_responses_sorted.json", help="Path to the dataset JSON file") 258 | parser.add_argument("--criteria_path", type=str, default="/home/chengyou/Qwen/prompt_consistency.json", help="Path to the prompt consistency criteria JSON file") 259 | parser.add_argument("--image_base_path", type=str, default="/home/chengyou/sx/Gemini/gemini", help="Path to the directory containing generated images") 260 | parser.add_argument("--image_pattern", type=str, default="jpg", help="Pattern for image filenames, use {} as placeholder for case_id") 261 | parser.add_argument("--output_path", type=str, help="Path to save the evaluation results") 262 | 263 | args = parser.parse_args() 264 | 265 | evaluate_prompt_consistency( 266 | args.dataset_path, 267 | args.criteria_path, 268 | args.image_base_path, 269 | args.image_pattern, 270 | args.output_path 271 | ) 272 | 273 | if __name__ == "__main__": 274 | main() 275 | -------------------------------------------------------------------------------- /T2IS_Gen/t2is_transformer_flux.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Any, Dict, Optional, Tuple, Union 17 | 18 | import numpy as np 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | from diffusers.configuration_utils import ConfigMixin, register_to_config 24 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin 25 | from diffusers.models.attention import FeedForward 26 | from diffusers.models.attention_processor import ( 27 | Attention, 28 | AttentionProcessor, 29 | FluxAttnProcessor2_0, 30 | FusedFluxAttnProcessor2_0, 31 | ) 32 | from diffusers.models.modeling_utils import ModelMixin 33 | from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle 34 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 35 | from diffusers.utils.torch_utils import maybe_allow_in_graph 36 | from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed 37 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 38 | from typing import List 39 | 40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 41 | 42 | 43 | @maybe_allow_in_graph 44 | class FluxSingleTransformerBlock(nn.Module): 45 | r""" 46 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. 47 | 48 | Reference: https://arxiv.org/abs/2403.03206 49 | 50 | Parameters: 51 | dim (`int`): The number of channels in the input and output. 52 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 53 | attention_head_dim (`int`): The number of channels in each head. 54 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the 55 | processing of `context` conditions. 56 | """ 57 | 58 | def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): 59 | super().__init__() 60 | self.mlp_hidden_dim = int(dim * mlp_ratio) 61 | 62 | self.norm = AdaLayerNormZeroSingle(dim) 63 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) 64 | self.act_mlp = nn.GELU(approximate="tanh") 65 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) 66 | 67 | processor = FluxAttnProcessor2_0() 68 | self.attn = Attention( 69 | query_dim=dim, 70 | cross_attention_dim=None, 71 | dim_head=attention_head_dim, 72 | heads=num_attention_heads, 73 | out_dim=dim, 74 | bias=True, 75 | processor=processor, 76 | qk_norm="rms_norm", 77 | eps=1e-6, 78 | pre_only=True, 79 | ) 80 | 81 | def forward( 82 | self, 83 | hidden_states: torch.FloatTensor, 84 | temb: torch.FloatTensor, 85 | image_rotary_emb=None, 86 | joint_attention_kwargs=None, 87 | ): 88 | residual = hidden_states 89 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb) 90 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) 91 | joint_attention_kwargs = joint_attention_kwargs or {} 92 | 93 | attn_output = self.attn( 94 | hidden_states=norm_hidden_states, 95 | image_rotary_emb=image_rotary_emb, 96 | **joint_attention_kwargs 97 | ) 98 | 99 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) 100 | gate = gate.unsqueeze(1) 101 | hidden_states = gate * self.proj_out(hidden_states) 102 | hidden_states = residual + hidden_states 103 | if hidden_states.dtype == torch.float16: 104 | hidden_states = hidden_states.clip(-65504, 65504) 105 | 106 | return hidden_states 107 | 108 | 109 | @maybe_allow_in_graph 110 | class FluxTransformerBlock(nn.Module): 111 | r""" 112 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. 113 | 114 | Reference: https://arxiv.org/abs/2403.03206 115 | 116 | Parameters: 117 | dim (`int`): The number of channels in the input and output. 118 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 119 | attention_head_dim (`int`): The number of channels in each head. 120 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the 121 | processing of `context` conditions. 122 | """ 123 | 124 | def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): 125 | super().__init__() 126 | 127 | self.norm1 = AdaLayerNormZero(dim) 128 | 129 | self.norm1_context = AdaLayerNormZero(dim) 130 | 131 | if hasattr(F, "scaled_dot_product_attention"): 132 | processor = FluxAttnProcessor2_0() 133 | else: 134 | raise ValueError( 135 | "The current PyTorch version does not support the `scaled_dot_product_attention` function." 136 | ) 137 | self.attn = Attention( 138 | query_dim=dim, 139 | cross_attention_dim=None, 140 | added_kv_proj_dim=dim, 141 | dim_head=attention_head_dim, 142 | heads=num_attention_heads, 143 | out_dim=dim, 144 | context_pre_only=False, 145 | bias=True, 146 | processor=processor, 147 | qk_norm=qk_norm, 148 | eps=eps, 149 | ) 150 | 151 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 152 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 153 | 154 | self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 155 | self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 156 | 157 | # let chunk size default to None 158 | self._chunk_size = None 159 | self._chunk_dim = 0 160 | 161 | def forward( 162 | self, 163 | hidden_states: torch.FloatTensor, 164 | encoder_hidden_states: torch.FloatTensor, 165 | temb: torch.FloatTensor, 166 | image_rotary_emb=None, 167 | joint_attention_kwargs=None, 168 | ): 169 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) 170 | 171 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( 172 | encoder_hidden_states, emb=temb 173 | ) 174 | joint_attention_kwargs = joint_attention_kwargs or {} 175 | 176 | 177 | 178 | # Attention. 179 | attn_output, context_attn_output = self.attn( 180 | hidden_states=norm_hidden_states, 181 | encoder_hidden_states=norm_encoder_hidden_states, 182 | image_rotary_emb=image_rotary_emb, 183 | **joint_attention_kwargs, 184 | ) 185 | 186 | # Process attention outputs for the `hidden_states`. 187 | attn_output = gate_msa.unsqueeze(1) * attn_output 188 | hidden_states = hidden_states + attn_output 189 | 190 | norm_hidden_states = self.norm2(hidden_states) 191 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 192 | 193 | ff_output = self.ff(norm_hidden_states) 194 | ff_output = gate_mlp.unsqueeze(1) * ff_output 195 | 196 | hidden_states = hidden_states + ff_output 197 | 198 | # Process attention outputs for the `encoder_hidden_states`. 199 | 200 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output 201 | encoder_hidden_states = encoder_hidden_states + context_attn_output 202 | 203 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) 204 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] 205 | 206 | context_ff_output = self.ff_context(norm_encoder_hidden_states) 207 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output 208 | if encoder_hidden_states.dtype == torch.float16: 209 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) 210 | 211 | return encoder_hidden_states, hidden_states 212 | 213 | 214 | class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): 215 | """ 216 | The Transformer model introduced in Flux. 217 | 218 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 219 | 220 | Parameters: 221 | patch_size (`int`): Patch size to turn the input data into small patches. 222 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. 223 | num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. 224 | num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. 225 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. 226 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. 227 | joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 228 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. 229 | guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. 230 | """ 231 | 232 | _supports_gradient_checkpointing = True 233 | _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] 234 | 235 | @register_to_config 236 | def __init__( 237 | self, 238 | patch_size: int = 1, 239 | in_channels: int = 64, 240 | num_layers: int = 19, 241 | num_single_layers: int = 38, 242 | attention_head_dim: int = 128, 243 | num_attention_heads: int = 24, 244 | joint_attention_dim: int = 4096, 245 | pooled_projection_dim: int = 768, 246 | guidance_embeds: bool = False, 247 | axes_dims_rope: Tuple[int] = (16, 56, 56), 248 | ): 249 | super().__init__() 250 | self.out_channels = in_channels 251 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim 252 | 253 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) 254 | 255 | text_time_guidance_cls = ( 256 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings 257 | ) 258 | self.time_text_embed = text_time_guidance_cls( 259 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim 260 | ) 261 | 262 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) 263 | self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) 264 | 265 | self.transformer_blocks = nn.ModuleList( 266 | [ 267 | FluxTransformerBlock( 268 | dim=self.inner_dim, 269 | num_attention_heads=self.config.num_attention_heads, 270 | attention_head_dim=self.config.attention_head_dim, 271 | ) 272 | for i in range(self.config.num_layers) 273 | ] 274 | ) 275 | 276 | self.single_transformer_blocks = nn.ModuleList( 277 | [ 278 | FluxSingleTransformerBlock( 279 | dim=self.inner_dim, 280 | num_attention_heads=self.config.num_attention_heads, 281 | attention_head_dim=self.config.attention_head_dim, 282 | ) 283 | for i in range(self.config.num_single_layers) 284 | ] 285 | ) 286 | 287 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) 288 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) 289 | 290 | self.gradient_checkpointing = False 291 | 292 | @property 293 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 294 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 295 | r""" 296 | Returns: 297 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 298 | indexed by its weight name. 299 | """ 300 | # set recursively 301 | processors = {} 302 | 303 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 304 | if hasattr(module, "get_processor"): 305 | processors[f"{name}.processor"] = module.get_processor() 306 | 307 | for sub_name, child in module.named_children(): 308 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 309 | 310 | return processors 311 | 312 | for name, module in self.named_children(): 313 | fn_recursive_add_processors(name, module, processors) 314 | 315 | return processors 316 | 317 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 318 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 319 | r""" 320 | Sets the attention processor to use to compute attention. 321 | 322 | Parameters: 323 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 324 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 325 | for **all** `Attention` layers. 326 | 327 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 328 | processor. This is strongly recommended when setting trainable attention processors. 329 | 330 | """ 331 | count = len(self.attn_processors.keys()) 332 | 333 | if isinstance(processor, dict) and len(processor) != count: 334 | raise ValueError( 335 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 336 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 337 | ) 338 | 339 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 340 | if hasattr(module, "set_processor"): 341 | if not isinstance(processor, dict): 342 | module.set_processor(processor) 343 | else: 344 | module.set_processor(processor.pop(f"{name}.processor")) 345 | 346 | for sub_name, child in module.named_children(): 347 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 348 | 349 | for name, module in self.named_children(): 350 | fn_recursive_attn_processor(name, module, processor) 351 | 352 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 353 | def fuse_qkv_projections(self): 354 | """ 355 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) 356 | are fused. For cross-attention modules, key and value projection matrices are fused. 357 | 358 | 359 | 360 | This API is 🧪 experimental. 361 | 362 | 363 | """ 364 | self.original_attn_processors = None 365 | 366 | for _, attn_processor in self.attn_processors.items(): 367 | if "Added" in str(attn_processor.__class__.__name__): 368 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") 369 | 370 | self.original_attn_processors = self.attn_processors 371 | 372 | for module in self.modules(): 373 | if isinstance(module, Attention): 374 | module.fuse_projections(fuse=True) 375 | 376 | self.set_attn_processor(FusedFluxAttnProcessor2_0()) 377 | 378 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 379 | def unfuse_qkv_projections(self): 380 | """Disables the fused QKV projection if enabled. 381 | 382 | 383 | 384 | This API is 🧪 experimental. 385 | 386 | 387 | 388 | """ 389 | if self.original_attn_processors is not None: 390 | self.set_attn_processor(self.original_attn_processors) 391 | 392 | def _set_gradient_checkpointing(self, module, value=False): 393 | if hasattr(module, "gradient_checkpointing"): 394 | module.gradient_checkpointing = value 395 | 396 | def Divide_replace_hidden_states(self, hidden_states, Divide_hidden_states_list_list, Divide_m_offset_list,Divide_n_offset_list,Divide_m_scale_list,Divide_n_scale_list, latent_h, latent_w, Divide_idx): 397 | hidden_states = hidden_states.view(hidden_states.shape[0], latent_h,latent_w, hidden_states.shape[2]) 398 | 399 | for Divide_hidden_states_list, Divide_m_offset, Divide_n_offset, Divide_m_scale, Divide_n_scale in zip(Divide_hidden_states_list_list, Divide_m_offset_list, Divide_n_offset_list, Divide_m_scale_list, Divide_n_scale_list): 400 | Divide_hidden_states = Divide_hidden_states_list[Divide_idx] 401 | Divide_hidden_states = Divide_hidden_states.view(Divide_hidden_states.shape[0], Divide_n_scale,Divide_m_scale, Divide_hidden_states.shape[2]) 402 | hidden_states[:, Divide_n_offset:Divide_n_offset+Divide_n_scale, Divide_m_offset:Divide_m_offset+Divide_m_scale, :] = Divide_hidden_states 403 | 404 | hidden_states = hidden_states.view(hidden_states.shape[0], latent_h*latent_w, hidden_states.shape[3]) 405 | Divide_idx += 1 406 | 407 | return hidden_states, Divide_idx 408 | 409 | def Repainting_replace_hidden_states(self, hidden_states, original_hidden_states_list, Repainting_Divide_m_offset, Repainting_Divide_n_offset, Repainting, latent_h, latent_w, Repainting_idx): 410 | original_hidden_states = original_hidden_states_list[Repainting_idx] 411 | original_hidden_states = original_hidden_states.view(original_hidden_states.shape[0], latent_h, latent_w, original_hidden_states.shape[2]) 412 | hidden_states = hidden_states.view(hidden_states.shape[0], latent_h,latent_w, hidden_states.shape[2]) 413 | original_hidden_states[:, Repainting_Divide_n_offset:Repainting_Divide_n_offset+Repainting.shape[1], Repainting_Divide_m_offset:Repainting_Divide_m_offset+Repainting.shape[2], :][Repainting == 1] = hidden_states[:, Repainting_Divide_n_offset:Repainting_Divide_n_offset+Repainting.shape[1], Repainting_Divide_m_offset:Repainting_Divide_m_offset+Repainting.shape[2], :][Repainting == 1] 414 | hidden_states = original_hidden_states.view(hidden_states.shape[0], latent_h*latent_w, hidden_states.shape[3]) 415 | Repainting_idx += 1 416 | 417 | return hidden_states, Repainting_idx 418 | 419 | def forward( 420 | self, 421 | hidden_states: torch.Tensor, 422 | encoder_hidden_states: torch.Tensor = None, 423 | pooled_projections: torch.Tensor = None, 424 | timestep: torch.LongTensor = None, 425 | img_ids: torch.Tensor = None, 426 | txt_ids: torch.Tensor = None, 427 | guidance: torch.Tensor = None, 428 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 429 | controlnet_block_samples=None, 430 | controlnet_single_block_samples=None, 431 | return_dict: bool = True, 432 | controlnet_blocks_repeat: bool = False, 433 | latent_h: int=None, 434 | latent_w: int=None, 435 | Divide_hidden_states_list_list: List[List[torch.Tensor]] = None, 436 | Divide_m_offset_list: List[int]=None, 437 | Divide_n_offset_list: List[int]=None, 438 | Divide_m_scale_list: List[int]=None, 439 | Divide_n_scale_list: List[int]=None, 440 | return_hidden_states_list: bool = False, 441 | original_hidden_states_list: List[torch.Tensor] = None, 442 | Repainting_Divide_m_offset: int=None, 443 | Repainting_Divide_n_offset: int=None, 444 | Repainting: torch.Tensor = None, 445 | Repainting_single: int=False, 446 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: 447 | """ 448 | The [`FluxTransformer2DModel`] forward method. 449 | 450 | Args: 451 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): 452 | Input `hidden_states`. 453 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): 454 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. 455 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected 456 | from the embeddings of input conditions. 457 | timestep ( `torch.LongTensor`): 458 | Used to indicate denoising step. 459 | block_controlnet_hidden_states: (`list` of `torch.Tensor`): 460 | A list of tensors that if specified are added to the residuals of transformer blocks. 461 | joint_attention_kwargs (`dict`, *optional*): 462 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 463 | `self.processor` in 464 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 465 | return_dict (`bool`, *optional*, defaults to `True`): 466 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 467 | tuple. 468 | 469 | Returns: 470 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 471 | `tuple` where the first element is the sample tensor. 472 | """ 473 | if return_hidden_states_list: 474 | hidden_states_list=[] 475 | if joint_attention_kwargs is not None: 476 | joint_attention_kwargs = joint_attention_kwargs.copy() 477 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 478 | else: 479 | lora_scale = 1.0 480 | 481 | if USE_PEFT_BACKEND: 482 | # weight the lora layers by setting `lora_scale` for each PEFT layer 483 | scale_lora_layers(self, lora_scale) 484 | else: 485 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 486 | logger.warning( 487 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 488 | ) 489 | hidden_states = self.x_embedder(hidden_states) 490 | 491 | if Divide_hidden_states_list_list is not None: 492 | Divide_idx = 0 493 | hidden_states, Divide_idx = self.Divide_replace_hidden_states(hidden_states, Divide_hidden_states_list_list, Divide_m_offset_list, Divide_n_offset_list, Divide_m_scale_list, Divide_n_scale_list, latent_h, latent_w, Divide_idx) 494 | 495 | if original_hidden_states_list is not None: 496 | Repainting_idx = 0 497 | hidden_states, Repainting_idx = self.Repainting_replace_hidden_states(hidden_states, original_hidden_states_list, Repainting_Divide_m_offset, Repainting_Divide_n_offset, Repainting, latent_h, latent_w, Repainting_idx) 498 | 499 | if return_hidden_states_list: 500 | hidden_states_list.append(hidden_states) 501 | 502 | timestep = timestep.to(hidden_states.dtype) * 1000 503 | if guidance is not None: 504 | guidance = guidance.to(hidden_states.dtype) * 1000 505 | else: 506 | guidance = None 507 | temb = ( 508 | self.time_text_embed(timestep, pooled_projections) 509 | if guidance is None 510 | else self.time_text_embed(timestep, guidance, pooled_projections) 511 | ) 512 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 513 | 514 | if txt_ids.ndim == 3: 515 | logger.warning( 516 | "Passing `txt_ids` 3d torch.Tensor is deprecated." 517 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 518 | ) 519 | txt_ids = txt_ids[0] 520 | if img_ids.ndim == 3: 521 | logger.warning( 522 | "Passing `img_ids` 3d torch.Tensor is deprecated." 523 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 524 | ) 525 | img_ids = img_ids[0] 526 | 527 | ids = torch.cat((txt_ids, img_ids), dim=0) 528 | image_rotary_emb = self.pos_embed(ids) 529 | 530 | for index_block, block in enumerate(self.transformer_blocks): 531 | if self.training and self.gradient_checkpointing: 532 | 533 | def create_custom_forward(module, return_dict=None): 534 | def custom_forward(*inputs): 535 | if return_dict is not None: 536 | return module(*inputs, return_dict=return_dict) 537 | else: 538 | return module(*inputs) 539 | 540 | return custom_forward 541 | 542 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 543 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 544 | create_custom_forward(block), 545 | hidden_states, 546 | encoder_hidden_states, 547 | temb, 548 | image_rotary_emb, 549 | **ckpt_kwargs, 550 | ) 551 | 552 | else: 553 | encoder_hidden_states, hidden_states = block( 554 | hidden_states=hidden_states, 555 | encoder_hidden_states=encoder_hidden_states, 556 | temb=temb, 557 | image_rotary_emb=image_rotary_emb, 558 | joint_attention_kwargs=joint_attention_kwargs, 559 | ) 560 | 561 | if Divide_hidden_states_list_list is not None: 562 | hidden_states, Divide_idx = self.Divide_replace_hidden_states(hidden_states, Divide_hidden_states_list_list, Divide_m_offset_list, Divide_n_offset_list, Divide_m_scale_list, Divide_n_scale_list, latent_h, latent_w, Divide_idx) 563 | 564 | if original_hidden_states_list is not None: 565 | hidden_states, Repainting_idx = self.Repainting_replace_hidden_states(hidden_states, original_hidden_states_list, Repainting_Divide_m_offset, Repainting_Divide_n_offset, Repainting, latent_h, latent_w, Repainting_idx) 566 | 567 | if return_hidden_states_list: 568 | hidden_states_list.append(hidden_states) 569 | 570 | # controlnet residual 571 | if controlnet_block_samples is not None: 572 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) 573 | interval_control = int(np.ceil(interval_control)) 574 | # For Xlabs ControlNet. 575 | if controlnet_blocks_repeat: 576 | hidden_states = ( 577 | hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] 578 | ) 579 | else: 580 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] 581 | 582 | 583 | 584 | # if not Divide_m_offset_list: 585 | # print("save feature") 586 | # print(timestep.item()) 587 | # torch.save(hidden_states.squeeze(0).cpu(), f"tensors_new/feature_{timestep.item()}.pt") # save feature in the shape of [c, h, w] 588 | 589 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 590 | 591 | for index_block, block in enumerate(self.single_transformer_blocks): 592 | if self.training and self.gradient_checkpointing: 593 | 594 | def create_custom_forward(module, return_dict=None): 595 | def custom_forward(*inputs): 596 | if return_dict is not None: 597 | return module(*inputs, return_dict=return_dict) 598 | else: 599 | return module(*inputs) 600 | 601 | return custom_forward 602 | 603 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 604 | hidden_states = torch.utils.checkpoint.checkpoint( 605 | create_custom_forward(block), 606 | hidden_states, 607 | temb, 608 | image_rotary_emb, 609 | **ckpt_kwargs, 610 | ) 611 | 612 | else: 613 | hidden_states = block( 614 | hidden_states=hidden_states, 615 | temb=temb, 616 | image_rotary_emb=image_rotary_emb, 617 | joint_attention_kwargs=joint_attention_kwargs, 618 | ) 619 | 620 | if Divide_hidden_states_list_list is not None: 621 | hidden_states_clone = hidden_states.clone()[:, encoder_hidden_states.shape[1] :, ...].view(hidden_states.shape[0], latent_h, latent_w, hidden_states.shape[2]) 622 | 623 | for Divide_hidden_states_list, Divide_m_offset, Divide_n_offset, Divide_m_scale,Divide_n_scale in zip(Divide_hidden_states_list_list, Divide_m_offset_list, Divide_n_offset_list, Divide_m_scale_list, Divide_n_scale_list): 624 | Divide_hidden_states = Divide_hidden_states_list[Divide_idx] 625 | Divide_hidden_states = Divide_hidden_states[:, Divide_hidden_states.shape[1]-Divide_n_scale*Divide_m_scale :, ...].view(Divide_hidden_states.shape[0], Divide_n_scale, Divide_m_scale, Divide_hidden_states.shape[2]) 626 | hidden_states_clone[:, Divide_n_offset:Divide_n_offset+Divide_n_scale,Divide_m_offset:Divide_m_offset+Divide_m_scale, :] = Divide_hidden_states 627 | 628 | hidden_states_clone = hidden_states_clone.view(hidden_states.shape[0], latent_h*latent_w, hidden_states.shape[2]) 629 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = hidden_states_clone 630 | Divide_idx += 1 631 | 632 | if original_hidden_states_list is not None: 633 | if Repainting_single: 634 | hidden_states_clone = hidden_states.clone()[:, encoder_hidden_states.shape[1] :, ...].view(hidden_states.shape[0], latent_h, latent_w, hidden_states.shape[2]) 635 | original_hidden_states = original_hidden_states_list[Repainting_idx] 636 | original_hidden_states = original_hidden_states[:, encoder_hidden_states.shape[1] :, ...].view(original_hidden_states.shape[0], latent_h, latent_w, original_hidden_states.shape[2]) 637 | original_hidden_states[:, Repainting_Divide_n_offset:Repainting_Divide_n_offset+Repainting.shape[1], Repainting_Divide_m_offset:Repainting_Divide_m_offset+Repainting.shape[2], :][Repainting == 1] = hidden_states_clone[:, Repainting_Divide_n_offset:Repainting_Divide_n_offset+Repainting.shape[1], Repainting_Divide_m_offset:Repainting_Divide_m_offset+Repainting.shape[2], :][Repainting == 1] 638 | hidden_states_clone = original_hidden_states.view(hidden_states.shape[0], latent_h*latent_w, hidden_states.shape[2]) 639 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = hidden_states_clone 640 | Repainting_idx += 1 641 | 642 | if return_hidden_states_list: 643 | hidden_states_list.append(hidden_states) 644 | 645 | # controlnet residual 646 | if controlnet_single_block_samples is not None: 647 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) 648 | interval_control = int(np.ceil(interval_control)) 649 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 650 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] 651 | + controlnet_single_block_samples[index_block // interval_control] 652 | ) 653 | 654 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 655 | 656 | hidden_states = self.norm_out(hidden_states, temb) 657 | 658 | if Divide_hidden_states_list_list is not None: 659 | hidden_states, Divide_idx = self.Divide_replace_hidden_states(hidden_states, Divide_hidden_states_list_list, Divide_m_offset_list, Divide_n_offset_list, Divide_m_scale_list, Divide_n_scale_list, latent_h, latent_w, Divide_idx) 660 | 661 | if original_hidden_states_list is not None: 662 | hidden_states, Repainting_idx = self.Repainting_replace_hidden_states(hidden_states, original_hidden_states_list, Repainting_Divide_m_offset, Repainting_Divide_n_offset, Repainting, latent_h, latent_w, Repainting_idx) 663 | 664 | if return_hidden_states_list: 665 | hidden_states_list.append(hidden_states) 666 | 667 | output = self.proj_out(hidden_states) 668 | 669 | if Divide_hidden_states_list_list is not None: 670 | hidden_states, Divide_idx = self.Divide_replace_hidden_states(hidden_states, Divide_hidden_states_list_list, Divide_m_offset_list, Divide_n_offset_list, Divide_m_scale_list, Divide_n_scale_list, latent_h, latent_w, Divide_idx) 671 | 672 | if original_hidden_states_list is not None: 673 | hidden_states, Repainting_idx = self.Repainting_replace_hidden_states(hidden_states, original_hidden_states_list, Repainting_Divide_m_offset, Repainting_Divide_n_offset, Repainting, latent_h, latent_w, Repainting_idx) 674 | 675 | if return_hidden_states_list: 676 | hidden_states_list.append(hidden_states) 677 | 678 | if USE_PEFT_BACKEND: 679 | # remove `lora_scale` from each PEFT layer 680 | unscale_lora_layers(self, lora_scale) 681 | 682 | if not return_dict: 683 | if return_hidden_states_list: 684 | return (output,),hidden_states_list 685 | else: 686 | return (output,) 687 | 688 | return Transformer2DModelOutput(sample=output) -------------------------------------------------------------------------------- /T2IS_Gen/t2is_pipeline_flux.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import Any, Callable, Dict, List, Optional, Union 17 | 18 | import numpy as np 19 | import torch 20 | from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast 21 | 22 | from diffusers.image_processor import VaeImageProcessor 23 | from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin 24 | from diffusers.models.autoencoders import AutoencoderKL 25 | from diffusers.models.transformers import FluxTransformer2DModel 26 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 27 | from diffusers.utils import ( 28 | USE_PEFT_BACKEND, 29 | is_torch_xla_available, 30 | logging, 31 | replace_example_docstring, 32 | scale_lora_layers, 33 | unscale_lora_layers, 34 | ) 35 | from diffusers.utils.torch_utils import randn_tensor 36 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 37 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput 38 | 39 | from init_attention import init_forwards,TOKENS 40 | 41 | from utils import get_grid_params 42 | import random 43 | import importlib.util 44 | import sys 45 | import PIL 46 | from PIL import Image, ImageChops 47 | from scipy.ndimage import binary_dilation 48 | import torchvision.transforms as transforms 49 | 50 | module_name = 'diffusers.models.transformers.transformer_flux' 51 | module_path = './t2is_transformer_flux.py' 52 | 53 | if module_name in sys.modules: 54 | del sys.modules[module_name] 55 | 56 | spec = importlib.util.spec_from_file_location(module_name, module_path) 57 | regionfluxmodel = importlib.util.module_from_spec(spec) 58 | sys.modules[module_name] = regionfluxmodel 59 | spec.loader.exec_module(regionfluxmodel) 60 | 61 | FluxTransformer2DModel = regionfluxmodel.FluxTransformer2DModel 62 | 63 | if is_torch_xla_available(): 64 | import torch_xla.core.xla_model as xm 65 | 66 | XLA_AVAILABLE = True 67 | else: 68 | XLA_AVAILABLE = False 69 | 70 | 71 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 72 | 73 | EXAMPLE_DOC_STRING = """ 74 | Examples: 75 | ```py 76 | >>> import torch 77 | >>> from diffusers import FluxPipeline 78 | 79 | >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) 80 | >>> pipe.to("cuda") 81 | >>> prompt = "A cat holding a sign that says hello world" 82 | >>> # Depending on the variant being used, the pipeline call will slightly vary. 83 | >>> # Refer to the pipeline documentation for more details. 84 | >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] 85 | >>> image.save("flux.png") 86 | ``` 87 | """ 88 | 89 | 90 | def calculate_shift( 91 | image_seq_len, 92 | base_seq_len: int = 256, 93 | max_seq_len: int = 4096, 94 | base_shift: float = 0.5, 95 | max_shift: float = 1.16, 96 | ): 97 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) 98 | b = base_shift - m * base_seq_len 99 | mu = image_seq_len * m + b 100 | return mu 101 | 102 | 103 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 104 | def retrieve_timesteps( 105 | scheduler, 106 | num_inference_steps: Optional[int] = None, 107 | device: Optional[Union[str, torch.device]] = None, 108 | timesteps: Optional[List[int]] = None, 109 | sigmas: Optional[List[float]] = None, 110 | **kwargs, 111 | ): 112 | r""" 113 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 114 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 115 | 116 | Args: 117 | scheduler (`SchedulerMixin`): 118 | The scheduler to get timesteps from. 119 | num_inference_steps (`int`): 120 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 121 | must be `None`. 122 | device (`str` or `torch.device`, *optional*): 123 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 124 | timesteps (`List[int]`, *optional*): 125 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 126 | `num_inference_steps` and `sigmas` must be `None`. 127 | sigmas (`List[float]`, *optional*): 128 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 129 | `num_inference_steps` and `timesteps` must be `None`. 130 | 131 | Returns: 132 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 133 | second element is the number of inference steps. 134 | """ 135 | if timesteps is not None and sigmas is not None: 136 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 137 | if timesteps is not None: 138 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 139 | if not accepts_timesteps: 140 | raise ValueError( 141 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 142 | f" timestep schedules. Please check whether you are using the correct scheduler." 143 | ) 144 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 145 | timesteps = scheduler.timesteps 146 | num_inference_steps = len(timesteps) 147 | elif sigmas is not None: 148 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 149 | if not accept_sigmas: 150 | raise ValueError( 151 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 152 | f" sigmas schedules. Please check whether you are using the correct scheduler." 153 | ) 154 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 155 | timesteps = scheduler.timesteps 156 | num_inference_steps = len(timesteps) 157 | else: 158 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 159 | timesteps = scheduler.timesteps 160 | return timesteps, num_inference_steps 161 | 162 | 163 | class T2IS_FluxPipeline( 164 | DiffusionPipeline, 165 | FluxLoraLoaderMixin, 166 | FromSingleFileMixin, 167 | TextualInversionLoaderMixin, 168 | ): 169 | r""" 170 | The Flux pipeline for text-to-image generation. 171 | 172 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ 173 | 174 | Args: 175 | transformer ([`FluxTransformer2DModel`]): 176 | Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. 177 | scheduler ([`FlowMatchEulerDiscreteScheduler`]): 178 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 179 | vae ([`AutoencoderKL`]): 180 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 181 | text_encoder ([`CLIPTextModel`]): 182 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 183 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 184 | text_encoder_2 ([`T5EncoderModel`]): 185 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically 186 | the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. 187 | tokenizer (`CLIPTokenizer`): 188 | Tokenizer of class 189 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). 190 | tokenizer_2 (`T5TokenizerFast`): 191 | Second Tokenizer of class 192 | [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). 193 | """ 194 | 195 | model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" 196 | _optional_components = [] 197 | _callback_tensor_inputs = ["latents", "prompt_embeds"] 198 | 199 | def __init__( 200 | self, 201 | scheduler: FlowMatchEulerDiscreteScheduler, 202 | vae: AutoencoderKL, 203 | text_encoder: CLIPTextModel, 204 | tokenizer: CLIPTokenizer, 205 | text_encoder_2: T5EncoderModel, 206 | tokenizer_2: T5TokenizerFast, 207 | transformer: FluxTransformer2DModel, 208 | ): 209 | super().__init__() 210 | 211 | self.register_modules( 212 | vae=vae, 213 | text_encoder=text_encoder, 214 | text_encoder_2=text_encoder_2, 215 | tokenizer=tokenizer, 216 | tokenizer_2=tokenizer_2, 217 | transformer=transformer, 218 | scheduler=scheduler, 219 | ) 220 | self.vae_scale_factor = ( 221 | 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 222 | ) 223 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 224 | self.tokenizer_max_length = ( 225 | self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 226 | ) 227 | self.default_sample_size = 64 228 | 229 | def _get_t5_prompt_embeds( 230 | self, 231 | prompt: Union[str, List[str]] = None, 232 | num_images_per_prompt: int = 1, 233 | max_sequence_length: int = 512, 234 | device: Optional[torch.device] = None, 235 | dtype: Optional[torch.dtype] = None, 236 | ): 237 | device = device or self._execution_device 238 | dtype = dtype or self.text_encoder.dtype 239 | 240 | prompt = [prompt] if isinstance(prompt, str) else prompt 241 | batch_size = len(prompt) 242 | 243 | if isinstance(self, TextualInversionLoaderMixin): 244 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) 245 | 246 | text_inputs = self.tokenizer_2( 247 | prompt, 248 | padding="max_length", 249 | max_length=max_sequence_length, 250 | truncation=True, 251 | return_length=False, 252 | return_overflowing_tokens=False, 253 | return_tensors="pt", 254 | ) 255 | text_input_ids = text_inputs.input_ids 256 | untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids 257 | 258 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 259 | removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) 260 | logger.warning( 261 | "The following part of your input was truncated because `max_sequence_length` is set to " 262 | f" {max_sequence_length} tokens: {removed_text}" 263 | ) 264 | 265 | prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] 266 | 267 | dtype = self.text_encoder_2.dtype 268 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 269 | 270 | _, seq_len, _ = prompt_embeds.shape 271 | 272 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 273 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 274 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 275 | 276 | return prompt_embeds 277 | 278 | def _get_clip_prompt_embeds( 279 | self, 280 | prompt: Union[str, List[str]], 281 | num_images_per_prompt: int = 1, 282 | device: Optional[torch.device] = None, 283 | ): 284 | device = device or self._execution_device 285 | 286 | prompt = [prompt] if isinstance(prompt, str) else prompt 287 | batch_size = len(prompt) 288 | 289 | if isinstance(self, TextualInversionLoaderMixin): 290 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer) 291 | 292 | text_inputs = self.tokenizer( 293 | prompt, 294 | padding="max_length", 295 | max_length=self.tokenizer_max_length, 296 | truncation=True, 297 | return_overflowing_tokens=False, 298 | return_length=False, 299 | return_tensors="pt", 300 | ) 301 | 302 | text_input_ids = text_inputs.input_ids 303 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 304 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 305 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) 306 | logger.warning( 307 | "The following part of your input was truncated because CLIP can only handle sequences up to" 308 | f" {self.tokenizer_max_length} tokens: {removed_text}" 309 | ) 310 | prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) 311 | 312 | # Use pooled output of CLIPTextModel 313 | prompt_embeds = prompt_embeds.pooler_output 314 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 315 | 316 | # duplicate text embeddings for each generation per prompt, using mps friendly method 317 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) 318 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 319 | 320 | return prompt_embeds 321 | 322 | def encode_prompt( 323 | self, 324 | prompt: Union[str, List[str]], 325 | prompt_2: Union[str, List[str]], 326 | device: Optional[torch.device] = None, 327 | num_images_per_prompt: int = 1, 328 | prompt_embeds: Optional[torch.FloatTensor] = None, 329 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 330 | max_sequence_length: int = 512, 331 | lora_scale: Optional[float] = None, 332 | ): 333 | r""" 334 | 335 | Args: 336 | prompt (`str` or `List[str]`, *optional*): 337 | prompt to be encoded 338 | prompt_2 (`str` or `List[str]`, *optional*): 339 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 340 | used in all text-encoders 341 | device: (`torch.device`): 342 | torch device 343 | num_images_per_prompt (`int`): 344 | number of images that should be generated per prompt 345 | prompt_embeds (`torch.FloatTensor`, *optional*): 346 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 347 | provided, text embeddings will be generated from `prompt` input argument. 348 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 349 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 350 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 351 | lora_scale (`float`, *optional*): 352 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 353 | """ 354 | device = device or self._execution_device 355 | 356 | # set lora scale so that monkey patched LoRA 357 | # function of text encoder can correctly access it 358 | if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): 359 | self._lora_scale = lora_scale 360 | 361 | # dynamically adjust the LoRA scale 362 | if self.text_encoder is not None and USE_PEFT_BACKEND: 363 | scale_lora_layers(self.text_encoder, lora_scale) 364 | if self.text_encoder_2 is not None and USE_PEFT_BACKEND: 365 | scale_lora_layers(self.text_encoder_2, lora_scale) 366 | 367 | prompt = [prompt] if isinstance(prompt, str) else prompt 368 | 369 | if prompt_embeds is None: 370 | prompt_2 = prompt_2 or prompt 371 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 372 | 373 | # We only use the pooled prompt output from the CLIPTextModel 374 | pooled_prompt_embeds = self._get_clip_prompt_embeds( 375 | prompt=prompt, 376 | device=device, 377 | num_images_per_prompt=num_images_per_prompt, 378 | ) 379 | prompt_embeds = self._get_t5_prompt_embeds( 380 | prompt=prompt_2, 381 | num_images_per_prompt=num_images_per_prompt, 382 | max_sequence_length=max_sequence_length, 383 | device=device, 384 | ) 385 | 386 | if self.text_encoder is not None: 387 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 388 | # Retrieve the original scale by scaling back the LoRA layers 389 | unscale_lora_layers(self.text_encoder, lora_scale) 390 | 391 | if self.text_encoder_2 is not None: 392 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: 393 | # Retrieve the original scale by scaling back the LoRA layers 394 | unscale_lora_layers(self.text_encoder_2, lora_scale) 395 | 396 | dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype 397 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) 398 | 399 | return prompt_embeds, pooled_prompt_embeds, text_ids 400 | 401 | def Divide_encode_prompt( 402 | self, 403 | Divide_prompt_list: None, 404 | Redux_list: None, 405 | device: Optional[torch.device] = None, 406 | num_images_per_prompt: int = 1, 407 | max_sequence_length: int = 512, 408 | lora_scale: Optional[float] = None, 409 | ): 410 | Divide_prompt_embeds_list = [] 411 | Divide_pooled_prompt_embeds_list = [] 412 | Divide_text_ids_list = [] 413 | 414 | if Redux_list is not None: 415 | for Redux in Redux_list: 416 | ( 417 | Divide_prompt_embeds, 418 | Divide_pooled_prompt_embeds, 419 | Divide_text_ids, 420 | ) = self.encode_prompt( 421 | **Redux, 422 | prompt=None, 423 | prompt_2=None, 424 | device=device, 425 | num_images_per_prompt=num_images_per_prompt, 426 | max_sequence_length=max_sequence_length, 427 | lora_scale=lora_scale, 428 | ) 429 | Divide_prompt_embeds_list.append(Divide_prompt_embeds) 430 | Divide_pooled_prompt_embeds_list.append(Divide_pooled_prompt_embeds) 431 | Divide_text_ids_list.append(Divide_text_ids) 432 | else: 433 | for Divide_prompt in Divide_prompt_list: 434 | ( 435 | Divide_prompt_embeds, 436 | Divide_pooled_prompt_embeds, 437 | Divide_text_ids, 438 | ) = self.encode_prompt( 439 | prompt=Divide_prompt, 440 | prompt_2=None, 441 | device=device, 442 | num_images_per_prompt=num_images_per_prompt, 443 | max_sequence_length=max_sequence_length, 444 | lora_scale=lora_scale, 445 | ) 446 | 447 | Divide_prompt_embeds_list.append(Divide_prompt_embeds) 448 | Divide_pooled_prompt_embeds_list.append(Divide_pooled_prompt_embeds) 449 | Divide_text_ids_list.append(Divide_text_ids) 450 | 451 | return Divide_prompt_embeds_list, Divide_pooled_prompt_embeds_list, Divide_text_ids_list 452 | 453 | 454 | 455 | 456 | 457 | 458 | def torch_fix_seed(self, seed=42): 459 | random.seed(seed) 460 | np.random.seed(seed) 461 | torch.manual_seed(seed) 462 | torch.cuda.manual_seed(seed) 463 | torch.backends.cudnn.deterministic = True 464 | torch.use_deterministic_algorithms = True 465 | 466 | def check_inputs( 467 | self, 468 | prompt, 469 | prompt_2, 470 | height, 471 | width, 472 | prompt_embeds=None, 473 | pooled_prompt_embeds=None, 474 | callback_on_step_end_tensor_inputs=None, 475 | max_sequence_length=None, 476 | ): 477 | if height % 8 != 0 or width % 8 != 0: 478 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 479 | 480 | if callback_on_step_end_tensor_inputs is not None and not all( 481 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 482 | ): 483 | raise ValueError( 484 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 485 | ) 486 | 487 | if prompt is not None and prompt_embeds is not None: 488 | raise ValueError( 489 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 490 | " only forward one of the two." 491 | ) 492 | elif prompt_2 is not None and prompt_embeds is not None: 493 | raise ValueError( 494 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 495 | " only forward one of the two." 496 | ) 497 | elif prompt is None and prompt_embeds is None: 498 | raise ValueError( 499 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 500 | ) 501 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 502 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 503 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): 504 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") 505 | 506 | if prompt_embeds is not None and pooled_prompt_embeds is None: 507 | raise ValueError( 508 | "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." 509 | ) 510 | 511 | if max_sequence_length is not None and max_sequence_length > 512: 512 | raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") 513 | 514 | @staticmethod 515 | def _prepare_latent_image_ids(batch_size, height, width, device, dtype): 516 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) 517 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] 518 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] 519 | 520 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape 521 | 522 | latent_image_ids = latent_image_ids.reshape( 523 | latent_image_id_height * latent_image_id_width, latent_image_id_channels 524 | ) 525 | 526 | return latent_image_ids.to(device=device, dtype=dtype) 527 | 528 | @staticmethod 529 | def _pack_latents(latents, batch_size, num_channels_latents, height, width): 530 | latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) 531 | latents = latents.permute(0, 2, 4, 1, 3, 5) 532 | latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) 533 | 534 | return latents 535 | 536 | @staticmethod 537 | def _unpack_latents(latents, height, width, vae_scale_factor): 538 | batch_size, num_patches, channels = latents.shape 539 | 540 | height = height // vae_scale_factor 541 | width = width // vae_scale_factor 542 | 543 | latents = latents.view(batch_size, height, width, channels // 4, 2, 2) 544 | latents = latents.permute(0, 3, 1, 4, 2, 5) 545 | 546 | latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) 547 | 548 | return latents 549 | 550 | def enable_vae_slicing(self): 551 | r""" 552 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 553 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 554 | """ 555 | self.vae.enable_slicing() 556 | 557 | def disable_vae_slicing(self): 558 | r""" 559 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to 560 | computing decoding in one step. 561 | """ 562 | self.vae.disable_slicing() 563 | 564 | def enable_vae_tiling(self): 565 | r""" 566 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 567 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow 568 | processing larger images. 569 | """ 570 | self.vae.enable_tiling() 571 | 572 | def disable_vae_tiling(self): 573 | r""" 574 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to 575 | computing decoding in one step. 576 | """ 577 | self.vae.disable_tiling() 578 | 579 | 580 | 581 | def prepare_latents( 582 | self, 583 | batch_size, 584 | num_channels_latents, 585 | height, 586 | width, 587 | dtype, 588 | device, 589 | generator, 590 | latents=None, 591 | ): 592 | height = 2 * (int(height) // self.vae_scale_factor) 593 | width = 2 * (int(width) // self.vae_scale_factor) 594 | 595 | shape = (batch_size, num_channels_latents, height, width) 596 | 597 | if latents is not None: 598 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 599 | return latents.to(device=device, dtype=dtype), latent_image_ids 600 | 601 | if isinstance(generator, list) and len(generator) != batch_size: 602 | raise ValueError( 603 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 604 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 605 | ) 606 | 607 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 608 | latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) 609 | 610 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) 611 | 612 | return latents, latent_image_ids 613 | 614 | def prepare_Divide_latents( 615 | self, 616 | Divide_m_scale_list, 617 | Divide_n_scale_list, 618 | batch_size, 619 | num_channels_latents, 620 | dtype, 621 | device, 622 | generator 623 | ): 624 | Divide_latents_list = [] 625 | Divide_latent_image_ids_list = [] 626 | 627 | for Divide_m_scale, Divide_n_scale in zip(Divide_m_scale_list, Divide_n_scale_list): 628 | Divide_latents, Divide_latent_image_ids = self.prepare_latents( 629 | batch_size, 630 | num_channels_latents, 631 | Divide_n_scale*16, 632 | Divide_m_scale*16, 633 | dtype, 634 | device, 635 | generator 636 | ) 637 | 638 | Divide_latents_list.append(Divide_latents) 639 | Divide_latent_image_ids_list.append(Divide_latent_image_ids) 640 | 641 | return Divide_latents_list, Divide_latent_image_ids_list 642 | 643 | def prepare_Divide_replace( 644 | self, Divide_latents_list, timesteps, Divide_replace, latents, Divide_prompt_embeds_list, Divide_pooled_prompt_embeds_list, Divide_text_ids_list, Divide_latent_image_ids_list, guidance, Divide_m_scale_list, Divide_n_scale_list 645 | ): 646 | Divide_latents_list_list = [Divide_latents_list] 647 | Divide_hidden_states_list_list_list = [] 648 | 649 | for i, t in enumerate(timesteps): 650 | if(i >= Divide_replace): 651 | break 652 | 653 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 654 | Divide_noise_pred_list = [] 655 | Divide_hidden_states_list_list = [] 656 | 657 | for Divide_prompt_embeds, Divide_latents, Divide_pooled_prompt_embeds, Divide_text_ids,Divide_latent_image_ids in zip(Divide_prompt_embeds_list, Divide_latents_list, Divide_pooled_prompt_embeds_list, Divide_text_ids_list, Divide_latent_image_ids_list): 658 | Divide_noise_pred, Divide_hidden_states_list = self.transformer( 659 | hidden_states=Divide_latents, 660 | timestep=timestep / 1000, 661 | guidance=guidance, 662 | pooled_projections=Divide_pooled_prompt_embeds, 663 | encoder_hidden_states=Divide_prompt_embeds, 664 | txt_ids=Divide_text_ids, 665 | img_ids=Divide_latent_image_ids, 666 | joint_attention_kwargs=None, 667 | return_dict=False, 668 | return_hidden_states_list=True, 669 | ) 670 | Divide_noise_pred_list.append(Divide_noise_pred[0]) 671 | Divide_hidden_states_list_list.append(Divide_hidden_states_list) 672 | Divide_hidden_states_list_list_list.append(Divide_hidden_states_list_list) 673 | 674 | updated_Divide_latents_list = [] 675 | for Divide_latents, Divide_noise_pred in zip(Divide_latents_list, Divide_noise_pred_list): 676 | self.scheduler._init_step_index(t) 677 | Divide_latents = self.scheduler.step(Divide_noise_pred, t, Divide_latents, return_dict=False)[0] 678 | updated_Divide_latents_list.append(Divide_latents) 679 | Divide_latents_list = updated_Divide_latents_list 680 | Divide_latents_list_list.append(Divide_latents_list) 681 | 682 | Divide_latents_list_list = [ 683 | [ 684 | latents.view(latents.shape[0], n_scale, m_scale, latents.shape[2]) 685 | for latents, m_scale, n_scale in zip(latents_list, Divide_m_scale_list, Divide_n_scale_list) 686 | ] 687 | for latents_list in Divide_latents_list_list 688 | ] 689 | 690 | return Divide_latents_list_list, Divide_hidden_states_list_list_list 691 | 692 | 693 | 694 | def Divide_replace_latents(self, latents, Divide_latents_list, Divide_m_offset_list, Divide_n_offset_list, height, width): 695 | latents = latents.view(latents.shape[0], int(height//16), int(width//16), latents.shape[2]) 696 | for Divide_latents, Divide_m_offset, Divide_n_offset in zip(Divide_latents_list, Divide_m_offset_list, Divide_n_offset_list): 697 | latents[:, Divide_n_offset:Divide_n_offset+Divide_latents.shape[1], Divide_m_offset:Divide_m_offset+Divide_latents.shape[2], ] = Divide_latents 698 | latents = latents.view(latents.shape[0], latents.shape[1]*latents.shape[2], latents.shape[3]) 699 | 700 | return latents 701 | 702 | 703 | @property 704 | def guidance_scale(self): 705 | return self._guidance_scale 706 | 707 | @property 708 | def joint_attention_kwargs(self): 709 | return self._joint_attention_kwargs 710 | 711 | @property 712 | def num_timesteps(self): 713 | return self._num_timesteps 714 | 715 | @property 716 | def interrupt(self): 717 | return self._interrupt 718 | 719 | @torch.no_grad() 720 | @replace_example_docstring(EXAMPLE_DOC_STRING) 721 | def __call__( 722 | self, 723 | Divide_replace: int, 724 | seed: int, 725 | Divide_prompt_list: List[str]=None, 726 | Redux_list = None, 727 | prompt: Union[str, List[str]] = None, 728 | prompt_2: Optional[Union[str, List[str]]] = None, 729 | height: Optional[int] = None, 730 | width: Optional[int] = None, 731 | num_inference_steps: int = 28, 732 | timesteps: List[int] = None, 733 | guidance_scale: float = 3.5, 734 | num_images_per_prompt: Optional[int] = 1, 735 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 736 | latents: Optional[torch.FloatTensor] = None, 737 | prompt_embeds: Optional[torch.FloatTensor] = None, 738 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 739 | output_type: Optional[str] = "pil", 740 | return_dict: bool = True, 741 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 742 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 743 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 744 | max_sequence_length: int = 512, 745 | ): 746 | r""" 747 | Function invoked when calling the pipeline for generation. 748 | 749 | Args: 750 | prompt (`str` or `List[str]`, *optional*): 751 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 752 | instead. 753 | prompt_2 (`str` or `List[str]`, *optional*): 754 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 755 | will be used instead 756 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 757 | The height in pixels of the generated image. This is set to 1024 by default for the best results. 758 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 759 | The width in pixels of the generated image. This is set to 1024 by default for the best results. 760 | num_inference_steps (`int`, *optional*, defaults to 50): 761 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 762 | expense of slower inference. 763 | timesteps (`List[int]`, *optional*): 764 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 765 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 766 | passed will be used. Must be in descending order. 767 | guidance_scale (`float`, *optional*, defaults to 7.0): 768 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 769 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 770 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 771 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 772 | usually at the expense of lower image quality. 773 | num_images_per_prompt (`int`, *optional*, defaults to 1): 774 | The number of images to generate per prompt. 775 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 776 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 777 | to make generation deterministic. 778 | latents (`torch.FloatTensor`, *optional*): 779 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 780 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 781 | tensor will ge generated by sampling using the supplied random `generator`. 782 | prompt_embeds (`torch.FloatTensor`, *optional*): 783 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 784 | provided, text embeddings will be generated from `prompt` input argument. 785 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 786 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 787 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 788 | output_type (`str`, *optional*, defaults to `"pil"`): 789 | The output format of the generate image. Choose between 790 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 791 | return_dict (`bool`, *optional*, defaults to `True`): 792 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. 793 | joint_attention_kwargs (`dict`, *optional*): 794 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 795 | `self.processor` in 796 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 797 | callback_on_step_end (`Callable`, *optional*): 798 | A function that calls at the end of each denoising steps during the inference. The function is called 799 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 800 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 801 | `callback_on_step_end_tensor_inputs`. 802 | callback_on_step_end_tensor_inputs (`List`, *optional*): 803 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 804 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 805 | `._callback_tensor_inputs` attribute of your pipeline class. 806 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. 807 | 808 | Examples: 809 | 810 | Returns: 811 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` 812 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated 813 | images. 814 | """ 815 | 816 | self.h = height 817 | self.w = width 818 | 819 | if (seed > 0): 820 | self.torch_fix_seed(seed = seed) 821 | 822 | init_forwards(self, self.transformer) 823 | 824 | 825 | 826 | # Get grid parameters based on number of prompts 827 | size = len(Divide_prompt_list) 828 | Divide_m_offset_list, Divide_n_offset_list, Divide_m_scale_list, Divide_n_scale_list = get_grid_params(size) 829 | 830 | # Convert to latent space coordinates 831 | Divide_m_offset_list = [int(m_offset * width // 16) for m_offset in Divide_m_offset_list] 832 | Divide_n_offset_list = [int(n_offset * height // 16) for n_offset in Divide_n_offset_list] 833 | Divide_m_scale_list = [int(m_scale * width // 16) for m_scale in Divide_m_scale_list] 834 | Divide_n_scale_list = [int(n_scale * height // 16) for n_scale in Divide_n_scale_list] 835 | 836 | height = height or self.default_sample_size * self.vae_scale_factor 837 | width = width or self.default_sample_size * self.vae_scale_factor 838 | 839 | # 1. Check inputs. Raise error if not correct 840 | self.check_inputs( 841 | prompt, 842 | prompt_2, 843 | height, 844 | width, 845 | prompt_embeds=prompt_embeds, 846 | pooled_prompt_embeds=pooled_prompt_embeds, 847 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 848 | max_sequence_length=max_sequence_length, 849 | ) 850 | 851 | self._guidance_scale = guidance_scale 852 | self._joint_attention_kwargs = joint_attention_kwargs 853 | self._interrupt = False 854 | 855 | # 2. Define call parameters 856 | if prompt is not None and isinstance(prompt, str): 857 | batch_size = 1 858 | elif prompt is not None and isinstance(prompt, list): 859 | batch_size = len(prompt) 860 | else: 861 | batch_size = prompt_embeds.shape[0] 862 | 863 | device = self._execution_device 864 | 865 | lora_scale = ( 866 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 867 | ) 868 | ( 869 | prompt_embeds, 870 | pooled_prompt_embeds, 871 | text_ids, 872 | ) = self.encode_prompt( 873 | prompt=prompt, 874 | prompt_2=prompt_2, 875 | prompt_embeds=prompt_embeds, 876 | pooled_prompt_embeds=pooled_prompt_embeds, 877 | device=device, 878 | num_images_per_prompt=num_images_per_prompt, 879 | max_sequence_length=max_sequence_length, 880 | lora_scale=lora_scale, 881 | ) 882 | 883 | ( 884 | Divide_prompt_embeds_list, 885 | Divide_pooled_prompt_embeds_list, 886 | Divide_text_ids_list, 887 | ) = self.Divide_encode_prompt( 888 | Divide_prompt_list=Divide_prompt_list, 889 | Redux_list=Redux_list, 890 | device=device, 891 | num_images_per_prompt=num_images_per_prompt, 892 | max_sequence_length=max_sequence_length, 893 | lora_scale=lora_scale, 894 | ) 895 | 896 | 897 | 898 | 899 | # 4. Prepare latent variables 900 | num_channels_latents = self.transformer.config.in_channels // 4 901 | latents, latent_image_ids = self.prepare_latents( 902 | batch_size * num_images_per_prompt, 903 | num_channels_latents, 904 | height, 905 | width, 906 | prompt_embeds.dtype, 907 | device, 908 | generator, 909 | latents, 910 | ) 911 | 912 | Divide_latents_list, Divide_latent_image_ids_list = self.prepare_Divide_latents( 913 | Divide_m_scale_list, 914 | Divide_n_scale_list, 915 | batch_size * num_images_per_prompt, 916 | num_channels_latents, 917 | prompt_embeds.dtype, 918 | device, 919 | generator 920 | ) 921 | 922 | 923 | 924 | 925 | # 5. Prepare timesteps 926 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 927 | image_seq_len = latents.shape[1] 928 | mu = calculate_shift( 929 | image_seq_len, 930 | self.scheduler.config.base_image_seq_len, 931 | self.scheduler.config.max_image_seq_len, 932 | self.scheduler.config.base_shift, 933 | self.scheduler.config.max_shift, 934 | ) 935 | timesteps, num_inference_steps = retrieve_timesteps( 936 | self.scheduler, 937 | num_inference_steps, 938 | device, 939 | timesteps, 940 | sigmas, 941 | mu=mu, 942 | ) 943 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 944 | self._num_timesteps = len(timesteps) 945 | 946 | # handle guidance 947 | if self.transformer.config.guidance_embeds: 948 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) 949 | guidance = guidance.expand(latents.shape[0]) 950 | else: 951 | guidance = None 952 | 953 | # 6. Denoising loop 954 | Divide_latents_list_list, Divide_hidden_states_list_list_list = self.prepare_Divide_replace(Divide_latents_list, timesteps, Divide_replace, latents, Divide_prompt_embeds_list, Divide_pooled_prompt_embeds_list, Divide_text_ids_list, Divide_latent_image_ids_list, guidance, Divide_m_scale_list, Divide_n_scale_list) 955 | 956 | 957 | 958 | 959 | # hook_forwards(self, self.transformer) 960 | 961 | self.scheduler._init_step_index(timesteps[0]) 962 | with self.progress_bar(total=num_inference_steps) as progress_bar: 963 | for i, t in enumerate(timesteps): 964 | if self.interrupt: 965 | continue 966 | 967 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 968 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 969 | 970 | if i <= Divide_replace : 971 | latents = self.Divide_replace_latents(latents, Divide_latents_list_list[i], Divide_m_offset_list, Divide_n_offset_list, height, width) 972 | 973 | self._joint_attention_kwargs = None 974 | if i < Divide_replace: 975 | noise_pred = self.transformer( 976 | hidden_states=latents, 977 | timestep=timestep / 1000, 978 | guidance=guidance, 979 | pooled_projections=pooled_prompt_embeds, 980 | encoder_hidden_states=prompt_embeds, 981 | txt_ids=text_ids, 982 | img_ids=latent_image_ids, 983 | joint_attention_kwargs=self.joint_attention_kwargs, 984 | return_dict=False, 985 | Divide_hidden_states_list_list=Divide_hidden_states_list_list_list[i], 986 | Divide_m_offset_list=Divide_m_offset_list, 987 | Divide_n_offset_list=Divide_n_offset_list, 988 | Divide_m_scale_list=Divide_m_scale_list, 989 | Divide_n_scale_list=Divide_n_scale_list, 990 | latent_h=height//16, 991 | latent_w=width//16 992 | )[0] 993 | 994 | if i >= Divide_replace: 995 | # Release memory of Divide_latents_list and Divide_hidden_states_list_list_list 996 | 997 | if 'Divide_latents_list' in locals(): 998 | del Divide_latents_list 999 | if 'Divide_hidden_states_list_list_list' in locals(): 1000 | del Divide_hidden_states_list_list_list 1001 | torch.cuda.empty_cache() 1002 | init_forwards(self, self.transformer) 1003 | noise_pred = self.transformer( 1004 | hidden_states=latents, 1005 | timestep=timestep / 1000, 1006 | guidance=guidance, 1007 | pooled_projections=pooled_prompt_embeds, 1008 | encoder_hidden_states=prompt_embeds, 1009 | txt_ids=text_ids, 1010 | img_ids=latent_image_ids, 1011 | # update 1012 | joint_attention_kwargs=None, 1013 | # joint_attention_kwargs=self.joint_attention_kwargs, 1014 | return_dict=False, 1015 | )[0] 1016 | 1017 | 1018 | 1019 | # compute the previous noisy sample x_t -> x_t-1 1020 | latents_dtype = latents.dtype 1021 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 1022 | 1023 | if latents.dtype != latents_dtype: 1024 | if torch.backends.mps.is_available(): 1025 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 1026 | latents = latents.to(latents_dtype) 1027 | 1028 | 1029 | 1030 | if callback_on_step_end is not None: 1031 | callback_kwargs = {} 1032 | for k in callback_on_step_end_tensor_inputs: 1033 | callback_kwargs[k] = locals()[k] 1034 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 1035 | 1036 | latents = callback_outputs.pop("latents", latents) 1037 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 1038 | 1039 | # call the callback, if provided 1040 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 1041 | progress_bar.update() 1042 | 1043 | if XLA_AVAILABLE: 1044 | xm.mark_step() 1045 | 1046 | if output_type == "latent": 1047 | image = latents 1048 | 1049 | else: 1050 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 1051 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 1052 | image = self.vae.decode(latents, return_dict=False)[0] 1053 | image = self.image_processor.postprocess(image, output_type=output_type) 1054 | 1055 | 1056 | 1057 | # Offload all models 1058 | self.maybe_free_model_hooks() 1059 | 1060 | if not return_dict: 1061 | return (image,) 1062 | 1063 | return FluxPipelineOutput(images=image) 1064 | --------------------------------------------------------------------------------