├── src ├── open-r1-multimodal │ ├── src │ │ └── open_r1 │ │ │ ├── __init__.py │ │ │ ├── trainer │ │ │ ├── __init__.py │ │ │ ├── InternVL2.py │ │ │ └── grpo_trainer.py │ │ │ ├── evaluate.py │ │ │ ├── generate.py │ │ │ ├── grpo.py │ │ │ └── sft.py │ ├── prepare_2B_base.sh │ ├── configs │ │ ├── ddp.yaml │ │ ├── zero2.yaml │ │ └── zero3.yaml │ ├── Makefile │ ├── run_sft.sh │ ├── run_sft_SAT.sh │ ├── run_grpo.sh │ ├── setup.cfg │ ├── run_grpo_SAT.sh │ ├── test.py │ ├── setup.py │ ├── README.md │ └── LICENSE ├── data │ └── SAT │ │ ├── prepare_dataset.sh │ │ └── process_dataset.py └── eval │ ├── evaluate_Qwen2_VL_CVBench.py │ └── evaluate_Qwen2_VL_CVBench-base.py ├── .gitignore ├── setup.sh └── README.md /src/open-r1-multimodal/src/open_r1/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/open-r1-multimodal/src/open_r1/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .grpo_trainer import Qwen2VLGRPOTrainer 2 | 3 | 4 | __all__ = ["Qwen2VLGRPOTrainer"] 5 | -------------------------------------------------------------------------------- /src/data/SAT/prepare_dataset.sh: -------------------------------------------------------------------------------- 1 | # Download the dataset parquet and rename it 2 | wget -O SAT_train.parquet "https://huggingface.co/datasets/array/SAT/resolve/main/SAT_train.parquet?download=true" 3 | 4 | # Create the dataset directory 5 | mkdir -p SAT_images_train 6 | 7 | # Process the dataset 8 | python process_dataset.py -------------------------------------------------------------------------------- /src/open-r1-multimodal/prepare_2B_base.sh: -------------------------------------------------------------------------------- 1 | # Prepare base model with chat template for SFT training 2 | git lfs install 3 | git clone https://huggingface.co/Qwen/Qwen2-VL-2B 4 | mv Qwen2-VL-2B Qwen2-VL-2B-Base 5 | 6 | huggingface-cli download Qwen/Qwen2-VL-2B-Instruct chat_template.json tokenizer_config.json --local-dir ./Qwen2-VL-2B-Base 7 | 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | *.tmp 3 | *.csv 4 | *.json 5 | *.parquet 6 | *.png 7 | *.jpg 8 | 9 | # dependency directories 10 | src/open-r1-multimodal/src/open_r1/__pycache__/ 11 | 12 | # Python cache 13 | __pycache__/ 14 | 15 | # Egg info 16 | *.egg-info/ 17 | 18 | # wandb 19 | src/open-r1-multimodal/wandb/ 20 | 21 | # folder 22 | src/open-r1-multimodal/trajectories/ 23 | 24 | # outputs 25 | src/open-r1-multimodal/outputs/ 26 | -------------------------------------------------------------------------------- /src/open-r1-multimodal/configs/ddp.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: bf16 9 | num_machines: 1 10 | num_processes: 8 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /src/open-r1-multimodal/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: style quality 2 | 3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) 4 | export PYTHONPATH = src 5 | 6 | check_dirs := src 7 | 8 | style: 9 | black --line-length 119 --target-version py310 $(check_dirs) setup.py 10 | isort $(check_dirs) setup.py 11 | 12 | quality: 13 | black --check --line-length 119 --target-version py310 $(check_dirs) setup.py 14 | isort --check-only $(check_dirs) setup.py 15 | flake8 --max-line-length 119 $(check_dirs) setup.py 16 | 17 | 18 | # Evaluation 19 | 20 | evaluate: 21 | -------------------------------------------------------------------------------- /src/open-r1-multimodal/configs/zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: bf16 14 | num_machines: 1 15 | num_processes: 4 16 | main_process_port: 44326 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | conda create -n VisualThinker python=3.11 2 | conda activate VisualThinker 3 | 4 | # Install the packages in open-r1-multimodal . 5 | cd src/open-r1-multimodal 6 | pip install -e ".[dev]" 7 | 8 | # Addtional modules 9 | pip install wandb==0.18.3 10 | pip install tensorboardx 11 | pip install qwen_vl_utils torchvision 12 | pip install flash-attn --no-build-isolation 13 | 14 | pip install transformers==4.49.0 # correct deepspeed support 15 | pip install duckdb 16 | pip install opencv-python 17 | pip install pandas 18 | pip install math_verify==0.5.2 19 | pip install datasets 20 | pip install accelerate 21 | pip install deepspeed 22 | -------------------------------------------------------------------------------- /src/open-r1-multimodal/configs/zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 4 17 | main_process_port: 22316 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_env: [] 21 | tpu_use_cluster: false 22 | tpu_use_sudo: false 23 | use_cpu: false 24 | -------------------------------------------------------------------------------- /src/open-r1-multimodal/run_sft.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | 3 | accelerate launch --config_file=configs/zero3.yaml src/open_r1/sft.py \ 4 | --model_name_or_path \ 5 | --dataset_name \ 6 | --learning_rate 2.0e-5 \ 7 | --num_train_epochs 2 \ 8 | --packing True \ 9 | --max_seq_length 1024 \ 10 | --per_device_train_batch_size 1 \ 11 | --per_device_eval_batch_size 4 \ 12 | --gradient_accumulation_steps 2 \ 13 | --gradient_checkpointing True \ 14 | --report_to wandb \ 15 | --bf16 True \ 16 | --logging_steps 5 \ 17 | --eval_strategy no \ 18 | --output_dir \ 19 | --run_name -------------------------------------------------------------------------------- /src/open-r1-multimodal/run_sft_SAT.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | 3 | accelerate launch --config_file=configs/zero3.yaml src/open_r1/sft.py \ 4 | --model_name_or_path Qwen2-VL-2B-Base \ 5 | --dataset_name SAT \ 6 | --learning_rate 2.0e-5 \ 7 | --num_train_epochs 2 \ 8 | --packing True \ 9 | --max_seq_length 1024 \ 10 | --per_device_train_batch_size 1 \ 11 | --per_device_eval_batch_size 4 \ 12 | --gradient_accumulation_steps 2 \ 13 | --gradient_checkpointing True \ 14 | --report_to wandb \ 15 | --bf16 True \ 16 | --logging_steps 5 \ 17 | --eval_strategy no \ 18 | --save_steps 300 \ 19 | --output_dir outputs/Qwen2_VL-2B-SFT \ 20 | --run_name Qwen2_VL-2B-SFT-SAT -------------------------------------------------------------------------------- /src/open-r1-multimodal/run_grpo.sh: -------------------------------------------------------------------------------- 1 | export DEBUG_MODE="true" 2 | export LOG_PATH="./debug_log_2b.txt" 3 | 4 | 5 | 6 | torchrun --nproc_per_node="8" \ 7 | --nnodes="1" \ 8 | --node_rank="0" \ 9 | --master_addr="127.0.0.1" \ 10 | --master_port="12345" \ 11 | src/open_r1/grpo.py \ 12 | --output_dir \ 13 | --model_name_or_path \ 14 | --dataset_name \ 15 | --max_prompt_length 1024 \ 16 | --per_device_train_batch_size 1 \ 17 | --gradient_accumulation_steps 2 \ 18 | --logging_steps 1 \ 19 | --bf16 \ 20 | --report_to wandb \ 21 | --gradient_checkpointing false \ 22 | --attn_implementation flash_attention_2 \ 23 | --max_pixels 401408 \ 24 | --num_train_epochs 2 \ 25 | --run_name Qwen2-VL-2B-GRPO-CLEVR-70k \ 26 | --save_steps 100 \ 27 | --save_only_model true -------------------------------------------------------------------------------- /src/open-r1-multimodal/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | default_section = FIRSTPARTY 3 | ensure_newline_before_comments = True 4 | force_grid_wrap = 0 5 | include_trailing_comma = True 6 | known_first_party = open_r1 7 | known_third_party = 8 | transformers 9 | datasets 10 | fugashi 11 | git 12 | h5py 13 | matplotlib 14 | nltk 15 | numpy 16 | packaging 17 | pandas 18 | psutil 19 | pytest 20 | rouge_score 21 | sacrebleu 22 | seqeval 23 | sklearn 24 | streamlit 25 | torch 26 | tqdm 27 | 28 | line_length = 119 29 | lines_after_imports = 2 30 | multi_line_output = 3 31 | use_parentheses = True 32 | 33 | [flake8] 34 | ignore = E203, E501, E741, W503, W605 35 | max-line-length = 119 36 | per-file-ignores = 37 | # imported but unused 38 | __init__.py: F401 39 | 40 | [tool:pytest] 41 | doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS -------------------------------------------------------------------------------- /src/open-r1-multimodal/run_grpo_SAT.sh: -------------------------------------------------------------------------------- 1 | export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL 2 | export LOG_PATH="./debug_log_2b.txt" 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | export MAIN_PROCESS_PORT=29507 # Change this to an available port 5 | 6 | accelerate launch --config_file=configs/zero2.yaml src/open_r1/grpo.py \ 7 | --output_dir outputs/Qwen2-VL-2B-GRPO-Base-SAT \ 8 | --model_name_or_path Qwen/Qwen2-VL-2B \ 9 | --dataset_name SAT \ 10 | --max_prompt_length 1024 \ 11 | --max_completion_length 700 \ 12 | --per_device_train_batch_size 1 \ 13 | --gradient_accumulation_steps 2 \ 14 | --logging_steps 1 \ 15 | --bf16 \ 16 | --gradient_checkpointing 1 \ 17 | --attn_implementation flash_attention_2 \ 18 | --max_pixels 401408 \ 19 | --num_train_epochs 2 \ 20 | --run_name Qwen2-VL-2B-GRPO-SAT \ 21 | --save_steps 100 \ 22 | --save_only_model true \ 23 | --report_to wandb \ 24 | -------------------------------------------------------------------------------- /src/open-r1-multimodal/src/open_r1/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Custom evaluation tasks for LightEval.""" 16 | 17 | from lighteval.metrics.dynamic_metrics import ( 18 | ExprExtractionConfig, 19 | LatexExtractionConfig, 20 | multilingual_extractive_match_metric, 21 | ) 22 | from lighteval.tasks.lighteval_task import LightevalTaskConfig 23 | from lighteval.tasks.requests import Doc 24 | from lighteval.utils.language import Language 25 | 26 | 27 | metric = multilingual_extractive_match_metric( 28 | language=Language.ENGLISH, 29 | fallback_mode="first_match", 30 | precision=5, 31 | gold_extraction_target=(LatexExtractionConfig(),), 32 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()), 33 | aggregation_function=max, 34 | ) 35 | 36 | 37 | def prompt_fn(line, task_name: str = None): 38 | """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically""" 39 | return Doc( 40 | task_name=task_name, 41 | query=line["problem"], 42 | choices=[line["solution"]], 43 | gold_index=0, 44 | ) 45 | 46 | 47 | # Define tasks 48 | aime24 = LightevalTaskConfig( 49 | name="aime24", 50 | suite=["custom"], 51 | prompt_function=prompt_fn, 52 | hf_repo="HuggingFaceH4/aime_2024", 53 | hf_subset="default", 54 | hf_avail_splits=["train"], 55 | evaluation_splits=["train"], 56 | few_shots_split=None, 57 | few_shots_select=None, 58 | generation_size=32768, 59 | metric=[metric], 60 | version=1, 61 | ) 62 | math_500 = LightevalTaskConfig( 63 | name="math_500", 64 | suite=["custom"], 65 | prompt_function=prompt_fn, 66 | hf_repo="HuggingFaceH4/MATH-500", 67 | hf_subset="default", 68 | hf_avail_splits=["test"], 69 | evaluation_splits=["test"], 70 | few_shots_split=None, 71 | few_shots_select=None, 72 | generation_size=32768, 73 | metric=[metric], 74 | version=1, 75 | ) 76 | 77 | # Add tasks to the table 78 | TASKS_TABLE = [] 79 | TASKS_TABLE.append(aime24) 80 | TASKS_TABLE.append(math_500) 81 | 82 | # MODULE LOGIC 83 | if __name__ == "__main__": 84 | print([t["name"] for t in TASKS_TABLE]) 85 | print(len(TASKS_TABLE)) 86 | -------------------------------------------------------------------------------- /src/data/SAT/process_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import re 4 | import json 5 | import random 6 | import argparse 7 | from PIL import Image 8 | import pandas as pd 9 | import duckdb 10 | from tqdm import tqdm 11 | 12 | def ensure_csv_exists(fold): 13 | """Ensure the CSV file exists by converting from Parquet if necessary.""" 14 | csv_file = f'SAT_{fold}.csv' 15 | if not os.path.exists(csv_file): 16 | duckdb.sql(f"""COPY (SELECT * FROM 'SAT_{fold}.parquet') TO '{csv_file}' (HEADER, FORMAT 'csv')""") 17 | return csv_file 18 | 19 | def extract_images(image_bytes): 20 | """Extract image bytes from string using regex.""" 21 | pattern = r"\\xFF\\xD8.*?\\xFF\\xD9" 22 | return re.findall(pattern, image_bytes.strip("[]")) 23 | 24 | def save_images(images_list, fold, index): 25 | """Save images from byte format to PNG files.""" 26 | image_paths = [] 27 | image_folder = f'SAT_images_{fold}' 28 | os.makedirs(image_folder, exist_ok=True) 29 | 30 | for idx, im_bytes in enumerate(images_list): 31 | im_bytes = im_bytes.strip().encode().decode('unicode_escape').encode('raw_unicode_escape') 32 | image = Image.open(io.BytesIO(im_bytes)) 33 | image_path = os.path.join(image_folder, f'{index}_{idx}.png') 34 | image.save(image_path) 35 | image_paths.append(image_path) 36 | 37 | return image_paths 38 | 39 | def process_data(df, fold, total_num): 40 | """Process dataset and generate conversation JSON files.""" 41 | conversations = [] 42 | 43 | for index, example in tqdm(df.iterrows(), total=total_num, desc="Processing indices"): 44 | if index >= total_num: 45 | break 46 | 47 | images_list = extract_images(example['image_bytes']) 48 | if len(images_list) > 1: 49 | continue # Skip multiple image cases 50 | 51 | images = save_images(images_list, fold, index) 52 | image_token = "" if images else "" 53 | 54 | question = example['question'] 55 | answer_choices = list(map(str, example['answers'].strip('[]').split(', '))) 56 | random.shuffle(answer_choices) 57 | correct_answer = example['correct_answer'] 58 | 59 | answer = ", ".join(answer_choices[:-1]) + " or " + answer_choices[-1] 60 | prompt = f"{question} Choose between the following options: {answer}" 61 | messages = [ 62 | {"role": "user", "content": f"{image_token} Answer in natural language. {prompt}"}, 63 | {"role": "assistant", "content": correct_answer} 64 | ] 65 | 66 | conversation = {"messages": messages, "images": images} 67 | 68 | conversations.append(conversation) 69 | 70 | with open(f'SAT_{fold}_{total_num}.json', 'w') as f: 71 | json.dump(conversations, f, indent=4) 72 | 73 | 74 | def main(): 75 | parser = argparse.ArgumentParser(description="Process SAT dataset and generate JSON conversations.") 76 | parser.add_argument('--fold', type=str, default='train', help="Dataset fold to process (e.g., train, val, test)") 77 | parser.add_argument('--total_num', type=int, default=15000, help="Maximum number of examples to process") 78 | args = parser.parse_args() 79 | 80 | csv_file = ensure_csv_exists(args.fold) 81 | df = pd.read_csv(csv_file) 82 | process_data(df, args.fold, args.total_num) 83 | 84 | if __name__ == "__main__": 85 | main() -------------------------------------------------------------------------------- /src/open-r1-multimodal/src/open_r1/trainer/InternVL2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms as T 4 | from decord import VideoReader, cpu 5 | from PIL import Image 6 | from torchvision.transforms.functional import InterpolationMode 7 | 8 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 9 | IMAGENET_STD = (0.229, 0.224, 0.225) 10 | 11 | def build_transform(input_size): 12 | MEAN, STD = IMAGENET_MEAN, IMAGENET_STD 13 | transform = T.Compose([ 14 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 15 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 16 | T.ToTensor(), 17 | T.Normalize(mean=MEAN, std=STD) 18 | ]) 19 | return transform 20 | 21 | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): 22 | best_ratio_diff = float('inf') 23 | best_ratio = (1, 1) 24 | area = width * height 25 | for ratio in target_ratios: 26 | target_aspect_ratio = ratio[0] / ratio[1] 27 | ratio_diff = abs(aspect_ratio - target_aspect_ratio) 28 | if ratio_diff < best_ratio_diff: 29 | best_ratio_diff = ratio_diff 30 | best_ratio = ratio 31 | elif ratio_diff == best_ratio_diff: 32 | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: 33 | best_ratio = ratio 34 | return best_ratio 35 | 36 | def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): 37 | orig_width, orig_height = image.size 38 | aspect_ratio = orig_width / orig_height 39 | 40 | # calculate the existing image aspect ratio 41 | target_ratios = set( 42 | (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if 43 | i * j <= max_num and i * j >= min_num) 44 | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) 45 | 46 | # find the closest aspect ratio to the target 47 | target_aspect_ratio = find_closest_aspect_ratio( 48 | aspect_ratio, target_ratios, orig_width, orig_height, image_size) 49 | 50 | # calculate the target width and height 51 | target_width = image_size * target_aspect_ratio[0] 52 | target_height = image_size * target_aspect_ratio[1] 53 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1] 54 | 55 | # resize the image 56 | resized_img = image.resize((target_width, target_height)) 57 | processed_images = [] 58 | for i in range(blocks): 59 | box = ( 60 | (i % (target_width // image_size)) * image_size, 61 | (i // (target_width // image_size)) * image_size, 62 | ((i % (target_width // image_size)) + 1) * image_size, 63 | ((i // (target_width // image_size)) + 1) * image_size 64 | ) 65 | # split the image 66 | split_img = resized_img.crop(box) 67 | processed_images.append(split_img) 68 | assert len(processed_images) == blocks 69 | if use_thumbnail and len(processed_images) != 1: 70 | thumbnail_img = image.resize((image_size, image_size)) 71 | processed_images.append(thumbnail_img) 72 | return processed_images 73 | 74 | def load_image(image_file, input_size=448, max_num=12): 75 | image = Image.open(image_file).convert('RGB') 76 | transform = build_transform(input_size=input_size) 77 | images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) 78 | pixel_values = [transform(image) for image in images] 79 | pixel_values = torch.stack(pixel_values) 80 | return pixel_values -------------------------------------------------------------------------------- /src/open-r1-multimodal/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms as T 4 | from decord import VideoReader, cpu 5 | from PIL import Image 6 | from torchvision.transforms.functional import InterpolationMode 7 | from transformers import AutoModel, AutoTokenizer 8 | 9 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 10 | IMAGENET_STD = (0.229, 0.224, 0.225) 11 | 12 | def build_transform(input_size): 13 | MEAN, STD = IMAGENET_MEAN, IMAGENET_STD 14 | transform = T.Compose([ 15 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 16 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 17 | T.ToTensor(), 18 | T.Normalize(mean=MEAN, std=STD) 19 | ]) 20 | return transform 21 | 22 | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): 23 | best_ratio_diff = float('inf') 24 | best_ratio = (1, 1) 25 | area = width * height 26 | for ratio in target_ratios: 27 | target_aspect_ratio = ratio[0] / ratio[1] 28 | ratio_diff = abs(aspect_ratio - target_aspect_ratio) 29 | if ratio_diff < best_ratio_diff: 30 | best_ratio_diff = ratio_diff 31 | best_ratio = ratio 32 | elif ratio_diff == best_ratio_diff: 33 | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: 34 | best_ratio = ratio 35 | return best_ratio 36 | 37 | def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): 38 | orig_width, orig_height = image.size 39 | aspect_ratio = orig_width / orig_height 40 | 41 | # calculate the existing image aspect ratio 42 | target_ratios = set( 43 | (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if 44 | i * j <= max_num and i * j >= min_num) 45 | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) 46 | 47 | # find the closest aspect ratio to the target 48 | target_aspect_ratio = find_closest_aspect_ratio( 49 | aspect_ratio, target_ratios, orig_width, orig_height, image_size) 50 | 51 | # calculate the target width and height 52 | target_width = image_size * target_aspect_ratio[0] 53 | target_height = image_size * target_aspect_ratio[1] 54 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1] 55 | 56 | # resize the image 57 | resized_img = image.resize((target_width, target_height)) 58 | processed_images = [] 59 | for i in range(blocks): 60 | box = ( 61 | (i % (target_width // image_size)) * image_size, 62 | (i // (target_width // image_size)) * image_size, 63 | ((i % (target_width // image_size)) + 1) * image_size, 64 | ((i // (target_width // image_size)) + 1) * image_size 65 | ) 66 | # split the image 67 | split_img = resized_img.crop(box) 68 | processed_images.append(split_img) 69 | assert len(processed_images) == blocks 70 | if use_thumbnail and len(processed_images) != 1: 71 | thumbnail_img = image.resize((image_size, image_size)) 72 | processed_images.append(thumbnail_img) 73 | return processed_images 74 | 75 | def load_image(image_file, input_size=448, max_num=12): 76 | image = Image.open(image_file).convert('RGB') 77 | transform = build_transform(input_size=input_size) 78 | images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) 79 | pixel_values = [transform(image) for image in images] 80 | pixel_values = torch.stack(pixel_values) 81 | return pixel_values 82 | 83 | # If you want to load a model using multiple GPUs, please refer to the `Multiple GPUs` section. 84 | path = 'OpenGVLab/InternVL2-2B' 85 | model = AutoModel.from_pretrained( 86 | path, 87 | torch_dtype=torch.bfloat16, 88 | low_cpu_mem_usage=True, 89 | use_flash_attn=True, 90 | trust_remote_code=True).eval().cuda() 91 | tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False) 92 | 93 | # set the max number of tiles in `max_num` 94 | pixel_values = load_image('temp_image.png', max_num=12).to(torch.bfloat16).cuda() 95 | generation_config = dict(max_new_tokens=1024, do_sample=True) 96 | 97 | # pure-text conversation (纯文本对话) 98 | question = 'Hello, who are you?' 99 | response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True) 100 | print(f'User: {question}\nAssistant: {response}') 101 | 102 | question = 'Can you tell me a story?' 103 | response, history = model.chat(tokenizer, None, question, generation_config, history=history, return_history=True) 104 | print(f'User: {question}\nAssistant: {response}') 105 | 106 | # single-image single-round conversation (单图单轮对话) 107 | question = '\nPlease describe the image shortly.' 108 | response = model.chat(tokenizer, pixel_values, question, generation_config) 109 | print(f'User: {question}\nAssistant: {response}') 110 | -------------------------------------------------------------------------------- /src/open-r1-multimodal/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py 16 | 17 | 18 | import re 19 | import shutil 20 | from pathlib import Path 21 | 22 | from setuptools import find_packages, setup 23 | 24 | 25 | # Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466 26 | stale_egg_info = Path(__file__).parent / "open_r1.egg-info" 27 | if stale_egg_info.exists(): 28 | print( 29 | ( 30 | "Warning: {} exists.\n\n" 31 | "If you recently updated open_r1, this is expected,\n" 32 | "but it may prevent open_r1 from installing in editable mode.\n\n" 33 | "This directory is automatically generated by Python's packaging tools.\n" 34 | "I will remove it now.\n\n" 35 | "See https://github.com/pypa/pip/issues/5466 for details.\n" 36 | ).format(stale_egg_info) 37 | ) 38 | shutil.rmtree(stale_egg_info) 39 | 40 | 41 | # IMPORTANT: all dependencies should be listed here with their version requirements, if any. 42 | # * If a dependency is fast-moving (e.g. transformers), pin to the exact version 43 | _deps = [ 44 | "accelerate>=1.2.1", 45 | "bitsandbytes>=0.43.0", 46 | "black>=24.4.2", 47 | "datasets>=3.2.0", 48 | "deepspeed==0.15.4", 49 | "distilabel[vllm,ray,openai]>=1.5.2", 50 | "einops>=0.8.0", 51 | "flake8>=6.0.0", 52 | "hf_transfer>=0.1.4", 53 | "huggingface-hub[cli]>=0.19.2,<1.0", 54 | "isort>=5.12.0", 55 | "liger_kernel==0.5.2", 56 | "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]", 57 | "math-verify", # Used for math verification in grpo 58 | "packaging>=23.0", 59 | "parameterized>=0.9.0", 60 | "pytest", 61 | "safetensors>=0.3.3", 62 | "sentencepiece>=0.1.99", 63 | "torch>=2.5.1", 64 | "transformers @ git+https://github.com/huggingface/transformers.git@main", 65 | "trl==0.14.0", 66 | "vllm==0.6.6.post1", 67 | "wandb>=0.19.1", 68 | "pillow", 69 | "timm", 70 | ] 71 | 72 | # this is a lookup table with items like: 73 | # 74 | # tokenizers: "tokenizers==0.9.4" 75 | # packaging: "packaging" 76 | # 77 | # some of the values are versioned whereas others aren't. 78 | deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)} 79 | 80 | 81 | def deps_list(*pkgs): 82 | return [deps[pkg] for pkg in pkgs] 83 | 84 | 85 | extras = {} 86 | extras["tests"] = deps_list("pytest", "parameterized") 87 | extras["torch"] = deps_list("torch") 88 | extras["quality"] = deps_list("black", "isort", "flake8") 89 | extras["eval"] = deps_list("lighteval", "math-verify") 90 | extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] 91 | 92 | # core dependencies shared across the whole project - keep this to a bare minimum :) 93 | install_requires = [ 94 | deps["accelerate"], 95 | deps["bitsandbytes"], 96 | deps["einops"], 97 | deps["datasets"], 98 | deps["deepspeed"], 99 | deps["hf_transfer"], 100 | deps["huggingface-hub"], 101 | deps["liger_kernel"], 102 | deps["packaging"], # utilities from PyPA to e.g., compare versions 103 | deps["safetensors"], 104 | deps["sentencepiece"], 105 | deps["transformers"], 106 | deps["trl"], 107 | ] 108 | 109 | setup( 110 | name="open-r1", 111 | version="0.1.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) 112 | author="The Hugging Face team (past and future)", 113 | author_email="lewis@huggingface.co", 114 | description="Open R1", 115 | long_description=open("README.md", "r", encoding="utf-8").read(), 116 | long_description_content_type="text/markdown", 117 | keywords="llm inference-time compute reasoning", 118 | license="Apache", 119 | url="https://github.com/huggingface/open-r1", 120 | package_dir={"": "src"}, 121 | packages=find_packages("src"), 122 | zip_safe=False, 123 | extras_require=extras, 124 | python_requires=">=3.10.9", 125 | install_requires=install_requires, 126 | classifiers=[ 127 | "Development Status :: 3 - Alpha", 128 | "Intended Audience :: Developers", 129 | "Intended Audience :: Education", 130 | "Intended Audience :: Science/Research", 131 | "License :: OSI Approved :: Apache Software License", 132 | "Operating System :: OS Independent", 133 | "Programming Language :: Python :: 3", 134 | "Programming Language :: Python :: 3.10", 135 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 136 | ], 137 | ) 138 | -------------------------------------------------------------------------------- /src/open-r1-multimodal/src/open_r1/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional 16 | 17 | from distilabel.llms import OpenAILLM 18 | from distilabel.pipeline import Pipeline 19 | from distilabel.steps.tasks import TextGeneration 20 | 21 | 22 | def build_distilabel_pipeline( 23 | model: str, 24 | base_url: str = "http://localhost:8000/v1", 25 | prompt_column: Optional[str] = None, 26 | temperature: Optional[float] = None, 27 | top_p: Optional[float] = None, 28 | max_new_tokens: int = 8192, 29 | num_generations: int = 1, 30 | ) -> Pipeline: 31 | generation_kwargs = {"max_new_tokens": max_new_tokens} 32 | 33 | if temperature is not None: 34 | generation_kwargs["temperature"] = temperature 35 | 36 | if top_p is not None: 37 | generation_kwargs["top_p"] = top_p 38 | 39 | with Pipeline().ray() as pipeline: 40 | TextGeneration( 41 | llm=OpenAILLM( 42 | base_url=base_url, 43 | api_key="something", 44 | model=model, 45 | # thinking can take some time... 46 | timeout=10 * 60, 47 | generation_kwargs=generation_kwargs, 48 | ), 49 | input_mappings={"instruction": prompt_column} if prompt_column is not None else {}, 50 | input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion 51 | num_generations=num_generations, 52 | ) 53 | 54 | return pipeline 55 | 56 | 57 | if __name__ == "__main__": 58 | import argparse 59 | 60 | from datasets import load_dataset 61 | 62 | parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1") 63 | parser.add_argument( 64 | "--hf-dataset", 65 | type=str, 66 | required=True, 67 | help="HuggingFace dataset to load", 68 | ) 69 | parser.add_argument( 70 | "--hf-dataset-config", 71 | type=str, 72 | required=False, 73 | help="Dataset config to use", 74 | ) 75 | parser.add_argument( 76 | "--hf-dataset-split", 77 | type=str, 78 | default="train", 79 | help="Dataset split to use", 80 | ) 81 | parser.add_argument("--prompt-column", type=str, default="prompt") 82 | parser.add_argument( 83 | "--model", 84 | type=str, 85 | required=True, 86 | help="Model name to use for generation", 87 | ) 88 | parser.add_argument( 89 | "--vllm-server-url", 90 | type=str, 91 | default="http://localhost:8000/v1", 92 | help="URL of the vLLM server", 93 | ) 94 | parser.add_argument( 95 | "--temperature", 96 | type=float, 97 | help="Temperature for generation", 98 | ) 99 | parser.add_argument( 100 | "--top-p", 101 | type=float, 102 | help="Top-p value for generation", 103 | ) 104 | parser.add_argument( 105 | "--max-new-tokens", 106 | type=int, 107 | default=8192, 108 | help="Maximum number of new tokens to generate", 109 | ) 110 | parser.add_argument( 111 | "--num-generations", 112 | type=int, 113 | default=1, 114 | help="Number of generations per problem", 115 | ) 116 | parser.add_argument( 117 | "--hf-output-dataset", 118 | type=str, 119 | required=False, 120 | help="HuggingFace repo to push results to", 121 | ) 122 | parser.add_argument( 123 | "--private", 124 | action="store_true", 125 | help="Whether to make the output dataset private when pushing to HF Hub", 126 | ) 127 | 128 | args = parser.parse_args() 129 | 130 | print("\nRunning with arguments:") 131 | for arg, value in vars(args).items(): 132 | print(f" {arg}: {value}") 133 | print() 134 | 135 | print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...") 136 | dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split) 137 | print("Dataset loaded!") 138 | 139 | pipeline = build_distilabel_pipeline( 140 | model=args.model, 141 | base_url=args.vllm_server_url, 142 | prompt_column=args.prompt_column, 143 | temperature=args.temperature, 144 | top_p=args.top_p, 145 | max_new_tokens=args.max_new_tokens, 146 | num_generations=args.num_generations, 147 | ) 148 | 149 | print("Running generation pipeline...") 150 | distiset = pipeline.run(dataset=dataset, use_cache=False) 151 | print("Generation pipeline finished!") 152 | 153 | if args.hf_output_dataset: 154 | print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...") 155 | distiset.push_to_hub(args.hf_output_dataset, private=args.private) 156 | print("Dataset pushed!") 157 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VisualThinker-R1-Zero: First ever R1-Zero's Aha Moment on just a 2B non-SFT Model 2 | [![Notion](https://img.shields.io/badge/Notion-%23000000.svg?style=for-the-badge&logo=notion&logoColor=white)](https://turningpointai.notion.site/the-multimodal-aha-moment-on-2b-model) 3 | 4 | ![Reinforcement Learning](https://img.shields.io/badge/Algo-Reinforcement--Learning-red) 5 | ![R1](https://img.shields.io/badge/Algo-R1-red) 6 | ![Vision-Centric](https://img.shields.io/badge/Task-Vision--Centric-yellow) 7 | ![Qwen2-VL-2B](https://img.shields.io/badge/Model-Qwen2--VL--2B-green) 8 | ![Aha-Moment](https://img.shields.io/badge/Analysis-Aha--moment-blue) 9 | 10 | VisualThinker-R1-Zero is a replication of [DeepSeek-R1-Zero](https://arxiv.org/abs/2501.12948) in visual reasoning. We are **the first** to successfully observe **the emergent “aha moment”** and **increased response length** in **visual reasoning** on just a **2B non-SFT models**. 11 | 12 | For more details, please refer to the notion [report](https://turningpointai.notion.site/the-multimodal-aha-moment-on-2b-model). 13 | 14 |
15 | visualthinking-intro-figure_00 16 |
17 | 18 | > Training dynamics of our VisualThinker-R1-Zero training starting from the Qwen-VL-2B, without SFT or reward models. An aha moment and increasing response length is ever observed at a multimodal model. 19 | 20 | ## 🔮 Highlights 21 | 1. We are the **first to successfully produce the emergent “aha moment” and increased response length** for multimodal reasoning on just a **non-SFT 2B model**. 22 | 2. We showed that **vision-centric** tasks could also benefit from improved reasoning capabilities. 23 | 24 | Similar to DeepSeek R1, self reflection behavior is also observed during our RL training on vision-centric reasoning tasks. The model exhibits an emergent ability to rethink and correct its mistakes: 25 | 26 | ``` 27 | . . . 28 | Therefore, dark brown wooden bed with white blanket is not above the doorway. 29 | But wait! I can think of something else. 30 | Maybe it's just higher than above the doorway, but slightly lower than above the doorway. 31 | . . . 32 | ``` 33 | 34 | ## 📢 Updates 35 | - 2025-03-16: 🤗We released the model [checkpoint](https://huggingface.co/turningpoint-ai/VisualThinker-R1-Zero) at huggingface! 36 | - 2025-02-26: 🔥We share our main findings in this [notion blog](https://turningpointai.notion.site/the-multimodal-aha-moment-on-2b-model). 37 | - 2025-02-26: 🔥We release the VisualThinker R1 Zero repo. 38 | 39 | ## 💻 Hardware Requirements 40 | 41 | \* *estimated* 42 | 43 | | Method | Bits | 2B | 44 | | ------------------------ | ---- | ------ | 45 | | GRPO Full Fine-Tuning | AMP | 4*80GB | 46 | 47 | ## 🧱 Setup 48 | 49 | ```bash 50 | bash setup.sh 51 | ``` 52 | ## 🤗 Prepare Dataset 53 | 54 | ```bash 55 | cd src/data/SAT 56 | bash prepare_dataset.sh 57 | ``` 58 | 59 | ## 🏋️ Training 60 | 61 | ### GRPO Training 62 | To reproduce the multimodal aha moment, run the following code to train the non-SFT model with GRPO on SAT: 63 | ```bash 64 | cd src/open-r1-multimodal 65 | bash run_grpo_SAT.sh # Adjust open-r1-multimodal/configs/zero3.yaml or zero2.yaml accordingly 66 | ``` 67 | 68 | ### SFT Training 69 | To obtain SFT model for comparison, run the following code to train the non-SFT model on SAT: 70 | ```bash 71 | cd src/open-r1-multimodal 72 | bash run_sft.sh # Adjust open-r1-multimodal/configs/zero3.yaml or zero2.yaml accordingly 73 | ``` 74 | 75 | ## 📈 Evaluation 76 | 77 | ### CVBench Evaluation 78 | We provide following commands to reproduce our evaluation results on the CVBench. First change to evaluation directory: 79 | ```bash 80 | cd src/eval 81 | ``` 82 | 83 | To evaluate Base + GRPO (VisualThinker R1 Zero) model: 84 | ```bash 85 | python evaluate_Qwen2_VL_CVBench-base.py --model_path \ 86 | --bs 8 \ 87 | --use_reasoning_prompt 88 | ``` 89 | To evaluate Base model: 90 | ```bash 91 | python evaluate_Qwen2_VL_CVBench-base.py --model_path \ 92 | --bs 8 \ 93 | --no-use_reasoning_prompt 94 | ``` 95 | To evaluate Instruct + GRPO model: 96 | ```bash 97 | python evaluate_Qwen2_VL_CVBench.py --model_path \ 98 | --bs 8 \ 99 | --use_reasoning_prompt 100 | ``` 101 | To evaluate Instruct model: 102 | ```bash 103 | python evaluate_Qwen2_VL_CVBench.py --model_path \ 104 | --bs 8 \ 105 | --no-use_reasoning_prompt 106 | ``` 107 | ## 🔍 Resources 108 | 109 | **Full experiment log:** Upcoming 110 | 111 | **Models CKPT:** [🤗VisualThinker-R1-Zero](https://huggingface.co/turningpoint-ai/VisualThinker-R1-Zero) at huggingface 112 | 113 | ## :coffee: Stay Connected! 114 | 115 | We are always open to engaging discussions, collaborations, or even just sharing a virtual coffee. To get in touch or join our team, visit [TurningPoint AI](https://www.turningpoint-ai.com/)'s homepage for contact information. 116 | 117 | ## 📖 Acknowledgements 118 | 119 | We sincerely thank [DeepSeek](https://github.com/deepseek-ai/DeepSeek-R1), [Open-R1](https://github.com/huggingface/open-r1), [QwenVL](https://github.com/QwenLM/Qwen2.5-VL), [Open-R1-Multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal), [R1-V](https://github.com/Deep-Agent/R1-V), [SAT](https://arxiv.org/abs/2412.07755), and [CV-Bench](https://cambrian-mllm.github.io/) for providing open source resources that laid the foundation of our project. 120 | 121 | ## 🤝 Contributors 122 | 123 | Here are the key contributors from [TurningPoint AI](https://www.turningpoint-ai.com/) to this project: 124 | 125 | [Hengguang Zhou](https://hengguangzhou.github.io/)1* , [Xirui Li](https://xirui-li.github.io/)1* , [Ruochen Wang](https://ruocwang.github.io/)1, [Minhao Cheng](https://cmhcbb.github.io/)2, [Tianyi Zhou](https://tianyizhou.github.io/)3 and [Cho-Jui Hsieh](https://web.cs.ucla.edu/~chohsieh/)14 126 | 127 | * Project Leads, Main Advisor 128 | 1University of California, Los Angeles, 2Penn State University, 3University of Maryland and 4Google Research 129 | 130 | 131 | ## :white_check_mark: Cite 132 | 133 | If you find our work useful for your projects, please kindly cite the following BibTeX: 134 | 135 | ```latex 136 | @misc{zhou2025r1zerosahamomentvisual, 137 | title={R1-Zero's "Aha Moment" in Visual Reasoning on a 2B Non-SFT Model}, 138 | author={Hengguang Zhou and Xirui Li and Ruochen Wang and Minhao Cheng and Tianyi Zhou and Cho-Jui Hsieh}, 139 | year={2025}, 140 | eprint={2503.05132}, 141 | archivePrefix={arXiv}, 142 | primaryClass={cs.AI}, 143 | url={https://arxiv.org/abs/2503.05132}, 144 | } 145 | ``` 146 | 147 | -------------------------------------------------------------------------------- /src/open-r1-multimodal/README.md: -------------------------------------------------------------------------------- 1 | # Multimodal Open R1 2 | 3 | We conducted a speed-run on to investigate R1's paradigm in multimodal models after observing growing interest in R1 and studying the elegant implementation of the GRPO algorithm in `open-r1` and `trl`. 4 | 5 | [🤗 Models](https://huggingface.co/lmms-lab/Qwen2-VL-2B-GRPO-8k) | [🤗 Datasets](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) | [Wandb Logs](https://api.wandb.ai/links/libo0013/lz60ml8h) 6 | 7 | > [!NOTE] 8 | > Although our insights may not be guaranteed to be correct, we commit to sharing them truthfully and honestly. We welcome community feedback and discussions to improve our understanding on multimodal reasoning models. We will PR to `open-r1` later to better support community study on multimodal RL. 9 | 10 | ![alt text](assets/lmm_r1.png) 11 | 12 | **What We Did** 13 | - Implemented Multimodal R1 based on [huggingface/open-r1](https://github.com/huggingface/open-r1) and [deepseek-ai/DeepSeek-R1](https://github.com/deepseek-ai/DeepSeek-R1). 14 | - Integrated Qwen2-VL series, Aria-MoE, and other VLMs available in `transformers`. 15 | - Open-sourced the first batch of `8k` multimodal RL training examples focused on Math reasoning. The data is created by GPT4o with reasoning paths and verifiable answers, based on `Math360K` and `Geo170K`. We provide a [script](local_scripts/create_vision_cot_data.py) for users to inspect and create their own data. 16 | - The dataset is available in [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified). 17 | - Open-sourced models trained with GRPO. 18 | - The models are available in [lmms-lab/Qwen2-VL-2B-GRPO-8k](https://huggingface.co/lmms-lab/Qwen2-VL-2B-GRPO-8k) | [lmms-lab/Qwen2-VL-7B-GRPO-8k](https://huggingface.co/lmms-lab/Qwen2-VL-7B-GRPO-8k). 19 | 20 | **Insights and Future Plans** 21 | - Multiple-choice option verification is necessary since many math multimodal problems are MCQs. Discussed in [issue#56](https://github.com/huggingface/open-r1/issues/56) and we customize the verification logic in [src/open_r1/grpo.py](src/open_r1/grpo.py). 22 | - Need to curate RL data to be verifiable, requiring further exploration on effectively converting existing data into RL data and validating GPT4o's curation reliability. 23 | - Current framework is not efficient for large-scale training. Qwen2-VL-2B model takes `10 hours` to train `1 epoch` on `8 H100 GPUs` for `8k samples`. So it's necessary to investigate how to efficiently scale up the training. 24 | - Our init model (Qwen2-VL-2/7B-Instruct) do not show good reasoning ability in our experiments, and during training, the model quickly gather rewards from `format` but not `accuracy`, which is not a good sign for whole RL training. We release our [wandb logs](https://api.wandb.ai/links/libo0013/lz60ml8h) for reference. 25 | 26 | ![image](https://github.com/user-attachments/assets/e0cfca59-3403-4776-97e9-090f2972b903) 27 | 28 | - The community may need to curate better multimodal dataset for RL training. Current dataset is limited to math scenarios since it has verifiable answers. It's unclear how to expand the RL dataset to other general domains with open-ended answer. We welcome community feedback on our current strategy and plan to release a larger dataset if we get clear scaling insights through community discussions. 29 | 30 | 31 | ## Training Models 32 | 33 | > [!NOTE] 34 | > The training commands below are configured for a node of 8 x H100s (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps. 35 | 36 | ### GRPO on Qwen2-VL-2/7B 37 | 38 | To run GRPO on Qwen2-VL-2B: 39 | 40 | ``` 41 | cd /home/tiger/multimodal-open-r1 42 | # pip3 install vllm==0.6.6.post1 43 | pip3 install -e ".[dev]" 44 | 45 | pip3 install wandb==0.18.3 46 | 47 | torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \ # 8 48 | --nnodes="${ARNOLD_WORKER_NUM}" \ # 1 49 | --node_rank="${ARNOLD_ID}" \ # 0 50 | --master_addr="${METIS_WORKER_0_HOST}" \ # 127.0.0.1 51 | --master_port="${port_in_cmd}" \ # 12345 52 | src/open_r1/grpo.py \ 53 | --deepspeed scripts/zero3.json \ 54 | --output_dir checkpoints/Qwen2-VL-2B-GRPO-8k \ 55 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \ 56 | --dataset_name lmms-lab/multimodal-open-r1-8k-verified \ 57 | --max_prompt_length 8192 \ 58 | --per_device_train_batch_size 1 \ 59 | --gradient_accumulation_steps 1 \ 60 | --logging_steps 1 \ 61 | --bf16 \ 62 | --report_to wandb \ 63 | --gradient_checkpointing true \ 64 | --attn_implementation flash_attention_2 \ 65 | --max_pixels 2359296 \ 66 | --save_total_limit 8 \ 67 | --num_train_epochs 1 \ 68 | --run_name Qwen2-VL-2B-GRPO-8k 69 | ``` 70 | 71 | Please refer to [local_scripts/train_qwen2_vl.sh](local_scripts/train_qwen2_vl.sh) for more details. 72 | 73 | Above scripts are naively for `multi-gpu/multi-node` training. 74 | 75 | ### Reasoning matters for evaluation 76 | 77 | Many benchmarks, such as MMMU and AI2D, require the model to directly output an answer without providing reasoning steps. This raises a critical issue for evaluation: does the model truly understand how to derive the answer or is it just guessing or relying on memorization? To address this, we require the model to first generate its reasoning steps before providing the final answer. We then use GPT-4o to extract and score the responses. 78 | 79 | We tested the original Qwen2-VL-2B-Instruct and Qwen2-VL-7B-Instruct models and observed that their scores decreased on certain benchmarks when reasoning steps were included. Subsequently, we compared the scores of our model using the same evaluation method. Our model performed better under the reasoning-based chain-of-thought (CoT) setting. We attribute this improvement to our model’s training on GRPO, which appears to enhance its ability to handle reasoning formats and consequently achieve higher scores. 80 | 81 | | Benchmarks | Qwen2-VL-2B-Instruct(w.o reasoning) | Qwen2-VL-2B-Instruct(w. reasoning) | Qwen2-VL-2B-GRPO-8k(w. reasoning) | Qwen2-VL-7B-Instruct(w.o reasoning) | Qwen2-VL-7B-Instruct(w. reasoning) | Qwen2-VL-7B-GRPO-8k(w. reasoning) | 82 | |----------------|-------------------------------------|------------------------------------|-----------------------------------|-------------------------------------|------------------------------------|-----------------------------------| 83 | | MMMU | 39.7 | 31.2 | 35.22 | 50.8 | 41.9 | 49.4 | 84 | | Mathvista-mini | 51.6 | 48.6 | 49.4 | 57.1 | 60.9 | 60.6 | 85 | 86 | In our logs, we sometimes find out that the model still just outputing the answer with our the reasoning steps (even for our trained models). We believe that this could because the model are not familiar with the reasoning steps and can't decide how to generate it. 87 | 88 | ### Evaluating models 89 | 90 | We use [lmms-eval]([https://github.com/LMMs-Lab/lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval)) to evaluate models, please run: 91 | 92 | ```shell 93 | bash local_scripts/lmms_eval_qwen2vl.sh 94 | ``` 95 | 96 | To reproduce our result on the above benchmarks, please checkout to the `dev/qwen_cot` branch. 97 | 98 | Visual reasoning task evaluation currently are limited in direct answer format and simple parsing logic. Tasks like `mmmu_val`, `mathvista_testmini`, and `mmmu_pro` expect direct answers rather than reasoning traces, and the current parsing logic cannot process step-by-step reasoning. We are actively working on improving this limitation and welcome community contributions to develop a more comprehensive evaluation framework for visual reasoning models. 99 | 100 | ### RL Data Generation 101 | 102 | We provide the first batch of `8k` multimodal RL training examples focused on Math reasoning. The data is generated by GPT4o. We provide the [script](local_scripts/create_vision_cot_data.py) to users to inspect and create their own data. 103 | 104 | Users can view data in [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified). The problem/solution are generated by GPT4o with reasoning path and verifiable answer. The `original question`/`original answer` are from the original dataset. 105 | -------------------------------------------------------------------------------- /src/open-r1-multimodal/src/open_r1/grpo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import re 17 | from datetime import datetime 18 | from dataclasses import dataclass, field 19 | from typing import Optional 20 | import cv2 21 | import numpy as np 22 | 23 | from datasets import load_dataset, load_from_disk, concatenate_datasets 24 | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor 25 | 26 | from math_verify import parse, verify 27 | from open_r1.trainer import Qwen2VLGRPOTrainer 28 | from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config 29 | from PIL import Image 30 | 31 | @dataclass 32 | class GRPOScriptArguments(ScriptArguments): 33 | """ 34 | Script arguments for the GRPO training script. 35 | 36 | Args: 37 | reward_funcs (`list[str]`): 38 | List of reward functions. Possible values: 'accuracy', 'format'. 39 | """ 40 | reward_funcs: list[str] = field( 41 | default_factory=lambda: ["accuracy", "format"], 42 | metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"}, 43 | ) 44 | max_pixels: Optional[int] = field( 45 | default=12845056, 46 | metadata={"help": "Maximum number of pixels for the image"}, 47 | ) 48 | min_pixels: Optional[int] = field( 49 | default=3136, 50 | metadata={"help": "Minimum number of pixels for the image"}, 51 | ) 52 | freeze_llm: bool = field( 53 | default=False, 54 | metadata={"help": "Whether to freeze the LLM parameters during training"}, 55 | ) 56 | freeze_vision: bool = field( 57 | default=False, 58 | metadata={"help": "Whether to freeze the vision model parameters during training"}, 59 | ) 60 | 61 | 62 | def extract_letters(text): # for RAVEN 63 | pattern = r'(^|\s|\[|\()([A-H])(\s|\]|\)|$)' 64 | matches = re.findall(pattern, text) 65 | return [match[1] for match in matches] 66 | 67 | def accuracy_reward(completions, solution, **kwargs): 68 | """Reward function that checks if the completion is correct using either symbolic verification or exact string matching.""" 69 | if isinstance(completions[0],str): 70 | contents = [completion for completion in completions] 71 | else: 72 | contents = [completion[0]["content"] for completion in completions] 73 | rewards = [] 74 | current_time = datetime.now().strftime("%d-%H-%M-%S-%f") 75 | for content, sol in zip(contents, solution): 76 | reward = 0.0 77 | # Try symbolic verification first 78 | try: 79 | answer = parse(content) 80 | if float(verify(answer, parse(sol))) > 0: 81 | reward = 1.0 82 | except Exception: 83 | pass # Continue to next verification method if this fails 84 | 85 | # If symbolic verification failed, try string matching 86 | if reward == 0.0: 87 | try: 88 | # Extract answer from solution if it has think/answer tags 89 | sol_match = re.search(r'(.*?)', sol) 90 | ground_truth = sol_match.group(1).strip() if sol_match else sol.strip() 91 | 92 | # Extract answer from content if it has think/answer tags 93 | content_match = re.search(r'(.*?)', content) 94 | student_answer = content_match.group(1).strip() if content_match else content.strip() 95 | 96 | if student_answer == ground_truth: 97 | reward = 1.0 98 | 99 | if extract_letters(student_answer)[-1] == ground_truth: 100 | reward = 1.0 101 | # Compare the extracted answers 102 | 103 | except Exception: 104 | pass # Keep reward as 0.0 if both methods fail 105 | 106 | rewards.append(reward) 107 | if os.getenv("DEBUG_MODE") == "true": 108 | log_path = os.getenv("LOG_PATH") 109 | # local_rank = int(os.getenv("LOCAL_RANK", 0)) 110 | with open(log_path, "a") as f: 111 | f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n") 112 | f.write(f"Content: {content}\n") 113 | f.write(f"Solution: {sol}\n") 114 | return rewards 115 | 116 | def length_reward(completions, **kwargs): 117 | processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") 118 | rewards = [] 119 | for completion in completions: 120 | if isinstance(completions[0],str): 121 | rewards.append(len(processor.tokenizer(completion)['input_ids']) * 0.001) 122 | else: 123 | rewards.append(len(processor.tokenizer(completion[0]["content"])['input_ids']) * 0.001) 124 | return rewards 125 | 126 | def format_reward(completions, **kwargs): 127 | """Reward function that checks if the completion has a specific format.""" 128 | pattern = r".*?\s*.*?" 129 | if isinstance(completions[0],str): 130 | completion_contents = ["" + completion for completion in completions] 131 | else: 132 | completion_contents = [completion[0]["content"] for completion in completions] 133 | matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents] 134 | return [1.0 if match else 0.0 for match in matches] 135 | 136 | 137 | reward_funcs_registry = { 138 | "accuracy": accuracy_reward, 139 | "format": format_reward, 140 | "length": length_reward, 141 | } 142 | 143 | def main(script_args, training_args, model_args): 144 | # Get reward functions 145 | reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] 146 | 147 | # Check if the model is a base model 148 | base_model_prompt = False 149 | if model_args.model_name_or_path.split("/")[-1] == "Qwen2-VL-2B" or "Base" in model_args.model_name_or_path: 150 | base_model_prompt = True 151 | 152 | QUESTION_TEMPLATE = "{Question} Output the thinking process in and final answer (option) in tags." 153 | 154 | def make_conversation_sat(example, base_model_prompt=False): 155 | if base_model_prompt: 156 | image = Image.open(dataset_prefix + example["images"][0]) 157 | question = example["messages"][0]["content"] 158 | question = question.replace("", "") 159 | prompt = f'A conversation between User and Assistant. The user asks a question about the image, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: {question} \nAssistant: Let me solve this step by step.\n' 160 | 161 | return {"image": image, 162 | "prompt": [ 163 | {"type": "image"}, 164 | {"type": "text", "text": "" + prompt}], 165 | "solution": "" + example["messages"][1]["content"] + "", 166 | } 167 | else: 168 | image = Image.open(dataset_prefix + example["images"][0]) 169 | return {"image": image, 170 | "image_path": example["images"][0], 171 | "prompt": [ 172 | { 173 | "role": "user", 174 | "content": [ 175 | {"type": "image"}, 176 | {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["messages"][0]["content"])}, 177 | ], 178 | }, 179 | ], 180 | "solution": "" + example["messages"][1]["content"] + "", 181 | } 182 | dataset_prefix = "../data/SAT/" 183 | dataset_path = "SAT_train_15000.json" 184 | 185 | import json 186 | # load json file 187 | with open(dataset_prefix + dataset_path, 'r') as f: 188 | sat_dataset = json.load(f) 189 | 190 | dataset = [make_conversation_sat(sample, base_model_prompt) for sample in sat_dataset] 191 | dataset = {'train': dataset} 192 | 193 | trainer_cls = Qwen2VLGRPOTrainer 194 | 195 | # Initialize the GRPO trainer 196 | trainer = trainer_cls( 197 | model=model_args.model_name_or_path, 198 | reward_funcs=reward_funcs, 199 | args=training_args, 200 | train_dataset=dataset[script_args.dataset_train_split], 201 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, 202 | peft_config=get_peft_config(model_args), 203 | attn_implementation=model_args.attn_implementation, 204 | max_pixels=script_args.max_pixels, 205 | min_pixels=script_args.min_pixels, 206 | ) 207 | 208 | if script_args.freeze_vision: 209 | trainer.model.visual.requires_grad_ = False 210 | elif script_args.freeze_llm: 211 | trainer.model.model.requires_grad_ = False 212 | # Train and push the model to the Hub 213 | trainer.train() 214 | 215 | # Save and push to hub 216 | trainer.save_model(training_args.output_dir) 217 | if training_args.push_to_hub: 218 | trainer.push_to_hub(dataset_name=script_args.dataset_name) 219 | 220 | 221 | if __name__ == "__main__": 222 | parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig)) 223 | script_args, training_args, model_args = parser.parse_args_and_config() 224 | main(script_args, training_args, model_args) 225 | -------------------------------------------------------------------------------- /src/open-r1-multimodal/src/open_r1/sft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Supervised fine-tuning script for decoder language models. 17 | 18 | Usage: 19 | 20 | # One 1 node of 8 x H100s 21 | accelerate launch --config_file=configs/zero3.yaml src/open_r1/sft.py \ 22 | --model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ 23 | --dataset_name HuggingFaceH4/Bespoke-Stratos-17k \ 24 | --learning_rate 2.0e-5 \ 25 | --num_train_epochs 1 \ 26 | --packing \ 27 | --max_seq_length 4096 \ 28 | --per_device_train_batch_size 4 \ 29 | --gradient_accumulation_steps 4 \ 30 | --gradient_checkpointing \ 31 | --bf16 \ 32 | --logging_steps 5 \ 33 | --eval_strategy steps \ 34 | --eval_steps 100 \ 35 | --output_dir data/Qwen2.5-1.5B-Open-R1-Distill 36 | """ 37 | import torch 38 | 39 | from datasets import load_dataset 40 | from transformers import AutoTokenizer, AutoProcessor, default_data_collator 41 | from qwen_vl_utils import process_vision_info 42 | 43 | from trl import ( 44 | ModelConfig, 45 | ScriptArguments, 46 | SFTConfig, 47 | SFTTrainer, 48 | TrlParser, 49 | get_kbit_device_map, 50 | get_peft_config, 51 | get_quantization_config, 52 | ) 53 | 54 | from transformers import AutoModelForCausalLM, Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2VLConfig, Qwen2VLForConditionalGeneration 55 | from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor 56 | AutoModelForCausalLM.register(config_class=Qwen2_5_VLConfig, model_class=Qwen2_5_VLForConditionalGeneration) 57 | AutoModelForCausalLM.register(config_class=Qwen2VLConfig, model_class=Qwen2VLForConditionalGeneration) 58 | 59 | from torch.utils.data import Dataset 60 | 61 | from PIL import Image 62 | 63 | class CustomDataset(Dataset): 64 | def __init__(self, data_list): 65 | self.data = data_list 66 | 67 | def __len__(self): 68 | return len(self.data) 69 | 70 | def __getitem__(self, idx): 71 | return self.data[idx] 72 | 73 | SYSTEM_PROMPT = ( 74 | "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " 75 | "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " 76 | "process and answer are enclosed within and tags, respectively, i.e., " 77 | " reasoning process here answer here " 78 | ) 79 | CHAT_TEMPLATE = { 80 | "chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" 81 | } 82 | 83 | # oracle answer 84 | 85 | def main(script_args, training_args, model_args): 86 | ################ 87 | # Model init kwargs & Tokenizer 88 | ################ 89 | quantization_config = get_quantization_config(model_args) 90 | model_kwargs = dict( 91 | revision=model_args.model_revision, 92 | trust_remote_code=model_args.trust_remote_code, 93 | attn_implementation=model_args.attn_implementation, 94 | torch_dtype=model_args.torch_dtype, 95 | use_cache=False if training_args.gradient_checkpointing else True, 96 | device_map=get_kbit_device_map() if quantization_config is not None else None, 97 | quantization_config=quantization_config, 98 | ) 99 | training_args.model_init_kwargs = model_kwargs 100 | tokenizer = AutoTokenizer.from_pretrained( 101 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True 102 | ) 103 | tokenizer.pad_token = tokenizer.eos_token 104 | 105 | ################ 106 | # Dataset 107 | ################ 108 | # dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) 109 | 110 | if script_args.dataset_name == "SAT": 111 | def make_conversation_sat(example): 112 | return [ 113 | { 114 | "role": "user", 115 | "content": [ 116 | { 117 | "type": "image", 118 | "image": dataset_prefix + example["images"][0], 119 | }, 120 | { 121 | "type": "text", 122 | "text": example["messages"][0]["content"], 123 | }, 124 | ], 125 | }, 126 | { 127 | "role": "assistant", 128 | "content": [{"type": "text", "text": example["messages"][1]["content"]}], 129 | }, 130 | ] 131 | 132 | dataset_prefix = "../data/SAT/" 133 | dataset_path = "SAT_train_15000.json" 134 | 135 | import json 136 | # load json file 137 | with open(dataset_prefix + dataset_path, 'r') as f: 138 | sat_dataset = json.load(f) 139 | # import pdb; pdb.set_trace() 140 | dataset = [make_conversation_sat(sample) for sample in sat_dataset] 141 | 142 | print("Dataset is ready") 143 | 144 | dataset = CustomDataset(dataset) 145 | 146 | # import pdb; pdb.set_trace() 147 | 148 | ################ 149 | # Define processor 150 | ################ 151 | def collate_fn(examples): 152 | # Get the texts and images, and apply the chat template 153 | texts = [ 154 | processor.apply_chat_template(example, CHAT_TEMPLATE['chat_template'], tokenize=False) for example in examples 155 | ] # Prepare texts for processing 156 | image_inputs = [process_vision_info(example)[0] for example in examples] 157 | 158 | # Tokenize the texts and process the images 159 | batch = processor( 160 | text=texts, images=image_inputs, return_tensors="pt", padding=True 161 | ) # Encode texts and images into tensors 162 | 163 | # The labels are the input_ids, and we mask the padding tokens in the loss computation 164 | labels = batch["input_ids"].clone() # Clone input IDs for labels 165 | labels[labels == processor.tokenizer.pad_token_id] = -100 # Mask padding tokens in labels 166 | 167 | # Ignore the image token index in the loss computation (model specific) 168 | if isinstance(processor, Qwen2VLProcessor): # Check if the processor is Qwen2VLProcessor 169 | image_tokens = [151652, 151653, 151655] # Specific image token IDs for Qwen2VLProcessor 170 | else: 171 | image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)] # Convert image token to ID 172 | 173 | # Mask image token IDs in the labels 174 | for image_token_id in image_tokens: 175 | labels[labels == image_token_id] = -100 # Mask image token IDs in labels 176 | 177 | batch["labels"] = labels # Add labels to the batch 178 | 179 | return batch 180 | 181 | ################ 182 | # Training 183 | ################ 184 | model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, 185 | torch_dtype=torch.bfloat16, 186 | attn_implementation="flash_attention_2", 187 | ) 188 | 189 | max_pixels = 512*28*28 190 | model.visual.requires_grad_ = True 191 | processor = Qwen2VLProcessor.from_pretrained(model_args.model_name_or_path, max_pixels=max_pixels, padding_side='right') 192 | processor.chat_template = CHAT_TEMPLATE["chat_template"] 193 | 194 | training_args.model_init_kwargs = None 195 | training_args.dataset_text_field = "" 196 | training_args.dataset_kwargs = {"skip_prepare_dataset": True} 197 | 198 | trainer = SFTTrainer( 199 | model=model, 200 | args=training_args, 201 | train_dataset=dataset, 202 | eval_dataset=None, 203 | peft_config=get_peft_config(model_args), 204 | tokenizer=processor.tokenizer, 205 | data_collator=collate_fn, 206 | ) 207 | 208 | trainer.train() 209 | 210 | # Save and push to hub 211 | trainer.save_model(training_args.output_dir) 212 | if training_args.push_to_hub: 213 | trainer.push_to_hub(dataset_name=script_args.dataset_name) 214 | 215 | 216 | if __name__ == "__main__": 217 | parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) 218 | script_args, training_args, model_args = parser.parse_args_and_config() 219 | # print(training_args) 220 | main(script_args, training_args, model_args) 221 | -------------------------------------------------------------------------------- /src/eval/evaluate_Qwen2_VL_CVBench.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor 4 | from qwen_vl_utils import process_vision_info 5 | import torch 6 | import json 7 | from tqdm import tqdm 8 | import re 9 | from PIL import Image 10 | from math_verify import parse, verify, LatexExtractionConfig, ExprExtractionConfig, StringExtractionConfig 11 | from datasets import load_dataset 12 | 13 | import argparse 14 | import os 15 | 16 | def parse_arguments(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--model_path', required=True, type=str) # Base model: "Qwen/Qwen2-VL-2B-Instruct" 19 | parser.add_argument('--bs', default=8, type=int) # reduce it if GPU OOM 20 | parser.add_argument('--output_dir', default="results", type=str) 21 | parser.add_argument("--precomputed_json", type=str) 22 | parser.add_argument("--use_reasoning_prompt", default=True, action=argparse.BooleanOptionalAction) 23 | 24 | return parser.parse_args() 25 | 26 | def extract_answer(output_str): 27 | # Try to find the number within tags, if can not find, return None 28 | answer_pattern = r"\s*(.*?)\s*" 29 | match = re.search(answer_pattern, output_str) 30 | 31 | if match: 32 | return match.group(1) 33 | return None 34 | 35 | def extract_characters_regex(s, choices=['(A)', '(B)', '(C)', '(D)', '(E)', '(F)']): 36 | if type(s) is dict: 37 | s = '' 38 | s = s.strip() 39 | answer_prefixes = [ 40 | 'The best answer is', 41 | 'The correct answer is', 42 | 'The answer is', 43 | 'The answer', 44 | 'The best option is' 45 | 'The correct option is', 46 | 'Best answer:' 47 | 'Best option:', 48 | ] 49 | for answer_prefix in answer_prefixes: 50 | s = s.replace(answer_prefix, '') 51 | 52 | if len(s.split()) > 10 and not re.search('[ABCDEF]', s): 53 | return '' 54 | matches = re.search(r'[ABCDEF]', s) 55 | if matches is None: 56 | for choice in choices: 57 | if s.lower() in choice.lower(): 58 | return choice[1] 59 | return '' 60 | return matches[0] 61 | 62 | if __name__ == "__main__": 63 | cv_bench = load_dataset("nyu-visionx/CV-Bench", split="test") 64 | 65 | args = parse_arguments() 66 | MODEL_PATH=args.model_path 67 | BSZ=args.bs 68 | OUTPUT_DIR=args.output_dir 69 | PRECOMPUTED_RESULT=args.precomputed_json 70 | 71 | correct_counter = 0 72 | counter_task = { 73 | 'Count': 0, 74 | 'Relation': 0, 75 | 'Depth': 0, 76 | 'Distance': 0 77 | } 78 | 79 | counter_correct = { 80 | 'Count': 0, 81 | 'Relation': 0, 82 | 'Depth': 0, 83 | 'Distance': 0 84 | } 85 | final_output = [] 86 | if not PRECOMPUTED_RESULT: 87 | # Handle after SFT on base model, the chat template is not saved [bugs in current transformer version]. 88 | # Check template 89 | template_path = os.path.join(MODEL_PATH, '/chat_template.json') 90 | if os.path.exists(template_path): 91 | print(f"Template found at {template_path}.") 92 | else: 93 | from huggingface_hub import hf_hub_download 94 | hf_hub_download(repo_id="Qwen/Qwen2-VL-2B-Instruct", filename='chat_template.json', local_dir=MODEL_PATH) 95 | print(f"Template downloaded for {MODEL_PATH}.") 96 | 97 | # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios. 98 | model = Qwen2VLForConditionalGeneration.from_pretrained( 99 | MODEL_PATH, 100 | torch_dtype=torch.bfloat16, 101 | attn_implementation="flash_attention_2", 102 | device_map="auto", 103 | ) 104 | 105 | model.eval() 106 | # default processer 107 | processor = AutoProcessor.from_pretrained(MODEL_PATH) 108 | processor.tokenizer.padding_side = 'left' 109 | 110 | resp_messages = [] 111 | 112 | for i, example in tqdm(enumerate(cv_bench)): 113 | if args.use_reasoning_prompt: 114 | resp_prompt = example['prompt'] + "\nOutput the thinking process in and final answer in tags." 115 | else: 116 | resp_prompt = example['prompt'] 117 | 118 | resp_message = [ 119 | { 120 | "role": "user", 121 | "content": [ 122 | { 123 | "type": "image", 124 | "image": example['image'], 125 | }, 126 | {"type": "text", "text": resp_prompt}, 127 | ], 128 | } 129 | ] 130 | 131 | resp_messages.append(resp_message) 132 | 133 | 134 | # List to store all answers 135 | all_resp_outputs = [] 136 | # Process data in batches 137 | 138 | def generate_batch(batch_messages): 139 | # Preparation for inference 140 | text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages] 141 | 142 | image_inputs, video_inputs = process_vision_info(batch_messages) 143 | 144 | inputs = processor( 145 | text=text, 146 | images=image_inputs, 147 | videos=video_inputs, 148 | padding=True, 149 | return_tensors="pt", 150 | ) 151 | inputs = inputs.to("cuda") 152 | 153 | # Inference: Generation of the output 154 | with torch.no_grad(): 155 | generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=1024, do_sample=False) 156 | 157 | generated_ids_trimmed = [ 158 | out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 159 | ] 160 | batch_output_text = processor.batch_decode( 161 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 162 | ) 163 | return batch_output_text 164 | 165 | for i in tqdm(range(0, len(resp_messages), BSZ)): 166 | 167 | batch_resp_output = generate_batch(resp_messages[i:i + BSZ]) 168 | 169 | all_resp_outputs.extend(batch_resp_output) 170 | 171 | else: 172 | with open(PRECOMPUTED_RESULT, "r") as f: 173 | result = json.load(f)['results'][:-1] 174 | all_resp_outputs = [r['response'] for r in result] 175 | 176 | for i, (input_example, model_resp_output) in enumerate(zip(cv_bench, all_resp_outputs)): 177 | # Count correct answers 178 | ground_truth = input_example['answer'] 179 | model_answer = extract_answer(model_resp_output) 180 | 181 | if not model_answer: 182 | short_response = model_resp_output 183 | else: 184 | short_response = model_answer 185 | 186 | if input_example['answer'] == '(A)': 187 | example_answer = input_example['choices'][0] 188 | elif input_example['answer'] == '(B)': 189 | example_answer = input_example['choices'][1] 190 | elif input_example['answer'] == '(C)': 191 | example_answer = input_example['choices'][2] 192 | elif input_example['answer'] == '(D)': 193 | example_answer = input_example['choices'][3] 194 | elif input_example['answer'] == '(E)': 195 | example_answer = input_example['choices'][4] 196 | elif input_example['answer'] == '(F)': 197 | example_answer = input_example['choices'][5] 198 | 199 | parsed_response = parse(short_response, extraction_config=[LatexExtractionConfig(), ExprExtractionConfig(), StringExtractionConfig(strings=("A", "B", "C", "D", "E", "F", 'a', 'b', 'c', 'd', 'e', 'f', '(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(A)', '(B)', '(C)', '(D)', '(E)', '(F)') + tuple(input_example['choices']))]) 200 | parsed_answer = [example_answer.lower(), example_answer, ground_truth[1], ground_truth[1].lower(), ground_truth.lower(), ground_truth] 201 | 202 | if verify(target=parsed_response, gold=parsed_answer): 203 | correct = 1 204 | correct_counter += 1 205 | counter_correct[input_example['task']] +=1 206 | else: 207 | correct = 0 208 | 209 | counter_task[input_example['task']] +=1 210 | 211 | result = { 212 | 'question': input_example['question'], 213 | 'options': input_example['choices'], 214 | "task": input_example['task'], 215 | 'ground_truth': ground_truth, 216 | 'response': model_resp_output, 217 | "model_answer": short_response, 218 | "correct": correct, 219 | } 220 | final_output.append(result) 221 | 222 | acc = {"Total Accuracy:": correct_counter / len(cv_bench)} 223 | acc['Count'] = counter_task['Count'] 224 | acc['Relation'] = counter_task['Relation'] 225 | acc['Depth'] = counter_task['Depth'] 226 | acc['Distance'] = counter_task['Distance'] 227 | 228 | acc['Count_acc'] = counter_correct['Count'] / counter_task['Count'] 229 | acc['Relation_acc'] = counter_correct['Relation'] / counter_task['Relation'] 230 | acc['Depth_acc'] = counter_correct['Depth'] / counter_task['Depth'] 231 | acc['Distance_acc'] = counter_correct['Distance'] / counter_task['Distance'] 232 | 233 | print(acc) 234 | 235 | if not os.path.exists(OUTPUT_DIR): 236 | os.makedirs(OUTPUT_DIR) 237 | 238 | model_name = MODEL_PATH.split('/')[-1] 239 | reasoning_tag = "reasoning" if args.use_reasoning_prompt else "no_reasoning" 240 | with open(os.path.join(OUTPUT_DIR, f"CVBench_result_{model_name}_{reasoning_tag}.json"), "w") as f: 241 | json.dump({ 242 | 'results': final_output, 243 | "accuracy": acc, 244 | "args": vars(args) 245 | }, f, indent=2) 246 | 247 | 248 | 249 | 250 | 251 | -------------------------------------------------------------------------------- /src/eval/evaluate_Qwen2_VL_CVBench-base.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import torch 4 | import json 5 | import argparse 6 | import pandas as pd 7 | from tqdm import tqdm 8 | from PIL import Image 9 | from math_verify import parse, verify, LatexExtractionConfig, ExprExtractionConfig, StringExtractionConfig 10 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor 11 | from qwen_vl_utils import process_vision_info 12 | from datasets import load_dataset 13 | 14 | def parse_arguments(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--model_path', required=True, type=str) # Base model: "Qwen/Qwen2-VL-2B-Instruct" 17 | parser.add_argument('--bs', default=32, type=int) # Batch size: reduce it if GPU OOM 18 | parser.add_argument('--output_dir', default="results", type=str) 19 | parser.add_argument("--precomputed_json", type=str) 20 | parser.add_argument("--use_reasoning_prompt", default=True, action=argparse.BooleanOptionalAction) 21 | 22 | return parser.parse_args() 23 | 24 | def extract_answer(output_str): 25 | # Try to find the number within tags, if can not find, return None 26 | answer_pattern = r"\s*(.*?)\s*" 27 | match = re.search(answer_pattern, output_str) 28 | 29 | if match: 30 | return match.group(1) 31 | return None 32 | 33 | def extract_characters_regex(s, choices=['(A)', '(B)', '(C)', '(D)', '(E)', '(F)']): 34 | if type(s) is dict: 35 | s = '' 36 | s = s.strip() 37 | answer_prefixes = [ 38 | 'The best answer is', 39 | 'The correct answer is', 40 | 'The answer is', 41 | 'The answer', 42 | 'The best option is' 43 | 'The correct option is', 44 | 'Best answer:' 45 | 'Best option:', 46 | ] 47 | for answer_prefix in answer_prefixes: 48 | s = s.replace(answer_prefix, '') 49 | 50 | if len(s.split()) > 10 and not re.search('[ABCDEF]', s): 51 | return '' 52 | matches = re.search(r'[ABCDEF]', s) 53 | if matches is None: 54 | for choice in choices: 55 | if s.lower() in choice.lower(): 56 | return choice[1] 57 | return '' 58 | return matches[0] 59 | def load_images(messsages): 60 | images = [] 61 | for message in messsages: 62 | for item in message: 63 | if item['type'] == 'image': 64 | if type(item['image']) == str: 65 | image_path = item['image'] 66 | image = Image.open(image_path) 67 | images.append(image) 68 | else: 69 | images.append(item['image']) 70 | return images 71 | if __name__ == "__main__": 72 | cv_bench = load_dataset("nyu-visionx/CV-Bench", split="test") 73 | 74 | args = parse_arguments() 75 | MODEL_PATH=args.model_path 76 | BSZ=args.bs 77 | OUTPUT_DIR=args.output_dir 78 | PRECOMPUTED_RESULT=args.precomputed_json 79 | 80 | correct_counter = 0 81 | counter_task = { 82 | 'Count': 0, 83 | 'Relation': 0, 84 | 'Depth': 0, 85 | 'Distance': 0 86 | } 87 | 88 | counter_correct = { 89 | 'Count': 0, 90 | 'Relation': 0, 91 | 'Depth': 0, 92 | 'Distance': 0 93 | } 94 | final_output = [] 95 | if not PRECOMPUTED_RESULT: 96 | #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios. 97 | model = Qwen2VLForConditionalGeneration.from_pretrained( 98 | MODEL_PATH, 99 | torch_dtype=torch.bfloat16, 100 | attn_implementation="flash_attention_2", 101 | device_map="auto", 102 | ) 103 | 104 | model.eval() 105 | model = torch.nn.DataParallel(model) 106 | 107 | model.module = torch.compile(model.module) 108 | 109 | # default processer 110 | processor = AutoProcessor.from_pretrained(MODEL_PATH) 111 | # processor.tokenizer.padding_side = 'left' 112 | 113 | resp_messages = [] 114 | 115 | for i, example in tqdm(enumerate(cv_bench)): 116 | 117 | question = example['prompt'] 118 | 119 | if args.use_reasoning_prompt: 120 | res_prompt = f'A conversation between User and Assistant. The user asks a question about the image, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: {question} \nAssistant: Let me solve this step by step.\n' 121 | else: 122 | res_prompt = f'A conversation between User and Assistant. The user asks a question about the image, and the Assistant solves it.\nUser: {question} \nAssistant: ' 123 | 124 | resp_message = [ 125 | 126 | { 127 | "type": "image", 128 | "image": example['image'], 129 | }, 130 | {"type": "text", "text": "" + res_prompt}, 131 | ] 132 | 133 | resp_messages.append(resp_message) 134 | 135 | # List to store all answers 136 | all_resp_outputs = [] 137 | 138 | # Process data in batches 139 | def generate_batch(batch_messages): 140 | 141 | # Preparation for inference 142 | text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages] 143 | 144 | images = load_images(batch_messages) 145 | inputs = processor( 146 | text=text, 147 | images=images, 148 | padding=True, 149 | return_tensors="pt", 150 | ) 151 | inputs = inputs.to(model.module.device) 152 | 153 | with torch.no_grad(): 154 | generated_ids = model.module.generate(**inputs, use_cache=True, max_new_tokens=1024, do_sample=False, temperature=1) 155 | 156 | generated_ids_trimmed = [ 157 | out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 158 | ] 159 | batch_output_text = processor.batch_decode( 160 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 161 | ) 162 | return batch_output_text 163 | 164 | for i in tqdm(range(0, len(resp_messages), BSZ)): 165 | batch_messages = resp_messages[i:i + BSZ] 166 | 167 | batch_resp_output = generate_batch(resp_messages[i:i + BSZ]) 168 | 169 | all_resp_outputs.extend(batch_resp_output) 170 | print(f"Processed batch {i//BSZ + 1}/{(len(resp_messages) + BSZ - 1)//BSZ}") 171 | 172 | else: 173 | 174 | with open(PRECOMPUTED_RESULT, "r") as f: 175 | result = json.load(f)['results'][:-1] 176 | all_resp_outputs = [r['response'] for r in result] 177 | 178 | for i, (input_example, model_resp_output) in enumerate(zip(cv_bench, all_resp_outputs)): 179 | # Count correct answers 180 | ground_truth = input_example['answer'] 181 | model_answer = extract_answer(model_resp_output) 182 | 183 | if not model_answer: 184 | short_response = model_resp_output 185 | else: 186 | short_response = model_answer 187 | 188 | if input_example['answer'] == '(A)': 189 | example_answer = input_example['choices'][0] 190 | elif input_example['answer'] == '(B)': 191 | example_answer = input_example['choices'][1] 192 | elif input_example['answer'] == '(C)': 193 | example_answer = input_example['choices'][2] 194 | elif input_example['answer'] == '(D)': 195 | example_answer = input_example['choices'][3] 196 | elif input_example['answer'] == '(E)': 197 | example_answer = input_example['choices'][4] 198 | elif input_example['answer'] == '(F)': 199 | example_answer = input_example['choices'][5] 200 | 201 | parsed_response = parse(short_response, extraction_config=[LatexExtractionConfig(), ExprExtractionConfig(), StringExtractionConfig(strings=("A", "B", "C", "D", "E", "F", 'a', 'b', 'c', 'd', 'e', 'f', '(a)', '(b)', '(c)', '(d)', '(e)', '(f)', '(A)', '(B)', '(C)', '(D)', '(E)', '(F)') + tuple(input_example['choices']))]) 202 | parsed_answer = [example_answer.lower(), example_answer, ground_truth[1], ground_truth[1].lower(), ground_truth.lower(), ground_truth] 203 | 204 | if verify(target=parsed_response, gold=parsed_answer): 205 | correct = 1 206 | correct_counter += 1 207 | counter_correct[input_example['task']] +=1 208 | else: 209 | correct = 0 210 | 211 | counter_task[input_example['task']] +=1 212 | 213 | result = { 214 | 'question': input_example['question'], 215 | 'options': input_example['choices'], 216 | "task": input_example['task'], 217 | 'ground_truth': ground_truth, 218 | 'response': model_resp_output, 219 | "model_answer": short_response, 220 | "correct": correct, 221 | } 222 | final_output.append(result) 223 | 224 | acc = {"Total Accuracy:": correct_counter / len(cv_bench)} 225 | acc['Count'] = counter_task['Count'] 226 | acc['Relation'] = counter_task['Relation'] 227 | acc['Depth'] = counter_task['Depth'] 228 | acc['Distance'] = counter_task['Distance'] 229 | 230 | acc['Count_acc'] = counter_correct['Count'] / counter_task['Count'] 231 | acc['Relation_acc'] = counter_correct['Relation'] / counter_task['Relation'] 232 | acc['Depth_acc'] = counter_correct['Depth'] / counter_task['Depth'] 233 | acc['Distance_acc'] = counter_correct['Distance'] / counter_task['Distance'] 234 | 235 | print(acc) 236 | 237 | if not os.path.exists(OUTPUT_DIR): 238 | os.makedirs(OUTPUT_DIR) 239 | 240 | model_name = MODEL_PATH.split('/')[-1] 241 | reasoning_tag = "reasoning" if args.use_reasoning_prompt else "no_reasoning" 242 | with open(os.path.join(OUTPUT_DIR, f"CVBench_result_{model_name}_{reasoning_tag}.json"), "w") as f: 243 | json.dump({ 244 | 'results': final_output, 245 | "accuracy": acc, 246 | "args": vars(args) 247 | }, f, indent=2) 248 | -------------------------------------------------------------------------------- /src/open-r1-multimodal/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import textwrap 17 | from collections import defaultdict 18 | from typing import Any, Callable, Optional, Union 19 | 20 | import torch 21 | import torch.utils.data 22 | import transformers 23 | from datasets import Dataset, IterableDataset 24 | from packaging import version 25 | from transformers import ( 26 | AutoModel, 27 | AriaForConditionalGeneration, 28 | AriaProcessor, 29 | AutoModelForCausalLM, 30 | AutoModelForSequenceClassification, 31 | AutoProcessor, 32 | AutoTokenizer, 33 | GenerationConfig, 34 | PreTrainedModel, 35 | PreTrainedTokenizerBase, 36 | Qwen2VLForConditionalGeneration, 37 | Trainer, 38 | TrainerCallback, 39 | is_wandb_available, 40 | ) 41 | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled 42 | from transformers.utils import is_peft_available 43 | 44 | from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template 45 | from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation 46 | from trl.trainer.grpo_config import GRPOConfig 47 | from trl.trainer.utils import generate_model_card, get_comet_experiment_url 48 | 49 | import copy 50 | import json 51 | 52 | # from InternVL2 import load_image 53 | 54 | 55 | if is_peft_available(): 56 | from peft import PeftConfig, get_peft_model 57 | 58 | if is_wandb_available(): 59 | import wandb 60 | 61 | # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of 62 | # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. 63 | RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] 64 | 65 | 66 | class Qwen2VLGRPOTrainer(Trainer): 67 | """ 68 | Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the 69 | paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300). 70 | 71 | Example: 72 | 73 | ```python 74 | from datasets import load_dataset 75 | from trl import GRPOTrainer 76 | 77 | dataset = load_dataset("trl-lib/tldr", split="train") 78 | 79 | trainer = GRPOTrainer( 80 | model="Qwen/Qwen2-0.5B-Instruct", 81 | reward_funcs="weqweasdas/RM-Gemma-2B", 82 | train_dataset=dataset, 83 | ) 84 | 85 | trainer.train() 86 | ``` 87 | 88 | Args: 89 | model (`Union[str, PreTrainedModel]`): 90 | Model to be trained. Can be either: 91 | 92 | - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or 93 | a path to a *directory* containing model weights saved using 94 | [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is 95 | loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments 96 | in `args.model_init_kwargs`. 97 | - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. 98 | reward_funcs (`Union[RewardFunc, list[RewardFunc]]`): 99 | Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward 100 | functions with the prompts and completions and sum the rewards. Can be either: 101 | 102 | - A single reward function, such as: 103 | - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a 104 | path to a *directory* containing model weights saved using 105 | [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded 106 | using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the 107 | keyword arguments in `args.model_init_kwargs`. 108 | - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. 109 | - A custom reward function: The function is provided with the prompts and the generated completions, 110 | plus any additional columns in the dataset. It should return a list of rewards. For more details, see 111 | [Using a custom reward function](#using-a-custom-reward-function). 112 | - A list of reward functions, where each item can independently be any of the above types. Mixing different 113 | types within the list (e.g., a string model ID and a custom reward function) is allowed. 114 | args ([`GRPOConfig`], *optional*, defaults to `None`): 115 | Configuration for this trainer. If `None`, a default configuration is used. 116 | train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): 117 | Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is 118 | ignored. The format of the samples can be either: 119 | 120 | - [Standard](dataset_formats#standard): Each sample contains plain text. 121 | - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role 122 | and content). 123 | eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): 124 | Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. 125 | processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): 126 | Processing class used to process the data. The padding side must be set to "left". If `None`, the 127 | processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`]. 128 | reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`): 129 | Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: 130 | 131 | - A single processing class: Used when `reward_funcs` contains only one reward function. 132 | - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. 133 | If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is 134 | `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`]. 135 | For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]), 136 | the corresponding entries in `reward_processing_classes` are ignored. 137 | callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): 138 | List of callbacks to customize the training loop. Will add those to the list of default callbacks 139 | detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). 140 | 141 | If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] 142 | method. 143 | optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): 144 | A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your 145 | model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. 146 | peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): 147 | PEFT configuration used to wrap the model. If `None`, the model is not wrapped. 148 | """ 149 | 150 | def __init__( 151 | self, 152 | model: Union[str, PreTrainedModel], 153 | reward_funcs: Union[RewardFunc, list[RewardFunc]], 154 | args: GRPOConfig = None, 155 | train_dataset: Optional[Union[Dataset, IterableDataset]] = None, 156 | eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, 157 | processing_class: Optional[PreTrainedTokenizerBase] = None, 158 | reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None, 159 | callbacks: Optional[list[TrainerCallback]] = None, 160 | optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), 161 | peft_config: Optional["PeftConfig"] = None, 162 | max_pixels: Optional[int] = 12845056, 163 | min_pixels: Optional[int] = 3136, 164 | attn_implementation: str = "flash_attention_2", 165 | ): 166 | # Args 167 | if args is None: 168 | model_name = model if isinstance(model, str) else model.config._name_or_path 169 | model_name = model_name.split("/")[-1] 170 | args = GRPOConfig(f"{model_name}-GRPO") 171 | 172 | # Models 173 | # Trained model 174 | model_init_kwargs = args.model_init_kwargs or {} 175 | model_init_kwargs["attn_implementation"] = attn_implementation 176 | if isinstance(model, str): 177 | model_id = model 178 | torch_dtype = model_init_kwargs.get("torch_dtype") 179 | if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: 180 | pass # torch_dtype is already a torch.dtype or "auto" or None 181 | elif isinstance(torch_dtype, str): # it's a str, but not "auto" 182 | torch_dtype = getattr(torch, torch_dtype) 183 | model_init_kwargs["torch_dtype"] = torch_dtype 184 | else: 185 | raise ValueError( 186 | "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " 187 | f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." 188 | ) 189 | # Disable caching if gradient checkpointing is enabled (not supported) 190 | model_init_kwargs["use_cache"] = ( 191 | False if args.gradient_checkpointing else model_init_kwargs.get("use_cache") 192 | ) 193 | if "Qwen2-VL" in model_id: 194 | model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs) 195 | elif "Aria" in model_id: 196 | model_init_kwargs.pop("use_cache") 197 | model = AriaForConditionalGeneration.from_pretrained(model, **model_init_kwargs) 198 | elif "InternVL2" in model_id: 199 | model_init_kwargs.pop("use_cache") 200 | model = AutoModel.from_pretrained(model, trust_remote_code=True, **model_init_kwargs) 201 | else: 202 | model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) 203 | else: 204 | model_id = model.config._name_or_path 205 | if args.model_init_kwargs is not None: 206 | raise ValueError( 207 | "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " 208 | "This argument can only be used when the `model` argument is a string." 209 | ) 210 | 211 | if peft_config is not None: 212 | model = get_peft_model(model, peft_config) 213 | 214 | # Reference model 215 | if is_deepspeed_zero3_enabled(): 216 | if "Qwen2-VL" in model_id: 217 | self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs) 218 | elif "Aria" in model_id: 219 | self.ref_model = AriaForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs) 220 | elif "InternVL2" in model_id: 221 | self.ref_model = AutoModel.from_pretrained(model_id, trust_remote_code=True, **model_init_kwargs) 222 | else: 223 | self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) 224 | elif peft_config is None: 225 | # If PEFT configuration is not provided, create a reference model based on the initial model. 226 | self.ref_model = create_reference_model(model) 227 | else: 228 | # If PEFT is used, the reference model is not needed since the adapter can be disabled 229 | # to revert to the initial model. 230 | self.ref_model = None 231 | 232 | # Processing class 233 | if processing_class is None: 234 | if "Qwen2-VL" in model_id or "Aria" in model_id: 235 | processing_class = AutoProcessor.from_pretrained(model_id) 236 | pad_token_id = processing_class.tokenizer.pad_token_id 237 | processing_class.pad_token_id = pad_token_id 238 | processing_class.eos_token_id = processing_class.tokenizer.eos_token_id 239 | if "Qwen2-VL" in model_id: 240 | processing_class.image_processor.max_pixels = max_pixels 241 | processing_class.image_processor.min_pixels = min_pixels 242 | elif "InternVL2" in model_id: 243 | processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, trust_remote_code=True, padding_side="left") 244 | pad_token_id = processing_class.pad_token_id 245 | else: 246 | processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left") 247 | pad_token_id = processing_class.pad_token_id 248 | 249 | # Reward functions 250 | if not isinstance(reward_funcs, list): 251 | reward_funcs = [reward_funcs] 252 | for i, reward_func in enumerate(reward_funcs): 253 | if isinstance(reward_func, str): 254 | reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( 255 | reward_func, num_labels=1, **model_init_kwargs 256 | ) 257 | self.reward_funcs = reward_funcs 258 | 259 | # Reward processing class 260 | if reward_processing_classes is None: 261 | reward_processing_classes = [None] * len(reward_funcs) 262 | elif not isinstance(reward_processing_classes, list): 263 | reward_processing_classes = [reward_processing_classes] 264 | else: 265 | if len(reward_processing_classes) != len(reward_funcs): 266 | raise ValueError("The number of reward processing classes must match the number of reward functions.") 267 | 268 | for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)): 269 | if isinstance(reward_func, PreTrainedModel): 270 | if reward_processing_class is None: 271 | reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path) 272 | if reward_processing_class.pad_token_id is None: 273 | reward_processing_class.pad_token = reward_processing_class.eos_token 274 | # The reward model computes the reward for the latest non-padded token in the input sequence. 275 | # So it's important to set the pad token ID to the padding token ID of the processing class. 276 | reward_func.config.pad_token_id = reward_processing_class.pad_token_id 277 | reward_processing_classes[i] = reward_processing_class 278 | self.reward_processing_classes = reward_processing_classes 279 | 280 | # Data collator 281 | def data_collator(features): # No data collation is needed in GRPO 282 | return features 283 | 284 | # Training arguments 285 | self.max_prompt_length = args.max_prompt_length 286 | self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper 287 | self.num_generations = args.num_generations # = G in the GRPO paper 288 | self.generation_config = GenerationConfig( 289 | max_new_tokens=self.max_completion_length, 290 | do_sample=True, 291 | temperature=1, # HACK 292 | num_return_sequences=self.num_generations, 293 | pad_token_id=pad_token_id, 294 | ) 295 | self.beta = args.beta 296 | 297 | # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the 298 | # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the 299 | # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning: 300 | # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To 301 | # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True. 302 | # This acts as a flag to indicate that the warning has already been issued. 303 | model.warnings_issued["estimate_tokens"] = True 304 | 305 | # Initialize the metrics 306 | self._metrics = defaultdict(list) 307 | 308 | super().__init__( 309 | model=model, 310 | args=args, 311 | data_collator=data_collator, 312 | train_dataset=train_dataset, 313 | eval_dataset=eval_dataset, 314 | processing_class=processing_class, 315 | callbacks=callbacks, 316 | optimizers=optimizers, 317 | ) 318 | 319 | # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the 320 | # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set 321 | # self.model_accepts_loss_kwargs to False to enable scaling. 322 | self.model_accepts_loss_kwargs = False 323 | 324 | if self.ref_model is not None: 325 | if self.is_deepspeed_enabled: 326 | self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) 327 | else: 328 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) 329 | 330 | for i, reward_func in enumerate(self.reward_funcs): 331 | if isinstance(reward_func, PreTrainedModel): 332 | self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) 333 | 334 | def _set_signature_columns_if_needed(self): 335 | # If `self.args.remove_unused_columns` is True, non-signature columns are removed. 336 | # By default, this method sets `self._signature_columns` to the model's expected inputs. 337 | # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work. 338 | # Instead, we set them to the columns expected by the `training_step` method, hence the override. 339 | if self._signature_columns is None: 340 | self._signature_columns = ["prompt"] 341 | 342 | 343 | # Get the per-token log probabilities for the completions for the model and the reference model 344 | def _get_per_token_logps(self, model, input_ids, attention_mask, pixel_values, image_grid_thw): 345 | logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V) 346 | logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred 347 | input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it 348 | # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak. 349 | per_token_logps = [] 350 | for logits_row, input_ids_row in zip(logits, input_ids): 351 | log_probs = logits_row.log_softmax(dim=-1) 352 | token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) 353 | per_token_logps.append(token_log_prob) 354 | return torch.stack(per_token_logps) 355 | 356 | 357 | # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device. 358 | # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step. 359 | def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: 360 | return inputs 361 | 362 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): 363 | if return_outputs: 364 | raise ValueError("The GRPOTrainer does not support returning outputs") 365 | 366 | prompts = [x["prompt"] for x in inputs] 367 | prompts_text = [apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] 368 | images = [x["image"] for x in inputs] 369 | prompt_inputs = self.processing_class( 370 | text=prompts_text, 371 | images=images, 372 | return_tensors="pt", 373 | padding=True, 374 | padding_side="left", 375 | add_special_tokens=False, 376 | ) 377 | prompt_inputs = super()._prepare_inputs(prompt_inputs) 378 | 379 | prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] 380 | pixel_values = prompt_inputs["pixel_values"] 381 | image_grid_thw = prompt_inputs["image_grid_thw"] 382 | 383 | 384 | if self.max_prompt_length is not None: 385 | prompt_ids = prompt_ids[:, -self.max_prompt_length :] 386 | prompt_mask = prompt_mask[:, -self.max_prompt_length :] 387 | 388 | # Generate completions 389 | with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: 390 | #prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config) 391 | # Generate N times, each generate one with the temp_generation_config , stack the output_ids to prompt_completion_ids, pad the empty places with number 151613 392 | num_generations = self.generation_config.num_return_sequences 393 | temp_generation_config = copy.deepcopy(self.generation_config) 394 | temp_generation_config.num_return_sequences = 1 395 | 396 | all_completions = [] 397 | 398 | for i in range(num_generations): # -1 because we already have one generation 399 | completion = unwrapped_model.generate(**prompt_inputs, generation_config=temp_generation_config) 400 | all_completions.append(completion) 401 | 402 | # Stack all completions and pad if needed 403 | max_length = max(completion.size(1) for completion in all_completions) 404 | padded_completions = [] 405 | 406 | for completion in all_completions: 407 | if completion.size(1) < max_length: 408 | padding = torch.full((completion.size(0), max_length - completion.size(1)), 409 | self.processing_class.tokenizer.pad_token_id, 410 | dtype=completion.dtype, 411 | device=completion.device) 412 | padded_completion = torch.cat([completion, padding], dim=1) 413 | else: 414 | padded_completion = completion 415 | padded_completions.append(padded_completion) 416 | 417 | # Stack all padded completions 418 | prompt_completion_ids = torch.cat(padded_completions, dim=0) 419 | 420 | prompt_length = prompt_ids.size(1) 421 | prompt_ids = prompt_completion_ids[:, :prompt_length] 422 | completion_ids = prompt_completion_ids[:, prompt_length:] 423 | prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0) 424 | 425 | # Mask everything after the first EOS token 426 | is_eos = completion_ids == self.processing_class.eos_token_id 427 | device = self.accelerator.device 428 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) 429 | eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] 430 | sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) 431 | completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() 432 | 433 | # Concatenate prompt_mask with completion_mask for logit computation 434 | attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) 435 | pixel_values = prompt_inputs["pixel_values"][None].repeat_interleave(self.num_generations, dim=0) 436 | image_grid_thw = prompt_inputs["image_grid_thw"].repeat_interleave(self.num_generations, dim=0) 437 | 438 | per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw) 439 | # Get rid of the prompt (-1 because of the shift done in get_per_token_logps) 440 | per_token_logps = per_token_logps[:, prompt_length - 1 :] 441 | 442 | with torch.inference_mode(): 443 | if self.ref_model is not None: 444 | ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw) 445 | else: 446 | with self.accelerator.unwrap_model(model).disable_adapter(): 447 | ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw) 448 | ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :] 449 | 450 | # Compute the KL divergence between the model and the reference model 451 | per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 452 | 453 | # Decode the generated completions 454 | completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) 455 | if is_conversational(inputs[0]): 456 | completions = [[{"role": "assistant", "content": completion}] for completion in completions] 457 | 458 | # Compute the rewards 459 | prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] 460 | 461 | rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) 462 | for i, (reward_func, reward_processing_class) in enumerate( 463 | zip(self.reward_funcs, self.reward_processing_classes) 464 | ): 465 | if isinstance(reward_func, PreTrainedModel): 466 | if is_conversational(inputs[0]): 467 | messages = [{"messages": p + c} for p, c in zip(prompts, completions)] 468 | texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] 469 | else: 470 | texts = [p + c for p, c in zip(prompts, completions)] 471 | reward_inputs = reward_processing_class( 472 | texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False 473 | ) 474 | reward_inputs = super()._prepare_inputs(reward_inputs) 475 | with torch.inference_mode(): 476 | rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) 477 | else: 478 | # Repeat all input columns (but "prompt" and "completion") to match the number of generations 479 | reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]} 480 | for key in reward_kwargs: 481 | for example in inputs: 482 | # Repeat each value in the column for `num_generations` times 483 | reward_kwargs[key].extend([example[key]] * self.num_generations) 484 | output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) 485 | rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) 486 | 487 | # Sum the rewards from all reward functions 488 | rewards = rewards_per_func.sum(dim=1) 489 | 490 | # Compute grouped-wise rewards 491 | mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) 492 | std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) 493 | 494 | # Normalize the rewards to compute the advantages 495 | mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) 496 | std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) 497 | advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) 498 | 499 | log_trajectories = True 500 | if log_trajectories: 501 | save_dir = f"trajectories/trajectories_{self.args.run_name}/step{self.state.global_step}" 502 | os.makedirs(save_dir, exist_ok=True) 503 | save_path = os.path.join(save_dir, f"rank{self.accelerator.process_index}.jsonl") 504 | with open(save_path, "w") as f: 505 | json.dump({ 506 | 'trajectories': [{"messages": {"prompt": p[0], "response": c[0]} if len(p) == 1 else {"prompt": p[1]['text'], "response":c}, "solution": inputs[0]['solution'], "reward": r} for p, c, r in zip(prompts, completions, rewards.view(self.num_generations).tolist())], 507 | }, f, indent=2) 508 | # x - x.detach() allows for preserving gradients from x 509 | per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) 510 | per_token_loss = -(per_token_loss - self.beta * per_token_kl) 511 | loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() 512 | 513 | # Log the metrics 514 | completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() 515 | self._metrics["completion_length"].append(completion_length) 516 | 517 | reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0) 518 | for i, reward_func in enumerate(self.reward_funcs): 519 | if isinstance(reward_func, PreTrainedModel): 520 | reward_func_name = reward_func.config._name_or_path.split("/")[-1] 521 | else: 522 | reward_func_name = reward_func.__name__ 523 | self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) 524 | 525 | self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item()) 526 | 527 | self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item()) 528 | 529 | mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() 530 | self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) 531 | 532 | return loss 533 | 534 | def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: 535 | metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics 536 | logs = {**logs, **metrics} 537 | if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): 538 | super().log(logs, start_time) 539 | else: # transformers<=4.46 540 | super().log(logs) 541 | self._metrics.clear() 542 | 543 | def create_model_card( 544 | self, 545 | model_name: Optional[str] = None, 546 | dataset_name: Optional[str] = None, 547 | tags: Union[str, list[str], None] = None, 548 | ): 549 | """ 550 | Creates a draft of a model card using the information available to the `Trainer`. 551 | 552 | Args: 553 | model_name (`str` or `None`, *optional*, defaults to `None`): 554 | Name of the model. 555 | dataset_name (`str` or `None`, *optional*, defaults to `None`): 556 | Name of the dataset used for training. 557 | tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): 558 | Tags to be associated with the model card. 559 | """ 560 | if not self.is_world_process_zero(): 561 | return 562 | 563 | if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): 564 | base_model = self.model.config._name_or_path 565 | else: 566 | base_model = None 567 | 568 | tags = tags or [] 569 | if isinstance(tags, str): 570 | tags = [tags] 571 | 572 | if hasattr(self.model.config, "unsloth_version"): 573 | tags.append("unsloth") 574 | 575 | citation = textwrap.dedent( 576 | """\ 577 | @article{zhihong2024deepseekmath, 578 | title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}}, 579 | author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, 580 | year = 2024, 581 | eprint = {arXiv:2402.03300}, 582 | """ 583 | ) 584 | 585 | model_card = generate_model_card( 586 | base_model=base_model, 587 | model_name=model_name, 588 | hub_model_id=self.hub_model_id, 589 | dataset_name=dataset_name, 590 | tags=tags, 591 | wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, 592 | comet_url=get_comet_experiment_url(), 593 | trainer_name="GRPO", 594 | trainer_citation=citation, 595 | paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models", 596 | paper_id="2402.03300", 597 | ) 598 | 599 | model_card.save(os.path.join(self.args.output_dir, "README.md")) 600 | --------------------------------------------------------------------------------