├── imgs ├── img1 ├── 0-0.jpg ├── 0-1.webp ├── 0-2.webp ├── 0-3.webp ├── 0-4.webp ├── 1-0.jpg ├── 1-1.webp ├── 1-2.webp ├── 1-3.webp ├── 1-4.webp ├── 2-0.jpg ├── 2-1.webp ├── 2-2.webp ├── 2-3.webp ├── 2-4.webp ├── 3-0.png ├── 3-1.webp ├── 3-2.webp ├── 3-3.webp ├── 3-4.webp ├── 4-0.jpg ├── 4-1.webp ├── 4-2.webp ├── 4-3.webp ├── 4-4.webp ├── 5-0.jpg ├── 5-1.webp ├── 5-2.webp ├── 5-3.webp └── 5-4.webp ├── flow_grpo ├── assets │ ├── simple_ocr_animals.txt │ ├── activities.txt │ ├── activities_v0.txt │ ├── simple_animals.txt │ ├── object_names.txt │ ├── simple_ocr_animals_digit1.txt │ ├── simple_ocr_animals_digit3.txt │ └── simple_ocr_animals_digit5.txt ├── test_cases │ ├── cat.jpg │ ├── nasa.jpg │ ├── hello world.jpg │ └── a photo of a brown giraffe and a white stop sign.png ├── reward_ckpt_path.py ├── imagereward_scorer.py ├── aesthetic_scorer.py ├── pickscore_scorer.py ├── prompts.py ├── clip_scorer.py ├── ocr.py ├── unifiedreward_scorer.py ├── ema.py ├── stat_tracking.py ├── diffusers_patch │ ├── train_dreambooth_lora_flux.py │ ├── train_dreambooth_lora_sd3.py │ ├── kontext_pipeline_with_logprob.py │ ├── pipeline_with_logprob.py │ ├── solver.py │ └── qwen_image_edit_old_pipeline_with_logprob.py ├── hpsv2_scorer.py ├── fsdp2_utils.py └── rewards.py ├── reward_server ├── requirements.txt ├── prompt_template.py ├── test_reward_server.py └── reward_server.py ├── reproduction ├── convert_to_diffusers_lora.py ├── README.md └── sampling │ ├── sampling_kontext_imgedit.py │ ├── sampling_qwen_imgedit.py │ ├── sampling_kontext_gedit.py │ └── sampling_qwen_gedit.py ├── examples ├── train_kontext.sh └── train_qwen_image_edit.sh ├── setup.py ├── config ├── qwen_image_edit_nft.py ├── kontext_nft.py └── base.py ├── README.md ├── .gitignore ├── LICENSE └── scripts └── evaluation.py /imgs/img1: -------------------------------------------------------------------------------- 1 | 111 2 | -------------------------------------------------------------------------------- /flow_grpo/assets/simple_ocr_animals.txt: -------------------------------------------------------------------------------- 1 | cat 2 | dog 3 | horse 4 | monkey 5 | rabbit -------------------------------------------------------------------------------- /imgs/0-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/0-0.jpg -------------------------------------------------------------------------------- /imgs/0-1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/0-1.webp -------------------------------------------------------------------------------- /imgs/0-2.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/0-2.webp -------------------------------------------------------------------------------- /imgs/0-3.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/0-3.webp -------------------------------------------------------------------------------- /imgs/0-4.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/0-4.webp -------------------------------------------------------------------------------- /imgs/1-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/1-0.jpg -------------------------------------------------------------------------------- /imgs/1-1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/1-1.webp -------------------------------------------------------------------------------- /imgs/1-2.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/1-2.webp -------------------------------------------------------------------------------- /imgs/1-3.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/1-3.webp -------------------------------------------------------------------------------- /imgs/1-4.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/1-4.webp -------------------------------------------------------------------------------- /imgs/2-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/2-0.jpg -------------------------------------------------------------------------------- /imgs/2-1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/2-1.webp -------------------------------------------------------------------------------- /imgs/2-2.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/2-2.webp -------------------------------------------------------------------------------- /imgs/2-3.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/2-3.webp -------------------------------------------------------------------------------- /imgs/2-4.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/2-4.webp -------------------------------------------------------------------------------- /imgs/3-0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/3-0.png -------------------------------------------------------------------------------- /imgs/3-1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/3-1.webp -------------------------------------------------------------------------------- /imgs/3-2.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/3-2.webp -------------------------------------------------------------------------------- /imgs/3-3.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/3-3.webp -------------------------------------------------------------------------------- /imgs/3-4.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/3-4.webp -------------------------------------------------------------------------------- /imgs/4-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/4-0.jpg -------------------------------------------------------------------------------- /imgs/4-1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/4-1.webp -------------------------------------------------------------------------------- /imgs/4-2.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/4-2.webp -------------------------------------------------------------------------------- /imgs/4-3.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/4-3.webp -------------------------------------------------------------------------------- /imgs/4-4.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/4-4.webp -------------------------------------------------------------------------------- /imgs/5-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/5-0.jpg -------------------------------------------------------------------------------- /imgs/5-1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/5-1.webp -------------------------------------------------------------------------------- /imgs/5-2.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/5-2.webp -------------------------------------------------------------------------------- /imgs/5-3.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/5-3.webp -------------------------------------------------------------------------------- /imgs/5-4.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/imgs/5-4.webp -------------------------------------------------------------------------------- /flow_grpo/assets/activities.txt: -------------------------------------------------------------------------------- 1 | washing the dishes 2 | riding a bike 3 | playing chess -------------------------------------------------------------------------------- /flow_grpo/assets/activities_v0.txt: -------------------------------------------------------------------------------- 1 | washing the dishes 2 | riding a bike 3 | playing chess -------------------------------------------------------------------------------- /reward_server/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.55.4 2 | vllm==0.9.2 3 | tokenizers==0.21.4 -------------------------------------------------------------------------------- /flow_grpo/test_cases/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/flow_grpo/test_cases/cat.jpg -------------------------------------------------------------------------------- /flow_grpo/test_cases/nasa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/flow_grpo/test_cases/nasa.jpg -------------------------------------------------------------------------------- /flow_grpo/test_cases/hello world.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/flow_grpo/test_cases/hello world.jpg -------------------------------------------------------------------------------- /flow_grpo/reward_ckpt_path.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | CKPT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../reward_ckpts") 4 | -------------------------------------------------------------------------------- /flow_grpo/test_cases/a photo of a brown giraffe and a white stop sign.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/Edit-R1/HEAD/flow_grpo/test_cases/a photo of a brown giraffe and a white stop sign.png -------------------------------------------------------------------------------- /reproduction/convert_to_diffusers_lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors.torch import load_file, save_file 3 | 4 | base_dir = "lora/" 5 | 6 | state_dict = load_file(f"{base_dir}/adapter_model.safetensors") 7 | 8 | new_state_dict = {} 9 | 10 | for key, value in state_dict.items(): 11 | new_key = key.replace("base_model.model", "transformer") 12 | new_state_dict[new_key] = value 13 | 14 | 15 | save_file(new_state_dict, f"{base_dir}/adapter_model_converted.safetensors") 16 | -------------------------------------------------------------------------------- /flow_grpo/assets/simple_animals.txt: -------------------------------------------------------------------------------- 1 | cat 2 | dog 3 | horse 4 | monkey 5 | rabbit 6 | zebra 7 | spider 8 | bird 9 | sheep 10 | deer 11 | cow 12 | goat 13 | lion 14 | tiger 15 | bear 16 | raccoon 17 | fox 18 | wolf 19 | lizard 20 | beetle 21 | ant 22 | butterfly 23 | fish 24 | shark 25 | whale 26 | dolphin 27 | squirrel 28 | mouse 29 | rat 30 | snake 31 | turtle 32 | frog 33 | chicken 34 | duck 35 | goose 36 | bee 37 | pig 38 | turkey 39 | fly 40 | llama 41 | camel 42 | bat 43 | gorilla 44 | hedgehog 45 | kangaroo 46 | -------------------------------------------------------------------------------- /reward_server/prompt_template.py: -------------------------------------------------------------------------------- 1 | SCORE_LOGIT = """Here are two images: the original and the edited version. Please evaluate the edited image based on the following editing instruction and requirement. 2 | Instruction: {prompt} 3 | Requirements: {requirement} 4 | You need to rate the editing result from 0 to 5 based on the accuracy and quality of the edit. 5 | 0: The wrong object was edited, or the edit completely fails to meet the requirements. 6 | 5: The correct object was edited, the requirements were met, and the visual result is high quality. 7 | Response Format (Directly response the score number): 8 | 0-5""" -------------------------------------------------------------------------------- /examples/train_kontext.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 2 | export TOKENIZERS_PARALLELISM=true 3 | export REWARD_SERVER=10.0.67.19:12341 4 | 5 | export NCCL_IB_TC=136 6 | export NCCL_IB_SL=5 7 | export NCCL_IB_GID_INDEX=3 8 | export NCCL_SOCKET_IFNAME=eth 9 | export NCCL_IB_HCA=mlx5 10 | export NCCL_IB_TIMEOUT=22 11 | export NCCL_IB_QPS_PER_CONNECTION=8 12 | export NCCL_NET_PLUGIN=none 13 | 14 | torchrun --nproc_per_node=8 \ 15 | --nnodes=${WORLD_SIZE} \ 16 | --master_addr=${MASTER_ADDR} \ 17 | --master_port=${MASTER_PORT} \ 18 | --node_rank ${RANK} \ 19 | scripts/train_nft_kontext.py --config config/kontext_nft.py:$1 -------------------------------------------------------------------------------- /examples/train_qwen_image_edit.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 2 | export TOKENIZERS_PARALLELISM=true 3 | export REWARD_SERVER=172.17.0.139:12341 4 | 5 | export NCCL_IB_TC=136 6 | export NCCL_IB_SL=5 7 | export NCCL_IB_GID_INDEX=3 8 | export NCCL_SOCKET_IFNAME=eth 9 | export NCCL_IB_HCA=mlx5 10 | export NCCL_IB_TIMEOUT=22 11 | export NCCL_IB_QPS_PER_CONNECTION=8 12 | export NCCL_NET_PLUGIN=none 13 | 14 | torchrun --nproc_per_node=8 \ 15 | --nnodes=${WORLD_SIZE} \ 16 | --master_addr=${MASTER_ADDR} \ 17 | --master_port=${MASTER_PORT} \ 18 | --node_rank ${RANK} \ 19 | scripts/train_nft_qwen_image_edit.py --config config/qwen_image_edit_nft.py:$1 -------------------------------------------------------------------------------- /flow_grpo/assets/object_names.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorcycle 5 | airplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | couch 59 | potted plant 60 | bed 61 | dining table 62 | toilet 63 | tv 64 | laptop 65 | computer mouse 66 | tv remote 67 | computer keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /flow_grpo/imagereward_scorer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch 4 | import ImageReward as RM 5 | 6 | 7 | class ImageRewardScorer(torch.nn.Module): 8 | def __init__(self, device="cuda", dtype=torch.float32): 9 | super().__init__() 10 | self.device = device 11 | self.dtype = dtype 12 | self.model = ( 13 | RM.load( 14 | "ImageReward-v1.0", 15 | device=device, 16 | download_root=os.path.join(os.environ.get("HF_HOME", "~/.cache/"), "ImageReward"), 17 | ) 18 | .eval() 19 | .to(dtype=dtype) 20 | ) 21 | self.model.requires_grad_(False) 22 | 23 | @torch.no_grad() 24 | def __call__(self, prompts, images): 25 | _, rewards = self.model.inference_rank(prompts, images) 26 | rewards = torch.diagonal(torch.Tensor(rewards).to(self.device).reshape(len(prompts), len(prompts)), 0) 27 | return rewards.contiguous() 28 | 29 | 30 | # Usage example 31 | def main(): 32 | scorer = ImageRewardScorer(device="cuda", dtype=torch.float32) 33 | 34 | images = [ 35 | "test_cases/nasa.jpg", 36 | "test_cases/hello world.jpg", 37 | ] 38 | pil_images = [Image.open(img) for img in images] 39 | prompts = [ 40 | 'An astronaut’s glove floating in zero-g with "NASA 2049" on the wrist', 41 | 'New York Skyline with "Hello World" written with fireworks on the sky', 42 | ] 43 | print(scorer(prompts, pil_images)) 44 | 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="diffusion-nft", 5 | version="0.0.1", 6 | packages=find_packages(), 7 | python_requires=">=3.10", 8 | install_requires=[ 9 | "torch==2.6.0", 10 | "torchvision==0.21.0", 11 | "transformers==4.40.0", 12 | "accelerate==1.4.0", 13 | "diffusers==0.33.1", 14 | 15 | "numpy==1.26.4", 16 | "pandas==2.2.3", 17 | "scipy==1.15.2", 18 | "scikit-learn==1.6.1", 19 | "scikit-image==0.25.2", 20 | 21 | "albumentations==1.4.10", 22 | "opencv-python==4.11.0.86", 23 | "pillow==10.4.0", 24 | 25 | "tqdm==4.67.1", 26 | "wandb==0.18.7", 27 | "pydantic==2.10.6", 28 | "requests", 29 | "matplotlib==3.10.0", 30 | 31 | "flash-attn==2.7.4.post1", 32 | "deepspeed==0.16.4", 33 | "peft==0.10.0", 34 | "bitsandbytes==0.45.3", 35 | 36 | "aiohttp==3.11.13", 37 | "fastapi==0.115.11", 38 | "uvicorn==0.34.0", 39 | 40 | "huggingface-hub==0.29.1", 41 | "datasets==3.3.2", 42 | "tokenizers==0.19.1", 43 | 44 | "einops==0.8.1", 45 | "nvidia-ml-py==12.570.86", 46 | "xformers", 47 | "absl-py", 48 | "ml_collections", 49 | "sentencepiece", 50 | ], 51 | extras_require={ 52 | "dev": [ 53 | "ipython==8.34.0", 54 | "black==24.2.0", 55 | "pytest==8.2.0" 56 | ] 57 | } 58 | ) -------------------------------------------------------------------------------- /reward_server/test_reward_server.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import pickle 3 | from PIL import Image 4 | import io 5 | 6 | ref_path = "assets/ref_image.png" 7 | img_path = "assets/image.png" 8 | 9 | def test_cot_continue(): 10 | ref_image = Image.open(ref_path).convert("RGB") 11 | ref_image_io = io.BytesIO() 12 | ref_image.save(ref_image_io, format="JPEG") 13 | ref_image_io.seek(0) 14 | 15 | image = Image.open(img_path).convert("RGB") 16 | image_io = io.BytesIO() 17 | image.save(image_io, format="JPEG") 18 | image_io.seek(0) 19 | 20 | data = { 21 | "images": [image_io.getvalue(), image_io.getvalue()], 22 | "ref_images": [ref_image_io.getvalue(), ref_image_io.getvalue()], 23 | "prompts": ["Change the color of the snowman's hat from black to green", "Change the color of the snowman's hat from black to red"], 24 | "metadatas": [ 25 | {"requirement": "None"}, 26 | {"requirement": "None"}, 27 | ], 28 | } 29 | 30 | payload = pickle.dumps(data) 31 | url = "http://127.0.0.1:12341/mode/logits_non_cot" 32 | proxies = { 33 | "http": None, 34 | "https": None, 35 | } 36 | response = requests.post( 37 | url, 38 | data=payload, 39 | proxies=proxies, 40 | headers={"Content-Type": "application/octet-stream"}, 41 | ) 42 | 43 | if response.status_code == 200: 44 | result = pickle.loads(response.content) 45 | print("Scores:", result["scores"]) 46 | else: 47 | print("Error:", response.status_code, response.text) 48 | 49 | if __name__ == "__main__": 50 | test_cot_continue() 51 | -------------------------------------------------------------------------------- /flow_grpo/assets/simple_ocr_animals_digit1.txt: -------------------------------------------------------------------------------- 1 | A cat holding a sign that says '0' 2 | A dog holding a sign that says '0' 3 | A horse holding a sign that says '0' 4 | A monkey holding a sign that says '0' 5 | A rabbit holding a sign that says '0' 6 | A cat holding a sign that says '1' 7 | A dog holding a sign that says '1' 8 | A horse holding a sign that says '1' 9 | A monkey holding a sign that says '1' 10 | A rabbit holding a sign that says '1' 11 | A cat holding a sign that says '2' 12 | A dog holding a sign that says '2' 13 | A horse holding a sign that says '2' 14 | A monkey holding a sign that says '2' 15 | A rabbit holding a sign that says '2' 16 | A cat holding a sign that says '3' 17 | A dog holding a sign that says '3' 18 | A horse holding a sign that says '3' 19 | A monkey holding a sign that says '3' 20 | A rabbit holding a sign that says '3' 21 | A cat holding a sign that says '4' 22 | A dog holding a sign that says '4' 23 | A horse holding a sign that says '4' 24 | A monkey holding a sign that says '4' 25 | A rabbit holding a sign that says '4' 26 | A cat holding a sign that says '5' 27 | A dog holding a sign that says '5' 28 | A horse holding a sign that says '5' 29 | A monkey holding a sign that says '5' 30 | A rabbit holding a sign that says '5' 31 | A cat holding a sign that says '6' 32 | A dog holding a sign that says '6' 33 | A horse holding a sign that says '6' 34 | A monkey holding a sign that says '6' 35 | A rabbit holding a sign that says '6' 36 | A cat holding a sign that says '7' 37 | A dog holding a sign that says '7' 38 | A horse holding a sign that says '7' 39 | A monkey holding a sign that says '7' 40 | A rabbit holding a sign that says '7' 41 | A cat holding a sign that says '8' 42 | A dog holding a sign that says '8' 43 | A horse holding a sign that says '8' 44 | A monkey holding a sign that says '8' 45 | A rabbit holding a sign that says '8' -------------------------------------------------------------------------------- /flow_grpo/assets/simple_ocr_animals_digit3.txt: -------------------------------------------------------------------------------- 1 | A cat holding a sign that says '123' 2 | A dog holding a sign that says '234' 3 | A horse holding a sign that says '345' 4 | A monkey holding a sign that says '456' 5 | A rabbit holding a sign that says '567' 6 | A cat holding a sign that says '678' 7 | A dog holding a sign that says '789' 8 | A horse holding a sign that says '123' 9 | A monkey holding a sign that says '234' 10 | A rabbit holding a sign that says '345' 11 | A cat holding a sign that says '456' 12 | A dog holding a sign that says '567' 13 | A horse holding a sign that says '678' 14 | A monkey holding a sign that says '789' 15 | A rabbit holding a sign that says '123' 16 | A cat holding a sign that says '234' 17 | A dog holding a sign that says '345' 18 | A horse holding a sign that says '456' 19 | A monkey holding a sign that says '567' 20 | A rabbit holding a sign that says '678' 21 | A cat holding a sign that says '789' 22 | A dog holding a sign that says '123' 23 | A horse holding a sign that says '234' 24 | A monkey holding a sign that says '345' 25 | A rabbit holding a sign that says '456' 26 | A cat holding a sign that says '567' 27 | A dog holding a sign that says '678' 28 | A horse holding a sign that says '789' 29 | A monkey holding a sign that says '123' 30 | A rabbit holding a sign that says '234' 31 | A cat holding a sign that says '345' 32 | A dog holding a sign that says '456' 33 | A horse holding a sign that says '567' 34 | A monkey holding a sign that says '678' 35 | A rabbit holding a sign that says '789' 36 | A cat holding a sign that says '123' 37 | A dog holding a sign that says '234' 38 | A horse holding a sign that says '345' 39 | A monkey holding a sign that says '456' 40 | A rabbit holding a sign that says '567' 41 | A cat holding a sign that says '678' 42 | A dog holding a sign that says '789' 43 | A horse holding a sign that says '123' 44 | A monkey holding a sign that says '234' 45 | A rabbit holding a sign that says '345' -------------------------------------------------------------------------------- /reproduction/README.md: -------------------------------------------------------------------------------- 1 | # Reproduction 2 | 3 | ## GEdit-Bench 4 | 5 | ### Sampling 6 | 7 | Qwen Image Edit [2509] Baseline: 8 | 9 | ``` 10 | python reproduction/sampling/sampling_qwen_gedit.py \ 11 | --pretrained_name_or_path [pretrained_model] \ 12 | --gedit_bench_path [gedit_bench_path] \ 13 | --output_dir [absolute_output_path] \ 14 | --seed [seed] 15 | ``` 16 | 17 | UniWorld-Qwen-Image-Edit [2509]: 18 | 19 | ``` 20 | python reproduction/sampling/sampling_qwen_gedit.py \ 21 | --pretrained_name_or_path [pretrained_model] \ 22 | --gedit_bench_path [gedit_bench_path] \ 23 | --output_dir [absolute_output_path] \ 24 | --seed [seed] \ 25 | --lora_path [our_lora] 26 | ``` 27 | 28 | ### Evaluation 29 | 30 | Refer to official evaluation code in [GEdit-Bench](https://github.com/stepfun-ai/Step1X-Edit/tree/main/GEdit-Bench). We **highly recommend** that set `temperature=0.0` [here](https://github.com/stepfun-ai/Step1X-Edit/blob/main/GEdit-Bench/viescore/mllm_tools/openai.py#L137) before evaluation. 31 | 32 | ## ImgEdit 33 | 34 | ### Sampling 35 | 36 | Qwen Image Edit [2509] Baseline: 37 | 38 | ``` 39 | python reproduction/sampling/sampling_qwen_imgedit.py \ 40 | --pretrained_name_or_path [pretrained_model] \ 41 | --input_path "[singleturn_json]" \ 42 | --output_dir "[absolute_output_path]" \ 43 | --root_path "[singleturn_dir]" \ 44 | --seed [seed] 45 | ``` 46 | 47 | UniWorld-Qwen-Image-Edit [2509]: 48 | 49 | ``` 50 | python reproduction/sampling/sampling_qwen_imgedit.py \ 51 | --pretrained_name_or_path [pretrained_model] \ 52 | --input_path "[singleturn_json]" \ 53 | --output_dir "[absolute_output_path]" \ 54 | --root_path "[singleturn_dir]" \ 55 | --seed [seed] \ 56 | --lora_path [our_lora] 57 | ``` 58 | 59 | 60 | ### Evaluation 61 | 62 | Refer to official evaluation code in [ImgEdit](https://github.com/PKU-YuanGroup/ImgEdit). We **highly recommend** that set `temperature=0.0` [here](https://github.com/PKU-YuanGroup/ImgEdit/blob/main/Benchmark/Basic/basic_bench.py#L41) before evaluation. 63 | -------------------------------------------------------------------------------- /flow_grpo/assets/simple_ocr_animals_digit5.txt: -------------------------------------------------------------------------------- 1 | A cat holding a sign that says '12345' 2 | A dog holding a sign that says '23456' 3 | A horse holding a sign that says '34567' 4 | A monkey holding a sign that says '45678' 5 | A rabbit holding a sign that says '56789' 6 | A cat holding a sign that says '54321' 7 | A dog holding a sign that says '65432' 8 | A horse holding a sign that says '76543' 9 | A monkey holding a sign that says '87654' 10 | A rabbit holding a sign that says '98765' 11 | A cat holding a sign that says '12345' 12 | A dog holding a sign that says '23456' 13 | A horse holding a sign that says '34567' 14 | A monkey holding a sign that says '45678' 15 | A rabbit holding a sign that says '56789' 16 | A cat holding a sign that says '54321' 17 | A dog holding a sign that says '65432' 18 | A horse holding a sign that says '76543' 19 | A monkey holding a sign that says '87654' 20 | A rabbit holding a sign that says '98765' 21 | A cat holding a sign that says '12345' 22 | A dog holding a sign that says '23456' 23 | A horse holding a sign that says '34567' 24 | A monkey holding a sign that says '45678' 25 | A rabbit holding a sign that says '56789' 26 | A cat holding a sign that says '54321' 27 | A dog holding a sign that says '65432' 28 | A horse holding a sign that says '76543' 29 | A monkey holding a sign that says '87654' 30 | A rabbit holding a sign that says '98765' 31 | A cat holding a sign that says '12345' 32 | A dog holding a sign that says '23456' 33 | A horse holding a sign that says '34567' 34 | A monkey holding a sign that says '45678' 35 | A rabbit holding a sign that says '56789' 36 | A cat holding a sign that says '54321' 37 | A dog holding a sign that says '65432' 38 | A horse holding a sign that says '76543' 39 | A monkey holding a sign that says '87654' 40 | A rabbit holding a sign that says '98765' 41 | A cat holding a sign that says '12345' 42 | A dog holding a sign that says '23456' 43 | A horse holding a sign that says '34567' 44 | A monkey holding a sign that says '45678' 45 | A rabbit holding a sign that says '56789' 46 | A cat holding a sign that says '54321' 47 | A dog holding a sign that says '65432' 48 | A horse holding a sign that says '76543' 49 | A monkey holding a sign that says '87654' 50 | A rabbit holding a sign that says '98765' -------------------------------------------------------------------------------- /flow_grpo/aesthetic_scorer.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/fe88a163f4661b4ddabba0751ff645e2e620746e/simple_inference.py 2 | 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | from transformers import CLIPModel, CLIPProcessor 7 | from flow_grpo.reward_ckpt_path import CKPT_PATH 8 | import numpy as np 9 | from PIL import Image 10 | 11 | 12 | class MLP(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | self.layers = nn.Sequential( 16 | nn.Linear(768, 1024), 17 | nn.Dropout(0.2), 18 | nn.Linear(1024, 128), 19 | nn.Dropout(0.2), 20 | nn.Linear(128, 64), 21 | nn.Dropout(0.1), 22 | nn.Linear(64, 16), 23 | nn.Linear(16, 1), 24 | ) 25 | 26 | @torch.no_grad() 27 | def forward(self, embed): 28 | return self.layers(embed) 29 | 30 | 31 | class AestheticScorer(torch.nn.Module): 32 | def __init__(self, dtype, device): 33 | super().__init__() 34 | self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device) 35 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") 36 | self.mlp = MLP().to(device) 37 | state_dict = torch.load(os.path.join(CKPT_PATH, "sac+logos+ava1-l14-linearMSE.pth"), map_location="cpu") 38 | self.mlp.load_state_dict(state_dict) 39 | self.dtype = dtype 40 | self.device = device 41 | self.eval() 42 | 43 | @torch.no_grad() 44 | def __call__(self, images): 45 | inputs = self.processor(images=images, return_tensors="pt") 46 | inputs = {k: v.to(self.dtype).to(self.device) for k, v in inputs.items()} 47 | embed = self.clip.get_image_features(**inputs) 48 | # normalize embedding 49 | embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) 50 | return self.mlp(embed).squeeze(1) 51 | 52 | 53 | # Usage example 54 | def main(): 55 | scorer = AestheticScorer(device="cuda", dtype=torch.float32) 56 | 57 | images = [ 58 | "test_cases/nasa.jpg", 59 | ] 60 | pil_images = np.stack([np.array(Image.open(img)) for img in images]) 61 | images = pil_images.transpose(0, 3, 1, 2) # NHWC -> NCHW 62 | images = torch.tensor(images, dtype=torch.uint8) 63 | print(scorer(images)) 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /flow_grpo/pickscore_scorer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoProcessor, AutoModel 2 | from PIL import Image 3 | import torch 4 | 5 | 6 | class PickScoreScorer(torch.nn.Module): 7 | def __init__(self, device="cuda", dtype=torch.float32): 8 | super().__init__() 9 | processor_path = "/mnt/data/checkpoints/laion/CLIP-ViT-H-14-laion2B-s32B-b79K" 10 | model_path = "/mnt/data/checkpoints/yuvalkirstain/PickScore_v1" 11 | self.device = device 12 | self.dtype = dtype 13 | self.processor = AutoProcessor.from_pretrained(processor_path) 14 | self.model = AutoModel.from_pretrained(model_path).eval().to(device) 15 | self.model = self.model.to(dtype=dtype) 16 | 17 | @torch.no_grad() 18 | def __call__(self, prompt, images): 19 | # Preprocess images 20 | image_inputs = self.processor( 21 | images=images, 22 | padding=True, 23 | truncation=True, 24 | max_length=77, 25 | return_tensors="pt", 26 | ) 27 | image_inputs = {k: v.to(device=self.device) for k, v in image_inputs.items()} 28 | # Preprocess text 29 | text_inputs = self.processor( 30 | text=prompt, 31 | padding=True, 32 | truncation=True, 33 | max_length=77, 34 | return_tensors="pt", 35 | ) 36 | text_inputs = {k: v.to(device=self.device) for k, v in text_inputs.items()} 37 | 38 | # Get embeddings 39 | image_embs = self.model.get_image_features(**image_inputs) 40 | image_embs = image_embs / image_embs.norm(p=2, dim=-1, keepdim=True) 41 | 42 | text_embs = self.model.get_text_features(**text_inputs) 43 | text_embs = text_embs / text_embs.norm(p=2, dim=-1, keepdim=True) 44 | 45 | # Calculate scores 46 | logit_scale = self.model.logit_scale.exp() 47 | scores = logit_scale * (text_embs @ image_embs.T) 48 | scores = scores.diag() 49 | # norm到0-1 50 | scores = scores / 26 51 | return scores 52 | 53 | 54 | # Usage example 55 | def main(): 56 | scorer = PickScoreScorer(device="cuda", dtype=torch.float32) 57 | images = [ 58 | "test_cases/nasa.jpg", 59 | ] 60 | pil_images = [Image.open(img) for img in images] 61 | prompts = [ 62 | 'An astronaut’s glove floating in zero-g with "NASA 2049" on the wrist', 63 | ] 64 | print(scorer(prompts, pil_images)) 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /flow_grpo/prompts.py: -------------------------------------------------------------------------------- 1 | from importlib import resources 2 | import os 3 | import functools 4 | import random 5 | 6 | # import inflect 7 | 8 | # IE = inflect.engine() 9 | IE = None 10 | ASSETS_PATH = resources.files("flow_grpo.assets") 11 | 12 | 13 | @functools.cache 14 | def _load_lines(path): 15 | """ 16 | Load lines from a file. First tries to load from `path` directly, and if that doesn't exist, searches the 17 | `flow_grpo/assets` directory for a file named `path`. 18 | """ 19 | if not os.path.exists(path): 20 | newpath = ASSETS_PATH.joinpath(path) 21 | if not os.path.exists(newpath): 22 | raise FileNotFoundError(f"Could not find {path} or flow_grpo.assets/{path}") 23 | path = newpath 24 | with open(path, "r") as f: 25 | return [line.strip() for line in f.readlines()] 26 | 27 | 28 | def from_file(path, low=None, high=None): 29 | prompts = _load_lines(path)[low:high] 30 | return random.choice(prompts), {} 31 | 32 | 33 | def imagenet_all(): 34 | return from_file("imagenet_classes.txt") 35 | 36 | 37 | def imagenet_animals(): 38 | return from_file("imagenet_classes.txt", 0, 398) 39 | 40 | 41 | def imagenet_dogs(): 42 | return from_file("imagenet_classes.txt", 151, 269) 43 | 44 | 45 | def simple_animals(): 46 | return from_file("simple_animals.txt") 47 | 48 | 49 | def general_ocr(): 50 | return from_file("general_ocr_train.txt") 51 | 52 | 53 | def simple_ocr_animals(): 54 | animals = _load_lines("simple_ocr_animals.txt") 55 | # random_number = random.randint(100, 999) 56 | # random_number = ''.join([str(random.randint(0, 9)) for _ in range(10)]) 57 | num = random.randint(1, 9) 58 | random_number = "".join([str(6) for _ in range(num)]) 59 | return f'A {random.choice(animals)} holding a sign that says "{random_number}"', {} 60 | 61 | 62 | def nouns_activities(nouns_file, activities_file): 63 | nouns = _load_lines(nouns_file) 64 | activities = _load_lines(activities_file) 65 | return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {} 66 | 67 | 68 | def counting(nouns_file, low, high): 69 | nouns = _load_lines(nouns_file) 70 | number = IE.number_to_words(random.randint(low, high)) 71 | noun = random.choice(nouns) 72 | plural_noun = IE.plural(noun) 73 | prompt = f"{number} {plural_noun}" 74 | metadata = { 75 | "questions": [ 76 | f"How many {plural_noun} are there in this image?", 77 | f"What animal is in this image?", 78 | ], 79 | "answers": [ 80 | number, 81 | noun, 82 | ], 83 | } 84 | return prompt, metadata 85 | -------------------------------------------------------------------------------- /flow_grpo/clip_scorer.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/RE-N-Y/imscore/blob/main/src/imscore/preference/model.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms as T 6 | from transformers import AutoImageProcessor, CLIPProcessor, CLIPModel 7 | import numpy as np 8 | from PIL import Image 9 | 10 | 11 | def get_size(size): 12 | if isinstance(size, int): 13 | return (size, size) 14 | elif "height" in size and "width" in size: 15 | return (size["height"], size["width"]) 16 | elif "shortest_edge" in size: 17 | return size["shortest_edge"] 18 | else: 19 | raise ValueError(f"Invalid size: {size}") 20 | 21 | 22 | def get_image_transform(processor: AutoImageProcessor): 23 | config = processor.to_dict() 24 | resize = T.Resize(get_size(config.get("size"))) if config.get("do_resize") else nn.Identity() 25 | crop = T.CenterCrop(get_size(config.get("crop_size"))) if config.get("do_center_crop") else nn.Identity() 26 | normalise = ( 27 | T.Normalize(mean=processor.image_mean, std=processor.image_std) if config.get("do_normalize") else nn.Identity() 28 | ) 29 | 30 | return T.Compose([resize, crop, normalise]) 31 | 32 | 33 | class ClipScorer(torch.nn.Module): 34 | def __init__(self, device): 35 | super().__init__() 36 | self.device = device 37 | self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device) 38 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") 39 | self.tform = get_image_transform(self.processor.image_processor) 40 | self.eval() 41 | 42 | def _process(self, pixels): 43 | dtype = pixels.dtype 44 | pixels = self.tform(pixels) 45 | pixels = pixels.to(dtype=dtype) 46 | 47 | return pixels 48 | 49 | @torch.no_grad() 50 | def __call__(self, pixels, prompts, return_img_embedding=False): 51 | texts = self.processor(text=prompts, padding="max_length", truncation=True, return_tensors="pt").to(self.device) 52 | pixels = self._process(pixels).to(self.device) 53 | outputs = self.model(pixel_values=pixels, **texts) 54 | if return_img_embedding: 55 | return outputs.logits_per_image.diagonal() / 100, outputs.image_embeds 56 | return outputs.logits_per_image.diagonal() / 100 57 | 58 | 59 | def main(): 60 | scorer = ClipScorer(device="cuda") 61 | 62 | images = ["test_cases/cat.jpg", "test_cases/cat.jpg"] 63 | pil_images = [Image.open(img) for img in images] 64 | prompts = ["an image of cat", "not an image of cat"] 65 | images = [np.array(img) for img in pil_images] 66 | images = np.array(images) 67 | images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW 68 | images = torch.tensor(images, dtype=torch.uint8) / 255.0 69 | print(scorer(images, prompts)) 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /flow_grpo/ocr.py: -------------------------------------------------------------------------------- 1 | from paddleocr import PaddleOCR 2 | import torch 3 | import numpy as np 4 | from Levenshtein import distance 5 | from typing import List, Union 6 | from PIL import Image 7 | 8 | 9 | class OcrScorer: 10 | def __init__(self, use_gpu: bool = False): 11 | """ 12 | OCR reward calculator 13 | :param use_gpu: Whether to use GPU acceleration for PaddleOCR 14 | """ 15 | self.ocr = PaddleOCR( 16 | use_angle_cls=False, lang="en", use_gpu=use_gpu, show_log=False # Disable unnecessary log output 17 | ) 18 | 19 | @torch.no_grad() 20 | def __call__(self, images: Union[List[Image.Image], List[np.ndarray]], prompts: List[str]) -> torch.Tensor: 21 | """ 22 | Calculate OCR reward 23 | :param images: List of input images (PIL or numpy format) 24 | :param prompts: Corresponding target text list 25 | :return: Reward tensor (CPU) 26 | """ 27 | prompts = [prompt.split('"')[1] for prompt in prompts] 28 | rewards = [] 29 | # Ensure input lengths are consistent 30 | assert len(images) == len(prompts), "Images and prompts must have the same length" 31 | for img, prompt in zip(images, prompts): 32 | # Convert image format 33 | if isinstance(img, Image.Image): 34 | img = np.array(img) 35 | 36 | try: 37 | # OCR recognition 38 | result = self.ocr.ocr(img, cls=False) 39 | # Extract recognized text (handle possible multi-line results) 40 | recognized_text = ( 41 | "".join([res[1][0] if res[1][1] > 0 else "" for res in result[0]]) if result[0] else "" 42 | ) 43 | 44 | recognized_text = recognized_text.replace(" ", "").lower() 45 | prompt = prompt.replace(" ", "").lower() 46 | if prompt in recognized_text: 47 | dist = 0 48 | else: 49 | dist = distance(recognized_text, prompt) 50 | # Recognized many unrelated characters, only add one character penalty 51 | if dist > len(prompt): 52 | dist = len(prompt) 53 | 54 | except Exception as e: 55 | # Error handling (e.g., OCR parsing failure) 56 | print(f"OCR processing failed: {str(e)}") 57 | dist = len(prompt) # Maximum penalty 58 | reward = 1 - dist / (len(prompt)) 59 | rewards.append(reward) 60 | 61 | return rewards 62 | 63 | 64 | if __name__ == "__main__": 65 | example_image_path = "test_cases/hello world.jpg" 66 | example_image = Image.open(example_image_path) 67 | example_prompt = 'New York Skyline with "Hello World" written with fireworks on the sky' 68 | # Instantiate scorer 69 | scorer = OcrScorer(use_gpu=False) 70 | 71 | # Call scorer and print result 72 | reward = scorer([example_image], [example_prompt]) 73 | print(f"OCR Reward: {reward}") 74 | -------------------------------------------------------------------------------- /reproduction/sampling/sampling_kontext_imgedit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import torch 5 | import ray 6 | from diffusers import FluxKontextPipeline 7 | from diffusers.utils import load_image 8 | from tqdm import tqdm 9 | 10 | @ray.remote(num_gpus=1) 11 | def process_slice(slice_items, pretrained_name_or_path, lora_path, output_dir, root_path, seed): 12 | pipe = FluxKontextPipeline.from_pretrained( 13 | pretrained_name_or_path, 14 | torch_dtype=torch.bfloat16, 15 | ) 16 | pipe.to("cuda") 17 | 18 | if lora_path: 19 | print("Load lora", lora_path) 20 | pipe.load_lora_weights( 21 | lora_path, 22 | weight_name="adapter_model.safetensors", 23 | adapter_name="lora", 24 | ) 25 | pipe.set_adapters(["lora"], adapter_weights=[1]) 26 | 27 | for key, item in tqdm(slice_items): 28 | try: 29 | relative_image_path = item["id"] 30 | prompt = item["prompt"] 31 | absolute_image_path = os.path.normpath(os.path.join(root_path, relative_image_path)) 32 | output_filename = f"{key}.jpg" 33 | output_filepath = os.path.join(output_dir, output_filename) 34 | if os.path.exists(output_filepath): 35 | continue 36 | 37 | input_image = load_image(absolute_image_path) 38 | generator = torch.Generator(device="cuda").manual_seed(seed) 39 | output_image = pipe( 40 | num_inference_steps=28, 41 | image=input_image, 42 | prompt=prompt, 43 | generator=generator, 44 | ).images[0] 45 | 46 | output_image.save(output_filepath) 47 | except Exception as e: 48 | continue 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--input_path", type=str, required=True) 53 | parser.add_argument("--output_dir", type=str, default="output_images") 54 | parser.add_argument("--pretrained_name_or_path", type=str, required=True) 55 | parser.add_argument("--root_path", type=str, required=True) 56 | parser.add_argument("--lora_path", type=str, required=False, default=None) 57 | parser.add_argument("--seed", type=int, default=42) 58 | args = parser.parse_args() 59 | os.makedirs(args.output_dir, exist_ok=True) 60 | root_path = os.path.normpath(args.root_path) 61 | ray.init() 62 | gpu_count = int(ray.available_resources().get("GPU", 1)) 63 | 64 | def load_json(path): 65 | try: 66 | with open(path, "r", encoding="utf-8") as f: 67 | return json.load(f) 68 | except Exception as e: 69 | exit(1) 70 | 71 | ds = load_json(args.input_path) 72 | all_items = list(ds.items()) 73 | 74 | slices = [all_items[i::gpu_count] for i in range(gpu_count)] 75 | ray.get([ 76 | process_slice.remote( 77 | slices[i], 78 | args.pretrained_name_or_path, 79 | args.lora_path, 80 | args.output_dir, 81 | root_path, 82 | args.seed, 83 | ) for i in range(gpu_count) 84 | ]) 85 | -------------------------------------------------------------------------------- /config/qwen_image_edit_nft.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import os 3 | 4 | base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py")) 5 | 6 | 7 | def get_config(name): 8 | return globals()[name]() 9 | 10 | def _get_config(base_model="qwen_image_edit", n_gpus=1, gradient_step_per_epoch=1, reward_fn={}, name=""): 11 | config = base.get_config() 12 | 13 | config.base_model = base_model 14 | config.transformer_path = None 15 | config.dataset = "../edit-r1-dataset" 16 | 17 | config.pretrained.model = "Qwen/Qwen-Image-Edit-2509" 18 | config.sample.num_steps = 6 19 | config.sample.eval_num_steps = 15 20 | config.sample.guidance_scale = 1.0 21 | config.resolution = 512 22 | config.train.beta = 0.0001 23 | config.sample.noise_level = 0.7 24 | bsz = 3 25 | 26 | config.sample.num_image_per_prompt = 12 27 | 28 | config.sample.ban_std_thres = 0.05 29 | config.sample.ban_mean_thres = 0.9 30 | config.sample.ban_prompt = False 31 | num_groups = 24 32 | 33 | while True: 34 | if bsz < 1: 35 | assert False, "Cannot find a proper batch size." 36 | if ( 37 | num_groups * config.sample.num_image_per_prompt % (n_gpus * bsz) == 0 38 | and bsz * n_gpus % config.sample.num_image_per_prompt == 0 39 | ): 40 | n_batch_per_epoch = num_groups * config.sample.num_image_per_prompt // (n_gpus * bsz) 41 | if n_batch_per_epoch % gradient_step_per_epoch == 0: 42 | config.sample.train_batch_size = bsz 43 | config.sample.num_batches_per_epoch = n_batch_per_epoch 44 | config.train.batch_size = config.sample.train_batch_size 45 | config.train.gradient_accumulation_steps = ( 46 | config.sample.num_batches_per_epoch // gradient_step_per_epoch 47 | ) 48 | break 49 | bsz -= 1 50 | 51 | # special design, the test set has a total of 1018/2212/2048 for ocr/geneval/pickscore, to make gpu_num*bs*n as close as possible to it, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization. 52 | config.sample.test_batch_size = bsz 53 | if n_gpus > 32: 54 | config.sample.test_batch_size = config.sample.test_batch_size // 2 55 | 56 | config.prompt_fn = "geneval" 57 | 58 | config.run_name = f"nft_{base_model}_{name}" 59 | config.save_dir = f"logs/nft/{base_model}/{name}" 60 | config.reward_fn = reward_fn 61 | 62 | config.decay_type = 1 63 | config.beta = 1.0 64 | config.train.adv_mode = "all" 65 | 66 | # config.sample.guidance_scale = 1.0 67 | config.sample.deterministic = True 68 | config.sample.solver = "dpm2" 69 | return config 70 | 71 | def qwen_mllm_reward(): 72 | reward_fn = { 73 | "mllm_score_continue": 1.0, 74 | } 75 | config = _get_config( 76 | base_model="qwen_image_edit", 77 | n_gpus=48, 78 | gradient_step_per_epoch=1, 79 | reward_fn=reward_fn, 80 | name="mllm_score_continue", 81 | ) 82 | config.sample.ban_prompt = True 83 | config.sample.ban_std_thres = 0.05 84 | return config 85 | 86 | 87 | -------------------------------------------------------------------------------- /reproduction/sampling/sampling_qwen_imgedit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import torch 5 | import ray 6 | from diffusers import QwenImageEditPlusPipeline 7 | from diffusers.utils import load_image 8 | from tqdm import tqdm 9 | 10 | @ray.remote(num_gpus=1) 11 | def process_slice(slice_items, pretrained_name_or_path, lora_path, output_dir, root_path, seed): 12 | pipe = QwenImageEditPlusPipeline.from_pretrained( 13 | pretrained_name_or_path, 14 | torch_dtype=torch.bfloat16, 15 | ) 16 | pipe.to("cuda") 17 | 18 | if lora_path: 19 | print("Load lora", lora_path) 20 | pipe.load_lora_weights( 21 | lora_path, 22 | weight_name="adapter_model.safetensors", 23 | adapter_name="lora", 24 | ) 25 | pipe.set_adapters(["lora"], adapter_weights=[1]) 26 | 27 | for key, item in tqdm(slice_items): 28 | try: 29 | relative_image_path = item["id"] 30 | prompt = item["prompt"] 31 | absolute_image_path = os.path.normpath(os.path.join(root_path, relative_image_path)) 32 | output_filename = f"{key}.jpg" 33 | output_filepath = os.path.join(output_dir, output_filename) 34 | if os.path.exists(output_filepath): 35 | continue 36 | 37 | input_image = load_image(absolute_image_path) 38 | generator = torch.Generator(device="cuda").manual_seed(seed) 39 | output_image = pipe( 40 | num_inference_steps=28, 41 | image=input_image, 42 | prompt=prompt, 43 | negative_prompt=" ", 44 | true_cfg_scale=4.0, 45 | guidance_scale=1.0, 46 | generator=generator, 47 | ).images[0] 48 | 49 | output_image.save(output_filepath) 50 | except Exception as e: 51 | continue 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--input_path", type=str, required=True) 56 | parser.add_argument("--pretrained_name_or_path", type=str, required=True) 57 | parser.add_argument("--output_dir", type=str, default="output_images") 58 | parser.add_argument("--root_path", type=str, required=True) 59 | parser.add_argument("--lora_path", type=str, required=False, default=None) 60 | parser.add_argument("--seed", type=int, default=42) 61 | args = parser.parse_args() 62 | os.makedirs(args.output_dir, exist_ok=True) 63 | root_path = os.path.normpath(args.root_path) 64 | ray.init() 65 | gpu_count = int(ray.available_resources().get("GPU", 1)) 66 | 67 | def load_json(path): 68 | try: 69 | with open(path, "r", encoding="utf-8") as f: 70 | return json.load(f) 71 | except Exception as e: 72 | exit(1) 73 | 74 | ds = load_json(args.input_path) 75 | all_items = list(ds.items()) 76 | 77 | slices = [all_items[i::gpu_count] for i in range(gpu_count)] 78 | ray.get([ 79 | process_slice.remote( 80 | slices[i], 81 | args.pretrained_name_or_path, 82 | args.lora_path, 83 | args.output_dir, 84 | root_path, 85 | args.seed, 86 | ) for i in range(gpu_count) 87 | ]) 88 | -------------------------------------------------------------------------------- /flow_grpo/unifiedreward_scorer.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from openai import AsyncOpenAI 3 | import base64 4 | from io import BytesIO 5 | import re 6 | from PIL import Image 7 | 8 | 9 | def pil_image_to_base64(image): 10 | buffered = BytesIO() 11 | image.save(buffered, format="PNG") 12 | encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8") 13 | base64_qwen = f"data:image;base64,{encoded_image_text}" 14 | return base64_qwen 15 | 16 | 17 | def _extract_scores(text_outputs): 18 | scores = [] 19 | pattern = r"Final Score:\s*([1-5](?:\.\d+)?)" 20 | for text in text_outputs: 21 | match = re.search(pattern, text) 22 | if match: 23 | try: 24 | scores.append(float(match.group(1))) 25 | except ValueError: 26 | scores.append(0.0) 27 | else: 28 | scores.append(0.0) 29 | return scores 30 | 31 | 32 | client = AsyncOpenAI(base_url="http://127.0.0.1:17140/v1", api_key="flowgrpo") 33 | 34 | 35 | async def evaluate_image(prompt, image): 36 | question = f"\nYou are given a text caption and a generated image based on that caption. Your task is to evaluate this image based on two key criteria:\n1. Alignment with the Caption: Assess how well this image aligns with the provided caption. Consider the accuracy of depicted objects, their relationships, and attributes as described in the caption.\n2. Overall Image Quality: Examine the visual quality of this image, including clarity, detail preservation, color accuracy, and overall aesthetic appeal.\nBased on the above criteria, assign a score from 1 to 5 after 'Final Score:'.\nYour task is provided as follows:\nText Caption: [{prompt}]" 37 | images_base64 = pil_image_to_base64(image) 38 | response = await client.chat.completions.create( 39 | model="UnifiedReward-7b-v1.5", 40 | messages=[ 41 | { 42 | "role": "user", 43 | "content": [ 44 | { 45 | "type": "image_url", 46 | "image_url": {"url": images_base64}, 47 | }, 48 | { 49 | "type": "text", 50 | "text": question, 51 | }, 52 | ], 53 | }, 54 | ], 55 | temperature=0, 56 | ) 57 | return response.choices[0].message.content 58 | 59 | 60 | async def evaluate_batch_image(images, prompts): 61 | tasks = [evaluate_image(prompt, img) for prompt, img in zip(prompts, images)] 62 | results = await asyncio.gather(*tasks) 63 | return results 64 | 65 | 66 | # Usage example 67 | def main(): 68 | images = [ 69 | "test_cases/nasa.jpg", 70 | "test_cases/hello world.jpg", 71 | "test_cases/a photo of a brown giraffe and a white stop sign.png", 72 | ] 73 | pil_images = [Image.open(img) for img in images] 74 | prompts = [ 75 | 'An astronaut’s glove floating in zero-g with "NASA 2049" on the wrist', 76 | 'New York Skyline with "Hello World" written with fireworks on the sky', 77 | "a photo of a brown giraffe and a white stop sign", 78 | ] 79 | text_outputs = asyncio.run(evaluate_batch_image(pil_images, prompts)) 80 | print(text_outputs) 81 | score = _extract_scores(text_outputs) 82 | score = [sc / 5.0 for sc in score] 83 | print(score) 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /config/kontext_nft.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import os 3 | 4 | base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py")) 5 | 6 | 7 | def get_config(name): 8 | return globals()[name]() 9 | 10 | def _get_config(base_model="kontext", n_gpus=1, gradient_step_per_epoch=1, reward_fn={}, name=""): 11 | config = base.get_config() 12 | 13 | config.base_model = base_model 14 | config.dataset = "../edit-r1-dataset" 15 | 16 | config.pretrained.model = "black-forest-labs/FLUX.1-Kontext-dev" 17 | config.sample.num_steps = 6 18 | config.sample.eval_num_steps = 15 19 | config.sample.guidance_scale = 2.5 20 | config.resolution = 512 21 | config.train.beta = 0.0001 22 | config.sample.noise_level = 0.7 23 | bsz = 3 24 | 25 | config.sample.num_image_per_prompt = 12 26 | 27 | config.sample.ban_std_thres = 0.05 28 | config.sample.ban_prompt = False 29 | 30 | num_groups = 24 31 | 32 | while True: 33 | if bsz < 1: 34 | assert False, "Cannot find a proper batch size." 35 | if ( 36 | num_groups * config.sample.num_image_per_prompt % (n_gpus * bsz) == 0 37 | and bsz * n_gpus % config.sample.num_image_per_prompt == 0 38 | ): 39 | n_batch_per_epoch = num_groups * config.sample.num_image_per_prompt // (n_gpus * bsz) 40 | if n_batch_per_epoch % gradient_step_per_epoch == 0: 41 | config.sample.train_batch_size = bsz 42 | config.sample.num_batches_per_epoch = n_batch_per_epoch 43 | config.train.batch_size = config.sample.train_batch_size 44 | config.train.gradient_accumulation_steps = ( 45 | config.sample.num_batches_per_epoch // gradient_step_per_epoch 46 | ) 47 | break 48 | bsz -= 1 49 | 50 | # special design, the test set has a total of 1018/2212/2048 for ocr/geneval/pickscore, to make gpu_num*bs*n as close as possible to it, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization. 51 | config.sample.test_batch_size = bsz 52 | if n_gpus > 32: 53 | config.sample.test_batch_size = config.sample.test_batch_size // 2 54 | 55 | config.prompt_fn = "geneval" 56 | 57 | config.run_name = f"nft_{base_model}_{name}" 58 | config.save_dir = f"logs/nft/{base_model}/{name}" 59 | config.reward_fn = reward_fn 60 | 61 | config.decay_type = 1 62 | config.beta = 1.0 63 | config.train.adv_mode = "all" 64 | 65 | # config.sample.guidance_scale = 1.0 66 | config.sample.deterministic = True 67 | config.sample.solver = "dpm2" 68 | return config 69 | 70 | def kontext_mllm_reward(): 71 | reward_fn = { 72 | "mllm_score_continue": 1.0, 73 | } 74 | config = _get_config( 75 | base_model="kontext", 76 | n_gpus=24, 77 | gradient_step_per_epoch=1, 78 | dataset="geneval", 79 | reward_fn=reward_fn, 80 | name="mllm_score_continue", 81 | ) 82 | return config 83 | 84 | def kontext_mllm_reward_ban_prompt(): 85 | reward_fn = { 86 | "mllm_score_continue": 1.0, 87 | } 88 | config = _get_config( 89 | base_model="kontext", 90 | n_gpus=24, 91 | gradient_step_per_epoch=1, 92 | dataset="geneval", 93 | reward_fn=reward_fn, 94 | name="mllm_score_continue_ban_prompt", 95 | ) 96 | config.sample.ban_prompt = True 97 | config.sample.ban_std_thres = 0.05 98 | return config -------------------------------------------------------------------------------- /reproduction/sampling/sampling_kontext_gedit.py: -------------------------------------------------------------------------------- 1 | from datasets import load_from_disk 2 | from diffusers import FluxKontextPipeline 3 | import ray 4 | import torch 5 | from typing import Optional 6 | import argparse 7 | import os 8 | from tqdm import tqdm 9 | 10 | def load_pipeline(pretrained_name_or_path: str, lora_path: Optional[str] = None): 11 | pipeline = FluxKontextPipeline.from_pretrained( 12 | pretrained_name_or_path, 13 | torch_dtype=torch.bfloat16, 14 | ) 15 | pipeline.to("cuda") 16 | 17 | if lora_path: 18 | pipeline.load_lora_weights( 19 | lora_path, 20 | weight_name="adapter_model.safetensors", 21 | adapter_name="lora", 22 | ) 23 | pipeline.set_adapters(["lora"], adapter_weights=[1]) 24 | print("Lora path provided") 25 | else: 26 | print("No lora path provided, using origin model") 27 | return pipeline 28 | 29 | 30 | @ray.remote(num_gpus=1) 31 | def sample(sliced_data, pretrained_name_or_path, lora_path, output_dir, seed): 32 | pipeline = load_pipeline(pretrained_name_or_path, lora_path) 33 | for item in tqdm(sliced_data): 34 | if item["instruction_language"] != "en": 35 | continue 36 | key = item["key"] 37 | prompt = item["instruction"] 38 | task_type = item["task_type"] 39 | input_image = item["input_image_raw"].convert("RGB") 40 | image_output_dir = os.path.join(output_dir, task_type) 41 | 42 | if os.path.exists(os.path.join(image_output_dir, f"{key}.png")): 43 | continue 44 | 45 | os.makedirs(image_output_dir, exist_ok=True) 46 | 47 | width, height = input_image.size 48 | 49 | generator = torch.Generator(device="cuda").manual_seed(seed) 50 | 51 | output_image = pipeline( 52 | prompt=prompt, 53 | guidance_scale=3.5, 54 | image=input_image, 55 | width=width, 56 | height=height, 57 | num_inference_steps=28, 58 | generator=generator, 59 | ).images[0] 60 | 61 | output_image.save(f"{image_output_dir}/{key}.png") 62 | 63 | 64 | def main(): 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--pretrained_name_or_path", type=str, default="/mnt/data/checkpoints/black-forest-labs/FLUX.1-Kontext-dev") 67 | parser.add_argument("--gedit_bench_path", type=str, default="/mnt/data/datasets/GEdit-Bench") 68 | parser.add_argument("--lora_path", type=str, default=None) 69 | parser.add_argument("--seed", type=int, default=42) 70 | parser.add_argument( 71 | "--output_dir", 72 | type=str, 73 | default="results/no_name", 74 | help="path to save the output images", 75 | ) 76 | args = parser.parse_args() 77 | pretrained_name_or_path = args.pretrained_name_or_path 78 | lora_path = args.lora_path 79 | seed = args.seed 80 | torch.manual_seed(seed) 81 | torch.cuda.manual_seed(seed) 82 | ray.init() 83 | os.makedirs(args.output_dir, exist_ok=True) 84 | dataset = load_from_disk(args.gedit_bench_path) 85 | gpu_count = int(ray.available_resources().get("GPU", 1)) 86 | print(f"GPU count: {gpu_count}") 87 | ray.get( 88 | [ 89 | sample.remote( 90 | dataset.select(range(i, len(dataset), gpu_count)), 91 | pretrained_name_or_path, 92 | lora_path, 93 | args.output_dir, 94 | seed, 95 | ) 96 | for i in range(gpu_count) 97 | ] 98 | ) 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /reproduction/sampling/sampling_qwen_gedit.py: -------------------------------------------------------------------------------- 1 | from datasets import load_from_disk 2 | from diffusers import QwenImageEditPlusPipeline, QwenImageEditPipeline, QwenImageTransformer2DModel 3 | import ray 4 | import torch 5 | from typing import Optional 6 | import argparse 7 | import os 8 | from tqdm import tqdm 9 | 10 | def load_pipeline(pretrained_name_or_path: str, lora_path: Optional[str] = None): 11 | pipeline = QwenImageEditPlusPipeline.from_pretrained( 12 | pretrained_name_or_path, 13 | torch_dtype=torch.bfloat16, 14 | ) 15 | pipeline.to("cuda") 16 | 17 | if lora_path: 18 | pipeline.load_lora_weights( 19 | lora_path, 20 | weight_name="adapter_model.safetensors", 21 | adapter_name="lora", 22 | ) 23 | pipeline.set_adapters(["lora"], adapter_weights=[1]) 24 | print("Lora path provided") 25 | else: 26 | print("No lora path provided, using origin model") 27 | return pipeline 28 | 29 | 30 | @ray.remote(num_gpus=1) 31 | def sample(sliced_data, pretrained_name_or_path, lora_path, output_dir, seed): 32 | pipeline = load_pipeline(pretrained_name_or_path, lora_path) 33 | for item in tqdm(sliced_data): 34 | if item["instruction_language"] != "en": 35 | continue 36 | key = item["key"] 37 | prompt = item["instruction"] 38 | task_type = item["task_type"] 39 | input_image = item["input_image_raw"].convert("RGB") 40 | image_output_dir = os.path.join(output_dir, task_type) 41 | 42 | if os.path.exists(os.path.join(image_output_dir, f"{key}.png")): 43 | continue 44 | 45 | os.makedirs(image_output_dir, exist_ok=True) 46 | 47 | width, height = input_image.size 48 | 49 | generator = torch.Generator(device="cuda").manual_seed(seed) 50 | 51 | output_image = pipeline( 52 | prompt=prompt, 53 | true_cfg_scale=4.0, 54 | guidance_scale=1.0, 55 | image=input_image, 56 | negative_prompt=" ", 57 | num_inference_steps=28, 58 | generator=generator, 59 | ).images[0] 60 | 61 | output_image.save(f"{image_output_dir}/{key}.png") 62 | 63 | 64 | def main(): 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--pretrained_name_or_path", type=str, default="/mnt/data/checkpoints/") 67 | parser.add_argument("--gedit_bench_path", type=str, default="/mnt/data/datasets/GEdit-Bench") 68 | parser.add_argument("--lora_path", type=str, default=None) 69 | parser.add_argument("--seed", type=int, default=42) 70 | parser.add_argument( 71 | "--output_dir", 72 | type=str, 73 | default="results/no_name", 74 | help="path to save the output images", 75 | ) 76 | args = parser.parse_args() 77 | pretrained_name_or_path = args.pretrained_name_or_path 78 | lora_path = args.lora_path 79 | seed = args.seed 80 | torch.manual_seed(seed) 81 | torch.cuda.manual_seed(seed) 82 | ray.init() 83 | os.makedirs(args.output_dir, exist_ok=True) 84 | dataset = load_from_disk(args.gedit_bench_path) 85 | gpu_count = int(ray.available_resources().get("GPU", 1)) 86 | print(f"GPU count: {gpu_count}") 87 | ray.get( 88 | [ 89 | sample.remote( 90 | dataset.select(range(i, len(dataset), gpu_count)), 91 | pretrained_name_or_path, 92 | lora_path, 93 | args.output_dir, 94 | seed, 95 | ) 96 | for i in range(gpu_count) 97 | ] 98 | ) 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /flow_grpo/ema.py: -------------------------------------------------------------------------------- 1 | # Copied from another repo, but I can't remember exactly which one. 2 | 3 | from collections.abc import Iterable 4 | 5 | import torch 6 | 7 | 8 | class EMAModuleWrapper: 9 | def __init__( 10 | self, 11 | parameters: Iterable[torch.nn.Parameter], 12 | decay: float = 0.9999, 13 | update_step_interval: int = 1, 14 | device: torch.device | None = None, 15 | ): 16 | parameters = list(parameters) 17 | self.ema_parameters = [p.clone().detach().to(device) for p in parameters] 18 | 19 | self.temp_stored_parameters = None 20 | 21 | self.decay = decay 22 | self.update_step_interval = update_step_interval 23 | self.device = device 24 | 25 | def get_current_decay(self, optimization_step) -> float: 26 | return min((1 + optimization_step) / (10 + optimization_step), self.decay) 27 | 28 | @torch.no_grad() 29 | def step(self, parameters: Iterable[torch.nn.Parameter], optimization_step): 30 | parameters = list(parameters) 31 | 32 | one_minus_decay = 1 - self.get_current_decay(optimization_step) 33 | 34 | if (optimization_step + 1) % self.update_step_interval == 0: 35 | for ema_parameter, parameter in zip(self.ema_parameters, parameters, strict=True): 36 | if parameter.requires_grad: 37 | if ema_parameter.device == parameter.device: 38 | ema_parameter.add_(one_minus_decay * (parameter - ema_parameter)) 39 | else: 40 | # in place calculations to save memory 41 | parameter_copy = parameter.detach().to(ema_parameter.device) 42 | parameter_copy.sub_(ema_parameter) 43 | parameter_copy.mul_(one_minus_decay) 44 | ema_parameter.add_(parameter_copy) 45 | del parameter_copy 46 | 47 | def to(self, device: torch.device = None, dtype: torch.dtype = None) -> None: 48 | self.device = device 49 | self.ema_parameters = [ 50 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) 51 | for p in self.ema_parameters 52 | ] 53 | 54 | @torch.no_grad() 55 | def sync_with_model(self, parameters: Iterable[torch.nn.Parameter]) -> None: 56 | """ 57 | Force the EMA parameters to be a direct copy of the given model parameters. 58 | This is used to create a snapshot for the rollout policy. 59 | """ 60 | parameters = list(parameters) 61 | for ema_parameter, parameter in zip(self.ema_parameters, parameters, strict=True): 62 | ema_parameter.data.copy_(parameter.detach().data) 63 | 64 | def copy_ema_to(self, parameters: Iterable[torch.nn.Parameter], store_temp: bool = True, grad=False) -> None: 65 | if store_temp: 66 | if grad: 67 | self.temp_stored_parameters = [parameter.data.clone() for parameter in parameters] 68 | else: 69 | self.temp_stored_parameters = [parameter.detach().cpu() for parameter in parameters] 70 | 71 | parameters = list(parameters) 72 | for ema_parameter, parameter in zip(self.ema_parameters, parameters, strict=True): 73 | parameter.data.copy_(ema_parameter.to(parameter.device).data) 74 | 75 | def copy_temp_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: 76 | for temp_parameter, parameter in zip(self.temp_stored_parameters, parameters, strict=True): 77 | # Ensure the temp parameter is on the right device 78 | parameter.data.copy_(temp_parameter.to(parameter.device)) 79 | 80 | self.temp_stored_parameters = None 81 | 82 | def load_state_dict(self, state_dict: dict) -> None: 83 | self.decay = self.decay if self.decay else state_dict.get("decay", self.decay) 84 | self.ema_parameters = state_dict.get("ema_parameters") 85 | self.to(self.device) 86 | 87 | def state_dict(self) -> dict: 88 | return { 89 | "decay": self.decay, 90 | "ema_parameters": self.ema_parameters, 91 | } 92 | -------------------------------------------------------------------------------- /flow_grpo/stat_tracking.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import deque 3 | 4 | 5 | class PerPromptStatTracker: 6 | def __init__(self, global_std=False, ban_std_thres=0.05, ban_mean_thres=0.9): 7 | self.global_std = global_std 8 | self.stats = {} 9 | self.history_prompts = set() 10 | 11 | # Banned prompt 12 | self.ban_std_thres = ban_std_thres 13 | self.ban_mean_thres = ban_mean_thres 14 | self.banned_prompts = set() 15 | 16 | # exp reward is for rwr 17 | def update(self, prompts, rewards, exp=False): 18 | prompts = np.array(prompts) 19 | rewards = np.array(rewards, dtype=np.float64) 20 | unique = np.unique(prompts) 21 | advantages = np.empty_like(rewards) * 0.0 22 | stds = np.empty_like(rewards) * 0.0 23 | means = np.empty_like(rewards) * 0.0 24 | 25 | for prompt in unique: 26 | prompt_rewards = rewards[prompts == prompt] 27 | if prompt not in self.stats: 28 | self.stats[prompt] = [] 29 | self.stats[prompt].extend(prompt_rewards) 30 | self.history_prompts.add( 31 | hash(prompt) 32 | ) # Add hash of prompt to history_prompts 33 | for prompt in unique: 34 | self.stats[prompt] = np.stack(self.stats[prompt]) 35 | prompt_rewards = rewards[ 36 | prompts == prompt 37 | ] # Fix: Recalculate prompt_rewards for each prompt 38 | mean = np.mean(self.stats[prompt], axis=0, keepdims=True) 39 | 40 | if self.global_std: 41 | std = ( 42 | np.std(rewards, axis=0, keepdims=True) + 1e-4 43 | ) # Use global std of all rewards 44 | else: 45 | std = np.std(self.stats[prompt], axis=0, keepdims=True) + 1e-4 46 | 47 | prompt_std = np.std(self.stats[prompt], axis=0, keepdims=True).mean() 48 | prompt_mean = np.mean(self.stats[prompt], axis=0, keepdims=True).mean() 49 | 50 | if prompt_std < self.ban_std_thres and prompt_mean > self.ban_mean_thres: 51 | self.banned_prompts.add(prompt) 52 | 53 | advantages[prompts == prompt] = (prompt_rewards - mean) / std 54 | stds[prompts == prompt] = prompt_std 55 | means[prompts == prompt] = mean 56 | 57 | return advantages, stds, means 58 | 59 | def get_stats(self): 60 | avg_group_size = ( 61 | sum(len(v) for v in self.stats.values()) / len(self.stats) 62 | if self.stats 63 | else 0 64 | ) 65 | history_prompts = len(self.history_prompts) 66 | return avg_group_size, history_prompts 67 | 68 | def clear(self): 69 | self.stats = {} 70 | 71 | def get_mean_of_top_rewards(self, top_percentage): 72 | if not self.stats: 73 | return 0.0 74 | 75 | assert 0 <= top_percentage <= 100 76 | 77 | per_prompt_top_means = [] 78 | for prompt_rewards in self.stats.values(): 79 | if isinstance(prompt_rewards, list): 80 | rewards = np.array(prompt_rewards) 81 | else: 82 | rewards = prompt_rewards 83 | 84 | if rewards.size == 0: 85 | continue 86 | 87 | if top_percentage == 100: 88 | per_prompt_top_means.append(np.mean(rewards)) 89 | continue 90 | 91 | lower_bound_percentile = 100 - top_percentage 92 | threshold = np.percentile(rewards, lower_bound_percentile) 93 | 94 | top_rewards = rewards[rewards >= threshold] 95 | 96 | if top_rewards.size > 0: 97 | per_prompt_top_means.append(np.mean(top_rewards)) 98 | 99 | if not per_prompt_top_means: 100 | return 0.0 101 | 102 | return np.mean(per_prompt_top_means) 103 | 104 | 105 | def main(): 106 | tracker = PerPromptStatTracker() 107 | prompts = ["a", "b", "a", "c", "b", "a"] 108 | rewards = [1, 2, 3, 4, 5, 6] 109 | advantages = tracker.update(prompts, rewards) 110 | print("Advantages:", advantages) 111 | avg_group_size, history_prompts = tracker.get_stats() 112 | print("Average Group Size:", avg_group_size) 113 | print("History Prompts:", history_prompts) 114 | tracker.clear() 115 | print("Stats after clear:", tracker.stats) 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /config/base.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_config(): 5 | config = ml_collections.ConfigDict() 6 | 7 | ###### General ###### 8 | # run name for wandb logging and checkpoint saving -- if not provided, will be auto-generated based on the datetime. 9 | config.run_name = "" 10 | config.debug = False 11 | 12 | # random seed for reproducibility. 13 | config.seed = 42 14 | # top-level logging directory for checkpoint saving. 15 | config.logdir = "logs" 16 | # number of epochs to train for. each epoch is one round of sampling from the model followed by training on those 17 | # samples. 18 | config.num_epochs = 100000 19 | # number of epochs between saving model checkpoints. 20 | config.save_freq = 30 21 | config.eval_freq = 10 22 | # mixed precision training. options are "fp16", "bf16", and "no". half-precision speeds up training significantly. 23 | config.mixed_precision = "bf16" 24 | # allow tf32 on Ampere GPUs, which can speed up training. 25 | config.allow_tf32 = True 26 | # resume training from a checkpoint. either an exact checkpoint directory (e.g. checkpoint_50), or a directory 27 | # containing checkpoints, in which case the latest one will be used. `config.use_lora` must be set to the same value 28 | # as the run that generated the saved checkpoint. 29 | config.resume_from = "" 30 | # whether or not to use LoRA. 31 | config.use_lora = True 32 | config.dataset = "" 33 | config.resolution = 768 34 | 35 | ###### Pretrained Model ###### 36 | config.pretrained = pretrained = ml_collections.ConfigDict() 37 | # base model to load. either a path to a local directory, or a model name from the HuggingFace model hub. 38 | pretrained.model = "" 39 | # revision of the model to load. 40 | pretrained.revision = "" 41 | 42 | ###### Sampling ###### 43 | config.sample = sample = ml_collections.ConfigDict() 44 | # number of sampler inference steps. 45 | sample.num_steps = 40 46 | sample.eval_num_steps = 40 47 | # classifier-free guidance weight. 1.0 is no guidance. 48 | sample.guidance_scale = 4.5 49 | # batch size (per GPU!) to use for sampling. 50 | sample.train_batch_size = 1 51 | sample.num_image_per_prompt = 1 52 | sample.test_batch_size = 1 53 | # number of batches to sample per epoch. the total number of samples per epoch is `num_batches_per_epoch * 54 | # batch_size * num_gpus`. 55 | sample.num_batches_per_epoch = 2 56 | # Whether use all samples in a batch to compute std 57 | sample.global_std = True 58 | # noise level 59 | sample.noise_level = 1.0 60 | 61 | ###### Training ###### 62 | config.train = train = ml_collections.ConfigDict() 63 | # batch size (per GPU!) to use for training. 64 | train.batch_size = 1 65 | # learning rate. 66 | train.learning_rate = 3e-4 67 | # Adam beta1. 68 | train.adam_beta1 = 0.9 69 | # Adam beta2. 70 | train.adam_beta2 = 0.999 71 | # Adam weight decay. 72 | train.adam_weight_decay = 1e-4 73 | # Adam epsilon. 74 | train.adam_epsilon = 1e-8 75 | # number of gradient accumulation steps. the effective batch size is `batch_size * num_gpus * 76 | # gradient_accumulation_steps`. 77 | train.gradient_accumulation_steps = 1 78 | # maximum gradient norm for gradient clipping. 79 | train.max_grad_norm = 1.0 80 | # number of inner epochs per outer epoch. each inner epoch is one iteration through the data collected during one 81 | # outer epoch's round of sampling. 82 | train.num_inner_epochs = 1 83 | # clip advantages to the range [-adv_clip_max, adv_clip_max]. 84 | train.adv_clip_max = 5 85 | # the fraction of timesteps to train on. if set to less than 1.0, the model will be trained on a subset of the 86 | # timesteps for each sample. this will speed up training but reduce the accuracy of policy gradient estimates. 87 | train.timestep_fraction = 0.99 88 | # kl ratio 89 | train.beta = 0.0001 90 | # pretrained lora path 91 | train.lora_path = None 92 | train.ema = True 93 | 94 | ###### Prompt Function ###### 95 | # prompt function to use. see `prompts.py` for available prompt functions. 96 | config.prompt_fn = "" 97 | # kwargs to pass to the prompt function. 98 | config.prompt_fn_kwargs = {} 99 | 100 | ###### Reward Function ###### 101 | # reward function to use. see `rewards.py` for available reward functions. 102 | config.reward_fn = ml_collections.ConfigDict() 103 | config.save_dir = "" 104 | 105 | ###### Per-Prompt Stat Tracking ###### 106 | config.per_prompt_stat_tracking = True 107 | 108 | return config 109 | -------------------------------------------------------------------------------- /flow_grpo/diffusers_patch/train_dreambooth_lora_flux.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import torch 17 | 18 | 19 | def _encode_prompt_with_t5( 20 | text_encoder, 21 | tokenizer, 22 | max_sequence_length=512, 23 | prompt=None, 24 | num_images_per_prompt=1, 25 | device=None, 26 | text_input_ids=None, 27 | ): 28 | prompt = [prompt] if isinstance(prompt, str) else prompt 29 | batch_size = len(prompt) 30 | 31 | if tokenizer is not None: 32 | text_inputs = tokenizer( 33 | prompt, 34 | padding="max_length", 35 | max_length=max_sequence_length, 36 | truncation=True, 37 | return_length=False, 38 | return_overflowing_tokens=False, 39 | return_tensors="pt", 40 | ) 41 | text_input_ids = text_inputs.input_ids 42 | else: 43 | if text_input_ids is None: 44 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified") 45 | 46 | prompt_embeds = text_encoder(text_input_ids.to(device))[0] 47 | 48 | if hasattr(text_encoder, "module"): 49 | dtype = text_encoder.module.dtype 50 | else: 51 | dtype = text_encoder.dtype 52 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 53 | 54 | _, seq_len, _ = prompt_embeds.shape 55 | 56 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 57 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 58 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 59 | 60 | return prompt_embeds 61 | 62 | 63 | def _encode_prompt_with_clip( 64 | text_encoder, 65 | tokenizer, 66 | prompt: str, 67 | device=None, 68 | text_input_ids=None, 69 | num_images_per_prompt: int = 1, 70 | ): 71 | prompt = [prompt] if isinstance(prompt, str) else prompt 72 | batch_size = len(prompt) 73 | 74 | if tokenizer is not None: 75 | text_inputs = tokenizer( 76 | prompt, 77 | padding="max_length", 78 | max_length=77, 79 | truncation=True, 80 | return_overflowing_tokens=False, 81 | return_length=False, 82 | return_tensors="pt", 83 | ) 84 | 85 | text_input_ids = text_inputs.input_ids 86 | else: 87 | if text_input_ids is None: 88 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified") 89 | 90 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) 91 | 92 | if hasattr(text_encoder, "module"): 93 | dtype = text_encoder.module.dtype 94 | else: 95 | dtype = text_encoder.dtype 96 | # Use pooled output of CLIPTextModel 97 | prompt_embeds = prompt_embeds.pooler_output 98 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 99 | 100 | # duplicate text embeddings for each generation per prompt, using mps friendly method 101 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 102 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 103 | 104 | return prompt_embeds 105 | 106 | 107 | def encode_prompt( 108 | text_encoders, 109 | tokenizers, 110 | prompt: str, 111 | max_sequence_length, 112 | device=None, 113 | num_images_per_prompt: int = 1, 114 | text_input_ids_list=None, 115 | ): 116 | prompt = [prompt] if isinstance(prompt, str) else prompt 117 | 118 | if hasattr(text_encoders[0], "module"): 119 | dtype = text_encoders[0].module.dtype 120 | else: 121 | dtype = text_encoders[0].dtype 122 | 123 | pooled_prompt_embeds = _encode_prompt_with_clip( 124 | text_encoder=text_encoders[0], 125 | tokenizer=tokenizers[0], 126 | prompt=prompt, 127 | device=device if device is not None else text_encoders[0].device, 128 | num_images_per_prompt=num_images_per_prompt, 129 | text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, 130 | ) 131 | 132 | prompt_embeds = _encode_prompt_with_t5( 133 | text_encoder=text_encoders[1], 134 | tokenizer=tokenizers[1], 135 | max_sequence_length=max_sequence_length, 136 | prompt=prompt, 137 | num_images_per_prompt=num_images_per_prompt, 138 | device=device if device is not None else text_encoders[1].device, 139 | text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, 140 | ) 141 | 142 | text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) 143 | 144 | return prompt_embeds, pooled_prompt_embeds, text_ids 145 | -------------------------------------------------------------------------------- /flow_grpo/diffusers_patch/train_dreambooth_lora_sd3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import torch 17 | 18 | 19 | def _encode_prompt_with_t5( 20 | text_encoder, 21 | tokenizer, 22 | max_sequence_length, 23 | prompt=None, 24 | num_images_per_prompt=1, 25 | device=None, 26 | text_input_ids=None, 27 | ): 28 | prompt = [prompt] if isinstance(prompt, str) else prompt 29 | batch_size = len(prompt) 30 | 31 | if tokenizer is not None: 32 | text_inputs = tokenizer( 33 | prompt, 34 | padding="max_length", 35 | max_length=max_sequence_length, 36 | truncation=True, 37 | add_special_tokens=True, 38 | return_tensors="pt", 39 | ) 40 | text_input_ids = text_inputs.input_ids 41 | else: 42 | if text_input_ids is None: 43 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified") 44 | 45 | prompt_embeds = text_encoder(text_input_ids.to(device))[0] 46 | 47 | dtype = text_encoder.dtype 48 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 49 | 50 | _, seq_len, _ = prompt_embeds.shape 51 | 52 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 53 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 54 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 55 | 56 | return prompt_embeds 57 | 58 | 59 | def _encode_prompt_with_clip( 60 | text_encoder, 61 | tokenizer, 62 | prompt: str, 63 | device=None, 64 | text_input_ids=None, 65 | num_images_per_prompt: int = 1, 66 | ): 67 | prompt = [prompt] if isinstance(prompt, str) else prompt 68 | batch_size = len(prompt) 69 | 70 | if tokenizer is not None: 71 | text_inputs = tokenizer( 72 | prompt, 73 | padding="max_length", 74 | max_length=77, 75 | truncation=True, 76 | return_tensors="pt", 77 | ) 78 | 79 | text_input_ids = text_inputs.input_ids 80 | else: 81 | if text_input_ids is None: 82 | raise ValueError("text_input_ids must be provided when the tokenizer is not specified") 83 | 84 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) 85 | 86 | pooled_prompt_embeds = prompt_embeds[0] 87 | prompt_embeds = prompt_embeds.hidden_states[-2] 88 | prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) 89 | 90 | _, seq_len, _ = prompt_embeds.shape 91 | # duplicate text embeddings for each generation per prompt, using mps friendly method 92 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 93 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 94 | 95 | return prompt_embeds, pooled_prompt_embeds 96 | 97 | 98 | def encode_prompt( 99 | text_encoders, 100 | tokenizers, 101 | prompt: str, 102 | max_sequence_length, 103 | device=None, 104 | num_images_per_prompt: int = 1, 105 | text_input_ids_list=None, 106 | ): 107 | prompt = [prompt] if isinstance(prompt, str) else prompt 108 | 109 | clip_tokenizers = tokenizers[:2] 110 | clip_text_encoders = text_encoders[:2] 111 | 112 | clip_prompt_embeds_list = [] 113 | clip_pooled_prompt_embeds_list = [] 114 | for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)): 115 | prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( 116 | text_encoder=text_encoder, 117 | tokenizer=tokenizer, 118 | prompt=prompt, 119 | device=device if device is not None else text_encoder.device, 120 | num_images_per_prompt=num_images_per_prompt, 121 | text_input_ids=text_input_ids_list[i] if text_input_ids_list else None, 122 | ) 123 | clip_prompt_embeds_list.append(prompt_embeds) 124 | clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) 125 | 126 | clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1) 127 | pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1) 128 | 129 | t5_prompt_embed = _encode_prompt_with_t5( 130 | text_encoders[-1], 131 | tokenizers[-1], 132 | max_sequence_length, 133 | prompt=prompt, 134 | num_images_per_prompt=num_images_per_prompt, 135 | text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None, 136 | device=device if device is not None else text_encoders[-1].device, 137 | ) 138 | 139 | clip_prompt_embeds = torch.nn.functional.pad( 140 | clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) 141 | ) 142 | prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) 143 | 144 | return prompt_embeds, pooled_prompt_embeds 145 | -------------------------------------------------------------------------------- /flow_grpo/hpsv2_scorer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torchvision.transforms import Normalize, Compose, InterpolationMode, ToTensor 5 | import torchvision.transforms.functional as F 6 | import numpy as np 7 | from PIL import Image 8 | 9 | from hpsv2.src.open_clip import create_model, get_tokenizer 10 | from flow_grpo.reward_ckpt_path import CKPT_PATH 11 | 12 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 13 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 14 | 15 | 16 | class ResizeMaxSize(nn.Module): 17 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0): 18 | super().__init__() 19 | if not isinstance(max_size, int): 20 | raise TypeError(f"Size should be int. Got {type(max_size)}") 21 | self.max_size = max_size 22 | self.interpolation = interpolation 23 | self.fn = min if fn == "min" else min # Note: both 'min' and 'max' map to min 24 | self.fill = fill 25 | 26 | def forward(self, img): 27 | if isinstance(img, torch.Tensor): 28 | # Assuming NCHW, get H and W from the last two dimensions 29 | height, width = img.shape[-2:] 30 | else: 31 | width, height = img.size 32 | scale = self.max_size / float(max(height, width)) 33 | if scale != 1.0: 34 | new_size = tuple(round(dim * scale) for dim in (height, width)) 35 | img = F.resize(img, new_size, self.interpolation) 36 | pad_h = self.max_size - new_size[0] 37 | pad_w = self.max_size - new_size[1] 38 | img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill) 39 | return img 40 | 41 | 42 | class MaskAwareNormalize(nn.Module): 43 | def __init__(self, mean, std): 44 | super().__init__() 45 | self.normalize = Normalize(mean=mean, std=std) 46 | 47 | def forward(self, tensor): 48 | # Assuming NCHW, check the channel dimension 49 | if tensor.shape[1] == 4: 50 | # Process each image in the batch 51 | normalized_parts = [] 52 | for i in range(tensor.shape[0]): 53 | img_slice = tensor[i] 54 | normalized_rgb = self.normalize(img_slice[:3]) 55 | alpha_channel = img_slice[3:] 56 | normalized_parts.append(torch.cat([normalized_rgb, alpha_channel], dim=0)) 57 | return torch.stack(normalized_parts, dim=0) 58 | else: 59 | return self.normalize(tensor) 60 | 61 | 62 | def image_transform_tensor( 63 | image_size: int, 64 | mean: tuple = None, 65 | std: tuple = None, 66 | fill_color: int = 0, 67 | ): 68 | mean = mean or OPENAI_DATASET_MEAN 69 | std = std or OPENAI_DATASET_STD 70 | 71 | if not isinstance(mean, (list, tuple)): 72 | mean = (mean,) * 3 73 | if not isinstance(std, (list, tuple)): 74 | std = (std,) * 3 75 | 76 | normalize = MaskAwareNormalize(mean=mean, std=std) 77 | 78 | transforms = [ 79 | ResizeMaxSize(image_size, fill=fill_color), 80 | normalize, 81 | ] 82 | return Compose(transforms) 83 | 84 | 85 | class HPSv2Scorer(nn.Module): 86 | def __init__(self, dtype, device): 87 | super().__init__() 88 | self.dtype = dtype 89 | self.device = device 90 | model = create_model( 91 | "ViT-H-14", 92 | os.path.join(CKPT_PATH, "open_clip_pytorch_model.bin"), 93 | precision="amp", 94 | device=device, 95 | jit=False, 96 | force_quick_gelu=False, 97 | force_custom_text=False, 98 | force_patch_dropout=False, 99 | force_image_size=None, 100 | pretrained_image=False, 101 | output_dict=True, 102 | ) 103 | 104 | image_mean = getattr(model.visual, "image_mean", None) 105 | image_std = getattr(model.visual, "image_std", None) 106 | image_size = model.visual.image_size 107 | if isinstance(image_size, tuple): 108 | image_size = image_size[0] 109 | preprocess_val = image_transform_tensor( 110 | image_size, 111 | mean=image_mean, 112 | std=image_std, 113 | ) 114 | 115 | self.model = model.to(device) 116 | self.preprocess_val = preprocess_val 117 | checkpoint = torch.load(os.path.join(CKPT_PATH, "HPS_v2.1_compressed.pt"), map_location="cpu") 118 | self.model.load_state_dict(checkpoint["state_dict"]) 119 | self.processor = get_tokenizer("ViT-H-14") 120 | self.eval() 121 | 122 | @torch.no_grad() 123 | def __call__(self, images, prompts): 124 | image = self.preprocess_val(images.to(self.dtype).to(device=self.device, non_blocking=True)) 125 | # Process the prompt 126 | text = self.processor(prompts).to(device=self.device, non_blocking=True) 127 | outputs = self.model(image, text) 128 | image_features, text_features = outputs["image_features"], outputs["text_features"] 129 | logits_per_image = image_features @ text_features.T 130 | hps_score = torch.diagonal(logits_per_image, 0) 131 | return hps_score.contiguous() 132 | 133 | 134 | def main(): 135 | scorer = HPSv2Scorer(dtype=torch.float32, device="cuda") 136 | 137 | images = [ 138 | "test_cases/nasa.jpg", 139 | "test_cases/hello world.jpg", 140 | ] 141 | pil_images = [Image.open(img) for img in images] 142 | prompts = [ 143 | 'An astronaut’s glove floating in zero-g with "NASA 2049" on the wrist', 144 | 'New York Skyline with "Hello World" written with fireworks on the sky', 145 | ] 146 | images = [np.array(img) for img in pil_images] 147 | images = np.array(images) 148 | images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW 149 | images = torch.tensor(images, dtype=torch.uint8) / 255.0 150 | print(scorer(images, prompts)) 151 | 152 | 153 | if __name__ == "__main__": 154 | main() 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 |

