├── pic
├── introduction.png
└── 0001_0007_merge_seed1234.png
├── T2IS_Gen
├── pic
│ └── 0001_0007_merge_seed1234.png
├── README.md
├── init_attention.py
├── utils.py
├── inference_t2is.py
├── t2is_transformer_flux.py
└── t2is_pipeline_flux.py
├── T2IS_Eval
├── run_prompt_alignment.sh
├── run_prompt_consistency.sh
├── README.md
├── vqa_alignment.py
└── test_prompt_consistency.py
├── README.md
└── T2IS_Bench
└── README.md
/pic/introduction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chengyou-jia/T2IS/HEAD/pic/introduction.png
--------------------------------------------------------------------------------
/pic/0001_0007_merge_seed1234.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chengyou-jia/T2IS/HEAD/pic/0001_0007_merge_seed1234.png
--------------------------------------------------------------------------------
/T2IS_Gen/pic/0001_0007_merge_seed1234.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chengyou-jia/T2IS/HEAD/T2IS_Gen/pic/0001_0007_merge_seed1234.png
--------------------------------------------------------------------------------
/T2IS_Eval/run_prompt_alignment.sh:
--------------------------------------------------------------------------------
1 | IMAGE_DIR="/home/chengyou/AutoT2IS/output_images/layout_deepseek-reasoner" # Target directory
2 | OUTPUT_JSON="/home/chengyou/results/layout_deepseek-reasoner.json" # Output JSON file path
3 |
4 | cd /home/chengyou/clipscore/t2v_metrics
5 | # Process specified directory
6 | # see https://github.com/linzhiqiu/t2v_metrics for t2v_metrics
7 |
8 | echo "Evaluating images in ${IMAGE_DIR}..."
9 | CUDA_VISIBLE_DEVICES=0 python vqa_alignment.py \
10 | --image_dir "${IMAGE_DIR}" \
11 | --image_format png \
12 | --prompt_json "./ChengyouJia/T2IS-Bench/prompt_alignment.json" \
13 | --output_json "${OUTPUT_JSON}" \
14 | --start_idx 0
15 |
16 | echo "Evaluation completed for ${IMAGE_DIR}!"
--------------------------------------------------------------------------------
/T2IS_Eval/run_prompt_consistency.sh:
--------------------------------------------------------------------------------
1 | # Set basic parameters
2 | IMAGE_DIR="/home/chengyou/AutoT2IS/output_images/layout_deepseek-reasoner" # Target directory
3 | IMAGE_PATTERN=".png" # Image format
4 | OUTPUT_JSON="/home/chengyou/Qwen/results/layout_deepseek-reasoner.json" # Output JSON file path
5 |
6 | # Process specified directory
7 | echo "Evaluating images in ${IMAGE_DIR}..."
8 | CUDA_VISIBLE_DEVICES=0 python test_prompt_consistency.py \
9 | --dataset_path "./ChengyouJia/T2IS-Bench/T2IS-Bench.json" \
10 | --criteria_path "./ChengyouJia/T2IS-Bench/prompt_consistency.json" \
11 | --image_base_path "${IMAGE_DIR}" \
12 | --image_pattern "${IMAGE_PATTERN}" \
13 | --output_path "${OUTPUT_JSON}"
14 | echo "Evaluation completed for ${IMAGE_DIR}!"
--------------------------------------------------------------------------------
/T2IS_Gen/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Usage
4 |
5 | ### Basic Usage
6 | ```bash
7 | python inference_t2is.py
8 | ```
9 |
10 | ### With Parameters
11 | ```bash
12 | python inference_t2is.py --idx 0 --json_file ./data/gallery.json --output_dir ./output_images/
13 | ```
14 |
15 | ## Parameters
16 |
17 | | Parameter | Type | Default | Description |
18 | |-----------|------|---------|-------------|
19 | | `--idx` | int | None | Parameter index from JSON file |
20 | | `--json_file` | str | `./data/gallery.json` | Input JSON file path |
21 | | `--output_dir` | str | `./output_images/` | Output directory path |
22 |
23 | ## Output Structure
24 |
25 | Generated images are saved in:
26 | ```
27 | output_images/
28 | └── seed_*/
29 | └── {idx}_merge_seed*.png
30 | ```
31 |
32 | ## Generated ImageSet
33 | Examples
34 |
35 |
36 |  |
37 |
38 |
39 | python inference_t2is.py --idx=
40 | |
41 |
--------------------------------------------------------------------------------
/T2IS_Eval/README.md:
--------------------------------------------------------------------------------
1 | # T2IS Evaluation Scripts
2 |
3 | This directory contains evaluation scripts for Text-to-ImageSet (T2IS) models, including prompt alignment and visual consistency evaluation.
4 |
5 | ## Prerequisites
6 |
7 | Before running the evaluation scripts, you need to download the T2IS-Bench dataset from Hugging Face:
8 |
9 | 1. Visit [ChengyouJia/T2IS-Bench](https://huggingface.co/datasets/ChengyouJia/T2IS-Bench) on Hugging Face
10 | 2. Download the following files:
11 | - `T2IS-Bench.json`: Main dataset file containing prompts and metadata
12 | - `prompt_alignment.json`: Alignment evaluation criteria
13 | - `prompt_consistency.json`: Consistency evaluation criteria
14 | 3. Place the downloaded files in a directory named `ChengyouJia/T2IS-Bench/` relative to the evaluation scripts
15 |
16 |
17 | ## Scripts Overview
18 |
19 | ### 1. Prompt Alignment Evaluation (`run_prompt_alignment.sh`)
20 |
21 | **Purpose**: Evaluates the alignment between generated images and their corresponding prompts using VQAScore metrics.
22 |
23 | **Features**:
24 | - Uses VQAScore to measure semantic similarity between images and text prompts
25 | - Processes PNG images from a specified directory
26 | - Outputs evaluation results in JSON format
27 | - Configurable start index for batch processing
28 |
29 | **Usage**:
30 | ```bash
31 | # Make the script executable
32 | chmod +x run_prompt_alignment.sh
33 |
34 | # Run the evaluation
35 | ./run_prompt_alignment.sh
36 | ```
37 |
38 | **Configuration**:
39 | - **Image Directory**: Set `IMAGE_DIR` to the path containing your generated images
40 | - **Output Path**: Set `OUTPUT_JSON` to where you want to save the results
41 | - **Image Format**: Currently set to PNG format
42 | - **Start Index**: Set `start_idx` for batch processing (default: 0)
43 |
44 | **Requirements**:
45 | - See details in https://github.com/linzhiqiu/t2v_metrics.
46 |
47 | ### 2. Prompt Consistency Evaluation (`run_prompt_consistency.sh`)
48 |
49 | **Purpose**: Evaluates the consistency of generated images with their prompts using Qwen model analysis.
50 |
51 | **Features**:
52 | - Uses Qwen2.5VL model to analyze image-prompt consistency
53 | - Processes images with configurable pattern matching
54 | - Outputs detailed consistency analysis in JSON format
55 | - Supports various image formats through pattern matching
56 |
57 | **Usage**:
58 | ```bash
59 | # Make the script executable
60 | chmod +x run_prompt_consistency.sh
61 |
62 | # Run the evaluation
63 | ./run_prompt_consistency.sh
64 | ```
65 |
66 | **Configuration**:
67 | - **Image Directory**: Set `IMAGE_DIR` to the path containing your generated images
68 | - **Image Pattern**: Set `IMAGE_PATTERN` to match your image files (e.g., ".png", ".jpg")
69 | - **Output Path**: Set `OUTPUT_JSON` to where you want to save the results
70 |
71 | **Requirements**:
72 | - See details in https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct.
73 |
--------------------------------------------------------------------------------
/T2IS_Gen/init_attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torchvision.transforms.functional as F
4 | TOKENS = 75
5 |
6 | def split_dims(x_t, height, width, self=None):
7 | """Split an attention layer dimension to height + width.
8 | The original estimate was latent_h = sqrt(hw_ratio*x_t),
9 | rounding to the nearest value. However, this proved inaccurate.
10 | The actual operation seems to be as follows:
11 | - Divide h,w by 8, rounding DOWN.
12 | - For every new layer (of 4), divide both by 2 and round UP (then back up).
13 | - Multiply h*w to yield x_t.
14 | There is no inverse function to this set of operations,
15 | so instead we mimic them without the multiplication part using the original h+w.
16 | It's worth noting that no known checkpoints follow a different system of layering,
17 | but it's theoretically possible. Please report if encountered.
18 | """
19 | scale = math.ceil(math.log2(math.sqrt(height * width / x_t)))
20 | latent_h = repeat_div(height, scale)
21 | latent_w = repeat_div(width, scale)
22 | if x_t > latent_h * latent_w and hasattr(self, "nei_multi"):
23 | latent_h, latent_w = self.nei_multi[1], self.nei_multi[0]
24 | while latent_h * latent_w != x_t:
25 | latent_h, latent_w = latent_h // 2, latent_w // 2
26 |
27 | return latent_h, latent_w
28 |
29 | def repeat_div(x,y):
30 | """Imitates dimension halving common in convolution operations.
31 |
32 | This is a pretty big assumption of the model,
33 | but then if some model doesn't work like that it will be easy to spot.
34 | """
35 | while y > 0:
36 | x = math.ceil(x / 2)
37 | y = y - 1
38 | return x
39 |
40 | def init_forwards(self, root_module: torch.nn.Module):
41 | for name, module in root_module.named_modules():
42 | if "attn" in name and "transformer_blocks" in name and "single_transformer_blocks" not in name and module.__class__.__name__ == "Attention":
43 | module.forward = FluxTransformerBlock_init_forward(self, module)
44 | elif "attn" in name and "single_transformer_blocks" in name and module.__class__.__name__ == "Attention":
45 | module.forward = FluxSingleTransformerBlock_init_forward(self, module)
46 |
47 | def FluxSingleTransformerBlock_init_forward(self, module):
48 | def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None,RPG_encoder_hidden_states_list=None,RPG_norm_encoder_hidden_states_list=None,RPG_hidden_states_list=None,RPG_norm_hidden_states_list=None):
49 | return module.processor(module, hidden_states=hidden_states, image_rotary_emb=image_rotary_emb)
50 | return forward
51 |
52 | def FluxTransformerBlock_init_forward(self, module):
53 | def forward(hidden_states=None, encoder_hidden_states=None, image_rotary_emb=None,RPG_encoder_hidden_states_list=None,RPG_norm_encoder_hidden_states_list=None,RPG_hidden_states_list=None,RPG_norm_hidden_states_list=None):
54 | return module.processor(module, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, image_rotary_emb=image_rotary_emb)
55 | return forward
--------------------------------------------------------------------------------
/T2IS_Gen/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for T2IS generation and image processing.
3 | """
4 |
5 | def calculate_layout_dimensions(num_prompts, sub_height, sub_width):
6 | """Calculate total height and width based on number of prompts and sub-image dimensions.
7 |
8 | Args:
9 | num_prompts (int): Number of sub-images/prompts
10 | sub_height (int): Height of each sub-image
11 | sub_width (int): Width of each sub-image
12 |
13 | Returns:
14 | tuple: Total (height, width) dimensions
15 | """
16 | if num_prompts == 4:
17 | # 2x2 layout
18 | height = sub_height * 2
19 | width = sub_width * 2
20 | elif num_prompts == 6:
21 | # 2x3 layout
22 | height = sub_height * 2
23 | width = sub_width * 3
24 | elif num_prompts == 8:
25 | # 2x4 layout
26 | height = sub_height * 2
27 | width = sub_width * 4
28 | elif num_prompts == 9:
29 | # 3x3 layout
30 | height = sub_height * 3
31 | width = sub_width * 3
32 | else:
33 | # Default: 1xN layout
34 | height = sub_height
35 | width = sub_width * num_prompts
36 | return height, width
37 |
38 |
39 | def calculate_cutting_layout(num_prompts):
40 | """Calculate rows and columns for cutting images based on number of prompts.
41 |
42 | Args:
43 | num_prompts (int): Number of sub-images/prompts
44 |
45 | Returns:
46 | tuple: (rows, cols) for cutting layout
47 | """
48 | if num_prompts == 4:
49 | # 2*2 切割
50 | rows, cols = 2, 2
51 | elif num_prompts == 6:
52 | # 2*3 切割
53 | rows, cols = 2, 3
54 | elif num_prompts == 8:
55 | # 2*4 切割
56 | rows, cols = 2, 4
57 | elif num_prompts == 9:
58 | # 3*3 切割
59 | rows, cols = 3, 3
60 | else:
61 | # 默认沿着宽切割
62 | rows, cols = 1, num_prompts
63 |
64 | return rows, cols
65 |
66 |
67 | def get_grid_params(size):
68 | """Calculate grid parameters for image layout based on size.
69 |
70 | Args:
71 | size (int): Number of sub-images/prompts
72 |
73 | Returns:
74 | tuple: (m_offsets, n_offsets, m_scales, n_scales) for grid layout
75 | """
76 | if size == 4:
77 | rows, cols = 2, 2
78 | elif size == 6:
79 | rows, cols = 2, 3
80 | elif size == 8:
81 | rows, cols = 2, 4
82 | elif size == 9:
83 | rows, cols = 3, 3
84 | else:
85 | rows, cols = 1, size
86 |
87 | m_scale = 1.0/cols
88 | n_scale = 1.0/rows
89 |
90 | m_offsets = []
91 | n_offsets = []
92 | m_scales = []
93 | n_scales = []
94 |
95 | for row in range(rows):
96 | for col in range(cols):
97 | m_offsets.append(col * m_scale)
98 | n_offsets.append(row * n_scale)
99 | m_scales.append(m_scale)
100 | n_scales.append(n_scale)
101 |
102 | return m_offsets, n_offsets, m_scales, n_scales
103 |
104 |
105 |
106 |
--------------------------------------------------------------------------------
/T2IS_Gen/inference_t2is.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import json
4 | import os
5 | from t2is_pipeline_flux import T2IS_FluxPipeline
6 | from PIL import Image
7 | from utils import calculate_layout_dimensions, calculate_cutting_layout
8 |
9 | def parse_arguments():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--idx', type=int, help="Loading parameters in json")
12 | parser.add_argument('--output_dir', type=str, default="./output_images/", help="Output directory")
13 | parser.add_argument('--json_file', type=str, default="./data/gallery.json", help="Input JSON file")
14 | return parser.parse_args()
15 |
16 | pipe = T2IS_FluxPipeline.from_pretrained("/home/chengyou/hugging/models/FLUX.1-dev", torch_dtype=torch.bfloat16)
17 | pipe = pipe.to("cuda")
18 |
19 | args = parse_arguments()
20 |
21 | if args.idx is not None:
22 | # Load parameters from JSON
23 | with open(args.json_file, 'r') as f:
24 | data = json.load(f)
25 |
26 | item = data[args.idx]
27 |
28 | idx = item["idx"]
29 | task_name_case_id = item["task_name_case_id"]
30 | prompt = item["prompt"]
31 | Divide_prompt_list = item["Divide_prompt_list"]
32 |
33 | # Calculate dimensions
34 | num_prompts = len(Divide_prompt_list)
35 | sub_height = 512
36 | sub_width = 512
37 | height, width = calculate_layout_dimensions(num_prompts, sub_height, sub_width)
38 |
39 | Divide_replace = 2
40 | seed = 1234
41 |
42 | else:
43 | # Default parameters
44 | idx = "0001_0007"
45 | prompt = "THREE-PANEL Images with a 1x3 grid layout a teenage boy with short spiky black hair, a slight build, and dark brown eyes in hyper-realistic style.All images maintain hyper-realistic digital painting style with consistent character design, emphasizing the boy's distinct features and naturalistic lighting across varied environments. [LEFT]:The boy stands at a science fair, surrounded by project displays and glowing holographic models. He holds a blueprint, his expression bright with curiosity. The background features blurred crowds and colorful experiment stations under bright indoor lighting. [MIDDLE]:The boy crouches in a sunlit garden, digging soil with a trowel. Dirt stains his hands and casual clothes, with scattered gardening tools nearby. His focused gaze and slightly parted lips suggest discovery, sunlight casting sharp shadows on the earthy textures. [RIGHT]:The boy wears a green knitted hat in a snowy urban park, breath visible in cold air. Frosted trees frame the scene as he clutches a steaming drink. The hat's yarn details contrast with his spiky hair, while distant ice-skating figures blur into the winter haze."
46 | Divide_prompt_list = [
47 | "The boy stands at a science fair, surrounded by project displays and glowing holographic models. He holds a blueprint, his expression bright with curiosity. The background features blurred crowds and colorful experiment stations under bright indoor lighting.",
48 | "The boy crouches in a sunlit garden, digging soil with a trowel. Dirt stains his hands and casual clothes, with scattered gardening tools nearby. His focused gaze and slightly parted lips suggest discovery, sunlight casting sharp shadows on the earthy textures.",
49 | "The boy wears a green knitted hat in a snowy urban park, breath visible in cold air. Frosted trees frame the scene as he clutches a steaming drink. The hat's yarn details contrast with his spiky hair, while distant ice-skating figures blur into the winter haze."
50 | ]
51 |
52 | sub_height = 512
53 | sub_width = 512
54 | height, width = calculate_layout_dimensions(3, sub_height, sub_width)
55 |
56 | Divide_replace = 2
57 | seed = 1234
58 |
59 | # Create output directory
60 | base_output_path = args.output_dir
61 | if not os.path.exists(base_output_path):
62 | os.makedirs(base_output_path)
63 |
64 | seed_output_path = os.path.join(base_output_path, f"seed_{seed}")
65 | if not os.path.exists(seed_output_path):
66 | os.makedirs(seed_output_path)
67 |
68 | print(f"Generating with seed {seed}:")
69 | try:
70 | image = pipe(
71 | Divide_prompt_list=Divide_prompt_list,
72 | Divide_replace=Divide_replace,
73 | seed=seed,
74 | prompt=prompt,
75 | height=height,
76 | width=width,
77 | num_inference_steps=20,
78 | guidance_scale=3.5,
79 | ).images[0]
80 |
81 | # Save image
82 | output_filename = os.path.join(seed_output_path, f"{idx}_merge_seed{seed}.png")
83 | image.save(output_filename)
84 | print(f"Image saved as {output_filename}")
85 |
86 | except Exception as e:
87 | print(f"Error processing {idx} with seed {seed}: {str(e)}")
88 |
--------------------------------------------------------------------------------
/T2IS_Eval/vqa_alignment.py:
--------------------------------------------------------------------------------
1 | import t2v_metrics
2 | import json
3 | import os
4 | from PIL import Image
5 | import torch
6 | from torchvision.transforms import ToTensor
7 | dimensions = ['Entity', 'Attribute', 'Relation']
8 | # Initialize the scoring model
9 | clip_flant5_score = t2v_metrics.VQAScore(model='clip-flant5-xxl')
10 |
11 |
12 |
13 | # Function to get image path from task_id and sub_id
14 | def get_image_path(task_id, sub_id, image_dir, image_format):
15 | return os.path.join(image_dir, f"{task_id}_{sub_id}.{image_format}")
16 |
17 | # Function to evaluate a single image against its criteria
18 | def evaluate_image(image_path, criteria):
19 | if not os.path.exists(image_path):
20 | return None
21 |
22 | # Load the image
23 | image = Image.open(image_path)
24 |
25 | # Prepare texts for each dimension
26 | texts = []
27 | for dimension in dimensions:
28 | texts.extend(criteria[dimension])
29 |
30 | # Calculate scores
31 | scores = clip_flant5_score(images=[image_path], texts=texts)
32 |
33 | # Calculate average score for each dimension
34 | dimension_scores = {}
35 | start_idx = 0
36 | for dimension in dimensions:
37 | num_criteria = len(criteria[dimension])
38 | dimension_scores[dimension] = sum(scores[0][start_idx:start_idx + num_criteria]) / num_criteria
39 | start_idx += num_criteria
40 |
41 | return dimension_scores
42 |
43 | # Main evaluation loop
44 | def main_evaluation(image_dir, image_format, prompt_json, output_json, start_idx=1):
45 |
46 | # Load the prompt alignment JSON file
47 | with open(prompt_json, 'r') as f:
48 | prompt_data = json.load(f)
49 |
50 | overall_dimension_scores = {dimension: 0 for dimension in dimensions}
51 | total_images = 0
52 | results = {}
53 |
54 | # Load existing results from the JSON file if it exists
55 | if os.path.exists(output_json):
56 | with open(output_json, 'r') as f:
57 | results = json.load(f)
58 | else:
59 | results = {}
60 |
61 | for task_id, task_data in prompt_data.items():
62 | print(f"\nEvaluating task: {task_id}")
63 |
64 | for idx, (sub_id, criteria) in enumerate(task_data.items(), start=start_idx):
65 | formatted_sub_id = f"{idx:04d}"
66 | result_key = f"{task_id}_{formatted_sub_id}"
67 |
68 | # Skip processing if the result already exists
69 | if result_key in results:
70 | print(f"Skipping already processed image: {result_key}")
71 | continue
72 |
73 | image_path = get_image_path(task_id, formatted_sub_id, image_dir, image_format)
74 | print(f"\nEvaluating image: {image_path}")
75 |
76 | scores = evaluate_image(image_path, criteria)
77 | if scores:
78 | print("Dimension scores:")
79 | for dimension, score in scores.items():
80 | print(f"{dimension}: {score:.4f}")
81 | overall_dimension_scores[dimension] += score
82 | total_images += 1
83 | results[result_key] = scores
84 |
85 | # Save the updated results to the JSON file after each image
86 | with open(output_json, 'w') as f:
87 | json.dump(results, f, indent=4, default=lambda o: o.tolist() if isinstance(o, torch.Tensor) else o)
88 | else:
89 | print("Image not found")
90 |
91 | # Calculate overall average scores for each dimension
92 | if total_images > 0:
93 | for dimension in overall_dimension_scores:
94 | overall_dimension_scores[dimension] /= total_images
95 |
96 | print("\nOverall average dimension scores:")
97 | for dimension, score in overall_dimension_scores.items():
98 | print(f"{dimension}: {score:.4f}")
99 |
100 | # Example usage
101 | if __name__ == "__main__":
102 | import argparse
103 |
104 | parser = argparse.ArgumentParser(description="Evaluate image dimensions.")
105 | parser.add_argument("--image_dir", type=str, default="/home/chengyou/GroupGen/baseline/results/flux/", help="Directory containing the images.")
106 | parser.add_argument("--image_format", type=str, default="png", choices=["png", "jpg"], help="Format of the images.")
107 | parser.add_argument("--output_json", type=str, default="dimension_scores.json", help="Path to the output JSON file.")
108 | parser.add_argument("--start_idx", type=int, default=1, help="Starting index for image enumeration.")
109 | parser.add_argument("--prompt_json", type=str, default="/home/chengyou/GroupGen/evaluation/prompt_alignment.json", help="Path to the prompt alignment JSON file.")
110 | args = parser.parse_args()
111 |
112 | main_evaluation(args.image_dir, args.image_format, args.prompt_json, args.output_json, args.start_idx)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | Why Settle for One? Text-to-ImageSet Generation and Evaluation
3 |
4 |
5 | [🌐 Website] •
6 | [📜 Paper] •
7 | [🤗 HF Dataset] •
8 |
9 |
10 |
11 |
12 | Official Repo for "Why Settle for One?
13 | Text-to-ImageSet Generation and Evaluation"
14 |
15 |
16 |
18 |
19 | ## T2IS
20 | 
21 |
22 |
23 | ## News
24 |
25 | - _2025.10_: We have added the latest **Seedream 4.0** results. Please refer to the [Seedream 4.0 Demo](https://www.volcengine.com/experience/ark?launch=seedream) and the attached file [**T2IS_Seedream.zip**](https://huggingface.co/datasets/ChengyouJia/T2IS-Bench).
26 |
27 |
28 | - _2025.09_: We release the [T2IS-Gen] simple version of set-aware generation code.
29 |
30 |
31 | - _2025.08_: We release the [T2IS-Eval] evaluation toolkit.
32 | - _2025.07_: We release the details of [T2IS-Bench].
33 |
34 |
35 | ## 🛠️ Installation
36 |
37 | ### Text-to-ImageSet Generation
38 |
39 | ### 1. Set Environment
40 | ```bash
41 | conda create -n T2IS python==3.9
42 | conda activate T2IS
43 | pip install xformers==0.0.28.post1 diffusers peft torchvision==0.19.1 opencv-python==4.10.0.84 sentencepiece==0.2.0 protobuf==5.28.1 scipy==1.13.1
44 | ```
45 |
46 | ### 2. Quick Start
47 |
48 | ```bash
49 | cd T2IS_Gen
50 | ```
51 |
52 | ```python
53 | import torch
54 | import argparse
55 | import json
56 | import os
57 | from t2is_pipeline_flux import T2IS_FluxPipeline
58 | from PIL import Image
59 | from utils import calculate_layout_dimensions, calculate_cutting_layout
60 | pipe = T2IS_FluxPipeline.from_pretrained("/home/chengyou/hugging/models/FLUX.1-dev", torch_dtype=torch.bfloat16)
61 | pipe = pipe.to("cuda")
62 |
63 | # base_output_path = "../output_images/RAG_layout_deepseek-reasoner_3_30_seed_1234"
64 | base_output_path = "./output_images/"
65 |
66 | print(f"Processing file with task name case ID: 0001_0007")
67 | task_name_case_id = "dynamic_character_scenario_design_0007"
68 | Divide_prompt_list = [
69 | "The boy stands at a science fair, surrounded by project displays and glowing holographic models. He holds a blueprint, his expression bright with curiosity. The background features blurred crowds and colorful experiment stations under bright indoor lighting.",
70 | "The boy crouches in a sunlit garden, digging soil with a trowel. Dirt stains his hands and casual clothes, with scattered gardening tools nearby. His focused gaze and slightly parted lips suggest discovery, sunlight casting sharp shadows on the earthy textures.",
71 | "The boy wears a green knitted hat in a snowy urban park, breath visible in cold air. Frosted trees frame the scene as he clutches a steaming drink. The hat's yarn details contrast with his spiky hair, while distant ice-skating figures blur into the winter haze."
72 | ]
73 | prompt = "THREE-PANEL Images with a 1x3 grid layout a teenage boy with short spiky black hair, a slight build, and dark brown eyes in hyper-realistic style.All images maintain hyper-realistic digital painting style with consistent character design, emphasizing the boy's distinct features and naturalistic lighting across varied environments. [LEFT]:The boy stands at a science fair, surrounded by project displays and glowing holographic models. He holds a blueprint, his expression bright with curiosity. The background features blurred crowds and colorful experiment stations under bright indoor lighting. [MIDDLE]:The boy crouches in a sunlit garden, digging soil with a trowel. Dirt stains his hands and casual clothes, with scattered gardening tools nearby. His focused gaze and slightly parted lips suggest discovery, sunlight casting sharp shadows on the earthy textures. [RIGHT]:The boy wears a green knitted hat in a snowy urban park, breath visible in cold air. Frosted trees frame the scene as he clutches a steaming drink. The hat's yarn details contrast with his spiky hair, while distant ice-skating figures blur into the winter haze."
74 |
75 | # Set default sub-image size to 512x512
76 | sub_height = 512
77 | sub_width = 512
78 |
79 | # Calculate total height and width based on layout
80 | num_prompts = len(Divide_prompt_list)
81 | height, width = calculate_layout_dimensions(num_prompts, sub_height, sub_width)
82 |
83 |
84 |
85 | Divide_replace = 2
86 | num_inference_steps = 20
87 |
88 | seeds = [1234]
89 |
90 | for seed_idx, seed in enumerate(seeds):
91 | seed_output_path = os.path.join(base_output_path, f"seed_{seed}")
92 | if not os.path.exists(seed_output_path):
93 | os.makedirs(seed_output_path)
94 |
95 | print(f"Generating with seed {seed}:")
96 | try:
97 | image = pipe(
98 | Divide_prompt_list=Divide_prompt_list,
99 | Divide_replace=Divide_replace,
100 | seed=seed,
101 | prompt=prompt,
102 | height=height,
103 | width=width,
104 | num_inference_steps=num_inference_steps,
105 | guidance_scale=3.5,
106 | ).images[0]
107 | except Exception as e:
108 | print(f"Error processing {idx} with seed {seed}: {str(e)}")
109 | continue
110 | image.save(os.path.join(seed_output_path, f"{idx}_merge_seed{seed}.png"))
111 | ```
112 | ## Generated ImageSet
113 | Examples
114 |
115 |
116 |  |
117 |
118 |
119 |
120 |
121 | ## Citation
122 | If you find it helpful, please kindly cite the paper.
123 | ```
124 | @article{jia2025settle,
125 | title={Why Settle for One? Text-to-ImageSet Generation and Evaluation},
126 | author={Jia, Chengyou and Shen, Xin and Dang, Zhuohang and Xia, Changliang and Wu, Weijia and Zhang, Xinyu and Qian, Hangwei and Tsang, Ivor W and Luo, Minnan},
127 | journal={arXiv preprint arXiv:2506.23275},
128 | year={2025}
129 | }
130 | ```
131 |
132 | ## 📬 Contact
133 |
134 | If you have any inquiries, suggestions, or wish to contact us for any reason, we warmly invite you to email us at cp3jia@stu.xjtu.edu.cn.
135 |
--------------------------------------------------------------------------------
/T2IS_Bench/README.md:
--------------------------------------------------------------------------------
1 | See https://huggingface.co/datasets/ChengyouJia/T2IS-Bench for more details.
2 |
3 |
4 | ### Dataset Overview
5 |
6 | **T2IS-Bench** is a comprehensive benchmark designed to evaluate generative models' performance in text-to-image set generation tasks. It includes **596 carefully constructed tasks** across **five major categories** (26 sub-categories), each targeting different aspects of set-level consistency such as identity preservation, style uniformity, and logical coherence. These tasks span a wide range of real-world applications, including character creation, visual storytelling, product mockups, procedural illustrations, and instructional content.
7 |
8 | T2IS-Bench provides a scalable evaluation framework that assesses image sets across **three critical consistency dimensions**: identity, style, and logic. Each of the **596 tasks** is paired with structured natural language instructions and evaluated using **LLM-driven criteria generation**, enabling automatic, interpretable, and fine-grained assessment. This design supports benchmarking generative models' ability to produce coherent visual outputs beyond prompt-level alignment, and reflects real-world requirements for controllability and consistency in multi-image generation.
9 |
10 |
11 |
12 | ### Supported Tasks
13 |
14 | The dataset comprises five main categories, each with a set of associated tasks and unique task IDs as listed below:
15 |
16 | #### **Character Generation**
17 |
18 | - `0001` – Multi-Scenario
19 | - `0002` – Multi-Expression
20 | - `0003` – Portrait Design
21 | - `0004` – Multi-view
22 | - `0005` – Multi-pose
23 |
24 | #### **Design Style Generation**
25 |
26 | - `0006` – Creative Style
27 | - `0007` – Poster Design
28 | - `0008` – Font Design
29 | - `0009` – IP Product
30 | - `0010` – Home Decoration
31 |
32 | #### **Story Generation**
33 |
34 | - `0011` – Movie Shot
35 | - `0012` – Comic Story
36 | - `0013` – Children Book
37 | - `0014` – News Illustration
38 | - `0015` – Hist. Narrative
39 |
40 | #### **Process Generation**
41 |
42 | - `0016` – Growth Process
43 | - `0017` – Draw Process
44 | - `0018` – Cooking Process
45 | - `0019` – Physical Law
46 | - `0020` – Arch. Building
47 | - `0021` – Evolution Illustration
48 |
49 | #### **Instruction Generation**
50 |
51 | - `0022` – Education Illustration
52 | - `0023` – Historical Panel
53 | - `0024` – Product Instruction
54 | - `0025` – Travel Guide
55 | - `0026` – Activity Arrange
56 |
57 |
58 |
59 | ### Use Cases
60 |
61 | **T2IS-Bench** is designed for evaluating generative models on multi-image consistency tasks, testing capabilities such as aesthetics, prompt alignment (including entity, attribute, and relation understanding), and visual consistency (covering identity, style, and logic) across image sets. It is suitable for benchmarking text-to-image models, diffusion transformers, and multimodal generation systems in real-world applications like product design, storytelling, and instructional visualization.
62 |
63 | ## Dataset Format and Structure
64 |
65 | ### Data Organization
66 |
67 | 1. **`T2IS-Bench.json`**
68 | A JSON file providing all of the cases. The structure of `T2IS-Bench.json` is as follows:
69 |
70 | ```json
71 | {
72 | ......
73 | "0018_0001": {
74 | "task_name": "Cooking Process",
75 | "num_of_cases": 27,
76 | "uid": "0018",
77 | "output_image_count": 4,
78 | "case_id": "0001",
79 | "task_name_case_id": "cooking_process_0001",
80 | "category": "Process Generation",
81 | "instruction": "Please provide a detailed guide on melting chocolate, including 4 steps. For each step, generate an image.",
82 | "sub_caption": [
83 | "A glass bowl filled with chopped dark chocolate pieces sits on top of a pot of simmering water. Steam rises gently around the bowl, and a thermometer is visible in the chocolate. The kitchen counter shows other baking ingredients in the background.",
84 | "Hands holding a silicone spatula are gently stirring melting chocolate in a glass bowl. The chocolate is partially melted, with some pieces still visible. The bowl is positioned over a steaming pot on a stovetop.",
85 | "A close-up view of a digital thermometer inserted into fully melted, glossy chocolate. The thermometer display shows a temperature of 88°F (31°C). The melted chocolate has a rich, dark color and smooth texture.",
86 | "A hand is seen removing the bowl of melted chocolate from the double boiler setup. The chocolate appears smooth and shiny. Next to the stove, various dessert items like strawberries, cookies, and a cake are ready for dipping or coating."
87 | ]
88 | }
89 | ......
90 | }
91 | ```
92 |
93 | - task_name: Name of the task.
94 | - num_of_cases: The number of individual cases in the task.
95 | - uid: Unique identifier for the task.
96 | - output_image_count: Number of images expected as output.
97 | - case_id: Identifier for this case.
98 | - task_name_case_id: Unique identifier for each specific case within a task, combining the task name and case ID.
99 | - category: The classification of the task.
100 | - instruction: The task's description, specifying what needs to be generated.
101 | - sub_caption: Descriptions for each image in the task by feeding instruction into LLM.
102 |
103 |
104 |
105 | 2. **`prompt_alignment_criterion.json`**
106 |
107 | This file contains evaluation criteria for assessing prompt alignment in image generation tasks. Each entry corresponds to a specific task and is organized by steps, with each step evaluated based on three key aspects: **Entity**, **Attribute**, and **Relation**.
108 |
109 | - **Entity** defines the key objects or characters required in the scene.
110 | - **Attribute** describes the properties or conditions that these entities must possess.
111 | - **Relation** outlines how the entities interact or are positioned within the scene.
112 |
113 | This structured format helps evaluate the accuracy of the generated images in response to specific prompts.
114 |
115 | 3. **`prompt_consistency_criterion.json`**
116 |
117 | This file defines evaluation criteria for assessing *intra-sequence consistency* in image generation tasks. Each entry corresponds to a specific task and outlines standards across three core aspects: **Style**, **Identity**, and **Logic**.
118 |
119 | - **Style** evaluates the visual coherence across all generated images, including consistency in rendering style, color palette, lighting conditions, and background detail. It ensures that all images share a unified artistic and atmospheric aesthetic.
120 |
121 | - **Identity** focuses on maintaining character integrity across scenes. This includes preserving key facial features, body proportions, attire, and expressions so that the same individual or entity is clearly represented throughout the sequence.
122 |
123 | - **Logic** ensures semantic and physical plausibility across images. This includes spatial layout consistency, realistic actions, appropriate interactions with the environment, and coherent scene transitions.
124 |
125 | This structured format enables a systematic evaluation of how well generated images maintain consistency within a task.
126 |
--------------------------------------------------------------------------------
/T2IS_Eval/test_prompt_consistency.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | from tqdm import tqdm
5 | from PIL import Image
6 | import torch
7 | import argparse
8 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
9 | from qwen_vl_utils import process_vision_info
10 | from transformers import GenerationConfig
11 |
12 | def load_model(model_path):
13 | # Load the model on the available device(s)
14 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
15 | model_path, torch_dtype="auto", device_map="auto"
16 | )
17 |
18 | # default processer
19 | processor = AutoProcessor.from_pretrained(model_path)
20 |
21 | return model, processor
22 |
23 | def evaluate_prompt_consistency(dataset_path, criteria_path, image_base_path, image_pattern, output_path=None):
24 | # Load the model
25 | model, processor = load_model("./Qwen/Qwen2.5-VL-7B-Instruct")
26 |
27 | # Load the filtered dataset
28 | with open(dataset_path, 'r') as f:
29 | dataset = json.load(f)
30 | print(f"Total number of cases in dataset: {len(dataset)}")
31 |
32 | # Load the prompt consistency criteria
33 | with open(criteria_path, 'r') as f:
34 | criteria_data = json.load(f)
35 | print(f"Loaded prompt consistency criteria")
36 |
37 | # Get folder name for output file
38 | folder_name = os.path.basename(image_base_path)
39 | print(f"\n{'='*50}")
40 | print(f"Processing folder: {folder_name}")
41 | print(f"{'='*50}\n")
42 |
43 | # Set output path if not provided
44 | if output_path is None:
45 | output_path = f"evaluation_scores_{folder_name}.json"
46 |
47 | # Initialize results
48 | results = {}
49 |
50 | # Check if results file already exists and load it
51 | if os.path.exists(output_path):
52 | with open(output_path, 'r') as f:
53 | results = json.load(f)
54 | print(f"Loaded existing results from {output_path}")
55 | else:
56 | results = {
57 | "case_results": {}
58 | }
59 |
60 | # Process each case in the dataset
61 | for case_id, case_data in tqdm(dataset.items(), desc=f"Processing cases"):
62 | # Skip if this case has already been processed
63 | if case_id in results["case_results"]:
64 | print(f"\nSkipping already processed case: {case_id}")
65 | continue
66 |
67 | print(f"\nProcessing case: {case_id}")
68 |
69 | # Get criteria for this case
70 | case_criteria = criteria_data.get(case_id, {})
71 | if not case_criteria:
72 | print(f"Warning: No criteria found for case {case_id}. Skipping.")
73 | continue
74 |
75 | # Get all image paths for this case based on the provided pattern
76 | all_files = os.listdir(image_base_path)
77 | image_paths = []
78 |
79 | # Parse the image pattern to extract case_id and file extension
80 | ext = image_pattern
81 |
82 | for f in all_files:
83 | if f.startswith(case_id) and f.endswith(ext):
84 | try:
85 | # Extract the last number from filename (e.g., 0001 from 0001_0001_0001.png)
86 | num = int(f.split('_')[-1].replace(ext, ''))
87 | image_paths.append((num, os.path.join(image_base_path, f)))
88 | except ValueError:
89 | continue
90 |
91 | # Sort by the extracted number and get only the paths
92 | image_paths = [path for _, path in sorted(image_paths)]
93 |
94 | if len(image_paths) == 0:
95 | print(f"Warning: Expected {case_data['output_image_count']} images, but found {len(image_paths)}. Skipping this case.")
96 | continue
97 |
98 | # Initialize dimension scores
99 | dimension_scores = {
100 | "Style": {"scores": [], "criteria": []},
101 | "Identity": {"scores": [], "criteria": []},
102 | "Logic": {"scores": [], "criteria": []}
103 | }
104 |
105 | # Process each dimension
106 | for dimension in ["Style", "Identity", "Logic"]:
107 | if dimension not in case_criteria:
108 | continue
109 |
110 | # Get criteria for this dimension
111 | dimension_criteria = case_criteria[dimension][0] # Get the first (and only) dictionary in the list
112 | dimension_scores[dimension]["criteria"] = list(dimension_criteria.values())
113 |
114 | # Process each criterion in this dimension
115 | for criterion_text in dimension_criteria.values():
116 | softmax_values = []
117 |
118 | # Compare each pair of images
119 | for i in range(len(image_paths)):
120 | for j in range(i + 1, len(image_paths)):
121 | messages = []
122 | messages.append(
123 | {
124 | "role": "user",
125 | "content": [
126 | {"type": "image", "image": image_paths[i], "resized_height": 512, "resized_width": 512},
127 | {"type": "image", "image": image_paths[j], "resized_height": 512, "resized_width": 512},
128 | {"type": "text", "text": f"Do images meet the following criteria? {criterion_text} Please answer Yes or No."},
129 | ],
130 | }
131 | )
132 |
133 | # Prepare for inference
134 | text = processor.apply_chat_template(
135 | messages, tokenize=False, add_generation_prompt=True
136 | )
137 | image_inputs, video_inputs = process_vision_info(messages)
138 | inputs = processor(
139 | text=[text],
140 | images=image_inputs,
141 | videos=video_inputs,
142 | padding=True,
143 | return_tensors="pt",
144 | )
145 | inputs = inputs.to("cuda")
146 |
147 | generated_ids = model.generate(**inputs, max_new_tokens=1, return_dict_in_generate=True, output_scores=True, output_logits=True)
148 | sequences = generated_ids.sequences
149 | scores = generated_ids.scores
150 | logits = generated_ids.logits
151 |
152 | no_logits = logits[0][0][2753]
153 | yes_logits = logits[0][0][9454]
154 |
155 | # Calculate softmax
156 | logits = torch.tensor([no_logits, yes_logits])
157 | softmax = torch.nn.functional.softmax(logits, dim=0)
158 | yes_softmax = softmax[1].item()
159 |
160 | generated_ids_trimmed = [
161 | out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, sequences)
162 | ]
163 |
164 | output_text = processor.batch_decode(
165 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
166 | )
167 |
168 | # Append the current yes_softmax value to the list
169 | softmax_values.append(yes_softmax)
170 |
171 | # Calculate average score for this criterion
172 | if softmax_values:
173 | average_softmax = sum(softmax_values) / len(softmax_values)
174 | dimension_scores[dimension]["scores"].append(average_softmax)
175 |
176 | # Calculate overall scores for each dimension
177 | dimension_averages = {}
178 | for dimension, data in dimension_scores.items():
179 | if data["scores"]:
180 | dimension_averages[dimension] = sum(data["scores"]) / len(data["scores"])
181 | else:
182 | dimension_averages[dimension] = 0.0
183 |
184 | # Calculate overall average across all dimensions
185 | overall_average = sum(dimension_averages.values()) / len(dimension_averages) if dimension_averages else 0.0
186 |
187 | # Store the results for this case
188 | results["case_results"][case_id] = {
189 | "task_name_case_id": case_data["task_name_case_id"],
190 | "num_images": case_data["output_image_count"],
191 | "dimension_scores": dimension_scores,
192 | "dimension_averages": dimension_averages,
193 | "overall_average": overall_average
194 | }
195 |
196 | print(f"\nScores for case {case_id}:")
197 | for dimension, avg in dimension_averages.items():
198 | print(f"{dimension} average: {avg:.4f}")
199 | print(f"Overall average: {overall_average:.4f}")
200 |
201 | # Save results after each case is processed
202 | with open(output_path, 'w') as f:
203 | json.dump(results, f, indent=4)
204 | print(f"Updated results saved to: {output_path}")
205 |
206 | # Final save of results
207 | with open(output_path, 'w') as f:
208 | json.dump(results, f, indent=4)
209 | print(f"\nFinal results saved to: {output_path}")
210 |
211 | # Calculate overall dimension scores across all cases
212 | print(f"\nCalculating overall dimension scores...")
213 |
214 | # Initialize dimension totals
215 | dimension_totals = {}
216 | dimension_counts = {}
217 |
218 | # Aggregate scores across all cases
219 | for case_id, case_result in results["case_results"].items():
220 | for dimension, avg in case_result["dimension_averages"].items():
221 | if dimension not in dimension_totals:
222 | dimension_totals[dimension] = 0.0
223 | dimension_counts[dimension] = 0
224 | if avg > 0: # Only add non-zero scores
225 | dimension_totals[dimension] += avg
226 | dimension_counts[dimension] += 1
227 |
228 | # Calculate final averages for each dimension
229 | final_dimension_averages = {}
230 | for dimension, total in dimension_totals.items():
231 | if dimension_counts[dimension] > 0:
232 | final_dimension_averages[dimension] = total / dimension_counts[dimension]
233 | else:
234 | final_dimension_averages[dimension] = 0.0
235 |
236 | # Calculate overall average across all dimensions
237 | final_overall_average = sum(final_dimension_averages.values()) / len(final_dimension_averages) if final_dimension_averages else 0.0
238 |
239 | # Add final scores to results
240 | results["final_dimension_averages"] = final_dimension_averages
241 | results["final_overall_average"] = final_overall_average
242 |
243 | # Save updated results with final scores
244 | with open(output_path, 'w') as f:
245 | json.dump(results, f, indent=4)
246 |
247 | # Print final scores
248 | print(f"\nFinal dimension scores:")
249 | for dimension, avg in final_dimension_averages.items():
250 | print(f"{dimension} final average: {avg:.4f}")
251 | print(f"Final overall average: {final_overall_average:.4f}")
252 |
253 | return results
254 |
255 | def main():
256 | parser = argparse.ArgumentParser(description="Evaluate prompt consistency for generated images")
257 | parser.add_argument("--dataset_path", type=str, default="/home/chengyou/GroupGen/baseline/filtered_responses_sorted.json", help="Path to the dataset JSON file")
258 | parser.add_argument("--criteria_path", type=str, default="/home/chengyou/Qwen/prompt_consistency.json", help="Path to the prompt consistency criteria JSON file")
259 | parser.add_argument("--image_base_path", type=str, default="/home/chengyou/sx/Gemini/gemini", help="Path to the directory containing generated images")
260 | parser.add_argument("--image_pattern", type=str, default="jpg", help="Pattern for image filenames, use {} as placeholder for case_id")
261 | parser.add_argument("--output_path", type=str, help="Path to save the evaluation results")
262 |
263 | args = parser.parse_args()
264 |
265 | evaluate_prompt_consistency(
266 | args.dataset_path,
267 | args.criteria_path,
268 | args.image_base_path,
269 | args.image_pattern,
270 | args.output_path
271 | )
272 |
273 | if __name__ == "__main__":
274 | main()
275 |
--------------------------------------------------------------------------------
/T2IS_Gen/t2is_transformer_flux.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import Any, Dict, Optional, Tuple, Union
17 |
18 | import numpy as np
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 |
23 | from diffusers.configuration_utils import ConfigMixin, register_to_config
24 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25 | from diffusers.models.attention import FeedForward
26 | from diffusers.models.attention_processor import (
27 | Attention,
28 | AttentionProcessor,
29 | FluxAttnProcessor2_0,
30 | FusedFluxAttnProcessor2_0,
31 | )
32 | from diffusers.models.modeling_utils import ModelMixin
33 | from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
34 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
35 | from diffusers.utils.torch_utils import maybe_allow_in_graph
36 | from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
37 | from diffusers.models.modeling_outputs import Transformer2DModelOutput
38 | from typing import List
39 |
40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41 |
42 |
43 | @maybe_allow_in_graph
44 | class FluxSingleTransformerBlock(nn.Module):
45 | r"""
46 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
47 |
48 | Reference: https://arxiv.org/abs/2403.03206
49 |
50 | Parameters:
51 | dim (`int`): The number of channels in the input and output.
52 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
53 | attention_head_dim (`int`): The number of channels in each head.
54 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
55 | processing of `context` conditions.
56 | """
57 |
58 | def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
59 | super().__init__()
60 | self.mlp_hidden_dim = int(dim * mlp_ratio)
61 |
62 | self.norm = AdaLayerNormZeroSingle(dim)
63 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
64 | self.act_mlp = nn.GELU(approximate="tanh")
65 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
66 |
67 | processor = FluxAttnProcessor2_0()
68 | self.attn = Attention(
69 | query_dim=dim,
70 | cross_attention_dim=None,
71 | dim_head=attention_head_dim,
72 | heads=num_attention_heads,
73 | out_dim=dim,
74 | bias=True,
75 | processor=processor,
76 | qk_norm="rms_norm",
77 | eps=1e-6,
78 | pre_only=True,
79 | )
80 |
81 | def forward(
82 | self,
83 | hidden_states: torch.FloatTensor,
84 | temb: torch.FloatTensor,
85 | image_rotary_emb=None,
86 | joint_attention_kwargs=None,
87 | ):
88 | residual = hidden_states
89 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
90 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
91 | joint_attention_kwargs = joint_attention_kwargs or {}
92 |
93 | attn_output = self.attn(
94 | hidden_states=norm_hidden_states,
95 | image_rotary_emb=image_rotary_emb,
96 | **joint_attention_kwargs
97 | )
98 |
99 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
100 | gate = gate.unsqueeze(1)
101 | hidden_states = gate * self.proj_out(hidden_states)
102 | hidden_states = residual + hidden_states
103 | if hidden_states.dtype == torch.float16:
104 | hidden_states = hidden_states.clip(-65504, 65504)
105 |
106 | return hidden_states
107 |
108 |
109 | @maybe_allow_in_graph
110 | class FluxTransformerBlock(nn.Module):
111 | r"""
112 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
113 |
114 | Reference: https://arxiv.org/abs/2403.03206
115 |
116 | Parameters:
117 | dim (`int`): The number of channels in the input and output.
118 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
119 | attention_head_dim (`int`): The number of channels in each head.
120 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
121 | processing of `context` conditions.
122 | """
123 |
124 | def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
125 | super().__init__()
126 |
127 | self.norm1 = AdaLayerNormZero(dim)
128 |
129 | self.norm1_context = AdaLayerNormZero(dim)
130 |
131 | if hasattr(F, "scaled_dot_product_attention"):
132 | processor = FluxAttnProcessor2_0()
133 | else:
134 | raise ValueError(
135 | "The current PyTorch version does not support the `scaled_dot_product_attention` function."
136 | )
137 | self.attn = Attention(
138 | query_dim=dim,
139 | cross_attention_dim=None,
140 | added_kv_proj_dim=dim,
141 | dim_head=attention_head_dim,
142 | heads=num_attention_heads,
143 | out_dim=dim,
144 | context_pre_only=False,
145 | bias=True,
146 | processor=processor,
147 | qk_norm=qk_norm,
148 | eps=eps,
149 | )
150 |
151 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
152 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
153 |
154 | self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
155 | self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
156 |
157 | # let chunk size default to None
158 | self._chunk_size = None
159 | self._chunk_dim = 0
160 |
161 | def forward(
162 | self,
163 | hidden_states: torch.FloatTensor,
164 | encoder_hidden_states: torch.FloatTensor,
165 | temb: torch.FloatTensor,
166 | image_rotary_emb=None,
167 | joint_attention_kwargs=None,
168 | ):
169 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
170 |
171 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
172 | encoder_hidden_states, emb=temb
173 | )
174 | joint_attention_kwargs = joint_attention_kwargs or {}
175 |
176 |
177 |
178 | # Attention.
179 | attn_output, context_attn_output = self.attn(
180 | hidden_states=norm_hidden_states,
181 | encoder_hidden_states=norm_encoder_hidden_states,
182 | image_rotary_emb=image_rotary_emb,
183 | **joint_attention_kwargs,
184 | )
185 |
186 | # Process attention outputs for the `hidden_states`.
187 | attn_output = gate_msa.unsqueeze(1) * attn_output
188 | hidden_states = hidden_states + attn_output
189 |
190 | norm_hidden_states = self.norm2(hidden_states)
191 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
192 |
193 | ff_output = self.ff(norm_hidden_states)
194 | ff_output = gate_mlp.unsqueeze(1) * ff_output
195 |
196 | hidden_states = hidden_states + ff_output
197 |
198 | # Process attention outputs for the `encoder_hidden_states`.
199 |
200 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
201 | encoder_hidden_states = encoder_hidden_states + context_attn_output
202 |
203 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
204 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
205 |
206 | context_ff_output = self.ff_context(norm_encoder_hidden_states)
207 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
208 | if encoder_hidden_states.dtype == torch.float16:
209 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
210 |
211 | return encoder_hidden_states, hidden_states
212 |
213 |
214 | class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
215 | """
216 | The Transformer model introduced in Flux.
217 |
218 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
219 |
220 | Parameters:
221 | patch_size (`int`): Patch size to turn the input data into small patches.
222 | in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
223 | num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
224 | num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
225 | attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
226 | num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
227 | joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
228 | pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
229 | guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
230 | """
231 |
232 | _supports_gradient_checkpointing = True
233 | _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
234 |
235 | @register_to_config
236 | def __init__(
237 | self,
238 | patch_size: int = 1,
239 | in_channels: int = 64,
240 | num_layers: int = 19,
241 | num_single_layers: int = 38,
242 | attention_head_dim: int = 128,
243 | num_attention_heads: int = 24,
244 | joint_attention_dim: int = 4096,
245 | pooled_projection_dim: int = 768,
246 | guidance_embeds: bool = False,
247 | axes_dims_rope: Tuple[int] = (16, 56, 56),
248 | ):
249 | super().__init__()
250 | self.out_channels = in_channels
251 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
252 |
253 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
254 |
255 | text_time_guidance_cls = (
256 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
257 | )
258 | self.time_text_embed = text_time_guidance_cls(
259 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
260 | )
261 |
262 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
263 | self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
264 |
265 | self.transformer_blocks = nn.ModuleList(
266 | [
267 | FluxTransformerBlock(
268 | dim=self.inner_dim,
269 | num_attention_heads=self.config.num_attention_heads,
270 | attention_head_dim=self.config.attention_head_dim,
271 | )
272 | for i in range(self.config.num_layers)
273 | ]
274 | )
275 |
276 | self.single_transformer_blocks = nn.ModuleList(
277 | [
278 | FluxSingleTransformerBlock(
279 | dim=self.inner_dim,
280 | num_attention_heads=self.config.num_attention_heads,
281 | attention_head_dim=self.config.attention_head_dim,
282 | )
283 | for i in range(self.config.num_single_layers)
284 | ]
285 | )
286 |
287 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
288 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
289 |
290 | self.gradient_checkpointing = False
291 |
292 | @property
293 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
294 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
295 | r"""
296 | Returns:
297 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
298 | indexed by its weight name.
299 | """
300 | # set recursively
301 | processors = {}
302 |
303 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
304 | if hasattr(module, "get_processor"):
305 | processors[f"{name}.processor"] = module.get_processor()
306 |
307 | for sub_name, child in module.named_children():
308 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
309 |
310 | return processors
311 |
312 | for name, module in self.named_children():
313 | fn_recursive_add_processors(name, module, processors)
314 |
315 | return processors
316 |
317 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
318 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
319 | r"""
320 | Sets the attention processor to use to compute attention.
321 |
322 | Parameters:
323 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
324 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
325 | for **all** `Attention` layers.
326 |
327 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
328 | processor. This is strongly recommended when setting trainable attention processors.
329 |
330 | """
331 | count = len(self.attn_processors.keys())
332 |
333 | if isinstance(processor, dict) and len(processor) != count:
334 | raise ValueError(
335 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
336 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
337 | )
338 |
339 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
340 | if hasattr(module, "set_processor"):
341 | if not isinstance(processor, dict):
342 | module.set_processor(processor)
343 | else:
344 | module.set_processor(processor.pop(f"{name}.processor"))
345 |
346 | for sub_name, child in module.named_children():
347 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
348 |
349 | for name, module in self.named_children():
350 | fn_recursive_attn_processor(name, module, processor)
351 |
352 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
353 | def fuse_qkv_projections(self):
354 | """
355 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
356 | are fused. For cross-attention modules, key and value projection matrices are fused.
357 |
358 |
359 |
360 | This API is 🧪 experimental.
361 |
362 |
363 | """
364 | self.original_attn_processors = None
365 |
366 | for _, attn_processor in self.attn_processors.items():
367 | if "Added" in str(attn_processor.__class__.__name__):
368 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
369 |
370 | self.original_attn_processors = self.attn_processors
371 |
372 | for module in self.modules():
373 | if isinstance(module, Attention):
374 | module.fuse_projections(fuse=True)
375 |
376 | self.set_attn_processor(FusedFluxAttnProcessor2_0())
377 |
378 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
379 | def unfuse_qkv_projections(self):
380 | """Disables the fused QKV projection if enabled.
381 |
382 |
383 |
384 | This API is 🧪 experimental.
385 |
386 |
387 |
388 | """
389 | if self.original_attn_processors is not None:
390 | self.set_attn_processor(self.original_attn_processors)
391 |
392 | def _set_gradient_checkpointing(self, module, value=False):
393 | if hasattr(module, "gradient_checkpointing"):
394 | module.gradient_checkpointing = value
395 |
396 | def Divide_replace_hidden_states(self, hidden_states, Divide_hidden_states_list_list, Divide_m_offset_list,Divide_n_offset_list,Divide_m_scale_list,Divide_n_scale_list, latent_h, latent_w, Divide_idx):
397 | hidden_states = hidden_states.view(hidden_states.shape[0], latent_h,latent_w, hidden_states.shape[2])
398 |
399 | for Divide_hidden_states_list, Divide_m_offset, Divide_n_offset, Divide_m_scale, Divide_n_scale in zip(Divide_hidden_states_list_list, Divide_m_offset_list, Divide_n_offset_list, Divide_m_scale_list, Divide_n_scale_list):
400 | Divide_hidden_states = Divide_hidden_states_list[Divide_idx]
401 | Divide_hidden_states = Divide_hidden_states.view(Divide_hidden_states.shape[0], Divide_n_scale,Divide_m_scale, Divide_hidden_states.shape[2])
402 | hidden_states[:, Divide_n_offset:Divide_n_offset+Divide_n_scale, Divide_m_offset:Divide_m_offset+Divide_m_scale, :] = Divide_hidden_states
403 |
404 | hidden_states = hidden_states.view(hidden_states.shape[0], latent_h*latent_w, hidden_states.shape[3])
405 | Divide_idx += 1
406 |
407 | return hidden_states, Divide_idx
408 |
409 | def Repainting_replace_hidden_states(self, hidden_states, original_hidden_states_list, Repainting_Divide_m_offset, Repainting_Divide_n_offset, Repainting, latent_h, latent_w, Repainting_idx):
410 | original_hidden_states = original_hidden_states_list[Repainting_idx]
411 | original_hidden_states = original_hidden_states.view(original_hidden_states.shape[0], latent_h, latent_w, original_hidden_states.shape[2])
412 | hidden_states = hidden_states.view(hidden_states.shape[0], latent_h,latent_w, hidden_states.shape[2])
413 | original_hidden_states[:, Repainting_Divide_n_offset:Repainting_Divide_n_offset+Repainting.shape[1], Repainting_Divide_m_offset:Repainting_Divide_m_offset+Repainting.shape[2], :][Repainting == 1] = hidden_states[:, Repainting_Divide_n_offset:Repainting_Divide_n_offset+Repainting.shape[1], Repainting_Divide_m_offset:Repainting_Divide_m_offset+Repainting.shape[2], :][Repainting == 1]
414 | hidden_states = original_hidden_states.view(hidden_states.shape[0], latent_h*latent_w, hidden_states.shape[3])
415 | Repainting_idx += 1
416 |
417 | return hidden_states, Repainting_idx
418 |
419 | def forward(
420 | self,
421 | hidden_states: torch.Tensor,
422 | encoder_hidden_states: torch.Tensor = None,
423 | pooled_projections: torch.Tensor = None,
424 | timestep: torch.LongTensor = None,
425 | img_ids: torch.Tensor = None,
426 | txt_ids: torch.Tensor = None,
427 | guidance: torch.Tensor = None,
428 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
429 | controlnet_block_samples=None,
430 | controlnet_single_block_samples=None,
431 | return_dict: bool = True,
432 | controlnet_blocks_repeat: bool = False,
433 | latent_h: int=None,
434 | latent_w: int=None,
435 | Divide_hidden_states_list_list: List[List[torch.Tensor]] = None,
436 | Divide_m_offset_list: List[int]=None,
437 | Divide_n_offset_list: List[int]=None,
438 | Divide_m_scale_list: List[int]=None,
439 | Divide_n_scale_list: List[int]=None,
440 | return_hidden_states_list: bool = False,
441 | original_hidden_states_list: List[torch.Tensor] = None,
442 | Repainting_Divide_m_offset: int=None,
443 | Repainting_Divide_n_offset: int=None,
444 | Repainting: torch.Tensor = None,
445 | Repainting_single: int=False,
446 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
447 | """
448 | The [`FluxTransformer2DModel`] forward method.
449 |
450 | Args:
451 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
452 | Input `hidden_states`.
453 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
454 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
455 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
456 | from the embeddings of input conditions.
457 | timestep ( `torch.LongTensor`):
458 | Used to indicate denoising step.
459 | block_controlnet_hidden_states: (`list` of `torch.Tensor`):
460 | A list of tensors that if specified are added to the residuals of transformer blocks.
461 | joint_attention_kwargs (`dict`, *optional*):
462 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
463 | `self.processor` in
464 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
465 | return_dict (`bool`, *optional*, defaults to `True`):
466 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
467 | tuple.
468 |
469 | Returns:
470 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
471 | `tuple` where the first element is the sample tensor.
472 | """
473 | if return_hidden_states_list:
474 | hidden_states_list=[]
475 | if joint_attention_kwargs is not None:
476 | joint_attention_kwargs = joint_attention_kwargs.copy()
477 | lora_scale = joint_attention_kwargs.pop("scale", 1.0)
478 | else:
479 | lora_scale = 1.0
480 |
481 | if USE_PEFT_BACKEND:
482 | # weight the lora layers by setting `lora_scale` for each PEFT layer
483 | scale_lora_layers(self, lora_scale)
484 | else:
485 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
486 | logger.warning(
487 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
488 | )
489 | hidden_states = self.x_embedder(hidden_states)
490 |
491 | if Divide_hidden_states_list_list is not None:
492 | Divide_idx = 0
493 | hidden_states, Divide_idx = self.Divide_replace_hidden_states(hidden_states, Divide_hidden_states_list_list, Divide_m_offset_list, Divide_n_offset_list, Divide_m_scale_list, Divide_n_scale_list, latent_h, latent_w, Divide_idx)
494 |
495 | if original_hidden_states_list is not None:
496 | Repainting_idx = 0
497 | hidden_states, Repainting_idx = self.Repainting_replace_hidden_states(hidden_states, original_hidden_states_list, Repainting_Divide_m_offset, Repainting_Divide_n_offset, Repainting, latent_h, latent_w, Repainting_idx)
498 |
499 | if return_hidden_states_list:
500 | hidden_states_list.append(hidden_states)
501 |
502 | timestep = timestep.to(hidden_states.dtype) * 1000
503 | if guidance is not None:
504 | guidance = guidance.to(hidden_states.dtype) * 1000
505 | else:
506 | guidance = None
507 | temb = (
508 | self.time_text_embed(timestep, pooled_projections)
509 | if guidance is None
510 | else self.time_text_embed(timestep, guidance, pooled_projections)
511 | )
512 | encoder_hidden_states = self.context_embedder(encoder_hidden_states)
513 |
514 | if txt_ids.ndim == 3:
515 | logger.warning(
516 | "Passing `txt_ids` 3d torch.Tensor is deprecated."
517 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
518 | )
519 | txt_ids = txt_ids[0]
520 | if img_ids.ndim == 3:
521 | logger.warning(
522 | "Passing `img_ids` 3d torch.Tensor is deprecated."
523 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
524 | )
525 | img_ids = img_ids[0]
526 |
527 | ids = torch.cat((txt_ids, img_ids), dim=0)
528 | image_rotary_emb = self.pos_embed(ids)
529 |
530 | for index_block, block in enumerate(self.transformer_blocks):
531 | if self.training and self.gradient_checkpointing:
532 |
533 | def create_custom_forward(module, return_dict=None):
534 | def custom_forward(*inputs):
535 | if return_dict is not None:
536 | return module(*inputs, return_dict=return_dict)
537 | else:
538 | return module(*inputs)
539 |
540 | return custom_forward
541 |
542 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
543 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
544 | create_custom_forward(block),
545 | hidden_states,
546 | encoder_hidden_states,
547 | temb,
548 | image_rotary_emb,
549 | **ckpt_kwargs,
550 | )
551 |
552 | else:
553 | encoder_hidden_states, hidden_states = block(
554 | hidden_states=hidden_states,
555 | encoder_hidden_states=encoder_hidden_states,
556 | temb=temb,
557 | image_rotary_emb=image_rotary_emb,
558 | joint_attention_kwargs=joint_attention_kwargs,
559 | )
560 |
561 | if Divide_hidden_states_list_list is not None:
562 | hidden_states, Divide_idx = self.Divide_replace_hidden_states(hidden_states, Divide_hidden_states_list_list, Divide_m_offset_list, Divide_n_offset_list, Divide_m_scale_list, Divide_n_scale_list, latent_h, latent_w, Divide_idx)
563 |
564 | if original_hidden_states_list is not None:
565 | hidden_states, Repainting_idx = self.Repainting_replace_hidden_states(hidden_states, original_hidden_states_list, Repainting_Divide_m_offset, Repainting_Divide_n_offset, Repainting, latent_h, latent_w, Repainting_idx)
566 |
567 | if return_hidden_states_list:
568 | hidden_states_list.append(hidden_states)
569 |
570 | # controlnet residual
571 | if controlnet_block_samples is not None:
572 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
573 | interval_control = int(np.ceil(interval_control))
574 | # For Xlabs ControlNet.
575 | if controlnet_blocks_repeat:
576 | hidden_states = (
577 | hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
578 | )
579 | else:
580 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
581 |
582 |
583 |
584 | # if not Divide_m_offset_list:
585 | # print("save feature")
586 | # print(timestep.item())
587 | # torch.save(hidden_states.squeeze(0).cpu(), f"tensors_new/feature_{timestep.item()}.pt") # save feature in the shape of [c, h, w]
588 |
589 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
590 |
591 | for index_block, block in enumerate(self.single_transformer_blocks):
592 | if self.training and self.gradient_checkpointing:
593 |
594 | def create_custom_forward(module, return_dict=None):
595 | def custom_forward(*inputs):
596 | if return_dict is not None:
597 | return module(*inputs, return_dict=return_dict)
598 | else:
599 | return module(*inputs)
600 |
601 | return custom_forward
602 |
603 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
604 | hidden_states = torch.utils.checkpoint.checkpoint(
605 | create_custom_forward(block),
606 | hidden_states,
607 | temb,
608 | image_rotary_emb,
609 | **ckpt_kwargs,
610 | )
611 |
612 | else:
613 | hidden_states = block(
614 | hidden_states=hidden_states,
615 | temb=temb,
616 | image_rotary_emb=image_rotary_emb,
617 | joint_attention_kwargs=joint_attention_kwargs,
618 | )
619 |
620 | if Divide_hidden_states_list_list is not None:
621 | hidden_states_clone = hidden_states.clone()[:, encoder_hidden_states.shape[1] :, ...].view(hidden_states.shape[0], latent_h, latent_w, hidden_states.shape[2])
622 |
623 | for Divide_hidden_states_list, Divide_m_offset, Divide_n_offset, Divide_m_scale,Divide_n_scale in zip(Divide_hidden_states_list_list, Divide_m_offset_list, Divide_n_offset_list, Divide_m_scale_list, Divide_n_scale_list):
624 | Divide_hidden_states = Divide_hidden_states_list[Divide_idx]
625 | Divide_hidden_states = Divide_hidden_states[:, Divide_hidden_states.shape[1]-Divide_n_scale*Divide_m_scale :, ...].view(Divide_hidden_states.shape[0], Divide_n_scale, Divide_m_scale, Divide_hidden_states.shape[2])
626 | hidden_states_clone[:, Divide_n_offset:Divide_n_offset+Divide_n_scale,Divide_m_offset:Divide_m_offset+Divide_m_scale, :] = Divide_hidden_states
627 |
628 | hidden_states_clone = hidden_states_clone.view(hidden_states.shape[0], latent_h*latent_w, hidden_states.shape[2])
629 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = hidden_states_clone
630 | Divide_idx += 1
631 |
632 | if original_hidden_states_list is not None:
633 | if Repainting_single:
634 | hidden_states_clone = hidden_states.clone()[:, encoder_hidden_states.shape[1] :, ...].view(hidden_states.shape[0], latent_h, latent_w, hidden_states.shape[2])
635 | original_hidden_states = original_hidden_states_list[Repainting_idx]
636 | original_hidden_states = original_hidden_states[:, encoder_hidden_states.shape[1] :, ...].view(original_hidden_states.shape[0], latent_h, latent_w, original_hidden_states.shape[2])
637 | original_hidden_states[:, Repainting_Divide_n_offset:Repainting_Divide_n_offset+Repainting.shape[1], Repainting_Divide_m_offset:Repainting_Divide_m_offset+Repainting.shape[2], :][Repainting == 1] = hidden_states_clone[:, Repainting_Divide_n_offset:Repainting_Divide_n_offset+Repainting.shape[1], Repainting_Divide_m_offset:Repainting_Divide_m_offset+Repainting.shape[2], :][Repainting == 1]
638 | hidden_states_clone = original_hidden_states.view(hidden_states.shape[0], latent_h*latent_w, hidden_states.shape[2])
639 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = hidden_states_clone
640 | Repainting_idx += 1
641 |
642 | if return_hidden_states_list:
643 | hidden_states_list.append(hidden_states)
644 |
645 | # controlnet residual
646 | if controlnet_single_block_samples is not None:
647 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
648 | interval_control = int(np.ceil(interval_control))
649 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
650 | hidden_states[:, encoder_hidden_states.shape[1] :, ...]
651 | + controlnet_single_block_samples[index_block // interval_control]
652 | )
653 |
654 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
655 |
656 | hidden_states = self.norm_out(hidden_states, temb)
657 |
658 | if Divide_hidden_states_list_list is not None:
659 | hidden_states, Divide_idx = self.Divide_replace_hidden_states(hidden_states, Divide_hidden_states_list_list, Divide_m_offset_list, Divide_n_offset_list, Divide_m_scale_list, Divide_n_scale_list, latent_h, latent_w, Divide_idx)
660 |
661 | if original_hidden_states_list is not None:
662 | hidden_states, Repainting_idx = self.Repainting_replace_hidden_states(hidden_states, original_hidden_states_list, Repainting_Divide_m_offset, Repainting_Divide_n_offset, Repainting, latent_h, latent_w, Repainting_idx)
663 |
664 | if return_hidden_states_list:
665 | hidden_states_list.append(hidden_states)
666 |
667 | output = self.proj_out(hidden_states)
668 |
669 | if Divide_hidden_states_list_list is not None:
670 | hidden_states, Divide_idx = self.Divide_replace_hidden_states(hidden_states, Divide_hidden_states_list_list, Divide_m_offset_list, Divide_n_offset_list, Divide_m_scale_list, Divide_n_scale_list, latent_h, latent_w, Divide_idx)
671 |
672 | if original_hidden_states_list is not None:
673 | hidden_states, Repainting_idx = self.Repainting_replace_hidden_states(hidden_states, original_hidden_states_list, Repainting_Divide_m_offset, Repainting_Divide_n_offset, Repainting, latent_h, latent_w, Repainting_idx)
674 |
675 | if return_hidden_states_list:
676 | hidden_states_list.append(hidden_states)
677 |
678 | if USE_PEFT_BACKEND:
679 | # remove `lora_scale` from each PEFT layer
680 | unscale_lora_layers(self, lora_scale)
681 |
682 | if not return_dict:
683 | if return_hidden_states_list:
684 | return (output,),hidden_states_list
685 | else:
686 | return (output,)
687 |
688 | return Transformer2DModelOutput(sample=output)
--------------------------------------------------------------------------------
/T2IS_Gen/t2is_pipeline_flux.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import inspect
16 | from typing import Any, Callable, Dict, List, Optional, Union
17 |
18 | import numpy as np
19 | import torch
20 | from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
21 |
22 | from diffusers.image_processor import VaeImageProcessor
23 | from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
24 | from diffusers.models.autoencoders import AutoencoderKL
25 | from diffusers.models.transformers import FluxTransformer2DModel
26 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27 | from diffusers.utils import (
28 | USE_PEFT_BACKEND,
29 | is_torch_xla_available,
30 | logging,
31 | replace_example_docstring,
32 | scale_lora_layers,
33 | unscale_lora_layers,
34 | )
35 | from diffusers.utils.torch_utils import randn_tensor
36 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
37 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
38 |
39 | from init_attention import init_forwards,TOKENS
40 |
41 | from utils import get_grid_params
42 | import random
43 | import importlib.util
44 | import sys
45 | import PIL
46 | from PIL import Image, ImageChops
47 | from scipy.ndimage import binary_dilation
48 | import torchvision.transforms as transforms
49 |
50 | module_name = 'diffusers.models.transformers.transformer_flux'
51 | module_path = './t2is_transformer_flux.py'
52 |
53 | if module_name in sys.modules:
54 | del sys.modules[module_name]
55 |
56 | spec = importlib.util.spec_from_file_location(module_name, module_path)
57 | regionfluxmodel = importlib.util.module_from_spec(spec)
58 | sys.modules[module_name] = regionfluxmodel
59 | spec.loader.exec_module(regionfluxmodel)
60 |
61 | FluxTransformer2DModel = regionfluxmodel.FluxTransformer2DModel
62 |
63 | if is_torch_xla_available():
64 | import torch_xla.core.xla_model as xm
65 |
66 | XLA_AVAILABLE = True
67 | else:
68 | XLA_AVAILABLE = False
69 |
70 |
71 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
72 |
73 | EXAMPLE_DOC_STRING = """
74 | Examples:
75 | ```py
76 | >>> import torch
77 | >>> from diffusers import FluxPipeline
78 |
79 | >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
80 | >>> pipe.to("cuda")
81 | >>> prompt = "A cat holding a sign that says hello world"
82 | >>> # Depending on the variant being used, the pipeline call will slightly vary.
83 | >>> # Refer to the pipeline documentation for more details.
84 | >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
85 | >>> image.save("flux.png")
86 | ```
87 | """
88 |
89 |
90 | def calculate_shift(
91 | image_seq_len,
92 | base_seq_len: int = 256,
93 | max_seq_len: int = 4096,
94 | base_shift: float = 0.5,
95 | max_shift: float = 1.16,
96 | ):
97 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
98 | b = base_shift - m * base_seq_len
99 | mu = image_seq_len * m + b
100 | return mu
101 |
102 |
103 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
104 | def retrieve_timesteps(
105 | scheduler,
106 | num_inference_steps: Optional[int] = None,
107 | device: Optional[Union[str, torch.device]] = None,
108 | timesteps: Optional[List[int]] = None,
109 | sigmas: Optional[List[float]] = None,
110 | **kwargs,
111 | ):
112 | r"""
113 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
114 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
115 |
116 | Args:
117 | scheduler (`SchedulerMixin`):
118 | The scheduler to get timesteps from.
119 | num_inference_steps (`int`):
120 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
121 | must be `None`.
122 | device (`str` or `torch.device`, *optional*):
123 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
124 | timesteps (`List[int]`, *optional*):
125 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
126 | `num_inference_steps` and `sigmas` must be `None`.
127 | sigmas (`List[float]`, *optional*):
128 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
129 | `num_inference_steps` and `timesteps` must be `None`.
130 |
131 | Returns:
132 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
133 | second element is the number of inference steps.
134 | """
135 | if timesteps is not None and sigmas is not None:
136 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
137 | if timesteps is not None:
138 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
139 | if not accepts_timesteps:
140 | raise ValueError(
141 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
142 | f" timestep schedules. Please check whether you are using the correct scheduler."
143 | )
144 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
145 | timesteps = scheduler.timesteps
146 | num_inference_steps = len(timesteps)
147 | elif sigmas is not None:
148 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
149 | if not accept_sigmas:
150 | raise ValueError(
151 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
152 | f" sigmas schedules. Please check whether you are using the correct scheduler."
153 | )
154 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
155 | timesteps = scheduler.timesteps
156 | num_inference_steps = len(timesteps)
157 | else:
158 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
159 | timesteps = scheduler.timesteps
160 | return timesteps, num_inference_steps
161 |
162 |
163 | class T2IS_FluxPipeline(
164 | DiffusionPipeline,
165 | FluxLoraLoaderMixin,
166 | FromSingleFileMixin,
167 | TextualInversionLoaderMixin,
168 | ):
169 | r"""
170 | The Flux pipeline for text-to-image generation.
171 |
172 | Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
173 |
174 | Args:
175 | transformer ([`FluxTransformer2DModel`]):
176 | Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
177 | scheduler ([`FlowMatchEulerDiscreteScheduler`]):
178 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
179 | vae ([`AutoencoderKL`]):
180 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
181 | text_encoder ([`CLIPTextModel`]):
182 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
183 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
184 | text_encoder_2 ([`T5EncoderModel`]):
185 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
186 | the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
187 | tokenizer (`CLIPTokenizer`):
188 | Tokenizer of class
189 | [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
190 | tokenizer_2 (`T5TokenizerFast`):
191 | Second Tokenizer of class
192 | [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
193 | """
194 |
195 | model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
196 | _optional_components = []
197 | _callback_tensor_inputs = ["latents", "prompt_embeds"]
198 |
199 | def __init__(
200 | self,
201 | scheduler: FlowMatchEulerDiscreteScheduler,
202 | vae: AutoencoderKL,
203 | text_encoder: CLIPTextModel,
204 | tokenizer: CLIPTokenizer,
205 | text_encoder_2: T5EncoderModel,
206 | tokenizer_2: T5TokenizerFast,
207 | transformer: FluxTransformer2DModel,
208 | ):
209 | super().__init__()
210 |
211 | self.register_modules(
212 | vae=vae,
213 | text_encoder=text_encoder,
214 | text_encoder_2=text_encoder_2,
215 | tokenizer=tokenizer,
216 | tokenizer_2=tokenizer_2,
217 | transformer=transformer,
218 | scheduler=scheduler,
219 | )
220 | self.vae_scale_factor = (
221 | 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
222 | )
223 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
224 | self.tokenizer_max_length = (
225 | self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
226 | )
227 | self.default_sample_size = 64
228 |
229 | def _get_t5_prompt_embeds(
230 | self,
231 | prompt: Union[str, List[str]] = None,
232 | num_images_per_prompt: int = 1,
233 | max_sequence_length: int = 512,
234 | device: Optional[torch.device] = None,
235 | dtype: Optional[torch.dtype] = None,
236 | ):
237 | device = device or self._execution_device
238 | dtype = dtype or self.text_encoder.dtype
239 |
240 | prompt = [prompt] if isinstance(prompt, str) else prompt
241 | batch_size = len(prompt)
242 |
243 | if isinstance(self, TextualInversionLoaderMixin):
244 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
245 |
246 | text_inputs = self.tokenizer_2(
247 | prompt,
248 | padding="max_length",
249 | max_length=max_sequence_length,
250 | truncation=True,
251 | return_length=False,
252 | return_overflowing_tokens=False,
253 | return_tensors="pt",
254 | )
255 | text_input_ids = text_inputs.input_ids
256 | untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
257 |
258 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
259 | removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
260 | logger.warning(
261 | "The following part of your input was truncated because `max_sequence_length` is set to "
262 | f" {max_sequence_length} tokens: {removed_text}"
263 | )
264 |
265 | prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
266 |
267 | dtype = self.text_encoder_2.dtype
268 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
269 |
270 | _, seq_len, _ = prompt_embeds.shape
271 |
272 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
273 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
274 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
275 |
276 | return prompt_embeds
277 |
278 | def _get_clip_prompt_embeds(
279 | self,
280 | prompt: Union[str, List[str]],
281 | num_images_per_prompt: int = 1,
282 | device: Optional[torch.device] = None,
283 | ):
284 | device = device or self._execution_device
285 |
286 | prompt = [prompt] if isinstance(prompt, str) else prompt
287 | batch_size = len(prompt)
288 |
289 | if isinstance(self, TextualInversionLoaderMixin):
290 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
291 |
292 | text_inputs = self.tokenizer(
293 | prompt,
294 | padding="max_length",
295 | max_length=self.tokenizer_max_length,
296 | truncation=True,
297 | return_overflowing_tokens=False,
298 | return_length=False,
299 | return_tensors="pt",
300 | )
301 |
302 | text_input_ids = text_inputs.input_ids
303 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
304 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
305 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
306 | logger.warning(
307 | "The following part of your input was truncated because CLIP can only handle sequences up to"
308 | f" {self.tokenizer_max_length} tokens: {removed_text}"
309 | )
310 | prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
311 |
312 | # Use pooled output of CLIPTextModel
313 | prompt_embeds = prompt_embeds.pooler_output
314 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
315 |
316 | # duplicate text embeddings for each generation per prompt, using mps friendly method
317 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
318 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
319 |
320 | return prompt_embeds
321 |
322 | def encode_prompt(
323 | self,
324 | prompt: Union[str, List[str]],
325 | prompt_2: Union[str, List[str]],
326 | device: Optional[torch.device] = None,
327 | num_images_per_prompt: int = 1,
328 | prompt_embeds: Optional[torch.FloatTensor] = None,
329 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
330 | max_sequence_length: int = 512,
331 | lora_scale: Optional[float] = None,
332 | ):
333 | r"""
334 |
335 | Args:
336 | prompt (`str` or `List[str]`, *optional*):
337 | prompt to be encoded
338 | prompt_2 (`str` or `List[str]`, *optional*):
339 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
340 | used in all text-encoders
341 | device: (`torch.device`):
342 | torch device
343 | num_images_per_prompt (`int`):
344 | number of images that should be generated per prompt
345 | prompt_embeds (`torch.FloatTensor`, *optional*):
346 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
347 | provided, text embeddings will be generated from `prompt` input argument.
348 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
349 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
350 | If not provided, pooled text embeddings will be generated from `prompt` input argument.
351 | lora_scale (`float`, *optional*):
352 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
353 | """
354 | device = device or self._execution_device
355 |
356 | # set lora scale so that monkey patched LoRA
357 | # function of text encoder can correctly access it
358 | if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
359 | self._lora_scale = lora_scale
360 |
361 | # dynamically adjust the LoRA scale
362 | if self.text_encoder is not None and USE_PEFT_BACKEND:
363 | scale_lora_layers(self.text_encoder, lora_scale)
364 | if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
365 | scale_lora_layers(self.text_encoder_2, lora_scale)
366 |
367 | prompt = [prompt] if isinstance(prompt, str) else prompt
368 |
369 | if prompt_embeds is None:
370 | prompt_2 = prompt_2 or prompt
371 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
372 |
373 | # We only use the pooled prompt output from the CLIPTextModel
374 | pooled_prompt_embeds = self._get_clip_prompt_embeds(
375 | prompt=prompt,
376 | device=device,
377 | num_images_per_prompt=num_images_per_prompt,
378 | )
379 | prompt_embeds = self._get_t5_prompt_embeds(
380 | prompt=prompt_2,
381 | num_images_per_prompt=num_images_per_prompt,
382 | max_sequence_length=max_sequence_length,
383 | device=device,
384 | )
385 |
386 | if self.text_encoder is not None:
387 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
388 | # Retrieve the original scale by scaling back the LoRA layers
389 | unscale_lora_layers(self.text_encoder, lora_scale)
390 |
391 | if self.text_encoder_2 is not None:
392 | if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
393 | # Retrieve the original scale by scaling back the LoRA layers
394 | unscale_lora_layers(self.text_encoder_2, lora_scale)
395 |
396 | dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
397 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
398 |
399 | return prompt_embeds, pooled_prompt_embeds, text_ids
400 |
401 | def Divide_encode_prompt(
402 | self,
403 | Divide_prompt_list: None,
404 | Redux_list: None,
405 | device: Optional[torch.device] = None,
406 | num_images_per_prompt: int = 1,
407 | max_sequence_length: int = 512,
408 | lora_scale: Optional[float] = None,
409 | ):
410 | Divide_prompt_embeds_list = []
411 | Divide_pooled_prompt_embeds_list = []
412 | Divide_text_ids_list = []
413 |
414 | if Redux_list is not None:
415 | for Redux in Redux_list:
416 | (
417 | Divide_prompt_embeds,
418 | Divide_pooled_prompt_embeds,
419 | Divide_text_ids,
420 | ) = self.encode_prompt(
421 | **Redux,
422 | prompt=None,
423 | prompt_2=None,
424 | device=device,
425 | num_images_per_prompt=num_images_per_prompt,
426 | max_sequence_length=max_sequence_length,
427 | lora_scale=lora_scale,
428 | )
429 | Divide_prompt_embeds_list.append(Divide_prompt_embeds)
430 | Divide_pooled_prompt_embeds_list.append(Divide_pooled_prompt_embeds)
431 | Divide_text_ids_list.append(Divide_text_ids)
432 | else:
433 | for Divide_prompt in Divide_prompt_list:
434 | (
435 | Divide_prompt_embeds,
436 | Divide_pooled_prompt_embeds,
437 | Divide_text_ids,
438 | ) = self.encode_prompt(
439 | prompt=Divide_prompt,
440 | prompt_2=None,
441 | device=device,
442 | num_images_per_prompt=num_images_per_prompt,
443 | max_sequence_length=max_sequence_length,
444 | lora_scale=lora_scale,
445 | )
446 |
447 | Divide_prompt_embeds_list.append(Divide_prompt_embeds)
448 | Divide_pooled_prompt_embeds_list.append(Divide_pooled_prompt_embeds)
449 | Divide_text_ids_list.append(Divide_text_ids)
450 |
451 | return Divide_prompt_embeds_list, Divide_pooled_prompt_embeds_list, Divide_text_ids_list
452 |
453 |
454 |
455 |
456 |
457 |
458 | def torch_fix_seed(self, seed=42):
459 | random.seed(seed)
460 | np.random.seed(seed)
461 | torch.manual_seed(seed)
462 | torch.cuda.manual_seed(seed)
463 | torch.backends.cudnn.deterministic = True
464 | torch.use_deterministic_algorithms = True
465 |
466 | def check_inputs(
467 | self,
468 | prompt,
469 | prompt_2,
470 | height,
471 | width,
472 | prompt_embeds=None,
473 | pooled_prompt_embeds=None,
474 | callback_on_step_end_tensor_inputs=None,
475 | max_sequence_length=None,
476 | ):
477 | if height % 8 != 0 or width % 8 != 0:
478 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
479 |
480 | if callback_on_step_end_tensor_inputs is not None and not all(
481 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
482 | ):
483 | raise ValueError(
484 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
485 | )
486 |
487 | if prompt is not None and prompt_embeds is not None:
488 | raise ValueError(
489 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
490 | " only forward one of the two."
491 | )
492 | elif prompt_2 is not None and prompt_embeds is not None:
493 | raise ValueError(
494 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
495 | " only forward one of the two."
496 | )
497 | elif prompt is None and prompt_embeds is None:
498 | raise ValueError(
499 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
500 | )
501 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
502 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
503 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
504 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
505 |
506 | if prompt_embeds is not None and pooled_prompt_embeds is None:
507 | raise ValueError(
508 | "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
509 | )
510 |
511 | if max_sequence_length is not None and max_sequence_length > 512:
512 | raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
513 |
514 | @staticmethod
515 | def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
516 | latent_image_ids = torch.zeros(height // 2, width // 2, 3)
517 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
518 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
519 |
520 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
521 |
522 | latent_image_ids = latent_image_ids.reshape(
523 | latent_image_id_height * latent_image_id_width, latent_image_id_channels
524 | )
525 |
526 | return latent_image_ids.to(device=device, dtype=dtype)
527 |
528 | @staticmethod
529 | def _pack_latents(latents, batch_size, num_channels_latents, height, width):
530 | latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
531 | latents = latents.permute(0, 2, 4, 1, 3, 5)
532 | latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
533 |
534 | return latents
535 |
536 | @staticmethod
537 | def _unpack_latents(latents, height, width, vae_scale_factor):
538 | batch_size, num_patches, channels = latents.shape
539 |
540 | height = height // vae_scale_factor
541 | width = width // vae_scale_factor
542 |
543 | latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
544 | latents = latents.permute(0, 3, 1, 4, 2, 5)
545 |
546 | latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
547 |
548 | return latents
549 |
550 | def enable_vae_slicing(self):
551 | r"""
552 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
553 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
554 | """
555 | self.vae.enable_slicing()
556 |
557 | def disable_vae_slicing(self):
558 | r"""
559 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
560 | computing decoding in one step.
561 | """
562 | self.vae.disable_slicing()
563 |
564 | def enable_vae_tiling(self):
565 | r"""
566 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
567 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
568 | processing larger images.
569 | """
570 | self.vae.enable_tiling()
571 |
572 | def disable_vae_tiling(self):
573 | r"""
574 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
575 | computing decoding in one step.
576 | """
577 | self.vae.disable_tiling()
578 |
579 |
580 |
581 | def prepare_latents(
582 | self,
583 | batch_size,
584 | num_channels_latents,
585 | height,
586 | width,
587 | dtype,
588 | device,
589 | generator,
590 | latents=None,
591 | ):
592 | height = 2 * (int(height) // self.vae_scale_factor)
593 | width = 2 * (int(width) // self.vae_scale_factor)
594 |
595 | shape = (batch_size, num_channels_latents, height, width)
596 |
597 | if latents is not None:
598 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
599 | return latents.to(device=device, dtype=dtype), latent_image_ids
600 |
601 | if isinstance(generator, list) and len(generator) != batch_size:
602 | raise ValueError(
603 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
604 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
605 | )
606 |
607 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
608 | latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
609 |
610 | latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
611 |
612 | return latents, latent_image_ids
613 |
614 | def prepare_Divide_latents(
615 | self,
616 | Divide_m_scale_list,
617 | Divide_n_scale_list,
618 | batch_size,
619 | num_channels_latents,
620 | dtype,
621 | device,
622 | generator
623 | ):
624 | Divide_latents_list = []
625 | Divide_latent_image_ids_list = []
626 |
627 | for Divide_m_scale, Divide_n_scale in zip(Divide_m_scale_list, Divide_n_scale_list):
628 | Divide_latents, Divide_latent_image_ids = self.prepare_latents(
629 | batch_size,
630 | num_channels_latents,
631 | Divide_n_scale*16,
632 | Divide_m_scale*16,
633 | dtype,
634 | device,
635 | generator
636 | )
637 |
638 | Divide_latents_list.append(Divide_latents)
639 | Divide_latent_image_ids_list.append(Divide_latent_image_ids)
640 |
641 | return Divide_latents_list, Divide_latent_image_ids_list
642 |
643 | def prepare_Divide_replace(
644 | self, Divide_latents_list, timesteps, Divide_replace, latents, Divide_prompt_embeds_list, Divide_pooled_prompt_embeds_list, Divide_text_ids_list, Divide_latent_image_ids_list, guidance, Divide_m_scale_list, Divide_n_scale_list
645 | ):
646 | Divide_latents_list_list = [Divide_latents_list]
647 | Divide_hidden_states_list_list_list = []
648 |
649 | for i, t in enumerate(timesteps):
650 | if(i >= Divide_replace):
651 | break
652 |
653 | timestep = t.expand(latents.shape[0]).to(latents.dtype)
654 | Divide_noise_pred_list = []
655 | Divide_hidden_states_list_list = []
656 |
657 | for Divide_prompt_embeds, Divide_latents, Divide_pooled_prompt_embeds, Divide_text_ids,Divide_latent_image_ids in zip(Divide_prompt_embeds_list, Divide_latents_list, Divide_pooled_prompt_embeds_list, Divide_text_ids_list, Divide_latent_image_ids_list):
658 | Divide_noise_pred, Divide_hidden_states_list = self.transformer(
659 | hidden_states=Divide_latents,
660 | timestep=timestep / 1000,
661 | guidance=guidance,
662 | pooled_projections=Divide_pooled_prompt_embeds,
663 | encoder_hidden_states=Divide_prompt_embeds,
664 | txt_ids=Divide_text_ids,
665 | img_ids=Divide_latent_image_ids,
666 | joint_attention_kwargs=None,
667 | return_dict=False,
668 | return_hidden_states_list=True,
669 | )
670 | Divide_noise_pred_list.append(Divide_noise_pred[0])
671 | Divide_hidden_states_list_list.append(Divide_hidden_states_list)
672 | Divide_hidden_states_list_list_list.append(Divide_hidden_states_list_list)
673 |
674 | updated_Divide_latents_list = []
675 | for Divide_latents, Divide_noise_pred in zip(Divide_latents_list, Divide_noise_pred_list):
676 | self.scheduler._init_step_index(t)
677 | Divide_latents = self.scheduler.step(Divide_noise_pred, t, Divide_latents, return_dict=False)[0]
678 | updated_Divide_latents_list.append(Divide_latents)
679 | Divide_latents_list = updated_Divide_latents_list
680 | Divide_latents_list_list.append(Divide_latents_list)
681 |
682 | Divide_latents_list_list = [
683 | [
684 | latents.view(latents.shape[0], n_scale, m_scale, latents.shape[2])
685 | for latents, m_scale, n_scale in zip(latents_list, Divide_m_scale_list, Divide_n_scale_list)
686 | ]
687 | for latents_list in Divide_latents_list_list
688 | ]
689 |
690 | return Divide_latents_list_list, Divide_hidden_states_list_list_list
691 |
692 |
693 |
694 | def Divide_replace_latents(self, latents, Divide_latents_list, Divide_m_offset_list, Divide_n_offset_list, height, width):
695 | latents = latents.view(latents.shape[0], int(height//16), int(width//16), latents.shape[2])
696 | for Divide_latents, Divide_m_offset, Divide_n_offset in zip(Divide_latents_list, Divide_m_offset_list, Divide_n_offset_list):
697 | latents[:, Divide_n_offset:Divide_n_offset+Divide_latents.shape[1], Divide_m_offset:Divide_m_offset+Divide_latents.shape[2], ] = Divide_latents
698 | latents = latents.view(latents.shape[0], latents.shape[1]*latents.shape[2], latents.shape[3])
699 |
700 | return latents
701 |
702 |
703 | @property
704 | def guidance_scale(self):
705 | return self._guidance_scale
706 |
707 | @property
708 | def joint_attention_kwargs(self):
709 | return self._joint_attention_kwargs
710 |
711 | @property
712 | def num_timesteps(self):
713 | return self._num_timesteps
714 |
715 | @property
716 | def interrupt(self):
717 | return self._interrupt
718 |
719 | @torch.no_grad()
720 | @replace_example_docstring(EXAMPLE_DOC_STRING)
721 | def __call__(
722 | self,
723 | Divide_replace: int,
724 | seed: int,
725 | Divide_prompt_list: List[str]=None,
726 | Redux_list = None,
727 | prompt: Union[str, List[str]] = None,
728 | prompt_2: Optional[Union[str, List[str]]] = None,
729 | height: Optional[int] = None,
730 | width: Optional[int] = None,
731 | num_inference_steps: int = 28,
732 | timesteps: List[int] = None,
733 | guidance_scale: float = 3.5,
734 | num_images_per_prompt: Optional[int] = 1,
735 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
736 | latents: Optional[torch.FloatTensor] = None,
737 | prompt_embeds: Optional[torch.FloatTensor] = None,
738 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
739 | output_type: Optional[str] = "pil",
740 | return_dict: bool = True,
741 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
742 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
743 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
744 | max_sequence_length: int = 512,
745 | ):
746 | r"""
747 | Function invoked when calling the pipeline for generation.
748 |
749 | Args:
750 | prompt (`str` or `List[str]`, *optional*):
751 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
752 | instead.
753 | prompt_2 (`str` or `List[str]`, *optional*):
754 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
755 | will be used instead
756 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
757 | The height in pixels of the generated image. This is set to 1024 by default for the best results.
758 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
759 | The width in pixels of the generated image. This is set to 1024 by default for the best results.
760 | num_inference_steps (`int`, *optional*, defaults to 50):
761 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
762 | expense of slower inference.
763 | timesteps (`List[int]`, *optional*):
764 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
765 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
766 | passed will be used. Must be in descending order.
767 | guidance_scale (`float`, *optional*, defaults to 7.0):
768 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
769 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
770 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
771 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
772 | usually at the expense of lower image quality.
773 | num_images_per_prompt (`int`, *optional*, defaults to 1):
774 | The number of images to generate per prompt.
775 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
776 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
777 | to make generation deterministic.
778 | latents (`torch.FloatTensor`, *optional*):
779 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
780 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
781 | tensor will ge generated by sampling using the supplied random `generator`.
782 | prompt_embeds (`torch.FloatTensor`, *optional*):
783 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
784 | provided, text embeddings will be generated from `prompt` input argument.
785 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
786 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
787 | If not provided, pooled text embeddings will be generated from `prompt` input argument.
788 | output_type (`str`, *optional*, defaults to `"pil"`):
789 | The output format of the generate image. Choose between
790 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
791 | return_dict (`bool`, *optional*, defaults to `True`):
792 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
793 | joint_attention_kwargs (`dict`, *optional*):
794 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
795 | `self.processor` in
796 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
797 | callback_on_step_end (`Callable`, *optional*):
798 | A function that calls at the end of each denoising steps during the inference. The function is called
799 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
800 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
801 | `callback_on_step_end_tensor_inputs`.
802 | callback_on_step_end_tensor_inputs (`List`, *optional*):
803 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
804 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
805 | `._callback_tensor_inputs` attribute of your pipeline class.
806 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
807 |
808 | Examples:
809 |
810 | Returns:
811 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
812 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
813 | images.
814 | """
815 |
816 | self.h = height
817 | self.w = width
818 |
819 | if (seed > 0):
820 | self.torch_fix_seed(seed = seed)
821 |
822 | init_forwards(self, self.transformer)
823 |
824 |
825 |
826 | # Get grid parameters based on number of prompts
827 | size = len(Divide_prompt_list)
828 | Divide_m_offset_list, Divide_n_offset_list, Divide_m_scale_list, Divide_n_scale_list = get_grid_params(size)
829 |
830 | # Convert to latent space coordinates
831 | Divide_m_offset_list = [int(m_offset * width // 16) for m_offset in Divide_m_offset_list]
832 | Divide_n_offset_list = [int(n_offset * height // 16) for n_offset in Divide_n_offset_list]
833 | Divide_m_scale_list = [int(m_scale * width // 16) for m_scale in Divide_m_scale_list]
834 | Divide_n_scale_list = [int(n_scale * height // 16) for n_scale in Divide_n_scale_list]
835 |
836 | height = height or self.default_sample_size * self.vae_scale_factor
837 | width = width or self.default_sample_size * self.vae_scale_factor
838 |
839 | # 1. Check inputs. Raise error if not correct
840 | self.check_inputs(
841 | prompt,
842 | prompt_2,
843 | height,
844 | width,
845 | prompt_embeds=prompt_embeds,
846 | pooled_prompt_embeds=pooled_prompt_embeds,
847 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
848 | max_sequence_length=max_sequence_length,
849 | )
850 |
851 | self._guidance_scale = guidance_scale
852 | self._joint_attention_kwargs = joint_attention_kwargs
853 | self._interrupt = False
854 |
855 | # 2. Define call parameters
856 | if prompt is not None and isinstance(prompt, str):
857 | batch_size = 1
858 | elif prompt is not None and isinstance(prompt, list):
859 | batch_size = len(prompt)
860 | else:
861 | batch_size = prompt_embeds.shape[0]
862 |
863 | device = self._execution_device
864 |
865 | lora_scale = (
866 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
867 | )
868 | (
869 | prompt_embeds,
870 | pooled_prompt_embeds,
871 | text_ids,
872 | ) = self.encode_prompt(
873 | prompt=prompt,
874 | prompt_2=prompt_2,
875 | prompt_embeds=prompt_embeds,
876 | pooled_prompt_embeds=pooled_prompt_embeds,
877 | device=device,
878 | num_images_per_prompt=num_images_per_prompt,
879 | max_sequence_length=max_sequence_length,
880 | lora_scale=lora_scale,
881 | )
882 |
883 | (
884 | Divide_prompt_embeds_list,
885 | Divide_pooled_prompt_embeds_list,
886 | Divide_text_ids_list,
887 | ) = self.Divide_encode_prompt(
888 | Divide_prompt_list=Divide_prompt_list,
889 | Redux_list=Redux_list,
890 | device=device,
891 | num_images_per_prompt=num_images_per_prompt,
892 | max_sequence_length=max_sequence_length,
893 | lora_scale=lora_scale,
894 | )
895 |
896 |
897 |
898 |
899 | # 4. Prepare latent variables
900 | num_channels_latents = self.transformer.config.in_channels // 4
901 | latents, latent_image_ids = self.prepare_latents(
902 | batch_size * num_images_per_prompt,
903 | num_channels_latents,
904 | height,
905 | width,
906 | prompt_embeds.dtype,
907 | device,
908 | generator,
909 | latents,
910 | )
911 |
912 | Divide_latents_list, Divide_latent_image_ids_list = self.prepare_Divide_latents(
913 | Divide_m_scale_list,
914 | Divide_n_scale_list,
915 | batch_size * num_images_per_prompt,
916 | num_channels_latents,
917 | prompt_embeds.dtype,
918 | device,
919 | generator
920 | )
921 |
922 |
923 |
924 |
925 | # 5. Prepare timesteps
926 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
927 | image_seq_len = latents.shape[1]
928 | mu = calculate_shift(
929 | image_seq_len,
930 | self.scheduler.config.base_image_seq_len,
931 | self.scheduler.config.max_image_seq_len,
932 | self.scheduler.config.base_shift,
933 | self.scheduler.config.max_shift,
934 | )
935 | timesteps, num_inference_steps = retrieve_timesteps(
936 | self.scheduler,
937 | num_inference_steps,
938 | device,
939 | timesteps,
940 | sigmas,
941 | mu=mu,
942 | )
943 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
944 | self._num_timesteps = len(timesteps)
945 |
946 | # handle guidance
947 | if self.transformer.config.guidance_embeds:
948 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
949 | guidance = guidance.expand(latents.shape[0])
950 | else:
951 | guidance = None
952 |
953 | # 6. Denoising loop
954 | Divide_latents_list_list, Divide_hidden_states_list_list_list = self.prepare_Divide_replace(Divide_latents_list, timesteps, Divide_replace, latents, Divide_prompt_embeds_list, Divide_pooled_prompt_embeds_list, Divide_text_ids_list, Divide_latent_image_ids_list, guidance, Divide_m_scale_list, Divide_n_scale_list)
955 |
956 |
957 |
958 |
959 | # hook_forwards(self, self.transformer)
960 |
961 | self.scheduler._init_step_index(timesteps[0])
962 | with self.progress_bar(total=num_inference_steps) as progress_bar:
963 | for i, t in enumerate(timesteps):
964 | if self.interrupt:
965 | continue
966 |
967 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
968 | timestep = t.expand(latents.shape[0]).to(latents.dtype)
969 |
970 | if i <= Divide_replace :
971 | latents = self.Divide_replace_latents(latents, Divide_latents_list_list[i], Divide_m_offset_list, Divide_n_offset_list, height, width)
972 |
973 | self._joint_attention_kwargs = None
974 | if i < Divide_replace:
975 | noise_pred = self.transformer(
976 | hidden_states=latents,
977 | timestep=timestep / 1000,
978 | guidance=guidance,
979 | pooled_projections=pooled_prompt_embeds,
980 | encoder_hidden_states=prompt_embeds,
981 | txt_ids=text_ids,
982 | img_ids=latent_image_ids,
983 | joint_attention_kwargs=self.joint_attention_kwargs,
984 | return_dict=False,
985 | Divide_hidden_states_list_list=Divide_hidden_states_list_list_list[i],
986 | Divide_m_offset_list=Divide_m_offset_list,
987 | Divide_n_offset_list=Divide_n_offset_list,
988 | Divide_m_scale_list=Divide_m_scale_list,
989 | Divide_n_scale_list=Divide_n_scale_list,
990 | latent_h=height//16,
991 | latent_w=width//16
992 | )[0]
993 |
994 | if i >= Divide_replace:
995 | # Release memory of Divide_latents_list and Divide_hidden_states_list_list_list
996 |
997 | if 'Divide_latents_list' in locals():
998 | del Divide_latents_list
999 | if 'Divide_hidden_states_list_list_list' in locals():
1000 | del Divide_hidden_states_list_list_list
1001 | torch.cuda.empty_cache()
1002 | init_forwards(self, self.transformer)
1003 | noise_pred = self.transformer(
1004 | hidden_states=latents,
1005 | timestep=timestep / 1000,
1006 | guidance=guidance,
1007 | pooled_projections=pooled_prompt_embeds,
1008 | encoder_hidden_states=prompt_embeds,
1009 | txt_ids=text_ids,
1010 | img_ids=latent_image_ids,
1011 | # update
1012 | joint_attention_kwargs=None,
1013 | # joint_attention_kwargs=self.joint_attention_kwargs,
1014 | return_dict=False,
1015 | )[0]
1016 |
1017 |
1018 |
1019 | # compute the previous noisy sample x_t -> x_t-1
1020 | latents_dtype = latents.dtype
1021 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1022 |
1023 | if latents.dtype != latents_dtype:
1024 | if torch.backends.mps.is_available():
1025 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1026 | latents = latents.to(latents_dtype)
1027 |
1028 |
1029 |
1030 | if callback_on_step_end is not None:
1031 | callback_kwargs = {}
1032 | for k in callback_on_step_end_tensor_inputs:
1033 | callback_kwargs[k] = locals()[k]
1034 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1035 |
1036 | latents = callback_outputs.pop("latents", latents)
1037 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1038 |
1039 | # call the callback, if provided
1040 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1041 | progress_bar.update()
1042 |
1043 | if XLA_AVAILABLE:
1044 | xm.mark_step()
1045 |
1046 | if output_type == "latent":
1047 | image = latents
1048 |
1049 | else:
1050 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1051 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1052 | image = self.vae.decode(latents, return_dict=False)[0]
1053 | image = self.image_processor.postprocess(image, output_type=output_type)
1054 |
1055 |
1056 |
1057 | # Offload all models
1058 | self.maybe_free_model_hooks()
1059 |
1060 | if not return_dict:
1061 | return (image,)
1062 |
1063 | return FluxPipelineOutput(images=image)
1064 |
--------------------------------------------------------------------------------