├── train_flux ├── flux │ ├── __init__.py │ ├── pipeline_tools.py │ ├── lora_controller.py │ ├── condition.py │ ├── transformer.py │ ├── generate.py │ └── block.py ├── train │ ├── __init__.py │ ├── callbacks.py │ ├── train.py │ ├── data.py │ └── model.py ├── train.sh ├── sample.sh ├── config_prompt.yaml ├── config.yaml └── sample.py ├── tts ├── verifiers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── nvila_verifier.cpython-310.pyc │ │ └── openai_verifier.cpython-310.pyc │ ├── nvila_verifier.py │ ├── reflexion_prompt.txt │ ├── refine_prompt.txt │ ├── verifier_prompt.txt │ ├── geneval_verifier_prompt.txt │ └── geneval_detailed_verifier_prompt.json ├── configs │ ├── our_reflectionmodel.yaml │ ├── flux.1_dev_nvilascore.json │ └── flux.1_dev_gptscore.json ├── tts_t2i_noise_scaling.py ├── utils.py ├── verifier_filter.py └── tts_t2i_noise_prompt_scaling.py ├── examples └── teaser.jpg ├── reward_modeling ├── __pycache__ │ ├── data.cpython-310.pyc │ ├── data.cpython-312.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-312.pyc │ ├── trainer.cpython-310.pyc │ ├── trainer.cpython-312.pyc │ ├── inference.cpython-310.pyc │ ├── test_reward.cpython-310.pyc │ ├── train_reward.cpython-310.pyc │ ├── prompt_template.cpython-310.pyc │ ├── prompt_template.cpython-312.pyc │ ├── vision_process.cpython-310.pyc │ └── vision_process.cpython-312.pyc ├── test_reward.py ├── prompt_template.py ├── data.py ├── inference.py ├── utils.py └── train_reward.py ├── requirements.txt ├── .gitignore └── README.md /train_flux/flux/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train_flux/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tts/verifiers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/examples/teaser.jpg -------------------------------------------------------------------------------- /tts/configs/our_reflectionmodel.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: path/infer/30000 2 | template: qwen2_vl 3 | finetuning_type: lora -------------------------------------------------------------------------------- /reward_modeling/__pycache__/data.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/data.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/data.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/data.cpython-312.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/utils.cpython-312.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/trainer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/trainer.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/trainer.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/trainer.cpython-312.pyc -------------------------------------------------------------------------------- /tts/verifiers/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/tts/verifiers/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/inference.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/inference.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/test_reward.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/test_reward.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/train_reward.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/train_reward.cpython-310.pyc -------------------------------------------------------------------------------- /tts/verifiers/__pycache__/nvila_verifier.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/tts/verifiers/__pycache__/nvila_verifier.cpython-310.pyc -------------------------------------------------------------------------------- /tts/verifiers/__pycache__/openai_verifier.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/tts/verifiers/__pycache__/openai_verifier.cpython-310.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | transformers 3 | peft 4 | opencv-python 5 | sentencepiece 6 | lightning 7 | datasets 8 | torchvision 9 | prodigyopt 10 | wandb 11 | webdataset -------------------------------------------------------------------------------- /reward_modeling/__pycache__/prompt_template.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/prompt_template.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/prompt_template.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/prompt_template.cpython-312.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/vision_process.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/vision_process.cpython-310.pyc -------------------------------------------------------------------------------- /reward_modeling/__pycache__/vision_process.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Diffusion-CoT/ReflectionFlow/HEAD/reward_modeling/__pycache__/vision_process.cpython-312.pyc -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.out 2 | *.batch 3 | 4 | train_flux/run/ 5 | train_flux/wandb/ 6 | train_flux/train/__pycache__ 7 | train_flux/flux/__pycache__ 8 | 9 | tts/dcgm/ 10 | tts/output/ 11 | tts/__pycache__/ 12 | -------------------------------------------------------------------------------- /train_flux/train.sh: -------------------------------------------------------------------------------- 1 | conda init 2 | source ~/.bashrc 3 | conda activate /ceph/data-bk/miniconda3/envs/flux 4 | 5 | cd /ceph/data-bk/zl/ReflectionFlow/train_flux 6 | 7 | export TOKENIZERS_PARALLELISM=true 8 | accelerate launch --main_process_port 41353 -m train.train 9 | -------------------------------------------------------------------------------- /train_flux/sample.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | CUDA_VISIBLE_DEVICES=2 python sample.py \ 4 | --model_name flux \ 5 | --step 30 \ 6 | --condition_size 512 \ 7 | --target_size 1024 \ 8 | --task_name geneval \ 9 | --lora_dir /mnt/petrelfs/zhuole/ReflectionFlow/train_flux/runs/full_data_v4_cond_512/20250410-191141/ckpt/16000/pytorch_lora_weights.safetensors \ 10 | --output_dir /mnt/petrelfs/zhuole/ReflectionFlow/train_flux/samples/full_data_v4_512_16k \ 11 | --seed 0 \ 12 | -------------------------------------------------------------------------------- /tts/verifiers/nvila_verifier.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel 2 | 3 | # nvila verifier 4 | def load_model(model_name, cache_dir): 5 | print("loading NVILA model") 6 | model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map="auto", cache_dir = cache_dir) 7 | yes_id = model.tokenizer.encode("yes", add_special_tokens=False)[0] 8 | no_id = model.tokenizer.encode("no", add_special_tokens=False)[0] 9 | print("loading NVILA finished") 10 | return model, yes_id, no_id -------------------------------------------------------------------------------- /tts/verifiers/reflexion_prompt.txt: -------------------------------------------------------------------------------- 1 | You are an expert assistant for generating image improvement instructions. Analyze the original prompt, the updated prompt to generate the image, the evaluation of the generated image, and the generated image, give instructions to create specific technical directions following these guidelines:\n1. Structure and Focus Areas:\nFocus strictly on these three aspects in order: Prompt Following.\n2. Detailed Requirements for Each Aspect:\nA. Prompt Following Instructions:\nExamine the original prompt sentence by sentence. \nList exact discrepancies between the bad image and prompt specifications. \nUse direct action verbs: Add, Remove, Replace, Reposition, Adjust, to modify the image. \nSpecify precise locations and modification commands. \nNever use vague terms like ensure or confirm.\n3. Format Specifications:\nUse exact section headers without markdown:\n1. Prompt Following:\n-\n Each instruction must start with a hyphen and complete command. \nInclude spatial references and implementation details. \nOmit sections with no required improvements. \nNever include explanations or examples.\n\n4. Content Principles:\nEvery instruction must be directly executable by an artist. \nPrioritize critical errors first. \nDescribe only missing or incorrect elements. \nUse imperative verb forms exclusively. \nMaintain technical specificity without assumptions.\n -------------------------------------------------------------------------------- /tts/configs/flux.1_dev_nvilascore.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "pipeline_args": { 4 | "pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev", 5 | "cache_dir": "FLUX_PATH", 6 | "torch_dtype": "bf16", 7 | "height": 1024, 8 | "width": 1024, 9 | "condition_size": 512, 10 | "max_sequence_length": 512, 11 | "guidance_scale": 3.5, 12 | "num_inference_steps": 30, 13 | "lora_path": "LORA_PATH" 14 | }, 15 | "verifier_args": { 16 | "name": "nvila", 17 | "model_name": "Efficient-Large-Model/NVILA-Lite-2B-Verifier", 18 | "cache_dir": "models/nvila" 19 | }, 20 | "refine_args": { 21 | "name": "openai", 22 | "choice_of_metric": "overall_score", 23 | "max_new_tokens": 1280, 24 | "refine_prompt_relpath": "refine_prompt.txt", 25 | "reflexion_prompt_relpath": "reflexion_prompt.txt", 26 | "verifier_prompt_relpath": "geneval_detailed_verifier_prompt.json" 27 | }, 28 | "search_args": { 29 | "search_method": "random", 30 | "search_branch": 2, 31 | "search_rounds": 16 32 | }, 33 | "model": { 34 | "add_cond_attn": false, 35 | "latent_lora": false, 36 | "union_cond_attn": true 37 | }, 38 | "reflection_args": { 39 | "run_reflection": true, 40 | "name": "openai" 41 | }, 42 | "prompt_refiner_args": { 43 | "run_refinement": true 44 | }, 45 | "use_low_gpu_vram": false, 46 | "batch_size_for_img_gen": 1 47 | } -------------------------------------------------------------------------------- /tts/verifiers/refine_prompt.txt: -------------------------------------------------------------------------------- 1 | """ 2 | You are a multimodal large-language model tasked with refining user's input prompt to 3 | create images using a text-to-image model. Given a original prompt, a current prompt, 4 | a batch of images generated by the prompt, a reflection prompt about the generated images and their corresponding assessments 5 | evaluated by a multi-domain scoring system, your goal is to refine the current prompt to 6 | improve the overall quality of the generated images. You should analyze the strengths 7 | and drawbacks of current prompt based on the given images and their evaluations. Consider 8 | aspects like subject, scene, style, lighting, tone, mood, camera style, composition, and 9 | others to refine the current prompt. Do not alter the original description from 10 | the original prompt. The refined prompt should not contradict with the reflection prompt. Directly output the refined prompt without any other text. 11 | 12 | Some further instructions you should keep in mind: 13 | 14 | 1) The current prompt is an iterative refinement of the original prompt. 15 | 16 | 2) In case the original prompt and current prompt are the same, ignore the current prompt. 17 | 18 | 3) In some cases, some of the above-mentioned inputs may not be available. For example, the images, 19 | the assessments, etc. In such situations, you should still do your best, analyze the inputs carefully, 20 | and arrive at a refined prompt that would potentially lead to improvements in the final generated images. 21 | 22 | 4) When the evaluations are provided, please consider all aspects of the evaluations very carefully. 23 | """ -------------------------------------------------------------------------------- /tts/configs/flux.1_dev_gptscore.json: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "pipeline_args": { 4 | "pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev", 5 | "cache_dir": "FLUX_PATH", 6 | "torch_dtype": "bf16", 7 | "height": 1024, 8 | "width": 1024, 9 | "condition_size": 512, 10 | "max_sequence_length": 512, 11 | "guidance_scale": 3.5, 12 | "num_inference_steps": 30, 13 | "lora_path": "LORA_PATH" 14 | }, 15 | "verifier_args": { 16 | "name": "openai", 17 | "choice_of_metric": "overall_score", 18 | "max_new_tokens": 1280, 19 | "refine_prompt_relpath": "refine_prompt.txt", 20 | "reflexion_prompt_relpath": "reflexion_prompt.txt", 21 | "verifier_prompt_relpath": "geneval_detailed_verifier_prompt.json" 22 | }, 23 | "refine_args": { 24 | "name": "openai", 25 | "choice_of_metric": "overall_score", 26 | "max_new_tokens": 1280, 27 | "refine_prompt_relpath": "refine_prompt.txt", 28 | "reflexion_prompt_relpath": "reflexion_prompt.txt", 29 | "verifier_prompt_relpath": "geneval_detailed_verifier_prompt.json" 30 | }, 31 | "search_args": { 32 | "search_method": "random", 33 | "search_branch": 2, 34 | "search_rounds": 16 35 | }, 36 | "model": { 37 | "add_cond_attn": false, 38 | "latent_lora": false, 39 | "union_cond_attn": true 40 | }, 41 | "reflection_args": { 42 | "run_reflection": true, 43 | "name": "openai" 44 | }, 45 | "prompt_refiner_args": { 46 | "run_refinement": true 47 | }, 48 | "use_low_gpu_vram": false, 49 | "batch_size_for_img_gen": 1 50 | } -------------------------------------------------------------------------------- /train_flux/flux/pipeline_tools.py: -------------------------------------------------------------------------------- 1 | from diffusers.pipelines import FluxPipeline 2 | from diffusers.utils import logging 3 | from diffusers.pipelines.flux.pipeline_flux import logger 4 | from torch import Tensor 5 | 6 | 7 | def encode_images(pipeline: FluxPipeline, images: Tensor): 8 | images = pipeline.image_processor.preprocess(images) 9 | images = images.to(pipeline.device).to(pipeline.dtype) 10 | images = pipeline.vae.encode(images).latent_dist.sample() 11 | images = ( 12 | images - pipeline.vae.config.shift_factor 13 | ) * pipeline.vae.config.scaling_factor 14 | images_tokens = pipeline._pack_latents(images, *images.shape) 15 | images_ids = pipeline._prepare_latent_image_ids( 16 | images.shape[0], 17 | images.shape[2], 18 | images.shape[3], 19 | pipeline.device, 20 | pipeline.dtype, 21 | ) 22 | if images_tokens.shape[1] != images_ids.shape[0]: 23 | images_ids = pipeline._prepare_latent_image_ids( 24 | images.shape[0], 25 | images.shape[2] // 2, 26 | images.shape[3] // 2, 27 | pipeline.device, 28 | pipeline.dtype, 29 | ) 30 | return images_tokens, images_ids 31 | 32 | 33 | def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512, prompts_2=None): 34 | # Turn off warnings (CLIP overflow) 35 | logger.setLevel(logging.ERROR) 36 | ( 37 | prompt_embeds, 38 | pooled_prompt_embeds, 39 | text_ids, 40 | ) = pipeline.encode_prompt( 41 | prompt=prompts, 42 | prompt_2=prompts_2, 43 | prompt_embeds=None, 44 | pooled_prompt_embeds=None, 45 | device=pipeline.device, 46 | num_images_per_prompt=1, 47 | max_sequence_length=max_sequence_length, 48 | lora_scale=None, 49 | ) 50 | # Turn on warnings 51 | logger.setLevel(logging.WARNING) 52 | return prompt_embeds, pooled_prompt_embeds, text_ids 53 | -------------------------------------------------------------------------------- /train_flux/config_prompt.yaml: -------------------------------------------------------------------------------- 1 | model_path: "black-forest-labs/FLUX.1-dev" 2 | dtype: "bfloat16" 3 | cache_dir: "CACHE_DIR" 4 | 5 | model: 6 | union_cond_attn: true 7 | add_cond_attn: false 8 | latent_lora: false 9 | 10 | train: 11 | batch_size: 8 12 | accumulate_grad_batches: 1 13 | dataloader_workers: 8 14 | save_interval: 2000 15 | sample_interval: 2000 16 | max_steps: -1 17 | gradient_checkpointing: true 18 | save_path: "./runs/prompt_cond_512" 19 | 20 | # Specify the type of condition to use. 21 | condition_type: "cot" 22 | resume_training_from_last_checkpoint: false 23 | resume_training_from_checkpoint_path: "" 24 | dataset: 25 | type: "img" 26 | path: "/ceph/hf-cache/hub/datasets--diffusion-cot--GenRef-wds/snapshots/42c837b891fc34a944ed0c8124876a7e8225266f/*.tar" 27 | split_ratios: { 28 | "general": [0.1, 0.3], 29 | "length": [0.1, 0.3], 30 | "rule": [0.1, 0.4], 31 | "editing": [0.7, 0.0] 32 | } 33 | training_stages: [0, 5000] 34 | root_dir: "" 35 | # val_path: { 36 | # "general": "VAL_TARS" 37 | # } 38 | # val_root_dir: "" 39 | condition_size: 512 40 | target_size: 1024 41 | drop_text_prob: 0.1 42 | drop_image_prob: 0.1 43 | drop_reflection_prob: 1.0 44 | 45 | wandb: 46 | project: "ReflectionFlow" 47 | name: "prompt_cond_512" 48 | 49 | lora_config: 50 | r: 32 51 | lora_alpha: 32 52 | init_lora_weights: "gaussian" 53 | target_modules: "(.*x_embedder|.*(? None: 7 | self.activated: bool = activated 8 | if activated: 9 | return 10 | self.lora_modules: List[BaseTunerLayer] = [ 11 | each for each in lora_modules if isinstance(each, BaseTunerLayer) 12 | ] 13 | self.scales = [ 14 | { 15 | active_adapter: lora_module.scaling[active_adapter] 16 | for active_adapter in lora_module.active_adapters 17 | } 18 | for lora_module in self.lora_modules 19 | ] 20 | 21 | def __enter__(self) -> None: 22 | if self.activated: 23 | return 24 | 25 | for lora_module in self.lora_modules: 26 | if not isinstance(lora_module, BaseTunerLayer): 27 | continue 28 | lora_module.scale_layer(0) 29 | 30 | def __exit__( 31 | self, 32 | exc_type: Optional[Type[BaseException]], 33 | exc_val: Optional[BaseException], 34 | exc_tb: Optional[Any], 35 | ) -> None: 36 | if self.activated: 37 | return 38 | for i, lora_module in enumerate(self.lora_modules): 39 | if not isinstance(lora_module, BaseTunerLayer): 40 | continue 41 | for active_adapter in lora_module.active_adapters: 42 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter] 43 | 44 | 45 | class set_lora_scale: 46 | def __init__(self, lora_modules: List[BaseTunerLayer], scale: float) -> None: 47 | self.lora_modules: List[BaseTunerLayer] = [ 48 | each for each in lora_modules if isinstance(each, BaseTunerLayer) 49 | ] 50 | self.scales = [ 51 | { 52 | active_adapter: lora_module.scaling[active_adapter] 53 | for active_adapter in lora_module.active_adapters 54 | } 55 | for lora_module in self.lora_modules 56 | ] 57 | self.scale = scale 58 | 59 | def __enter__(self) -> None: 60 | for lora_module in self.lora_modules: 61 | if not isinstance(lora_module, BaseTunerLayer): 62 | continue 63 | lora_module.scale_layer(self.scale) 64 | 65 | def __exit__( 66 | self, 67 | exc_type: Optional[Type[BaseException]], 68 | exc_val: Optional[BaseException], 69 | exc_tb: Optional[Any], 70 | ) -> None: 71 | for i, lora_module in enumerate(self.lora_modules): 72 | if not isinstance(lora_module, BaseTunerLayer): 73 | continue 74 | for active_adapter in lora_module.active_adapters: 75 | lora_module.scaling[active_adapter] = self.scales[i][active_adapter] 76 | -------------------------------------------------------------------------------- /train_flux/train/callbacks.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from PIL import Image, ImageFilter, ImageDraw 3 | import numpy as np 4 | from transformers import pipeline 5 | import cv2 6 | import torch 7 | import os 8 | 9 | try: 10 | import wandb 11 | except ImportError: 12 | wandb = None 13 | 14 | import time 15 | 16 | 17 | class TrainingCallback(L.Callback): 18 | def __init__(self, run_name, training_config: dict = {}): 19 | self.run_name, self.training_config = run_name, training_config 20 | 21 | self.print_every_n_steps = training_config.get("print_every_n_steps", 10) 22 | self.save_interval = training_config.get("save_interval", 1000) 23 | self.save_path = training_config.get("save_path", "./output") 24 | 25 | self.wandb_config = training_config.get("wandb", None) 26 | self.use_wandb = ( 27 | wandb is not None and os.environ.get("WANDB_API_KEY") is not None 28 | ) 29 | 30 | self.total_steps = 0 31 | 32 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 33 | gradient_size = 0 34 | max_gradient_size = 0 35 | count = 0 36 | for _, param in pl_module.named_parameters(): 37 | if param.grad is not None: 38 | gradient_size += param.grad.norm(2).item() 39 | max_gradient_size = max(max_gradient_size, param.grad.norm(2).item()) 40 | count += 1 41 | if count > 0: 42 | gradient_size /= count 43 | 44 | self.total_steps += 1 45 | 46 | # Update split ratios 47 | trainer.train_dataloader.dataset._update_split_ratios() 48 | 49 | # Print training progress every n steps 50 | if self.use_wandb: 51 | report_dict = { 52 | "steps": batch_idx, 53 | "steps": self.total_steps, 54 | "epoch": trainer.current_epoch, 55 | "gradient_size": gradient_size, 56 | } 57 | loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches 58 | report_dict["loss"] = loss_value 59 | report_dict["t"] = pl_module.last_t 60 | wandb.log(report_dict) 61 | 62 | if self.total_steps % self.print_every_n_steps == 0: 63 | print( 64 | f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, Loss: {pl_module.log_loss:.4f}, Gradient size: {gradient_size:.4f}, Max gradient size: {max_gradient_size:.4f}" 65 | ) 66 | 67 | # Save LoRA weights at specified intervals 68 | if self.total_steps % self.save_interval == 0: 69 | print( 70 | f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving LoRA weights" 71 | ) 72 | pl_module.save_lora( 73 | f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}" 74 | ) -------------------------------------------------------------------------------- /tts/verifiers/verifier_prompt.txt: -------------------------------------------------------------------------------- 1 | """ 2 | You are a multimodal large-language model tasked with evaluating images 3 | generated by a text-to-image model. Your goal is to assess each generated 4 | image based on specific aspects and provide a detailed critique, along with 5 | a scoring system. The final output should be formatted as a JSON object 6 | containing individual scores for each aspect and an overall score. The keys 7 | in the JSON object should be: `accuracy_to_prompt`, `creativity_and_originality`, 8 | `visual_quality_and_realism`, `consistency_and_cohesion`, 9 | `emotional_or_thematic_resonance`, and `overall_score`. Below is a comprehensive 10 | guide to follow in your evaluation process: 11 | 12 | 1. Key Evaluation Aspects and Scoring Criteria: 13 | For each aspect, provide a score from 0 to 10, where 0 represents poor 14 | performance and 10 represents excellent performance. For each score, include 15 | a short explanation or justification (1-2 sentences) explaining why that 16 | score was given. The aspects to evaluate are as follows: 17 | 18 | a) Accuracy to Prompt 19 | Assess how well the image matches the description given in the prompt. 20 | Consider whether all requested elements are present and if the scene, 21 | objects, and setting align accurately with the text. Score: 0 (no 22 | alignment) to 10 (perfect match to prompt). 23 | 24 | b) Creativity and Originality 25 | Evaluate the uniqueness and creativity of the generated image. Does the 26 | model present an imaginative or aesthetically engaging interpretation of the 27 | prompt? Is there any evidence of creativity beyond a literal interpretation? 28 | Score: 0 (lacks creativity) to 10 (highly creative and original). 29 | 30 | c) Visual Quality and Realism 31 | Assess the overall visual quality, including resolution, detail, and realism. 32 | Look for coherence in lighting, shading, and perspective. Even if the image 33 | is stylized or abstract, judge whether the visual elements are well-rendered 34 | and visually appealing. Score: 0 (poor quality) to 10 (high-quality and 35 | realistic). 36 | 37 | d) Consistency and Cohesion 38 | Check for internal consistency within the image. Are all elements cohesive 39 | and aligned with the prompt? For instance, does the perspective make sense, 40 | and do objects fit naturally within the scene without visual anomalies? 41 | Score: 0 (inconsistent) to 10 (fully cohesive and consistent). 42 | 43 | e) Emotional or Thematic Resonance 44 | Evaluate how well the image evokes the intended emotional or thematic tone of 45 | the prompt. For example, if the prompt is meant to be serene, does the image 46 | convey calmness? If it’s adventurous, does it evoke excitement? Score: 0 47 | (no resonance) to 10 (strong resonance with the prompt’s theme). 48 | 49 | 2. Overall Score 50 | After scoring each aspect individually, provide an overall score, 51 | representing the model’s general performance on this image. This should be 52 | a weighted average based on the importance of each aspect to the prompt or an 53 | average of all aspects. 54 | """ -------------------------------------------------------------------------------- /tts/verifiers/geneval_verifier_prompt.txt: -------------------------------------------------------------------------------- 1 | """ 2 | You are a multimodal large-language model tasked with evaluating images 3 | generated by a text-to-image model. Your goal is to assess each generated 4 | image based on specific aspects and provide a detailed critique, along with 5 | a scoring system. The final output should be formatted as a JSON object 6 | containing individual scores for each aspect and an overall score. The keys 7 | in the JSON object should be: `single_object`, `two_object`, `counting`, 8 | `colors`, `position`, `color_attr`, and `overall_score`. Below is a comprehensive 9 | guide to follow in your evaluation process: 10 | 11 | 1. Key Evaluation Aspects and Scoring Criteria: 12 | For each aspect, provide a score from 0 to 10, where 0 represents poor 13 | performance and 10 represents excellent performance. For each score, include 14 | a short explanation or justification (1-2 sentences) explaining why that 15 | score was given. The aspects to evaluate are as follows: 16 | 17 | a) Single Object 18 | Assess the completeness and detectability of individual objects. Consider 19 | structural integrity, absence of deformations, and clear visibility against 20 | the background. Score: 0 (unrecognizable fragment) to 10 (perfectly defined 21 | and isolated object). 22 | 23 | b) Two Objects 24 | Evaluate separation quality and dual integrity. Check for clear boundaries 25 | between objects, appropriate spatial relationship, and preservation of 26 | distinct features for both elements. Score: 0 (merged/blended objects) to 27 | 10 (perfectly separated with individual clarity). 28 | 29 | c) Counting 30 | Verify accurate quantity representation and distinctiveness. Assess whether 31 | the correct number of objects are present, clearly distinguishable without 32 | overlap or occlusion. Score: 0 (critical counting errors) to 10 (exact count 33 | with unambiguous instances). 34 | 35 | d) Colors 36 | Evaluate color fidelity and contrast. Check if object colors match the prompt 37 | specifications and maintain consistency across multiple instances, with 38 | sufficient contrast against background. Score: 0 (color pollution/confusion) 39 | to 10 (precise color matching with high contrast). 40 | 41 | e) Position 42 | Analyze spatial relationships and occlusion handling. Verify if positional 43 | descriptors (left/right, above/below) are accurately represented, even with 44 | partial overlaps. Score: 0 (contradictory positioning) to 10 (pixel-perfect 45 | spatial alignment). 46 | 47 | f) Color-Attribute Binding 48 | Assess color-object association accuracy. Check if specific colors are 49 | correctly applied to designated objects and maintain differentiation between 50 | distinct elements. Score: 0 (color-attribute mismatch) to 10 (perfect 51 | color-object binding). 52 | 53 | 2. Overall Score 54 | After scoring each aspect individually, provide an overall score calculated 55 | as: 56 | (0.3×Single Object) + (0.2×Two Objects) + (0.15×Counting) + (0.15×Colors) 57 | + (0.1×Position) + (0.1×Color-Attribute Binding). Include a brief rationale 58 | for the weighting approach. 59 | """ -------------------------------------------------------------------------------- /train_flux/flux/condition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Union, List, Tuple 3 | from diffusers.pipelines import FluxPipeline 4 | from PIL import Image, ImageFilter 5 | import numpy as np 6 | import cv2 7 | 8 | from .pipeline_tools import encode_images 9 | 10 | condition_dict = { 11 | "depth": 0, 12 | "canny": 1, 13 | "subject": 4, 14 | "coloring": 6, 15 | "deblurring": 7, 16 | "depth_pred": 8, 17 | "fill": 9, 18 | "sr": 10, 19 | "cartoon": 11, 20 | "cot": 12, 21 | } 22 | 23 | 24 | class Condition(object): 25 | def __init__( 26 | self, 27 | condition_type: str, 28 | raw_img: Union[Image.Image, torch.Tensor] = None, 29 | condition: Union[Image.Image, torch.Tensor] = None, 30 | mask=None, 31 | position_delta=None, 32 | ) -> None: 33 | self.condition_type = condition_type 34 | assert raw_img is not None or condition is not None 35 | if raw_img is not None: 36 | self.condition = self.get_condition(condition_type, raw_img) 37 | else: 38 | self.condition = condition 39 | self.position_delta = position_delta 40 | # TODO: Add mask support 41 | assert mask is None, "Mask not supported yet" 42 | 43 | def get_condition( 44 | self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor] 45 | ) -> Union[Image.Image, torch.Tensor]: 46 | """ 47 | Returns the condition image. 48 | """ 49 | if condition_type == "depth": 50 | from transformers import pipeline 51 | 52 | depth_pipe = pipeline( 53 | task="depth-estimation", 54 | model="LiheYoung/depth-anything-small-hf", 55 | device="cuda", 56 | ) 57 | source_image = raw_img.convert("RGB") 58 | condition_img = depth_pipe(source_image)["depth"].convert("RGB") 59 | return condition_img 60 | elif condition_type == "canny": 61 | img = np.array(raw_img) 62 | edges = cv2.Canny(img, 100, 200) 63 | edges = Image.fromarray(edges).convert("RGB") 64 | return edges 65 | elif condition_type == "subject": 66 | return raw_img 67 | elif condition_type == "coloring": 68 | return raw_img.convert("L").convert("RGB") 69 | elif condition_type == "deblurring": 70 | condition_image = ( 71 | raw_img.convert("RGB") 72 | .filter(ImageFilter.GaussianBlur(10)) 73 | .convert("RGB") 74 | ) 75 | return condition_image 76 | elif condition_type == "fill": 77 | return raw_img.convert("RGB") 78 | elif condition_type == "cartoon": 79 | return raw_img.convert("RGB") 80 | return self.condition 81 | 82 | @property 83 | def type_id(self) -> int: 84 | """ 85 | Returns the type id of the condition. 86 | """ 87 | return condition_dict[self.condition_type] 88 | 89 | @classmethod 90 | def get_type_id(cls, condition_type: str) -> int: 91 | """ 92 | Returns the type id of the condition. 93 | """ 94 | return condition_dict[condition_type] 95 | 96 | def encode( 97 | self, pipe: FluxPipeline, empty: bool = False 98 | ) -> Tuple[torch.Tensor, torch.Tensor, int]: 99 | """ 100 | Encodes the condition into tokens, ids and type_id. 101 | """ 102 | if self.condition_type in [ 103 | "depth", 104 | "canny", 105 | "subject", 106 | "coloring", 107 | "deblurring", 108 | "depth_pred", 109 | "fill", 110 | "sr", 111 | "cartoon", 112 | "cot", 113 | ]: 114 | if empty: 115 | # make the condition black 116 | e_condition = Image.new("RGB", self.condition.size, (0, 0, 0)) 117 | e_condition = e_condition.convert("RGB") 118 | tokens, ids = encode_images(pipe, e_condition) 119 | else: 120 | tokens, ids = encode_images(pipe, self.condition) 121 | tokens, ids = encode_images(pipe, self.condition) 122 | else: 123 | raise NotImplementedError( 124 | f"Condition type {self.condition_type} not implemented" 125 | ) 126 | if self.position_delta is None and self.condition_type == "subject": 127 | self.position_delta = [0, -self.condition.size[0] // 16] 128 | if self.position_delta is not None: 129 | ids[:, 1] += self.position_delta[0] 130 | ids[:, 2] += self.position_delta[1] 131 | type_id = torch.ones_like(ids[:, :1]) * self.type_id 132 | return tokens, ids, type_id 133 | -------------------------------------------------------------------------------- /tts/tts_t2i_noise_scaling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import numpy as np 5 | import torch 6 | from diffusers import DiffusionPipeline 7 | from tqdm.auto import tqdm 8 | import copy 9 | from PIL import Image 10 | 11 | from utils import get_noises, TORCH_DTYPE_MAP, get_latent_prep_fn, parse_cli_args 12 | 13 | # Non-configurable constants 14 | MAX_SEED = np.iinfo(np.int32).max # To generate random seeds 15 | 16 | def sample( 17 | noises: dict[int, torch.Tensor], 18 | prompts: list[str], 19 | search_round: int, 20 | pipe: DiffusionPipeline, 21 | config: dict, 22 | original_prompt: str, 23 | midimg_path: str, 24 | ) -> dict: 25 | """ 26 | For a given prompt, generate images using all provided noises in batches, 27 | score them with the verifier, and select the top-K noise. 28 | The images and JSON artifacts are saved under `root_dir`. 29 | """ 30 | config_cp = copy.deepcopy(config) 31 | 32 | use_low_gpu_vram = config_cp.get("use_low_gpu_vram", False) 33 | batch_size_for_img_gen = config_cp.get("batch_size_for_img_gen", 1) 34 | 35 | images_for_prompt = [] 36 | noises_used = [] 37 | seeds_used = [] 38 | 39 | # Convert the noises dictionary into a list of (seed, noise) tuples. 40 | noise_items = list(noises.items()) 41 | 42 | # Process the noises in batches. 43 | full_imgnames = [] 44 | for i in range(0, len(noise_items), batch_size_for_img_gen): 45 | batch = noise_items[i : i + batch_size_for_img_gen] 46 | seeds_batch, noises_batch = zip(*batch) 47 | filenames_batch = [ 48 | os.path.join(midimg_path, f"{search_round}_round@{seed}.png") for seed in seeds_batch 49 | ] 50 | full_imgnames.extend(filenames_batch) 51 | 52 | if use_low_gpu_vram: 53 | pipe = pipe.to("cuda:0") 54 | print(f"Generating images for batch with seeds: {[s for s in seeds_batch]}.") 55 | 56 | # Create a batched prompt list and stack the latents. 57 | batched_latents = torch.stack(noises_batch).squeeze(dim=1) 58 | batched_prompts = prompts[i : i + batch_size_for_img_gen] 59 | # breakpoint() 60 | batch_result = pipe(prompt=batched_prompts, latents=batched_latents, guidance_scale=config_cp["pipeline_args"]["guidance_scale"], num_inference_steps=config_cp["pipeline_args"]["num_inference_steps"], height=config_cp["pipeline_args"]["height"], width=config_cp["pipeline_args"]["width"]) 61 | batch_images = batch_result.images 62 | if use_low_gpu_vram : 63 | pipe = pipe.to("cpu") 64 | 65 | # Iterate over the batch and save the images. 66 | for seed, noise, image, filename in zip(seeds_batch, noises_batch, batch_images, filenames_batch): 67 | images_for_prompt.append(image) 68 | noises_used.append(noise) 69 | seeds_used.append(seed) 70 | image.save(filename) 71 | 72 | datapoint = { 73 | "prompt": original_prompt, 74 | "search_round": search_round, 75 | "num_noises": len(noises), 76 | } 77 | return datapoint 78 | 79 | @torch.no_grad() 80 | def main(): 81 | """ 82 | Main function: 83 | - Parses CLI arguments. 84 | - Creates an output directory based on verifier and current datetime. 85 | - Loads prompts. 86 | - Loads the image-generation pipeline. 87 | - Loads the verifier model. 88 | - Runs several search rounds where for each prompt a pool of random noises is generated, 89 | candidate images are produced and verified, and the best noise is chosen. 90 | """ 91 | args = parse_cli_args() 92 | os.environ["API_KEY"] = os.environ["OPENAI_API_KEY"] # args.openai_api_key 93 | 94 | # Build a config dictionary for parameters that need to be passed around. 95 | with open(args.pipeline_config_path, "r") as f: 96 | config = json.load(f) 97 | 98 | config.update(vars(args)) 99 | 100 | search_rounds = config["search_args"]["search_rounds"] 101 | search_branch = config["search_args"]["search_branch"] 102 | 103 | # Create a root output directory: output/{verifier_to_use}/{current_datetime} 104 | pipeline_name = config["pipeline_args"].get("pretrained_model_name_or_path") 105 | cache_dir = config["pipeline_args"]["cache_dir"] 106 | root_dir = config["output_dir"] 107 | os.makedirs(root_dir, exist_ok=True) 108 | 109 | # Set up the image-generation pipeline (on the first GPU if available). 110 | torch_dtype = TORCH_DTYPE_MAP[config["pipeline_args"].get("torch_dtype")] 111 | pipe = DiffusionPipeline.from_pretrained(pipeline_name, torch_dtype=torch_dtype, cache_dir=cache_dir) 112 | if not config["use_low_gpu_vram"]: 113 | pipe = pipe.to("cuda:0") 114 | pipe.set_progress_bar_config(disable=True) 115 | 116 | # Main loop: For each search round and each prompt, generate images, verify, and save artifacts. 117 | with open(args.meta_path) as fp: 118 | metadatas = [json.loads(line) for line in fp] 119 | 120 | # meta splits 121 | if args.end_index == -1: 122 | metadatas = metadatas[args.start_index:] 123 | else: 124 | metadatas = metadatas[args.start_index:args.end_index] 125 | 126 | for index, metadata in tqdm(enumerate(metadatas), desc="Sampling prompts"): 127 | original_prompt = metadata['prompt'] 128 | current_prompts = [original_prompt] * search_branch 129 | # create output directory 130 | outpath = os.path.join(root_dir, f"{index + args.start_index:0>5}") 131 | os.makedirs(outpath, exist_ok=True) 132 | 133 | # create middle img directory 134 | midimg_path = os.path.join(outpath, "samples") 135 | os.makedirs(midimg_path, exist_ok=True) 136 | 137 | # create metadata file 138 | with open(os.path.join(outpath, "metadata.jsonl"), "w") as fp: 139 | json.dump(metadata, fp) 140 | 141 | for round in range(1, search_rounds + 1): 142 | print(f"\n=== Round: {round} ===") 143 | noises = get_noises( 144 | max_seed=MAX_SEED, 145 | num_samples=search_branch, 146 | height=config["pipeline_args"]["height"], 147 | width=config["pipeline_args"]["width"], 148 | dtype=torch_dtype, 149 | fn=get_latent_prep_fn(pipeline_name), 150 | ) 151 | datapoint = sample( 152 | noises=noises, 153 | prompts=current_prompts, 154 | search_round=round, 155 | pipe=pipe, 156 | config=config, 157 | original_prompt=original_prompt, 158 | midimg_path=midimg_path, 159 | ) 160 | 161 | 162 | if __name__ == "__main__": 163 | main() 164 | -------------------------------------------------------------------------------- /train_flux/train/train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torch 3 | import lightning as L 4 | import yaml 5 | import os 6 | import time 7 | 8 | from diffusers.utils.logging import set_verbosity_error 9 | set_verbosity_error() 10 | 11 | from .data import ImageConditionWebDataset 12 | from .model import OminiModel 13 | from .callbacks import TrainingCallback 14 | 15 | def get_rank(): 16 | try: 17 | rank = int(os.environ.get("LOCAL_RANK")) 18 | except: 19 | rank = 0 20 | return rank 21 | 22 | def get_config(): 23 | config_path = os.environ.get("XFL_CONFIG") 24 | assert config_path is not None, "Please set the XFL_CONFIG environment variable" 25 | with open(config_path, "r") as f: 26 | config = yaml.safe_load(f) 27 | return config 28 | 29 | 30 | def init_wandb(wandb_config, run_name): 31 | import wandb 32 | 33 | try: 34 | assert os.environ.get("WANDB_API_KEY") is not None 35 | wandb.init( 36 | project=wandb_config["project"], 37 | name=wandb_config["name"] if wandb_config["name"] else run_name, 38 | config={}, 39 | settings=wandb.Settings(start_method="fork"), 40 | ) 41 | except Exception as e: 42 | print("Failed to initialize WanDB:", e) 43 | 44 | 45 | def main(): 46 | # Initialize 47 | is_main_process, rank = get_rank() == 0, get_rank() 48 | torch.cuda.set_device(rank) 49 | config = get_config() 50 | training_config = config["train"] 51 | run_name = time.strftime("%Y%m%d-%H%M%S") 52 | 53 | # Initialize WanDB 54 | wandb_config = training_config.get("wandb", None) 55 | if wandb_config is not None and is_main_process: 56 | init_wandb(wandb_config, run_name) 57 | 58 | print("Rank:", rank) 59 | if is_main_process: 60 | print("Config:", config) 61 | 62 | # Initialize dataset and dataloader 63 | if training_config["dataset"]["type"] == "img": 64 | dataset = ImageConditionWebDataset( 65 | training_config["dataset"]["path"], 66 | condition_size=training_config["dataset"]["condition_size"], 67 | target_size=training_config["dataset"]["target_size"], 68 | condition_type=training_config["condition_type"], 69 | drop_text_prob=training_config["dataset"]["drop_text_prob"], 70 | drop_image_prob=training_config["dataset"]["drop_image_prob"], 71 | drop_reflection_prob=training_config["dataset"]["drop_reflection_prob"], 72 | root_dir=training_config["dataset"]["root_dir"], 73 | split_ratios=training_config["dataset"]["split_ratios"], 74 | training_stages=training_config["dataset"]["training_stages"], 75 | ) 76 | if "val_path" in training_config["dataset"]: 77 | val_dataset = ImageConditionWebDataset( 78 | training_config["dataset"]["val_path"], 79 | condition_size=training_config["dataset"]["condition_size"], 80 | target_size=training_config["dataset"]["target_size"], 81 | condition_type=training_config["condition_type"], 82 | root_dir=training_config["dataset"]["val_root_dir"], 83 | drop_text_prob=0, 84 | drop_image_prob=0, 85 | drop_reflection_prob=0, 86 | shuffle_buffer=0, 87 | ) 88 | else: 89 | val_dataset = None 90 | else: 91 | raise NotImplementedError 92 | 93 | train_loader = DataLoader( 94 | dataset, 95 | batch_size=training_config["batch_size"], 96 | num_workers=training_config["dataloader_workers"], 97 | ) 98 | if val_dataset is not None: 99 | val_loader = DataLoader( 100 | val_dataset, 101 | batch_size=1, 102 | shuffle=False, 103 | num_workers=0 104 | ) 105 | else: 106 | val_loader = None 107 | 108 | # Try add resume training 109 | lora_path = None 110 | 111 | if training_config['resume_training_from_last_checkpoint'] and os.path.exists(training_config['save_path']): 112 | # get latest directory in training_config['save_path'], ignore hidden files 113 | all_training_sessions = [d for d in os.listdir(training_config['save_path']) if not d.startswith('.')] 114 | all_training_sessions.sort(reverse=True) 115 | last_training_session = all_training_sessions[0] 116 | if os.path.exists(f"{training_config['save_path']}/{last_training_session}/ckpt"): 117 | ckpt_paths = [d for d in os.listdir(f"{training_config['save_path']}/{last_training_session}/ckpt") if not d.startswith('.')] 118 | ckpt_paths.sort(reverse=True) 119 | lora_path = f"{training_config['save_path']}/{last_training_session}/ckpt/{ckpt_paths[0]}" 120 | print(f"Resuming training from {lora_path}") 121 | else: 122 | print("No checkpoint found. Training without LoRA weights.") 123 | 124 | elif training_config['resume_training_from_checkpoint_path'] != "": 125 | _lora_path = training_config['resume_training_from_checkpoint_path'] 126 | # Check if the path exists 127 | if os.path.exists(_lora_path): 128 | lora_path = _lora_path 129 | print(f"Training with LoRA weights from {_lora_path}") 130 | else: 131 | print(f"Path {_lora_path} does not exist. Training without LoRA weights.") 132 | 133 | # Initialize model 134 | trainable_model = OminiModel( 135 | flux_pipe_id=config["model_path"], 136 | lora_path=lora_path, 137 | lora_config=training_config["lora_config"], 138 | data_config=training_config["dataset"], 139 | device=f"cuda:{rank}", 140 | dtype=getattr(torch, config["dtype"]), 141 | optimizer_config=training_config["optimizer"], 142 | model_config=config.get("model", {}), 143 | gradient_checkpointing=training_config.get("gradient_checkpointing", False), 144 | save_path=training_config.get("save_path", "./output"), 145 | run_name=run_name, 146 | cache_dir=config["cache_dir"], 147 | ) 148 | 149 | # Callbacks for logging and saving checkpoints 150 | training_callbacks = ( 151 | [TrainingCallback(run_name, training_config=training_config)] 152 | if is_main_process 153 | else [] 154 | ) 155 | 156 | # Initialize trainer 157 | trainer = L.Trainer( 158 | accumulate_grad_batches=training_config["accumulate_grad_batches"], 159 | callbacks=training_callbacks, 160 | enable_checkpointing=False, 161 | enable_progress_bar=False, 162 | logger=False, 163 | max_steps=training_config.get("max_steps", -1), 164 | max_epochs=training_config.get("max_epochs", -1), 165 | gradient_clip_val=training_config.get("gradient_clip_val", 0.5), 166 | val_check_interval=training_config.get("sample_interval", 1000), 167 | num_sanity_val_steps=training_config.get("num_sanity_val_steps", -1), 168 | ) 169 | 170 | setattr(trainer, "training_config", training_config) 171 | 172 | # Save config 173 | save_path = training_config.get("save_path", "./output") 174 | if is_main_process: 175 | os.makedirs(f"{save_path}/{run_name}") 176 | os.makedirs(f"{save_path}/{run_name}/val") 177 | with open(f"{save_path}/{run_name}/config.yaml", "w") as f: 178 | yaml.dump(config, f) 179 | 180 | # Start training 181 | trainer.fit(trainable_model, train_loader, val_loader) 182 | 183 | 184 | if __name__ == "__main__": 185 | main() 186 | -------------------------------------------------------------------------------- /reward_modeling/test_reward.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import os 4 | import pdb 5 | import argparse 6 | from collections.abc import Mapping 7 | import pandas as pd 8 | from tqdm import tqdm 9 | import sys 10 | sys.path.append('ReflectionFlow') 11 | sys.path.append('reward_modeling') 12 | 13 | import torch 14 | from vision_process import process_vision_info 15 | 16 | 17 | from data import DataConfig 18 | from reward_modeling.utils import ModelConfig, PEFTLoraConfig, TrainingConfig 19 | from reward_modeling.utils import load_model_from_checkpoint 20 | from train_reward import create_model_and_processor 21 | from prompt_template import build_prompt 22 | 23 | 24 | def load_configs_from_json(config_path): 25 | with open(config_path, "r") as f: 26 | config_dict = json.load(f) 27 | 28 | del config_dict["data_config"]["meta_data"] 29 | del config_dict["data_config"]["data_dir"] 30 | 31 | return config_dict["data_config"], None, config_dict["model_config"], config_dict["peft_lora_config"], \ 32 | config_dict["inference_config"] if "inference_config" in config_dict else None 33 | 34 | 35 | class ImageVLMRewardInference(): 36 | def __init__(self, load_from_pretrained, load_from_pretrained_step=-1, device='cuda', dtype=torch.bfloat16): 37 | config_path = os.path.join(load_from_pretrained, "model_config.json") 38 | data_config, _, model_config, peft_lora_config, inference_config = load_configs_from_json(config_path) 39 | data_config = DataConfig(**data_config) 40 | model_config = ModelConfig(**model_config) 41 | peft_lora_config = PEFTLoraConfig(**peft_lora_config) 42 | 43 | training_args = TrainingConfig( 44 | load_from_pretrained=load_from_pretrained, 45 | load_from_pretrained_step=load_from_pretrained_step, 46 | gradient_checkpointing=False, 47 | disable_flash_attn2=False, 48 | bf16=True if dtype == torch.bfloat16 else False, 49 | fp16=True if dtype == torch.float16 else False, 50 | output_dir="", 51 | ) 52 | 53 | model, processor, peft_config = create_model_and_processor( 54 | model_config=model_config, 55 | peft_lora_config=peft_lora_config, 56 | training_args=training_args, 57 | cache_dir="/ibex/user/zhaol0c/uniediting_continue/our_reward/initialreward" 58 | ) 59 | 60 | self.device = device 61 | 62 | model, checkpoint_step = load_model_from_checkpoint(model, load_from_pretrained, load_from_pretrained_step) 63 | model.eval() 64 | 65 | self.model = model 66 | self.processor = processor 67 | 68 | self.model.to(self.device) 69 | 70 | self.data_config = data_config 71 | 72 | self.inference_config = inference_config 73 | 74 | def _norm(self, reward): 75 | if self.inference_config is None: 76 | return reward 77 | else: 78 | reward['VQ'] = (reward['VQ'] - self.inference_config['VQ_mean']) / self.inference_config['VQ_std'] 79 | return reward 80 | 81 | def _pad_sequence(self, sequences, attention_mask, max_len, padding_side='right'): 82 | assert padding_side in ['right', 'left'] 83 | if sequences.shape[1] >= max_len: 84 | return sequences, attention_mask 85 | 86 | pad_len = max_len - sequences.shape[1] 87 | padding = (0, pad_len) if padding_side == 'right' else (pad_len, 0) 88 | 89 | sequences_padded = torch.nn.functional.pad(sequences, padding, 'constant', self.processor.tokenizer.pad_token_id) 90 | attention_mask_padded = torch.nn.functional.pad(attention_mask, padding, 'constant', 0) 91 | 92 | return sequences_padded, attention_mask_padded 93 | 94 | def _prepare_input(self, data): 95 | if isinstance(data, Mapping): 96 | return type(data)({k: self._prepare_input(v) for k, v in data.items()}) 97 | elif isinstance(data, (tuple, list)): 98 | return type(data)(self._prepare_input(v) for v in data) 99 | elif isinstance(data, torch.Tensor): 100 | kwargs = {"device": self.device} 101 | return data.to(**kwargs) 102 | return data 103 | 104 | def _prepare_inputs(self, inputs): 105 | inputs = self._prepare_input(inputs) 106 | if len(inputs) == 0: 107 | raise ValueError 108 | return inputs 109 | 110 | def prepare_batch(self, image_paths, prompts, max_pixels=None): 111 | max_pixels = self.data_config.max_frame_pixels if max_pixels is None else max_pixels 112 | 113 | chat_data = [ 114 | [ 115 | { 116 | "role": "user", 117 | "content": [ 118 | { 119 | "type": "image", 120 | "image": image_path, 121 | "max_pixels": max_pixels, 122 | }, 123 | {"type": "text", "text": build_prompt(prompt, self.data_config.eval_dim, self.data_config.prompt_template_type)}, 124 | ], 125 | }, 126 | ] for image_path, prompt in zip(image_paths, prompts) 127 | ] 128 | image_inputs, video_inputs = process_vision_info(chat_data) 129 | 130 | batch = self.processor( 131 | text=self.processor.apply_chat_template(chat_data, tokenize=False, add_generation_prompt=True), 132 | images=image_inputs, 133 | videos=video_inputs, 134 | padding=True, 135 | return_tensors="pt", 136 | videos_kwargs={"do_rescale": True}, 137 | ) 138 | batch = self._prepare_inputs(batch) 139 | 140 | return batch 141 | 142 | def reward(self, image_paths, prompts, max_pixels=None, use_norm=True): 143 | batch = self.prepare_batch(image_paths, prompts, max_pixels) 144 | rewards = self.model( 145 | return_dict=True, 146 | **batch 147 | )["logits"] 148 | 149 | rewards = [{'VQ': reward[0].item()} for reward in rewards] 150 | for i in range(len(rewards)): 151 | if use_norm: 152 | rewards[i] = self._norm(rewards[i]) 153 | rewards[i]['Overall'] = rewards[i]['VQ'] 154 | 155 | return rewards 156 | 157 | 158 | if __name__ == "__main__": 159 | parser = argparse.ArgumentParser(description="Video Alignment Reward Inference") 160 | parser.add_argument("--load_from_pretrained", type=str, required=True, 161 | help="Path to pretrained model") 162 | parser.add_argument("--device", type=str, default="cuda", 163 | help="Device to run inference on") 164 | parser.add_argument("--ckpt_step", type=int, default=-1, 165 | help="Checkpoint step for processing") 166 | args = parser.parse_args() 167 | 168 | inferencer = ImageVLMRewardInference(args.load_from_pretrained, load_from_pretrained_step=args.ckpt_step, device=args.device, dtype=torch.bfloat16) 169 | 170 | # 手动输入图像路径和 Prompt 171 | image_paths = ["/ibex/user/zhaol0c/uniediting_continue/nvilaverifier_exps/b2_d16_6000model/00548/samples_best/00001.png", "/ibex/user/zhaol0c/uniediting_continue/nvilaverifier_exps/b2_d16_6000model/00548/midimg/14_round@958442093.png"] 172 | prompts = ["a photo of a yellow bicycle and a red motorcycle", "a photo of a yellow bicycle and a red motorcycle"] 173 | 174 | # 进行打分 175 | with torch.no_grad(): 176 | rewards = inferencer.reward(image_paths, prompts, use_norm=True) 177 | breakpoint() 178 | print(f"Rewards: {rewards}") -------------------------------------------------------------------------------- /tts/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.utils.torch_utils import randn_tensor 3 | from diffusers import FluxPipeline 4 | import re 5 | import hashlib 6 | from typing import Dict 7 | import json 8 | from typing import Union 9 | from PIL import Image 10 | import requests 11 | import argparse 12 | import io 13 | 14 | 15 | TORCH_DTYPE_MAP = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} 16 | MODEL_NAME_MAP = { 17 | "black-forest-labs/FLUX.1-dev": "flux.1-dev", 18 | "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "pixart-sigma-1024-ms", 19 | "stabilityai/stable-diffusion-xl-base-1.0": "sdxl-base", 20 | "stable-diffusion-v1-5/stable-diffusion-v1-5": "sd-v1.5", 21 | } 22 | 23 | 24 | def parse_cli_args(): 25 | """ 26 | Parse and return CLI arguments. 27 | """ 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument( 30 | "--pipeline_config_path", 31 | type=str, 32 | default="configs/flux.1_dev_nvilascore.json", 33 | help="Pipeline configuration path that should include loading info and __call__() args and their values.", 34 | ) 35 | parser.add_argument( 36 | "--start_index", 37 | type=int, 38 | default=0, 39 | help="start index", 40 | ) 41 | parser.add_argument( 42 | "--end_index", 43 | type=int, 44 | default=-1, 45 | help="end index", 46 | ) 47 | parser.add_argument( 48 | "--imgpath", 49 | type=str, 50 | default="", 51 | help="path to generated images and their metadata", 52 | ) 53 | parser.add_argument( 54 | "--output_dir", 55 | type=str, 56 | default="output", 57 | help="output directory", 58 | ) 59 | parser.add_argument( 60 | "--meta_path", 61 | type=str, 62 | default="meta.jsonl", 63 | help="meta path", 64 | ) 65 | args = parser.parse_args() 66 | 67 | return args 68 | 69 | 70 | # Adapted from Diffusers. 71 | def prepare_latents_for_flux( 72 | batch_size: int, 73 | height: int, 74 | width: int, 75 | generator: torch.Generator, 76 | device: str, 77 | dtype: torch.dtype, 78 | ) -> torch.Tensor: 79 | num_latent_channels = 16 80 | vae_scale_factor = 8 81 | 82 | height = 2 * (int(height) // (vae_scale_factor * 2)) 83 | width = 2 * (int(width) // (vae_scale_factor * 2)) 84 | shape = (batch_size, num_latent_channels, height, width) 85 | latents = randn_tensor(shape, generator=generator, device=torch.device(device), dtype=dtype) 86 | latents = FluxPipeline._pack_latents(latents, batch_size, num_latent_channels, height, width) 87 | return latents 88 | 89 | 90 | # Adapted from Diffusers. 91 | def prepare_latents( 92 | batch_size: int, height: int, width: int, generator: torch.Generator, device: str, dtype: torch.dtype 93 | ): 94 | num_channels_latents = 4 95 | vae_scale_factor = 8 96 | shape = ( 97 | batch_size, 98 | num_channels_latents, 99 | int(height) // vae_scale_factor, 100 | int(width) // vae_scale_factor, 101 | ) 102 | latents = randn_tensor(shape, generator=generator, device=torch.device(device), dtype=dtype) 103 | return latents 104 | 105 | def prepare_latents_for_sd3( 106 | batch_size: int, height: int, width: int, generator: torch.Generator, device: str, dtype: torch.dtype 107 | ): 108 | num_channels_latents = 16 109 | vae_scale_factor = 8 110 | shape = ( 111 | batch_size, 112 | num_channels_latents, 113 | int(height) // vae_scale_factor, 114 | int(width) // vae_scale_factor, 115 | ) 116 | latents = randn_tensor(shape, generator=generator, device=torch.device(device), dtype=dtype) 117 | return latents 118 | 119 | 120 | def get_latent_prep_fn(pretrained_model_name_or_path: str) -> callable: 121 | fn_map = { 122 | "black-forest-labs/FLUX.1-dev": prepare_latents_for_flux, 123 | "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": prepare_latents, 124 | "stabilityai/stable-diffusion-xl-base-1.0": prepare_latents, 125 | "stable-diffusion-v1-5/stable-diffusion-v1-5": prepare_latents, 126 | "stabilityai/stable-diffusion-3-medium-diffusers": prepare_latents_for_sd3, 127 | }[pretrained_model_name_or_path] 128 | return fn_map 129 | 130 | 131 | def get_noises( 132 | max_seed: int, 133 | num_samples: int, 134 | height: int, 135 | width: int, 136 | device="cuda", 137 | dtype: torch.dtype = torch.bfloat16, 138 | fn: callable = prepare_latents_for_flux, 139 | ) -> Dict[int, torch.Tensor]: 140 | seeds = torch.randint(0, high=max_seed, size=(num_samples,)) 141 | 142 | noises = {} 143 | for noise_seed in seeds: 144 | latents = fn( 145 | batch_size=1, 146 | height=height, 147 | width=width, 148 | generator=torch.manual_seed(int(noise_seed)), 149 | device=device, 150 | dtype=dtype, 151 | ) 152 | noises.update({int(noise_seed): latents}) 153 | 154 | assert len(noises) == len(seeds) 155 | return noises 156 | 157 | def load_verifier_prompt(path: str): 158 | # 判断文件类型 159 | if path.endswith(".txt"): 160 | # 处理文本文件 161 | with open(path, "r") as f: 162 | verifier_prompt = f.read().replace('"""', "") 163 | return verifier_prompt 164 | elif path.endswith(".json"): 165 | # 处理 JSON 文件 166 | with open(path, "r") as f: 167 | data = json.load(f) # 加载 JSON 文件 168 | return data 169 | else: 170 | # 如果文件格式不支持,抛出异常 171 | raise ValueError("Unsupported file type. Please provide a .txt or .json file.") 172 | 173 | 174 | def prompt_to_filename(prompt, max_length=100): 175 | """Thanks ChatGPT.""" 176 | filename = re.sub(r"[^a-zA-Z0-9]", "_", prompt.strip()) 177 | filename = re.sub(r"_+", "_", filename) 178 | hash_digest = hashlib.sha256(prompt.encode()).hexdigest()[:8] 179 | base_filename = f"prompt@{filename}_hash@{hash_digest}" 180 | 181 | if len(base_filename) > max_length: 182 | base_length = max_length - len(hash_digest) - 7 183 | base_filename = f"prompt@{filename[:base_length]}_hash@{hash_digest}" 184 | 185 | return base_filename 186 | 187 | 188 | def load_image(path_or_url: Union[str, Image.Image]) -> Image.Image: 189 | """ 190 | Load an image from a local path or a URL and return a PIL Image object. 191 | 192 | `path_or_url` is returned as is if it's an `Image` already. 193 | """ 194 | if isinstance(path_or_url, Image.Image): 195 | return path_or_url 196 | elif path_or_url.startswith("http"): 197 | response = requests.get(path_or_url, stream=True) 198 | response.raise_for_status() 199 | return Image.open(io.BytesIO(response.content)) 200 | return Image.open(path_or_url) 201 | 202 | 203 | def convert_to_bytes(path_or_url: Union[str, Image.Image]) -> bytes: 204 | """Load an image from a path or URL and convert it to bytes.""" 205 | image = load_image(path_or_url).convert("RGB") 206 | image_bytes_io = io.BytesIO() 207 | image.save(image_bytes_io, format="PNG") 208 | return image_bytes_io.getvalue() 209 | 210 | 211 | def recover_json_from_output(output: str): 212 | start = output.find("{") 213 | end = output.rfind("}") + 1 214 | json_part = output[start:end] 215 | return json.loads(json_part) 216 | 217 | 218 | def get_batches(items, batch_size): 219 | num_batches = (len(items) + batch_size - 1) // batch_size 220 | batches = [] 221 | 222 | for i in range(num_batches): 223 | start_index = i * batch_size 224 | end_index = min((i + 1) * batch_size, len(items)) 225 | batch = items[start_index:end_index] 226 | batches.append(batch) 227 | 228 | return batches 229 | -------------------------------------------------------------------------------- /tts/verifier_filter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from datetime import datetime 4 | 5 | import numpy as np 6 | import torch 7 | from diffusers import DiffusionPipeline 8 | from tqdm.auto import tqdm 9 | import copy 10 | from typing import Union, List, Optional 11 | from PIL import Image 12 | import sys 13 | sys.path.append('../') 14 | from train_flux.flux.generate import generate 15 | from train_flux.flux.condition import Condition 16 | import time 17 | from verifiers.openai_verifier import OpenAIVerifier 18 | from verifiers.nvila_verifier import load_model 19 | 20 | from utils import prompt_to_filename, get_noises, TORCH_DTYPE_MAP, get_latent_prep_fn, parse_cli_args, MODEL_NAME_MAP 21 | 22 | # Non-configurable constants 23 | MAX_SEED = np.iinfo(np.int32).max # To generate random seeds 24 | MAX_RETRIES = 5 25 | RETRY_DELAY = 2 26 | 27 | @torch.no_grad() 28 | def main(): 29 | """ 30 | Main function: 31 | - Parses CLI arguments. 32 | - Creates an output directory based on verifier and current datetime. 33 | - Loads prompts. 34 | - Loads the image-generation pipeline. 35 | - Loads the verifier model. 36 | - Runs several search rounds where for each prompt a pool of random noises is generated, 37 | candidate images are produced and verified, and the best noise is chosen. 38 | """ 39 | args = parse_cli_args() 40 | 41 | # Build a config dictionary for parameters that need to be passed around. 42 | with open(args.pipeline_config_path, "r") as f: 43 | config = json.load(f) 44 | 45 | config.update(vars(args)) 46 | 47 | ### load nvila verifier for scoring 48 | verifier_args = config["verifier_args"] 49 | verifier, yes_id, no_id = load_model(model_name=verifier_args["model_name"], cache_dir=verifier_args["cache_dir"]) 50 | 51 | metadatas = [] 52 | for folder_name in sorted(os.listdir(args.imgpath)): 53 | folder_path = os.path.join(args.imgpath, folder_name) 54 | 55 | if os.path.isdir(folder_path): 56 | metadata_path = os.path.join(folder_path, 'metadata.jsonl') 57 | midimg_path = os.path.join(folder_path, 'midimg') 58 | 59 | with open(metadata_path, "r") as f: 60 | metadata = [json.loads(line) for line in f] 61 | folder_data = { 62 | 'metadata': metadata, 63 | 'images': [] 64 | } 65 | 66 | round_images = {} 67 | for file in sorted(os.listdir(midimg_path)): 68 | if file.endswith('.png'): 69 | round_key = file.split('_round@')[0] 70 | if round_key not in round_images: 71 | round_images[round_key] = [] 72 | round_images[round_key].append(file) 73 | 74 | # breakpoint() 75 | round_images = dict(sorted(round_images.items(), key=lambda x: int(x[0]))) 76 | for round_key, files in round_images.items(): 77 | for file in files: 78 | img_path = os.path.join(midimg_path, file) 79 | folder_data['images'].append({'img_path': img_path}) 80 | 81 | metadatas.append(folder_data) 82 | 83 | # meta splits 84 | if args.end_index == -1: 85 | metadatas = metadatas[args.start_index:] 86 | else: 87 | metadatas = metadatas[args.start_index:args.end_index] 88 | 89 | for index, metadata in tqdm(enumerate(metadatas), desc="Sampling data"): 90 | prompt = metadata['metadata'][0]['prompt'] 91 | imgs = metadata['images'] 92 | imgs = [tmp['img_path'] for tmp in imgs] 93 | cur_dir = os.path.dirname(imgs[0]) 94 | cur_dir = os.path.dirname(cur_dir) 95 | nfe1_path = os.path.join(cur_dir, "nfe1") 96 | nfe2_path = os.path.join(cur_dir, "nfe2") 97 | nfe4_path = os.path.join(cur_dir, "nfe4") 98 | nfe8_path = os.path.join(cur_dir, "nfe8") 99 | nfe16_path = os.path.join(cur_dir, "nfe16") 100 | nfe32_path = os.path.join(cur_dir, "nfe32") 101 | os.makedirs(nfe1_path, exist_ok=True) 102 | os.makedirs(nfe2_path, exist_ok=True) 103 | os.makedirs(nfe4_path, exist_ok=True) 104 | os.makedirs(nfe8_path, exist_ok=True) 105 | os.makedirs(nfe16_path, exist_ok=True) 106 | os.makedirs(nfe32_path, exist_ok=True) 107 | 108 | start_time = time.time() 109 | 110 | outputs = [] 111 | # nvila verifier 112 | for imgname in imgs: 113 | r1, scores1 = verifier.generate_content([Image.open(imgname), prompt]) 114 | if r1 == "yes": 115 | outputs.append({"image_name": imgname, "label": "yes", "score": scores1[0][0, yes_id].detach().cpu().float().item()}) 116 | else: 117 | outputs.append({"image_name": imgname, "label": "no", "score": scores1[0][0, no_id].detach().cpu().float().item()}) 118 | 119 | end_time = time.time() 120 | print(f"Time taken for evaluation: {end_time - start_time} seconds") 121 | 122 | # nvila verfier filter rule 123 | def f(x): 124 | if x["label"] == "yes": 125 | return (0, -x["score"]) 126 | else: 127 | return (1, x["score"]) 128 | 129 | # do nfe1 130 | sorted_list_nfe1 = sorted(outputs[:1], key=lambda x: f(x)) 131 | topk_scores_nfe1 = sorted_list_nfe1 132 | topk_idx_nfe1 = [outputs.index(x) for x in topk_scores_nfe1] 133 | selected_imgs_nfe1 = [imgs[i] for i in topk_idx_nfe1] 134 | img = Image.open(selected_imgs_nfe1[0]) 135 | img.save(os.path.join(nfe1_path, f"{0:05}.png")) 136 | 137 | # do nfe2 138 | sorted_list_nfe2 = sorted(outputs[:2], key=lambda x: f(x)) 139 | topk_scores_nfe2 = sorted_list_nfe2 140 | topk_idx_nfe2 = [outputs.index(x) for x in topk_scores_nfe2] 141 | selected_imgs_nfe2 = [imgs[i] for i in topk_idx_nfe2] 142 | img = Image.open(selected_imgs_nfe2[0]) 143 | img.save(os.path.join(nfe2_path, f"{0:05}.png")) 144 | 145 | # do nfe4 146 | sorted_list_nfe4 = sorted(outputs[:4], key=lambda x: f(x)) 147 | topk_scores_nfe4 = sorted_list_nfe4[:4] 148 | topk_idx_nfe4 = [outputs.index(x) for x in topk_scores_nfe4] 149 | selected_imgs_nfe4 = [imgs[i] for i in topk_idx_nfe4] 150 | img = Image.open(selected_imgs_nfe4[0]) 151 | img.save(os.path.join(nfe4_path, f"{0:05}.png")) 152 | 153 | # do nfe8 154 | sorted_list_nfe8 = sorted(outputs[:8], key=lambda x: f(x)) 155 | topk_scores_nfe8 = sorted_list_nfe8[:8] 156 | topk_idx_nfe8 = [outputs.index(x) for x in topk_scores_nfe8] 157 | selected_imgs_nfe8 = [imgs[i] for i in topk_idx_nfe8] 158 | img = Image.open(selected_imgs_nfe8[0]) 159 | img.save(os.path.join(nfe8_path, f"{0:05}.png")) 160 | 161 | # do nfe16 162 | sorted_list_nfe16 = sorted(outputs[:16], key=lambda x: f(x)) 163 | topk_scores_nfe16 = sorted_list_nfe16[:16] 164 | topk_idx_nfe16 = [outputs.index(x) for x in topk_scores_nfe16] 165 | selected_imgs_nfe16 = [imgs[i] for i in topk_idx_nfe16] 166 | img = Image.open(selected_imgs_nfe16[0]) 167 | img.save(os.path.join(nfe16_path, f"{0:05}.png")) 168 | 169 | # do nfe32 170 | # breakpoint() 171 | sorted_list_nfe32 = sorted(outputs[:32], key=lambda x: f(x)) 172 | topk_scores_nfe32 = sorted_list_nfe32[:32] 173 | topk_idx_nfe32 = [outputs.index(x) for x in topk_scores_nfe32] 174 | selected_imgs_nfe32 = [imgs[i] for i in topk_idx_nfe32] 175 | img = Image.open(selected_imgs_nfe32[0]) 176 | img.save(os.path.join(nfe32_path, f"{0:05}.png")) 177 | 178 | 179 | if __name__ == "__main__": 180 | main() 181 | -------------------------------------------------------------------------------- /train_flux/sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import random 5 | import torch 6 | import numpy as np 7 | import argparse 8 | from diffusers.pipelines import FluxPipeline, StableDiffusion3Pipeline 9 | from src.flux.condition import Condition 10 | from src.sd3.condition import Condition as ConditionSD3 11 | from PIL import Image 12 | 13 | from src.flux.generate import generate, seed_everything 14 | from src.sd3.generate import generate as generate_sd3 15 | 16 | # Parse command line arguments 17 | parser = argparse.ArgumentParser(description="Run FLUX pipeline with reflection prompts") 18 | parser.add_argument("--model_name", type=str, default="flux", help="Model name") 19 | parser.add_argument("--step", type=int, default=30, help="Number of inference steps") 20 | parser.add_argument("--condition_size", type=int, default=1024, help="Size of condition image") 21 | parser.add_argument("--target_size", type=int, default=1024, help="Size of target image") 22 | parser.add_argument("--task_name", type=str, default="geneval", 23 | choices=["edit", "geneval", "flux_pro_short", "flux_pro_detailed", "flux_pro"], 24 | help="Task name for selecting the appropriate dataset") 25 | parser.add_argument("--lora_dir", type=str, 26 | default="/mnt/petrelfs/gaopeng/zl/ReflectionFlow/train_flux/runs/full_data_v2/20250227-040606/ckpt/5000/pytorch_lora_weights.safetensors", 27 | help="Path to LoRA weights") 28 | parser.add_argument("--output_dir", type=str, 29 | default="/mnt/petrelfs/gaopeng/zl/ReflectionFlow/train_flux/samples/full_data_1024_5k", 30 | help="Base directory for output") 31 | parser.add_argument("--root_dir", type=str, default="", help="Root directory for image paths") 32 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 33 | parser.add_argument("--guidance_scale", type=float, default=3.5, help="Guidance scale") 34 | parser.add_argument("--image_guidance_scale", type=float, default=1.0, help="Guidance scale") 35 | 36 | args = parser.parse_args() 37 | 38 | # Set variables from parsed arguments 39 | step = args.step 40 | condition_size = args.condition_size 41 | target_size = args.target_size 42 | task_name = args.task_name 43 | lora_dir = args.lora_dir 44 | output_dir = args.output_dir 45 | root_dir = args.root_dir 46 | 47 | # Set json_dir based on task_name 48 | if task_name == "edit": 49 | json_dir = "/mnt/petrelfs/gaopeng/zl/data/reflection/metadata/edit_reflection_cleaned_val.json" 50 | output_dir = os.path.join(output_dir, "edit") 51 | elif task_name == "geneval": 52 | json_dir = "/mnt/petrelfs/zhuole/data/metadata_clean/geneval_pairs_val.json" 53 | output_dir = os.path.join(output_dir, "geneval") 54 | elif task_name == "flux_pro_short": 55 | json_dir = "/mnt/petrelfs/gaopeng/zl/data/reflection/metadata/flux_pro_detailed_reflection_cleaned_val.json" 56 | output_dir = os.path.join(output_dir, "flux_pro_short") 57 | elif task_name == "flux_pro_detailed": 58 | json_dir = "/mnt/petrelfs/gaopeng/zl/data/reflection/metadata/flux_pro_detailed_reflection_cleaned_val.json" 59 | output_dir = os.path.join(output_dir, "flux_pro_detailed") 60 | elif task_name == "flux_pro": 61 | json_dir = "/mnt/petrelfs/gaopeng/zl/data/reflection/metadata/flux_pro_reflection_cleaned_val.json" 62 | output_dir = os.path.join(output_dir, "flux_pro") 63 | else: 64 | raise ValueError(f"Invalid task name: {task_name}") 65 | os.makedirs(output_dir, exist_ok=True) 66 | 67 | if args.model_name == "flux": 68 | pipe = FluxPipeline.from_pretrained( 69 | "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 70 | ) 71 | else: 72 | pipe = StableDiffusion3Pipeline.from_pretrained( 73 | "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.bfloat16, cache_dir = "/ibex/user/zhaol0c/uniediting/ReflectionFlow-main/sd3medium" 74 | ) 75 | pipe = pipe.to("cuda") 76 | pipe.load_lora_weights( 77 | lora_dir, 78 | adapter_name="cot", 79 | ) 80 | 81 | # Load the JSON list containing prompt and image details 82 | print(f"Loading JSON from {json_dir}") 83 | with open(json_dir, "r") as f: 84 | # items = [json.loads(line.strip()) for line in f] 85 | items = json.load(f) 86 | 87 | seed_everything(args.seed) 88 | 89 | for idx, item in enumerate(items): 90 | # Open the bad image and good image 91 | # bad_image_path = os.path.join(root_dir, item["bad_image"]) 92 | # good_image_path = os.path.join(root_dir, item["good_image"]) 93 | bad_image_path = item["bad_image"] 94 | good_image_path = item["good_image"] 95 | 96 | # Load images 97 | bad_image = Image.open(bad_image_path).convert("RGB") 98 | good_image = Image.open(good_image_path).convert("RGB") 99 | 100 | # Get dimensions 101 | bad_w, bad_h = bad_image.size 102 | good_w, good_h = good_image.size 103 | 104 | # Resize bad image to match good image dimensions 105 | bad_image = bad_image.resize((good_w, good_h), Image.BICUBIC) 106 | 107 | # Resize the shorter edge to target_size while maintaining aspect ratio 108 | ratio = target_size / min(good_w, good_h) 109 | new_w = math.ceil(good_w * ratio) 110 | new_h = math.ceil(good_h * ratio) 111 | 112 | # Resize both images to the same dimensions 113 | good_image = good_image.resize((new_w, new_h), Image.BICUBIC) 114 | bad_image = bad_image.resize((new_w, new_h), Image.BICUBIC) 115 | 116 | # Randomly crop both images to exactly target_size x target_size 117 | if new_w > target_size or new_h > target_size: 118 | left = random.randint(0, max(0, new_w - target_size)) 119 | top = random.randint(0, max(0, new_h - target_size)) 120 | 121 | # Apply the same crop to both images to maintain pixel correspondence 122 | good_image = good_image.crop((left, top, left + target_size, top + target_size)) 123 | bad_image = bad_image.crop((left, top, left + target_size, top + target_size)) 124 | 125 | # Finally, resize bad_image to condition_size 126 | image = bad_image.resize((condition_size, condition_size), Image.BICUBIC) 127 | 128 | # Create a condition for the pipeline 129 | if args.model_name == "flux": 130 | condition_cls = Condition 131 | else: 132 | condition_cls = ConditionSD3 133 | condition = condition_cls( 134 | condition_type="cot", 135 | condition=image, 136 | position_delta=np.array([0, -condition_size // 16]) 137 | ) 138 | 139 | # Build the prompt by combining base prompt and reflection prompt if available 140 | original_prompt = item["prompt"] 141 | prompt = item["prompt"] 142 | if "reflection_prompt" in item: 143 | prompt += " [Reflexion] " + item["reflection_prompt"] 144 | elif "instruction" in item: 145 | prompt += " [Reflexion] " + item["instruction"] 146 | elif "reflection" in item: 147 | prompt += " [Reflexion] " + item["reflection"] 148 | elif "edited_prompt_list" in item: 149 | prompt += " [Reflexion] " + item["edited_prompt_list"][-1] 150 | else: 151 | raise ValueError(f"No reflection found in item: {item}") 152 | 153 | # Generate the result image 154 | if args.model_name == "flux": 155 | generate_func = generate 156 | else: 157 | generate_func = generate_sd3 158 | # breakpoint() 159 | result_img = generate_func( 160 | pipe, 161 | prompt=original_prompt, 162 | prompt_2=prompt, 163 | conditions=[condition], 164 | num_inference_steps=step, 165 | height=target_size, 166 | width=target_size, 167 | guidance_scale=args.guidance_scale, 168 | image_guidance_scale=args.image_guidance_scale 169 | ).images[0] 170 | 171 | # Concatenate bad image, good image, and generated image side by side 172 | concat_image = Image.new("RGB", (condition_size + target_size + target_size, target_size)) 173 | concat_image.paste(image, (0, 0)) 174 | concat_image.paste(good_image, (condition_size, 0)) 175 | concat_image.paste(result_img, (condition_size + target_size, 0)) 176 | 177 | # Save the concatenated image, using image_id if present 178 | output_name = item.get("image_id", f"result_{idx}") 179 | concat_image.save(os.path.join(output_dir, f"{output_name}.jpg")) -------------------------------------------------------------------------------- /reward_modeling/prompt_template.py: -------------------------------------------------------------------------------- 1 | 2 | VIDEOSCORE_QUERY_PROMPT = """ 3 | Suppose you are an expert in judging and evaluating the quality of AI-generated videos, 4 | please watch the frames of a given video and see the text prompt for generating the video, 5 | then give scores based on its {dimension_name}, i.e., {dimension_description}. 6 | Output a float number from 1.0 to 5.0 for this dimension, 7 | the higher the number is, the better the video performs in that sub-score, 8 | the lowest 1.0 means Bad, the highest 5.0 means Perfect/Real (the video is like a real video). 9 | The text prompt used for generation is "{text_prompt}". 10 | """ 11 | 12 | DIMENSION_DESCRIPTIONS = { 13 | 'VQ': ['visual quality', 'the quality of the video in terms of clearness, resolution, brightness, and color'], 14 | 'TA': ['text-to-video alignment', 'the alignment between the text prompt and the video content and motion'], 15 | 'MQ': ['motion quality', 'the quality of the motion in terms of consistency, smoothness, and completeness'], 16 | 'Overall': ['Overall Performance', 'the overall performance of the video in terms of visual quality, text-to-video alignment, and motion quality'], 17 | } 18 | 19 | SIMPLE_PROMPT = """ 20 | Please evaluate the {dimension_name} of a generated video. Consider {dimension_description}. 21 | The text prompt used for generation is "{text_prompt}". 22 | """ 23 | 24 | DETAILED_PROMPT_WITH_SPECIAL_TOKEN = """ 25 | You are tasked with evaluating a generated image based on two distinct criteria: Visual Quality and Text Alignment. Please provide a overall rating from 0 to 10, with 0 being the worst and 10 being the best. 26 | 27 | **Visual Quality:** 28 | Evaluate the overall visual quality of the image. The following sub-dimensions should be considered: 29 | - **Reasonableness:** The image should not contain any significant biological or logical errors, such as abnormal body structures or nonsensical environmental setups. 30 | - **Clarity:** Evaluate the sharpness and visibility of the image. The image should be clear and easy to interpret, with no blurring or indistinct areas. 31 | - **Detail Richness:** Consider the level of detail in textures, materials, lighting, and other visual elements (e.g., hair, clothing, shadows). 32 | - **Aesthetic and Creativity:** Assess the artistic aspects of the image, including the color scheme, composition, atmosphere, depth of field, and the overall creative appeal. The scene should convey a sense of harmony and balance. 33 | 34 | **Text Alignment:** 35 | Assess how well the image matches the textual prompt across the following sub-dimensions: 36 | - **Subject Relevance** Evaluate how accurately the subject(s) in the image (e.g., person, animal, object) align with the textual description. The subject should match the description in terms of number, appearance, interaction, etc. 37 | - **Environment Relevance:** Assess whether the background and scene fit the prompt. This includes checking if real-world locations or scenes are accurately represented, though some stylistic adaptation is acceptable. 38 | - **Style Relevance:** If the prompt specifies a particular artistic or stylistic style, evaluate how well the image adheres to this style. 39 | 40 | Textual prompt - {text_prompt} 41 | Please provide the overall rating: <|VQ_reward|> 42 | """ 43 | 44 | DETAILED_PROMPT = """ 45 | You are tasked with evaluating a generated video based on three distinct criteria: Visual Quality, Motion Quality, and Text Alignment. Please provide a rating from 0 to 10 for each of the three categories, with 0 being the worst and 10 being the best. Each evaluation should be independent of the others. 46 | 47 | **Visual Quality:** 48 | Evaluate the overall visual quality of the video, with a focus on static factors. The following sub-dimensions should be considered: 49 | - **Reasonableness:** The video should not contain any significant biological or logical errors, such as abnormal body structures or nonsensical environmental setups. 50 | - **Clarity:** Evaluate the sharpness and visibility of the video. The image should be clear and easy to interpret, with no blurring or indistinct areas. 51 | - **Detail Richness:** Consider the level of detail in textures, materials, lighting, and other visual elements (e.g., hair, clothing, shadows). 52 | - **Aesthetic and Creativity:** Assess the artistic aspects of the video, including the color scheme, composition, atmosphere, depth of field, and the overall creative appeal. The scene should convey a sense of harmony and balance. 53 | - **Safety:** The video should not contain harmful or inappropriate content, such as political, violent, or adult material. If such content is present, the image quality and satisfaction score should be the lowest possible. 54 | 55 | **Motion Quality:** 56 | Assess the dynamic aspects of the video, with a focus on dynamic factors. Consider the following sub-dimensions: 57 | - **Stability:** Evaluate the continuity and stability between frames. There should be no sudden, unnatural jumps, and the video should maintain stable attributes (e.g., no fluctuating colors, textures, or missing body parts). 58 | - **Naturalness:** The movement should align with physical laws and be realistic. For example, clothing should flow naturally with motion, and facial expressions should change appropriately (e.g., blinking, mouth movements). 59 | - **Aesthetic Quality:** The movement should be smooth and fluid. The transitions between different motions or camera angles should be seamless, and the overall dynamic feel should be visually pleasing. 60 | - **Fusion:** Ensure that elements in motion (e.g., edges of the subject, hair, clothing) blend naturally with the background, without obvious artifacts or the feeling of cut-and-paste effects. 61 | - **Clarity of Motion:** The video should be clear and smooth in motion. Pay attention to any areas where the video might have blurry or unsteady sections that hinder visual continuity. 62 | - **Amplitude:** If the video is largely static or has little movement, assign a low score for motion quality. 63 | 64 | 65 | **Text Alignment:** 66 | Assess how well the video matches the textual prompt across the following sub-dimensions: 67 | - **Subject Relevance** Evaluate how accurately the subject(s) in the video (e.g., person, animal, object) align with the textual description. The subject should match the description in terms of number, appearance, and behavior. 68 | - **Motion Relevance:** Evaluate if the dynamic actions (e.g., gestures, posture, facial expressions like talking or blinking) align with the described prompt. The motion should match the prompt in terms of type, scale, and direction. 69 | - **Environment Relevance:** Assess whether the background and scene fit the prompt. This includes checking if real-world locations or scenes are accurately represented, though some stylistic adaptation is acceptable. 70 | - **Style Relevance:** If the prompt specifies a particular artistic or stylistic style, evaluate how well the video adheres to this style. 71 | - **Camera Movement Relevance:** Check if the camera movements (e.g., following the subject, focus shifts) are consistent with the expected behavior from the prompt. 72 | 73 | Textual prompt - {text_prompt} 74 | Please provide the ratings of Visual Quality, Motion Quality, and Text Alignment. 75 | """ 76 | 77 | SIMPLE_PROMPT_NO_PROMPT = """ 78 | Please evaluate the {dimension_name} of a generated video. Consider {dimension_description}. 79 | """ 80 | 81 | def build_prompt(prompt, dimension, template_type): 82 | if isinstance(dimension, list) and len(dimension) > 1: 83 | dimension_name = ", ".join([DIMENSION_DESCRIPTIONS[d][0] for d in dimension]) 84 | dimension_name = f'overall performance({dimension_name})' 85 | dimension_description = "the overall performance of the video" 86 | else: 87 | if isinstance(dimension, list): 88 | dimension = dimension[0] 89 | dimension_name = DIMENSION_DESCRIPTIONS[dimension][0] 90 | dimension_description = DIMENSION_DESCRIPTIONS[dimension][1] 91 | 92 | if template_type == "none": 93 | return prompt 94 | elif template_type == "simple": 95 | return SIMPLE_PROMPT.format(dimension_name=dimension_name, 96 | dimension_description=dimension_description, 97 | text_prompt=prompt) 98 | elif template_type == "video_score": 99 | return VIDEOSCORE_QUERY_PROMPT.format(dimension_name=dimension_name, 100 | dimension_description=dimension_description, 101 | text_prompt=prompt) 102 | elif template_type == "detailed_special": 103 | return DETAILED_PROMPT_WITH_SPECIAL_TOKEN.format(text_prompt=prompt) 104 | elif template_type == "detailed": 105 | return DETAILED_PROMPT.format(text_prompt=prompt) 106 | else: 107 | raise ValueError("Invalid template type") 108 | -------------------------------------------------------------------------------- /train_flux/train/data.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import glob 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import webdataset as wds 8 | import torch 9 | from torch.utils.data import IterableDataset, DataLoader 10 | import torchvision.transforms as T 11 | 12 | # Taken from https://github.com/tmbdev-archive/webdataset-imagenet-2/blob/01a4ab54307b9156c527d45b6b171f88623d2dec/imagenet.py#L65. 13 | def nodesplitter(src, group=None): 14 | if torch.distributed.is_initialized(): 15 | if group is None: 16 | group = torch.distributed.group.WORLD 17 | rank = torch.distributed.get_rank(group=group) 18 | size = torch.distributed.get_world_size(group=group) 19 | count = 0 20 | for i, item in enumerate(src): 21 | if i % size == rank: 22 | yield item 23 | count += 1 24 | else: 25 | yield from src 26 | 27 | class ImageConditionWebDataset(IterableDataset): 28 | def __init__( 29 | self, 30 | shards_pattern: str, 31 | condition_size: int = 1024, 32 | target_size: int = 1024, 33 | condition_type: str = "cot", 34 | drop_text_prob: float = 0.1, 35 | drop_image_prob: float = 0.1, 36 | drop_reflection_prob: float = 0.2, 37 | root_dir: str = "", 38 | split_ratios: dict = None, # e.g. {"general":[.7,.3], "length":[.3,.7], …} 39 | training_stages: list = None, # e.g. [0, 10000, 20000] 40 | return_pil_image: bool = False, 41 | shuffle_buffer: int = 1000, 42 | ): 43 | super().__init__() 44 | self.condition_size = condition_size 45 | self.target_size = target_size 46 | self.condition_type = condition_type 47 | self.drop_text_prob = drop_text_prob 48 | self.drop_image_prob = drop_image_prob 49 | self.drop_reflection_prob = drop_reflection_prob 50 | self.return_pil_image = return_pil_image 51 | 52 | # prepare WebDataset pipelines per subset 53 | self.splits = list(split_ratios.keys()) 54 | self.all_split_ratios = split_ratios 55 | # start with first stage’s ratios 56 | self.split_ratios = {s: split_ratios[s][0] for s in self.splits} 57 | self.training_stages = training_stages or [0] 58 | 59 | # one independent pipeline for each subset 60 | self.datasets = {} 61 | if "https://" not in shards_pattern: 62 | shards_pattern = glob.glob(shards_pattern) 63 | for split in self.splits: 64 | # Define mandatory keys that must exist in each sample 65 | MANDATORY_KEYS = {"good_image.jpg", "bad_image.jpg", "reflection.txt", "prompt.txt", "subset.txt"} 66 | 67 | ds = ( 68 | wds.WebDataset(shards_pattern, handler=wds.ignore_and_continue, nodesplitter=nodesplitter) 69 | # First filter samples to ensure all required keys exist 70 | .select(lambda sample: set(sample.keys()) >= MANDATORY_KEYS) 71 | .shuffle(shuffle_buffer) 72 | .decode("pil") # good_image.jpg / bad_image.jpg → PIL 73 | .to_tuple( 74 | "good_image.jpg", 75 | "bad_image.jpg", 76 | "reflection.txt", 77 | "prompt.txt", 78 | "subset.txt", 79 | ) 80 | # keep only records whose subset matches this split 81 | .select(lambda sample: sample[4] == split) 82 | ) 83 | self.datasets[split] = ds 84 | 85 | # create one iterator per subset 86 | self.iters = {s: iter(ds) for s, ds in self.datasets.items()} 87 | self.to_tensor = T.ToTensor() 88 | self.iteration = 0 89 | 90 | def _update_split_ratios(self): 91 | itr = self.iteration 92 | stages = self.training_stages 93 | # beyond last => use last ratios 94 | if itr >= stages[-1]: 95 | for s in self.splits: 96 | self.split_ratios[s] = self.all_split_ratios[s][-1] 97 | return 98 | 99 | # find current stage index 100 | idx = max(i for i, t in enumerate(stages) if itr >= t) 101 | next_idx = min(idx+1, len(stages)-1) 102 | start, end = stages[idx], stages[next_idx] 103 | progress = (itr - start) / (end - start) if end>start else 1.0 104 | 105 | for s in self.splits: 106 | r0 = self.all_split_ratios[s][idx] 107 | r1 = self.all_split_ratios[s][next_idx] 108 | self.split_ratios[s] = r0 + progress*(r1-r0) 109 | 110 | def _preprocess_pair(self, good: Image.Image, bad: Image.Image): 111 | # match bad → good dims 112 | gw, gh = good.size 113 | bad = bad.resize((gw, gh), Image.BICUBIC) 114 | # scale shorter edge → target_size 115 | ratio = self.target_size / min(gw, gh) 116 | nw, nh = math.ceil(gw*ratio), math.ceil(gh*ratio) 117 | good = good.resize((nw, nh), Image.BICUBIC) 118 | bad = bad.resize((nw, nh), Image.BICUBIC) 119 | 120 | # same random crop 121 | if nw>self.target_size or nh>self.target_size: 122 | left = random.randint(0, nw-self.target_size) 123 | top = random.randint(0, nh-self.target_size) 124 | box = (left, top, left+self.target_size, top+self.target_size) 125 | good = good.crop(box) 126 | bad = bad.crop(box) 127 | 128 | # final resize bad → condition_size 129 | bad = bad.resize((self.condition_size, self.condition_size), Image.BICUBIC) 130 | return good, bad 131 | 132 | def __iter__(self): 133 | while True: 134 | # 1) update dynamic ratios 135 | self._update_split_ratios() 136 | 137 | # 2) pick a split by current weights 138 | split = random.choices( 139 | self.splits, 140 | weights=[self.split_ratios[s] for s in self.splits], 141 | k=1, 142 | )[0] 143 | 144 | # 3) pull next sample (re‑reset iterator on exhaustion) 145 | try: 146 | good, bad, ref_bytes, prom_bytes, sub_bytes = next(self.iters[split]) 147 | except StopIteration: 148 | self.iters[split] = iter(self.datasets[split]) 149 | good, bad, ref_bytes, prom_bytes, sub_bytes = next(self.iters[split]) 150 | 151 | # decode text 152 | reflection = ref_bytes 153 | prompt = prom_bytes 154 | subset = sub_bytes 155 | 156 | # convert to RGB 157 | good = good.convert("RGB") 158 | bad = bad.convert("RGB") 159 | 160 | # 4) apply your resize/crop logic 161 | good, bad = self._preprocess_pair(good, bad) 162 | 163 | # 5) decide drops 164 | drop_text = random.random() < self.drop_text_prob 165 | drop_image_flag = random.random() < self.drop_image_prob and subset!="editing" 166 | drop_reflection = ( 167 | random.random() < self.drop_reflection_prob 168 | or len(reflection)<5 169 | ) 170 | 171 | if drop_reflection or drop_image_flag: 172 | description = prompt 173 | else: 174 | description = f"{prompt} [Reflexion] {reflection}" 175 | if drop_text: 176 | description = "" 177 | if drop_image_flag: 178 | # black out condition 179 | bad = Image.new("RGB", (self.condition_size, self.condition_size), (0,0,0)) 180 | 181 | # 6) to tensors 182 | image = self.to_tensor(good) 183 | condition = self.to_tensor(bad) 184 | 185 | out = { 186 | "image": image, 187 | "condition": condition, 188 | "original_prompt": prompt, 189 | "condition_type": self.condition_type, 190 | "description": description, 191 | "position_delta": np.array([0, -self.condition_size//16]), 192 | "subset": subset 193 | } 194 | if self.return_pil_image: 195 | out["pil_image"] = [good, bad] 196 | 197 | self.iteration += 1 198 | yield out 199 | 200 | # usage: 201 | if __name__ == "__main__": 202 | 203 | split_ratios = { 204 | "general": [0.8, 0.6, 0.4], 205 | "length": [0.1, 0.2, 0.3], 206 | "rule": [0.1, 0.2, 0.3], 207 | "editing": [0.0, 0.0, 0.0], 208 | } 209 | training_stages = [0, 10000, 20000] 210 | 211 | # local path should be provided as 212 | # shards_pattern = "DIR_WHERE_GenRef-wds_IS_DOWNLOADED/*.tar" 213 | shards_pattern = "/ceph/hf-cache/hub/datasets--diffusion-cot--GenRef-wds/snapshots/42c837b891fc34a944ed0c8124876a7e8225266f/*.tar" 214 | dataset = ImageConditionWebDataset( 215 | shards_pattern=shards_pattern, 216 | condition_size=1024, 217 | target_size=1024, 218 | condition_type="cot", 219 | drop_text_prob=0.1, 220 | drop_image_prob=0.1, 221 | drop_reflection_prob=0.2, 222 | split_ratios=split_ratios, 223 | training_stages=training_stages, 224 | return_pil_image=False, 225 | ) 226 | 227 | loader = DataLoader(dataset, batch_size=8, num_workers=4) 228 | 229 | # iterate: 230 | from tqdm import tqdm 231 | for batch in tqdm(loader): 232 | continue 233 | # print(batch.keys()) 234 | # print(batch["image"].size()) 235 | # print(batch["condition"].size()) 236 | # break 237 | -------------------------------------------------------------------------------- /reward_modeling/data.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | from dataclasses import dataclass 3 | from typing import Optional, List, Union 4 | 5 | import pandas as pd 6 | import torch 7 | from prompt_template import build_prompt 8 | # from qwen_vl_utils import process_vision_info 9 | from vision_process import process_vision_info 10 | from torch.utils.data import Dataset 11 | import torchvision.transforms.functional as F 12 | 13 | 14 | @dataclass 15 | class DataConfig: 16 | meta_data: str = "/path/to/dataset/meta_data.csv" 17 | data_dir: str = "/path/to/dataset" 18 | meta_data_test: str = None 19 | max_frame_pixels: int = 240 * 320 20 | num_frames: float = None 21 | fps: float = 2.0 22 | p_shuffle_frames: float = 0.0 23 | p_color_jitter: float = 0.0 24 | eval_dim: Union[str, List[str]] = "VQ" 25 | prompt_template_type: str = "none" 26 | add_noise: bool = False 27 | sample_type: str = "uniform" 28 | use_tied_data: bool = True 29 | 30 | def convert_GSB_csv_to_reward_data(example, data_dir, eval_dims=["VQ"], max_pixels=448 * 448, prompt_template_type="none"): 31 | """ 32 | Convert Good/Same/Bad csv data to reward data. 33 | 34 | Args: 35 | example (dict): A dataframe containing the GSB csv data. 36 | data_dir (str): The directory path to the video files. 37 | eval_dim (str): The dimension to evaluate ("VQ"/"MQ"/"TA"). 38 | max_pixels (int): The maximum number of pixels allowed for videos. 39 | num_frames (float): Number of frames. 40 | prompt_template_type (str): The type of prompt template to use ("none"/"simple"/"video_score"). 41 | 42 | Returns: 43 | dict: A dictionary containing the reward data. 44 | """ 45 | 46 | A_data = [ 47 | { 48 | "role": "user", 49 | "content": [ 50 | { 51 | "type": "image", 52 | "image": example[f'image1'], 53 | "max_pixels": max_pixels, 54 | }, 55 | {"type": "text", "text": build_prompt(example["prompt"], eval_dims, prompt_template_type)}, 56 | ], 57 | } 58 | ] 59 | B_data = [ 60 | { 61 | "role": "user", 62 | "content": [ 63 | { 64 | "type": "image", 65 | "image": example[f'image2'], 66 | "max_pixels": max_pixels, 67 | }, 68 | {"type": "text", "text": build_prompt(example["prompt"], eval_dims, prompt_template_type)}, 69 | ], 70 | } 71 | ] 72 | 73 | chosen_labels = [] 74 | A_scores = [] 75 | B_scores = [] 76 | 77 | for eval_dim in eval_dims: 78 | ### chosen_label: 1 if A is chosen, -1 if B is chosen, 0 if tied. 79 | ### 22 if invalid. ooaaeeaa o.O 80 | try: 81 | if example[f"{eval_dim}"] is not None: 82 | if example[f"{eval_dim}"] == "A": 83 | chosen_label = 1 84 | elif example[f"{eval_dim}"] == "B": 85 | chosen_label = -1 86 | elif example[f"{eval_dim}"] == "same": 87 | chosen_label = 0 88 | elif example[f"{eval_dim}"] == "invalid": 89 | chosen_label = 22 90 | else: 91 | chosen_label = 22 92 | else: 93 | chosen_label = 22 94 | except Exception as e: 95 | chosen_label = 22 96 | 97 | chosen_labels.append(chosen_label) 98 | if f"MOS_A_{eval_dim}" in example and f"MOS_B_{eval_dim}" in example: 99 | try: 100 | A_score = example[f"MOS_A_{eval_dim}"] if example[f"MOS_A_{eval_dim}"] is not None else 0.0 101 | B_score = example[f"MOS_B_{eval_dim}"] if example[f"MOS_B_{eval_dim}"] is not None else 0.0 102 | except Exception as e: 103 | A_score = 0.0 104 | B_score = 0.0 105 | A_scores.append(A_score) 106 | B_scores.append(B_score) 107 | else: 108 | A_scores.append(0.0) 109 | B_scores.append(0.0) 110 | 111 | chosen_labels = torch.tensor(chosen_labels, dtype=torch.long) 112 | A_scores = torch.tensor(A_scores, dtype=torch.float) 113 | B_scores = torch.tensor(B_scores, dtype=torch.float) 114 | metainfo_idx = None 115 | if 'metainfo_idx' in example: 116 | metainfo_idx = example['metainfo_idx'] 117 | 118 | return {"A_data": A_data, "B_data": B_data, 119 | "A_scores": A_scores, "B_scores": B_scores, 120 | "chosen_label": chosen_labels, 121 | "metainfo_idx": metainfo_idx,} 122 | 123 | class QWen2VLDataCollator(): 124 | def __init__(self, processor, add_noise=False, p_shuffle_frames=0.0, p_color_jitter=0.0): 125 | self.processor = processor 126 | self.add_noise = add_noise 127 | self.set_noise_step = None 128 | 129 | self.p_shuffle_frames = p_shuffle_frames 130 | self.p_color_jitter = p_color_jitter 131 | 132 | self.noise_adder = None 133 | 134 | def _clean_message(self, message): 135 | """ 136 | remove unnecessary keys from message(very very necessary) 137 | """ 138 | out_message = [ 139 | { 140 | "role": "user", 141 | "content": [ 142 | { 143 | "type": "image", 144 | "image": message[0]["content"][0]["image"], 145 | "max_pixels": message[0]["content"][0]["max_pixels"], 146 | }, 147 | {"type": "text", "text": message[0]["content"][1]["text"]}, 148 | ], 149 | } 150 | ] 151 | 152 | return out_message 153 | 154 | 155 | def _pad_sequence(self, sequences, attention_mask, max_len, padding_side='right'): 156 | """ 157 | Pad the sequences to the maximum length. 158 | """ 159 | assert padding_side in ['right', 'left'] 160 | if sequences.shape[1] >= max_len: 161 | return sequences, attention_mask 162 | 163 | pad_len = max_len - sequences.shape[1] 164 | padding = (0, pad_len) if padding_side == 'right' else (pad_len, 0) 165 | 166 | sequences_padded = torch.nn.functional.pad(sequences, padding, 'constant', self.processor.tokenizer.pad_token_id) 167 | attention_mask_padded = torch.nn.functional.pad(attention_mask, padding, 'constant', 0) 168 | 169 | return sequences_padded, attention_mask_padded 170 | 171 | def __call__(self, features, enable_noise=True): 172 | """ 173 | Preprocess inputs to token sequences and return a batch 174 | """ 175 | # try: 176 | features_A = [] 177 | features_B = [] 178 | # check if we have a margin. If we do, we need to batch it as well 179 | # has_margin = "margin" in features[0] 180 | has_idx = "metainfo_idx" in features[0] and features[0]["metainfo_idx"] is not None 181 | 182 | for idx, feature in enumerate(features): 183 | features_A.append(self._clean_message(feature["A_data"])) 184 | features_B.append(self._clean_message(feature["B_data"])) 185 | 186 | # import pdb; pdb.set_trace() 187 | image_inputs_A, video_inputs_A = process_vision_info(features_A) 188 | image_inputs_B, video_inputs_B = process_vision_info(features_B) 189 | 190 | do_rescale = False 191 | # print(f"{video_inputs_A[0].shape}, {video_inputs_B[0].shape}") 192 | 193 | # if not enable_noise: 194 | # print("Not training, no noise added.") 195 | batch_A = self.processor( 196 | text=self.processor.apply_chat_template(features_A, tokenize=False, add_generation_prompt=True), 197 | images=image_inputs_A, 198 | videos=video_inputs_A, 199 | padding=True, 200 | return_tensors="pt", 201 | videos_kwargs={"do_rescale": do_rescale}, 202 | ) 203 | batch_B = self.processor( 204 | text=self.processor.apply_chat_template(features_B, tokenize=False, add_generation_prompt=True), 205 | images=image_inputs_B, 206 | videos=video_inputs_B, 207 | padding=True, 208 | return_tensors="pt", 209 | videos_kwargs={"do_rescale": do_rescale}, 210 | ) 211 | 212 | # pdb.set_trace() 213 | max_len = max(batch_A["input_ids"].shape[1], batch_B["input_ids"].shape[1]) 214 | batch_A["input_ids"], batch_A["attention_mask"] = self._pad_sequence(batch_A["input_ids"], batch_A["attention_mask"], max_len, "right") 215 | batch_B["input_ids"], batch_B["attention_mask"] = self._pad_sequence(batch_B["input_ids"], batch_B["attention_mask"], max_len, "right") 216 | # print(f"Batch A: {batch_A['input_ids'].shape}, Batch B: {batch_B['input_ids'].shape}") 217 | 218 | chosen_label = torch.stack([torch.tensor(feature["chosen_label"]) for feature in features]) 219 | 220 | A_scores = torch.stack([torch.tensor(feature["A_scores"]) for feature in features]) 221 | B_scores = torch.stack([torch.tensor(feature["B_scores"]) for feature in features]) 222 | 223 | batch = { 224 | "A": batch_A, 225 | "B": batch_B, 226 | "return_loss": True, 227 | "chosen_label": chosen_label, 228 | "A_scores": A_scores, 229 | "B_scores": B_scores, 230 | } 231 | 232 | if has_idx: 233 | metainfo_idx = torch.stack([torch.tensor(feature["metainfo_idx"]) for feature in features]) 234 | batch["metainfo_idx"] = metainfo_idx 235 | 236 | # pdb.set_trace() 237 | return batch 238 | 239 | # except Exception as e: 240 | # print(f"Error processing batch: {e} in reading.") 241 | # # get next batch 242 | # return None 243 | -------------------------------------------------------------------------------- /train_flux/train/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lightning as L 3 | from diffusers.pipelines import FluxPipeline 4 | import torch 5 | from torchvision import transforms 6 | from peft import LoraConfig, get_peft_model_state_dict 7 | 8 | import prodigyopt 9 | 10 | from flux.generate import generate 11 | from flux.transformer import tranformer_forward 12 | from flux.condition import Condition 13 | from flux.pipeline_tools import encode_images, prepare_text_input 14 | 15 | 16 | class OminiModel(L.LightningModule): 17 | def __init__( 18 | self, 19 | flux_pipe_id: str, 20 | lora_path: str = None, 21 | lora_config: dict = None, 22 | data_config: dict = None, 23 | device: str = "cuda", 24 | dtype: torch.dtype = torch.bfloat16, 25 | model_config: dict = {}, 26 | optimizer_config: dict = None, 27 | gradient_checkpointing: bool = False, 28 | save_path: str = None, 29 | run_name: str = None, 30 | cache_dir: str = None, 31 | ): 32 | # Initialize the LightningModule 33 | super().__init__() 34 | self.model_config = model_config 35 | self.optimizer_config = optimizer_config 36 | self.data_config = data_config 37 | self.save_path = save_path 38 | self.run_name = run_name 39 | 40 | # Load the Flux pipeline 41 | self.flux_pipe: FluxPipeline = ( 42 | FluxPipeline.from_pretrained(flux_pipe_id, cache_dir=cache_dir, torch_dtype=dtype).to(device) 43 | ) 44 | self.transformer = self.flux_pipe.transformer 45 | self.transformer.gradient_checkpointing = gradient_checkpointing 46 | self.transformer.train() 47 | 48 | # Freeze the Flux pipeline 49 | self.flux_pipe.text_encoder.requires_grad_(False).eval() 50 | self.flux_pipe.text_encoder_2.requires_grad_(False).eval() 51 | self.flux_pipe.vae.requires_grad_(False).eval() 52 | 53 | # Initialize LoRA layers 54 | self.lora_layers = self.init_lora(lora_path, lora_config) 55 | 56 | self.to(device).to(dtype) 57 | 58 | def init_lora(self, lora_path: str, lora_config: dict): 59 | assert lora_path or lora_config 60 | if lora_path: 61 | # # TODO: Implement this 62 | # raise NotImplementedError 63 | self.flux_pipe.load_lora_weights(lora_path, adapter_name="default") 64 | # TODO: Check if this is correct (p.requires_grad) 65 | lora_layers = [] 66 | for name, p in self.transformer.named_parameters(): 67 | if "lora" in name: 68 | lora_layers.append(p) 69 | # lora_layers = filter( 70 | # lambda p: p.requires_grad, self.transformer.parameters() 71 | # ) 72 | else: 73 | if lora_config.get("target_modules", None) == "all-linear": 74 | target_modules = set() 75 | for name, module in self.transformer.named_modules(): 76 | if isinstance(module, torch.nn.Linear): 77 | target_modules.add(name) 78 | target_modules = list(target_modules) 79 | lora_config["target_modules"] = target_modules 80 | self.transformer.add_adapter(LoraConfig(**lora_config)) 81 | # TODO: Check if this is correct (p.requires_grad) 82 | lora_layers = filter( 83 | lambda p: p.requires_grad, self.transformer.parameters() 84 | ) 85 | return list(lora_layers) if not isinstance(lora_layers, list) else lora_layers 86 | 87 | def save_lora(self, path: str): 88 | FluxPipeline.save_lora_weights( 89 | save_directory=path, 90 | transformer_lora_layers=get_peft_model_state_dict(self.transformer), 91 | safe_serialization=True, 92 | ) 93 | 94 | def configure_optimizers(self): 95 | # Freeze the transformer 96 | self.transformer.requires_grad_(False) 97 | opt_config = self.optimizer_config 98 | 99 | # Set the trainable parameters 100 | self.trainable_params = self.lora_layers 101 | 102 | # Unfreeze trainable parameters 103 | for p in self.trainable_params: 104 | p.requires_grad_(True) 105 | 106 | # Initialize the optimizer 107 | if opt_config["type"] == "AdamW": 108 | optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"]) 109 | elif opt_config["type"] == "Prodigy": 110 | optimizer = prodigyopt.Prodigy( 111 | self.trainable_params, 112 | **opt_config["params"], 113 | ) 114 | elif opt_config["type"] == "SGD": 115 | optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"]) 116 | else: 117 | raise NotImplementedError 118 | 119 | return optimizer 120 | 121 | def validation_step(self, batch, batch_idx): 122 | generator = torch.Generator(device=self.device) 123 | generator.manual_seed(42) 124 | target_size = self.data_config["target_size"] 125 | condition_size = self.data_config["condition_size"] 126 | prompt = batch["description"][0] 127 | original_prompt = batch["original_prompt"][0] 128 | condition_type = batch["condition_type"][0] 129 | condition_img = batch["condition"][0] 130 | to_pil = transforms.ToPILImage() 131 | condition_img = to_pil(condition_img.cpu().float()) 132 | position_delta = batch["position_delta"][0] 133 | condition = Condition( 134 | condition_type=condition_type, 135 | condition=condition_img, 136 | position_delta=position_delta, 137 | ) 138 | 139 | res = generate( 140 | self.flux_pipe, 141 | prompt=original_prompt, 142 | prompt_2=prompt, 143 | conditions=[condition], 144 | height=target_size, 145 | width=target_size, 146 | generator=generator, 147 | model_config=self.model_config, 148 | default_lora=True, 149 | ) 150 | os.makedirs(os.path.join(self.save_path, self.run_name, "val"), exist_ok=True) 151 | res.images[0].save( 152 | os.path.join(self.save_path, self.run_name, "val", f"{self.global_step}_{condition_type}_{batch_idx}.jpg") 153 | ) 154 | 155 | def training_step(self, batch, batch_idx): 156 | step_loss = self.step(batch) 157 | self.log_loss = ( 158 | step_loss.item() 159 | if not hasattr(self, "log_loss") 160 | else self.log_loss * 0.95 + step_loss.item() * 0.05 161 | ) 162 | return step_loss 163 | 164 | def step(self, batch): 165 | imgs = batch["image"] 166 | conditions = batch["condition"] 167 | condition_types = batch["condition_type"] 168 | prompts = batch["original_prompt"] 169 | position_delta = batch["position_delta"][0] 170 | prompts_2 = batch["description"] 171 | 172 | # Prepare inputs 173 | with torch.no_grad(): 174 | # Prepare image input 175 | x_0, img_ids = encode_images(self.flux_pipe, imgs) 176 | 177 | # Prepare text input 178 | prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input( 179 | self.flux_pipe, prompts, prompts_2=prompts_2 180 | ) 181 | 182 | # Prepare t and x_t 183 | t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device)) 184 | x_1 = torch.randn_like(x_0).to(self.device) 185 | t_ = t.unsqueeze(1).unsqueeze(1) 186 | x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype) 187 | 188 | # # Prepare conditions 189 | # condition_latents, condition_ids = encode_images(self.flux_pipe, conditions) 190 | 191 | # # Add position delta 192 | # condition_ids[:, 1] += position_delta[0] 193 | # condition_ids[:, 2] += position_delta[1] 194 | 195 | # # Prepare condition type 196 | # condition_type_ids = torch.tensor( 197 | # [ 198 | # Condition.get_type_id(condition_type) 199 | # for condition_type in condition_types 200 | # ] 201 | # ).to(self.device) 202 | # condition_type_ids = ( 203 | # torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0] 204 | # ).unsqueeze(1) 205 | 206 | condition_latents = None 207 | condition_ids = None 208 | condition_type_ids = None 209 | 210 | # Prepare guidance 211 | guidance = ( 212 | torch.ones_like(t).to(self.device) 213 | if self.transformer.config.guidance_embeds 214 | else None 215 | ) 216 | 217 | # Forward pass 218 | transformer_out = tranformer_forward( 219 | self.transformer, 220 | # Model config 221 | model_config=self.model_config, 222 | # Inputs of the condition (new feature) 223 | condition_latents=condition_latents, 224 | condition_ids=condition_ids, 225 | condition_type_ids=condition_type_ids, 226 | # Inputs to the original transformer 227 | hidden_states=x_t, 228 | timestep=t, 229 | guidance=guidance, 230 | pooled_projections=pooled_prompt_embeds, 231 | encoder_hidden_states=prompt_embeds, 232 | txt_ids=text_ids, 233 | img_ids=img_ids, 234 | joint_attention_kwargs=None, 235 | return_dict=False, 236 | ) 237 | pred = transformer_out[0] 238 | 239 | # Compute loss 240 | loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean") 241 | self.last_t = t.mean().item() 242 | return loss 243 | -------------------------------------------------------------------------------- /train_flux/flux/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.pipelines import FluxPipeline 3 | from typing import List, Union, Optional, Dict, Any, Callable 4 | from .block import block_forward, single_block_forward 5 | from .lora_controller import enable_lora 6 | from accelerate.utils import is_torch_version 7 | from diffusers.models.transformers.transformer_flux import ( 8 | FluxTransformer2DModel, 9 | Transformer2DModelOutput, 10 | USE_PEFT_BACKEND, 11 | scale_lora_layers, 12 | unscale_lora_layers, 13 | logger, 14 | ) 15 | import numpy as np 16 | 17 | 18 | def prepare_params( 19 | hidden_states: torch.Tensor, 20 | encoder_hidden_states: torch.Tensor = None, 21 | pooled_projections: torch.Tensor = None, 22 | timestep: torch.LongTensor = None, 23 | img_ids: torch.Tensor = None, 24 | txt_ids: torch.Tensor = None, 25 | guidance: torch.Tensor = None, 26 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 27 | controlnet_block_samples=None, 28 | controlnet_single_block_samples=None, 29 | return_dict: bool = True, 30 | **kwargs: dict, 31 | ): 32 | return ( 33 | hidden_states, 34 | encoder_hidden_states, 35 | pooled_projections, 36 | timestep, 37 | img_ids, 38 | txt_ids, 39 | guidance, 40 | joint_attention_kwargs, 41 | controlnet_block_samples, 42 | controlnet_single_block_samples, 43 | return_dict, 44 | ) 45 | 46 | 47 | def tranformer_forward( 48 | transformer: FluxTransformer2DModel, 49 | condition_latents: torch.Tensor, 50 | condition_ids: torch.Tensor, 51 | condition_type_ids: torch.Tensor, 52 | model_config: Optional[Dict[str, Any]] = {}, 53 | c_t=0, 54 | **params: dict, 55 | ): 56 | self = transformer 57 | use_condition = condition_latents is not None 58 | 59 | ( 60 | hidden_states, 61 | encoder_hidden_states, 62 | pooled_projections, 63 | timestep, 64 | img_ids, 65 | txt_ids, 66 | guidance, 67 | joint_attention_kwargs, 68 | controlnet_block_samples, 69 | controlnet_single_block_samples, 70 | return_dict, 71 | ) = prepare_params(**params) 72 | 73 | if joint_attention_kwargs is not None: 74 | joint_attention_kwargs = joint_attention_kwargs.copy() 75 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 76 | else: 77 | lora_scale = 1.0 78 | 79 | if USE_PEFT_BACKEND: 80 | # weight the lora layers by setting `lora_scale` for each PEFT layer 81 | scale_lora_layers(self, lora_scale) 82 | else: 83 | if ( 84 | joint_attention_kwargs is not None 85 | and joint_attention_kwargs.get("scale", None) is not None 86 | ): 87 | logger.warning( 88 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 89 | ) 90 | 91 | with enable_lora((self.x_embedder,), model_config.get("latent_lora", False)): 92 | hidden_states = self.x_embedder(hidden_states) 93 | condition_latents = self.x_embedder(condition_latents) if use_condition else None 94 | 95 | timestep = timestep.to(hidden_states.dtype) * 1000 96 | 97 | if guidance is not None: 98 | guidance = guidance.to(hidden_states.dtype) * 1000 99 | else: 100 | guidance = None 101 | 102 | temb = ( 103 | self.time_text_embed(timestep, pooled_projections) 104 | if guidance is None 105 | else self.time_text_embed(timestep, guidance, pooled_projections) 106 | ) 107 | 108 | cond_temb = ( 109 | self.time_text_embed(torch.ones_like(timestep) * c_t * 1000, pooled_projections) 110 | if guidance is None 111 | else self.time_text_embed( 112 | torch.ones_like(timestep) * c_t * 1000, torch.ones_like(guidance) * 1000, pooled_projections 113 | ) 114 | ) 115 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 116 | 117 | if txt_ids.ndim == 3: 118 | logger.warning( 119 | "Passing `txt_ids` 3d torch.Tensor is deprecated." 120 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 121 | ) 122 | txt_ids = txt_ids[0] 123 | if img_ids.ndim == 3: 124 | logger.warning( 125 | "Passing `img_ids` 3d torch.Tensor is deprecated." 126 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 127 | ) 128 | img_ids = img_ids[0] 129 | 130 | ids = torch.cat((txt_ids, img_ids), dim=0) 131 | image_rotary_emb = self.pos_embed(ids) 132 | if use_condition: 133 | # condition_ids[:, :1] = condition_type_ids 134 | cond_rotary_emb = self.pos_embed(condition_ids) 135 | 136 | # hidden_states = torch.cat([hidden_states, condition_latents], dim=1) 137 | 138 | for index_block, block in enumerate(self.transformer_blocks): 139 | if self.training and self.gradient_checkpointing: 140 | ckpt_kwargs: Dict[str, Any] = ( 141 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 142 | ) 143 | encoder_hidden_states, hidden_states, condition_latents = ( 144 | torch.utils.checkpoint.checkpoint( 145 | block_forward, 146 | self=block, 147 | model_config=model_config, 148 | hidden_states=hidden_states, 149 | encoder_hidden_states=encoder_hidden_states, 150 | condition_latents=condition_latents if use_condition else None, 151 | temb=temb, 152 | cond_temb=cond_temb if use_condition else None, 153 | cond_rotary_emb=cond_rotary_emb if use_condition else None, 154 | image_rotary_emb=image_rotary_emb, 155 | **ckpt_kwargs, 156 | ) 157 | ) 158 | 159 | else: 160 | encoder_hidden_states, hidden_states, condition_latents = block_forward( 161 | block, 162 | model_config=model_config, 163 | hidden_states=hidden_states, 164 | encoder_hidden_states=encoder_hidden_states, 165 | condition_latents=condition_latents if use_condition else None, 166 | temb=temb, 167 | cond_temb=cond_temb if use_condition else None, 168 | cond_rotary_emb=cond_rotary_emb if use_condition else None, 169 | image_rotary_emb=image_rotary_emb, 170 | ) 171 | 172 | # controlnet residual 173 | if controlnet_block_samples is not None: 174 | interval_control = len(self.transformer_blocks) / len( 175 | controlnet_block_samples 176 | ) 177 | interval_control = int(np.ceil(interval_control)) 178 | hidden_states = ( 179 | hidden_states 180 | + controlnet_block_samples[index_block // interval_control] 181 | ) 182 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 183 | 184 | for index_block, block in enumerate(self.single_transformer_blocks): 185 | if self.training and self.gradient_checkpointing: 186 | ckpt_kwargs: Dict[str, Any] = ( 187 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 188 | ) 189 | result = torch.utils.checkpoint.checkpoint( 190 | single_block_forward, 191 | self=block, 192 | model_config=model_config, 193 | hidden_states=hidden_states, 194 | temb=temb, 195 | image_rotary_emb=image_rotary_emb, 196 | **( 197 | { 198 | "condition_latents": condition_latents, 199 | "cond_temb": cond_temb, 200 | "cond_rotary_emb": cond_rotary_emb, 201 | } 202 | if use_condition 203 | else {} 204 | ), 205 | **ckpt_kwargs, 206 | ) 207 | 208 | else: 209 | result = single_block_forward( 210 | block, 211 | model_config=model_config, 212 | hidden_states=hidden_states, 213 | temb=temb, 214 | image_rotary_emb=image_rotary_emb, 215 | **( 216 | { 217 | "condition_latents": condition_latents, 218 | "cond_temb": cond_temb, 219 | "cond_rotary_emb": cond_rotary_emb, 220 | } 221 | if use_condition 222 | else {} 223 | ), 224 | ) 225 | if use_condition: 226 | hidden_states, condition_latents = result 227 | else: 228 | hidden_states = result 229 | 230 | # controlnet residual 231 | if controlnet_single_block_samples is not None: 232 | interval_control = len(self.single_transformer_blocks) / len( 233 | controlnet_single_block_samples 234 | ) 235 | interval_control = int(np.ceil(interval_control)) 236 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 237 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] 238 | + controlnet_single_block_samples[index_block // interval_control] 239 | ) 240 | 241 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 242 | 243 | hidden_states = self.norm_out(hidden_states, temb) 244 | output = self.proj_out(hidden_states) 245 | 246 | if USE_PEFT_BACKEND: 247 | # remove `lora_scale` from each PEFT layer 248 | unscale_lora_layers(self, lora_scale) 249 | 250 | if not return_dict: 251 | return (output,) 252 | return Transformer2DModelOutput(sample=output) -------------------------------------------------------------------------------- /tts/verifiers/geneval_detailed_verifier_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "single_object": "You are a multimodal large-language model tasked with evaluating images generated by a text-to-image model. Your goal is to assess each generated image based on specific aspects and provide a detailed critique, along with a scoring system. The final output should be formatted as a JSON object containing individual scores for each aspect and an overall score. The keys in the JSON object should be: `object_completeness`, `detectability`, `occlusion_handling`, and `overall_score`. Below is a comprehensive guide to follow in your evaluation process: Your evaluation should focus on these aspects:\n\n 1. Key Evaluation Aspects and Scoring Criteria: For each aspect, provide a score from 0 to 10, where 0 represents poor performance and 10 represents excellent performance. For each score, include a short explanation or justification (1-2 sentences) explaining why that score was given. The aspects to evaluate are as follows: \n\n a) Object Completeness:\nAssess the structural integrity of the object (no defects/deformations), detail clarity and legibility.\nScore: 0 (severely fragmented) to 10 (perfectly intact).\n\nb) Detectability:\nEvaluate the distinction and visual saliency of objects and backgrounds using contrast analysis.\nScore: 0 (camouflaged) to 10 (immediately noticeable).\n\nc) Occlusion Handling:\nAssess whether there is unreasonable occlusion (natural occlusion needs to keep the subject visible).\nScore: 0 (key parts are blocked) to 10 (no blockage/natural and reasonable blockage).\n\n2. Overall Score \nAfter scoring each aspect individually, provide an overall score, representing the model's general performance on this image. This should be a weighted average based on the importance of each aspect to the prompt or an average of all aspects.", 3 | 4 | "two_object": "You are a multimodal large-language model tasked with evaluating images generated by a text-to-image model. Your goal is to assess each generated image based on specific aspects and provide a detailed critique, along with a scoring system. The final output should be formatted as a JSON object containing individual scores for each aspect and an overall score. The keys in the JSON object should be: `separation_clarity`, `individual_completeness`, `relationship_accuracy`, and `overall_score`. Below is a comprehensive guide to follow in your evaluation process: Your evaluation should focus on these aspects:\n\n 1. Key Evaluation Aspects and Scoring Criteria: For each aspect, provide a score from 0 to 10, where 0 represents poor performance and 10 represents excellent performance. For each score, include a short explanation or justification (1-2 sentences) explaining why that score was given. The aspects to evaluate are as follows: \n\n a) Seperation Clarity:\nAssess the spatial separation and boundary clarity of two objects.\nScore: 0 (fully overlapped) to 10 (completely separate and clearly defined boundaries)\n\nb) Indivisual Completeness:\nEvaluate each object's individual integrity and detail retention.\nScore: 0 (both objects are incomplete) to 10 (both objects are complete).\n\nc) Relationship Accuracy:\nAssess the rationality of size proportions.\nScore: 0 (wrong proportions) to 10 (perfectly in line with physical laws).\n\n2. Overall Score \nAfter scoring each aspect individually, provide an overall score, representing the model's general performance on this image. This should be a weighted average based on the importance of each aspect to the prompt or an average of all aspects.", 5 | 6 | "counting": "You are a multimodal large-language model tasked with evaluating images generated by a text-to-image model. Your goal is to assess each generated image based on specific aspects and provide a detailed critique, along with a scoring system. The final output should be formatted as a JSON object containing individual scores for each aspect and an overall score. The keys in the JSON object should be: `count_accuracy`, `object_uniformity`, `spatial_legibility`, and `overall_score`. Below is a comprehensive guide to follow in your evaluation process: Your evaluation should focus on these aspects:\n\n 1. Key Evaluation Aspects and Scoring Criteria: For each aspect, provide a score from 0 to 10, where 0 represents poor performance and 10 represents excellent performance. For each score, include a short explanation or justification (1-2 sentences) explaining why that score was given. The aspects to evaluate are as follows: \n\n a) Count Accuracy:\nAssess the number of generated objects matches the exact prompt.\nScore: 0 (number wrong) to 10 (number correct)\n\nb) Object Uniformity:\nEvaluate the consistency of shape/size/color among same kind of objects.\nScore: 0 (same kind but total different shape/size/color) to 10 (same kind and same shape/size/color).\n\nc) Spatial Legibility:\nEvaluate the plausibility and visibility of object distribution (no excessive overlap).\nScore: 0 (heavily overlapped) to 10 (perfect displayed and all easily seen).\n\n2. Overall Score \nAfter scoring each aspect individually, provide an overall score, representing the model's general performance on this image. This should be a weighted average based on the importance of each aspect to the prompt or an average of all aspects.", 7 | 8 | "colors": "You are a multimodal large-language model tasked with evaluating images generated by a text-to-image model. Your goal is to assess each generated image based on specific aspects and provide a detailed critique, along with a scoring system. The final output should be formatted as a JSON object containing individual scores for each aspect and an overall score. The keys in the JSON object should be: `color_fidelity`, `contrast_effectiveness`, `multi_object_consistency`, and `overall_score`. Below is a comprehensive guide to follow in your evaluation process: Your evaluation should focus on these aspects:\n\n 1. Key Evaluation Aspects and Scoring Criteria: For each aspect, provide a score from 0 to 10, where 0 represents poor performance and 10 represents excellent performance. For each score, include a short explanation or justification (1-2 sentences) explaining why that score was given. The aspects to evaluate are as follows: \n\n a) Color Fidelity:\nAssess the exact match between the object color and the input prompt.\nScore: 0 (color wrong) to 10 (color correct)\n\nb) Contrast Effectiveness:\nEvaluate the difference between foreground and background colors.\nScore: 0 (similar colors, difficult to distinguish) to 10 (high contrast).\n\nc) Multi-Object Consistency:\nAssess color consistency across multiple same kind of objects.\nScore: 0 (same kind of objects with total different colors) to 10 (same kind with same color).\n\n2. Overall Score \nAfter scoring each aspect individually, provide an overall score, representing the model's general performance on this image. This should be a weighted average based on the importance of each aspect to the prompt or an average of all aspects.", 9 | 10 | "position": "You are a multimodal large-language model tasked with evaluating images generated by a text-to-image model. Your goal is to assess each generated image based on specific aspects and provide a detailed critique, along with a scoring system. The final output should be formatted as a JSON object containing individual scores for each aspect and an overall score. The keys in the JSON object should be: `position_accuracy`, `occlusion_management`, `perspective_consistency`, and `overall_score`. Below is a comprehensive guide to follow in your evaluation process: Your evaluation should focus on these aspects:\n\n 1. Key Evaluation Aspects and Scoring Criteria: For each aspect, provide a score from 0 to 10, where 0 represents poor performance and 10 represents excellent performance. For each score, include a short explanation or justification (1-2 sentences) explaining why that score was given. The aspects to evaluate are as follows: \n\n a) Positional Accuracy:\nAssess the matching accuracy between spatial position and prompt description.\nScore: 0 (totally wrong) to 10 (postion correct)\n\nb) Occlusion Management:\nEvaluate position discernibility in the presence of occlusion.\nScore: 0 (fully occlusion) to 10 (clearly dsiplay the relationship).\n\nc) Perspective Consistency:\nAssess the rationality of perspective relationship and spatial depth.\nScore: 0 (perspective contradiction) to 10 (completely reasonable).\n\n2. Overall Score \nAfter scoring each aspect individually, provide an overall score, representing the model's general performance on this image. This should be a weighted average based on the importance of each aspect to the prompt or an average of all aspects.", 11 | 12 | "color_attr": "You are a multimodal large-language model tasked with evaluating images generated by a text-to-image model. Your goal is to assess each generated image based on specific aspects and provide a detailed critique, along with a scoring system. The final output should be formatted as a JSON object containing individual scores for each aspect and an overall score. The keys in the JSON object should be: `attribute_binding`, `contrast_effectiveness`, `material_consistency`, and `overall_score`. Below is a comprehensive guide to follow in your evaluation process: Your evaluation should focus on these aspects:\n\n 1. Key Evaluation Aspects and Scoring Criteria: For each aspect, provide a score from 0 to 10, where 0 represents poor performance and 10 represents excellent performance. For each score, include a short explanation or justification (1-2 sentences) explaining why that score was given. The aspects to evaluate are as follows: \n\n a) Attrribute Binding:\nCorrect binding of colors to designated objects (no color mismatches).\nScore: 0 (color mismatch) to 10 (correct binding)\n\nb) Evaluate the difference between foreground and background colors.\nScore: 0 (similar colors, difficult to distinguish) to 10 (high contrast).\n\nc) Material Consistency:\nAssess the coordination of color and material performance.\nScore: 0 (material conflicts) to 10 (perfect harmony).\n\n2. Overall Score \nAfter scoring each aspect individually, provide an overall score, representing the model's general performance on this image. This should be a weighted average based on the importance of each aspect to the prompt or an average of all aspects." 13 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