5 | 6 | Edit-R1: Reinforce Image Editing with Diffusion Negative-Aware Finetuning and 7 | MLLM Implicit Feedback 8 | 9 |

10 | 11 | [![UniWorld-V2](https://img.shields.io/badge/Arxiv-UniWorldV2-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2510.16888) 12 | [![UniWorld-V1](https://img.shields.io/badge/Arxiv-UniWorldV1-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2506.03147) 13 | [![ImgEdit](https://img.shields.io/badge/Arxiv-ImgEdit-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2506.03147) 14 | [![Collection](https://img.shields.io/badge/🤗-Collection-blue.svg)](https://huggingface.co/collections/chestnutlzj/edit-r1-68dc3ecce74f5d37314d59f4) 15 | [![License](https://img.shields.io/badge/License-Apache-yellow)](https://github.com/PKU-YuanGroup/UniWorld-V2/blob/main/LICENSE) 16 | 17 | ## 📣 News 18 | 19 | **[2025/10/19]**: We release **Edit-R1**, which employs [DiffusionNFT](https://github.com/NVlabs/DiffusionNFT) and a training-free reward 20 | model derived from pretrained MLLMs to fine-tune diffusion models for image editing. [UniWorld-Qwen-Image-Edit-2509](https://huggingface.co/collections/chestnutlzj/edit-r1-68dc3ecce74f5d37314d59f4) and [UniWorld-FLUX.1-Kontext-Dev](https://huggingface.co/collections/chestnutlzj/edit-r1-68dc3ecce74f5d37314d59f4) are open-sourced. 21 | 22 | ## 🚀 Environment Set Up 23 | Clone this repository and install packages. 24 | ```bash 25 | git clone https://github.com/PKU-YuanGroup/Edit-R1.git 26 | cd Edit-R1 27 | conda create -n Edit-R1 python=3.10.16 28 | pip install -e . 29 | ``` 30 | 31 | ## 🗝️ Train 32 | 33 | ### Deploy vLLM Reward Server 34 | 35 | Start the reward server: 36 | 37 | ``` 38 | python reward_server/reward_server.py 39 | ``` 40 | 41 | If you want to check the status of the reward server, you can test it by running: 42 | 43 | ``` 44 | python reward_server/test_reward_server.py 45 | ``` 46 | 47 | ### Data Format 48 | 49 | Directory structure: 50 | 51 | ``` 52 | - dataset-dir 53 | - images/ 54 | - YOUR_IMAGE_DATA 55 | - ... 56 | - train_metadata.jsonl 57 | - test_metadata.jsonl 58 | ``` 59 | 60 | `train_metadata.jsonl` and `test_metadata.jsonl` format: 61 | 62 | ``` 63 | {"prompt": "PROMPT", "image": "IMAGE_RELATIVE_PATH", "requirement": "TASK_REQUIREMENT"} 64 | ... 65 | ``` 66 | 67 | ### Configure Training 68 | 69 | See `config/qwen_image_edit_nft.py` and `config/kontext_nft.py` for available configurations. 70 | 71 | ### Run Training 72 | 73 | ```shell 74 | export REWARD_SERVER=[YOUR_REWARD_SERVICE_IP_ADDR]:12341 75 | 76 | torchrun --nproc_per_node=8 \ 77 | scripts/train_nft_qwen_image_edit.py --config config/qwen_image_edit_nft.py:config_name 78 | ``` 79 | 80 | And you can also refer to the example scripts in `examples/`. 81 | 82 | ## ⚡️ Reproduction 83 | 84 | For reproducibility, we provide the reproduction scripts in `reproduction/`. 85 | 86 | See [Reproduction Details](reproduction/README.md) for more details. 87 | 88 | ## 👍 Acknowledgement 89 | 90 | - [**DiffusionNFT**](https://github.com/NVlabs/DiffusionNFT): Huge thanks for their elegant codebase 🤩! 91 | - [Flow-GRPO](https://github.com/yifan123/flow_grpo) 92 | - [ImgEdit](https://github.com/PKU-YuanGroup/ImgEdit) 93 | - [UniWorld-V1](https://github.com/PKU-YuanGroup/UniWorld-V1) 94 | 95 | ## 🔒 License 96 | 97 | See [LICENSE](LICENSE) for details. The FLUX weights fall under the [FLUX.1 [dev] Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md). 98 | 99 | ## ✏️ Citation 100 | 101 | ``` 102 | @article{li2025uniworldv2, 103 | title={Uniworld-V2: Reinforce Image Editing with Diffusion Negative-aware Finetuning and MLLM Implicit Feedback}, 104 | author={Li, Zongjian and Liu, Zheyuan and Zhang, Qihui and Lin, Bin and Yuan, Shenghai and Yan, Zhiyuan and Ye, Yang and Yu, Wangbo and Niu, Yuwei and Yuan, Li}, 105 | journal={arXiv preprint arXiv:2510.16888}, 106 | year={2025} 107 | } 108 | 109 | @article{lin2025uniworld, 110 | title={Uniworld: High-resolution semantic encoders for unified visual understanding and generation}, 111 | author={Lin, Bin and Li, Zongjian and Cheng, Xinhua and Niu, Yuwei and Ye, Yang and He, Xianyi and Yuan, Shenghai and Yu, Wangbo and Wang, Shaodong and Ge, Yunyang and others}, 112 | journal={arXiv preprint arXiv:2506.03147}, 113 | year={2025} 114 | } 115 | 116 | @article{ye2025imgedit, 117 | title={Imgedit: A unified image editing dataset and benchmark}, 118 | author={Ye, Yang and He, Xianyi and Li, Zongjian and Lin, Bin and Yuan, Shenghai and Yan, Zhiyuan and Hou, Bohan and Yuan, Li}, 119 | journal={arXiv preprint arXiv:2505.20275}, 120 | year={2025} 121 | } 122 | ``` 123 | 124 | ## 🎨 Case Comparisons 125 | 126 | | Original | Prompt | Nano-banana | GPT-4o | Qwen-Image-Edit | **UniWorld-V2 (Ours)** | 127 | | :---: | :---: | :---: | :---: | :---: | :---: | 128 | | | **Case 1:** `把鸟移动到红框里,删除掉现在的鸟,最后移除红框` | | | | (✅正确执行指令)| 129 | | | **Case 2:** `把中间白色衣服戴口罩女生的手势改成OK` | | | | (✅OK手势 )| 130 | | | **Case 3:** `提取画面中的吉他` | | | | (✅弦钮上二下三 ) | 131 | | | **Case 4:** `把下面的所有文字并改用书法体。中间的“月满中秋”改成“千里团圆”。并且把月亮改成模糊的月饼。` | | | | (✅模糊月饼,✅书法字体)| 132 | | | **Case 5:** `让画面中的形象坐在高档西餐厅,双手拿刀叉吃牛排` | | | | (✅人物特征,✅刀叉)| 133 | | | **Case 6:** `在中间人物身上添加 3D 网格,精确覆盖衣服褶皱、头发和细节 ` | | | | (✅精确覆盖)| 134 | -------------------------------------------------------------------------------- /reward_server/reward_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from typing import List 4 | from vllm import LLM, SamplingParams 5 | import vllm 6 | from PIL import Image 7 | from io import BytesIO 8 | import base64 9 | import pickle 10 | import traceback 11 | from flask import Flask, request 12 | import ray 13 | import asyncio 14 | import prompt_template 15 | 16 | if vllm.__version__ != "0.9.2": 17 | raise ValueError("vLLM version must be 0.9.2") 18 | 19 | os.environ["VLLM_USE_V1"] = "0" # IMPORTANT 20 | 21 | app = Flask(__name__) 22 | 23 | # Global variables 24 | score_idx = [15, 16, 17, 18, 19, 20] 25 | workers = [] # Ray actors for each GPU 26 | MODEL_PATH = "/mnt/data/checkpoints/Qwen/Qwen2.5-VL-32B-Instruct/" 27 | NUM_GPUS = 8 28 | NUM_TP = 2 29 | 30 | def get_base64(image): 31 | image_data = BytesIO() 32 | image.save(image_data, format="JPEG") 33 | image_data_bytes = image_data.getvalue() 34 | encoded_image = base64.b64encode(image_data_bytes).decode("utf-8") 35 | return encoded_image 36 | 37 | 38 | class LogitsSpy: 39 | def __init__(self): 40 | self.processed_logits: list[torch.Tensor] = [] 41 | 42 | def __call__(self, token_ids: list[int], logits: torch.Tensor): 43 | self.processed_logits.append(logits) 44 | return logits 45 | 46 | 47 | @ray.remote(num_gpus=NUM_TP) 48 | class ModelWorker: 49 | def __init__(self): 50 | self.llm = None 51 | self.load_model() 52 | 53 | def load_model(self): 54 | """Load the Qwen2-VL model using vLLM on specific GPU""" 55 | self.llm = LLM( 56 | MODEL_PATH, limit_mm_per_prompt={"image": 3}, tensor_parallel_size=NUM_TP 57 | ) 58 | 59 | def evaluate_image( 60 | self, image_bytes, prompt, ref_image_bytes=None, requirement: str = "" 61 | ): 62 | # Convert bytes to PIL Image 63 | image = Image.open(BytesIO(image_bytes), formats=["jpeg"]) 64 | ref_image = Image.open(BytesIO(ref_image_bytes), formats=["jpeg"]) 65 | conversation = [ 66 | { 67 | "role": "user", 68 | "content": [ 69 | {"type": "image_pil", "image_pil": ref_image}, 70 | {"type": "image_pil", "image_pil": image}, 71 | { 72 | "type": "text", 73 | "text": prompt_template.SCORE_LOGIT.format( 74 | prompt=prompt, requirement=requirement 75 | ), 76 | }, 77 | ], 78 | }, 79 | ] 80 | return self._vllm_evaluate(conversation) 81 | 82 | def _vllm_evaluate(self, conversation, max_tokens=3, max_score=5): 83 | logits_spy = LogitsSpy() 84 | sampling_params = SamplingParams( 85 | max_tokens=max_tokens, logits_processors=[logits_spy] 86 | ) 87 | self.llm.chat(conversation, sampling_params=sampling_params) 88 | try: 89 | if logits_spy.processed_logits: 90 | probs = torch.softmax(logits_spy.processed_logits[0][score_idx], dim=-1) 91 | score_prob = ( 92 | torch.sum( 93 | probs * torch.arange(len(score_idx)).to(probs.device) 94 | ).item() 95 | / max_score 96 | ) 97 | print(f"Score: {score_prob:.4f}") 98 | return score_prob 99 | else: 100 | print("No outputs received") 101 | return 0.0 102 | except Exception as e: 103 | print(f"Error in _vllm_evaluate: {e}") 104 | score = 0.0 105 | 106 | return score 107 | 108 | 109 | def initialize_ray_workers(num_gpus=8, num_tp=4): 110 | global workers 111 | # Initialize Ray 112 | if not ray.is_initialized(): 113 | ray.init() 114 | 115 | # Create workers for each GPU 116 | workers = [] 117 | for _ in range(num_gpus // num_tp): 118 | worker = ModelWorker.remote() 119 | workers.append(worker) 120 | 121 | print(f"Initialized {num_gpus//num_tp} Ray workers") 122 | return workers 123 | 124 | 125 | async def evaluate_images_async( 126 | image_bytes_list, prompts, ref_image_bytes_list=None, requirements: List[str] = [] 127 | ): 128 | global workers 129 | 130 | if not workers: 131 | raise RuntimeError("Ray workers not initialized") 132 | 133 | tasks = [] 134 | if not requirements: 135 | requirements = [""] * len(prompts) 136 | if ref_image_bytes_list is None: 137 | ref_image_bytes_list = [None] * len(prompts) 138 | for i, (image_bytes, prompt, ref_image_bytes, requirement) in enumerate( 139 | zip(image_bytes_list, prompts, ref_image_bytes_list, requirements) 140 | ): 141 | worker_idx = i % len(workers) 142 | worker = workers[worker_idx] 143 | task = worker.evaluate_image.remote( 144 | image_bytes, prompt, ref_image_bytes, requirement 145 | ) 146 | tasks.append(task) 147 | 148 | scores = ray.get(tasks) 149 | return scores 150 | 151 | 152 | def evaluate_images( 153 | image_bytes_list, prompts, ref_image_bytes_list=None, requirements=[] 154 | ): 155 | loop = asyncio.new_event_loop() 156 | asyncio.set_event_loop(loop) 157 | try: 158 | scores = loop.run_until_complete( 159 | evaluate_images_async( 160 | image_bytes_list, prompts, ref_image_bytes_list, requirements 161 | ) 162 | ) 163 | return scores 164 | finally: 165 | loop.close() 166 | 167 | 168 | @app.route("/mode/", methods=["POST"]) 169 | def inference_mode(mode): 170 | data = request.get_data() 171 | 172 | assert mode in ["logits_non_cot"], "Invalid mode" 173 | 174 | try: 175 | data = pickle.loads(data) 176 | image_bytes_list = data["images"] 177 | ref_image_bytes_list = data.get("ref_images", None) 178 | prompts = data["prompts"] 179 | metadatas = data.get("metadatas", []) 180 | requirements = [] 181 | for metadata in metadatas: 182 | requirements.append(metadata.get("requirement", "")) 183 | 184 | scores = evaluate_images( 185 | image_bytes_list, prompts, ref_image_bytes_list, requirements 186 | ) 187 | 188 | response = {"scores": scores} 189 | response = pickle.dumps(response) 190 | returncode = 200 191 | except KeyError as e: 192 | response = f"KeyError: {str(e)}" 193 | response = response.encode("utf-8") 194 | returncode = 500 195 | except Exception as e: 196 | response = traceback.format_exc() 197 | response = response.encode("utf-8") 198 | returncode = 500 199 | 200 | return response, returncode 201 | 202 | 203 | if __name__ == "__main__": 204 | initialize_ray_workers(NUM_GPUS, NUM_TP) 205 | print(f"Starting Flask server with {NUM_GPUS//NUM_TP} Ray workers...") 206 | app.run(host="0.0.0.0", port=12341, debug=False) 207 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python,intellij+all,vim 3 | 4 | ### Intellij+all ### 5 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 6 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 7 | 8 | # User-specific stuff 9 | .idea/**/workspace.xml 10 | .idea/**/tasks.xml 11 | .idea/**/usage.statistics.xml 12 | .idea/**/dictionaries 13 | .idea/**/shelf 14 | 15 | # AWS User-specific 16 | .idea/**/aws.xml 17 | 18 | # Generated files 19 | .idea/**/contentModel.xml 20 | 21 | # Sensitive or high-churn files 22 | .idea/**/dataSources/ 23 | .idea/**/dataSources.ids 24 | .idea/**/dataSources.local.xml 25 | .idea/**/sqlDataSources.xml 26 | .idea/**/dynamic.xml 27 | .idea/**/uiDesigner.xml 28 | .idea/**/dbnavigator.xml 29 | 30 | # Gradle 31 | .idea/**/gradle.xml 32 | .idea/**/libraries 33 | 34 | # Gradle and Maven with auto-import 35 | # When using Gradle or Maven with auto-import, you should exclude module files, 36 | # since they will be recreated, and may cause churn. Uncomment if using 37 | # auto-import. 38 | # .idea/artifacts 39 | # .idea/compiler.xml 40 | # .idea/jarRepositories.xml 41 | # .idea/modules.xml 42 | # .idea/*.iml 43 | # .idea/modules 44 | # *.iml 45 | # *.ipr 46 | 47 | # CMake 48 | cmake-build-*/ 49 | 50 | # Mongo Explorer plugin 51 | .idea/**/mongoSettings.xml 52 | 53 | # File-based project format 54 | *.iws 55 | 56 | # IntelliJ 57 | out/ 58 | 59 | # mpeltonen/sbt-idea plugin 60 | .idea_modules/ 61 | 62 | # JIRA plugin 63 | atlassian-ide-plugin.xml 64 | 65 | # Cursive Clojure plugin 66 | .idea/replstate.xml 67 | 68 | # SonarLint plugin 69 | .idea/sonarlint/ 70 | 71 | # Crashlytics plugin (for Android Studio and IntelliJ) 72 | com_crashlytics_export_strings.xml 73 | crashlytics.properties 74 | crashlytics-build.properties 75 | fabric.properties 76 | 77 | # Editor-based Rest Client 78 | .idea/httpRequests 79 | 80 | # Android studio 3.1+ serialized cache file 81 | .idea/caches/build_file_checksums.ser 82 | 83 | ### Intellij+all Patch ### 84 | # Ignore everything but code style settings and run configurations 85 | # that are supposed to be shared within teams. 86 | 87 | .idea/* 88 | 89 | !.idea/codeStyles 90 | !.idea/runConfigurations 91 | 92 | ### Python ### 93 | # Byte-compiled / optimized / DLL files 94 | __pycache__/ 95 | *.py[cod] 96 | *$py.class 97 | 98 | # C extensions 99 | *.so 100 | 101 | # Distribution / packaging 102 | .Python 103 | build/ 104 | develop-eggs/ 105 | dist/ 106 | downloads/ 107 | eggs/ 108 | .eggs/ 109 | lib/ 110 | lib64/ 111 | parts/ 112 | sdist/ 113 | var/ 114 | wheels/ 115 | share/python-wheels/ 116 | *.egg-info/ 117 | .installed.cfg 118 | *.egg 119 | MANIFEST 120 | 121 | # PyInstaller 122 | # Usually these files are written by a python script from a template 123 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 124 | *.manifest 125 | *.spec 126 | 127 | # Installer logs 128 | pip-log.txt 129 | pip-delete-this-directory.txt 130 | 131 | # Unit test / coverage reports 132 | htmlcov/ 133 | .tox/ 134 | .nox/ 135 | .coverage 136 | .coverage.* 137 | .cache 138 | nosetests.xml 139 | coverage.xml 140 | *.cover 141 | *.py,cover 142 | .hypothesis/ 143 | .pytest_cache/ 144 | cover/ 145 | 146 | # Translations 147 | *.mo 148 | *.pot 149 | 150 | # Django stuff: 151 | *.log 152 | local_settings.py 153 | db.sqlite3 154 | db.sqlite3-journal 155 | 156 | # Flask stuff: 157 | instance/ 158 | .webassets-cache 159 | 160 | # Scrapy stuff: 161 | .scrapy 162 | 163 | # Sphinx documentation 164 | docs/_build/ 165 | 166 | # PyBuilder 167 | .pybuilder/ 168 | target/ 169 | 170 | # Jupyter Notebook 171 | .ipynb_checkpoints 172 | 173 | # IPython 174 | profile_default/ 175 | ipython_config.py 176 | 177 | # pyenv 178 | # For a library or package, you might want to ignore these files since the code is 179 | # intended to run in multiple environments; otherwise, check them in: 180 | # .python-version 181 | 182 | # pipenv 183 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 184 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 185 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 186 | # install all needed dependencies. 187 | #Pipfile.lock 188 | 189 | # poetry 190 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 191 | # This is especially recommended for binary packages to ensure reproducibility, and is more 192 | # commonly ignored for libraries. 193 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 194 | #poetry.lock 195 | 196 | # pdm 197 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 198 | #pdm.lock 199 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 200 | # in version control. 201 | # https://pdm.fming.dev/#use-with-ide 202 | .pdm.toml 203 | 204 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 205 | __pypackages__/ 206 | 207 | # Celery stuff 208 | celerybeat-schedule 209 | celerybeat.pid 210 | 211 | # SageMath parsed files 212 | *.sage.py 213 | 214 | # Environments 215 | .env 216 | .venv 217 | env/ 218 | venv/ 219 | ENV/ 220 | env.bak/ 221 | venv.bak/ 222 | 223 | # Spyder project settings 224 | .spyderproject 225 | .spyproject 226 | 227 | # Rope project settings 228 | .ropeproject 229 | 230 | # mkdocs documentation 231 | /site 232 | 233 | # mypy 234 | .mypy_cache/ 235 | .dmypy.json 236 | dmypy.json 237 | 238 | # Pyre type checker 239 | .pyre/ 240 | 241 | # pytype static type analyzer 242 | .pytype/ 243 | 244 | # Cython debug symbols 245 | cython_debug/ 246 | 247 | # PyCharm 248 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 249 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 250 | # and can be added to the global gitignore or merged into this file. For a more nuclear 251 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 252 | #.idea/ 253 | 254 | ### Python Patch ### 255 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 256 | poetry.toml 257 | 258 | # ruff 259 | .ruff_cache/ 260 | 261 | # LSP config files 262 | pyrightconfig.json 263 | 264 | ### Vim ### 265 | # Swap 266 | [._]*.s[a-v][a-z] 267 | !*.svg # comment out if you don't need vector files 268 | [._]*.sw[a-p] 269 | [._]s[a-rt-v][a-z] 270 | [._]ss[a-gi-z] 271 | [._]sw[a-p] 272 | 273 | # Session 274 | Session.vim 275 | Sessionx.vim 276 | 277 | # Temporary 278 | .netrwhist 279 | *~ 280 | # Auto-generated tag files 281 | tags 282 | # Persistent undo 283 | [._]*.un~ 284 | 285 | ### VisualStudioCode ### 286 | .vscode/* 287 | !.vscode/settings.json 288 | !.vscode/tasks.json 289 | !.vscode/launch.json 290 | !.vscode/extensions.json 291 | !.vscode/*.code-snippets 292 | 293 | # Local History for Visual Studio Code 294 | .history/ 295 | 296 | # Built Visual Studio Code Extensions 297 | *.vsix 298 | 299 | ### VisualStudioCode Patch ### 300 | # Ignore all local history of files 301 | .history 302 | .ionide 303 | 304 | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim 305 | 306 | wandb/ 307 | logs/ 308 | notebooks/ 309 | 310 | *.pth 311 | 312 | mmcv 313 | mmdetection 314 | HPSv2 315 | 316 | *.err 317 | reward_ckpts/ 318 | *.jpg 319 | *.png 320 | results/ 321 | tests/ 322 | evaluations/ 323 | experiments 324 | evaluations 325 | _legacy 326 | *_reproduction/ -------------------------------------------------------------------------------- /flow_grpo/diffusers_patch/kontext_pipeline_with_logprob.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py 2 | # with the following modifications: 3 | # - It uses the patched version of `sde_step_with_logprob` from `sd3_sde_with_logprob.py`. 4 | # - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step. 5 | from typing import Any, Dict, List, Optional, Union 6 | import torch 7 | import numpy as np 8 | from diffusers.pipelines.flux.pipeline_flux_kontext import ( 9 | retrieve_timesteps, 10 | calculate_shift, 11 | PipelineImageInput, 12 | ) 13 | from .solver import run_sampling 14 | 15 | 16 | @torch.no_grad() 17 | def pipeline_with_logprob( 18 | self, 19 | image: Optional[PipelineImageInput] = None, 20 | prompt: Union[str, List[str]] = None, 21 | prompt_2: Optional[Union[str, List[str]]] = None, 22 | negative_prompt: Union[str, List[str]] = None, 23 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 24 | height: Optional[int] = None, 25 | width: Optional[int] = None, 26 | num_inference_steps: int = 28, 27 | guidance_scale: float = 3.5, 28 | num_images_per_prompt: Optional[int] = 1, 29 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 30 | latents: Optional[torch.FloatTensor] = None, 31 | prompt_embeds: Optional[torch.FloatTensor] = None, 32 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 33 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 34 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 35 | output_type: Optional[str] = "pil", 36 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 37 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 38 | max_sequence_length: int = 512, 39 | noise_level: float = 0.7, 40 | deterministic: bool = False, 41 | solver: str = "flow", 42 | max_area: int = 1024**2, 43 | _auto_resize: bool = True, 44 | ): 45 | height = height or self.default_sample_size * self.vae_scale_factor 46 | width = width or self.default_sample_size * self.vae_scale_factor 47 | 48 | aspect_ratio = width / height 49 | width = round((max_area * aspect_ratio) ** 0.5) 50 | height = round((max_area / aspect_ratio) ** 0.5) 51 | 52 | multiple_of = self.vae_scale_factor * 2 53 | width = width // multiple_of * multiple_of 54 | height = height // multiple_of * multiple_of 55 | 56 | self.check_inputs( 57 | prompt, 58 | prompt_2, 59 | height, 60 | width, 61 | negative_prompt=negative_prompt, 62 | negative_prompt_2=negative_prompt_2, 63 | prompt_embeds=prompt_embeds, 64 | negative_prompt_embeds=negative_prompt_embeds, 65 | pooled_prompt_embeds=pooled_prompt_embeds, 66 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 67 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 68 | max_sequence_length=max_sequence_length, 69 | ) 70 | 71 | self._guidance_scale = guidance_scale 72 | self._joint_attention_kwargs = joint_attention_kwargs 73 | self._current_timestep = None 74 | self._interrupt = False 75 | 76 | # 2. Define call parameters 77 | if prompt is not None and isinstance(prompt, str): 78 | batch_size = 1 79 | elif prompt is not None and isinstance(prompt, list): 80 | batch_size = len(prompt) 81 | else: 82 | batch_size = prompt_embeds.shape[0] 83 | 84 | device = self._execution_device 85 | 86 | lora_scale = ( 87 | self.joint_attention_kwargs.get("scale", None) 88 | if self.joint_attention_kwargs is not None 89 | else None 90 | ) 91 | 92 | ( 93 | prompt_embeds, 94 | pooled_prompt_embeds, 95 | text_ids, 96 | ) = self.encode_prompt( 97 | prompt=prompt, 98 | prompt_2=prompt_2, 99 | prompt_embeds=prompt_embeds, 100 | pooled_prompt_embeds=pooled_prompt_embeds, 101 | device=device, 102 | num_images_per_prompt=num_images_per_prompt, 103 | max_sequence_length=max_sequence_length, 104 | lora_scale=lora_scale, 105 | ) 106 | 107 | # 3. Preprocess image 108 | if image is not None and not ( 109 | isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels 110 | ): 111 | image = self.image_processor.resize(image, height, width) 112 | image = self.image_processor.preprocess(image, height, width) 113 | 114 | # 4. Prepare latent variables 115 | num_channels_latents = self.transformer.config.in_channels // 4 116 | latents, image_latents, latent_ids, image_ids = self.prepare_latents( 117 | image.float(), 118 | batch_size * num_images_per_prompt, 119 | num_channels_latents, 120 | height, 121 | width, 122 | prompt_embeds.dtype, 123 | device, 124 | generator, 125 | latents, 126 | ) 127 | if image_ids is not None: 128 | latent_ids = torch.cat( 129 | [latent_ids, image_ids], dim=0 130 | ) # dim 0 is sequence dimension 131 | 132 | # 5. Prepare timesteps 133 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 134 | if ( 135 | hasattr(self.scheduler.config, "use_flow_sigmas") 136 | and self.scheduler.config.use_flow_sigmas 137 | ): 138 | sigmas = None 139 | image_seq_len = latents.shape[1] 140 | mu = calculate_shift( 141 | image_seq_len, 142 | self.scheduler.config.get("base_image_seq_len", 256), 143 | self.scheduler.config.get("max_image_seq_len", 4096), 144 | self.scheduler.config.get("base_shift", 0.5), 145 | self.scheduler.config.get("max_shift", 1.15), 146 | ) 147 | timesteps, num_inference_steps = retrieve_timesteps( 148 | self.scheduler, 149 | num_inference_steps, 150 | device, 151 | sigmas=sigmas, 152 | mu=mu, 153 | ) 154 | self._num_timesteps = len(timesteps) 155 | 156 | # handle guidance 157 | if self.transformer.config.guidance_embeds: 158 | guidance = torch.full( 159 | [1], guidance_scale, device=device, dtype=torch.float32 160 | ) 161 | guidance = guidance.expand(latents.shape[0]) 162 | else: 163 | guidance = None 164 | 165 | sigmas = self.scheduler.sigmas.float() 166 | 167 | def v_pred_fn(z, sigma): 168 | latent_model_input = z 169 | if image_latents is not None: 170 | latent_model_input = torch.cat([z, image_latents], dim=1) 171 | 172 | timesteps = torch.full( 173 | [latent_model_input.shape[0]], sigma, device=z.device, dtype=torch.float32 174 | ) 175 | noise_pred = self.transformer( 176 | hidden_states=latent_model_input, 177 | timestep=timesteps, 178 | guidance=guidance, 179 | pooled_projections=pooled_prompt_embeds, 180 | encoder_hidden_states=prompt_embeds, 181 | txt_ids=text_ids, 182 | img_ids=latent_ids, 183 | joint_attention_kwargs=self.joint_attention_kwargs, 184 | return_dict=False, 185 | )[0] 186 | noise_pred = noise_pred[:, : latents.size(1)] 187 | return noise_pred 188 | 189 | # 6. Prepare image embeddings 190 | all_latents = [latents] 191 | all_log_probs = [] 192 | 193 | # 7. Denoising loop 194 | latents, all_latents, all_log_probs = run_sampling( 195 | v_pred_fn, latents, sigmas, solver, deterministic, noise_level 196 | ) 197 | 198 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 199 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 200 | latents = latents.to(dtype=self.vae.dtype) 201 | image = self.vae.decode(latents, return_dict=False)[0] 202 | image = self.image_processor.postprocess(image, output_type=output_type) 203 | 204 | # Offload all models 205 | self.maybe_free_model_hooks() 206 | 207 | return image, all_latents, latent_ids, text_ids, image_latents, all_log_probs 208 | -------------------------------------------------------------------------------- /flow_grpo/diffusers_patch/pipeline_with_logprob.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py 2 | # with the following modifications: 3 | # - It uses the patched version of `sde_step_with_logprob` from `sd3_sde_with_logprob.py`. 4 | # - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step. 5 | from typing import Any, Dict, List, Optional, Union 6 | import torch 7 | import numpy as np 8 | from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps 9 | from .solver import run_sampling 10 | 11 | 12 | def calculate_shift( 13 | image_seq_len, 14 | base_seq_len: int = 256, 15 | max_seq_len: int = 4096, 16 | base_shift: float = 0.5, 17 | max_shift: float = 1.15, 18 | ): 19 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) 20 | b = base_shift - m * base_seq_len 21 | mu = image_seq_len * m + b 22 | return mu 23 | 24 | 25 | @torch.no_grad() 26 | def pipeline_with_logprob( 27 | self, 28 | prompt: Union[str, List[str]] = None, 29 | prompt_2: Optional[Union[str, List[str]]] = None, 30 | prompt_3: Optional[Union[str, List[str]]] = None, 31 | height: Optional[int] = None, 32 | width: Optional[int] = None, 33 | num_inference_steps: int = 28, 34 | guidance_scale: float = 7.0, 35 | negative_prompt: Optional[Union[str, List[str]]] = None, 36 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 37 | negative_prompt_3: Optional[Union[str, List[str]]] = None, 38 | num_images_per_prompt: Optional[int] = 1, 39 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 40 | latents: Optional[torch.FloatTensor] = None, 41 | prompt_embeds: Optional[torch.FloatTensor] = None, 42 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 43 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 44 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 45 | output_type: Optional[str] = "pil", 46 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 47 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 48 | max_sequence_length: int = 256, 49 | noise_level: float = 0.7, 50 | deterministic: bool = False, 51 | solver: str = "flow", 52 | model_type: str = "sd3", 53 | ): 54 | height = height or self.default_sample_size * self.vae_scale_factor 55 | width = width or self.default_sample_size * self.vae_scale_factor 56 | 57 | assert model_type in ["sd3", "flux"] 58 | flux = model_type == "flux" 59 | # 1. Check inputs. Raise error if not correct 60 | if not flux: 61 | self.check_inputs( 62 | prompt, 63 | prompt_2, 64 | prompt_3, 65 | height, 66 | width, 67 | negative_prompt=negative_prompt, 68 | negative_prompt_2=negative_prompt_2, 69 | negative_prompt_3=negative_prompt_3, 70 | prompt_embeds=prompt_embeds, 71 | negative_prompt_embeds=negative_prompt_embeds, 72 | pooled_prompt_embeds=pooled_prompt_embeds, 73 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 74 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 75 | max_sequence_length=max_sequence_length, 76 | ) 77 | else: 78 | self.check_inputs( 79 | prompt, 80 | prompt_2, 81 | height, 82 | width, 83 | negative_prompt=negative_prompt, 84 | negative_prompt_2=negative_prompt_2, 85 | prompt_embeds=prompt_embeds, 86 | negative_prompt_embeds=negative_prompt_embeds, 87 | pooled_prompt_embeds=pooled_prompt_embeds, 88 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 89 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 90 | max_sequence_length=max_sequence_length, 91 | ) 92 | 93 | self._guidance_scale = guidance_scale 94 | self._joint_attention_kwargs = joint_attention_kwargs 95 | self._current_timestep = None 96 | self._interrupt = False 97 | 98 | # 2. Define call parameters 99 | if prompt is not None and isinstance(prompt, str): 100 | batch_size = 1 101 | elif prompt is not None and isinstance(prompt, list): 102 | batch_size = len(prompt) 103 | else: 104 | batch_size = prompt_embeds.shape[0] 105 | 106 | device = self._execution_device 107 | 108 | lora_scale = self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 109 | if not flux: 110 | ( 111 | prompt_embeds, 112 | negative_prompt_embeds, 113 | pooled_prompt_embeds, 114 | negative_pooled_prompt_embeds, 115 | ) = self.encode_prompt( 116 | prompt=prompt, 117 | prompt_2=prompt_2, 118 | prompt_3=prompt_3, 119 | negative_prompt=negative_prompt, 120 | negative_prompt_2=negative_prompt_2, 121 | negative_prompt_3=negative_prompt_3, 122 | do_classifier_free_guidance=self.do_classifier_free_guidance, 123 | prompt_embeds=prompt_embeds, 124 | negative_prompt_embeds=negative_prompt_embeds, 125 | pooled_prompt_embeds=pooled_prompt_embeds, 126 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 127 | device=device, 128 | num_images_per_prompt=num_images_per_prompt, 129 | max_sequence_length=max_sequence_length, 130 | lora_scale=lora_scale, 131 | ) 132 | if self.do_classifier_free_guidance: 133 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 134 | pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) 135 | else: 136 | ( 137 | prompt_embeds, 138 | pooled_prompt_embeds, 139 | text_ids, 140 | ) = self.encode_prompt( 141 | prompt=prompt, 142 | prompt_2=prompt_2, 143 | prompt_embeds=prompt_embeds, 144 | pooled_prompt_embeds=pooled_prompt_embeds, 145 | device=device, 146 | num_images_per_prompt=num_images_per_prompt, 147 | max_sequence_length=max_sequence_length, 148 | lora_scale=lora_scale, 149 | ) 150 | 151 | # 4. Prepare latent variables 152 | if not flux: 153 | num_channels_latents = self.transformer.config.in_channels 154 | latents = self.prepare_latents( 155 | batch_size * num_images_per_prompt, 156 | num_channels_latents, 157 | height, 158 | width, 159 | prompt_embeds.dtype, 160 | device, 161 | generator, 162 | latents, 163 | ) 164 | else: 165 | num_channels_latents = self.transformer.config.in_channels // 4 166 | latents, latent_image_ids = self.prepare_latents( 167 | batch_size * num_images_per_prompt, 168 | num_channels_latents, 169 | height, 170 | width, 171 | prompt_embeds.dtype, 172 | device, 173 | generator, 174 | latents, 175 | ) 176 | 177 | # 5. Prepare timesteps 178 | if not flux: 179 | timesteps, num_inference_steps = retrieve_timesteps( 180 | self.scheduler, 181 | num_inference_steps, 182 | device, 183 | sigmas=None, 184 | ) 185 | self._num_timesteps = len(timesteps) 186 | else: 187 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 188 | if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: 189 | sigmas = None 190 | image_seq_len = latents.shape[1] 191 | mu = calculate_shift( 192 | image_seq_len, 193 | self.scheduler.config.get("base_image_seq_len", 256), 194 | self.scheduler.config.get("max_image_seq_len", 4096), 195 | self.scheduler.config.get("base_shift", 0.5), 196 | self.scheduler.config.get("max_shift", 1.15), 197 | ) 198 | timesteps, num_inference_steps = retrieve_timesteps( 199 | self.scheduler, 200 | num_inference_steps, 201 | device, 202 | sigmas=sigmas, 203 | mu=mu, 204 | ) 205 | self._num_timesteps = len(timesteps) 206 | 207 | sigmas = self.scheduler.sigmas.float() 208 | 209 | 210 | def v_pred_fn(z, sigma): 211 | if not flux: 212 | latent_model_input = torch.cat([z] * 2) if self.do_classifier_free_guidance else z 213 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 214 | timesteps = torch.full([latent_model_input.shape[0]], sigma * 1000, device=z.device, dtype=torch.long) 215 | noise_pred = self.transformer( 216 | hidden_states=latent_model_input, 217 | timestep=timesteps, 218 | encoder_hidden_states=prompt_embeds, 219 | pooled_projections=pooled_prompt_embeds, 220 | joint_attention_kwargs=self.joint_attention_kwargs, 221 | return_dict=False, 222 | )[0] 223 | noise_pred = noise_pred.to(prompt_embeds.dtype) 224 | # perform guidance 225 | if self.do_classifier_free_guidance: 226 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 227 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 228 | else: 229 | latent_model_input = z 230 | # handle guidance 231 | if self.transformer.config.guidance_embeds: 232 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) 233 | guidance = guidance.expand(latent_model_input.shape[0]) 234 | else: 235 | guidance = None 236 | 237 | timesteps = torch.full([latent_model_input.shape[0]], sigma, device=z.device, dtype=torch.long) 238 | noise_pred = self.transformer( 239 | hidden_states=latent_model_input, 240 | timestep=timesteps, 241 | guidance=guidance, 242 | pooled_projections=pooled_prompt_embeds, 243 | encoder_hidden_states=prompt_embeds, 244 | txt_ids=text_ids, 245 | img_ids=latent_image_ids, 246 | joint_attention_kwargs=self.joint_attention_kwargs, 247 | return_dict=False, 248 | )[0] 249 | return noise_pred 250 | 251 | # 6. Prepare image embeddings 252 | all_latents = [latents] 253 | all_log_probs = [] 254 | 255 | # 7. Denoising loop 256 | latents, all_latents, all_log_probs = run_sampling(v_pred_fn, latents, sigmas, solver, deterministic, noise_level) 257 | 258 | if flux: 259 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 260 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 261 | latents = latents.to(dtype=self.vae.dtype) 262 | image = self.vae.decode(latents, return_dict=False)[0] 263 | image = self.image_processor.postprocess(image, output_type=output_type) 264 | 265 | # Offload all models 266 | self.maybe_free_model_hooks() 267 | 268 | if not flux: 269 | return image, all_latents, all_log_probs 270 | else: 271 | return image, all_latents, latent_image_ids, text_ids, all_log_probs 272 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /flow_grpo/diffusers_patch/solver.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from diffusers.utils.torch_utils import randn_tensor 4 | from typing import Optional, List 5 | from dataclasses import dataclass 6 | import torch.distributed as dist 7 | import tqdm 8 | from functools import partial 9 | 10 | tqdm = partial(tqdm.tqdm, dynamic_ncols=True) 11 | 12 | 13 | # Modified from MixGRPO 14 | def run_sampling( 15 | v_pred_fn, 16 | z, 17 | sigma_schedule, 18 | solver="flow", 19 | determistic=False, 20 | eta=0.7, 21 | ): 22 | assert solver in ["flow", "dance", "ddim", "dpm1", "dpm2"] 23 | dtype = z.dtype 24 | all_latents = [z] 25 | all_log_probs = [] 26 | 27 | if "dpm" in solver: 28 | order = int(solver[-1]) 29 | dpm_state = DPMState(order=order) 30 | for i in tqdm( 31 | range(len(sigma_schedule) - 1), 32 | desc="Sampling Progress", 33 | disable=not dist.is_initialized() or dist.get_rank() != 0, 34 | ): 35 | sigma = sigma_schedule[i] 36 | 37 | pred = v_pred_fn(z.to(dtype), sigma) 38 | if solver == "flow": 39 | z, pred_original, log_prob = flow_grpo_step( 40 | model_output=pred.float(), 41 | latents=z.float(), 42 | eta=eta if not determistic else 0, 43 | sigmas=sigma_schedule, 44 | index=i, 45 | prev_sample=None, 46 | ) 47 | elif solver == "dance": 48 | z, pred_original, log_prob = dance_grpo_step( 49 | pred.float(), z.float(), eta if not determistic else 0, sigmas=sigma_schedule, index=i, prev_sample=None 50 | ) 51 | elif solver == "ddim": 52 | z, pred_original, log_prob = ddim_step( 53 | pred.float(), z.float(), eta if not determistic else 0, sigmas=sigma_schedule, index=i, prev_sample=None 54 | ) 55 | elif "dpm" in solver: 56 | assert determistic 57 | z, pred_original, log_prob = dpm_step( 58 | order, 59 | model_output=pred.float(), 60 | sample=z.float(), 61 | step_index=i, 62 | timesteps=sigma_schedule[:-1], 63 | sigmas=sigma_schedule, 64 | dpm_state=dpm_state, 65 | ) 66 | else: 67 | assert False 68 | z = z.to(dtype) 69 | all_latents.append(z) 70 | all_log_probs.append(log_prob) 71 | 72 | latents = z.to(dtype) 73 | # all_latents = torch.stack(all_latents, dim=1) # (batch_size, num_steps + 1, 4, 64, 64) 74 | # all_log_probs = torch.stack(all_log_probs, dim=1) # (batch_size, num_steps, 1) 75 | return latents, all_latents, all_log_probs 76 | 77 | 78 | def flow_grpo_step( 79 | model_output: torch.Tensor, 80 | latents: torch.Tensor, 81 | eta: float, 82 | sigmas: torch.Tensor, 83 | index: int, 84 | prev_sample: torch.Tensor, 85 | generator: Optional[torch.Generator] = None, 86 | ): 87 | device = model_output.device 88 | sigma = sigmas[index].to(device) 89 | sigma_prev = sigmas[index + 1].to(device) 90 | sigma_max = sigmas[1].item() 91 | dt = sigma_prev - sigma # neg dt 92 | 93 | pred_original_sample = latents - sigma * model_output 94 | 95 | std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))) * eta 96 | 97 | if prev_sample is not None and generator is not None: 98 | raise ValueError( 99 | "Cannot pass both generator and prev_sample. Please make sure that either `generator` or" 100 | " `prev_sample` stays `None`." 101 | ) 102 | 103 | prev_sample_mean = ( 104 | latents * (1 + std_dev_t**2 / (2 * sigma) * dt) 105 | + model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt 106 | ) 107 | 108 | if prev_sample is None: 109 | variance_noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype) 110 | prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1 * dt) * variance_noise 111 | 112 | log_prob = ( 113 | -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2)) 114 | - torch.log(std_dev_t * torch.sqrt(-1 * dt)) 115 | - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) 116 | ) 117 | 118 | # mean along all but batch dimension 119 | log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) 120 | 121 | return prev_sample, pred_original_sample, log_prob 122 | 123 | 124 | def dance_grpo_step( 125 | model_output: torch.Tensor, 126 | latents: torch.Tensor, 127 | eta: float, 128 | sigmas: torch.Tensor, 129 | index: int, 130 | prev_sample: torch.Tensor, 131 | ): 132 | sigma = sigmas[index] 133 | dsigma = sigmas[index + 1] - sigma # neg dt 134 | prev_sample_mean = latents + dsigma * model_output 135 | 136 | pred_original_sample = latents - sigma * model_output 137 | 138 | delta_t = sigma - sigmas[index + 1] # pos -dt 139 | std_dev_t = eta * math.sqrt(delta_t) 140 | 141 | score_estimate = -(latents - pred_original_sample * (1 - sigma)) / sigma**2 142 | log_term = -0.5 * eta**2 * score_estimate 143 | prev_sample_mean = prev_sample_mean + log_term * dsigma 144 | 145 | if prev_sample is None: 146 | prev_sample = prev_sample_mean + torch.randn_like(prev_sample_mean) * std_dev_t 147 | 148 | # log prob of prev_sample given prev_sample_mean and std_dev_t 149 | log_prob = -((prev_sample.detach().to(torch.float32) - prev_sample_mean.to(torch.float32)) ** 2) / ( 150 | 2 * (std_dev_t**2) 151 | ) 152 | -math.log(std_dev_t) - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) 153 | 154 | # mean along all but batch dimension 155 | log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) 156 | return prev_sample, pred_original_sample, log_prob 157 | 158 | 159 | def ddim_step( 160 | model_output: torch.Tensor, 161 | latents: torch.Tensor, 162 | eta: float, 163 | sigmas: torch.Tensor, 164 | index: int, 165 | prev_sample: torch.Tensor, 166 | ): 167 | model_output = convert_model_output(model_output, latents, sigmas, step_index=index) 168 | prev_sample, prev_sample_mean, std_dev_t, dt_sqrt = ddim_update( 169 | model_output, 170 | sigmas.to(torch.float64), 171 | index, 172 | latents, 173 | eta=eta, 174 | ) 175 | 176 | # Compute log_prob 177 | log_prob = ( 178 | -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * dt_sqrt) ** 2)) 179 | - torch.log(std_dev_t * dt_sqrt) 180 | - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) 181 | ) 182 | 183 | # mean along all but batch dimension 184 | log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) 185 | return prev_sample, model_output, log_prob 186 | 187 | 188 | @dataclass 189 | class DPMState: 190 | order: int 191 | model_outputs: List[torch.Tensor] = None 192 | lower_order_nums = 0 193 | 194 | def __post_init__(self): 195 | self.model_outputs = [None] * self.order 196 | 197 | def update(self, model_output: torch.Tensor): 198 | for i in range(self.order - 1): 199 | self.model_outputs[i] = self.model_outputs[i + 1] 200 | self.model_outputs[-1] = model_output 201 | 202 | def update_lower_order(self): 203 | if self.lower_order_nums < self.order: 204 | self.lower_order_nums += 1 205 | 206 | 207 | def dpm_step( 208 | order, 209 | model_output: torch.Tensor, 210 | sample: torch.Tensor, 211 | step_index: int, 212 | timesteps: list, 213 | sigmas: torch.Tensor, 214 | dpm_state: DPMState = None, 215 | ) -> torch.Tensor: 216 | 217 | # Improve numerical stability for small number of steps 218 | lower_order_final = step_index == len(timesteps) - 1 219 | lower_order_second = (step_index == len(timesteps) - 2) and len(timesteps) < 15 220 | 221 | model_output = convert_model_output(model_output, sample, sigmas, step_index=step_index) 222 | 223 | assert dpm_state is not None 224 | dpm_state.update(model_output) 225 | 226 | # Upcast to avoid precision issues when computing prev_sample 227 | sample = sample.to(torch.float32) 228 | 229 | if order == 1 or dpm_state.lower_order_nums < 1 or lower_order_final: 230 | if step_index == 0 or lower_order_final: 231 | prev_sample, _, _, _ = ddim_update( 232 | model_output, 233 | sigmas.to(torch.float64), 234 | step_index, 235 | sample, 236 | eta=0.0, 237 | ) 238 | else: 239 | prev_sample = dpm_solver_first_order_update( 240 | model_output, 241 | sigmas.to(torch.float64), 242 | step_index, 243 | sample, 244 | ) 245 | elif order == 2 or dpm_state.lower_order_nums < 2 or lower_order_second: 246 | prev_sample = multistep_dpm_solver_second_order_update( 247 | dpm_state.model_outputs, 248 | sigmas.to(torch.float64), 249 | step_index, 250 | sample, 251 | ) 252 | else: 253 | assert False 254 | 255 | dpm_state.update_lower_order() 256 | 257 | # Cast sample back to expected dtype 258 | prev_sample = prev_sample.to(model_output.dtype) 259 | 260 | return prev_sample, model_output, None 261 | 262 | 263 | def convert_model_output( 264 | model_output, 265 | sample, 266 | sigmas, 267 | step_index, 268 | ) -> torch.Tensor: 269 | sigma_t = sigmas[step_index] 270 | x0_pred = sample - sigma_t * model_output 271 | 272 | return x0_pred 273 | 274 | 275 | def ddim_update( 276 | model_output: torch.Tensor, 277 | sigmas, 278 | step_index, 279 | sample: torch.Tensor = None, 280 | noise: Optional[torch.Tensor] = None, 281 | eta: float = 1.0, 282 | ) -> torch.Tensor: 283 | 284 | t, s = sigmas[step_index + 1], sigmas[step_index] 285 | 286 | std_dev_t = eta * t 287 | dt_sqrt = torch.sqrt(1.0 - t**2 * (1 - s) ** 2 / (s**2 * (1 - t) ** 2)) 288 | rho_t = std_dev_t * dt_sqrt 289 | noise_pred = (sample - (1 - s) * model_output) / s 290 | if noise is None: 291 | noise = torch.randn_like(model_output) 292 | prev_mean = (1 - t) * model_output + torch.sqrt(t**2 - rho_t**2) * noise_pred 293 | x_t = prev_mean + rho_t * noise 294 | 295 | return x_t, prev_mean, std_dev_t, dt_sqrt 296 | 297 | 298 | def dpm_solver_first_order_update( 299 | model_output: torch.Tensor, 300 | sigmas, 301 | step_index, 302 | sample: torch.Tensor = None, 303 | ) -> torch.Tensor: 304 | 305 | sigma_t, sigma_s = sigmas[step_index + 1], sigmas[step_index] 306 | alpha_t, sigma_t = _sigma_to_alpha_sigma_t(sigma_t) 307 | alpha_s, sigma_s = _sigma_to_alpha_sigma_t(sigma_s) 308 | lambda_t = torch.log(alpha_t) - torch.log(sigma_t) 309 | lambda_s = torch.log(alpha_s) - torch.log(sigma_s) 310 | 311 | h = lambda_t - lambda_s 312 | x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output 313 | 314 | return x_t 315 | 316 | 317 | def multistep_dpm_solver_second_order_update( 318 | model_output_list: List[torch.Tensor], 319 | sigmas, 320 | step_index, 321 | sample: torch.Tensor = None, 322 | ) -> torch.Tensor: 323 | 324 | sigma_t, sigma_s0, sigma_s1 = ( 325 | sigmas[step_index + 1], 326 | sigmas[step_index], 327 | sigmas[step_index - 1], 328 | ) 329 | 330 | alpha_t, sigma_t = _sigma_to_alpha_sigma_t(sigma_t) 331 | alpha_s0, sigma_s0 = _sigma_to_alpha_sigma_t(sigma_s0) 332 | alpha_s1, sigma_s1 = _sigma_to_alpha_sigma_t(sigma_s1) 333 | 334 | lambda_t = torch.log(alpha_t) - torch.log(sigma_t) 335 | lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) 336 | lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) 337 | 338 | m0, m1 = model_output_list[-1], model_output_list[-2] 339 | 340 | h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 341 | r0 = h_0 / h 342 | D0, D1 = m0, (1.0 / r0) * (m0 - m1) 343 | 344 | x_t = ( 345 | (sigma_t / sigma_s0) * sample 346 | - (alpha_t * (torch.exp(-h) - 1.0)) * D0 347 | - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 348 | ) 349 | 350 | return x_t 351 | 352 | 353 | def _sigma_to_alpha_sigma_t(sigma): 354 | alpha_t = 1 - sigma 355 | sigma_t = sigma 356 | return alpha_t, sigma_t 357 | -------------------------------------------------------------------------------- /flow_grpo/fsdp2_utils.py: -------------------------------------------------------------------------------- 1 | """reference: https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py""" 2 | 3 | import os 4 | from typing import Any, Callable, Dict, List, Optional, cast 5 | 6 | import safetensors 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn as nn 10 | from torch.distributed._composable.fsdp import ( 11 | CPUOffloadPolicy, 12 | MixedPrecisionPolicy, 13 | fully_shard, 14 | ) 15 | from torch.distributed._tensor import DTensor, distribute_tensor 16 | from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta 17 | from torch.distributed.checkpoint.state_dict import ( 18 | StateDictOptions, 19 | _init_optim_state, 20 | get_optimizer_state_dict, 21 | ) 22 | from torch.distributed.device_mesh import DeviceMesh 23 | from torch.nn.parallel import DistributedDataParallel 24 | from torch.optim import Optimizer 25 | from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 26 | 27 | 28 | def unwrap_model(model: nn.Module) -> nn.Module: 29 | if isinstance(model, DistributedDataParallel): 30 | return model.module 31 | return model 32 | 33 | 34 | def prepare_fsdp_model( 35 | model: nn.Module, 36 | shard_conditions: List[Callable[[str, nn.Module], bool]], 37 | reshard_after_forward: bool = True, 38 | dp_mesh: Optional[DeviceMesh] = None, 39 | cpu_offload: bool = False, 40 | weight_dtype: torch.dtype = torch.bfloat16, 41 | ): 42 | fsdp_kwargs = { 43 | "reshard_after_forward": reshard_after_forward, 44 | "mesh": dp_mesh, 45 | } # dp_mesh is None means distributed to all nodes. 46 | 47 | if cpu_offload: 48 | fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() 49 | 50 | fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy( 51 | param_dtype=weight_dtype, 52 | reduce_dtype=weight_dtype, 53 | output_dtype=weight_dtype, 54 | ) 55 | 56 | num_layers_sharded = 0 57 | for n, m in reversed(list(model.named_modules())): 58 | if any(shard_condition(n, m) for shard_condition in shard_conditions): 59 | fully_shard(m, **fsdp_kwargs) 60 | num_layers_sharded += 1 61 | 62 | if num_layers_sharded == 0: 63 | raise ValueError( 64 | "No layer modules were sharded. Please check if shard conditions are working as expected." 65 | ) 66 | 67 | fully_shard(model, **fsdp_kwargs) 68 | 69 | 70 | def save_state( 71 | output_dir, 72 | global_step, 73 | model, 74 | is_fsdp, 75 | save_key_filter=None, 76 | optimizer=None, 77 | dataloader=None, 78 | sampler=None, 79 | scaler=None, 80 | lr_scheduler=None, 81 | ): 82 | """Save FSDP2 state dict to file.""" 83 | full_state_dict = {} 84 | full_optimizer_state_dict = {} 85 | full_dataloader_state_dict = {} 86 | full_sampler_state_dict = {} 87 | full_scaler_state_dict = {} 88 | lr_scheduler_state_dict = {} 89 | 90 | training_state = {"global_step": global_step} 91 | 92 | for key, value in model.state_dict().items(): 93 | if save_key_filter is None or save_key_filter in key: 94 | if not is_fsdp: 95 | full_state_dict[key] = value 96 | else: 97 | full_state_dict[key] = value.full_tensor() 98 | 99 | if not is_fsdp: 100 | full_optimizer_state_dict = optimizer.state_dict() 101 | else: 102 | options = StateDictOptions( 103 | full_state_dict=True, broadcast_from_rank0=True, cpu_offload=True 104 | ) 105 | full_optimizer_state_dict = get_optimizer_state_dict( 106 | model=model, optimizers=optimizer, options=options 107 | ) 108 | 109 | if dataloader is not None: 110 | full_dataloader_state_dict = dataloader.state_dict() 111 | 112 | if sampler is not None: 113 | full_sampler_state_dict = sampler.state_dict() 114 | 115 | if scaler is not None: 116 | full_scaler_state_dict = scaler.state_dict() 117 | 118 | if lr_scheduler is not None: 119 | lr_scheduler_state_dict = lr_scheduler.state_dict() 120 | 121 | if dist.get_rank() == 0: 122 | if not os.path.exists(output_dir): 123 | os.makedirs(output_dir) 124 | safetensors.torch.save_file( 125 | full_state_dict, os.path.join(output_dir, "model.safetensors") 126 | ) 127 | torch.save(training_state, os.path.join(output_dir, "training_state.pth")) 128 | if len(full_optimizer_state_dict) > 0: # optimizer is not None 129 | torch.save( 130 | full_optimizer_state_dict, os.path.join(output_dir, "optimizer.pth") 131 | ) 132 | if len(full_dataloader_state_dict) > 0: # dataloader is not None 133 | torch.save( 134 | full_dataloader_state_dict, os.path.join(output_dir, "dataloader.pth") 135 | ) 136 | if len(full_sampler_state_dict) > 0: # sampler is not None 137 | torch.save(full_sampler_state_dict, os.path.join(output_dir, "sampler.pth")) 138 | 139 | if len(full_scaler_state_dict) > 0: # scaler is not None 140 | torch.save(full_scaler_state_dict, os.path.join(output_dir, "scaler.pth")) 141 | 142 | if len(lr_scheduler_state_dict) > 0: # lr_scheduler is not None 143 | torch.save( 144 | lr_scheduler_state_dict, os.path.join(output_dir, "lr_scheduler.pth") 145 | ) 146 | 147 | 148 | def load_model_state(model, path, device, is_fsdp=False, fsdp_cpu_offload=False): 149 | """Load FSDP2 state dict from file.""" 150 | full_state_dict = safetensors.torch.load_file( 151 | os.path.join(path, "model.safetensors") 152 | ) 153 | 154 | # # Replace all lora_*.default with lora_* in key 155 | # keys_to_replace = [ 156 | # key for key in full_state_dict.keys() if "lora" in key and "default" in key 157 | # ] 158 | # for key in keys_to_replace: 159 | # new_key = key.replace("default.", "") 160 | # full_state_dict[new_key] = full_state_dict[key] 161 | # del full_state_dict[key] 162 | 163 | if is_fsdp: 164 | _load_from_full_model_state_dict( 165 | model, full_state_dict, device=device, cpu_offload=fsdp_cpu_offload 166 | ) 167 | else: 168 | model.load_state_dict(full_state_dict, strict=False) 169 | 170 | 171 | def load_optimizer_state(optimizer, path, is_fsdp=False): 172 | full_optimizer_state_dict = torch.load(os.path.join(path, "optimizer.pth")) 173 | if is_fsdp: 174 | _load_from_full_optimizer_state_dict(optimizer, full_optimizer_state_dict) 175 | else: 176 | optimizer.load_state_dict(full_optimizer_state_dict) 177 | 178 | 179 | def load_state(path, dataloader=None, sampler=None, scaler=None, lr_scheduler=None): 180 | training_state = torch.load(os.path.join(path, "training_state.pth")) 181 | if dataloader is not None: 182 | dataloader.load_state_dict(torch.load(os.path.join(path, "dataloader.pth"))) 183 | if sampler is not None: 184 | sampler.load_state_dict(torch.load(os.path.join(path, "sampler.pth"))) 185 | if scaler is not None: 186 | scaler.load_state_dict(torch.load(os.path.join(path, "scaler.pth"))) 187 | if lr_scheduler is not None: 188 | lr_scheduler.load_state_dict(torch.load(os.path.join(path, "lr_scheduler.pth"))) 189 | return training_state["global_step"] 190 | 191 | 192 | def _load_from_full_model_state_dict( 193 | model: "FSDPModule", # noqa 194 | full_sd: Dict[str, Any], 195 | device: torch.device, 196 | strict: bool = False, 197 | cpu_offload: bool = False, 198 | ): 199 | # has_nf4 = any( 200 | # hasattr(param, "_local_tensor") and isinstance(param._local_tensor, NF4Tensor) 201 | # for param in model.parameters() 202 | # ) 203 | meta_sharded_sd = model.state_dict() 204 | 205 | sharded_sd = {} 206 | for param_name, full_tensor in full_sd.items(): 207 | sharded_meta_param = meta_sharded_sd.get(param_name) 208 | full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device) 209 | if hasattr(sharded_meta_param, "_local_tensor") and isinstance( 210 | sharded_meta_param._local_tensor, NF4Tensor 211 | ): 212 | block_size = sharded_meta_param._local_tensor.block_size 213 | scaler_block_size = sharded_meta_param._local_tensor.scaler_block_size 214 | full_tensor = to_nf4( 215 | full_tensor, 216 | block_size=block_size, 217 | scaler_block_size=scaler_block_size, 218 | ) 219 | # replicating logic from `_fsdp_param.py`` `_init_sharded_param` 220 | # otherwise `distribute_tensor(DTensor(local=NF4))` 221 | # requires dispatching `c10d.scatter_`` 222 | # long-term solution is `swap_tensor` 223 | mesh = sharded_meta_param.device_mesh 224 | if mesh.ndim > 1: 225 | raise NotImplementedError(f"only support 1D FSDP but got {mesh.ndim=}") 226 | shard_mesh_dim = 0 227 | shard_world_size = mesh.size(shard_mesh_dim) 228 | shard_rank = cast( 229 | torch.distributed.ProcessGroup, mesh.get_group(shard_mesh_dim) 230 | ).rank() 231 | chunk = list(torch.chunk(full_tensor, shard_world_size, dim=0))[shard_rank] 232 | sharded_param = full_tensor.new_zeros(chunk.size()) 233 | sharded_param[: chunk.size(0)].copy_(chunk) 234 | # TODO: change to from_local API (need to add view support for NF4) 235 | sharded_tensor = DTensor( 236 | local_tensor=sharded_param, 237 | spec=DTensorSpec( 238 | mesh=sharded_meta_param.device_mesh, 239 | placements=sharded_meta_param.placements, 240 | tensor_meta=TensorMeta( 241 | shape=sharded_meta_param.size(), 242 | dtype=sharded_meta_param.dtype, 243 | stride=sharded_meta_param.stride(), 244 | ), 245 | ), 246 | requires_grad=sharded_meta_param.requires_grad, 247 | ) 248 | elif not hasattr(sharded_meta_param, "device_mesh"): 249 | # In cases where parts of the model aren't sharded, some parameters will be plain tensors 250 | sharded_tensor = full_tensor 251 | else: 252 | sharded_tensor = distribute_tensor( 253 | full_tensor, 254 | sharded_meta_param.device_mesh, 255 | sharded_meta_param.placements, 256 | ) 257 | if cpu_offload: 258 | sharded_tensor = sharded_tensor.cpu() 259 | sharded_sd[param_name] = nn.Parameter(sharded_tensor) 260 | # choose `assign=True` since we cannot call `copy_` on meta tensor 261 | return model.load_state_dict(sharded_sd, strict=strict, assign=True) 262 | 263 | 264 | def _load_from_full_optimizer_state_dict( 265 | opt: Optimizer, 266 | full_sd: Dict[str, Any], 267 | ) -> None: 268 | PARAMS = "params" # noqa: N806 269 | _init_optim_state(opt) 270 | param_groups = opt.state_dict()["param_groups"] 271 | state = opt.state_dict()["state"] 272 | full_param_groups = full_sd["param_groups"] 273 | full_state = full_sd["state"] 274 | for param_group, full_param_group in zip(param_groups, full_param_groups): 275 | for key, value in full_param_group.items(): 276 | if key == PARAMS: 277 | continue 278 | param_group[key] = value 279 | for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]): 280 | if pid not in state: 281 | continue 282 | param_state = state[pid] 283 | full_param_state = full_state[full_pid] 284 | for attr, full_tensor in full_param_state.items(): 285 | sharded_tensor = param_state[attr] 286 | if isinstance(sharded_tensor, DTensor): 287 | # exp_avg is DTensor 288 | param_state[attr] = distribute_tensor( 289 | full_tensor, 290 | sharded_tensor.device_mesh, 291 | sharded_tensor.placements, 292 | ) 293 | else: 294 | # step is plain tensor 295 | param_state[attr] = full_tensor 296 | opt.load_state_dict( 297 | { 298 | "param_groups": param_groups, 299 | "state": state, 300 | } 301 | ) -------------------------------------------------------------------------------- /flow_grpo/diffusers_patch/qwen_image_edit_old_pipeline_with_logprob.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Union 2 | import torch 3 | import numpy as np 4 | from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit import ( 5 | retrieve_timesteps, 6 | ) 7 | from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus import ( 8 | calculate_shift, 9 | calculate_dimensions, 10 | ) 11 | 12 | from diffusers.image_processor import PipelineImageInput 13 | from flow_grpo.diffusers_patch.solver import run_sampling 14 | 15 | 16 | CONDITION_IMAGE_SIZE = 384 * 384 17 | VAE_IMAGE_SIZE = 1024 * 1024 18 | 19 | def _get_qwen_prompt_embeds( 20 | self, 21 | prompt: Union[str, List[str]] = None, 22 | image: Optional[torch.Tensor] = None, 23 | device: Optional[torch.device] = None, 24 | dtype: Optional[torch.dtype] = None, 25 | max_seq_len: int = 1024, 26 | ): 27 | device = device or self._execution_device 28 | dtype = dtype or self.text_encoder.dtype 29 | prompt = [prompt] if isinstance(prompt, str) else prompt 30 | template = self.prompt_template_encode 31 | drop_idx = self.prompt_template_encode_start_idx 32 | txt = [template.format(e) for e in prompt] 33 | 34 | model_inputs = self.processor( 35 | text=txt, 36 | images=image, 37 | padding=True, 38 | return_tensors="pt", 39 | ).to(device) 40 | 41 | outputs = self.text_encoder( 42 | input_ids=model_inputs.input_ids, 43 | attention_mask=model_inputs.attention_mask, 44 | # pixel_values=model_inputs.pixel_values, 45 | # image_grid_thw=model_inputs.image_grid_thw, 46 | output_hidden_states=True, 47 | ) 48 | 49 | hidden_states = outputs.hidden_states[-1] 50 | split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) 51 | split_hidden_states = [e[drop_idx:] for e in split_hidden_states] 52 | attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] 53 | # max_seq_len = max([e.size(0) for e in split_hidden_states]) 54 | prompt_embeds = torch.stack([ 55 | torch.cat([ 56 | u[:max_seq_len] if u.size(0) > max_seq_len else u, 57 | u.new_zeros(max(0, max_seq_len - u.size(0)), u.size(1)) 58 | ]) 59 | for u in split_hidden_states 60 | ]) 61 | encoder_attention_mask = torch.stack([ 62 | torch.cat([ 63 | u[:max_seq_len] if u.size(0) > max_seq_len else u, 64 | u.new_zeros(max(0, max_seq_len - u.size(0))) 65 | ]) 66 | for u in attn_mask_list 67 | ]) 68 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 69 | return prompt_embeds, encoder_attention_mask 70 | 71 | def encode_prompt( 72 | self, 73 | prompt: Union[str, List[str]], 74 | image: Optional[torch.Tensor] = None, 75 | device: Optional[torch.device] = None, 76 | num_images_per_prompt: int = 1, 77 | prompt_embeds: Optional[torch.Tensor] = None, 78 | prompt_embeds_mask: Optional[torch.Tensor] = None, 79 | max_sequence_length: int = 1024, 80 | ): 81 | device = device or self._execution_device 82 | prompt = [prompt] if isinstance(prompt, str) else prompt 83 | batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] 84 | if prompt_embeds is None: 85 | prompt_embeds, prompt_embeds_mask = _get_qwen_prompt_embeds(self, prompt, image, device, max_seq_len=max_sequence_length) 86 | _, seq_len, _ = prompt_embeds.shape 87 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 88 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 89 | prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) 90 | prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) 91 | return prompt_embeds, prompt_embeds_mask 92 | 93 | @torch.no_grad() 94 | def pipeline_with_logprob( 95 | self, 96 | image: Optional[PipelineImageInput] = None, 97 | prompt: Union[str, List[str]] = None, 98 | negative_prompt: Union[str, List[str]] = None, 99 | height: Optional[int] = None, 100 | width: Optional[int] = None, 101 | num_inference_steps: int = 28, 102 | true_cfg_scale: float = 4.0, 103 | guidance_scale: Optional[float] = None, 104 | num_images_per_prompt: Optional[int] = 1, 105 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 106 | latents: Optional[torch.FloatTensor] = None, 107 | prompt_embeds: Optional[torch.FloatTensor] = None, 108 | prompt_embeds_mask: Optional[torch.Tensor] = None, 109 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 110 | negative_prompt_embeds_mask: Optional[torch.Tensor] = None, 111 | output_type: Optional[str] = "pil", 112 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 113 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 114 | max_sequence_length: int = 256, 115 | noise_level: float = 0.7, 116 | deterministic: bool = False, 117 | max_area: Optional[int] = None, 118 | solver: str = "flow", 119 | ): 120 | max_area = VAE_IMAGE_SIZE if max_area is None else max_area 121 | image_size = image[0].size if isinstance(image, list) else image.size 122 | calculated_width, calculated_height = calculate_dimensions(max_area, image_size[0] / image_size[1]) 123 | height = height or calculated_height 124 | width = width or calculated_width 125 | 126 | multiple_of = self.vae_scale_factor * 2 127 | width = width // multiple_of * multiple_of 128 | height = height // multiple_of * multiple_of 129 | 130 | self.check_inputs( 131 | prompt, 132 | height, 133 | width, 134 | negative_prompt=negative_prompt, 135 | prompt_embeds=prompt_embeds, 136 | negative_prompt_embeds=negative_prompt_embeds, 137 | prompt_embeds_mask=prompt_embeds_mask, 138 | negative_prompt_embeds_mask=negative_prompt_embeds_mask, 139 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 140 | max_sequence_length=max_sequence_length, 141 | ) 142 | 143 | self._guidance_scale = guidance_scale 144 | self._joint_attention_kwargs = joint_attention_kwargs 145 | self._current_timestep = None 146 | self._interrupt = False 147 | 148 | # Define call parameters 149 | if prompt is not None and isinstance(prompt, str): 150 | batch_size = 1 151 | elif prompt is not None and isinstance(prompt, list): 152 | batch_size = len(prompt) 153 | else: 154 | batch_size = prompt_embeds.shape[0] 155 | 156 | device = self._execution_device 157 | 158 | if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): 159 | image = self.image_processor.resize(image, calculated_height, calculated_width) 160 | prompt_image = image 161 | image = self.image_processor.preprocess(image, calculated_height, calculated_width) 162 | image = image.unsqueeze(2) 163 | 164 | has_neg_prompt = negative_prompt is not None or ( 165 | negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None 166 | ) 167 | 168 | do_true_cfg = true_cfg_scale > 1 and has_neg_prompt 169 | 170 | prompt_embeds, prompt_embeds_mask = encode_prompt( 171 | self, 172 | prompt=prompt, 173 | prompt_embeds=prompt_embeds, 174 | prompt_embeds_mask=prompt_embeds_mask, 175 | device=device, 176 | num_images_per_prompt=num_images_per_prompt, 177 | max_sequence_length=max_sequence_length, 178 | ) 179 | if do_true_cfg: 180 | negative_prompt_embeds, negative_prompt_embeds_mask = encode_prompt( 181 | self, 182 | prompt=negative_prompt, 183 | prompt_embeds=negative_prompt_embeds, 184 | prompt_embeds_mask=negative_prompt_embeds_mask, 185 | device=device, 186 | num_images_per_prompt=num_images_per_prompt, 187 | max_sequence_length=max_sequence_length, 188 | ) 189 | 190 | # Preprocess image 191 | num_channels_latents = self.transformer.config.in_channels // 4 192 | latents, image_latents = self.prepare_latents( 193 | image, 194 | batch_size * num_images_per_prompt, 195 | num_channels_latents, 196 | height, 197 | width, 198 | prompt_embeds.dtype, 199 | device, 200 | generator, 201 | latents, 202 | ) 203 | img_shapes = [ 204 | [ 205 | (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), 206 | (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2), 207 | ] 208 | ] * batch_size 209 | 210 | # Prepare timesteps 211 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) 212 | if ( 213 | hasattr(self.scheduler.config, "use_flow_sigmas") 214 | and self.scheduler.config.use_flow_sigmas 215 | ): 216 | sigmas = None 217 | image_seq_len = latents.shape[1] 218 | mu = calculate_shift( 219 | image_seq_len, 220 | self.scheduler.config.get("base_image_seq_len", 256), 221 | self.scheduler.config.get("max_image_seq_len", 4096), 222 | self.scheduler.config.get("base_shift", 0.5), 223 | self.scheduler.config.get("max_shift", 1.15), 224 | ) 225 | timesteps, num_inference_steps = retrieve_timesteps( 226 | self.scheduler, 227 | num_inference_steps, 228 | device, 229 | sigmas=sigmas, 230 | mu=mu, 231 | ) 232 | self._num_timesteps = len(timesteps) 233 | 234 | # handle guidance 235 | if self.transformer.config.guidance_embeds and guidance_scale is None: 236 | raise ValueError("guidance_scale is required for guidance-distilled model.") 237 | elif self.transformer.config.guidance_embeds: 238 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) 239 | guidance = guidance.expand(latents.shape[0]) 240 | elif not self.transformer.config.guidance_embeds and guidance_scale is not None: 241 | guidance = None 242 | elif not self.transformer.config.guidance_embeds and guidance_scale is None: 243 | guidance = None 244 | 245 | self._attention_kwargs = {} 246 | 247 | sigmas = self.scheduler.sigmas.float() 248 | 249 | txt_seq_lens = [max_sequence_length] * batch_size 250 | negative_txt_seq_lens = [max_sequence_length] * batch_size 251 | 252 | def v_pred_fn(z, sigma): 253 | latent_model_input = z 254 | if image_latents is not None: 255 | latent_model_input = torch.cat([z, image_latents], dim=1) 256 | 257 | timesteps = torch.full( 258 | [latent_model_input.shape[0]], sigma, device=z.device, dtype=torch.float32 259 | ) 260 | noise_pred = self.transformer( 261 | hidden_states=latent_model_input, 262 | timestep=timesteps, 263 | guidance=guidance, 264 | encoder_hidden_states_mask=prompt_embeds_mask, 265 | encoder_hidden_states=prompt_embeds, 266 | img_shapes=img_shapes, 267 | txt_seq_lens=txt_seq_lens, 268 | attention_kwargs=self.attention_kwargs, 269 | return_dict=False, 270 | )[0] 271 | noise_pred = noise_pred[:, : latents.size(1)] 272 | 273 | if do_true_cfg: 274 | neg_noise_pred = self.transformer( 275 | hidden_states=latent_model_input, 276 | timestep=timesteps, 277 | guidance=guidance, 278 | encoder_hidden_states_mask=negative_prompt_embeds_mask, 279 | encoder_hidden_states=negative_prompt_embeds, 280 | img_shapes=img_shapes, 281 | txt_seq_lens=negative_txt_seq_lens, 282 | attention_kwargs=self.attention_kwargs, 283 | return_dict=False, 284 | )[0] 285 | neg_noise_pred = neg_noise_pred[:, : latents.size(1)] 286 | 287 | comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) 288 | cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) 289 | noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) 290 | noise_pred = comb_pred * (cond_norm / noise_norm) 291 | 292 | return noise_pred 293 | 294 | # 6. Prepare image embeddings 295 | all_latents = [latents] 296 | all_log_probs = [] 297 | 298 | # 7. Denoising loop 299 | latents, all_latents, all_log_probs = run_sampling( 300 | v_pred_fn, latents, sigmas, solver, deterministic, noise_level 301 | ) 302 | 303 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 304 | latents = latents.to(self.vae.dtype) 305 | latents_mean = ( 306 | torch.tensor(self.vae.config.latents_mean) 307 | .view(1, self.vae.config.z_dim, 1, 1, 1) 308 | .to(latents.device, latents.dtype) 309 | ) 310 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( 311 | latents.device, latents.dtype 312 | ) 313 | latents = latents / latents_std + latents_mean 314 | image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] 315 | image = self.image_processor.postprocess(image, output_type=output_type) 316 | 317 | # Offload all models 318 | self.maybe_free_model_hooks() 319 | 320 | return ( 321 | image, 322 | all_latents, 323 | image_latents, 324 | img_shapes, 325 | txt_seq_lens, 326 | prompt_embeds, 327 | prompt_embeds_mask, 328 | all_log_probs, 329 | ) 330 | -------------------------------------------------------------------------------- /scripts/evaluation.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import os 18 | import json 19 | import torch 20 | import numpy as np 21 | from PIL import Image 22 | from tqdm import tqdm 23 | 24 | from diffusers import StableDiffusion3Pipeline 25 | from torch.utils.data import DataLoader, Dataset 26 | from peft import LoraConfig, get_peft_model 27 | 28 | from flow_grpo.rewards import multi_score 29 | 30 | import torch.distributed as dist 31 | from torch.utils.data.distributed import DistributedSampler 32 | from collections import defaultdict 33 | from peft import PeftModel 34 | 35 | import logging 36 | 37 | logging.getLogger("openai").setLevel(logging.ERROR) 38 | logging.getLogger("httpx").setLevel(logging.ERROR) 39 | 40 | 41 | def setup_distributed(rank, world_size): 42 | """Initializes the distributed process group.""" 43 | os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost") 44 | os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355") 45 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 46 | 47 | 48 | def cleanup_distributed(): 49 | """Destroys the distributed process group.""" 50 | dist.destroy_process_group() 51 | 52 | 53 | def is_main_process(rank): 54 | """Checks if the current process is the main one (rank 0).""" 55 | return rank == 0 56 | 57 | 58 | class TextPromptDataset(Dataset): 59 | def __init__(self, dataset_path, split="test"): 60 | self.file_path = os.path.join(dataset_path, f"{split}.txt") 61 | if not os.path.exists(self.file_path): 62 | raise FileNotFoundError(f"Dataset file not found at {self.file_path}") 63 | with open(self.file_path, "r") as f: 64 | self.prompts = [line.strip() for line in f.readlines()] 65 | 66 | def __len__(self): 67 | return len(self.prompts) 68 | 69 | def __getitem__(self, idx): 70 | return {"prompt": self.prompts[idx], "metadata": {}, "original_index": idx} 71 | 72 | 73 | class GenevalPromptDataset(Dataset): 74 | def __init__(self, dataset_path, split="test"): 75 | self.file_path = os.path.join(dataset_path, f"{split}_metadata.jsonl") 76 | if not os.path.exists(self.file_path): 77 | raise FileNotFoundError(f"Dataset file not found at {self.file_path}") 78 | with open(self.file_path, "r", encoding="utf-8") as f: 79 | self.metadatas = [json.loads(line) for line in f] 80 | self.prompts = [item["prompt"] for item in self.metadatas] 81 | 82 | def __len__(self): 83 | return len(self.prompts) 84 | 85 | def __getitem__(self, idx): 86 | return {"prompt": self.prompts[idx], "metadata": self.metadatas[idx], "original_index": idx} 87 | 88 | 89 | def collate_fn(examples): 90 | prompts = [example["prompt"] for example in examples] 91 | metadatas = [example["metadata"] for example in examples] 92 | indices = [example["original_index"] for example in examples] 93 | return prompts, metadatas, indices 94 | 95 | 96 | def main(args): 97 | # --- Distributed Setup --- 98 | rank = int(os.environ.get("RANK", "0")) 99 | world_size = int(os.environ.get("WORLD_SIZE", "1")) 100 | local_rank = int(os.environ.get("LOCAL_RANK", "0")) 101 | 102 | setup_distributed(rank, world_size) 103 | device = torch.device(f"cuda:{local_rank}") 104 | torch.cuda.set_device(device) 105 | 106 | # --- Mixed Precision Setup --- 107 | mixed_precision_dtype = None 108 | if args.mixed_precision == "fp16": 109 | mixed_precision_dtype = torch.float16 110 | elif args.mixed_precision == "bf16": 111 | mixed_precision_dtype = torch.bfloat16 112 | 113 | enable_amp = mixed_precision_dtype is not None 114 | 115 | if is_main_process(rank): 116 | print(f"Running distributed evaluation with {world_size} GPUs.") 117 | if enable_amp: 118 | print(f"Using mixed precision: {args.mixed_precision}") 119 | os.makedirs(args.output_dir, exist_ok=True) 120 | if args.save_images: 121 | os.makedirs(os.path.join(args.output_dir, "images"), exist_ok=True) 122 | 123 | results_filepath = os.path.join(args.output_dir, "evaluation_results.jsonl") 124 | 125 | # --- Load Model and Pipeline --- 126 | if is_main_process(rank): 127 | print("Loading model and pipeline...") 128 | 129 | if args.model_type == "sd3": 130 | pipeline = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-medium") 131 | target_modules = [ 132 | "attn.add_k_proj", 133 | "attn.add_q_proj", 134 | "attn.add_v_proj", 135 | "attn.to_add_out", 136 | "attn.to_k", 137 | "attn.to_out.0", 138 | "attn.to_q", 139 | "attn.to_v", 140 | ] 141 | transformer_lora_config = LoraConfig( 142 | r=32, lora_alpha=64, init_lora_weights="gaussian", target_modules=target_modules 143 | ) 144 | else: 145 | raise ValueError(f"Unsupported model type: {args.model_type}") 146 | 147 | torch.backends.cuda.matmul.allow_tf32 = True 148 | torch.backends.cudnn.allow_tf32 = True 149 | 150 | if args.lora_hf_path: 151 | pipeline.transformer = PeftModel.from_pretrained(pipeline.transformer, args.lora_hf_path) 152 | pipeline.transformer = pipeline.transformer.merge_and_unload() 153 | elif args.checkpoint_path: 154 | lora_path = os.path.join(args.checkpoint_path, "lora") 155 | if is_main_process(rank): 156 | print(f"Loading LoRA weights from: {lora_path}") 157 | if not os.path.exists(lora_path): 158 | raise FileNotFoundError( 159 | f"LoRA directory not found at {lora_path}. Ensure your checkpoint has a 'lora' subdirectory." 160 | ) 161 | 162 | pipeline.transformer = get_peft_model(pipeline.transformer, transformer_lora_config) 163 | pipeline.transformer.load_adapter(lora_path, adapter_name="default", is_trainable=False) 164 | 165 | pipeline.transformer.eval() 166 | text_encoder_dtype = mixed_precision_dtype if enable_amp else torch.float32 167 | 168 | pipeline.transformer.to(device, dtype=text_encoder_dtype) 169 | pipeline.vae.to(device, dtype=torch.float32) # VAE usually fp32 170 | pipeline.text_encoder.to(device, dtype=text_encoder_dtype) 171 | pipeline.text_encoder_2.to(device, dtype=text_encoder_dtype) 172 | pipeline.text_encoder_3.to(device, dtype=text_encoder_dtype) 173 | 174 | pipeline.safety_checker = None 175 | pipeline.set_progress_bar_config( 176 | position=1, 177 | disable=not is_main_process(rank), 178 | leave=False, 179 | desc="Timestep", 180 | dynamic_ncols=True, 181 | ) 182 | 183 | # --- Load Dataset with Distributed Sampler --- 184 | dataset_path = f"dataset/{args.dataset}" 185 | if is_main_process(rank): 186 | print(f"Loading dataset from: {dataset_path}") 187 | 188 | if args.dataset == "geneval": 189 | dataset = GenevalPromptDataset(dataset_path, split="test") 190 | all_reward_scorers = {"geneval": 1.0} 191 | eval_batch_size = 14 192 | elif args.dataset == "ocr": 193 | dataset = TextPromptDataset(dataset_path, split="test") 194 | all_reward_scorers = {"ocr": 1.0} 195 | eval_batch_size = 16 196 | elif args.dataset == "pickscore": 197 | dataset = TextPromptDataset(dataset_path, split="test") 198 | all_reward_scorers = { 199 | "imagereward": 1.0, 200 | "pickscore": 1.0, 201 | "aesthetic": 1.0, 202 | "unifiedreward": 1.0, 203 | "clipscore": 1.0, 204 | "hpsv2": 1.0, 205 | } 206 | eval_batch_size = 16 207 | elif args.dataset == "drawbench": 208 | dataset = TextPromptDataset(dataset_path, split="test") 209 | all_reward_scorers = { 210 | "imagereward": 1.0, 211 | "pickscore": 1.0, 212 | "aesthetic": 1.0, 213 | "unifiedreward": 1.0, 214 | "clipscore": 1.0, 215 | "hpsv2": 1.0, 216 | } 217 | eval_batch_size = 5 218 | 219 | sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False) 220 | dataloader = DataLoader( 221 | dataset, 222 | batch_size=eval_batch_size, 223 | sampler=sampler, 224 | collate_fn=collate_fn, 225 | shuffle=False, 226 | ) 227 | 228 | # --- Instantiate Reward Models --- 229 | if is_main_process(rank): 230 | print("Initializing reward models...") 231 | scoring_fn = multi_score(device, all_reward_scorers) 232 | 233 | # --- Evaluation Loop --- 234 | results_this_rank = [] 235 | 236 | for batch in tqdm(dataloader, desc=f"Evaluating (Rank {rank})", disable=not is_main_process(rank)): 237 | prompts, metadata, indices = batch 238 | current_batch_size = len(prompts) 239 | 240 | with torch.cuda.amp.autocast(enabled=enable_amp, dtype=mixed_precision_dtype): 241 | with torch.no_grad(): 242 | images = pipeline( 243 | prompts, 244 | num_inference_steps=args.num_inference_steps, 245 | guidance_scale=args.guidance_scale, 246 | output_type="pt", 247 | height=args.resolution, 248 | width=args.resolution, 249 | )[0] 250 | 251 | all_scores, _ = scoring_fn(images, prompts, metadata, only_strict=False) 252 | 253 | for i in range(current_batch_size): 254 | sample_idx = indices[i] 255 | result_item = { 256 | "sample_id": sample_idx, 257 | "prompt": prompts[i], 258 | "metadata": metadata[i] if metadata else {}, 259 | "scores": {}, 260 | } 261 | 262 | if args.save_images: 263 | image_path = os.path.join(args.output_dir, "images", f"{sample_idx:05d}.jpg") 264 | pil_image = Image.fromarray((images[i].cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)) 265 | pil_image.save(image_path) 266 | result_item["image_path"] = image_path 267 | 268 | for score_name, score_values in all_scores.items(): 269 | if isinstance(score_values, torch.Tensor): 270 | result_item["scores"][score_name] = score_values[i].detach().cpu().item() 271 | else: 272 | result_item["scores"][score_name] = float(score_values[i]) 273 | 274 | results_this_rank.append(result_item) 275 | 276 | del images, all_scores 277 | torch.cuda.empty_cache() 278 | 279 | # --- Gather and Save Results --- 280 | dist.barrier() 281 | 282 | all_gathered_results = [None] * world_size 283 | dist.all_gather_object(all_gathered_results, results_this_rank) 284 | 285 | if is_main_process(rank): 286 | flat_results = [item for sublist in all_gathered_results for item in sublist] 287 | 288 | flat_results.sort(key=lambda x: x["sample_id"]) 289 | 290 | with open(results_filepath, "w") as f_out: 291 | for result_item in flat_results: 292 | f_out.write(json.dumps(result_item) + "\n") 293 | 294 | print(f"\nEvaluation finished. All {len(flat_results)} results saved to {results_filepath}") 295 | 296 | all_scores_agg = defaultdict(list) 297 | 298 | for result in flat_results: 299 | for score_name, score_value in result["scores"].items(): 300 | if isinstance(score_value, (int, float)): 301 | all_scores_agg[score_name].append(score_value) 302 | 303 | average_scores = { 304 | name: np.mean(list(filter(lambda score: score != -10.0, scores))) for name, scores in all_scores_agg.items() 305 | } 306 | 307 | print("\n--- Average Scores ---") 308 | if not average_scores: 309 | print("No scores were found to average.") 310 | else: 311 | for name, avg_score in sorted(average_scores.items()): 312 | print(f"{name:<20}: {avg_score:.4f}") 313 | print("----------------------") 314 | 315 | avg_scores_filepath = os.path.join(args.output_dir, "average_scores.json") 316 | with open(avg_scores_filepath, "w") as f_avg: 317 | json.dump(average_scores, f_avg, indent=4) 318 | print(f"Average scores also saved to {avg_scores_filepath}") 319 | 320 | cleanup_distributed() 321 | 322 | 323 | if __name__ == "__main__": 324 | parser = argparse.ArgumentParser(description="Evaluate a trained diffusion model in a distributed manner.") 325 | parser.add_argument( 326 | "--lora_hf_path", 327 | type=str, 328 | default="", 329 | help="Huggingface path for LoRA.", 330 | ) 331 | parser.add_argument( 332 | "--checkpoint_path", 333 | type=str, 334 | default="", 335 | help="Local path to the LoRA checkpoint directory (e.g., './save/run_name/checkpoints/checkpoint-5000').", 336 | ) 337 | parser.add_argument( 338 | "--model_type", 339 | type=str, 340 | required=True, 341 | choices=["sd3"], 342 | help="Type of the base model ('sd3').", 343 | ) 344 | parser.add_argument( 345 | "--dataset", type=str, required=True, choices=["geneval", "ocr", "pickscore", "drawbench"], help="Dataset type." 346 | ) 347 | parser.add_argument( 348 | "--output_dir", 349 | type=str, 350 | default="./evaluation_output", 351 | help="Directory to save evaluation results and generated images.", 352 | ) 353 | parser.add_argument( 354 | "--num_inference_steps", type=int, default=40, help="Number of inference steps for the diffusion pipeline." 355 | ) 356 | parser.add_argument("--guidance_scale", type=float, default=1.0, help="Classifier-free guidance scale.") 357 | parser.add_argument("--resolution", type=int, default=512, help="Resolution of the generated images.") 358 | parser.add_argument( 359 | "--save_images", action="store_true", help="Include this flag to save generated images to the output directory." 360 | ) 361 | parser.add_argument( 362 | "--mixed_precision", 363 | type=str, 364 | default="no", 365 | choices=["no", "fp16", "bf16"], 366 | help="Whether to use mixed precision. Choose between 'no', 'fp16', or 'bf16'.", 367 | ) 368 | 369 | args = parser.parse_args() 370 | main(args) 371 | -------------------------------------------------------------------------------- /flow_grpo/rewards.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import io 3 | import os 4 | import numpy as np 5 | import torch 6 | from collections import defaultdict 7 | import random 8 | 9 | 10 | def jpeg_incompressibility(): 11 | def _fn(images, prompts, metadata): 12 | if isinstance(images, torch.Tensor): 13 | images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() 14 | images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC 15 | images = [Image.fromarray(image) for image in images] 16 | buffers = [io.BytesIO() for _ in images] 17 | for image, buffer in zip(images, buffers): 18 | image.save(buffer, format="JPEG", quality=95) 19 | sizes = [buffer.tell() / 1000 for buffer in buffers] 20 | return np.array(sizes), {} 21 | 22 | return _fn 23 | 24 | 25 | def jpeg_compressibility(): 26 | jpeg_fn = jpeg_incompressibility() 27 | 28 | def _fn(images, prompts, metadata): 29 | rew, meta = jpeg_fn(images, prompts, metadata) 30 | return -rew / 500, meta 31 | 32 | return _fn 33 | 34 | 35 | def mllm_score_continue(device): 36 | """Submits images to GenEval and computes a reward. 37 | """ 38 | import requests 39 | from requests.adapters import HTTPAdapter, Retry 40 | from io import BytesIO 41 | import pickle 42 | 43 | batch_size = 64 44 | url = f"http://{os.getenv('REWARD_SERVER', 'localhost:12341')}/mode/logits_non_cot" 45 | sess = requests.Session() 46 | retries = Retry( 47 | total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False 48 | ) 49 | sess.mount("http://", HTTPAdapter(max_retries=retries)) 50 | 51 | def _fn(ref_images, images, prompts, metadatas): 52 | if isinstance(images, torch.Tensor): 53 | images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() 54 | images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC 55 | images_batched = np.array_split(images, np.ceil(len(images) / batch_size)) 56 | 57 | if not isinstance(ref_images, torch.Tensor): 58 | ref_images = np.array([np.array(img) for img in ref_images]) 59 | ref_images_batched = np.array_split(ref_images, np.ceil(len(ref_images) / batch_size)) 60 | 61 | all_scores = [] 62 | for image_batch, ref_image_batch in zip(images_batched, ref_images_batched): 63 | 64 | jpeg_images = [] 65 | for image in image_batch: 66 | img = Image.fromarray(image) 67 | buffer = BytesIO() 68 | img.save(buffer, format="JPEG") 69 | jpeg_images.append(buffer.getvalue()) 70 | 71 | ref_jpeg_images = [] 72 | for ref_image in ref_image_batch: 73 | img = Image.fromarray(ref_image) 74 | buffer = BytesIO() 75 | img.save(buffer, format="JPEG") 76 | ref_jpeg_images.append(buffer.getvalue()) 77 | 78 | # format for LLaVA server 79 | data = { 80 | "ref_images": ref_jpeg_images, 81 | "images": jpeg_images, 82 | "prompts": prompts, 83 | "metadatas": metadatas, 84 | } 85 | data_bytes = pickle.dumps(data) 86 | 87 | # send a request to the llava server 88 | response = sess.post(url, data=data_bytes, timeout=360) 89 | response_data = pickle.loads(response.content) 90 | 91 | all_scores += response_data["scores"] 92 | 93 | return all_scores, {} 94 | 95 | return _fn 96 | 97 | def aesthetic_score(device): 98 | from flow_grpo.aesthetic_scorer import AestheticScorer 99 | 100 | scorer = AestheticScorer(dtype=torch.float32, device=device) 101 | 102 | def _fn(images, prompts, metadata): 103 | if isinstance(images, torch.Tensor): 104 | images = (images * 255).round().clamp(0, 255).to(torch.uint8) 105 | else: 106 | images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW 107 | images = torch.tensor(images, dtype=torch.uint8) 108 | scores = scorer(images) 109 | return scores, {} 110 | 111 | return _fn 112 | 113 | 114 | def clip_score(device): 115 | from flow_grpo.clip_scorer import ClipScorer 116 | 117 | scorer = ClipScorer(device=device) 118 | 119 | def _fn(images, prompts, metadata): 120 | if not isinstance(images, torch.Tensor): 121 | images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW 122 | images = torch.tensor(images, dtype=torch.uint8) / 255.0 123 | scores = scorer(images, prompts) 124 | return scores, {} 125 | 126 | return _fn 127 | 128 | 129 | def hpsv2_score(device): 130 | from flow_grpo.hpsv2_scorer import HPSv2Scorer 131 | 132 | scorer = HPSv2Scorer(dtype=torch.float32, device=device) 133 | 134 | def _fn(images, prompts, metadata): 135 | if not isinstance(images, torch.Tensor): 136 | images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW 137 | images = torch.tensor(images, dtype=torch.uint8) / 255.0 138 | scores = scorer(images, prompts) 139 | return scores, {} 140 | 141 | return _fn 142 | 143 | 144 | def pickscore_score(device): 145 | from flow_grpo.pickscore_scorer import PickScoreScorer 146 | 147 | scorer = PickScoreScorer(dtype=torch.float32, device=device) 148 | 149 | def _fn(images, prompts, metadata): 150 | if isinstance(images, torch.Tensor): 151 | images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() 152 | images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC 153 | images = [Image.fromarray(image) for image in images] 154 | scores = scorer(prompts, images) 155 | return scores, {} 156 | 157 | return _fn 158 | 159 | 160 | def imagereward_score(device): 161 | from flow_grpo.imagereward_scorer import ImageRewardScorer 162 | 163 | scorer = ImageRewardScorer(dtype=torch.float32, device=device) 164 | 165 | def _fn(images, prompts, metadata): 166 | if isinstance(images, torch.Tensor): 167 | images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() 168 | images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC 169 | images = [Image.fromarray(image) for image in images] 170 | prompts = [prompt for prompt in prompts] 171 | scores = scorer(prompts, images) 172 | return scores, {} 173 | 174 | return _fn 175 | 176 | 177 | def geneval_score(device): 178 | from flow_grpo.gen_eval import load_geneval 179 | 180 | batch_size = 64 181 | compute_geneval = load_geneval(device) 182 | 183 | def _fn(images, prompts, metadatas, only_strict): 184 | del prompts 185 | if isinstance(images, torch.Tensor): 186 | images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() 187 | images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC 188 | images_batched = np.array_split(images, np.ceil(len(images) / batch_size)) 189 | metadatas_batched = np.array_split(metadatas, np.ceil(len(metadatas) / batch_size)) 190 | all_scores = [] 191 | all_rewards = [] 192 | all_strict_rewards = [] 193 | all_group_strict_rewards = [] 194 | all_group_rewards = [] 195 | for image_batch, metadata_batched in zip(images_batched, metadatas_batched): 196 | pil_images = [Image.fromarray(image) for image in image_batch] 197 | 198 | data = { 199 | "images": pil_images, 200 | "metadatas": list(metadata_batched), 201 | "only_strict": only_strict, 202 | } 203 | scores, rewards, strict_rewards, group_rewards, group_strict_rewards = compute_geneval(**data) 204 | 205 | all_scores += scores 206 | all_rewards += rewards 207 | all_strict_rewards += strict_rewards 208 | all_group_strict_rewards.append(group_strict_rewards) 209 | all_group_rewards.append(group_rewards) 210 | all_group_strict_rewards_dict = defaultdict(list) 211 | all_group_rewards_dict = defaultdict(list) 212 | for current_dict in all_group_strict_rewards: 213 | for key, value in current_dict.items(): 214 | all_group_strict_rewards_dict[key].extend(value) 215 | all_group_strict_rewards_dict = dict(all_group_strict_rewards_dict) 216 | 217 | for current_dict in all_group_rewards: 218 | for key, value in current_dict.items(): 219 | all_group_rewards_dict[key].extend(value) 220 | all_group_rewards_dict = dict(all_group_rewards_dict) 221 | 222 | return all_scores, all_rewards, all_strict_rewards, all_group_rewards_dict, all_group_strict_rewards_dict 223 | 224 | return _fn 225 | 226 | 227 | def ocr_score(device): 228 | from flow_grpo.ocr import OcrScorer 229 | 230 | scorer = OcrScorer() 231 | 232 | def _fn(images, prompts, metadata): 233 | if isinstance(images, torch.Tensor): 234 | images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() 235 | images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC 236 | scores = scorer(images, prompts) 237 | # change tensor to list 238 | return scores, {} 239 | 240 | return _fn 241 | 242 | 243 | def unifiedreward_score_sglang(device): 244 | import asyncio 245 | from openai import AsyncOpenAI 246 | import base64 247 | from io import BytesIO 248 | import re 249 | 250 | def pil_image_to_base64(image): 251 | buffered = BytesIO() 252 | image.save(buffered, format="PNG") 253 | encoded_image_text = base64.b64encode(buffered.getvalue()).decode("utf-8") 254 | base64_qwen = f"data:image;base64,{encoded_image_text}" 255 | return base64_qwen 256 | 257 | def _extract_scores(text_outputs): 258 | scores = [] 259 | pattern = r"Final Score:\s*([1-5](?:\.\d+)?)" 260 | for text in text_outputs: 261 | match = re.search(pattern, text) 262 | if match: 263 | try: 264 | scores.append(float(match.group(1))) 265 | except ValueError: 266 | scores.append(0.0) 267 | else: 268 | scores.append(0.0) 269 | return scores 270 | 271 | client = AsyncOpenAI(base_url="http://127.0.0.1:17140/v1", api_key="flowgrpo") 272 | 273 | async def evaluate_image(prompt, image): 274 | question = f"\nYou are given a text caption and a generated image based on that caption. Your task is to evaluate this image based on two key criteria:\n1. Alignment with the Caption: Assess how well this image aligns with the provided caption. Consider the accuracy of depicted objects, their relationships, and attributes as described in the caption.\n2. Overall Image Quality: Examine the visual quality of this image, including clarity, detail preservation, color accuracy, and overall aesthetic appeal.\nBased on the above criteria, assign a score from 1 to 5 after 'Final Score:'.\nYour task is provided as follows:\nText Caption: [{prompt}]" 275 | images_base64 = pil_image_to_base64(image) 276 | response = await client.chat.completions.create( 277 | model="UnifiedReward-7b-v1.5", 278 | messages=[ 279 | { 280 | "role": "user", 281 | "content": [ 282 | { 283 | "type": "image_url", 284 | "image_url": {"url": images_base64}, 285 | }, 286 | { 287 | "type": "text", 288 | "text": question, 289 | }, 290 | ], 291 | }, 292 | ], 293 | temperature=0, 294 | ) 295 | return response.choices[0].message.content 296 | 297 | async def evaluate_batch_image(images, prompts): 298 | tasks = [evaluate_image(prompt, img) for prompt, img in zip(prompts, images)] 299 | results = await asyncio.gather(*tasks) 300 | return results 301 | 302 | def _fn(images, prompts, metadata): 303 | # 处理Tensor类型转换 304 | if isinstance(images, torch.Tensor): 305 | images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() 306 | images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC 307 | 308 | # 转换为PIL Image并调整尺寸 309 | images = [Image.fromarray(image).resize((512, 512)) for image in images] 310 | 311 | # 执行异步批量评估 312 | text_outputs = asyncio.run(evaluate_batch_image(images, prompts)) 313 | score = _extract_scores(text_outputs) 314 | score = [sc / 5.0 for sc in score] 315 | return score, {} 316 | 317 | return _fn 318 | 319 | def dummy(): 320 | def _fn(images, prompts, metadata): 321 | return [random.random() for _ in range(len(images))], {} 322 | return _fn 323 | 324 | 325 | def multi_score(device, score_dict): 326 | score_functions = { 327 | "ocr": ocr_score, 328 | "imagereward": imagereward_score, 329 | "pickscore": pickscore_score, 330 | "aesthetic": aesthetic_score, 331 | "jpeg_compressibility": jpeg_compressibility, 332 | "unifiedreward": unifiedreward_score_sglang, 333 | "geneval": geneval_score, 334 | "clipscore": clip_score, 335 | "hpsv2": hpsv2_score, 336 | "mllm_score_continue": mllm_score_continue, 337 | "dummy": dummy 338 | } 339 | score_fns = {} 340 | for score_name, weight in score_dict.items(): 341 | score_fns[score_name] = ( 342 | score_functions[score_name](device) 343 | if "device" in score_functions[score_name].__code__.co_varnames 344 | else score_functions[score_name]() 345 | ) 346 | 347 | # only_strict is only for geneval. During training, only the strict reward is needed, and non-strict rewards don't need to be computed, reducing reward calculation time. 348 | def _fn(images, prompts, metadata, ref_images=None, only_strict=True): 349 | total_scores = [] 350 | score_details = {} 351 | 352 | for score_name, weight in score_dict.items(): 353 | if score_name == "geneval": 354 | scores, rewards, strict_rewards, group_rewards, group_strict_rewards = score_fns[score_name]( 355 | images, prompts, metadata, only_strict 356 | ) 357 | score_details["accuracy"] = rewards 358 | score_details["strict_accuracy"] = strict_rewards 359 | for key, value in group_strict_rewards.items(): 360 | score_details[f"{key}_strict_accuracy"] = value 361 | for key, value in group_rewards.items(): 362 | score_details[f"{key}_accuracy"] = value 363 | elif score_name.startswith("mllm_"): 364 | scores, rewards = score_fns[score_name](ref_images, images, prompts, metadata) 365 | else: 366 | scores, rewards = score_fns[score_name](images, prompts, metadata) 367 | score_details[score_name] = scores 368 | weighted_scores = [weight * score for score in scores] 369 | 370 | if not total_scores: 371 | total_scores = weighted_scores 372 | else: 373 | total_scores = [total + weighted for total, weighted in zip(total_scores, weighted_scores)] 374 | 375 | score_details["avg"] = total_scores 376 | return score_details, {} 377 | 378 | return _fn 379 | 380 | 381 | def main(): 382 | import torchvision.transforms as transforms 383 | 384 | image_paths = [ 385 | "test_cases/nasa.jpg", 386 | ] 387 | 388 | transform = transforms.Compose( 389 | [ 390 | transforms.ToTensor(), # Convert to tensor 391 | ] 392 | ) 393 | 394 | images = torch.stack([transform(Image.open(image_path).convert("RGB")) for image_path in image_paths]) 395 | prompts = [ 396 | 'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist', 397 | ] 398 | metadata = {} # Example metadata 399 | score_dict = {"unifiedreward": 1.0} 400 | # Initialize the multi_score function with a device and score_dict 401 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 402 | scoring_fn = multi_score(device, score_dict) 403 | # Get the scores 404 | scores, _ = scoring_fn(images, prompts, metadata) 405 | # Print the scores 406 | print("Scores:", scores) 407 | 408 | 409 | if __name__ == "__main__": 410 | main() 411 | --------------------------------------------------------------------------------