├── train_flux ├── flux │ ├── __init__.py │ ├── pipeline_tools.py │ ├── lora_controller.py │ ├── condition.py │ ├── transformer.py │ ├── generate.py │ └── block.py ├── train │ ├── __init__.py │ ├── callbacks.py │ ├── train.py │ ├── data.py │ └── model.py ├── train.sh ├── sample.sh ├── config_prompt.yaml ├── config.yaml └── sample.py ├── tts ├── verifiers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── nvila_verifier.cpython-310.pyc │ │ └── openai_verifier.cpython-310.pyc │ ├── nvila_verifier.py │ ├── reflexion_prompt.txt │ ├── refine_prompt.txt │ ├── verifier_prompt.txt │ ├── geneval_verifier_prompt.txt │ └── geneval_detailed_verifier_prompt.json ├── configs │ ├── our_reflectionmodel.yaml │ ├── flux.1_dev_nvilascore.json │ └── flux.1_dev_gptscore.json ├── tts_t2i_noise_scaling.py ├── utils.py ├── verifier_filter.py └── tts_t2i_noise_prompt_scaling.py ├── examples └── teaser.jpg ├── reward_modeling ├── __pycache__ │ ├── data.cpython-310.pyc │ ├── data.cpython-312.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-312.pyc │ ├── trainer.cpython-310.pyc │ ├── trainer.cpython-312.pyc │ ├── inference.cpython-310.pyc │ ├── test_reward.cpython-310.pyc │ ├── train_reward.cpython-310.pyc │ ├── prompt_template.cpython-310.pyc │ ├── prompt_template.cpython-312.pyc │ ├── vision_process.cpython-310.pyc │ └── vision_process.cpython-312.pyc ├── test_reward.py ├── prompt_template.py ├── data.py ├── inference.py ├── utils.py └── train_reward.py ├── requirements.txt ├── .gitignore └── README.md /train_flux/flux/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train_flux/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tts/verifiers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/examples/teaser.jpg -------------------------------------------------------------------------------- /tts/configs/our_reflectionmodel.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: path/infer/30000 2 | template: qwen2_vl 3 | finetuning_type: lora -------------------------------------------------------------------------------- /reward_modeling/__pycache__/data.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/data.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/data.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/data.cpython-312.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/utils.cpython-312.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/trainer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/trainer.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/trainer.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/trainer.cpython-312.pyc -------------------------------------------------------------------------------- /tts/verifiers/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/tts/verifiers/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/inference.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/inference.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/test_reward.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/test_reward.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/train_reward.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/train_reward.cpython-310.pyc -------------------------------------------------------------------------------- /tts/verifiers/__pycache__/nvila_verifier.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/tts/verifiers/__pycache__/nvila_verifier.cpython-310.pyc -------------------------------------------------------------------------------- /tts/verifiers/__pycache__/openai_verifier.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/tts/verifiers/__pycache__/openai_verifier.cpython-310.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | transformers 3 | peft 4 | opencv-python 5 | sentencepiece 6 | lightning 7 | datasets 8 | torchvision 9 | prodigyopt 10 | wandb 11 | webdataset -------------------------------------------------------------------------------- /reward_modeling/__pycache__/prompt_template.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/prompt_template.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/prompt_template.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/prompt_template.cpython-312.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/vision_process.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/vision_process.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/vision_process.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/vision_process.cpython-312.pyc -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.out 2 | *.batch 3 | 4 | train_flux/run/ 5 | train_flux/wandb/ 6 | train_flux/train/__pycache__ 7 | train_flux/flux/__pycache__ 8 | 9 | tts/dcgm/ 10 | tts/output/ 11 | tts/__pycache__/ 12 | -------------------------------------------------------------------------------- /train_flux/train.sh: -------------------------------------------------------------------------------- 1 | conda init 2 | source ~/.bashrc 3 | conda activate /ceph/data-bk/miniconda3/envs/flux 4 | 5 | cd /ceph/data-bk/zl/ReflectionFlow/train_flux 6 | 7 | export TOKENIZERS_PARALLELISM=true 8 | accelerate launch --main_process_port 41353 -m train.train 9 | -------------------------------------------------------------------------------- /train_flux/sample.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | CUDA_VISIBLE_DEVICES=2 python sample.py \ 4 | --model_name flux \ 5 | --step 30 \ 6 | --condition_size 512 \ 7 | --target_size 1024 \ 8 | --task_name geneval \ 9 | --lora_dir /mnt/petrelfs/zhuole/ReflectionFlow/train_flux/runs/full_data_v4_cond_512/20250410-191141/ckpt/16000/pytorch_lora_weights.safetensors \ 10 | --output_dir /mnt/petrelfs/zhuole/ReflectionFlow/train_flux/samples/full_data_v4_512_16k \ 11 | --seed 0 \ 12 | -------------------------------------------------------------------------------- /tts/verifiers/nvila_verifier.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel 2 | 3 | # nvila verifier 4 | def load_model(model_name, cache_dir): 5 | print("loading NVILA model") 6 | model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map="auto", cache_dir = cache_dir) 7 | yes_id = model.tokenizer.encode("yes", add_special_tokens=False)[0] 8 | no_id = model.tokenizer.encode("no", add_special_tokens=False)[0] 9 | print("loading NVILA finished") 10 | return model, yes_id, no_id -------------------------------------------------------------------------------- /tts/verifiers/reflexion_prompt.txt: -------------------------------------------------------------------------------- 1 | You are an expert assistant for generating image improvement instructions. Analyze the original prompt, the updated prompt to generate the image, the evaluation of the generated image, and the generated image, give instructions to create specific technical directions following these guidelines:\n1. Structure and Focus Areas:\nFocus strictly on these three aspects in order: Prompt Following.\n2. Detailed Requirements for Each Aspect:\nA. Prompt Following Instructions:\nExamine the original prompt sentence by sentence. \nList exact discrepancies between the bad image and prompt specifications. \nUse direct action verbs: Add, Remove, Replace, Reposition, Adjust, to modify the image. \nSpecify precise locations and modification commands. \nNever use vague terms like ensure or confirm.\n3. Format Specifications:\nUse exact section headers without markdown:\n1. Prompt Following:\n-\n Each instruction must start with a hyphen and complete command. \nInclude spatial references and implementation details. \nOmit sections with no required improvements. \nNever include explanations or examples.\n\n4. Content Principles:\nEvery instruction must be directly executable by an artist. \nPrioritize critical errors first. \nDescribe only missing or incorrect elements. \nUse imperative verb forms exclusively. \nMaintain technical specificity without assumptions.\n -------------------------------------------------------------------------------- /tts/configs/flux.1_dev_nvilascore.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "pipeline_args": { 4 | "pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev", 5 | "cache_dir": "FLUX_PATH", 6 | "torch_dtype": "bf16", 7 | "height": 1024, 8 | "width": 1024, 9 | "condition_size": 512, 10 | "max_sequence_length": 512, 11 | "guidance_scale": 3.5, 12 | "num_inference_steps": 30, 13 | "lora_path": "LORA_PATH" 14 | }, 15 | "verifier_args": { 16 | "name": "nvila", 17 | "model_name": "Efficient-Large-Model/NVILA-Lite-2B-Verifier", 18 | "cache_dir": "models/nvila" 19 | }, 20 | "refine_args": { 21 | "name": "openai", 22 | "choice_of_metric": "overall_score", 23 | "max_new_tokens": 1280, 24 | "refine_prompt_relpath": "refine_prompt.txt", 25 | "reflexion_prompt_relpath": "reflexion_prompt.txt", 26 | "verifier_prompt_relpath": "geneval_detailed_verifier_prompt.json" 27 | }, 28 | "search_args": { 29 | "search_method": "random", 30 | "search_branch": 2, 31 | "search_rounds": 16 32 | }, 33 | "model": { 34 | "add_cond_attn": false, 35 | "latent_lora": false, 36 | "union_cond_attn": true 37 | }, 38 | "reflection_args": { 39 | "run_reflection": true, 40 | "name": "openai" 41 | }, 42 | "prompt_refiner_args": { 43 | "run_refinement": true 44 | }, 45 | "use_low_gpu_vram": false, 46 | "batch_size_for_img_gen": 1 47 | } -------------------------------------------------------------------------------- /tts/verifiers/refine_prompt.txt: -------------------------------------------------------------------------------- 1 | """ 2 | You are a multimodal large-language model tasked with refining user's input prompt to 3 | create images using a text-to-image model. Given a original prompt, a current prompt, 4 | a batch of images generated by the prompt, a reflection prompt about the generated images and their corresponding assessments 5 | evaluated by a multi-domain scoring system, your goal is to refine the current prompt to 6 | improve the overall quality of the generated images. You should analyze the strengths 7 | and drawbacks of current prompt based on the given images and their evaluations. Consider 8 | aspects like subject, scene, style, lighting, tone, mood, camera style, composition, and 9 | others to refine the current prompt. Do not alter the original description from 10 | the original prompt. The refined prompt should not contradict with the reflection prompt. Directly output the refined prompt without any other text. 11 | 12 | Some further instructions you should keep in mind: 13 | 14 | 1) The current prompt is an iterative refinement of the original prompt. 15 | 16 | 2) In case the original prompt and current prompt are the same, ignore the current prompt. 17 | 18 | 3) In some cases, some of the above-mentioned inputs may not be available. For example, the images, 19 | the assessments, etc. In such situations, you should still do your best, analyze the inputs carefully, 20 | and arrive at a refined prompt that would potentially lead to improvements in the final generated images. 21 | 22 | 4) When the evaluations are provided, please consider all aspects of the evaluations very carefully. 23 | """ -------------------------------------------------------------------------------- /tts/configs/flux.1_dev_gptscore.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "pipeline_args": { 4 | "pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev", 5 | "cache_dir": "FLUX_PATH", 6 | "torch_dtype": "bf16", 7 | "height": 1024, 8 | "width": 1024, 9 | "condition_size": 512, 10 | "max_sequence_length": 512, 11 | "guidance_scale": 3.5, 12 | "num_inference_steps": 30, 13 | "lora_path": "LORA_PATH" 14 | }, 15 | "verifier_args": { 16 | "name": "openai", 17 | "choice_of_metric": "overall_score", 18 | "max_new_tokens": 1280, 19 | "refine_prompt_relpath": "refine_prompt.txt", 20 | "reflexion_prompt_relpath": "reflexion_prompt.txt", 21 | "verifier_prompt_relpath": "geneval_detailed_verifier_prompt.json" 22 | }, 23 | "refine_args": { 24 | "name": "openai", 25 | "choice_of_metric": "overall_score", 26 | "max_new_tokens": 1280, 27 | "refine_prompt_relpath": "refine_prompt.txt", 28 | "reflexion_prompt_relpath": "reflexion_prompt.txt", 29 | "verifier_prompt_relpath": "geneval_detailed_verifier_prompt.json" 30 | }, 31 | "search_args": { 32 | "search_method": "random", 33 | "search_branch": 2, 34 | "search_rounds": 16 35 | }, 36 | "model": { 37 | "add_cond_attn": false, 38 | "latent_lora": false, 39 | "union_cond_attn": true 40 | }, 41 | "reflection_args": { 42 | "run_reflection": true, 43 | "name": "openai" 44 | }, 45 | "prompt_refiner_args": { 46 | "run_refinement": true 47 | }, 48 | "use_low_gpu_vram": false, 49 | "batch_size_for_img_gen": 1 50 | } -------------------------------------------------------------------------------- /train_flux/flux/pipeline_tools.py: -------------------------------------------------------------------------------- 1 | from diffusers.pipelines import FluxPipeline 2 | from diffusers.utils import logging 3 | from diffusers.pipelines.flux.pipeline_flux import logger 4 | from torch import Tensor 5 | 6 | 7 | def encode_images(pipeline: FluxPipeline, images: Tensor): 8 | images = pipeline.image_processor.preprocess(images) 9 | images = images.to(pipeline.device).to(pipeline.dtype) 10 | images = pipeline.vae.encode(images).latent_dist.sample() 11 | images = ( 12 | images - pipeline.vae.config.shift_factor 13 | ) * pipeline.vae.config.scaling_factor 14 | images_tokens = pipeline._pack_latents(images, *images.shape) 15 | images_ids = pipeline._prepare_latent_image_ids( 16 | images.shape[0], 17 | images.shape[2], 18 | images.shape[3], 19 | pipeline.device, 20 | pipeline.dtype, 21 | ) 22 | if images_tokens.shape[1] != images_ids.shape[0]: 23 | images_ids = pipeline._prepare_latent_image_ids( 24 | images.shape[0], 25 | images.shape[2] // 2, 26 | images.shape[3] // 2, 27 | pipeline.device, 28 | pipeline.dtype, 29 | ) 30 | return images_tokens, images_ids 31 | 32 | 33 | def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512, prompts_2=None): 34 | # Turn off warnings (CLIP overflow) 35 | logger.setLevel(logging.ERROR) 36 | ( 37 | prompt_embeds, 38 | pooled_prompt_embeds, 39 | text_ids, 40 | ) = pipeline.encode_prompt( 41 | prompt=prompts, 42 | prompt_2=prompts_2, 43 | prompt_embeds=None, 44 | pooled_prompt_embeds=None, 45 | device=pipeline.device, 46 | num_images_per_prompt=1, 47 | max_sequence_length=max_sequence_length, 48 | lora_scale=None, 49 | ) 50 | # Turn on warnings 51 | logger.setLevel(logging.WARNING) 52 | return prompt_embeds, pooled_prompt_embeds, text_ids 53 | -------------------------------------------------------------------------------- /train_flux/config_prompt.yaml: -------------------------------------------------------------------------------- 1 | model_path: "black-forest-labs/FLUX.1-dev" 2 | dtype: "bfloat16" 3 | cache_dir: "CACHE_DIR" 4 | 5 | model: 6 | union_cond_attn: true 7 | add_cond_attn: false 8 | latent_lora: false 9 | 10 | train: 11 | batch_size: 8 12 | accumulate_grad_batches: 1 13 | dataloader_workers: 8 14 | save_interval: 2000 15 | sample_interval: 2000 16 | max_steps: -1 17 | gradient_checkpointing: true 18 | save_path: "./runs/prompt_cond_512" 19 | 20 | # Specify the type of condition to use. 21 | condition_type: "cot" 22 | resume_training_from_last_checkpoint: false 23 | resume_training_from_checkpoint_path: "" 24 | dataset: 25 | type: "img" 26 | path: "/ceph/hf-cache/hub/datasets--diffusion-cot--GenRef-wds/snapshots/42c837b891fc34a944ed0c8124876a7e8225266f/*.tar" 27 | split_ratios: { 28 | "general": [0.1, 0.3], 29 | "length": [0.1, 0.3], 30 | "rule": [0.1, 0.4], 31 | "editing": [0.7, 0.0] 32 | } 33 | training_stages: [0, 5000] 34 | root_dir: "" 35 | # val_path: { 36 | # "general": "VAL_TARS" 37 | # } 38 | # val_root_dir: "" 39 | condition_size: 512 40 | target_size: 1024 41 | drop_text_prob: 0.1 42 | drop_image_prob: 0.1 43 | drop_reflection_prob: 1.0 44 | 45 | wandb: 46 | project: "ReflectionFlow" 47 | name: "prompt_cond_512" 48 | 49 | lora_config: 50 | r: 32 51 | lora_alpha: 32 52 | init_lora_weights: "gaussian" 53 | target_modules: "(.*x_embedder|.*(? None: 7 | self.activated: bool = activated 8 | if activated: 9 | return 10 | self.lora_modules: List[BaseTunerLayer] = [ 11 | each for each in lora_modules if isinstance(each, BaseTunerLayer) 12 | ] 13 | self.scales = [ 14 | { 15 | active_adapter: lora_module.scaling[active_adapter] 16 | for active_adapter in lora_module.active_adapters 17 | } 18 | for lora_module in self.lora_modules 19 | ] 20 | 21 | def __enter__(self) -> None: 22 | if self.activated: 23 | return 24 | 25 | for lora_module in self.lora_modules: 26 | if not isinstance(lora_module, BaseTunerLayer): 27 | continue 28 | lora_module.scale_layer(0) 29 | 30 | def __exit__( 31 | self, 32 | exc_type: Optional[Type[BaseException]], 33 | exc_val: Optional[BaseException], 34 | exc_tb: Optional[Any], 35 | ) -> None: 36 | if self.activated: 37 | return 38 | for i, lora_module in enumerate(self.lora_modules): 39 | if not isinstance(lora_module, BaseTunerLayer): 40 | continue 41 | for active_adapter in lora_module.active_adapters: 42 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter] 43 | 44 | 45 | class set_lora_scale: 46 | def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None: 47 | self.lora_modules: List[BaseTunerLayer] = [ 48 | each for each in lora_modules if isinstance(each, BaseTunerLayer) 49 | ] 50 | self.scales = [ 51 | { 52 | active_adapter: lora_module.scaling[active_adapter] 53 | for active_adapter in lora_module.active_adapters 54 | } 55 | for lora_module in self.lora_modules 56 | ] 57 | self.scale = scale 58 | 59 | def __enter__(self) -> None: 60 | for lora_module in self.lora_modules: 61 | if not isinstance(lora_module, BaseTunerLayer): 62 | continue 63 | lora_module.scale_layer(self.scale) 64 | 65 | def __exit__( 66 | self, 67 | exc_type: Optional[Type[BaseException]], 68 | exc_val: Optional[BaseException], 69 | exc_tb: Optional[Any], 70 | ) -> None: 71 | for i, lora_module in enumerate(self.lora_modules): 72 | if not isinstance(lora_module, BaseTunerLayer): 73 | continue 74 | for active_adapter in lora_module.active_adapters: 75 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter] 76 | -------------------------------------------------------------------------------- /train_flux/train/callbacks.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from PIL import Image, ImageFilter, ImageDraw 3 | import numpy as np 4 | from transformers import pipeline 5 | import cv2 6 | import torch 7 | import os 8 | 9 | try: 10 | import wandb 11 | except ImportError: 12 | wandb = None 13 | 14 | import time 15 | 16 | 17 | class TrainingCallback(L.Callback): 18 | def __init__(self, run_name, training_config: dict = {}): 19 | self.run_name, self.training_config = run_name, training_config 20 | 21 | self.print_every_n_steps = training_config.get("print_every_n_steps", 10) 22 | self.save_interval = training_config.get("save_interval", 1000) 23 | self.save_path = training_config.get("save_path", "./output") 24 | 25 | self.wandb_config = training_config.get("wandb", None) 26 | self.use_wandb = ( 27 | wandb is not None and os.environ.get("WANDB_API_KEY") is not None 28 | ) 29 | 30 | self.total_steps = 0 31 | 32 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 33 | gradient_size = 0 34 | max_gradient_size = 0 35 | count = 0 36 | for _, param in pl_module.named_parameters(): 37 | if param.grad is not None: 38 | gradient_size += param.grad.norm(2).item() 39 | max_gradient_size = max(max_gradient_size, param.grad.norm(2).item()) 40 | count += 1 41 | if count > 0: 42 | gradient_size /= count 43 | 44 | self.total_steps += 1 45 | 46 | # Update split ratios 47 | trainer.train_dataloader.dataset._update_split_ratios() 48 | 49 | # Print training progress every n steps 50 | if self.use_wandb: 51 | report_dict = { 52 | "steps": batch_idx, 53 | "steps": self.total_steps, 54 | "epoch": trainer.current_epoch, 55 | "gradient_size": gradient_size, 56 | } 57 | loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches 58 | report_dict["loss"] = loss_value 59 | report_dict["t"] = pl_module.last_t 60 | wandb.log(report_dict) 61 | 62 | if self.total_steps % self.print_every_n_steps == 0: 63 | print( 64 | f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}" 65 | ) 66 | 67 | # Save LoRA weights at specified intervals 68 | if self.total_steps % self.save_interval == 0: 69 | print( 70 | f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights" 71 | ) 72 | pl_module.save_lora( 73 | f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}" 74 | ) -------------------------------------------------------------------------------- /tts/verifiers/verifier_prompt.txt: -------------------------------------------------------------------------------- 1 | """ 2 | You are a multimodal large-language model tasked with evaluating images 3 | generated by a text-to-image model. Your goal is to assess each generated 4 | image based on specific aspects and provide a detailed critique, along with 5 | a scoring system. The final output should be formatted as a JSON object 6 | containing individual scores for each aspect and an overall score. The keys 7 | in the JSON object should be: `accuracy_to_prompt`, `creativity_and_originality`, 8 | `visual_quality_and_realism`, `consistency_and_cohesion`, 9 | `emotional_or_thematic_resonance`, and `overall_score`. Below is a comprehensive 10 | guide to follow in your evaluation process: 11 | 12 | 1. Key Evaluation Aspects and Scoring Criteria: 13 | For each aspect, provide a score from 0 to 10, where 0 represents poor 14 | performance and 10 represents excellent performance. For each score, include 15 | a short explanation or justification (1-2 sentences) explaining why that 16 | score was given. The aspects to evaluate are as follows: 17 | 18 | a) Accuracy to Prompt 19 | Assess how well the image matches the description given in the prompt. 20 | Consider whether all requested elements are present and if the scene, 21 | objects, and setting align accurately with the text. Score: 0 (no 22 | alignment) to 10 (perfect match to prompt). 23 | 24 | b) Creativity and Originality 25 | Evaluate the uniqueness and creativity of the generated image. Does the 26 | model present an imaginative or aesthetically engaging interpretation of the 27 | prompt? Is there any evidence of creativity beyond a literal interpretation? 28 | Score: 0 (lacks creativity) to 10 (highly creative and original). 29 | 30 | c) Visual Quality and Realism 31 | Assess the overall visual quality, including resolution, detail, and realism. 32 | Look for coherence in lighting, shading, and perspective. Even if the image 33 | is stylized or abstract, judge whether the visual elements are well-rendered 34 | and visually appealing. Score: 0 (poor quality) to 10 (high-quality and 35 | realistic). 36 | 37 | d) Consistency and Cohesion 38 | Check for internal consistency within the image. Are all elements cohesive 39 | and aligned with the prompt? For instance, does the perspective make sense, 40 | and do objects fit naturally within the scene without visual anomalies? 41 | Score: 0 (inconsistent) to 10 (fully cohesive and consistent). 42 | 43 | e) Emotional or Thematic Resonance 44 | Evaluate how well the image evokes the intended emotional or thematic tone of 45 | the prompt. For example, if the prompt is meant to be serene, does the image 46 | convey calmness? If it’s adventurous, does it evoke excitement? Score: 0 47 | (no resonance) to 10 (strong resonance with the prompt’s theme). 48 | 49 | 2. Overall Score 50 | After scoring each aspect individually, provide an overall score, 51 | representing the model’s general performance on this image. This should be 52 | a weighted average based on the importance of each aspect to the prompt or an 53 | average of all aspects. 54 | """ -------------------------------------------------------------------------------- /tts/verifiers/geneval_verifier_prompt.txt: -------------------------------------------------------------------------------- 1 | """ 2 | You are a multimodal large-language model tasked with evaluating images 3 | generated by a text-to-image model. Your goal is to assess each generated 4 | image based on specific aspects and provide a detailed critique, along with 5 | a scoring system. The final output should be formatted as a JSON object 6 | containing individual scores for each aspect and an overall score. The keys 7 | in the JSON object should be: `single_object`, `two_object`, `counting`, 8 | `colors`, `position`, `color_attr`, and `overall_score`. Below is a comprehensive 9 | guide to follow in your evaluation process: 10 | 11 | 1. Key Evaluation Aspects and Scoring Criteria: 12 | For each aspect, provide a score from 0 to 10, where 0 represents poor 13 | performance and 10 represents excellent performance. For each score, include 14 | a short explanation or justification (1-2 sentences) explaining why that 15 | score was given. The aspects to evaluate are as follows: 16 | 17 | a) Single Object 18 | Assess the completeness and detectability of individual objects. Consider 19 | structural integrity, absence of deformations, and clear visibility against 20 | the background. Score: 0 (unrecognizable fragment) to 10 (perfectly defined 21 | and isolated object). 22 | 23 | b) Two Objects 24 | Evaluate separation quality and dual integrity. Check for clear boundaries 25 | between objects, appropriate spatial relationship, and preservation of 26 | distinct features for both elements. Score: 0 (merged/blended objects) to 27 | 10 (perfectly separated with individual clarity). 28 | 29 | c) Counting 30 | Verify accurate quantity representation and distinctiveness. Assess whether 31 | the correct number of objects are present, clearly distinguishable without 32 | overlap or occlusion. Score: 0 (critical counting errors) to 10 (exact count 33 | with unambiguous instances). 34 | 35 | d) Colors 36 | Evaluate color fidelity and contrast. Check if object colors match the prompt 37 | specifications and maintain consistency across multiple instances, with 38 | sufficient contrast against background. Score: 0 (color pollution/confusion) 39 | to 10 (precise color matching with high contrast). 40 | 41 | e) Position 42 | Analyze spatial relationships and occlusion handling. Verify if positional 43 | descriptors (left/right, above/below) are accurately represented, even with 44 | partial overlaps. Score: 0 (contradictory positioning) to 10 (pixel-perfect 45 | spatial alignment). 46 | 47 | f) Color-Attribute Binding 48 | Assess color-object association accuracy. Check if specific colors are 49 | correctly applied to designated objects and maintain differentiation between 50 | distinct elements. Score: 0 (color-attribute mismatch) to 10 (perfect 51 | color-object binding). 52 | 53 | 2. Overall Score 54 | After scoring each aspect individually, provide an overall score calculated 55 | as: 56 | (0.3×Single Object) + (0.2×Two Objects) + (0.15×Counting) + (0.15×Colors) 57 | + (0.1×Position) + (0.1×Color-Attribute Binding). Include a brief rationale 58 | for the weighting approach. 59 | """ -------------------------------------------------------------------------------- /train_flux/flux/condition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Union, List, Tuple 3 | from diffusers.pipelines import FluxPipeline 4 | from PIL import Image, ImageFilter 5 | import numpy as np 6 | import cv2 7 | 8 | from .pipeline_tools import encode_images 9 | 10 | condition_dict = { 11 | "depth": 0, 12 | "canny": 1, 13 | "subject": 4, 14 | "coloring": 6, 15 | "deblurring": 7, 16 | "depth_pred": 8, 17 | "fill": 9, 18 | "sr": 10, 19 | "cartoon": 11, 20 | "cot": 12, 21 | } 22 | 23 | 24 | class Condition(object): 25 | def __init__( 26 | self, 27 | condition_type: str, 28 | raw_img: Union[Image.Image, torch.Tensor] = None, 29 | condition: Union[Image.Image, torch.Tensor] = None, 30 | mask=None, 31 | position_delta=None, 32 | ) -> None: 33 | self.condition_type = condition_type 34 | assert raw_img is not None or condition is not None 35 | if raw_img is not None: 36 | self.condition = self.get_condition(condition_type, raw_img) 37 | else: 38 | self.condition = condition 39 | self.position_delta = position_delta 40 | # TODO: Add mask support 41 | assert mask is None, "Mask not supported yet" 42 | 43 | def get_condition( 44 | self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor] 45 | ) -> Union[Image.Image, torch.Tensor]: 46 | """ 47 | Returns the condition image. 48 | """ 49 | if condition_type == "depth": 50 | from transformers import pipeline 51 | 52 | depth_pipe = pipeline( 53 | task="depth-estimation", 54 | model="LiheYoung/depth-anything-small-hf", 55 | device="cuda", 56 | ) 57 | source_image = raw_img.convert("RGB") 58 | condition_img = depth_pipe(source_image)["depth"].convert("RGB") 59 | return condition_img 60 | elif condition_type == "canny": 61 | img = np.array(raw_img) 62 | edges = cv2.Canny(img, 100, 200) 63 | edges = Image.fromarray(edges).convert("RGB") 64 | return edges 65 | elif condition_type == "subject": 66 | return raw_img 67 | elif condition_type == "coloring": 68 | return raw_img.convert("L").convert("RGB") 69 | elif condition_type == "deblurring": 70 | condition_image = ( 71 | raw_img.convert("RGB") 72 | .filter(ImageFilter.GaussianBlur(10)) 73 | .convert("RGB") 74 | ) 75 | return condition_image 76 | elif condition_type == "fill": 77 | return raw_img.convert("RGB") 78 | elif condition_type == "cartoon": 79 | return raw_img.convert("RGB") 80 | return self.condition 81 | 82 | @property 83 | def type_id(self) -> int: 84 | """ 85 | Returns the type id of the condition. 86 | """ 87 | return condition_dict[self.condition_type] 88 | 89 | @classmethod 90 | def get_type_id(cls, condition_type: str) -> int: 91 | """ 92 | Returns the type id of the condition. 93 | """ 94 | return condition_dict[condition_type] 95 | 96 | def encode( 97 | self, pipe: FluxPipeline, empty: bool = False 98 | ) -> Tuple[torch.Tensor, torch.Tensor, int]: 99 | """ 100 | Encodes the condition into tokens, ids and type_id. 101 | """ 102 | if self.condition_type in [ 103 | "depth", 104 | "canny", 105 | "subject", 106 | "coloring", 107 | "deblurring", 108 | "depth_pred", 109 | "fill", 110 | "sr", 111 | "cartoon", 112 | "cot", 113 | ]: 114 | if empty: 115 | # make the condition black 116 | e_condition = Image.new("RGB", self.condition.size, (0, 0, 0)) 117 | e_condition = e_condition.convert("RGB") 118 | tokens, ids = encode_images(pipe, e_condition) 119 | else: 120 | tokens, ids = encode_images(pipe, self.condition) 121 | tokens, ids = encode_images(pipe, self.condition) 122 | else: 123 | raise NotImplementedError( 124 | f"Condition type {self.condition_type} not implemented" 125 | ) 126 | if self.position_delta is None and self.condition_type == "subject": 127 | self.position_delta = [0, -self.condition.size[0] // 16] 128 | if self.position_delta is not None: 129 | ids[:, 1] += self.position_delta[0] 130 | ids[:, 2] += self.position_delta[1] 131 | type_id = torch.ones_like(ids[:, :1]) * self.type_id 132 | return tokens, ids, type_id 133 | -------------------------------------------------------------------------------- /tts/tts_t2i_noise_scaling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import numpy as np 5 | import torch 6 | from diffusers import DiffusionPipeline 7 | from tqdm.auto import tqdm 8 | import copy 9 | from PIL import Image 10 | 11 | from utils import get_noises, TORCH_DTYPE_MAP, get_latent_prep_fn, parse_cli_args 12 | 13 | # Non-configurable constants 14 | MAX_SEED = np.iinfo(np.int32).max # To generate random seeds 15 | 16 | def sample( 17 | noises: dict[int, torch.Tensor], 18 | prompts: list[str], 19 | search_round: int, 20 | pipe: DiffusionPipeline, 21 | config: dict, 22 | original_prompt: str, 23 | midimg_path: str, 24 | ) -> dict: 25 | """ 26 | For a given prompt, generate images using all provided noises in batches, 27 | score them with the verifier, and select the top-K noise. 28 | The images and JSON artifacts are saved under `root_dir`. 29 | """ 30 | config_cp = copy.deepcopy(config) 31 | 32 | use_low_gpu_vram = config_cp.get("use_low_gpu_vram", False) 33 | batch_size_for_img_gen = config_cp.get("batch_size_for_img_gen", 1) 34 | 35 | images_for_prompt = [] 36 | noises_used = [] 37 | seeds_used = [] 38 | 39 | # Convert the noises dictionary into a list of (seed, noise) tuples. 40 | noise_items = list(noises.items()) 41 | 42 | # Process the noises in batches. 43 | full_imgnames = [] 44 | for i in range(0, len(noise_items), batch_size_for_img_gen): 45 | batch = noise_items[i : i + batch_size_for_img_gen] 46 | seeds_batch, noises_batch = zip(*batch) 47 | filenames_batch = [ 48 | os.path.join(midimg_path, f"{search_round}_round@{seed}.png") for seed in seeds_batch 49 | ] 50 | full_imgnames.extend(filenames_batch) 51 | 52 | if use_low_gpu_vram: 53 | pipe = pipe.to("cuda:0") 54 | print(f"Generating images for batch with seeds: {[s for s in seeds_batch]}.") 55 | 56 | # Create a batched prompt list and stack the latents. 57 | batched_latents = torch.stack(noises_batch).squeeze(dim=1) 58 | batched_prompts = prompts[i : i + batch_size_for_img_gen] 59 | # breakpoint() 60 | batch_result = pipe(prompt=batched_prompts, latents=batched_latents, guidance_scale=config_cp["pipeline_args"]["guidance_scale"], num_inference_steps=config_cp["pipeline_args"]["num_inference_steps"], height=config_cp["pipeline_args"]["height"], width=config_cp["pipeline_args"]["width"]) 61 | batch_images = batch_result.images 62 | if use_low_gpu_vram : 63 | pipe = pipe.to("cpu") 64 | 65 | # Iterate over the batch and save the images. 66 | for seed, noise, image, filename in zip(seeds_batch, noises_batch, batch_images, filenames_batch): 67 | images_for_prompt.append(image) 68 | noises_used.append(noise) 69 | seeds_used.append(seed) 70 | image.save(filename) 71 | 72 | datapoint = { 73 | "prompt": original_prompt, 74 | "search_round": search_round, 75 | "num_noises": len(noises), 76 | } 77 | return datapoint 78 | 79 | @torch.no_grad() 80 | def main(): 81 | """ 82 | Main function: 83 | - Parses CLI arguments. 84 | - Creates an output directory based on verifier and current datetime. 85 | - Loads prompts. 86 | - Loads the image-generation pipeline. 87 | - Loads the verifier model. 88 | - Runs several search rounds where for each prompt a pool of random noises is generated, 89 | candidate images are produced and verified, and the best noise is chosen. 90 | """ 91 | args = parse_cli_args() 92 | os.environ["API_KEY"] = os.environ["OPENAI_API_KEY"] # args.openai_api_key 93 | 94 | # Build a config dictionary for parameters that need to be passed around. 95 | with open(args.pipeline_config_path, "r") as f: 96 | config = json.load(f) 97 | 98 | config.update(vars(args)) 99 | 100 | search_rounds = config["search_args"]["search_rounds"] 101 | search_branch = config["search_args"]["search_branch"] 102 | 103 | # Create a root output directory: output/{verifier_to_use}/{current_datetime} 104 | pipeline_name = config["pipeline_args"].get("pretrained_model_name_or_path") 105 | cache_dir = config["pipeline_args"]["cache_dir"] 106 | root_dir = config["output_dir"] 107 | os.makedirs(root_dir, exist_ok=True) 108 | 109 | # Set up the image-generation pipeline (on the first GPU if available). 110 | torch_dtype = TORCH_DTYPE_MAP[config["pipeline_args"].get("torch_dtype")] 111 | pipe = DiffusionPipeline.from_pretrained(pipeline_name, torch_dtype=torch_dtype, cache_dir=cache_dir) 112 | if not config["use_low_gpu_vram"]: 113 | pipe = pipe.to("cuda:0") 114 | pipe.set_progress_bar_config(disable=True) 115 | 116 | # Main loop: For each search round and each prompt, generate images, verify, and save artifacts. 117 | with open(args.meta_path) as fp: 118 | metadatas = [json.loads(line) for line in fp] 119 | 120 | # meta splits 121 | if args.end_index == -1: 122 | metadatas = metadatas[args.start_index:] 123 | else: 124 | metadatas = metadatas[args.start_index:args.end_index] 125 | 126 | for index, metadata in tqdm(enumerate(metadatas), desc="Sampling prompts"): 127 | original_prompt = metadata['prompt'] 128 | current_prompts = [original_prompt] * search_branch 129 | # create output directory 130 | outpath = os.path.join(root_dir, f"{index + args.start_index:0>5}") 131 | os.makedirs(outpath, exist_ok=True) 132 | 133 | # create middle img directory 134 | midimg_path = os.path.join(outpath, "samples") 135 | os.makedirs(midimg_path, exist_ok=True) 136 | 137 | # create metadata file 138 | with open(os.path.join(outpath, "metadata.jsonl"), "w") as fp: 139 | json.dump(metadata, fp) 140 | 141 | for round in range(1, search_rounds + 1): 142 | print(f"\n=== Round: {round} ===") 143 | noises = get_noises( 144 | max_seed=MAX_SEED, 145 | num_samples=search_branch, 146 | height=config["pipeline_args"]["height"], 147 | width=config["pipeline_args"]["width"], 148 | dtype=torch_dtype, 149 | fn=get_latent_prep_fn(pipeline_name), 150 | ) 151 | datapoint = sample( 152 | noises=noises, 153 | prompts=current_prompts, 154 | search_round=round, 155 | pipe=pipe, 156 | config=config, 157 | original_prompt=original_prompt, 158 | midimg_path=midimg_path, 159 | ) 160 | 161 | 162 | if __name__ == "__main__": 163 | main() 164 | -------------------------------------------------------------------------------- /train_flux/train/train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torch 3 | import lightning as L 4 | import yaml 5 | import os 6 | import time 7 | 8 | from diffusers.utils.logging import set_verbosity_error 9 | set_verbosity_error() 10 | 11 | from .data import ImageConditionWebDataset 12 | from .model import OminiModel 13 | from .callbacks import TrainingCallback 14 | 15 | def get_rank(): 16 | try: 17 | rank = int(os.environ.get("LOCAL_RANK")) 18 | except: 19 | rank = 0 20 | return rank 21 | 22 | def get_config(): 23 | config_path = os.environ.get("XFL_CONFIG") 24 | assert config_path is not None, "Please set the XFL_CONFIG environment variable" 25 | with open(config_path, "r") as f: 26 | config = yaml.safe_load(f) 27 | return config 28 | 29 | 30 | def init_wandb(wandb_config, run_name): 31 | import wandb 32 | 33 | try: 34 | assert os.environ.get("WANDB_API_KEY") is not None 35 | wandb.init( 36 | project=wandb_config["project"], 37 | name=wandb_config["name"] if wandb_config["name"] else run_name, 38 | config={}, 39 | settings=wandb.Settings(start_method="fork"), 40 | ) 41 | except Exception as e: 42 | print("Failed to initialize WanDB:", e) 43 | 44 | 45 | def main(): 46 | # Initialize 47 | is_main_process, rank = get_rank() == 0, get_rank() 48 | torch.cuda.set_device(rank) 49 | config = get_config() 50 | training_config = config["train"] 51 | run_name = time.strftime("%Y%m%d-%H%M%S") 52 | 53 | # Initialize WanDB 54 | wandb_config = training_config.get("wandb", None) 55 | if wandb_config is not None and is_main_process: 56 | init_wandb(wandb_config, run_name) 57 | 58 | print("Rank:", rank) 59 | if is_main_process: 60 | print("Config:", config) 61 | 62 | # Initialize dataset and dataloader 63 | if training_config["dataset"]["type"] == "img": 64 | dataset = ImageConditionWebDataset( 65 | training_config["dataset"]["path"], 66 | condition_size=training_config["dataset"]["condition_size"], 67 | target_size=training_config["dataset"]["target_size"], 68 | condition_type=training_config["condition_type"], 69 | drop_text_prob=training_config["dataset"]["drop_text_prob"], 70 | drop_image_prob=training_config["dataset"]["drop_image_prob"], 71 | drop_reflection_prob=training_config["dataset"]["drop_reflection_prob"], 72 | root_dir=training_config["dataset"]["root_dir"], 73 | split_ratios=training_config["dataset"]["split_ratios"], 74 | training_stages=training_config["dataset"]["training_stages"], 75 | ) 76 | if "val_path" in training_config["dataset"]: 77 | val_dataset = ImageConditionWebDataset( 78 | training_config["dataset"]["val_path"], 79 | condition_size=training_config["dataset"]["condition_size"], 80 | target_size=training_config["dataset"]["target_size"], 81 | condition_type=training_config["condition_type"], 82 | root_dir=training_config["dataset"]["val_root_dir"], 83 | drop_text_prob=0, 84 | drop_image_prob=0, 85 | drop_reflection_prob=0, 86 | shuffle_buffer=0, 87 | ) 88 | else: 89 | val_dataset = None 90 | else: 91 | raise NotImplementedError 92 | 93 | train_loader = DataLoader( 94 | dataset, 95 | batch_size=training_config["batch_size"], 96 | num_workers=training_config["dataloader_workers"], 97 | ) 98 | if val_dataset is not None: 99 | val_loader = DataLoader( 100 | val_dataset, 101 | batch_size=1, 102 | shuffle=False, 103 | num_workers=0 104 | ) 105 | else: 106 | val_loader = None 107 | 108 | # Try add resume training 109 | lora_path = None 110 | 111 | if training_config['resume_training_from_last_checkpoint'] and os.path.exists(training_config['save_path']): 112 | # get latest directory in training_config['save_path'], ignore hidden files 113 | all_training_sessions = [d for d in os.listdir(training_config['save_path']) if not d.startswith('.')] 114 | all_training_sessions.sort(reverse=True) 115 | last_training_session = all_training_sessions[0] 116 | if os.path.exists(f"{training_config['save_path']}/{last_training_session}/ckpt"): 117 | ckpt_paths = [d for d in os.listdir(f"{training_config['save_path']}/{last_training_session}/ckpt") if not d.startswith('.')] 118 | ckpt_paths.sort(reverse=True) 119 | lora_path = f"{training_config['save_path']}/{last_training_session}/ckpt/{ckpt_paths[0]}" 120 | print(f"Resuming training from {lora_path}") 121 | else: 122 | print("No checkpoint found. Training without LoRA weights.") 123 | 124 | elif training_config['resume_training_from_checkpoint_path'] != "": 125 | _lora_path = training_config['resume_training_from_checkpoint_path'] 126 | # Check if the path exists 127 | if os.path.exists(_lora_path): 128 | lora_path = _lora_path 129 | print(f"Training with LoRA weights from {_lora_path}") 130 | else: 131 | print(f"Path {_lora_path} does not exist. Training without LoRA weights.") 132 | 133 | # Initialize model 134 | trainable_model = OminiModel( 135 | flux_pipe_id=config["model_path"], 136 | lora_path=lora_path, 137 | lora_config=training_config["lora_config"], 138 | data_config=training_config["dataset"], 139 | device=f"cuda:{rank}", 140 | dtype=getattr(torch, config["dtype"]), 141 | optimizer_config=training_config["optimizer"], 142 | model_config=config.get("model", {}), 143 | gradient_checkpointing=training_config.get("gradient_checkpointing", False), 144 | save_path=training_config.get("save_path", "./output"), 145 | run_name=run_name, 146 | cache_dir=config["cache_dir"], 147 | ) 148 | 149 | # Callbacks for logging and saving checkpoints 150 | training_callbacks = ( 151 | [TrainingCallback(run_name, training_config=training_config)] 152 | if is_main_process 153 | else [] 154 | ) 155 | 156 | # Initialize trainer 157 | trainer = L.Trainer( 158 | accumulate_grad_batches=training_config["accumulate_grad_batches"], 159 | callbacks=training_callbacks, 160 | enable_checkpointing=False, 161 | enable_progress_bar=False, 162 | logger=False, 163 | max_steps=training_config.get("max_steps", -1), 164 | max_epochs=training_config.get("max_epochs", -1), 165 | gradient_clip_val=training_config.get("gradient_clip_val", 0.5), 166 | val_check_interval=training_config.get("sample_interval", 1000), 167 | num_sanity_val_steps=training_config.get("num_sanity_val_steps", -1), 168 | ) 169 | 170 | setattr(trainer, "training_config", training_config) 171 | 172 | # Save config 173 | save_path = training_config.get("save_path", "./output") 174 | if is_main_process: 175 | os.makedirs(f"{save_path}/{run_name}") 176 | os.makedirs(f"{save_path}/{run_name}/val") 177 | with open(f"{save_path}/{run_name}/config.yaml", "w") as f: 178 | yaml.dump(config, f) 179 | 180 | # Start training 181 | trainer.fit(trainable_model, train_loader, val_loader) 182 | 183 | 184 | if __name__ == "__main__": 185 | main() 186 | -------------------------------------------------------------------------------- /reward_modeling/test_reward.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import os 4 | import pdb 5 | import argparse 6 | from collections.abc import Mapping 7 | import pandas as pd 8 | from tqdm import tqdm 9 | import sys 10 | sys.path.append('ReflectionFlow') 11 | sys.path.append('reward_modeling') 12 | 13 | import torch 14 | from vision_process import process_vision_info 15 | 16 | 17 | from data import DataConfig 18 | from reward_modeling.utils import ModelConfig, PEFTLoraConfig, TrainingConfig 19 | from reward_modeling.utils import load_model_from_checkpoint 20 | from train_reward import create_model_and_processor 21 | from prompt_template import build_prompt 22 | 23 | 24 | def load_configs_from_json(config_path): 25 | with open(config_path, "r") as f: 26 | config_dict = json.load(f) 27 | 28 | del config_dict["data_config"]["meta_data"] 29 | del config_dict["data_config"]["data_dir"] 30 | 31 | return config_dict["data_config"], None, config_dict["model_config"], config_dict["peft_lora_config"], \ 32 | config_dict["inference_config"] if "inference_config" in config_dict else None 33 | 34 | 35 | class ImageVLMRewardInference(): 36 | def __init__(self, load_from_pretrained, load_from_pretrained_step=-1, device='cuda', dtype=torch.bfloat16): 37 | config_path = os.path.join(load_from_pretrained, "model_config.json") 38 | data_config, _, model_config, peft_lora_config, inference_config = load_configs_from_json(config_path) 39 | data_config = DataConfig(**data_config) 40 | model_config = ModelConfig(**model_config) 41 | peft_lora_config = PEFTLoraConfig(**peft_lora_config) 42 | 43 | training_args = TrainingConfig( 44 | load_from_pretrained=load_from_pretrained, 45 | load_from_pretrained_step=load_from_pretrained_step, 46 | gradient_checkpointing=False, 47 | disable_flash_attn2=False, 48 | bf16=True if dtype == torch.bfloat16 else False, 49 | fp16=True if dtype == torch.float16 else False, 50 | output_dir="", 51 | ) 52 | 53 | model, processor, peft_config = create_model_and_processor( 54 | model_config=model_config, 55 | peft_lora_config=peft_lora_config, 56 | training_args=training_args, 57 | cache_dir="/ibex/user/zhaol0c/uniediting_continue/our_reward/initialreward" 58 | ) 59 | 60 | self.device = device 61 | 62 | model, checkpoint_step = load_model_from_checkpoint(model, load_from_pretrained, load_from_pretrained_step) 63 | model.eval() 64 | 65 | self.model = model 66 | self.processor = processor 67 | 68 | self.model.to(self.device) 69 | 70 | self.data_config = data_config 71 | 72 | self.inference_config = inference_config 73 | 74 | def _norm(self, reward): 75 | if self.inference_config is None: 76 | return reward 77 | else: 78 | reward['VQ'] = (reward['VQ'] - self.inference_config['VQ_mean']) / self.inference_config['VQ_std'] 79 | return reward 80 | 81 | def _pad_sequence(self, sequences, attention_mask, max_len, padding_side='right'): 82 | assert padding_side in ['right', 'left'] 83 | if sequences.shape[1] >= max_len: 84 | return sequences, attention_mask 85 | 86 | pad_len = max_len - sequences.shape[1] 87 | padding = (0, pad_len) if padding_side == 'right' else (pad_len, 0) 88 | 89 | sequences_padded = torch.nn.functional.pad(sequences, padding, 'constant', self.processor.tokenizer.pad_token_id) 90 | attention_mask_padded = torch.nn.functional.pad(attention_mask, padding, 'constant', 0) 91 | 92 | return sequences_padded, attention_mask_padded 93 | 94 | def _prepare_input(self, data): 95 | if isinstance(data, Mapping): 96 | return type(data)({k: self._prepare_input(v) for k, v in data.items()}) 97 | elif isinstance(data, (tuple, list)): 98 | return type(data)(self._prepare_input(v) for v in data) 99 | elif isinstance(data, torch.Tensor): 100 | kwargs = {"device": self.device} 101 | return data.to(**kwargs) 102 | return data 103 | 104 | def _prepare_inputs(self, inputs): 105 | inputs = self._prepare_input(inputs) 106 | if len(inputs) == 0: 107 | raise ValueError 108 | return inputs 109 | 110 | def prepare_batch(self, image_paths, prompts, max_pixels=None): 111 | max_pixels = self.data_config.max_frame_pixels if max_pixels is None else max_pixels 112 | 113 | chat_data = [ 114 | [ 115 | { 116 | "role": "user", 117 | "content": [ 118 | { 119 | "type": "image", 120 | "image": image_path, 121 | "max_pixels": max_pixels, 122 | }, 123 | {"type": "text", "text": build_prompt(prompt, self.data_config.eval_dim, self.data_config.prompt_template_type)}, 124 | ], 125 | }, 126 | ] for image_path, prompt in zip(image_paths, prompts) 127 | ] 128 | image_inputs, video_inputs = process_vision_info(chat_data) 129 | 130 | batch = self.processor( 131 | text=self.processor.apply_chat_template(chat_data, tokenize=False, add_generation_prompt=True), 132 | images=image_inputs, 133 | videos=video_inputs, 134 | padding=True, 135 | return_tensors="pt", 136 | videos_kwargs={"do_rescale": True}, 137 | ) 138 | batch = self._prepare_inputs(batch) 139 | 140 | return batch 141 | 142 | def reward(self, image_paths, prompts, max_pixels=None, use_norm=True): 143 | batch = self.prepare_batch(image_paths, prompts, max_pixels) 144 | rewards = self.model( 145 | return_dict=True, 146 | **batch 147 | )["logits"] 148 | 149 | rewards = [{'VQ': reward[0].item()} for reward in rewards] 150 | for i in range(len(rewards)): 151 | if use_norm: 152 | rewards[i] = self._norm(rewards[i]) 153 | rewards[i]['Overall'] = rewards[i]['VQ'] 154 | 155 | return rewards 156 | 157 | 158 | if __name__ == "__main__": 159 | parser = argparse.ArgumentParser(description="Video Alignment Reward Inference") 160 | parser.add_argument("--load_from_pretrained", type=str, required=True, 161 | help="Path to pretrained model") 162 | parser.add_argument("--device", type=str, default="cuda", 163 | help="Device to run inference on") 164 | parser.add_argument("--ckpt_step", type=int, default=-1, 165 | help="Checkpoint step for processing") 166 | args = parser.parse_args() 167 | 168 | inferencer = ImageVLMRewardInference(args.load_from_pretrained, load_from_pretrained_step=args.ckpt_step, device=args.device, dtype=torch.bfloat16) 169 | 170 | # 手动输入图像路径和 Prompt 171 | image_paths = ["/ibex/user/zhaol0c/uniediting_continue/nvilaverifier_exps/b2_d16_6000model/00548/samples_best/00001.png", "/ibex/user/zhaol0c/uniediting_continue/nvilaverifier_exps/b2_d16_6000model/00548/midimg/14_round@958442093.png"] 172 | prompts = ["a photo of a yellow bicycle and a red motorcycle", "a photo of a yellow bicycle and a red motorcycle"] 173 | 174 | # 进行打分 175 | with torch.no_grad(): 176 | rewards = inferencer.reward(image_paths, prompts, use_norm=True) 177 | breakpoint() 178 | print(f"Rewards: {rewards}") -------------------------------------------------------------------------------- /tts/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.utils.torch_utils import randn_tensor 3 | from diffusers import FluxPipeline 4 | import re 5 | import hashlib 6 | from typing import Dict 7 | import json 8 | from typing import Union 9 | from PIL import Image 10 | import requests 11 | import argparse 12 | import io 13 | 14 | 15 | TORCH_DTYPE_MAP = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} 16 | MODEL_NAME_MAP = { 17 | "black-forest-labs/FLUX.1-dev": "flux.1-dev", 18 | "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "pixart-sigma-1024-ms", 19 | "stabilityai/stable-diffusion-xl-base-1.0": "sdxl-base", 20 | "stable-diffusion-v1-5/stable-diffusion-v1-5": "sd-v1.5", 21 | } 22 | 23 | 24 | def parse_cli_args(): 25 | """ 26 | Parse and return CLI arguments. 27 | """ 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument( 30 | "--pipeline_config_path", 31 | type=str, 32 | default="configs/flux.1_dev_nvilascore.json", 33 | help="Pipeline configuration path that should include loading info and __call__() args and their values.", 34 | ) 35 | parser.add_argument( 36 | "--start_index", 37 | type=int, 38 | default=0, 39 | help="start index", 40 | ) 41 | parser.add_argument( 42 | "--end_index", 43 | type=int, 44 | default=-1, 45 | help="end index", 46 | ) 47 | parser.add_argument( 48 | "--imgpath", 49 | type=str, 50 | default="", 51 | help="path to generated images and their metadata", 52 | ) 53 | parser.add_argument( 54 | "--output_dir", 55 | type=str, 56 | default="output", 57 | help="output directory", 58 | ) 59 | parser.add_argument( 60 | "--meta_path", 61 | type=str, 62 | default="meta.jsonl", 63 | help="meta path", 64 | ) 65 | args = parser.parse_args() 66 | 67 | return args 68 | 69 | 70 | # Adapted from Diffusers. 71 | def prepare_latents_for_flux( 72 | batch_size: int, 73 | height: int, 74 | width: int, 75 | generator: torch.Generator, 76 | device: str, 77 | dtype: torch.dtype, 78 | ) -> torch.Tensor: 79 | num_latent_channels = 16 80 | vae_scale_factor = 8 81 | 82 | height = 2 * (int(height) // (vae_scale_factor * 2)) 83 | width = 2 * (int(width) // (vae_scale_factor * 2)) 84 | shape = (batch_size, num_latent_channels, height, width) 85 | latents = randn_tensor(shape, generator=generator, device=torch.device(device), dtype=dtype) 86 | latents = FluxPipeline._pack_latents(latents, batch_size, num_latent_channels, height, width) 87 | return latents 88 | 89 | 90 | # Adapted from Diffusers. 91 | def prepare_latents( 92 | batch_size: int, height: int, width: int, generator: torch.Generator, device: str, dtype: torch.dtype 93 | ): 94 | num_channels_latents = 4 95 | vae_scale_factor = 8 96 | shape = ( 97 | batch_size, 98 | num_channels_latents, 99 | int(height) // vae_scale_factor, 100 | int(width) // vae_scale_factor, 101 | ) 102 | latents = randn_tensor(shape, generator=generator, device=torch.device(device), dtype=dtype) 103 | return latents 104 | 105 | def prepare_latents_for_sd3( 106 | batch_size: int, height: int, width: int, generator: torch.Generator, device: str, dtype: torch.dtype 107 | ): 108 | num_channels_latents = 16 109 | vae_scale_factor = 8 110 | shape = ( 111 | batch_size, 112 | num_channels_latents, 113 | int(height) // vae_scale_factor, 114 | int(width) // vae_scale_factor, 115 | ) 116 | latents = randn_tensor(shape, generator=generator, device=torch.device(device), dtype=dtype) 117 | return latents 118 | 119 | 120 | def get_latent_prep_fn(pretrained_model_name_or_path: str) -> callable: 121 | fn_map = { 122 | "black-forest-labs/FLUX.1-dev": prepare_latents_for_flux, 123 | "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": prepare_latents, 124 | "stabilityai/stable-diffusion-xl-base-1.0": prepare_latents, 125 | "stable-diffusion-v1-5/stable-diffusion-v1-5": prepare_latents, 126 | "stabilityai/stable-diffusion-3-medium-diffusers": prepare_latents_for_sd3, 127 | }[pretrained_model_name_or_path] 128 | return fn_map 129 | 130 | 131 | def get_noises( 132 | max_seed: int, 133 | num_samples: int, 134 | height: int, 135 | width: int, 136 | device="cuda", 137 | dtype: torch.dtype = torch.bfloat16, 138 | fn: callable = prepare_latents_for_flux, 139 | ) -> Dict[int, torch.Tensor]: 140 | seeds = torch.randint(0, high=max_seed, size=(num_samples,)) 141 | 142 | noises = {} 143 | for noise_seed in seeds: 144 | latents = fn( 145 | batch_size=1, 146 | height=height, 147 | width=width, 148 | generator=torch.manual_seed(int(noise_seed)), 149 | device=device, 150 | dtype=dtype, 151 | ) 152 | noises.update({int(noise_seed): latents}) 153 | 154 | assert len(noises) == len(seeds) 155 | return noises 156 | 157 | def load_verifier_prompt(path: str): 158 | # 判断文件类型 159 | if path.endswith(".txt"): 160 | # 处理文本文件 161 | with open(path, "r") as f: 162 | verifier_prompt = f.read().replace('"""', "") 163 | return verifier_prompt 164 | elif path.endswith(".json"): 165 | # 处理 JSON 文件 166 | with open(path, "r") as f: 167 | data = json.load(f) # 加载 JSON 文件 168 | return data 169 | else: 170 | # 如果文件格式不支持,抛出异常 171 | raise ValueError("Unsupported file type. Please provide a .txt or .json file.") 172 | 173 | 174 | def prompt_to_filename(prompt, max_length=100): 175 | """Thanks ChatGPT.""" 176 | filename = re.sub(r"[^a-zA-Z0-9]", "_", prompt.strip()) 177 | filename = re.sub(r"_+", "_", filename) 178 | hash_digest = hashlib.sha256(prompt.encode()).hexdigest()[:8] 179 | base_filename = f"prompt@{filename}_hash@{hash_digest}" 180 | 181 | if len(base_filename) > max_length: 182 | base_length = max_length - len(hash_digest) - 7 183 | base_filename = f"prompt@{filename[:base_length]}_hash@{hash_digest}" 184 | 185 | return base_filename 186 | 187 | 188 | def load_image(path_or_url: Union[str, Image.Image]) -> Image.Image: 189 | """ 190 | Load an image from a local path or a URL and return a PIL Image object. 191 | 192 | `path_or_url` is returned as is if it's an `Image` already. 193 | """ 194 | if isinstance(path_or_url, Image.Image): 195 | return path_or_url 196 | elif path_or_url.startswith("http"): 197 | response = requests.get(path_or_url, stream=True) 198 | response.raise_for_status() 199 | return Image.open(io.BytesIO(response.content)) 200 | return Image.open(path_or_url) 201 | 202 | 203 | def convert_to_bytes(path_or_url: Union[str, Image.Image]) -> bytes: 204 | """Load an image from a path or URL and convert it to bytes.""" 205 | image = load_image(path_or_url).convert("RGB") 206 | image_bytes_io = io.BytesIO() 207 | image.save(image_bytes_io, format="PNG") 208 | return image_bytes_io.getvalue() 209 | 210 | 211 | def recover_json_from_output(output: str): 212 | start = output.find("{") 213 | end = output.rfind("}") + 1 214 | json_part = output[start:end] 215 | return json.loads(json_part) 216 | 217 | 218 | def get_batches(items, batch_size): 219 | num_batches = (len(items) + batch_size - 1) // batch_size 220 | batches = [] 221 | 222 | for i in range(num_batches): 223 | start_index = i * batch_size 224 | end_index = min((i + 1) * batch_size, len(items)) 225 | batch = items[start_index:end_index] 226 | batches.append(batch) 227 | 228 | return batches 229 | -------------------------------------------------------------------------------- /tts/verifier_filter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from datetime import datetime 4 | 5 | import numpy as np 6 | import torch 7 | from diffusers import DiffusionPipeline 8 | from tqdm.auto import tqdm 9 | import copy 10 | from typing import Union, List, Optional 11 | from PIL import Image 12 | import sys 13 | sys.path.append('../') 14 | from train_flux.flux.generate import generate 15 | from train_flux.flux.condition import Condition 16 | import time 17 | from verifiers.openai_verifier import OpenAIVerifier 18 | from verifiers.nvila_verifier import load_model 19 | 20 | from utils import prompt_to_filename, get_noises, TORCH_DTYPE_MAP, get_latent_prep_fn, parse_cli_args, MODEL_NAME_MAP 21 | 22 | # Non-configurable constants 23 | MAX_SEED = np.iinfo(np.int32).max # To generate random seeds 24 | MAX_RETRIES = 5 25 | RETRY_DELAY = 2 26 | 27 | @torch.no_grad() 28 | def main(): 29 | """ 30 | Main function: 31 | - Parses CLI arguments. 32 | - Creates an output directory based on verifier and current datetime. 33 | - Loads prompts. 34 | - Loads the image-generation pipeline. 35 | - Loads the verifier model. 36 | - Runs several search rounds where for each prompt a pool of random noises is generated, 37 | candidate images are produced and verified, and the best noise is chosen. 38 | """ 39 | args = parse_cli_args() 40 | 41 | # Build a config dictionary for parameters that need to be passed around. 42 | with open(args.pipeline_config_path, "r") as f: 43 | config = json.load(f) 44 | 45 | config.update(vars(args)) 46 | 47 | ### load nvila verifier for scoring 48 | verifier_args = config["verifier_args"] 49 | verifier, yes_id, no_id = load_model(model_name=verifier_args["model_name"], cache_dir=verifier_args["cache_dir"]) 50 | 51 | metadatas = [] 52 | for folder_name in sorted(os.listdir(args.imgpath)): 53 | folder_path = os.path.join(args.imgpath, folder_name) 54 | 55 | if os.path.isdir(folder_path): 56 | metadata_path = os.path.join(folder_path, 'metadata.jsonl') 57 | midimg_path = os.path.join(folder_path, 'midimg') 58 | 59 | with open(metadata_path, "r") as f: 60 | metadata = [json.loads(line) for line in f] 61 | folder_data = { 62 | 'metadata': metadata, 63 | 'images': [] 64 | } 65 | 66 | round_images = {} 67 | for file in sorted(os.listdir(midimg_path)): 68 | if file.endswith('.png'): 69 | round_key = file.split('_round@')[0] 70 | if round_key not in round_images: 71 | round_images[round_key] = [] 72 | round_images[round_key].append(file) 73 | 74 | # breakpoint() 75 | round_images = dict(sorted(round_images.items(), key=lambda x: int(x[0]))) 76 | for round_key, files in round_images.items(): 77 | for file in files: 78 | img_path = os.path.join(midimg_path, file) 79 | folder_data['images'].append({'img_path': img_path}) 80 | 81 | metadatas.append(folder_data) 82 | 83 | # meta splits 84 | if args.end_index == -1: 85 | metadatas = metadatas[args.start_index:] 86 | else: 87 | metadatas = metadatas[args.start_index:args.end_index] 88 | 89 | for index, metadata in tqdm(enumerate(metadatas), desc="Sampling data"): 90 | prompt = metadata['metadata'][0]['prompt'] 91 | imgs = metadata['images'] 92 | imgs = [tmp['img_path'] for tmp in imgs] 93 | cur_dir = os.path.dirname(imgs[0]) 94 | cur_dir = os.path.dirname(cur_dir) 95 | nfe1_path = os.path.join(cur_dir, "nfe1") 96 | nfe2_path = os.path.join(cur_dir, "nfe2") 97 | nfe4_path = os.path.join(cur_dir, "nfe4") 98 | nfe8_path = os.path.join(cur_dir, "nfe8") 99 | nfe16_path = os.path.join(cur_dir, "nfe16") 100 | nfe32_path = os.path.join(cur_dir, "nfe32") 101 | os.makedirs(nfe1_path, exist_ok=True) 102 | os.makedirs(nfe2_path, exist_ok=True) 103 | os.makedirs(nfe4_path, exist_ok=True) 104 | os.makedirs(nfe8_path, exist_ok=True) 105 | os.makedirs(nfe16_path, exist_ok=True) 106 | os.makedirs(nfe32_path, exist_ok=True) 107 | 108 | start_time = time.time() 109 | 110 | outputs = [] 111 | # nvila verifier 112 | for imgname in imgs: 113 | r1, scores1 = verifier.generate_content([Image.open(imgname), prompt]) 114 | if r1 == "yes": 115 | outputs.append({"image_name": imgname, "label": "yes", "score": scores1[0][0, yes_id].detach().cpu().float().item()}) 116 | else: 117 | outputs.append({"image_name": imgname, "label": "no", "score": scores1[0][0, no_id].detach().cpu().float().item()}) 118 | 119 | end_time = time.time() 120 | print(f"Time taken for evaluation: {end_time - start_time} seconds") 121 | 122 | # nvila verfier filter rule 123 | def f(x): 124 | if x["label"] == "yes": 125 | return (0, -x["score"]) 126 | else: 127 | return (1, x["score"]) 128 | 129 | # do nfe1 130 | sorted_list_nfe1 = sorted(outputs[:1], key=lambda x: f(x)) 131 | topk_scores_nfe1 = sorted_list_nfe1 132 | topk_idx_nfe1 = [outputs.index(x) for x in topk_scores_nfe1] 133 | selected_imgs_nfe1 = [imgs[i] for i in topk_idx_nfe1] 134 | img = Image.open(selected_imgs_nfe1[0]) 135 | img.save(os.path.join(nfe1_path, f"{0:05}.png")) 136 | 137 | # do nfe2 138 | sorted_list_nfe2 = sorted(outputs[:2], key=lambda x: f(x)) 139 | topk_scores_nfe2 = sorted_list_nfe2 140 | topk_idx_nfe2 = [outputs.index(x) for x in topk_scores_nfe2] 141 | selected_imgs_nfe2 = [imgs[i] for i in topk_idx_nfe2] 142 | img = Image.open(selected_imgs_nfe2[0]) 143 | img.save(os.path.join(nfe2_path, f"{0:05}.png")) 144 | 145 | # do nfe4 146 | sorted_list_nfe4 = sorted(outputs[:4], key=lambda x: f(x)) 147 | topk_scores_nfe4 = sorted_list_nfe4[:4] 148 | topk_idx_nfe4 = [outputs.index(x) for x in topk_scores_nfe4] 149 | selected_imgs_nfe4 = [imgs[i] for i in topk_idx_nfe4] 150 | img = Image.open(selected_imgs_nfe4[0]) 151 | img.save(os.path.join(nfe4_path, f"{0:05}.png")) 152 | 153 | # do nfe8 154 | sorted_list_nfe8 = sorted(outputs[:8], key=lambda x: f(x)) 155 | topk_scores_nfe8 = sorted_list_nfe8[:8] 156 | topk_idx_nfe8 = [outputs.index(x) for x in topk_scores_nfe8] 157 | selected_imgs_nfe8 = [imgs[i] for i in topk_idx_nfe8] 158 | img = Image.open(selected_imgs_nfe8[0]) 159 | img.save(os.path.join(nfe8_path, f"{0:05}.png")) 160 | 161 | # do nfe16 162 | sorted_list_nfe16 = sorted(outputs[:16], key=lambda x: f(x)) 163 | topk_scores_nfe16 = sorted_list_nfe16[:16] 164 | topk_idx_nfe16 = [outputs.index(x) for x in topk_scores_nfe16] 165 | selected_imgs_nfe16 = [imgs[i] for i in topk_idx_nfe16] 166 | img = Image.open(selected_imgs_nfe16[0]) 167 | img.save(os.path.join(nfe16_path, f"{0:05}.png")) 168 | 169 | # do nfe32 170 | # breakpoint() 171 | sorted_list_nfe32 = sorted(outputs[:32], key=lambda x: f(x)) 172 | topk_scores_nfe32 = sorted_list_nfe32[:32] 173 | topk_idx_nfe32 = [outputs.index(x) for x in topk_scores_nfe32] 174 | selected_imgs_nfe32 = [imgs[i] for i in topk_idx_nfe32] 175 | img = Image.open(selected_imgs_nfe32[0]) 176 | img.save(os.path.join(nfe32_path, f"{0:05}.png")) 177 | 178 | 179 | if __name__ == "__main__": 180 | main() 181 | -------------------------------------------------------------------------------- /train_flux/sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import random 5 | import torch 6 | import numpy as np 7 | import argparse 8 | from diffusers.pipelines import FluxPipeline, StableDiffusion3Pipeline 9 | from src.flux.condition import Condition 10 | from src.sd3.condition import Condition as ConditionSD3 11 | from PIL import Image 12 | 13 | from src.flux.generate import generate, seed_everything 14 | from src.sd3.generate import generate as generate_sd3 15 | 16 | # Parse command line arguments 17 | parser = argparse.ArgumentParser(description="Run FLUX pipeline with reflection prompts") 18 | parser.add_argument("--model_name", type=str, default="flux", help="Model name") 19 | parser.add_argument("--step", type=int, default=30, help="Number of inference steps") 20 | parser.add_argument("--condition_size", type=int, default=1024, help="Size of condition image") 21 | parser.add_argument("--target_size", type=int, default=1024, help="Size of target image") 22 | parser.add_argument("--task_name", type=str, default="geneval", 23 | choices=["edit", "geneval", "flux_pro_short", "flux_pro_detailed", "flux_pro"], 24 | help="Task name for selecting the appropriate dataset") 25 | parser.add_argument("--lora_dir", type=str, 26 | default="/mnt/petrelfs/gaopeng/zl/ReflectionFlow/train_flux/runs/full_data_v2/20250227-040606/ckpt/5000/pytorch_lora_weights.safetensors", 27 | help="Path to LoRA weights") 28 | parser.add_argument("--output_dir", type=str, 29 | default="/mnt/petrelfs/gaopeng/zl/ReflectionFlow/train_flux/samples/full_data_1024_5k", 30 | help="Base directory for output") 31 | parser.add_argument("--root_dir", type=str, default="", help="Root directory for image paths") 32 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 33 | parser.add_argument("--guidance_scale", type=float, default=3.5, help="Guidance scale") 34 | parser.add_argument("--image_guidance_scale", type=float, default=1.0, help="Guidance scale") 35 | 36 | args = parser.parse_args() 37 | 38 | # Set variables from parsed arguments 39 | step = args.step 40 | condition_size = args.condition_size 41 | target_size = args.target_size 42 | task_name = args.task_name 43 | lora_dir = args.lora_dir 44 | output_dir = args.output_dir 45 | root_dir = args.root_dir 46 | 47 | # Set json_dir based on task_name 48 | if task_name == "edit": 49 | json_dir = "/mnt/petrelfs/gaopeng/zl/data/reflection/metadata/edit_reflection_cleaned_val.json" 50 | output_dir = os.path.join(output_dir, "edit") 51 | elif task_name == "geneval": 52 | json_dir = "/mnt/petrelfs/zhuole/data/metadata_clean/geneval_pairs_val.json" 53 | output_dir = os.path.join(output_dir, "geneval") 54 | elif task_name == "flux_pro_short": 55 | json_dir = "/mnt/petrelfs/gaopeng/zl/data/reflection/metadata/flux_pro_detailed_reflection_cleaned_val.json" 56 | output_dir = os.path.join(output_dir, "flux_pro_short") 57 | elif task_name == "flux_pro_detailed": 58 | json_dir = "/mnt/petrelfs/gaopeng/zl/data/reflection/metadata/flux_pro_detailed_reflection_cleaned_val.json" 59 | output_dir = os.path.join(output_dir, "flux_pro_detailed") 60 | elif task_name == "flux_pro": 61 | json_dir = "/mnt/petrelfs/gaopeng/zl/data/reflection/metadata/flux_pro_reflection_cleaned_val.json" 62 | output_dir = os.path.join(output_dir, "flux_pro") 63 | else: 64 | raise ValueError(f"Invalid task name: {task_name}") 65 | os.makedirs(output_dir, exist_ok=True) 66 | 67 | if args.model_name == "flux": 68 | pipe = FluxPipeline.from_pretrained( 69 | "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 70 | ) 71 | else: 72 | pipe = StableDiffusion3Pipeline.from_pretrained( 73 | "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.bfloat16, cache_dir = "/ibex/user/zhaol0c/uniediting/ReflectionFlow-main/sd3medium" 74 | ) 75 | pipe = pipe.to("cuda") 76 | pipe.load_lora_weights( 77 | lora_dir, 78 | adapter_name="cot", 79 | ) 80 | 81 | # Load the JSON list containing prompt and image details 82 | print(f"Loading JSON from {json_dir}") 83 | with open(json_dir, "r") as f: 84 | # items = [json.loads(line.strip()) for line in f] 85 | items = json.load(f) 86 | 87 | seed_everything(args.seed) 88 | 89 | for idx, item in enumerate(items): 90 | # Open the bad image and good image 91 | # bad_image_path = os.path.join(root_dir, item["bad_image"]) 92 | # good_image_path = os.path.join(root_dir, item["good_image"]) 93 | bad_image_path = item["bad_image"] 94 | good_image_path = item["good_image"] 95 | 96 | # Load images 97 | bad_image = Image.open(bad_image_path).convert("RGB") 98 | good_image = Image.open(good_image_path).convert("RGB") 99 | 100 | # Get dimensions 101 | bad_w, bad_h = bad_image.size 102 | good_w, good_h = good_image.size 103 | 104 | # Resize bad image to match good image dimensions 105 | bad_image = bad_image.resize((good_w, good_h), Image.BICUBIC) 106 | 107 | # Resize the shorter edge to target_size while maintaining aspect ratio 108 | ratio = target_size / min(good_w, good_h) 109 | new_w = math.ceil(good_w * ratio) 110 | new_h = math.ceil(good_h * ratio) 111 | 112 | # Resize both images to the same dimensions 113 | good_image = good_image.resize((new_w, new_h), Image.BICUBIC) 114 | bad_image = bad_image.resize((new_w, new_h), Image.BICUBIC) 115 | 116 | # Randomly crop both images to exactly target_size x target_size 117 | if new_w > target_size or new_h > target_size: 118 | left = random.randint(0, max(0, new_w - target_size)) 119 | top = random.randint(0, max(0, new_h - target_size)) 120 | 121 | # Apply the same crop to both images to maintain pixel correspondence 122 | good_image = good_image.crop((left, top, left + target_size, top + target_size)) 123 | bad_image = bad_image.crop((left, top, left + target_size, top + target_size)) 124 | 125 | # Finally, resize bad_image to condition_size 126 | image = bad_image.resize((condition_size, condition_size), Image.BICUBIC) 127 | 128 | # Create a condition for the pipeline 129 | if args.model_name == "flux": 130 | condition_cls = Condition 131 | else: 132 | condition_cls = ConditionSD3 133 | condition = condition_cls( 134 | condition_type="cot", 135 | condition=image, 136 | position_delta=np.array([0, -condition_size // 16]) 137 | ) 138 | 139 | # Build the prompt by combining base prompt and reflection prompt if available 140 | original_prompt = item["prompt"] 141 | prompt = item["prompt"] 142 | if "reflection_prompt" in item: 143 | prompt += " [Reflexion] " + item["reflection_prompt"] 144 | elif "instruction" in item: 145 | prompt += " [Reflexion] " + item["instruction"] 146 | elif "reflection" in item: 147 | prompt += " [Reflexion] " + item["reflection"] 148 | elif "edited_prompt_list" in item: 149 | prompt += " [Reflexion] " + item["edited_prompt_list"][-1] 150 | else: 151 | raise ValueError(f"No reflection found in item: {item}") 152 | 153 | # Generate the result image 154 | if args.model_name == "flux": 155 | generate_func = generate 156 | else: 157 | generate_func = generate_sd3 158 | # breakpoint() 159 | result_img = generate_func( 160 | pipe, 161 | prompt=original_prompt, 162 | prompt_2=prompt, 163 | conditions=[condition], 164 | num_inference_steps=step, 165 | height=target_size, 166 | width=target_size, 167 | guidance_scale=args.guidance_scale, 168 | image_guidance_scale=args.image_guidance_scale 169 | ).images[0] 170 | 171 | # Concatenate bad image, good image, and generated image side by side 172 | concat_image = Image.new("RGB", (condition_size + target_size + target_size, target_size)) 173 | concat_image.paste(image, (0, 0)) 174 | concat_image.paste(good_image, (condition_size, 0)) 175 | concat_image.paste(result_img, (condition_size + target_size, 0)) 176 | 177 | # Save the concatenated image, using image_id if present 178 | output_name = item.get("image_id", f"result_{idx}") 179 | concat_image.save(os.path.join(output_dir, f"{output_name}.jpg")) -------------------------------------------------------------------------------- /reward_modeling/prompt_template.py: -------------------------------------------------------------------------------- 1 | 2 | VIDEOSCORE_QUERY_PROMPT = """ 3 | Suppose you are an expert in judging and evaluating the quality of AI-generated videos, 4 | please watch the frames of a given video and see the text prompt for generating the video, 5 | then give scores based on its {dimension_name}, i.e., {dimension_description}. 6 | Output a float number from 1.0 to 5.0 for this dimension, 7 | the higher the number is, the better the video performs in that sub-score, 8 | the lowest 1.0 means Bad, the highest 5.0 means Perfect/Real (the video is like a real video). 9 | The text prompt used for generation is "{text_prompt}". 10 | """ 11 | 12 | DIMENSION_DESCRIPTIONS = { 13 | 'VQ': ['visual quality', 'the quality of the video in terms of clearness, resolution, brightness, and color'], 14 | 'TA': ['text-to-video alignment', 'the alignment between the text prompt and the video content and motion'], 15 | 'MQ': ['motion quality', 'the quality of the motion in terms of consistency, smoothness, and completeness'], 16 | 'Overall': ['Overall Performance', 'the overall performance of the video in terms of visual quality, text-to-video alignment, and motion quality'], 17 | } 18 | 19 | SIMPLE_PROMPT = """ 20 | Please evaluate the {dimension_name} of a generated video. Consider {dimension_description}. 21 | The text prompt used for generation is "{text_prompt}". 22 | """ 23 | 24 | DETAILED_PROMPT_WITH_SPECIAL_TOKEN = """ 25 | You are tasked with evaluating a generated image based on two distinct criteria: Visual Quality and Text Alignment. Please provide a overall rating from 0 to 10, with 0 being the worst and 10 being the best. 26 | 27 | **Visual Quality:** 28 | Evaluate the overall visual quality of the image. The following sub-dimensions should be considered: 29 | - **Reasonableness:** The image should not contain any significant biological or logical errors, such as abnormal body structures or nonsensical environmental setups. 30 | - **Clarity:** Evaluate the sharpness and visibility of the image. The image should be clear and easy to interpret, with no blurring or indistinct areas. 31 | - **Detail Richness:** Consider the level of detail in textures, materials, lighting, and other visual elements (e.g., hair, clothing, shadows). 32 | - **Aesthetic and Creativity:** Assess the artistic aspects of the image, including the color scheme, composition, atmosphere, depth of field, and the overall creative appeal. The scene should convey a sense of harmony and balance. 33 | 34 | **Text Alignment:** 35 | Assess how well the image matches the textual prompt across the following sub-dimensions: 36 | - **Subject Relevance** Evaluate how accurately the subject(s) in the image (e.g., person, animal, object) align with the textual description. The subject should match the description in terms of number, appearance, interaction, etc. 37 | - **Environment Relevance:** Assess whether the background and scene fit the prompt. This includes checking if real-world locations or scenes are accurately represented, though some stylistic adaptation is acceptable. 38 | - **Style Relevance:** If the prompt specifies a particular artistic or stylistic style, evaluate how well the image adheres to this style. 39 | 40 | Textual prompt - {text_prompt} 41 | Please provide the overall rating: <|VQ_reward|> 42 | """ 43 | 44 | DETAILED_PROMPT = """ 45 | You are tasked with evaluating a generated video based on three distinct criteria: Visual Quality, Motion Quality, and Text Alignment. Please provide a rating from 0 to 10 for each of the three categories, with 0 being the worst and 10 being the best. Each evaluation should be independent of the others. 46 | 47 | **Visual Quality:** 48 | Evaluate the overall visual quality of the video, with a focus on static factors. The following sub-dimensions should be considered: 49 | - **Reasonableness:** The video should not contain any significant biological or logical errors, such as abnormal body structures or nonsensical environmental setups. 50 | - **Clarity:** Evaluate the sharpness and visibility of the video. The image should be clear and easy to interpret, with no blurring or indistinct areas. 51 | - **Detail Richness:** Consider the level of detail in textures, materials, lighting, and other visual elements (e.g., hair, clothing, shadows). 52 | - **Aesthetic and Creativity:** Assess the artistic aspects of the video, including the color scheme, composition, atmosphere, depth of field, and the overall creative appeal. The scene should convey a sense of harmony and balance. 53 | - **Safety:** The video should not contain harmful or inappropriate content, such as political, violent, or adult material. If such content is present, the image quality and satisfaction score should be the lowest possible. 54 | 55 | **Motion Quality:** 56 | Assess the dynamic aspects of the video, with a focus on dynamic factors. Consider the following sub-dimensions: 57 | - **Stability:** Evaluate the continuity and stability between frames. There should be no sudden, unnatural jumps, and the video should maintain stable attributes (e.g., no fluctuating colors, textures, or missing body parts). 58 | - **Naturalness:** The movement should align with physical laws and be realistic. For example, clothing should flow naturally with motion, and facial expressions should change appropriately (e.g., blinking, mouth movements). 59 | - **Aesthetic Quality:** The movement should be smooth and fluid. The transitions between different motions or camera angles should be seamless, and the overall dynamic feel should be visually pleasing. 60 | - **Fusion:** Ensure that elements in motion (e.g., edges of the subject, hair, clothing) blend naturally with the background, without obvious artifacts or the feeling of cut-and-paste effects. 61 | - **Clarity of Motion:** The video should be clear and smooth in motion. Pay attention to any areas where the video might have blurry or unsteady sections that hinder visual continuity. 62 | - **Amplitude:** If the video is largely static or has little movement, assign a low score for motion quality. 63 | 64 | 65 | **Text Alignment:** 66 | Assess how well the video matches the textual prompt across the following sub-dimensions: 67 | - **Subject Relevance** Evaluate how accurately the subject(s) in the video (e.g., person, animal, object) align with the textual description. The subject should match the description in terms of number, appearance, and behavior. 68 | - **Motion Relevance:** Evaluate if the dynamic actions (e.g., gestures, posture, facial expressions like talking or blinking) align with the described prompt. The motion should match the prompt in terms of type, scale, and direction. 69 | - **Environment Relevance:** Assess whether the background and scene fit the prompt. This includes checking if real-world locations or scenes are accurately represented, though some stylistic adaptation is acceptable. 70 | - **Style Relevance:** If the prompt specifies a particular artistic or stylistic style, evaluate how well the video adheres to this style. 71 | - **Camera Movement Relevance:** Check if the camera movements (e.g., following the subject, focus shifts) are consistent with the expected behavior from the prompt. 72 | 73 | Textual prompt - {text_prompt} 74 | Please provide the ratings of Visual Quality, Motion Quality, and Text Alignment. 75 | """ 76 | 77 | SIMPLE_PROMPT_NO_PROMPT = """ 78 | Please evaluate the {dimension_name} of a generated video. Consider {dimension_description}. 79 | """ 80 | 81 | def build_prompt(prompt, dimension, template_type): 82 | if isinstance(dimension, list) and len(dimension) > 1: 83 | dimension_name = ", ".join([DIMENSION_DESCRIPTIONS[d][0] for d in dimension]) 84 | dimension_name = f'overall performance({dimension_name})' 85 | dimension_description = "the overall performance of the video" 86 | else: 87 | if isinstance(dimension, list): 88 | dimension = dimension[0] 89 | dimension_name = DIMENSION_DESCRIPTIONS[dimension][0] 90 | dimension_description = DIMENSION_DESCRIPTIONS[dimension][1] 91 | 92 | if template_type == "none": 93 | return prompt 94 | elif template_type == "simple": 95 | return SIMPLE_PROMPT.format(dimension_name=dimension_name, 96 | dimension_description=dimension_description, 97 | text_prompt=prompt) 98 | elif template_type == "video_score": 99 | return VIDEOSCORE_QUERY_PROMPT.format(dimension_name=dimension_name, 100 | dimension_description=dimension_description, 101 | text_prompt=prompt) 102 | elif template_type == "detailed_special": 103 | return DETAILED_PROMPT_WITH_SPECIAL_TOKEN.format(text_prompt=prompt) 104 | elif template_type == "detailed": 105 | return DETAILED_PROMPT.format(text_prompt=prompt) 106 | else: 107 | raise ValueError("Invalid template type") 108 | -------------------------------------------------------------------------------- /train_flux/train/data.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import glob 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import webdataset as wds 8 | import torch 9 | from torch.utils.data import IterableDataset, DataLoader 10 | import torchvision.transforms as T 11 | 12 | # Taken from https://github.com/tmbdev-archive/webdataset-imagenet-2/blob/01a4ab54307b9156c527d45b6b171f88623d2dec/imagenet.py#L65. 13 | def nodesplitter(src, group=None): 14 | if torch.distributed.is_initialized(): 15 | if group is None: 16 | group = torch.distributed.group.WORLD 17 | rank = torch.distributed.get_rank(group=group) 18 | size = torch.distributed.get_world_size(group=group) 19 | count = 0 20 | for i, item in enumerate(src): 21 | if i % size == rank: 22 | yield item 23 | count += 1 24 | else: 25 | yield from src 26 | 27 | class ImageConditionWebDataset(IterableDataset): 28 | def __init__( 29 | self, 30 | shards_pattern: str, 31 | condition_size: int = 1024, 32 | target_size: int = 1024, 33 | condition_type: str = "cot", 34 | drop_text_prob: float = 0.1, 35 | drop_image_prob: float = 0.1, 36 | drop_reflection_prob: float = 0.2, 37 | root_dir: str = "", 38 | split_ratios: dict = None, # e.g. {"general":[.7,.3], "length":[.3,.7], …} 39 | training_stages: list = None, # e.g. [0, 10000, 20000] 40 | return_pil_image: bool = False, 41 | shuffle_buffer: int = 1000, 42 | ): 43 | super().__init__() 44 | self.condition_size = condition_size 45 | self.target_size = target_size 46 | self.condition_type = condition_type 47 | self.drop_text_prob = drop_text_prob 48 | self.drop_image_prob = drop_image_prob 49 | self.drop_reflection_prob = drop_reflection_prob 50 | self.return_pil_image = return_pil_image 51 | 52 | # prepare WebDataset pipelines per subset 53 | self.splits = list(split_ratios.keys()) 54 | self.all_split_ratios = split_ratios 55 | # start with first stage’s ratios 56 | self.split_ratios = {s: split_ratios[s][0] for s in self.splits} 57 | self.training_stages = training_stages or [0] 58 | 59 | # one independent pipeline for each subset 60 | self.datasets = {} 61 | if "https://" not in shards_pattern: 62 | shards_pattern = glob.glob(shards_pattern) 63 | for split in self.splits: 64 | # Define mandatory keys that must exist in each sample 65 | MANDATORY_KEYS = {"good_image.jpg", "bad_image.jpg", "reflection.txt", "prompt.txt", "subset.txt"} 66 | 67 | ds = ( 68 | wds.WebDataset(shards_pattern, handler=wds.ignore_and_continue, nodesplitter=nodesplitter) 69 | # First filter samples to ensure all required keys exist 70 | .select(lambda sample: set(sample.keys()) >= MANDATORY_KEYS) 71 | .shuffle(shuffle_buffer) 72 | .decode("pil") # good_image.jpg / bad_image.jpg → PIL 73 | .to_tuple( 74 | "good_image.jpg", 75 | "bad_image.jpg", 76 | "reflection.txt", 77 | "prompt.txt", 78 | "subset.txt", 79 | ) 80 | # keep only records whose subset matches this split 81 | .select(lambda sample: sample[4] == split) 82 | ) 83 | self.datasets[split] = ds 84 | 85 | # create one iterator per subset 86 | self.iters = {s: iter(ds) for s, ds in self.datasets.items()} 87 | self.to_tensor = T.ToTensor() 88 | self.iteration = 0 89 | 90 | def _update_split_ratios(self): 91 | itr = self.iteration 92 | stages = self.training_stages 93 | # beyond last => use last ratios 94 | if itr >= stages[-1]: 95 | for s in self.splits: 96 | self.split_ratios[s] = self.all_split_ratios[s][-1] 97 | return 98 | 99 | # find current stage index 100 | idx = max(i for i, t in enumerate(stages) if itr >= t) 101 | next_idx = min(idx+1, len(stages)-1) 102 | start, end = stages[idx], stages[next_idx] 103 | progress = (itr - start) / (end - start) if end>start else 1.0 104 | 105 | for s in self.splits: 106 | r0 = self.all_split_ratios[s][idx] 107 | r1 = self.all_split_ratios[s][next_idx] 108 | self.split_ratios[s] = r0 + progress*(r1-r0) 109 | 110 | def _preprocess_pair(self, good: Image.Image, bad: Image.Image): 111 | # match bad → good dims 112 | gw, gh = good.size 113 | bad = bad.resize((gw, gh), Image.BICUBIC) 114 | # scale shorter edge → target_size 115 | ratio = self.target_size / min(gw, gh) 116 | nw, nh = math.ceil(gw*ratio), math.ceil(gh*ratio) 117 | good = good.resize((nw, nh), Image.BICUBIC) 118 | bad = bad.resize((nw, nh), Image.BICUBIC) 119 | 120 | # same random crop 121 | if nw>self.target_size or nh>self.target_size: 122 | left = random.randint(0, nw-self.target_size) 123 | top = random.randint(0, nh-self.target_size) 124 | box = (left, top, left+self.target_size, top+self.target_size) 125 | good = good.crop(box) 126 | bad = bad.crop(box) 127 | 128 | # final resize bad → condition_size 129 | bad = bad.resize((self.condition_size, self.condition_size), Image.BICUBIC) 130 | return good, bad 131 | 132 | def __iter__(self): 133 | while True: 134 | # 1) update dynamic ratios 135 | self._update_split_ratios() 136 | 137 | # 2) pick a split by current weights 138 | split = random.choices( 139 | self.splits, 140 | weights=[self.split_ratios[s] for s in self.splits], 141 | k=1, 142 | )[0] 143 | 144 | # 3) pull next sample (re‑reset iterator on exhaustion) 145 | try: 146 | good, bad, ref_bytes, prom_bytes, sub_bytes = next(self.iters[split]) 147 | except StopIteration: 148 | self.iters[split] = iter(self.datasets[split]) 149 | good, bad, ref_bytes, prom_bytes, sub_bytes = next(self.iters[split]) 150 | 151 | # decode text 152 | reflection = ref_bytes 153 | prompt = prom_bytes 154 | subset = sub_bytes 155 | 156 | # convert to RGB 157 | good = good.convert("RGB") 158 | bad = bad.convert("RGB") 159 | 160 | # 4) apply your resize/crop logic 161 | good, bad = self._preprocess_pair(good, bad) 162 | 163 | # 5) decide drops 164 | drop_text = random.random() < self.drop_text_prob 165 | drop_image_flag = random.random() < self.drop_image_prob and subset!="editing" 166 | drop_reflection = ( 167 | random.random() < self.drop_reflection_prob 168 | or len(reflection)<5 169 | ) 170 | 171 | if drop_reflection or drop_image_flag: 172 | description = prompt 173 | else: 174 | description = f"{prompt} [Reflexion] {reflection}" 175 | if drop_text: 176 | description = "" 177 | if drop_image_flag: 178 | # black out condition 179 | bad = Image.new("RGB", (self.condition_size, self.condition_size), (0,0,0)) 180 | 181 | # 6) to tensors 182 | image = self.to_tensor(good) 183 | condition = self.to_tensor(bad) 184 | 185 | out = { 186 | "image": image, 187 | "condition": condition, 188 | "original_prompt": prompt, 189 | "condition_type": self.condition_type, 190 | "description": description, 191 | "position_delta": np.array([0, -self.condition_size//16]), 192 | "subset": subset 193 | } 194 | if self.return_pil_image: 195 | out["pil_image"] = [good, bad] 196 | 197 | self.iteration += 1 198 | yield out 199 | 200 | # usage: 201 | if __name__ == "__main__": 202 | 203 | split_ratios = { 204 | "general": [0.8, 0.6, 0.4], 205 | "length": [0.1, 0.2, 0.3], 206 | "rule": [0.1, 0.2, 0.3], 207 | "editing": [0.0, 0.0, 0.0], 208 | } 209 | training_stages = [0, 10000, 20000] 210 | 211 | # local path should be provided as 212 | # shards_pattern = "DIR_WHERE_GenRef-wds_IS_DOWNLOADED/*.tar" 213 | shards_pattern = "/ceph/hf-cache/hub/datasets--diffusion-cot--GenRef-wds/snapshots/42c837b891fc34a944ed0c8124876a7e8225266f/*.tar" 214 | dataset = ImageConditionWebDataset( 215 | shards_pattern=shards_pattern, 216 | condition_size=1024, 217 | target_size=1024, 218 | condition_type="cot", 219 | drop_text_prob=0.1, 220 | drop_image_prob=0.1, 221 | drop_reflection_prob=0.2, 222 | split_ratios=split_ratios, 223 | training_stages=training_stages, 224 | return_pil_image=False, 225 | ) 226 | 227 | loader = DataLoader(dataset, batch_size=8, num_workers=4) 228 | 229 | # iterate: 230 | from tqdm import tqdm 231 | for batch in tqdm(loader): 232 | continue 233 | # print(batch.keys()) 234 | # print(batch["image"].size()) 235 | # print(batch["condition"].size()) 236 | # break 237 | -------------------------------------------------------------------------------- /reward_modeling/data.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | from dataclasses import dataclass 3 | from typing import Optional, List, Union 4 | 5 | import pandas as pd 6 | import torch 7 | from prompt_template import build_prompt 8 | # from qwen_vl_utils import process_vision_info 9 | from vision_process import process_vision_info 10 | from torch.utils.data import Dataset 11 | import torchvision.transforms.functional as F 12 | 13 | 14 | @dataclass 15 | class DataConfig: 16 | meta_data: str = "/path/to/dataset/meta_data.csv" 17 | data_dir: str = "/path/to/dataset" 18 | meta_data_test: str = None 19 | max_frame_pixels: int = 240 * 320 20 | num_frames: float = None 21 | fps: float = 2.0 22 | p_shuffle_frames: float = 0.0 23 | p_color_jitter: float = 0.0 24 | eval_dim: Union[str, List[str]] = "VQ" 25 | prompt_template_type: str = "none" 26 | add_noise: bool = False 27 | sample_type: str = "uniform" 28 | use_tied_data: bool = True 29 | 30 | def convert_GSB_csv_to_reward_data(example, data_dir, eval_dims=["VQ"], max_pixels=448 * 448, prompt_template_type="none"): 31 | """ 32 | Convert Good/Same/Bad csv data to reward data. 33 | 34 | Args: 35 | example (dict): A dataframe containing the GSB csv data. 36 | data_dir (str): The directory path to the video files. 37 | eval_dim (str): The dimension to evaluate ("VQ"/"MQ"/"TA"). 38 | max_pixels (int): The maximum number of pixels allowed for videos. 39 | num_frames (float): Number of frames. 40 | prompt_template_type (str): The type of prompt template to use ("none"/"simple"/"video_score"). 41 | 42 | Returns: 43 | dict: A dictionary containing the reward data. 44 | """ 45 | 46 | A_data = [ 47 | { 48 | "role": "user", 49 | "content": [ 50 | { 51 | "type": "image", 52 | "image": example[f'image1'], 53 | "max_pixels": max_pixels, 54 | }, 55 | {"type": "text", "text": build_prompt(example["prompt"], eval_dims, prompt_template_type)}, 56 | ], 57 | } 58 | ] 59 | B_data = [ 60 | { 61 | "role": "user", 62 | "content": [ 63 | { 64 | "type": "image", 65 | "image": example[f'image2'], 66 | "max_pixels": max_pixels, 67 | }, 68 | {"type": "text", "text": build_prompt(example["prompt"], eval_dims, prompt_template_type)}, 69 | ], 70 | } 71 | ] 72 | 73 | chosen_labels = [] 74 | A_scores = [] 75 | B_scores = [] 76 | 77 | for eval_dim in eval_dims: 78 | ### chosen_label: 1 if A is chosen, -1 if B is chosen, 0 if tied. 79 | ### 22 if invalid. ooaaeeaa o.O 80 | try: 81 | if example[f"{eval_dim}"] is not None: 82 | if example[f"{eval_dim}"] == "A": 83 | chosen_label = 1 84 | elif example[f"{eval_dim}"] == "B": 85 | chosen_label = -1 86 | elif example[f"{eval_dim}"] == "same": 87 | chosen_label = 0 88 | elif example[f"{eval_dim}"] == "invalid": 89 | chosen_label = 22 90 | else: 91 | chosen_label = 22 92 | else: 93 | chosen_label = 22 94 | except Exception as e: 95 | chosen_label = 22 96 | 97 | chosen_labels.append(chosen_label) 98 | if f"MOS_A_{eval_dim}" in example and f"MOS_B_{eval_dim}" in example: 99 | try: 100 | A_score = example[f"MOS_A_{eval_dim}"] if example[f"MOS_A_{eval_dim}"] is not None else 0.0 101 | B_score = example[f"MOS_B_{eval_dim}"] if example[f"MOS_B_{eval_dim}"] is not None else 0.0 102 | except Exception as e: 103 | A_score = 0.0 104 | B_score = 0.0 105 | A_scores.append(A_score) 106 | B_scores.append(B_score) 107 | else: 108 | A_scores.append(0.0) 109 | B_scores.append(0.0) 110 | 111 | chosen_labels = torch.tensor(chosen_labels, dtype=torch.long) 112 | A_scores = torch.tensor(A_scores, dtype=torch.float) 113 | B_scores = torch.tensor(B_scores, dtype=torch.float) 114 | metainfo_idx = None 115 | if 'metainfo_idx' in example: 116 | metainfo_idx = example['metainfo_idx'] 117 | 118 | return {"A_data": A_data, "B_data": B_data, 119 | "A_scores": A_scores, "B_scores": B_scores, 120 | "chosen_label": chosen_labels, 121 | "metainfo_idx": metainfo_idx,} 122 | 123 | class QWen2VLDataCollator(): 124 | def __init__(self, processor, add_noise=False, p_shuffle_frames=0.0, p_color_jitter=0.0): 125 | self.processor = processor 126 | self.add_noise = add_noise 127 | self.set_noise_step = None 128 | 129 | self.p_shuffle_frames = p_shuffle_frames 130 | self.p_color_jitter = p_color_jitter 131 | 132 | self.noise_adder = None 133 | 134 | def _clean_message(self, message): 135 | """ 136 | remove unnecessary keys from message(very very necessary) 137 | """ 138 | out_message = [ 139 | { 140 | "role": "user", 141 | "content": [ 142 | { 143 | "type": "image", 144 | "image": message[0]["content"][0]["image"], 145 | "max_pixels": message[0]["content"][0]["max_pixels"], 146 | }, 147 | {"type": "text", "text": message[0]["content"][1]["text"]}, 148 | ], 149 | } 150 | ] 151 | 152 | return out_message 153 | 154 | 155 | def _pad_sequence(self, sequences, attention_mask, max_len, padding_side='right'): 156 | """ 157 | Pad the sequences to the maximum length. 158 | """ 159 | assert padding_side in ['right', 'left'] 160 | if sequences.shape[1] >= max_len: 161 | return sequences, attention_mask 162 | 163 | pad_len = max_len - sequences.shape[1] 164 | padding = (0, pad_len) if padding_side == 'right' else (pad_len, 0) 165 | 166 | sequences_padded = torch.nn.functional.pad(sequences, padding, 'constant', self.processor.tokenizer.pad_token_id) 167 | attention_mask_padded = torch.nn.functional.pad(attention_mask, padding, 'constant', 0) 168 | 169 | return sequences_padded, attention_mask_padded 170 | 171 | def __call__(self, features, enable_noise=True): 172 | """ 173 | Preprocess inputs to token sequences and return a batch 174 | """ 175 | # try: 176 | features_A = [] 177 | features_B = [] 178 | # check if we have a margin. If we do, we need to batch it as well 179 | # has_margin = "margin" in features[0] 180 | has_idx = "metainfo_idx" in features[0] and features[0]["metainfo_idx"] is not None 181 | 182 | for idx, feature in enumerate(features): 183 | features_A.append(self._clean_message(feature["A_data"])) 184 | features_B.append(self._clean_message(feature["B_data"])) 185 | 186 | # import pdb; pdb.set_trace() 187 | image_inputs_A, video_inputs_A = process_vision_info(features_A) 188 | image_inputs_B, video_inputs_B = process_vision_info(features_B) 189 | 190 | do_rescale = False 191 | # print(f"{video_inputs_A[0].shape}, {video_inputs_B[0].shape}") 192 | 193 | # if not enable_noise: 194 | # print("Not training, no noise added.") 195 | batch_A = self.processor( 196 | text=self.processor.apply_chat_template(features_A, tokenize=False, add_generation_prompt=True), 197 | images=image_inputs_A, 198 | videos=video_inputs_A, 199 | padding=True, 200 | return_tensors="pt", 201 | videos_kwargs={"do_rescale": do_rescale}, 202 | ) 203 | batch_B = self.processor( 204 | text=self.processor.apply_chat_template(features_B, tokenize=False, add_generation_prompt=True), 205 | images=image_inputs_B, 206 | videos=video_inputs_B, 207 | padding=True, 208 | return_tensors="pt", 209 | videos_kwargs={"do_rescale": do_rescale}, 210 | ) 211 | 212 | # pdb.set_trace() 213 | max_len = max(batch_A["input_ids"].shape[1], batch_B["input_ids"].shape[1]) 214 | batch_A["input_ids"], batch_A["attention_mask"] = self._pad_sequence(batch_A["input_ids"], batch_A["attention_mask"], max_len, "right") 215 | batch_B["input_ids"], batch_B["attention_mask"] = self._pad_sequence(batch_B["input_ids"], batch_B["attention_mask"], max_len, "right") 216 | # print(f"Batch A: {batch_A['input_ids'].shape}, Batch B: {batch_B['input_ids'].shape}") 217 | 218 | chosen_label = torch.stack([torch.tensor(feature["chosen_label"]) for feature in features]) 219 | 220 | A_scores = torch.stack([torch.tensor(feature["A_scores"]) for feature in features]) 221 | B_scores = torch.stack([torch.tensor(feature["B_scores"]) for feature in features]) 222 | 223 | batch = { 224 | "A": batch_A, 225 | "B": batch_B, 226 | "return_loss": True, 227 | "chosen_label": chosen_label, 228 | "A_scores": A_scores, 229 | "B_scores": B_scores, 230 | } 231 | 232 | if has_idx: 233 | metainfo_idx = torch.stack([torch.tensor(feature["metainfo_idx"]) for feature in features]) 234 | batch["metainfo_idx"] = metainfo_idx 235 | 236 | # pdb.set_trace() 237 | return batch 238 | 239 | # except Exception as e: 240 | # print(f"Error processing batch: {e} in reading.") 241 | # # get next batch 242 | # return None 243 | -------------------------------------------------------------------------------- /train_flux/train/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lightning as L 3 | from diffusers.pipelines import FluxPipeline 4 | import torch 5 | from torchvision import transforms 6 | from peft import LoraConfig, get_peft_model_state_dict 7 | 8 | import prodigyopt 9 | 10 | from flux.generate import generate 11 | from flux.transformer import tranformer_forward 12 | from flux.condition import Condition 13 | from flux.pipeline_tools import encode_images, prepare_text_input 14 | 15 | 16 | class OminiModel(L.LightningModule): 17 | def __init__( 18 | self, 19 | flux_pipe_id: str, 20 | lora_path: str = None, 21 | lora_config: dict = None, 22 | data_config: dict = None, 23 | device: str = "cuda", 24 | dtype: torch.dtype = torch.bfloat16, 25 | model_config: dict = {}, 26 | optimizer_config: dict = None, 27 | gradient_checkpointing: bool = False, 28 | save_path: str = None, 29 | run_name: str = None, 30 | cache_dir: str = None, 31 | ): 32 | # Initialize the LightningModule 33 | super().__init__() 34 | self.model_config = model_config 35 | self.optimizer_config = optimizer_config 36 | self.data_config = data_config 37 | self.save_path = save_path 38 | self.run_name = run_name 39 | 40 | # Load the Flux pipeline 41 | self.flux_pipe: FluxPipeline = ( 42 | FluxPipeline.from_pretrained(flux_pipe_id, cache_dir=cache_dir, torch_dtype=dtype).to(device) 43 | ) 44 | self.transformer = self.flux_pipe.transformer 45 | self.transformer.gradient_checkpointing = gradient_checkpointing 46 | self.transformer.train() 47 | 48 | # Freeze the Flux pipeline 49 | self.flux_pipe.text_encoder.requires_grad_(False).eval() 50 | self.flux_pipe.text_encoder_2.requires_grad_(False).eval() 51 | self.flux_pipe.vae.requires_grad_(False).eval() 52 | 53 | # Initialize LoRA layers 54 | self.lora_layers = self.init_lora(lora_path, lora_config) 55 | 56 | self.to(device).to(dtype) 57 | 58 | def init_lora(self, lora_path: str, lora_config: dict): 59 | assert lora_path or lora_config 60 | if lora_path: 61 | # # TODO: Implement this 62 | # raise NotImplementedError 63 | self.flux_pipe.load_lora_weights(lora_path, adapter_name="default") 64 | # TODO: Check if this is correct (p.requires_grad) 65 | lora_layers = [] 66 | for name, p in self.transformer.named_parameters(): 67 | if "lora" in name: 68 | lora_layers.append(p) 69 | # lora_layers = filter( 70 | # lambda p: p.requires_grad, self.transformer.parameters() 71 | # ) 72 | else: 73 | if lora_config.get("target_modules", None) == "all-linear": 74 | target_modules = set() 75 | for name, module in self.transformer.named_modules(): 76 | if isinstance(module, torch.nn.Linear): 77 | target_modules.add(name) 78 | target_modules = list(target_modules) 79 | lora_config["target_modules"] = target_modules 80 | self.transformer.add_adapter(LoraConfig(**lora_config)) 81 | # TODO: Check if this is correct (p.requires_grad) 82 | lora_layers = filter( 83 | lambda p: p.requires_grad, self.transformer.parameters() 84 | ) 85 | return list(lora_layers) if not isinstance(lora_layers, list) else lora_layers 86 | 87 | def save_lora(self, path: str): 88 | FluxPipeline.save_lora_weights( 89 | save_directory=path, 90 | transformer_lora_layers=get_peft_model_state_dict(self.transformer), 91 | safe_serialization=True, 92 | ) 93 | 94 | def configure_optimizers(self): 95 | # Freeze the transformer 96 | self.transformer.requires_grad_(False) 97 | opt_config = self.optimizer_config 98 | 99 | # Set the trainable parameters 100 | self.trainable_params = self.lora_layers 101 | 102 | # Unfreeze trainable parameters 103 | for p in self.trainable_params: 104 | p.requires_grad_(True) 105 | 106 | # Initialize the optimizer 107 | if opt_config["type"] == "AdamW": 108 | optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"]) 109 | elif opt_config["type"] == "Prodigy": 110 | optimizer = prodigyopt.Prodigy( 111 | self.trainable_params, 112 | **opt_config["params"], 113 | ) 114 | elif opt_config["type"] == "SGD": 115 | optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"]) 116 | else: 117 | raise NotImplementedError 118 | 119 | return optimizer 120 | 121 | def validation_step(self, batch, batch_idx): 122 | generator = torch.Generator(device=self.device) 123 | generator.manual_seed(42) 124 | target_size = self.data_config["target_size"] 125 | condition_size = self.data_config["condition_size"] 126 | prompt = batch["description"][0] 127 | original_prompt = batch["original_prompt"][0] 128 | condition_type = batch["condition_type"][0] 129 | condition_img = batch["condition"][0] 130 | to_pil = transforms.ToPILImage() 131 | condition_img = to_pil(condition_img.cpu().float()) 132 | position_delta = batch["position_delta"][0] 133 | condition = Condition( 134 | condition_type=condition_type, 135 | condition=condition_img, 136 | position_delta=position_delta, 137 | ) 138 | 139 | res = generate( 140 | self.flux_pipe, 141 | prompt=original_prompt, 142 | prompt_2=prompt, 143 | conditions=[condition], 144 | height=target_size, 145 | width=target_size, 146 | generator=generator, 147 | model_config=self.model_config, 148 | default_lora=True, 149 | ) 150 | os.makedirs(os.path.join(self.save_path, self.run_name, "val"), exist_ok=True) 151 | res.images[0].save( 152 | os.path.join(self.save_path, self.run_name, "val", f"{self.global_step}_{condition_type}_{batch_idx}.jpg") 153 | ) 154 | 155 | def training_step(self, batch, batch_idx): 156 | step_loss = self.step(batch) 157 | self.log_loss = ( 158 | step_loss.item() 159 | if not hasattr(self, "log_loss") 160 | else self.log_loss * 0.95 + step_loss.item() * 0.05 161 | ) 162 | return step_loss 163 | 164 | def step(self, batch): 165 | imgs = batch["image"] 166 | conditions = batch["condition"] 167 | condition_types = batch["condition_type"] 168 | prompts = batch["original_prompt"] 169 | position_delta = batch["position_delta"][0] 170 | prompts_2 = batch["description"] 171 | 172 | # Prepare inputs 173 | with torch.no_grad(): 174 | # Prepare image input 175 | x_0, img_ids = encode_images(self.flux_pipe, imgs) 176 | 177 | # Prepare text input 178 | prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input( 179 | self.flux_pipe, prompts, prompts_2=prompts_2 180 | ) 181 | 182 | # Prepare t and x_t 183 | t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device)) 184 | x_1 = torch.randn_like(x_0).to(self.device) 185 | t_ = t.unsqueeze(1).unsqueeze(1) 186 | x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype) 187 | 188 | # # Prepare conditions 189 | # condition_latents, condition_ids = encode_images(self.flux_pipe, conditions) 190 | 191 | # # Add position delta 192 | # condition_ids[:, 1] += position_delta[0] 193 | # condition_ids[:, 2] += position_delta[1] 194 | 195 | # # Prepare condition type 196 | # condition_type_ids = torch.tensor( 197 | # [ 198 | # Condition.get_type_id(condition_type) 199 | # for condition_type in condition_types 200 | # ] 201 | # ).to(self.device) 202 | # condition_type_ids = ( 203 | # torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0] 204 | # ).unsqueeze(1) 205 | 206 | condition_latents = None 207 | condition_ids = None 208 | condition_type_ids = None 209 | 210 | # Prepare guidance 211 | guidance = ( 212 | torch.ones_like(t).to(self.device) 213 | if self.transformer.config.guidance_embeds 214 | else None 215 | ) 216 | 217 | # Forward pass 218 | transformer_out = tranformer_forward( 219 | self.transformer, 220 | # Model config 221 | model_config=self.model_config, 222 | # Inputs of the condition (new feature) 223 | condition_latents=condition_latents, 224 | condition_ids=condition_ids, 225 | condition_type_ids=condition_type_ids, 226 | # Inputs to the original transformer 227 | hidden_states=x_t, 228 | timestep=t, 229 | guidance=guidance, 230 | pooled_projections=pooled_prompt_embeds, 231 | encoder_hidden_states=prompt_embeds, 232 | txt_ids=text_ids, 233 | img_ids=img_ids, 234 | joint_attention_kwargs=None, 235 | return_dict=False, 236 | ) 237 | pred = transformer_out[0] 238 | 239 | # Compute loss 240 | loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean") 241 | self.last_t = t.mean().item() 242 | return loss 243 | -------------------------------------------------------------------------------- /train_flux/flux/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.pipelines import FluxPipeline 3 | from typing import List, Union, Optional, Dict, Any, Callable 4 | from .block import block_forward, single_block_forward 5 | from .lora_controller import enable_lora 6 | from accelerate.utils import is_torch_version 7 | from diffusers.models.transformers.transformer_flux import ( 8 | FluxTransformer2DModel, 9 | Transformer2DModelOutput, 10 | USE_PEFT_BACKEND, 11 | scale_lora_layers, 12 | unscale_lora_layers, 13 | logger, 14 | ) 15 | import numpy as np 16 | 17 | 18 | def prepare_params( 19 | hidden_states: torch.Tensor, 20 | encoder_hidden_states: torch.Tensor = None, 21 | pooled_projections: torch.Tensor = None, 22 | timestep: torch.LongTensor = None, 23 | img_ids: torch.Tensor = None, 24 | txt_ids: torch.Tensor = None, 25 | guidance: torch.Tensor = None, 26 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 27 | controlnet_block_samples=None, 28 | controlnet_single_block_samples=None, 29 | return_dict: bool = True, 30 | **kwargs: dict, 31 | ): 32 | return ( 33 | hidden_states, 34 | encoder_hidden_states, 35 | pooled_projections, 36 | timestep, 37 | img_ids, 38 | txt_ids, 39 | guidance, 40 | joint_attention_kwargs, 41 | controlnet_block_samples, 42 | controlnet_single_block_samples, 43 | return_dict, 44 | ) 45 | 46 | 47 | def tranformer_forward( 48 | transformer: FluxTransformer2DModel, 49 | condition_latents: torch.Tensor, 50 | condition_ids: torch.Tensor, 51 | condition_type_ids: torch.Tensor, 52 | model_config: Optional[Dict[str, Any]] = {}, 53 | c_t=0, 54 | **params: dict, 55 | ): 56 | self = transformer 57 | use_condition = condition_latents is not None 58 | 59 | ( 60 | hidden_states, 61 | encoder_hidden_states, 62 | pooled_projections, 63 | timestep, 64 | img_ids, 65 | txt_ids, 66 | guidance, 67 | joint_attention_kwargs, 68 | controlnet_block_samples, 69 | controlnet_single_block_samples, 70 | return_dict, 71 | ) = prepare_params(**params) 72 | 73 | if joint_attention_kwargs is not None: 74 | joint_attention_kwargs = joint_attention_kwargs.copy() 75 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 76 | else: 77 | lora_scale = 1.0 78 | 79 | if USE_PEFT_BACKEND: 80 | # weight the lora layers by setting `lora_scale` for each PEFT layer 81 | scale_lora_layers(self, lora_scale) 82 | else: 83 | if ( 84 | joint_attention_kwargs is not None 85 | and joint_attention_kwargs.get("scale", None) is not None 86 | ): 87 | logger.warning( 88 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 89 | ) 90 | 91 | with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)): 92 | hidden_states = self.x_embedder(hidden_states) 93 | condition_latents = self.x_embedder(condition_latents) if use_condition else None 94 | 95 | timestep = timestep.to(hidden_states.dtype) * 1000 96 | 97 | if guidance is not None: 98 | guidance = guidance.to(hidden_states.dtype) * 1000 99 | else: 100 | guidance = None 101 | 102 | temb = ( 103 | self.time_text_embed(timestep, pooled_projections) 104 | if guidance is None 105 | else self.time_text_embed(timestep, guidance, pooled_projections) 106 | ) 107 | 108 | cond_temb = ( 109 | self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections) 110 | if guidance is None 111 | else self.time_text_embed( 112 | torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections 113 | ) 114 | ) 115 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 116 | 117 | if txt_ids.ndim == 3: 118 | logger.warning( 119 | "Passing `txt_ids` 3d torch.Tensor is deprecated." 120 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 121 | ) 122 | txt_ids = txt_ids[0] 123 | if img_ids.ndim == 3: 124 | logger.warning( 125 | "Passing `img_ids` 3d torch.Tensor is deprecated." 126 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 127 | ) 128 | img_ids = img_ids[0] 129 | 130 | ids = torch.cat((txt_ids, img_ids), dim=0) 131 | image_rotary_emb = self.pos_embed(ids) 132 | if use_condition: 133 | # condition_ids[:, :1] = condition_type_ids 134 | cond_rotary_emb = self.pos_embed(condition_ids) 135 | 136 | # hidden_states = torch.cat([hidden_states, condition_latents], dim=1) 137 | 138 | for index_block, block in enumerate(self.transformer_blocks): 139 | if self.training and self.gradient_checkpointing: 140 | ckpt_kwargs: Dict[str, Any] = ( 141 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 142 | ) 143 | encoder_hidden_states, hidden_states, condition_latents = ( 144 | torch.utils.checkpoint.checkpoint( 145 | block_forward, 146 | self=block, 147 | model_config=model_config, 148 | hidden_states=hidden_states, 149 | encoder_hidden_states=encoder_hidden_states, 150 | condition_latents=condition_latents if use_condition else None, 151 | temb=temb, 152 | cond_temb=cond_temb if use_condition else None, 153 | cond_rotary_emb=cond_rotary_emb if use_condition else None, 154 | image_rotary_emb=image_rotary_emb, 155 | **ckpt_kwargs, 156 | ) 157 | ) 158 | 159 | else: 160 | encoder_hidden_states, hidden_states, condition_latents = block_forward( 161 | block, 162 | model_config=model_config, 163 | hidden_states=hidden_states, 164 | encoder_hidden_states=encoder_hidden_states, 165 | condition_latents=condition_latents if use_condition else None, 166 | temb=temb, 167 | cond_temb=cond_temb if use_condition else None, 168 | cond_rotary_emb=cond_rotary_emb if use_condition else None, 169 | image_rotary_emb=image_rotary_emb, 170 | ) 171 | 172 | # controlnet residual 173 | if controlnet_block_samples is not None: 174 | interval_control = len(self.transformer_blocks) / len( 175 | controlnet_block_samples 176 | ) 177 | interval_control = int(np.ceil(interval_control)) 178 | hidden_states = ( 179 | hidden_states 180 | + controlnet_block_samples[index_block // interval_control] 181 | ) 182 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 183 | 184 | for index_block, block in enumerate(self.single_transformer_blocks): 185 | if self.training and self.gradient_checkpointing: 186 | ckpt_kwargs: Dict[str, Any] = ( 187 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 188 | ) 189 | result = torch.utils.checkpoint.checkpoint( 190 | single_block_forward, 191 | self=block, 192 | model_config=model_config, 193 | hidden_states=hidden_states, 194 | temb=temb, 195 | image_rotary_emb=image_rotary_emb, 196 | **( 197 | { 198 | "condition_latents": condition_latents, 199 | "cond_temb": cond_temb, 200 | "cond_rotary_emb": cond_rotary_emb, 201 | } 202 | if use_condition 203 | else {} 204 | ), 205 | **ckpt_kwargs, 206 | ) 207 | 208 | else: 209 | result = single_block_forward( 210 | block, 211 | model_config=model_config, 212 | hidden_states=hidden_states, 213 | temb=temb, 214 | image_rotary_emb=image_rotary_emb, 215 | **( 216 | { 217 | "condition_latents": condition_latents, 218 | "cond_temb": cond_temb, 219 | "cond_rotary_emb": cond_rotary_emb, 220 | } 221 | if use_condition 222 | else {} 223 | ), 224 | ) 225 | if use_condition: 226 | hidden_states, condition_latents = result 227 | else: 228 | hidden_states = result 229 | 230 | # controlnet residual 231 | if controlnet_single_block_samples is not None: 232 | interval_control = len(self.single_transformer_blocks) / len( 233 | controlnet_single_block_samples 234 | ) 235 | interval_control = int(np.ceil(interval_control)) 236 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 237 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] 238 | + controlnet_single_block_samples[index_block // interval_control] 239 | ) 240 | 241 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 242 | 243 | hidden_states = self.norm_out(hidden_states, temb) 244 | output = self.proj_out(hidden_states) 245 | 246 | if USE_PEFT_BACKEND: 247 | # remove `lora_scale` from each PEFT layer 248 | unscale_lora_layers(self, lora_scale) 249 | 250 | if not return_dict: 251 | return (output,) 252 | return Transformer2DModelOutput(sample=output) -------------------------------------------------------------------------------- /tts/verifiers/geneval_detailed_verifier_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "single_object": "You are a multimodal large-language model tasked with evaluating images generated by a text-to-image model. Your goal is to assess each generated image based on specific aspects and provide a detailed critique, along with a scoring system. The final output should be formatted as a JSON object containing individual scores for each aspect and an overall score. The keys in the JSON object should be: `object_completeness`, `detectability`, `occlusion_handling`, and `overall_score`. Below is a comprehensive guide to follow in your evaluation process: Your evaluation should focus on these aspects:\n\n 1. Key Evaluation Aspects and Scoring Criteria: For each aspect, provide a score from 0 to 10, where 0 represents poor performance and 10 represents excellent performance. For each score, include a short explanation or justification (1-2 sentences) explaining why that score was given. The aspects to evaluate are as follows: \n\n a) Object Completeness:\nAssess the structural integrity of the object (no defects/deformations), detail clarity and legibility.\nScore: 0 (severely fragmented) to 10 (perfectly intact).\n\nb) Detectability:\nEvaluate the distinction and visual saliency of objects and backgrounds using contrast analysis.\nScore: 0 (camouflaged) to 10 (immediately noticeable).\n\nc) Occlusion Handling:\nAssess whether there is unreasonable occlusion (natural occlusion needs to keep the subject visible).\nScore: 0 (key parts are blocked) to 10 (no blockage/natural and reasonable blockage).\n\n2. Overall Score \nAfter scoring each aspect individually, provide an overall score, representing the model's general performance on this image. This should be a weighted average based on the importance of each aspect to the prompt or an average of all aspects.", 3 | 4 | "two_object": "You are a multimodal large-language model tasked with evaluating images generated by a text-to-image model. Your goal is to assess each generated image based on specific aspects and provide a detailed critique, along with a scoring system. The final output should be formatted as a JSON object containing individual scores for each aspect and an overall score. The keys in the JSON object should be: `separation_clarity`, `individual_completeness`, `relationship_accuracy`, and `overall_score`. Below is a comprehensive guide to follow in your evaluation process: Your evaluation should focus on these aspects:\n\n 1. Key Evaluation Aspects and Scoring Criteria: For each aspect, provide a score from 0 to 10, where 0 represents poor performance and 10 represents excellent performance. For each score, include a short explanation or justification (1-2 sentences) explaining why that score was given. The aspects to evaluate are as follows: \n\n a) Seperation Clarity:\nAssess the spatial separation and boundary clarity of two objects.\nScore: 0 (fully overlapped) to 10 (completely separate and clearly defined boundaries)\n\nb) Indivisual Completeness:\nEvaluate each object's individual integrity and detail retention.\nScore: 0 (both objects are incomplete) to 10 (both objects are complete).\n\nc) Relationship Accuracy:\nAssess the rationality of size proportions.\nScore: 0 (wrong proportions) to 10 (perfectly in line with physical laws).\n\n2. Overall Score \nAfter scoring each aspect individually, provide an overall score, representing the model's general performance on this image. This should be a weighted average based on the importance of each aspect to the prompt or an average of all aspects.", 5 | 6 | "counting": "You are a multimodal large-language model tasked with evaluating images generated by a text-to-image model. Your goal is to assess each generated image based on specific aspects and provide a detailed critique, along with a scoring system. The final output should be formatted as a JSON object containing individual scores for each aspect and an overall score. The keys in the JSON object should be: `count_accuracy`, `object_uniformity`, `spatial_legibility`, and `overall_score`. Below is a comprehensive guide to follow in your evaluation process: Your evaluation should focus on these aspects:\n\n 1. Key Evaluation Aspects and Scoring Criteria: For each aspect, provide a score from 0 to 10, where 0 represents poor performance and 10 represents excellent performance. For each score, include a short explanation or justification (1-2 sentences) explaining why that score was given. The aspects to evaluate are as follows: \n\n a) Count Accuracy:\nAssess the number of generated objects matches the exact prompt.\nScore: 0 (number wrong) to 10 (number correct)\n\nb) Object Uniformity:\nEvaluate the consistency of shape/size/color among same kind of objects.\nScore: 0 (same kind but total different shape/size/color) to 10 (same kind and same shape/size/color).\n\nc) Spatial Legibility:\nEvaluate the plausibility and visibility of object distribution (no excessive overlap).\nScore: 0 (heavily overlapped) to 10 (perfect displayed and all easily seen).\n\n2. Overall Score \nAfter scoring each aspect individually, provide an overall score, representing the model's general performance on this image. This should be a weighted average based on the importance of each aspect to the prompt or an average of all aspects.", 7 | 8 | "colors": "You are a multimodal large-language model tasked with evaluating images generated by a text-to-image model. Your goal is to assess each generated image based on specific aspects and provide a detailed critique, along with a scoring system. The final output should be formatted as a JSON object containing individual scores for each aspect and an overall score. The keys in the JSON object should be: `color_fidelity`, `contrast_effectiveness`, `multi_object_consistency`, and `overall_score`. Below is a comprehensive guide to follow in your evaluation process: Your evaluation should focus on these aspects:\n\n 1. Key Evaluation Aspects and Scoring Criteria: For each aspect, provide a score from 0 to 10, where 0 represents poor performance and 10 represents excellent performance. For each score, include a short explanation or justification (1-2 sentences) explaining why that score was given. The aspects to evaluate are as follows: \n\n a) Color Fidelity:\nAssess the exact match between the object color and the input prompt.\nScore: 0 (color wrong) to 10 (color correct)\n\nb) Contrast Effectiveness:\nEvaluate the difference between foreground and background colors.\nScore: 0 (similar colors, difficult to distinguish) to 10 (high contrast).\n\nc) Multi-Object Consistency:\nAssess color consistency across multiple same kind of objects.\nScore: 0 (same kind of objects with total different colors) to 10 (same kind with same color).\n\n2. Overall Score \nAfter scoring each aspect individually, provide an overall score, representing the model's general performance on this image. This should be a weighted average based on the importance of each aspect to the prompt or an average of all aspects.", 9 | 10 | "position": "You are a multimodal large-language model tasked with evaluating images generated by a text-to-image model. Your goal is to assess each generated image based on specific aspects and provide a detailed critique, along with a scoring system. The final output should be formatted as a JSON object containing individual scores for each aspect and an overall score. The keys in the JSON object should be: `position_accuracy`, `occlusion_management`, `perspective_consistency`, and `overall_score`. Below is a comprehensive guide to follow in your evaluation process: Your evaluation should focus on these aspects:\n\n 1. Key Evaluation Aspects and Scoring Criteria: For each aspect, provide a score from 0 to 10, where 0 represents poor performance and 10 represents excellent performance. For each score, include a short explanation or justification (1-2 sentences) explaining why that score was given. The aspects to evaluate are as follows: \n\n a) Positional Accuracy:\nAssess the matching accuracy between spatial position and prompt description.\nScore: 0 (totally wrong) to 10 (postion correct)\n\nb) Occlusion Management:\nEvaluate position discernibility in the presence of occlusion.\nScore: 0 (fully occlusion) to 10 (clearly dsiplay the relationship).\n\nc) Perspective Consistency:\nAssess the rationality of perspective relationship and spatial depth.\nScore: 0 (perspective contradiction) to 10 (completely reasonable).\n\n2. Overall Score \nAfter scoring each aspect individually, provide an overall score, representing the model's general performance on this image. This should be a weighted average based on the importance of each aspect to the prompt or an average of all aspects.", 11 | 12 | "color_attr": "You are a multimodal large-language model tasked with evaluating images generated by a text-to-image model. Your goal is to assess each generated image based on specific aspects and provide a detailed critique, along with a scoring system. The final output should be formatted as a JSON object containing individual scores for each aspect and an overall score. The keys in the JSON object should be: `attribute_binding`, `contrast_effectiveness`, `material_consistency`, and `overall_score`. Below is a comprehensive guide to follow in your evaluation process: Your evaluation should focus on these aspects:\n\n 1. Key Evaluation Aspects and Scoring Criteria: For each aspect, provide a score from 0 to 10, where 0 represents poor performance and 10 represents excellent performance. For each score, include a short explanation or justification (1-2 sentences) explaining why that score was given. The aspects to evaluate are as follows: \n\n a) Attrribute Binding:\nCorrect binding of colors to designated objects (no color mismatches).\nScore: 0 (color mismatch) to 10 (correct binding)\n\nb) Evaluate the difference between foreground and background colors.\nScore: 0 (similar colors, difficult to distinguish) to 10 (high contrast).\n\nc) Material Consistency:\nAssess the coordination of color and material performance.\nScore: 0 (material conflicts) to 10 (perfect harmony).\n\n2. Overall Score \nAfter scoring each aspect individually, provide an overall score, representing the model's general performance on this image. This should be a weighted average based on the importance of each aspect to the prompt or an average of all aspects." 13 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
36 |
37 |