From Reflection to Perfection:
Scaling Inference-Time Optimization for Text-to-Image Diffusion Models via Reflection Tuning

3 | 4 | 5 | 6 | 7 | arXiv 8 | 9 | Website 10 | 11 | HF Dataset: ReflectionFlow 12 | 13 |
14 | Le Zhuo1,4, 15 | Liangbing Zhao2, 16 | Sayak Paul3, 17 | Yue Liao1, 18 | Renrui Zhang1, 19 | Yi Xin4, 20 | Peng Gao4, 21 |
22 | Mohamed Elhoseiny2, 23 | Hongsheng Li1 24 |
25 | 26 | 27 |
28 | 1CUHK MMLAB  29 | 2KAUST  30 | 3Hugging Face  31 | 4Shanghai AI Lab  32 |
33 | 34 | 35 | 36 | 37 | Overall pipeline of the ReflectionFlow framework with qualitative and quantitative results of scaling compute at inference time. 38 | 39 |
40 | 41 | ## :fire: News 42 | 43 | - [2025/6/25] Our paper is accepted by ICCV 2025! 44 | - [2025/5/23] Release the code for our image verifier. 45 | - [2025/4/23] Release [paper](https://arxiv.org/abs/2504.16080). 46 | - [2025/4/20] Release GenRef dataset, model checkpoints, as well as the training and inference code. 47 | 48 | ## ✨ Quick Start 49 | 50 | ### Installation 51 | 52 | 1. **Environment setup** 53 | ```bash 54 | conda create -n ReflectionFlow python=3.10 55 | conda activate ReflectionFlow 56 | ``` 57 | 2. **Requirements installation** 58 | ```bash 59 | pip install -r requirements.txt 60 | ``` 61 | 62 | ## 🚀 Models and Datasets 63 | 64 | ### Datasets 65 | | Name | Description | Link | 66 | | --- | --- | --- | 67 | | GenRef-wds | WebDataset format of full GenRef | [HuggingFace](https://huggingface.co/datasets/diffusion-cot/GenRef-wds) | 68 | | GenRef-CoT | Chain-of-Thought reflection dataset | [HuggingFace](https://huggingface.co/datasets/diffusion-cot/GenRef-CoT) | 69 | 70 | ### Models 71 | | Name | Description | Finetune Data | Link | 72 | | --- | --- | --- | --- | 73 | | FLUX Corrector | Main FLUX-based "text image -> image" model | GenRef-wds | [HuggingFace](https://huggingface.co/diffusion-cot/FLUX-Corrector) | 74 | | Reflection Generator | Qwen-based reflection generator | GenRef-CoT | [HuggingFace](https://huggingface.co/diffusion-cot/Reflection-Generator) | 75 | | Image Verifier | Qwen-based image verifier | GenRef-CoT | [HuggingFace](https://huggingface.co/diffusion-cot/Image-Verifier) | 76 | 77 | 78 | ## 🤖 Reflection Tuning 79 | 80 | [`train_flux/config.yaml`](./train_flux/config.yaml) exposes all the arguments to control 81 | all the training-time configurations. 82 | 83 | First, get the data. You can either download the `webdataset` shards from [`diffusion-cot/GenRef-wds`](https://huggingface.co/datasets/diffusion-cot/GenRef-wds) or directly pass URLs. 84 | 85 | When using local paths, set `path` under `[train][dataset]` to a glob pattern: `DATA_DIR/genref_*.tar`. The current `config.yaml` configures training to stream from the `diffusion-cot/GenRef-wds` repository. You can even 86 | change the number of tars you want to stream for easier debugging. Just change `genref_{0..208}.tar` to something 87 | like `genref_{0..4}.tar`, depending on the number of shards you want to use. 88 | 89 | Run the following command for training the FLUX Corrector: 90 | 91 | ```bash 92 | bash train_flux/train.sh 93 | ``` 94 | 95 | We tested our implementation on a single node of 8 80GB A100s and H100s. We acknowledge that there are opportunities 96 | for optimization, but we didn't prioritize them in this release. 97 | 98 | >[!NOTE] 99 | > Validation during training is yet to be implemented. 100 | 101 | ## ⚡ Inference Time Scaling 102 | 103 | ### Introduction 104 | We provide the code for the inference time scaling of our reflection-tuned models. Currently, we support: 105 | * GPT-4o as verifier, reflection generator, and prompt refiner. 106 | * [NVILA-2B](https://huggingface.co/Efficient-Large-Model/NVILA-Lite-2B-Verifier) verifier from SANA. 107 | * Our [reflection generator](https://huggingface.co/diffusion-cot/Reflection-Generator). 108 | 109 | ### Setup 110 | First, you need to set up the following: 111 | 112 | ```bash 113 | export OPENAI_API_KEY=your_api_key 114 | # if you want to use NVILA as verifier 115 | pip install transformers==4.46 116 | pip install git+https://github.com/bfshi/scaling_on_scales.git 117 | ``` 118 | Then you need to set up the `FLUX_PATH` and `LORA_PATH` in the config file of your choice from [tts/config](./tts/configs/). The `FLUX_PATH` is basically the contents of [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/tree/main) which can be downloaded like so: 119 | 120 | ```py 121 | from huggingface_hub import snapshot_download 122 | 123 | local_dir = "SOME_DIR" 124 | snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", local_dir=local_dir) 125 | ``` 126 | 127 | The `LORA_PATH` is our [corrector model](https://huggingface.co/diffusion-cot/FLUX-Corrector) path. 128 | 129 | If you want to use our finetuned reflection generator, you need to first install [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory). Then download the model from [here](https://huggingface.co/diffusion-cot/Reflection-Generator) and change the `model_name_or_path` in the config file of 130 | `tts/config/our_reflectionmodel.yaml` to the reflection generator path. To be specific, the path should be like `Reflection-Generator/infer/30000`. Next, host the model with: 131 | 132 | ```bash 133 | API_PORT=8001 CUDA_VISIBLE_DEVICES=0 llamafactory-cli api configs/our_reflectionmodel.yaml 134 | ``` 135 | And change the `name` of `reflection_args` in the config file (for example: [tts/configs/flux.1_dev_gptscore.json](./tts/config/flux.1_dev_gptscore.json)) to `ours`. 136 | 137 | > [!NOTE] 138 | > When using our reflection generator model, please consider using at least two GPUs for better allocating resources. 139 | 140 | ### Run 141 | 142 | First, please run `tts_t2i_noise_scaling.py` to generate naive noise scaling results, with the commands: 143 | 144 | ```bash 145 | export OUTPUT_DIR=output_dir 146 | cd tts 147 | python tts_t2i_noise_scaling.py --output_dir=$OUTPUT_DIR --meta_path=geneval/evaluation_metadata.jsonl --pipeline_config_path=configs/flux.1_dev_gptscore.json 148 | ``` 149 | 150 | Next, you can run the following command to generate the results of reflection tuning: 151 | 152 | ```bash 153 | export NEW_OUTPUT_DIR=reflection_tuning_dir 154 | python tts_reflectionflow.py --imgpath=$OUTPUT_DIR --pipeline_config_path=configs/flux.1_dev_gptscore.json --output_dir=$NEW_OUTPUT_DIR 155 | ``` 156 | 157 | We also provide the code for only noise & prompt scaling: 158 | 159 | ```bash 160 | python tts_t2i_noise_prompt_scaling.py --output_dir=$OUTPUT_DIR --meta_path=geneval/evaluation_metadata.jsonl --pipeline_config_path=configs/flux.1_dev_gptscore.json 161 | ``` 162 | 163 | You can also change to [tts/configs/flux.1_dev_nvilascore.json](./tts/config/flux.1_dev_nvilascore.json) to use the NVILA verifier. 164 | 165 | By default, we use prompts from [tts/config/geneval/evaluation_metadata.jsonl](./tts/config/geneval/evaluation_metadata.jsonl). If you don't want to use all the prompts from it, you can specify `--start_index` and `--end_index` CLI args. 166 | 167 | ### NVILA Verifier Filter 168 | 169 | After generation, we provide the code using NVILA verifier to filter and get different numbers of sample results. 170 | 171 | ```bash 172 | python verifier_filter.py --imgpath=$OUTPUT_DIR --pipeline_config_path=configs/flux.1_dev_nvilascore.json 173 | ``` 174 | 175 | ### Our Image Verifier 176 | 177 | We provide a simple start code for our image verifier. To run the code, please first upgrade the `transformers`. Currently, we use the `transformers` version `4.51.3`. 178 | 179 | ```bash 180 | pip install transformers==4.51.3 181 | ``` 182 | 183 | Then you can run the following code to get the score of the image: 184 | 185 | ```python 186 | from reward_modeling.test_reward import ImageVLMRewardInference 187 | import torch 188 | 189 | imgname = IMG_PATH 190 | original_prompt = ORIGINAL_PROMPT 191 | 192 | score_verfier = ImageVLMRewardInference(MODEL_PATH, load_from_pretrained_step=10080, device="cuda", dtype=torch.bfloat16) 193 | scores = score_verfier.reward([imgname], [original_prompt], use_norm=True) 194 | print(scores[0]['VQ']) 195 | ``` 196 | 197 | The `MODEL_PATH` is the path to the model [checkpoint](https://huggingface.co/diffusion-cot/Image-Verifier). And `scores[0]['VQ']` is the score of the text-image pair, which is higher the better. 198 | 199 | ## 🤝 Acknowledgement 200 | 201 | We are deeply grateful for the following GitHub repositories, as their valuable code and efforts have been incredibly helpful: 202 | 203 | * OminiControl (https://github.com/Yuanshi9815/OminiControl) 204 | * Flux-TTS (https://github.com/sayakpaul/tt-scale-flux) 205 | 206 | 207 | ## ✏️ Citation 208 | 209 | If you find ReflectionFlow useful for your research and applications, please cite using this BibTeX: 210 | 211 | ```bibtex 212 | @misc{zhuo2025reflectionperfectionscalinginferencetime, 213 | title={From Reflection to Perfection: Scaling Inference-Time Optimization for Text-to-Image Diffusion Models via Reflection Tuning}, 214 | author={Le Zhuo and Liangbing Zhao and Sayak Paul and Yue Liao and Renrui Zhang and Yi Xin and Peng Gao and Mohamed Elhoseiny and Hongsheng Li}, 215 | year={2025}, 216 | eprint={2504.16080}, 217 | archivePrefix={arXiv}, 218 | primaryClass={cs.CV}, 219 | url={https://arxiv.org/abs/2504.16080}, 220 | } 221 | ``` 222 | -------------------------------------------------------------------------------- /tts/tts_t2i_noise_prompt_scaling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from datetime import datetime 4 | 5 | import numpy as np 6 | import torch 7 | from diffusers import DiffusionPipeline 8 | from tqdm.auto import tqdm 9 | import copy 10 | from PIL import Image 11 | import time 12 | from typing import Union, List, Optional 13 | from verifiers.openai_verifier import OpenAIVerifier 14 | from verifiers.nvila_verifier import load_model 15 | 16 | from utils import get_noises, TORCH_DTYPE_MAP, get_latent_prep_fn, parse_cli_args 17 | 18 | global verifier, yes_id, no_id 19 | # Non-configurable constants 20 | MAX_SEED = np.iinfo(np.int32).max # To generate random seeds 21 | 22 | def sample( 23 | noises: dict[int, torch.Tensor], 24 | original_prompt: str, 25 | updated_prompt: Union[str, List[str]], 26 | search_round: int, 27 | pipe: DiffusionPipeline, 28 | topk: int, 29 | root_dir: str, 30 | config: dict, 31 | midimg_path: str, 32 | tag: str, 33 | ) -> dict: 34 | """ 35 | For a given prompt, generate images using all provided noises in batches, 36 | score them with the verifier, and select the top-K noise. 37 | The images and JSON artifacts are saved under `root_dir`. 38 | """ 39 | global verifier, yes_id, no_id 40 | config_cp = copy.deepcopy(config) 41 | verifier_args = config["verifier_args"] 42 | verifier_name = verifier_args.get("name", "openai") 43 | 44 | refine_args = config["refine_args"] 45 | max_new_tokens = refine_args.get("max_new_tokens", None) 46 | choice_of_metric = refine_args.get("choice_of_metric", None) 47 | # currently only support openai refiner 48 | refiner = OpenAIVerifier(refine_prompt_relpath=refine_args["refine_prompt_relpath"], reflexion_prompt_relpath=refine_args["reflexion_prompt_relpath"], verifier_prompt_relpath=refine_args["verifier_prompt_relpath"]) 49 | 50 | use_low_gpu_vram = config_cp.get("use_low_gpu_vram", False) 51 | batch_size_for_img_gen = config_cp.get("batch_size_for_img_gen", 1) 52 | 53 | images_for_prompt = [] 54 | noises_used = [] 55 | seeds_used = [] 56 | prompts = updated_prompt 57 | 58 | # Convert the noises dictionary into a list of (seed, noise) tuples. 59 | noise_items = list(noises.items()) 60 | 61 | # Process the noises in batches. 62 | full_imgnames = [] 63 | for i in range(0, len(noise_items), batch_size_for_img_gen): 64 | batch = noise_items[i : i + batch_size_for_img_gen] 65 | seeds_batch, noises_batch = zip(*batch) 66 | filenames_batch = [ 67 | os.path.join(midimg_path, f"{search_round}_round@{seed}.png") for seed in seeds_batch 68 | ] 69 | full_imgnames.extend(filenames_batch) 70 | 71 | if use_low_gpu_vram: 72 | pipe = pipe.to("cuda:0") 73 | print(f"Generating images for batch with seeds: {[s for s in seeds_batch]}.") 74 | 75 | # Create a batched prompt list and stack the latents. 76 | batched_latents = torch.stack(noises_batch).squeeze(dim=1) 77 | batched_prompts = prompts[i : i + batch_size_for_img_gen] 78 | 79 | batch_result = pipe(prompt=batched_prompts, latents=batched_latents, guidance_scale=config_cp["pipeline_args"]["guidance_scale"], num_inference_steps=config_cp["pipeline_args"]["num_inference_steps"], height=config_cp["pipeline_args"]["height"], width=config_cp["pipeline_args"]["width"]) 80 | batch_images = batch_result.images 81 | if use_low_gpu_vram: 82 | pipe = pipe.to("cpu") 83 | 84 | # Iterate over the batch and save the images. 85 | for seed, noise, image, filename in zip(seeds_batch, noises_batch, batch_images, filenames_batch): 86 | images_for_prompt.append(image) 87 | noises_used.append(noise) 88 | seeds_used.append(seed) 89 | image.save(filename) 90 | 91 | # Prepare verifier inputs and perform inference. 92 | start_time = time.time() 93 | if verifier_name == "openai": 94 | verifier_inputs = verifier.prepare_inputs(images=images_for_prompt, prompts=[original_prompt]*len(images_for_prompt)) 95 | outputs = verifier.score( 96 | inputs=verifier_inputs, 97 | tag=tag, 98 | max_new_tokens=max_new_tokens, # Ignored when using Gemini for now. 99 | ) 100 | def f(x): 101 | if isinstance(x[choice_of_metric], dict): 102 | return x[choice_of_metric]["score"] 103 | return x[choice_of_metric] 104 | sorted_list = sorted(outputs, key=lambda x: f(x), reverse=True) 105 | elif verifier_name == "nvila": 106 | outputs = [] 107 | for imgname in full_imgnames: 108 | r1, scores1 = verifier.generate_content([Image.open(imgname), original_prompt]) 109 | if r1 == "yes": 110 | outputs.append({"image_name": imgname, "label": "yes", "score": scores1[0][0, yes_id].detach().cpu().float().item()}) 111 | else: 112 | outputs.append({"image_name": imgname, "label": "no", "score": scores1[0][0, no_id].detach().cpu().float().item()}) 113 | def f(x): 114 | if x["label"] == "yes": 115 | return (0, -x["score"]) 116 | else: 117 | return (1, x["score"]) 118 | sorted_list = sorted(outputs, key=lambda x: f(x)) 119 | end_time = time.time() 120 | print(f"Time taken for evaluation: {end_time - start_time} seconds") 121 | 122 | topk_scores = sorted_list[:topk] 123 | topk_idx = [outputs.index(x) for x in topk_scores] 124 | 125 | # Refine the prompt for the next round 126 | evaluations = [json.dumps(json_dict) for json_dict in outputs] 127 | if verifier_name == "openai": 128 | refined_prompt_inputs = refiner.prepare_refine_prompt_inputs(images=images_for_prompt, evaluations=evaluations, original_prompt=[original_prompt] * len(images_for_prompt), current_prompt=prompts) 129 | else: 130 | refined_prompt_inputs = refiner.prepare_refine_prompt_inputs(images=images_for_prompt, original_prompt=[original_prompt] * len(images_for_prompt), current_prompt=prompts) 131 | refined_prompt = refiner.refine_prompt(inputs=refined_prompt_inputs) 132 | assert len(refined_prompt) == len(prompts) 133 | prompts = refined_prompt 134 | 135 | with open(os.path.join(root_dir, f"best_img_meta.jsonl"), "a") as f: 136 | f.write(f"refined_prompt{search_round}: "+json.dumps(prompts) + "\n") 137 | 138 | datapoint = { 139 | "original_prompt": original_prompt, 140 | "refined_prompt": prompts, 141 | "search_round": search_round, 142 | "num_noises": len(noises), 143 | "choice_of_metric": choice_of_metric, 144 | } 145 | return datapoint 146 | 147 | 148 | @torch.no_grad() 149 | def main(): 150 | """ 151 | Main function: 152 | - Parses CLI arguments. 153 | - Creates an output directory based on verifier and current datetime. 154 | - Loads prompts. 155 | - Loads the image-generation pipeline. 156 | - Loads the verifier model. 157 | - Runs several search rounds where for each prompt a pool of random noises is generated, 158 | candidate images are produced and verified, and the best noise is chosen. 159 | """ 160 | global verifier, yes_id, no_id 161 | args = parse_cli_args() 162 | os.environ["API_KEY"] = os.environ["OPENAI_API_KEY"] 163 | 164 | # Build a config dictionary for parameters that need to be passed around. 165 | with open(args.pipeline_config_path, "r") as f: 166 | config = json.load(f) 167 | 168 | config.update(vars(args)) 169 | search_rounds = config["search_args"]["search_rounds"] 170 | search_branch = config["search_args"]["search_branch"] 171 | 172 | # Create a root output directory: output/{verifier_to_use}/{current_datetime} 173 | pipeline_name = config["pipeline_args"].get("pretrained_model_name_or_path") 174 | cache_dir = config["pipeline_args"]["cache_dir"] 175 | root_dir = config["output_dir"] 176 | os.makedirs(root_dir, exist_ok=True) 177 | 178 | # Set up the image-generation pipeline (on the first GPU if available). 179 | torch_dtype = TORCH_DTYPE_MAP[config["pipeline_args"].get("torch_dtype")] 180 | pipe = DiffusionPipeline.from_pretrained(pipeline_name, torch_dtype=torch_dtype, cache_dir=cache_dir) 181 | if not config["use_low_gpu_vram"]: 182 | pipe = pipe.to("cuda:0") 183 | pipe.set_progress_bar_config(disable=True) 184 | 185 | # Doesn't help that much currently as several things within the transformer are changing. 186 | if config["pipeline_args"].get("compile", False): 187 | pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) 188 | print("Compilation.") 189 | 190 | # Load the verifier model. 191 | verifier_args = config["verifier_args"] 192 | verifier_name = verifier_args.get("name", "openai") 193 | if verifier_name == "openai": 194 | verifier = OpenAIVerifier(refine_prompt_relpath=verifier_args["refine_prompt_relpath"], reflexion_prompt_relpath=verifier_args["reflexion_prompt_relpath"], verifier_prompt_relpath=verifier_args["verifier_prompt_relpath"]) 195 | elif verifier_name == "nvila": 196 | verifier, yes_id, no_id = load_model(model_name=verifier_args["model_name"], cache_dir=verifier_args["cache_dir"]) 197 | else: 198 | raise ValueError(f"Verifier {verifier_name} not supported") 199 | 200 | # Main loop: For each search round and each prompt, generate images, verify, and save artifacts. 201 | with open(args.meta_path) as fp: 202 | metadatas = [json.loads(line) for line in fp] 203 | 204 | # meta splits 205 | if args.end_index == -1: 206 | metadatas = metadatas[args.start_index:] 207 | else: 208 | metadatas = metadatas[args.start_index:args.end_index] 209 | 210 | for index, metadata in tqdm(enumerate(metadatas), desc="Sampling data"): 211 | # create output directory 212 | outpath = os.path.join(root_dir, f"{index + args.start_index:0>5}") 213 | os.makedirs(outpath, exist_ok=True) 214 | 215 | # create middle img directory 216 | midimg_path = os.path.join(outpath, "samples") 217 | os.makedirs(midimg_path, exist_ok=True) 218 | 219 | # create metadata file 220 | with open(os.path.join(outpath, "metadata.jsonl"), "w") as fp: 221 | json.dump(metadata, fp) 222 | 223 | updated_prompt = [metadata['prompt']] * search_branch 224 | original_prompt = metadata['prompt'] 225 | for round in range(1, search_rounds + 1): 226 | print(f"\n=== Round: {round} ===") 227 | noises = get_noises( 228 | max_seed=MAX_SEED, 229 | num_samples=search_branch, 230 | height=config["pipeline_args"]["height"], 231 | width=config["pipeline_args"]["width"], 232 | dtype=torch_dtype, 233 | fn=get_latent_prep_fn(pipeline_name), 234 | ) 235 | print(f"Number of noise samples: {len(noises)}") 236 | datapoint = sample( 237 | noises=noises, 238 | original_prompt=original_prompt, 239 | updated_prompt=updated_prompt, 240 | search_round=round, 241 | pipe=pipe, 242 | topk=search_branch, 243 | root_dir=outpath, 244 | config=config, 245 | midimg_path=midimg_path, 246 | tag=metadata['tag'], 247 | ) 248 | updated_prompt = datapoint['refined_prompt'] 249 | 250 | if __name__ == "__main__": 251 | main() 252 | -------------------------------------------------------------------------------- /reward_modeling/inference.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import os 4 | import pdb 5 | import argparse 6 | from collections.abc import Mapping 7 | import pandas as pd 8 | from tqdm import tqdm 9 | import sys 10 | sys.path.append('ReflectionFlow') 11 | sys.path.append('reward_modeling') 12 | import torch 13 | from vision_process import process_vision_info 14 | 15 | from data import DataConfig 16 | from reward_modeling.utils import ModelConfig, PEFTLoraConfig, TrainingConfig 17 | from reward_modeling.utils import load_model_from_checkpoint 18 | from train_reward import create_model_and_processor 19 | from prompt_template import build_prompt 20 | 21 | 22 | def load_configs_from_json(config_path): 23 | with open(config_path, "r") as f: 24 | config_dict = json.load(f) 25 | 26 | # del config_dict["training_args"]["_n_gpu"] 27 | del config_dict["data_config"]["meta_data"] 28 | del config_dict["data_config"]["data_dir"] 29 | 30 | return config_dict["data_config"], None, config_dict["model_config"], config_dict["peft_lora_config"], \ 31 | config_dict["inference_config"] if "inference_config" in config_dict else None 32 | 33 | class ImageVLMRewardInference(): 34 | def __init__(self, load_from_pretrained, load_from_pretrained_step=-1, device='cuda', dtype=torch.bfloat16): 35 | config_path = os.path.join(load_from_pretrained, "model_config.json") 36 | data_config, _, model_config, peft_lora_config, inference_config = load_configs_from_json(config_path) 37 | data_config = DataConfig(**data_config) 38 | model_config = ModelConfig(**model_config) 39 | peft_lora_config = PEFTLoraConfig(**peft_lora_config) 40 | 41 | training_args = TrainingConfig( 42 | load_from_pretrained=load_from_pretrained, 43 | load_from_pretrained_step=load_from_pretrained_step, 44 | gradient_checkpointing=False, 45 | disable_flash_attn2=False, 46 | bf16=True if dtype == torch.bfloat16 else False, 47 | fp16=True if dtype == torch.float16 else False, 48 | output_dir="", 49 | ) 50 | 51 | model, processor, peft_config = create_model_and_processor( 52 | model_config=model_config, 53 | peft_lora_config=peft_lora_config, 54 | training_args=training_args, 55 | ) 56 | 57 | self.device = device 58 | 59 | model, checkpoint_step = load_model_from_checkpoint(model, load_from_pretrained, load_from_pretrained_step) 60 | model.eval() 61 | 62 | self.model = model 63 | self.processor = processor 64 | 65 | self.model.to(self.device) 66 | 67 | self.data_config = data_config 68 | 69 | self.inference_config = inference_config 70 | 71 | def _norm(self, reward): 72 | if self.inference_config is None: 73 | return reward 74 | else: 75 | reward['VQ'] = (reward['VQ'] - self.inference_config['VQ_mean']) / self.inference_config['VQ_std'] 76 | return reward 77 | 78 | def _pad_sequence(self, sequences, attention_mask, max_len, padding_side='right'): 79 | """ 80 | Pad the sequences to the maximum length. 81 | """ 82 | assert padding_side in ['right', 'left'] 83 | if sequences.shape[1] >= max_len: 84 | return sequences, attention_mask 85 | 86 | pad_len = max_len - sequences.shape[1] 87 | padding = (0, pad_len) if padding_side == 'right' else (pad_len, 0) 88 | 89 | sequences_padded = torch.nn.functional.pad(sequences, padding, 'constant', self.processor.tokenizer.pad_token_id) 90 | attention_mask_padded = torch.nn.functional.pad(attention_mask, padding, 'constant', 0) 91 | 92 | return sequences_padded, attention_mask_padded 93 | 94 | def _prepare_input(self, data): 95 | """ 96 | Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and 97 | handling potential state. 98 | """ 99 | if isinstance(data, Mapping): 100 | return type(data)({k: self._prepare_input(v) for k, v in data.items()}) 101 | elif isinstance(data, (tuple, list)): 102 | return type(data)(self._prepare_input(v) for v in data) 103 | elif isinstance(data, torch.Tensor): 104 | kwargs = {"device": self.device} 105 | ## TODO: Maybe need to add dtype 106 | # if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)): 107 | # # NLP models inputs are int/uint and those get adjusted to the right dtype of the 108 | # # embedding. Other models such as wav2vec2's inputs are already float and thus 109 | # # may need special handling to match the dtypes of the model 110 | # kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()}) 111 | return data.to(**kwargs) 112 | return data 113 | 114 | def _prepare_inputs(self, inputs): 115 | """ 116 | Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and 117 | handling potential state. 118 | """ 119 | inputs = self._prepare_input(inputs) 120 | if len(inputs) == 0: 121 | raise ValueError 122 | return inputs 123 | 124 | def prepare_batch(self, image_paths, prompts, max_pixels=None,): 125 | max_pixels = self.data_config.max_frame_pixels if max_pixels is None else max_pixels 126 | 127 | chat_data = [ 128 | [ 129 | { 130 | "role": "user", 131 | "content": [ 132 | { 133 | "type": "image", 134 | "image": image_path, 135 | "max_pixels": max_pixels, 136 | }, 137 | {"type": "text", "text": build_prompt(prompt, self.data_config.eval_dim, self.data_config.prompt_template_type)}, 138 | ], 139 | }, 140 | ] for image_path, prompt in zip(image_paths, prompts) 141 | ] 142 | image_inputs, video_inputs = process_vision_info(chat_data) 143 | 144 | batch = self.processor( 145 | text=self.processor.apply_chat_template(chat_data, tokenize=False, add_generation_prompt=True), 146 | images=image_inputs, 147 | videos=video_inputs, 148 | padding=True, 149 | return_tensors="pt", 150 | videos_kwargs={"do_rescale": True}, 151 | ) 152 | batch = self._prepare_inputs(batch) 153 | 154 | return batch 155 | 156 | def reward(self, image_paths, prompts, max_pixels=None, use_norm=True): 157 | """ 158 | Inputs: 159 | image_paths: List[str], B paths of the videos. 160 | prompts: List[str], B prompts for the videos. 161 | eval_dims: List[str], N evaluation dimensions. 162 | max_pixels: int, maximum pixels of the videos. If None, use the default value in the config. 163 | use_norm: bool, whether to rescale the output rewards 164 | Outputs: 165 | Rewards: List[dict], N + 1 rewards of the B videos. 166 | """ 167 | 168 | batch = self.prepare_batch(image_paths, prompts, max_pixels) 169 | rewards = self.model( 170 | return_dict=True, 171 | **batch 172 | )["logits"] 173 | 174 | rewards = [{'VQ': reward[0].item()} for reward in rewards] 175 | for i in range(len(rewards)): 176 | if use_norm: 177 | rewards[i] = self._norm(rewards[i]) 178 | rewards[i]['Overall'] = rewards[i]['VQ'] 179 | 180 | return rewards 181 | 182 | 183 | if __name__ == "__main__": 184 | parser = argparse.ArgumentParser(description="Video Alignment Reward Inference") 185 | parser.add_argument("--json_path", type=str, default="/mnt/petrelfs/zhuole/gaopeng_for_zl/data/reflection/geneval_pairs.json", 186 | help="Path to input JSON file") 187 | parser.add_argument("--load_from_pretrained", type=str, default="/mnt/petrelfs/zhuole/VideoAlign/rm_output", 188 | help="Path to pretrained model") 189 | parser.add_argument("--device", type=str, default="cuda", 190 | help="Device to run inference on") 191 | parser.add_argument("--output_path", type=str, default="/mnt/petrelfs/zhuole/data/geneval_pairs_reward.json", 192 | help="Path to output JSON file") 193 | parser.add_argument("--start_index", type=int, default=0, 194 | help="Start index for processing") 195 | parser.add_argument("--end_index", type=int, default=-1, 196 | help="End index for processing (-1 for all)") 197 | parser.add_argument("--batch_size", type=int, default=32, 198 | help="Batch size for processing") 199 | parser.add_argument("--ckpt_step", type=int, default=-1, 200 | help="Checkpoint step for processing") 201 | args = parser.parse_args() 202 | 203 | with open(args.json_path, "r") as f: 204 | data = json.load(f) 205 | 206 | # Check if output file exists and load processed items 207 | if os.path.exists(args.output_path): 208 | with open(args.output_path, "r") as f: 209 | outputs = json.load(f) 210 | else: 211 | outputs = [] 212 | 213 | # Process data and check for already processed items 214 | args.end_index = len(data) if args.end_index == -1 else args.end_index 215 | processed_count = len(outputs) 216 | to_process = [] 217 | for idx, item in enumerate(data[args.start_index:args.end_index]): 218 | item["good_image"] = item["good_image"].replace("gaopeng/zl", "gaopeng_for_zl").replace("ReflectionFlow", "data/reflection") 219 | item["bad_image"] = item["bad_image"].replace("gaopeng/zl", "gaopeng_for_zl").replace("ReflectionFlow", "data/reflection") 220 | 221 | # Skip if already processed 222 | if idx < processed_count: 223 | if outputs[idx]["good_image"] == item["good_image"] and \ 224 | outputs[idx]["bad_image"] == item["bad_image"]: 225 | print(f"Skipping {idx} because it already exists") 226 | continue 227 | else: 228 | raise ValueError(f"Can't find {idx} in outputs") 229 | 230 | to_process.append(item) 231 | 232 | inferencer = ImageVLMRewardInference(args.load_from_pretrained, load_from_pretrained_step=args.ckpt_step, device=args.device, dtype=torch.bfloat16) 233 | 234 | # Process in batches 235 | for i in tqdm(range(0, len(to_process), args.batch_size), 236 | desc="Processing batches"): 237 | batch_items = to_process[i:i+args.batch_size] 238 | 239 | # Prepare batch data 240 | good_image_paths = [] 241 | bad_image_paths = [] 242 | good_prompts = [] 243 | bad_prompts = [] 244 | for item in batch_items: 245 | good_image_paths.append(item["good_image"]) 246 | bad_image_paths.append(item["bad_image"]) 247 | good_prompts.append(item["prompt"]) 248 | bad_prompts.append(item["prompt"]) 249 | image_paths = good_image_paths + bad_image_paths 250 | prompts = good_prompts + bad_prompts 251 | 252 | with torch.no_grad(): 253 | batch_rewards = inferencer.reward(image_paths, prompts, use_norm=True) 254 | 255 | good_rewards = batch_rewards[:len(good_image_paths)] 256 | bad_rewards = batch_rewards[len(good_image_paths):] 257 | # Process results for each item in the batch 258 | for j, item in enumerate(batch_items): 259 | item_with_rewards = item.copy() 260 | item_with_rewards["good_reward"] = good_rewards[j]["VQ"] 261 | item_with_rewards["bad_reward"] = bad_rewards[j]["VQ"] 262 | outputs.append(item_with_rewards) 263 | 264 | # Save after each batch to allow resuming 265 | with open(args.output_path, "w") as f: 266 | json.dump(outputs, f, indent=4) 267 | -------------------------------------------------------------------------------- /train_flux/flux/generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml, os 3 | from diffusers.pipelines import FluxPipeline 4 | from typing import List, Union, Optional, Dict, Any, Callable 5 | from .transformer import tranformer_forward 6 | from .condition import Condition 7 | 8 | from diffusers.pipelines.flux.pipeline_flux import ( 9 | FluxPipelineOutput, 10 | calculate_shift, 11 | retrieve_timesteps, 12 | np, 13 | ) 14 | 15 | 16 | def get_config(config_path: str = None): 17 | config_path = config_path or os.environ.get("XFL_CONFIG") 18 | if not config_path: 19 | return {} 20 | with open(config_path, "r") as f: 21 | config = yaml.safe_load(f) 22 | return config 23 | 24 | 25 | def prepare_params( 26 | prompt: Union[str, List[str]] = None, 27 | prompt_2: Optional[Union[str, List[str]]] = None, 28 | height: Optional[int] = 512, 29 | width: Optional[int] = 512, 30 | num_inference_steps: int = 28, 31 | timesteps: List[int] = None, 32 | guidance_scale: float = 3.5, 33 | num_images_per_prompt: Optional[int] = 1, 34 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 35 | latents: Optional[torch.FloatTensor] = None, 36 | prompt_embeds: Optional[torch.FloatTensor] = None, 37 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 38 | output_type: Optional[str] = "pil", 39 | return_dict: bool = True, 40 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 41 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 42 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 43 | max_sequence_length: int = 512, 44 | **kwargs: dict, 45 | ): 46 | return ( 47 | prompt, 48 | prompt_2, 49 | height, 50 | width, 51 | num_inference_steps, 52 | timesteps, 53 | guidance_scale, 54 | num_images_per_prompt, 55 | generator, 56 | latents, 57 | prompt_embeds, 58 | pooled_prompt_embeds, 59 | output_type, 60 | return_dict, 61 | joint_attention_kwargs, 62 | callback_on_step_end, 63 | callback_on_step_end_tensor_inputs, 64 | max_sequence_length, 65 | ) 66 | 67 | 68 | def seed_everything(seed: int = 42): 69 | torch.backends.cudnn.deterministic = True 70 | torch.manual_seed(seed) 71 | np.random.seed(seed) 72 | 73 | 74 | @torch.no_grad() 75 | def generate( 76 | pipeline: FluxPipeline, 77 | conditions: List[Condition] = None, 78 | config_path: str = None, 79 | model_config: Optional[Dict[str, Any]] = {}, 80 | condition_scale: float = 1.0, 81 | default_lora: bool = False, 82 | image_guidance_scale: float = 1.0, 83 | **params: dict, 84 | ): 85 | model_config = model_config or get_config(config_path).get("model", {}) 86 | if condition_scale != 1: 87 | for name, module in pipeline.transformer.named_modules(): 88 | if not name.endswith(".attn"): 89 | continue 90 | module.c_factor = torch.ones(1, 1) * condition_scale 91 | 92 | self = pipeline 93 | ( 94 | prompt, 95 | prompt_2, 96 | height, 97 | width, 98 | num_inference_steps, 99 | timesteps, 100 | guidance_scale, 101 | num_images_per_prompt, 102 | generator, 103 | latents, 104 | prompt_embeds, 105 | pooled_prompt_embeds, 106 | output_type, 107 | return_dict, 108 | joint_attention_kwargs, 109 | callback_on_step_end, 110 | callback_on_step_end_tensor_inputs, 111 | max_sequence_length, 112 | ) = prepare_params(**params) 113 | 114 | height = height or self.default_sample_size * self.vae_scale_factor 115 | width = width or self.default_sample_size * self.vae_scale_factor 116 | 117 | # 1. Check inputs. Raise error if not correct 118 | self.check_inputs( 119 | prompt, 120 | prompt_2, 121 | height, 122 | width, 123 | prompt_embeds=prompt_embeds, 124 | pooled_prompt_embeds=pooled_prompt_embeds, 125 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 126 | max_sequence_length=max_sequence_length, 127 | ) 128 | 129 | self._guidance_scale = guidance_scale 130 | self._joint_attention_kwargs = joint_attention_kwargs 131 | self._interrupt = False 132 | 133 | # 2. Define call parameters 134 | if prompt is not None and isinstance(prompt, str): 135 | batch_size = 1 136 | elif prompt is not None and isinstance(prompt, list): 137 | batch_size = len(prompt) 138 | else: 139 | batch_size = prompt_embeds.shape[0] 140 | 141 | device = self._execution_device 142 | 143 | lora_scale = ( 144 | self.joint_attention_kwargs.get("scale", None) 145 | if self.joint_attention_kwargs is not None 146 | else None 147 | ) 148 | ( 149 | prompt_embeds, 150 | pooled_prompt_embeds, 151 | text_ids, 152 | ) = self.encode_prompt( 153 | prompt=prompt, 154 | prompt_2=prompt_2, 155 | prompt_embeds=prompt_embeds, 156 | pooled_prompt_embeds=pooled_prompt_embeds, 157 | device=device, 158 | num_images_per_prompt=num_images_per_prompt, 159 | max_sequence_length=max_sequence_length, 160 | lora_scale=lora_scale, 161 | ) 162 | 163 | # 4. Prepare latent variables 164 | num_channels_latents = self.transformer.config.in_channels // 4 165 | latents, latent_image_ids = self.prepare_latents( 166 | batch_size * num_images_per_prompt, 167 | num_channels_latents, 168 | height, 169 | width, 170 | prompt_embeds.dtype, 171 | device, 172 | generator, 173 | latents, 174 | ) 175 | 176 | # 4.1. Prepare conditions 177 | condition_latents, condition_ids, condition_type_ids = ([] for _ in range(3)) 178 | use_condition = conditions is not None or [] 179 | if use_condition: 180 | assert len(conditions) <= 1, "Only one condition is supported for now." 181 | if not default_lora: 182 | pipeline.set_adapters(conditions[0].condition_type) 183 | for condition in conditions: 184 | tokens, ids, type_id = condition.encode(self) 185 | condition_latents.append(tokens) # [batch_size, token_n, token_dim] 186 | condition_ids.append(ids) # [token_n, id_dim(3)] 187 | condition_type_ids.append(type_id) # [token_n, 1] 188 | condition_latents = torch.cat(condition_latents, dim=1) 189 | condition_ids = torch.cat(condition_ids, dim=0) 190 | condition_type_ids = torch.cat(condition_type_ids, dim=0) 191 | 192 | # 5. Prepare timesteps 193 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 194 | image_seq_len = latents.shape[1] 195 | mu = calculate_shift( 196 | image_seq_len, 197 | self.scheduler.config.base_image_seq_len, 198 | self.scheduler.config.max_image_seq_len, 199 | self.scheduler.config.base_shift, 200 | self.scheduler.config.max_shift, 201 | ) 202 | timesteps, num_inference_steps = retrieve_timesteps( 203 | self.scheduler, 204 | num_inference_steps, 205 | device, 206 | timesteps, 207 | sigmas, 208 | mu=mu, 209 | ) 210 | num_warmup_steps = max( 211 | len(timesteps) - num_inference_steps * self.scheduler.order, 0 212 | ) 213 | self._num_timesteps = len(timesteps) 214 | 215 | # 6. Denoising loop 216 | with self.progress_bar(total=num_inference_steps) as progress_bar: 217 | for i, t in enumerate(timesteps): 218 | if self.interrupt: 219 | continue 220 | 221 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 222 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 223 | 224 | # handle guidance 225 | if self.transformer.config.guidance_embeds: 226 | guidance = torch.tensor([guidance_scale], device=device) 227 | guidance = guidance.expand(latents.shape[0]) 228 | else: 229 | guidance = None 230 | noise_pred = tranformer_forward( 231 | self.transformer, 232 | model_config=model_config, 233 | # Inputs of the condition (new feature) 234 | condition_latents=condition_latents if use_condition else None, 235 | condition_ids=condition_ids if use_condition else None, 236 | condition_type_ids=condition_type_ids if use_condition else None, 237 | # Inputs to the original transformer 238 | hidden_states=latents, 239 | # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) 240 | timestep=timestep / 1000, 241 | guidance=guidance, 242 | pooled_projections=pooled_prompt_embeds, 243 | encoder_hidden_states=prompt_embeds, 244 | txt_ids=text_ids, 245 | img_ids=latent_image_ids, 246 | joint_attention_kwargs=self.joint_attention_kwargs, 247 | return_dict=False, 248 | )[0] 249 | 250 | if image_guidance_scale != 1.0: 251 | uncondition_latents = condition.encode(self, empty=True)[0] 252 | unc_pred = tranformer_forward( 253 | self.transformer, 254 | model_config=model_config, 255 | # Inputs of the condition (new feature) 256 | condition_latents=uncondition_latents if use_condition else None, 257 | condition_ids=condition_ids if use_condition else None, 258 | condition_type_ids=condition_type_ids if use_condition else None, 259 | # Inputs to the original transformer 260 | hidden_states=latents, 261 | # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) 262 | timestep=timestep / 1000, 263 | guidance=torch.ones_like(guidance), 264 | pooled_projections=pooled_prompt_embeds, 265 | encoder_hidden_states=prompt_embeds, 266 | txt_ids=text_ids, 267 | img_ids=latent_image_ids, 268 | joint_attention_kwargs=self.joint_attention_kwargs, 269 | return_dict=False, 270 | )[0] 271 | 272 | noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred) 273 | 274 | # compute the previous noisy sample x_t -> x_t-1 275 | latents_dtype = latents.dtype 276 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 277 | 278 | if latents.dtype != latents_dtype: 279 | if torch.backends.mps.is_available(): 280 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 281 | latents = latents.to(latents_dtype) 282 | 283 | if callback_on_step_end is not None: 284 | callback_kwargs = {} 285 | for k in callback_on_step_end_tensor_inputs: 286 | callback_kwargs[k] = locals()[k] 287 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 288 | 289 | latents = callback_outputs.pop("latents", latents) 290 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 291 | 292 | # call the callback, if provided 293 | if i == len(timesteps) - 1 or ( 294 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 295 | ): 296 | progress_bar.update() 297 | 298 | if output_type == "latent": 299 | image = latents 300 | 301 | else: 302 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 303 | latents = ( 304 | latents / self.vae.config.scaling_factor 305 | ) + self.vae.config.shift_factor 306 | image = self.vae.decode(latents, return_dict=False)[0] 307 | image = self.image_processor.postprocess(image, output_type=output_type) 308 | 309 | # Offload all models 310 | self.maybe_free_model_hooks() 311 | 312 | if condition_scale != 1: 313 | for name, module in pipeline.transformer.named_modules(): 314 | if not name.endswith(".attn"): 315 | continue 316 | del module.c_factor 317 | 318 | if not return_dict: 319 | return (image,) 320 | 321 | return FluxPipelineOutput(images=image) -------------------------------------------------------------------------------- /reward_modeling/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from dataclasses import dataclass, field 4 | from typing import List, Literal, Optional 5 | import sys 6 | sys.path.append('ReflectionFlow') 7 | sys.path.append('ReflectionFlow/reward_modeling') 8 | import safetensors 9 | import torch 10 | from peft import PeftModel 11 | from transformers import BitsAndBytesConfig, Qwen2VLForConditionalGeneration, AutoProcessor, AutoConfig, Qwen2_5_VLForConditionalGeneration, TrainingArguments 12 | import warnings 13 | import json 14 | 15 | ########## DataClass For Configure ########## 16 | 17 | @dataclass 18 | class TrainingConfig(TrainingArguments): 19 | max_length: Optional[int] = None 20 | dataset_num_proc: Optional[int] = None 21 | center_rewards_coefficient: Optional[float] = None 22 | disable_flash_attn2: bool = field(default=False) 23 | 24 | vision_lr: Optional[float] = None 25 | merger_lr: Optional[float] = None 26 | special_token_lr: Optional[float] = None 27 | 28 | conduct_eval: Optional[bool] = True 29 | load_from_pretrained: str = None 30 | load_from_pretrained_step: int = None 31 | logging_epochs: Optional[float] = None 32 | eval_epochs: Optional[float] = None 33 | save_epochs: Optional[float] = None 34 | remove_unused_columns: Optional[bool] = False 35 | 36 | save_full_model: Optional[bool] = False 37 | 38 | @dataclass 39 | class PEFTLoraConfig: 40 | lora_enable: bool = False 41 | vision_lora: bool = False 42 | lora_r: int = 16 43 | lora_alpha: int = 32 44 | lora_dropout: float = 0.05 45 | lora_target_modules: Optional[List[str]] = None 46 | lora_namespan_exclude: Optional[List[str]] = None 47 | lora_modules_to_save: Optional[List[str]] = None 48 | lora_task_type: str = "CAUSAL_LM" 49 | use_rslora: bool = False 50 | num_lora_modules: int = -1 51 | 52 | def __post_init__(self): 53 | if isinstance(self.lora_target_modules, list) and len(self.lora_target_modules) == 1: 54 | self.lora_target_modules = self.lora_target_modules[0] 55 | 56 | if isinstance(self.lora_namespan_exclude, list) and len(self.lora_namespan_exclude) == 1: 57 | self.lora_namespan_exclude = self.lora_namespan_exclude[0] 58 | 59 | @dataclass 60 | class ModelConfig: 61 | model_name_or_path: Optional[str] = None 62 | model_revision: str = "main" 63 | 64 | output_dim: int = 1 65 | 66 | use_special_tokens: bool = False 67 | 68 | freeze_vision_tower: bool = field(default=False) 69 | freeze_llm: bool = field(default=False) 70 | tune_merger: bool = field(default=False) 71 | 72 | torch_dtype: Optional[Literal["auto", "bfloat16", "float16", "float32"]] = None 73 | trust_remote_code: bool = False 74 | attn_implementation: Optional[str] = None 75 | load_in_8bit: bool = False 76 | load_in_4bit: bool = False 77 | bnb_4bit_quant_type: Literal["fp4", "nf4"] = "nf4" 78 | use_bnb_nested_quant: bool = False 79 | reward_token: Literal["last", "mean", "special"] = "last" 80 | loss_type: Literal["bt", "reg", "btt", "margin", "constant_margin", "scaled"] = "regular" 81 | 82 | def __post_init__(self): 83 | if self.load_in_8bit and self.load_in_4bit: 84 | raise ValueError("You can't use 8 bit and 4 bit precision at the same time") 85 | 86 | # if isinstance(self.lora_target_modules, list) and len(self.lora_target_modules) == 1: 87 | # self.lora_target_modules = self.lora_target_modules[0] 88 | 89 | # if isinstance(self.lora_namespan_exclude, list) and len(self.lora_namespan_exclude) == 1: 90 | # self.lora_namespan_exclude = self.lora_namespan_exclude[0] 91 | 92 | ########## Functions for get trainable modules' parameters ########## 93 | 94 | def maybe_zero_3(param, ignore_status=False, name=None): 95 | from deepspeed import zero 96 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 97 | if hasattr(param, "ds_id"): 98 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 99 | if not ignore_status: 100 | logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") 101 | with zero.GatheredParameters([param]): 102 | param = param.data.detach().cpu().clone() 103 | else: 104 | param = param.detach().cpu().clone() 105 | return param 106 | 107 | # Borrowed from peft.utils.get_peft_model_state_dict 108 | def get_peft_state_maybe_zero_3(named_params, bias): 109 | if bias == "none": 110 | to_return = {k: t for k, t in named_params if "lora_" in k} 111 | elif bias == "all": 112 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} 113 | elif bias == "lora_only": 114 | to_return = {} 115 | maybe_lora_bias = {} 116 | lora_bias_names = set() 117 | for k, t in named_params: 118 | if "lora_" in k: 119 | to_return[k] = t 120 | bias_name = k.split("lora_")[0] + "bias" 121 | lora_bias_names.add(bias_name) 122 | elif "bias" in k: 123 | maybe_lora_bias[k] = t 124 | for k, t in maybe_lora_bias: 125 | if bias_name in lora_bias_names: 126 | to_return[bias_name] = t 127 | else: 128 | raise NotImplementedError 129 | to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} 130 | return to_return 131 | 132 | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): 133 | to_return = {k: t for k, t in named_params if "lora_" not in k} 134 | if require_grad_only: 135 | to_return = {k: t for k, t in to_return.items() if t.requires_grad} 136 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} 137 | return to_return 138 | 139 | ########## Load Models From Folder ########## 140 | 141 | def _insert_adapter_name_into_state_dict( 142 | state_dict: dict[str, torch.Tensor], adapter_name: str, parameter_prefix: str 143 | ) -> dict[str, torch.Tensor]: 144 | """Utility function to remap the state_dict keys to fit the PEFT model by inserting the adapter name.""" 145 | peft_model_state_dict = {} 146 | for key, val in state_dict.items(): 147 | if parameter_prefix in key: 148 | suffix = key.split(parameter_prefix)[1] 149 | if "." in suffix: 150 | suffix_to_replace = ".".join(suffix.split(".")[1:]) 151 | key = key.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}") 152 | else: 153 | key = f"{key}.{adapter_name}" 154 | peft_model_state_dict[key] = val 155 | else: 156 | peft_model_state_dict[key] = val 157 | return peft_model_state_dict 158 | 159 | 160 | def save_video(tensor, path): 161 | from torchvision.io import write_video 162 | tensor = tensor * 255.0 163 | tensor = tensor.permute(0, 2, 3, 1) 164 | tensor = tensor.clamp(0, 255).byte() 165 | write_video(path, tensor, 4, video_codec='h264') 166 | 167 | 168 | def load_model_from_checkpoint( 169 | model, checkpoint_dir, checkpoint_step 170 | ): 171 | checkpoint_paths = glob.glob(os.path.join(checkpoint_dir, "checkpoint-*")) 172 | checkpoint_paths.sort(key=lambda x: int(x.split("-")[-1]), reverse=True) 173 | 174 | if checkpoint_step is None or checkpoint_step == -1: 175 | # get the latest checkpoint 176 | checkpoint_path = checkpoint_paths[0] 177 | print(f"===> Checkpoint step is not provided, using the latest checkpoint: {checkpoint_path}") 178 | else: 179 | checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint-{checkpoint_step}") 180 | if checkpoint_path not in checkpoint_paths: 181 | checkpoint_path = checkpoint_paths[0] 182 | print(f"===> Checkpoint step {checkpoint_step} not found, using the latest checkpoint: {checkpoint_path}") 183 | else: 184 | print(f"===> Checkpoint step {checkpoint_step} found, using the specified checkpoint: {checkpoint_path}") 185 | 186 | checkpoint_step = checkpoint_path.split("checkpoint-")[-1].split("/")[0] 187 | 188 | full_ckpt = os.path.join(checkpoint_path, "model.pth") 189 | lora_ckpt = os.path.join(checkpoint_path, "adapter_model.safetensors") 190 | non_lora_ckpt = os.path.join(checkpoint_path, "non_lora_state_dict.pth") 191 | if os.path.exists(full_ckpt): 192 | model_state_dict = torch.load(full_ckpt, map_location="cpu") 193 | model.load_state_dict(model_state_dict) 194 | else: 195 | lora_state_dict = safetensors.torch.load_file(lora_ckpt) 196 | non_lora_state_dict = torch.load(non_lora_ckpt, map_location="cpu") 197 | 198 | lora_state_dict = _insert_adapter_name_into_state_dict(lora_state_dict, adapter_name="default", parameter_prefix="lora_") 199 | 200 | model_state_dict = model.state_dict() 201 | model_state_dict.update(non_lora_state_dict) 202 | model_state_dict.update(lora_state_dict) 203 | model.load_state_dict(model_state_dict) 204 | 205 | return model, checkpoint_step 206 | 207 | 208 | def disable_torch_init(): 209 | """ 210 | Disable the redundant torch default initialization to accelerate model creation. 211 | """ 212 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 213 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 214 | 215 | # This code is borrowed from LLaVA 216 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, 217 | device_map="auto", device="cuda", use_flash_attn=False, **kwargs): 218 | kwargs = {"device_map": device_map} 219 | 220 | if device != "cuda": 221 | kwargs['device_map'] = {"":device} 222 | 223 | if load_8bit: 224 | kwargs['load_in_8bit'] = True 225 | elif load_4bit: 226 | kwargs['quantization_config'] = BitsAndBytesConfig( 227 | load_in_4bit=True, 228 | bnb_4bit_compute_dtype=torch.float16, 229 | bnb_4bit_use_double_quant=True, 230 | bnb_4bit_quant_type='nf4' 231 | ) 232 | else: 233 | kwargs['torch_dtype'] = torch.float16 234 | 235 | if use_flash_attn: 236 | kwargs['_attn_implementation'] = 'flash_attention_2' 237 | 238 | if 'lora' in model_name.lower() and model_base is None: 239 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument.') 240 | if 'lora' in model_name.lower() and model_base is not None: 241 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_base) 242 | if hasattr(lora_cfg_pretrained, 'quantization_config'): 243 | del lora_cfg_pretrained.quantization_config 244 | processor = AutoProcessor.from_pretrained(model_base) 245 | print('Loading Qwen2-VL from base model...') 246 | if "Qwen2.5" in model_base: 247 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 248 | else: 249 | model = Qwen2VLForConditionalGeneration.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 250 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 251 | if model.lm_head.weight.shape[0] != token_num: 252 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 253 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 254 | 255 | print('Loading additional Qwen2-VL weights...') 256 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_state_dict.pth'), map_location='cpu') 257 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 258 | if any(k.startswith('model.model.') for k in non_lora_trainables): 259 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 260 | model.load_state_dict(non_lora_trainables, strict=False) 261 | 262 | print('Loading LoRA weights...') 263 | model = PeftModel.from_pretrained(model, model_path) 264 | 265 | print('Merging LoRA weights...') 266 | model = model.merge_and_unload() 267 | 268 | print('Model Loaded!!!') 269 | 270 | else: 271 | with open(os.path.join(model_path, 'config.json'), 'r') as f: 272 | config = json.load(f) 273 | 274 | if "Qwen2_5" in config["architectures"]: 275 | processor = AutoProcessor.from_pretrained(model_path) 276 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 277 | 278 | else: 279 | processor = AutoProcessor.from_pretrained(model_path) 280 | model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 281 | 282 | return processor, model 283 | 284 | 285 | def get_model_name_from_path(model_path): 286 | model_path = model_path.strip("/") 287 | model_paths = model_path.split("/") 288 | if model_paths[-1].startswith('checkpoint-'): 289 | return model_paths[-2] + "_" + model_paths[-1] 290 | else: 291 | return model_paths[-1] -------------------------------------------------------------------------------- /reward_modeling/train_reward.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import os 4 | import pdb 5 | import random 6 | from dataclasses import asdict 7 | from functools import partial 8 | 9 | import torch 10 | from datasets import load_dataset, concatenate_datasets 11 | from peft import LoraConfig, get_peft_model 12 | from transformers import AutoProcessor, HfArgumentParser 13 | from trl import get_kbit_device_map, get_quantization_config 14 | 15 | from trainer import Qwen2VLRewardModelBT, ImageVLMRewardTrainer, compute_multi_attr_accuracy, PartialEmbeddingUpdateCallback 16 | from data import DataConfig, QWen2VLDataCollator, convert_GSB_csv_to_reward_data 17 | from reward_modeling.utils import ModelConfig, PEFTLoraConfig, TrainingConfig 18 | from reward_modeling.utils import load_model_from_checkpoint 19 | 20 | 21 | def save_configs_to_json(data_config, training_args, model_config, peft_lora_config): 22 | """ 23 | Save all configurations to a JSON file. 24 | """ 25 | config_dict = { 26 | "data_config": asdict(data_config), 27 | "training_args": asdict(training_args), 28 | "model_config": asdict(model_config), 29 | "peft_lora_config": asdict(peft_lora_config), 30 | } 31 | # del information about local device 32 | del config_dict["training_args"]["local_rank"] 33 | del config_dict["training_args"]["_n_gpu"] 34 | 35 | save_path = os.path.join(training_args.output_dir, "model_config.json") 36 | 37 | os.makedirs(training_args.output_dir, exist_ok=True) 38 | print(training_args.output_dir) 39 | 40 | with open(save_path, "w") as f: 41 | json.dump(config_dict, f, indent=4) 42 | 43 | def find_target_linear_names(model, num_lora_modules=-1, lora_namespan_exclude=[], verbose=False): 44 | """ 45 | Find the target linear modules for LoRA. 46 | """ 47 | linear_cls = torch.nn.Linear 48 | embedding_cls = torch.nn.Embedding 49 | lora_module_names = [] 50 | 51 | for name, module in model.named_modules(): 52 | if any(ex_keyword in name for ex_keyword in lora_namespan_exclude): 53 | # print(f"Excluding module: {name}") 54 | continue 55 | 56 | if isinstance(module, (linear_cls, embedding_cls)): 57 | lora_module_names.append(name) 58 | 59 | if num_lora_modules > 0: 60 | lora_module_names = lora_module_names[-num_lora_modules:] 61 | if verbose: 62 | print(f"Found {len(lora_module_names)} lora modules: {lora_module_names}") 63 | return lora_module_names 64 | 65 | def set_requires_grad(parameters, requires_grad): 66 | for p in parameters: 67 | p.requires_grad = requires_grad 68 | 69 | def create_model_and_processor( 70 | model_config, peft_lora_config, training_args, 71 | cache_dir=None, 72 | ): 73 | # create model 74 | torch_dtype = ( 75 | model_config.torch_dtype 76 | if model_config.torch_dtype in ["auto", None] 77 | else getattr(torch, model_config.torch_dtype) 78 | ) 79 | quantization_config = get_quantization_config(model_config) 80 | model_kwargs = dict( 81 | revision=model_config.model_revision, 82 | device_map=get_kbit_device_map() if quantization_config is not None else None, 83 | quantization_config=quantization_config, 84 | use_cache=True if training_args.gradient_checkpointing else False, 85 | ) 86 | # pdb.set_trace() 87 | 88 | # create processor and set padding 89 | processor = AutoProcessor.from_pretrained(model_config.model_name_or_path, 90 | padding_side="right", 91 | cache_dir=cache_dir) 92 | 93 | special_token_ids = None 94 | if model_config.use_special_tokens: 95 | special_tokens = ["<|VQ_reward|>"] 96 | processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) 97 | special_token_ids = processor.tokenizer.convert_tokens_to_ids(special_tokens) 98 | 99 | model = Qwen2VLRewardModelBT.from_pretrained( 100 | model_config.model_name_or_path, 101 | output_dim=model_config.output_dim, 102 | reward_token=model_config.reward_token, 103 | special_token_ids=special_token_ids, 104 | torch_dtype=torch_dtype, 105 | attn_implementation="flash_attention_2" if not training_args.disable_flash_attn2 else "sdpa", 106 | cache_dir=cache_dir, 107 | **model_kwargs 108 | ) 109 | if model_config.use_special_tokens: 110 | model.resize_token_embeddings(len(processor.tokenizer)) 111 | 112 | if training_args.bf16: 113 | model.to(torch.bfloat16) 114 | if training_args.fp16: 115 | model.to(torch.float16) 116 | 117 | # create lora and peft model 118 | if peft_lora_config.lora_enable: 119 | target_modules = find_target_linear_names(model, 120 | num_lora_modules=peft_lora_config.num_lora_modules, 121 | lora_namespan_exclude=peft_lora_config.lora_namespan_exclude) 122 | peft_config = LoraConfig( 123 | target_modules=target_modules, 124 | r=peft_lora_config.lora_r, 125 | lora_alpha=peft_lora_config.lora_alpha, 126 | lora_dropout=peft_lora_config.lora_dropout, 127 | task_type=peft_lora_config.lora_task_type, 128 | use_rslora=peft_lora_config.use_rslora, 129 | bias="none", 130 | modules_to_save=peft_lora_config.lora_modules_to_save, 131 | ) 132 | model = get_peft_model(model, peft_config) 133 | else: 134 | peft_config = None 135 | 136 | model.config.tokenizer_padding_side = processor.tokenizer.padding_side 137 | model.config.pad_token_id = processor.tokenizer.pad_token_id 138 | 139 | return model, processor, peft_config 140 | 141 | def create_dataset(data_config, meta_file=None): 142 | if meta_file is None: 143 | meta_file = data_config.meta_data 144 | dataset = load_dataset('json', data_files=meta_file) 145 | def add_idx(example, idx): 146 | example['metainfo_idx'] = idx 147 | return example 148 | dataset['train'] = dataset['train'].map(lambda example, idx: add_idx(example, idx), with_indices=True) 149 | 150 | if not data_config.use_tied_data: 151 | filter_func = lambda example: any(example[f"{dim}"] != "same" for dim in data_config.eval_dim) 152 | dataset = dataset.filter(filter_func) 153 | 154 | # convert data to reward data 155 | convert_func = lambda example: convert_GSB_csv_to_reward_data(example, data_config.data_dir, data_config.eval_dim, 156 | data_config.max_frame_pixels, data_config.prompt_template_type) 157 | dataset = dataset.map(convert_func, remove_columns=dataset['train'].column_names, load_from_cache_file=False) 158 | dataset = dataset['train'] 159 | # pdb.set_trace() 160 | return dataset 161 | 162 | def train(): 163 | ## ===> Step 1: Parse arguments 164 | parser = HfArgumentParser((DataConfig, TrainingConfig, ModelConfig, PEFTLoraConfig)) 165 | data_config, training_args, model_config, peft_lora_config = parser.parse_args_into_dataclasses() 166 | # pdb.set_trace() 167 | 168 | # check valid (lora config) 169 | assert not (peft_lora_config.lora_enable and model_config.freeze_llm), 'When using LoRA, the LLM should not be frozen. If you want to freeze the LLM, please disable LoRA.' 170 | if not peft_lora_config.lora_enable: 171 | assert not peft_lora_config.vision_lora, \ 172 | "Error: model_config.lora_enable is not enabled, but model_config.vision_lora is enabled." 173 | else: 174 | if peft_lora_config.lora_namespan_exclude is not None: 175 | peft_lora_config.lora_namespan_exclude = ast.literal_eval(peft_lora_config.lora_namespan_exclude) 176 | else: 177 | peft_lora_config.lora_namespan_exclude = [] 178 | if not peft_lora_config.vision_lora: 179 | peft_lora_config.lora_namespan_exclude += ["visual"] 180 | 181 | # pdb.set_trace() 182 | 183 | ## ===> Step 2: Load model and configure 184 | model, processor, peft_config = create_model_and_processor( 185 | model_config=model_config, 186 | peft_lora_config=peft_lora_config, 187 | training_args=training_args, 188 | ) 189 | 190 | ## load model 191 | if training_args.load_from_pretrained is not None: 192 | model, checkpoint_step = load_model_from_checkpoint(model, training_args.load_from_pretrained, training_args.load_from_pretrained_step) 193 | model.train() 194 | 195 | if peft_lora_config.lora_enable: 196 | model_to_configure = model.model 197 | else: 198 | model_to_configure = model 199 | # set requires_grad for LLM 200 | set_requires_grad(model_to_configure.model.parameters(), not model_config.freeze_llm) 201 | 202 | if not peft_lora_config.vision_lora: 203 | # set requires_grad for visual encoder and merger 204 | set_requires_grad(model_to_configure.visual.parameters(), not model_config.freeze_vision_tower) 205 | set_requires_grad(model_to_configure.visual.merger.parameters(), model_config.tune_merger) 206 | 207 | # set requires_grad for regression head 208 | set_requires_grad(model_to_configure.rm_head.parameters(), True) 209 | 210 | ## ===> Step 3: Load Dataset and configure 211 | if isinstance(data_config.eval_dim, str): 212 | data_config.eval_dim = [data_config.eval_dim] 213 | # datasets = create_dataset(data_config) 214 | # train_dataset = concatenate_datasets([datasets[dim] for dim in data_config.eval_dim]) 215 | train_dataset = create_dataset(data_config) 216 | train_dataset = train_dataset.shuffle(seed=42) 217 | 218 | if training_args.conduct_eval: 219 | if data_config.meta_data_test is not None: 220 | random.seed(42) 221 | valid_dataset = create_dataset(data_config, meta_file=data_config.meta_data_test) 222 | indices = random.sample(range(len(valid_dataset)), 1000) 223 | valid_dataset = valid_dataset.select(indices) 224 | else: 225 | dataset = train_dataset.train_test_split(test_size=0.02) 226 | train_dataset = dataset['train'] 227 | valid_dataset = dataset['test'] 228 | else: 229 | valid_dataset = None 230 | 231 | print(f"===> Selected {len(train_dataset)} samples for training.") 232 | print(f"===> Selected {len(valid_dataset)} samples for testing.") 233 | 234 | num_gpu = int(os.environ.get("WORLD_SIZE", 1)) 235 | data_collator = QWen2VLDataCollator(processor, add_noise=data_config.add_noise, 236 | p_shuffle_frames=data_config.p_shuffle_frames, 237 | p_color_jitter=data_config.p_color_jitter,) 238 | compute_metrics = partial(compute_multi_attr_accuracy, eval_dims=data_config.eval_dim) 239 | 240 | actual_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * num_gpu 241 | total_steps = training_args.num_train_epochs * len(train_dataset) // actual_batch_size 242 | if training_args.save_epochs is not None: 243 | training_args.save_steps = round(training_args.save_epochs * len(train_dataset) / actual_batch_size) 244 | if training_args.eval_epochs is not None: 245 | training_args.eval_steps = round(training_args.eval_epochs * len(train_dataset) / actual_batch_size) 246 | if training_args.logging_epochs is not None: 247 | training_args.logging_steps = round(training_args.logging_epochs * len(train_dataset) / actual_batch_size) 248 | 249 | if training_args.local_rank == -1 or training_args.local_rank == 0: 250 | print(f"===> Using {num_gpu} GPUs.") 251 | print(f"===> Total Batch Size: {actual_batch_size}") 252 | print(f"===> Training Epochs: {training_args.num_train_epochs}") 253 | print(f"===> Total Steps: {total_steps}") 254 | print(f"===> Save Steps: {training_args.save_steps}") 255 | print(f"===> Eval Steps: {training_args.eval_steps}") 256 | print(f"===> Logging Steps: {training_args.logging_steps}") 257 | 258 | 259 | # pdb.set_trace() 260 | 261 | ## ===> Step 4: Save configs for re-check 262 | if training_args.local_rank == -1 or training_args.local_rank == 0: 263 | save_configs_to_json(data_config, training_args, model_config, peft_lora_config) 264 | 265 | print(train_dataset) 266 | ## ===> Step 5: Start Training! 267 | 268 | special_token_ids = model.special_token_ids 269 | callbacks = [] 270 | if special_token_ids is not None: 271 | callbacks.append(PartialEmbeddingUpdateCallback(special_token_ids)) 272 | 273 | trainer = ImageVLMRewardTrainer( 274 | model=model, 275 | compute_metrics=compute_metrics, 276 | data_collator=data_collator, 277 | args=training_args, 278 | train_dataset=train_dataset, 279 | eval_dataset=valid_dataset if training_args.conduct_eval else None, 280 | peft_config=peft_config, 281 | callbacks=callbacks, 282 | loss_type=model_config.loss_type, 283 | tokenizer=processor.tokenizer, 284 | ) 285 | 286 | trainer.train() 287 | 288 | if training_args.local_rank == -1 or training_args.local_rank == 0: 289 | model_state_dict = model.state_dict() 290 | torch.save(model_state_dict, os.path.join(training_args.output_dir, 'final_model.pth')) 291 | model.config.save_pretrained(training_args.output_dir) 292 | 293 | 294 | if __name__ == "__main__": 295 | train() -------------------------------------------------------------------------------- /train_flux/flux/block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Union, Optional, Dict, Any, Callable 3 | from diffusers.models.attention_processor import Attention, F 4 | from .lora_controller import enable_lora 5 | 6 | 7 | def attn_forward( 8 | attn: Attention, 9 | hidden_states: torch.FloatTensor, 10 | encoder_hidden_states: torch.FloatTensor = None, 11 | condition_latents: torch.FloatTensor = None, 12 | attention_mask: Optional[torch.FloatTensor] = None, 13 | image_rotary_emb: Optional[torch.Tensor] = None, 14 | cond_rotary_emb: Optional[torch.Tensor] = None, 15 | model_config: Optional[Dict[str, Any]] = {}, 16 | ) -> torch.FloatTensor: 17 | batch_size, _, _ = ( 18 | hidden_states.shape 19 | if encoder_hidden_states is None 20 | else encoder_hidden_states.shape 21 | ) 22 | 23 | with enable_lora( 24 | (attn.to_q, attn.to_k, attn.to_v), model_config.get("latent_lora", False) 25 | ): 26 | # `sample` projections. 27 | query = attn.to_q(hidden_states) 28 | key = attn.to_k(hidden_states) 29 | value = attn.to_v(hidden_states) 30 | 31 | inner_dim = key.shape[-1] 32 | head_dim = inner_dim // attn.heads 33 | 34 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 35 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 36 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 37 | 38 | if attn.norm_q is not None: 39 | query = attn.norm_q(query) 40 | if attn.norm_k is not None: 41 | key = attn.norm_k(key) 42 | 43 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` 44 | if encoder_hidden_states is not None: 45 | # `context` projections. 46 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 47 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 48 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 49 | 50 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( 51 | batch_size, -1, attn.heads, head_dim 52 | ).transpose(1, 2) 53 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( 54 | batch_size, -1, attn.heads, head_dim 55 | ).transpose(1, 2) 56 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( 57 | batch_size, -1, attn.heads, head_dim 58 | ).transpose(1, 2) 59 | 60 | if attn.norm_added_q is not None: 61 | encoder_hidden_states_query_proj = attn.norm_added_q( 62 | encoder_hidden_states_query_proj 63 | ) 64 | if attn.norm_added_k is not None: 65 | encoder_hidden_states_key_proj = attn.norm_added_k( 66 | encoder_hidden_states_key_proj 67 | ) 68 | 69 | # attention 70 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) 71 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) 72 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) 73 | 74 | if image_rotary_emb is not None: 75 | from diffusers.models.embeddings import apply_rotary_emb 76 | 77 | query = apply_rotary_emb(query, image_rotary_emb) 78 | key = apply_rotary_emb(key, image_rotary_emb) 79 | 80 | if condition_latents is not None: 81 | cond_query = attn.to_q(condition_latents) 82 | cond_key = attn.to_k(condition_latents) 83 | cond_value = attn.to_v(condition_latents) 84 | 85 | cond_query = cond_query.view(batch_size, -1, attn.heads, head_dim).transpose( 86 | 1, 2 87 | ) 88 | cond_key = cond_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 89 | cond_value = cond_value.view(batch_size, -1, attn.heads, head_dim).transpose( 90 | 1, 2 91 | ) 92 | if attn.norm_q is not None: 93 | cond_query = attn.norm_q(cond_query) 94 | if attn.norm_k is not None: 95 | cond_key = attn.norm_k(cond_key) 96 | 97 | if cond_rotary_emb is not None: 98 | cond_query = apply_rotary_emb(cond_query, cond_rotary_emb) 99 | cond_key = apply_rotary_emb(cond_key, cond_rotary_emb) 100 | 101 | if condition_latents is not None: 102 | query = torch.cat([query, cond_query], dim=2) 103 | key = torch.cat([key, cond_key], dim=2) 104 | value = torch.cat([value, cond_value], dim=2) 105 | 106 | if not model_config.get("union_cond_attn", True): 107 | # If we don't want to use the union condition attention, we need to mask the attention 108 | # between the hidden states and the condition latents 109 | attention_mask = torch.ones( 110 | query.shape[2], key.shape[2], device=query.device, dtype=torch.bool 111 | ) 112 | condition_n = cond_query.shape[2] 113 | attention_mask[-condition_n:, :-condition_n] = False 114 | attention_mask[:-condition_n, -condition_n:] = False 115 | if hasattr(attn, "c_factor"): 116 | attention_mask = torch.zeros( 117 | query.shape[2], key.shape[2], device=query.device, dtype=query.dtype 118 | ) 119 | condition_n = cond_query.shape[2] 120 | bias = torch.log(attn.c_factor[0]) 121 | attention_mask[-condition_n:, :-condition_n] = bias 122 | attention_mask[:-condition_n, -condition_n:] = bias 123 | hidden_states = F.scaled_dot_product_attention( 124 | query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask 125 | ) 126 | hidden_states = hidden_states.transpose(1, 2).reshape( 127 | batch_size, -1, attn.heads * head_dim 128 | ) 129 | hidden_states = hidden_states.to(query.dtype) 130 | 131 | if encoder_hidden_states is not None: 132 | if condition_latents is not None: 133 | encoder_hidden_states, hidden_states, condition_latents = ( 134 | hidden_states[:, : encoder_hidden_states.shape[1]], 135 | hidden_states[ 136 | :, encoder_hidden_states.shape[1] : -condition_latents.shape[1] 137 | ], 138 | hidden_states[:, -condition_latents.shape[1] :], 139 | ) 140 | else: 141 | encoder_hidden_states, hidden_states = ( 142 | hidden_states[:, : encoder_hidden_states.shape[1]], 143 | hidden_states[:, encoder_hidden_states.shape[1] :], 144 | ) 145 | 146 | with enable_lora((attn.to_out[0],), model_config.get("latent_lora", False)): 147 | # linear proj 148 | hidden_states = attn.to_out[0](hidden_states) 149 | # dropout 150 | hidden_states = attn.to_out[1](hidden_states) 151 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 152 | 153 | if condition_latents is not None: 154 | condition_latents = attn.to_out[0](condition_latents) 155 | condition_latents = attn.to_out[1](condition_latents) 156 | 157 | return ( 158 | (hidden_states, encoder_hidden_states, condition_latents) 159 | if condition_latents is not None 160 | else (hidden_states, encoder_hidden_states) 161 | ) 162 | elif condition_latents is not None: 163 | # if there are condition_latents, we need to separate the hidden_states and the condition_latents 164 | hidden_states, condition_latents = ( 165 | hidden_states[:, : -condition_latents.shape[1]], 166 | hidden_states[:, -condition_latents.shape[1] :], 167 | ) 168 | return hidden_states, condition_latents 169 | else: 170 | return hidden_states 171 | 172 | 173 | def block_forward( 174 | self, 175 | hidden_states: torch.FloatTensor, 176 | encoder_hidden_states: torch.FloatTensor, 177 | condition_latents: torch.FloatTensor, 178 | temb: torch.FloatTensor, 179 | cond_temb: torch.FloatTensor, 180 | cond_rotary_emb=None, 181 | image_rotary_emb=None, 182 | model_config: Optional[Dict[str, Any]] = {}, 183 | ): 184 | use_cond = condition_latents is not None 185 | with enable_lora((self.norm1.linear,), model_config.get("latent_lora", False)): 186 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 187 | hidden_states, emb=temb 188 | ) 189 | 190 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( 191 | self.norm1_context(encoder_hidden_states, emb=temb) 192 | ) 193 | 194 | if use_cond: 195 | ( 196 | norm_condition_latents, 197 | cond_gate_msa, 198 | cond_shift_mlp, 199 | cond_scale_mlp, 200 | cond_gate_mlp, 201 | ) = self.norm1(condition_latents, emb=cond_temb) 202 | 203 | # Attention. 204 | result = attn_forward( 205 | self.attn, 206 | model_config=model_config, 207 | hidden_states=norm_hidden_states, 208 | encoder_hidden_states=norm_encoder_hidden_states, 209 | condition_latents=norm_condition_latents if use_cond else None, 210 | image_rotary_emb=image_rotary_emb, 211 | cond_rotary_emb=cond_rotary_emb if use_cond else None, 212 | ) 213 | attn_output, context_attn_output = result[:2] 214 | cond_attn_output = result[2] if use_cond else None 215 | 216 | # Process attention outputs for the `hidden_states`. 217 | # 1. hidden_states 218 | attn_output = gate_msa.unsqueeze(1) * attn_output 219 | hidden_states = hidden_states + attn_output 220 | # 2. encoder_hidden_states 221 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output 222 | encoder_hidden_states = encoder_hidden_states + context_attn_output 223 | # 3. condition_latents 224 | if use_cond: 225 | cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output 226 | condition_latents = condition_latents + cond_attn_output 227 | if model_config.get("add_cond_attn", False): 228 | hidden_states += cond_attn_output 229 | 230 | # LayerNorm + MLP. 231 | # 1. hidden_states 232 | norm_hidden_states = self.norm2(hidden_states) 233 | norm_hidden_states = ( 234 | norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 235 | ) 236 | # 2. encoder_hidden_states 237 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) 238 | norm_encoder_hidden_states = ( 239 | norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] 240 | ) 241 | # 3. condition_latents 242 | if use_cond: 243 | norm_condition_latents = self.norm2(condition_latents) 244 | norm_condition_latents = ( 245 | norm_condition_latents * (1 + cond_scale_mlp[:, None]) 246 | + cond_shift_mlp[:, None] 247 | ) 248 | 249 | # Feed-forward. 250 | with enable_lora((self.ff.net[2],), model_config.get("latent_lora", False)): 251 | # 1. hidden_states 252 | ff_output = self.ff(norm_hidden_states) 253 | ff_output = gate_mlp.unsqueeze(1) * ff_output 254 | # 2. encoder_hidden_states 255 | context_ff_output = self.ff_context(norm_encoder_hidden_states) 256 | context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output 257 | # 3. condition_latents 258 | if use_cond: 259 | cond_ff_output = self.ff(norm_condition_latents) 260 | cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output 261 | 262 | # Process feed-forward outputs. 263 | hidden_states = hidden_states + ff_output 264 | encoder_hidden_states = encoder_hidden_states + context_ff_output 265 | if use_cond: 266 | condition_latents = condition_latents + cond_ff_output 267 | 268 | # Clip to avoid overflow. 269 | if encoder_hidden_states.dtype == torch.float16: 270 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) 271 | 272 | return encoder_hidden_states, hidden_states, condition_latents if use_cond else None 273 | 274 | 275 | def single_block_forward( 276 | self, 277 | hidden_states: torch.FloatTensor, 278 | temb: torch.FloatTensor, 279 | image_rotary_emb=None, 280 | condition_latents: torch.FloatTensor = None, 281 | cond_temb: torch.FloatTensor = None, 282 | cond_rotary_emb=None, 283 | model_config: Optional[Dict[str, Any]] = {}, 284 | ): 285 | 286 | using_cond = condition_latents is not None 287 | residual = hidden_states 288 | with enable_lora( 289 | ( 290 | self.norm.linear, 291 | self.proj_mlp, 292 | ), 293 | model_config.get("latent_lora", False), 294 | ): 295 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb) 296 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) 297 | if using_cond: 298 | residual_cond = condition_latents 299 | norm_condition_latents, cond_gate = self.norm(condition_latents, emb=cond_temb) 300 | mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_condition_latents)) 301 | 302 | attn_output = attn_forward( 303 | self.attn, 304 | model_config=model_config, 305 | hidden_states=norm_hidden_states, 306 | image_rotary_emb=image_rotary_emb, 307 | **( 308 | { 309 | "condition_latents": norm_condition_latents, 310 | "cond_rotary_emb": cond_rotary_emb if using_cond else None, 311 | } 312 | if using_cond 313 | else {} 314 | ), 315 | ) 316 | if using_cond: 317 | attn_output, cond_attn_output = attn_output 318 | 319 | with enable_lora((self.proj_out,), model_config.get("latent_lora", False)): 320 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) 321 | gate = gate.unsqueeze(1) 322 | hidden_states = gate * self.proj_out(hidden_states) 323 | hidden_states = residual + hidden_states 324 | if using_cond: 325 | condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2) 326 | cond_gate = cond_gate.unsqueeze(1) 327 | condition_latents = cond_gate * self.proj_out(condition_latents) 328 | condition_latents = residual_cond + condition_latents 329 | 330 | if hidden_states.dtype == torch.float16: 331 | hidden_states = hidden_states.clip(-65504, 65504) 332 | 333 | return hidden_states if not using_cond else (hidden_states, condition_latents) 334 | --------------------------------------------------------------------------------