├── src ├── open_r1_egoplan │ ├── __init__.py │ ├── trainer │ │ ├── __init__.py │ │ └── grpo_trainer.py │ ├── grpo.py │ └── sft.py └── open_r1.egg-info │ ├── not-zip-safe │ ├── dependency_links.txt │ ├── top_level.txt │ ├── requires.txt │ ├── SOURCES.txt │ └── PKG-INFO ├── qwen-vl-utils ├── .python-version ├── src │ └── qwen_vl_utils │ │ ├── __init__.py │ │ └── vision_process_egoplan.py ├── requirements.lock ├── pyproject.toml ├── requirements-dev.lock └── README.md ├── assets ├── teaser.png ├── data_statistics.png ├── evaluation_results.png └── question_examples.png ├── scripts ├── sft.sh ├── zero3.yaml ├── eval.sh ├── qwen2vl_sft_config.yaml ├── zero2.json ├── zero3.json ├── grpo.sh └── zero3_offload.json ├── Makefile ├── setup.cfg ├── setup.py ├── README.md ├── infer.py └── LICENSE /src/open_r1_egoplan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/open_r1.egg-info/not-zip-safe: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /qwen-vl-utils/.python-version: -------------------------------------------------------------------------------- 1 | 3.8.19 2 | -------------------------------------------------------------------------------- /src/open_r1.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/open_r1.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | open_r1_egoplan 2 | open_r1_video 3 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SEED-Bench-R1/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /assets/data_statistics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SEED-Bench-R1/HEAD/assets/data_statistics.png -------------------------------------------------------------------------------- /assets/evaluation_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SEED-Bench-R1/HEAD/assets/evaluation_results.png -------------------------------------------------------------------------------- /assets/question_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SEED-Bench-R1/HEAD/assets/question_examples.png -------------------------------------------------------------------------------- /src/open_r1_egoplan/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .grpo_trainer import Qwen2VLGRPOTrainer 2 | 3 | 4 | __all__ = ["Qwen2VLGRPOTrainer"] 5 | -------------------------------------------------------------------------------- /qwen-vl-utils/src/qwen_vl_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .vision_process_egoplan import ( 2 | extract_vision_info, 3 | fetch_image, 4 | fetch_video, 5 | process_vision_info, 6 | smart_resize, 7 | ) 8 | -------------------------------------------------------------------------------- /scripts/sft.sh: -------------------------------------------------------------------------------- 1 | export PROJECT_ROOT="/group/40101/milkcychen/SEED-Bench-R1" 2 | 3 | accelerate launch --config_file ${PROJECT_ROOT}/scripts/zero3.yaml \ 4 | src/open_r1_egoplan/sft.py \ 5 | --config ${PROJECT_ROOT}/scripts/qwen2vl_sft_config.yaml 6 | 7 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: style quality 2 | 3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) 4 | export PYTHONPATH = src 5 | 6 | check_dirs := src 7 | 8 | style: 9 | black --line-length 119 --target-version py310 $(check_dirs) setup.py 10 | isort $(check_dirs) setup.py 11 | 12 | quality: 13 | black --check --line-length 119 --target-version py310 $(check_dirs) setup.py 14 | isort --check-only $(check_dirs) setup.py 15 | flake8 --max-line-length 119 $(check_dirs) setup.py 16 | 17 | 18 | # Evaluation 19 | 20 | evaluate: 21 | -------------------------------------------------------------------------------- /scripts/zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /src/open_r1.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | accelerate>=1.2.1 2 | bitsandbytes>=0.43.0 3 | einops>=0.8.0 4 | datasets>=3.2.0 5 | deepspeed==0.15.4 6 | hf_transfer>=0.1.4 7 | huggingface-hub[cli]<1.0,>=0.19.2 8 | liger_kernel==0.5.2 9 | packaging>=23.0 10 | safetensors>=0.3.3 11 | sentencepiece>=0.1.99 12 | transformers 13 | trl 14 | 15 | [dev] 16 | black>=24.4.2 17 | isort>=5.12.0 18 | flake8>=6.0.0 19 | pytest 20 | parameterized>=0.9.0 21 | math-verify 22 | 23 | [eval] 24 | math-verify 25 | 26 | [quality] 27 | black>=24.4.2 28 | isort>=5.12.0 29 | flake8>=6.0.0 30 | 31 | [tests] 32 | pytest 33 | parameterized>=0.9.0 34 | 35 | [torch] 36 | torch>=2.5.1 37 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | python -u infer.py \ 2 | --model_path "checkpoint_path" \ 3 | --test_file_path "/group/40101/milkcychen/SEED-Bench-R1/data/annotations/validation_L1.jsonl" \ 4 | --data_root_dir "/group/40101/milkcychen/SEED-Bench-R1/data" \ 5 | --test_batch_size 1 6 | 7 | python -u infer.py \ 8 | --model_path "checkpoint_path" \ 9 | --test_file_path "/group/40101/milkcychen/SEED-Bench-R1/data/annotations/validation_L2.jsonl" \ 10 | --data_root_dir "/group/40101/milkcychen/SEED-Bench-R1/data" \ 11 | --test_batch_size 1 12 | 13 | 14 | python -u infer.py \ 15 | --model_path "checkpoint_path" \ 16 | --test_file_path "/group/40101/milkcychen/SEED-Bench-R1/data/annotations/validation_L3.jsonl" \ 17 | --data_root_dir "/group/40101/milkcychen/SEED-Bench-R1/data" \ 18 | --test_batch_size 1 -------------------------------------------------------------------------------- /qwen-vl-utils/requirements.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: ["decord"] 7 | # all-features: false 8 | # with-sources: false 9 | # generate-hashes: false 10 | # universal: false 11 | 12 | -e file:. 13 | av==12.3.0 14 | # via qwen-vl-utils 15 | certifi==2022.12.7 16 | # via requests 17 | charset-normalizer==2.1.1 18 | # via requests 19 | decord==0.6.0 20 | # via qwen-vl-utils 21 | idna==3.4 22 | # via requests 23 | numpy==1.24.4 24 | # via decord 25 | packaging==24.1 26 | # via qwen-vl-utils 27 | pillow==10.2.0 28 | # via qwen-vl-utils 29 | requests==2.28.1 30 | # via qwen-vl-utils 31 | urllib3==1.26.13 32 | # via requests 33 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | default_section = FIRSTPARTY 3 | ensure_newline_before_comments = True 4 | force_grid_wrap = 0 5 | include_trailing_comma = True 6 | known_first_party = open_r1 7 | known_third_party = 8 | transformers 9 | datasets 10 | fugashi 11 | git 12 | h5py 13 | matplotlib 14 | nltk 15 | numpy 16 | packaging 17 | pandas 18 | psutil 19 | pytest 20 | rouge_score 21 | sacrebleu 22 | seqeval 23 | sklearn 24 | streamlit 25 | torch 26 | tqdm 27 | 28 | line_length = 119 29 | lines_after_imports = 2 30 | multi_line_output = 3 31 | use_parentheses = True 32 | 33 | [flake8] 34 | ignore = E203, E501, E741, W503, W605 35 | max-line-length = 119 36 | per-file-ignores = 37 | # imported but unused 38 | __init__.py: F401 39 | 40 | [tool:pytest] 41 | doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS -------------------------------------------------------------------------------- /src/open_r1.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | setup.cfg 4 | setup.py 5 | src/open_r1.egg-info/PKG-INFO 6 | src/open_r1.egg-info/SOURCES.txt 7 | src/open_r1.egg-info/dependency_links.txt 8 | src/open_r1.egg-info/not-zip-safe 9 | src/open_r1.egg-info/requires.txt 10 | src/open_r1.egg-info/top_level.txt 11 | src/open_r1_egoplan/__init__.py 12 | src/open_r1_egoplan/evaluate.py 13 | src/open_r1_egoplan/generate.py 14 | src/open_r1_egoplan/grpo.py 15 | src/open_r1_egoplan/sft.py 16 | src/open_r1_egoplan/sft_orig.py 17 | src/open_r1_egoplan/trainer/__init__.py 18 | src/open_r1_egoplan/trainer/grpo_trainer.py 19 | src/open_r1_video/__init__.py 20 | src/open_r1_video/evaluate.py 21 | src/open_r1_video/generate.py 22 | src/open_r1_video/grpo.py 23 | src/open_r1_video/sft.py 24 | src/open_r1_video/trainer/__init__.py 25 | src/open_r1_video/trainer/grpo_trainer.py 26 | src/open_r1_video/trainer/grpo_trainer_w2s.py -------------------------------------------------------------------------------- /scripts/qwen2vl_sft_config.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: /group/40101/milkcychen/SEED-Bench-R1/pretrained_ckpt/Qwen2-VL-7B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | 6 | # Data training arguments 7 | dataset_name: xxx 8 | jsonl_path: /group/40101/milkcychen/SEED-Bench-R1/data/annotations/training_with_cot_6k.jsonl 9 | data_root_dir: /group/40101/milkcychen/SEED-Bench-R1/data 10 | 11 | # SFT trainer config 12 | bf16: true 13 | do_eval: true 14 | eval_strategy: "no" 15 | gradient_accumulation_steps: 16 16 | gradient_checkpointing: true 17 | gradient_checkpointing_kwargs: 18 | use_reentrant: false 19 | learning_rate: 2.0e-05 20 | log_level: info 21 | logging_steps: 1 22 | logging_strategy: steps 23 | lr_scheduler_type: cosine 24 | packing: true 25 | max_seq_length: 8192 26 | max_steps: -1 27 | num_train_epochs: 1 28 | output_dir: /group/40101/milkcychen/SEED-Bench-R1/ckpt/Qwen2-VL-7B-Instruct-SFT/training_with_cot_6k 29 | overwrite_output_dir: true 30 | per_device_eval_batch_size: 1 31 | per_device_train_batch_size: 1 32 | report_to: 33 | - wandb 34 | save_strategy: steps 35 | seed: 42 36 | warmup_ratio: 0.1 37 | save_steps: 1 38 | save_only_model: true -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "none", 26 | "pin_memory": true 27 | }, 28 | "allgather_partitions": true, 29 | "allgather_bucket_size": 2e8, 30 | "overlap_comm": false, 31 | "reduce_scatter": true, 32 | "reduce_bucket_size": 2e8, 33 | "contiguous_gradients": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 3, 16 | "offload_optimizer": { 17 | "device": "none", 18 | "pin_memory": true 19 | }, 20 | "offload_param": { 21 | "device": "none", 22 | "pin_memory": true 23 | }, 24 | "overlap_comm": true, 25 | "contiguous_gradients": true, 26 | "sub_group_size": 1e9, 27 | "reduce_bucket_size": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "stage3_param_persistence_threshold": "auto", 30 | "stage3_max_live_parameters": 1e9, 31 | "stage3_max_reuse_distance": 1e9, 32 | "stage3_gather_16bit_weights_on_model_save": true 33 | }, 34 | 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /scripts/grpo.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT=Qwen2-VL-7B-GRPO 2 | export WANDB_NAME=training_6k-remove-formatreward-matchletterreward-f16 3 | export DEBUG_MODE=true 4 | export PROJECT_ROOT=/group/40101/milkcychen/SEED-Bench-R1 5 | export LOG_PATH=${PROJECT_ROOT}/output_ckpt/$WANDB_PROJECT/$WANDB_NAME/completions_log.txt 6 | 7 | mkdir -p ${PROJECT_ROOT}/output_ckpt/$WANDB_PROJECT/$WANDB_NAME 8 | 9 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node="8" \ 10 | --nnodes="1" \ 11 | --node_rank="0" \ 12 | --master_addr="127.0.0.1" \ 13 | --master_port="12352" \ 14 | src/open_r1_egoplan/grpo.py \ 15 | --deepspeed scripts/zero3_offload.json \ 16 | --output_dir ${PROJECT_ROOT}/output_ckpt/$WANDB_PROJECT/$WANDB_NAME \ 17 | --model_name_or_path ${PROJECT_ROOT}/pretrained_ckpt/Qwen2-VL-7B-Instruct \ 18 | --dataset_name xxx \ 19 | --jsonl_path ${PROJECT_ROOT}/data/annotations/training_6k.jsonl \ 20 | --data_root_dir ${PROJECT_ROOT}/data \ 21 | --max_prompt_length 8192 \ 22 | --learning_rate 1e-6 \ 23 | --beta 0.1 \ 24 | --per_device_train_batch_size 1 \ 25 | --gradient_accumulation_steps 1 \ 26 | --logging_steps 1 \ 27 | --bf16 \ 28 | --torch_dtype bfloat16 \ 29 | --data_seed 42 \ 30 | --report_to wandb \ 31 | --gradient_checkpointing true \ 32 | --attn_implementation flash_attention_2 \ 33 | --num_train_epochs 1 \ 34 | --run_name $WANDB_NAME \ 35 | --save_steps 10 \ 36 | --save_only_model true 37 | 38 | -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 3, 24 | "offload_optimizer": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "offload_param": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "overlap_comm": true, 33 | "contiguous_gradients": true, 34 | "sub_group_size": 1e9, 35 | "reduce_bucket_size": "auto", 36 | "stage3_prefetch_bucket_size": "auto", 37 | "stage3_param_persistence_threshold": "auto", 38 | "stage3_max_live_parameters": 1e9, 39 | "stage3_max_reuse_distance": 1e9, 40 | "gather_16bit_weights_on_model_save": true 41 | }, 42 | "gradient_accumulation_steps": "auto", 43 | "gradient_clipping": "auto", 44 | "train_batch_size": "auto", 45 | "train_micro_batch_size_per_gpu": "auto", 46 | "steps_per_print": 1e5, 47 | "wall_clock_breakdown": false 48 | } -------------------------------------------------------------------------------- /qwen-vl-utils/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "qwen-vl-utils" 3 | version = "0.0.10" 4 | description = "Qwen Vision Language Model Utils - PyTorch" 5 | authors = [ 6 | { name = "Qwen Team", email = "chenkeqin.ckq@alibaba-inc.com" }, 7 | ] 8 | dependencies = [ 9 | "requests", 10 | "pillow", 11 | "av", 12 | "packaging", 13 | ] 14 | readme = "README.md" 15 | requires-python = ">= 3.8" 16 | license = {text = "Apache-2.0"} 17 | keywords = [ 18 | 'large language model', 19 | 'vision language model', 20 | 'qwen-vl', 21 | 'pytorch', 22 | ] 23 | classifiers = [ 24 | 'Development Status :: 4 - Beta', 25 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 26 | 'Programming Language :: Python :: 3', 27 | 'License :: OSI Approved :: Apache Software License', 28 | ] 29 | 30 | [project.urls] 31 | Homepage = "https://github.com/QwenLM/Qwen2-VL/tree/main/qwen-vl-utils" 32 | Repository = "https://github.com/QwenLM/Qwen2-VL.git" 33 | Issues = "https://github.com/QwenLM/Qwen2-VL/issues" 34 | 35 | [project.optional-dependencies] 36 | decord = [ 37 | "decord", 38 | ] 39 | 40 | [build-system] 41 | requires = ["hatchling"] 42 | build-backend = "hatchling.build" 43 | 44 | [tool.rye] 45 | managed = true 46 | dev-dependencies = [ 47 | "torch", 48 | "torchvision", 49 | ] 50 | 51 | [tool.hatch.metadata] 52 | allow-direct-references = true 53 | 54 | [tool.hatch.build.targets.wheel] 55 | packages = ["src/qwen_vl_utils"] 56 | 57 | [tool.ruff] 58 | line-length = 119 59 | 60 | [tool.ruff.lint] 61 | ignore = ["C408", "C901", "E501", "E731", "E741", "W605"] 62 | select = ["C", "E", "F", "I", "W"] 63 | 64 | [tool.ruff.lint.per-file-ignores] 65 | "__init__.py" = ["E402", "F401", "F403", "F811"] 66 | 67 | [tool.ruff.lint.isort] 68 | lines-after-imports = 2 69 | known-first-party = ["qwen_vl_utils"] 70 | 71 | [tool.ruff.format] 72 | quote-style = "double" 73 | indent-style = "space" 74 | skip-magic-trailing-comma = false 75 | line-ending = "auto" 76 | -------------------------------------------------------------------------------- /qwen-vl-utils/requirements-dev.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: ["decord"] 7 | # all-features: false 8 | # with-sources: false 9 | # generate-hashes: false 10 | # universal: false 11 | 12 | -e file:. 13 | av==12.3.0 14 | # via qwen-vl-utils 15 | certifi==2022.12.7 16 | # via requests 17 | charset-normalizer==2.1.1 18 | # via requests 19 | decord==0.6.0 20 | # via qwen-vl-utils 21 | filelock==3.13.1 22 | # via torch 23 | # via triton 24 | fsspec==2024.2.0 25 | # via torch 26 | idna==3.4 27 | # via requests 28 | jinja2==3.1.3 29 | # via torch 30 | markupsafe==2.1.5 31 | # via jinja2 32 | mpmath==1.3.0 33 | # via sympy 34 | networkx==3.1 35 | # via torch 36 | numpy==1.24.1 37 | # via decord 38 | # via torchvision 39 | nvidia-cublas-cu12==12.1.3.1 40 | # via nvidia-cudnn-cu12 41 | # via nvidia-cusolver-cu12 42 | # via torch 43 | nvidia-cuda-cupti-cu12==12.1.105 44 | # via torch 45 | nvidia-cuda-nvrtc-cu12==12.1.105 46 | # via torch 47 | nvidia-cuda-runtime-cu12==12.1.105 48 | # via torch 49 | nvidia-cudnn-cu12==9.1.0.70 50 | # via torch 51 | nvidia-cufft-cu12==11.0.2.54 52 | # via torch 53 | nvidia-curand-cu12==10.3.2.106 54 | # via torch 55 | nvidia-cusolver-cu12==11.4.5.107 56 | # via torch 57 | nvidia-cusparse-cu12==12.1.0.106 58 | # via nvidia-cusolver-cu12 59 | # via torch 60 | nvidia-nccl-cu12==2.20.5 61 | # via torch 62 | nvidia-nvjitlink-cu12==12.6.68 63 | # via nvidia-cusolver-cu12 64 | # via nvidia-cusparse-cu12 65 | nvidia-nvtx-cu12==12.1.105 66 | # via torch 67 | packaging==24.1 68 | # via qwen-vl-utils 69 | pillow==10.2.0 70 | # via qwen-vl-utils 71 | # via torchvision 72 | requests==2.28.1 73 | # via qwen-vl-utils 74 | sympy==1.12 75 | # via torch 76 | torch==2.4.0 77 | # via torchvision 78 | torchvision==0.19.0 79 | triton==3.0.0 80 | # via torch 81 | typing-extensions==4.9.0 82 | # via torch 83 | urllib3==1.26.13 84 | # via requests 85 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py 16 | 17 | 18 | import re 19 | import shutil 20 | from pathlib import Path 21 | 22 | from setuptools import find_packages, setup 23 | 24 | 25 | # Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466 26 | stale_egg_info = Path(__file__).parent / "open_r1.egg-info" 27 | if stale_egg_info.exists(): 28 | print( 29 | ( 30 | "Warning: {} exists.\n\n" 31 | "If you recently updated open_r1, this is expected,\n" 32 | "but it may prevent open_r1 from installing in editable mode.\n\n" 33 | "This directory is automatically generated by Python's packaging tools.\n" 34 | "I will remove it now.\n\n" 35 | "See https://github.com/pypa/pip/issues/5466 for details.\n" 36 | ).format(stale_egg_info) 37 | ) 38 | shutil.rmtree(stale_egg_info) 39 | 40 | 41 | # IMPORTANT: all dependencies should be listed here with their version requirements, if any. 42 | # * If a dependency is fast-moving (e.g. transformers), pin to the exact version 43 | _deps = [ 44 | "accelerate>=1.2.1", 45 | "bitsandbytes>=0.43.0", 46 | "black>=24.4.2", 47 | "datasets>=3.2.0", 48 | "deepspeed==0.15.4", 49 | "distilabel[vllm,ray,openai]>=1.5.2", 50 | "einops>=0.8.0", 51 | "flake8>=6.0.0", 52 | "hf_transfer>=0.1.4", 53 | "huggingface-hub[cli]>=0.19.2,<1.0", 54 | "isort>=5.12.0", 55 | "liger_kernel==0.5.2", 56 | "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]", 57 | "math-verify", # Used for math verification in grpo 58 | "packaging>=23.0", 59 | "parameterized>=0.9.0", 60 | "pytest", 61 | "safetensors>=0.3.3", 62 | "sentencepiece>=0.1.99", 63 | "torch>=2.5.1", 64 | "transformers", 65 | "trl", 66 | "vllm==0.6.6.post1", 67 | "wandb>=0.19.1", 68 | "pillow", 69 | ] 70 | 71 | # this is a lookup table with items like: 72 | # 73 | # tokenizers: "tokenizers==0.9.4" 74 | # packaging: "packaging" 75 | # 76 | # some of the values are versioned whereas others aren't. 77 | deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)} 78 | 79 | 80 | def deps_list(*pkgs): 81 | return [deps[pkg] for pkg in pkgs] 82 | 83 | 84 | extras = {} 85 | extras["tests"] = deps_list("pytest", "parameterized") 86 | extras["torch"] = deps_list("torch") 87 | extras["quality"] = deps_list("black", "isort", "flake8") 88 | extras["eval"] = deps_list("math-verify") 89 | extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] 90 | 91 | # core dependencies shared across the whole project - keep this to a bare minimum :) 92 | install_requires = [ 93 | deps["accelerate"], 94 | deps["bitsandbytes"], 95 | deps["einops"], 96 | deps["datasets"], 97 | deps["deepspeed"], 98 | deps["hf_transfer"], 99 | deps["huggingface-hub"], 100 | deps["liger_kernel"], 101 | deps["packaging"], # utilities from PyPA to e.g., compare versions 102 | deps["safetensors"], 103 | deps["sentencepiece"], 104 | deps["transformers"], 105 | deps["trl"], 106 | ] 107 | 108 | setup( 109 | name="open-r1", 110 | version="0.1.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) 111 | author="The Hugging Face team (past and future)", 112 | author_email="lewis@huggingface.co", 113 | description="Open R1", 114 | long_description=open("README.md", "r", encoding="utf-8").read(), 115 | long_description_content_type="text/markdown", 116 | keywords="llm inference-time compute reasoning", 117 | license="Apache", 118 | url="https://github.com/huggingface/open-r1", 119 | package_dir={"": "src"}, 120 | packages=find_packages("src"), 121 | zip_safe=False, 122 | extras_require=extras, 123 | python_requires=">=3.10.9", 124 | install_requires=install_requires, 125 | classifiers=[ 126 | "Development Status :: 3 - Alpha", 127 | "Intended Audience :: Developers", 128 | "Intended Audience :: Education", 129 | "Intended Audience :: Science/Research", 130 | "License :: OSI Approved :: Apache Software License", 131 | "Operating System :: OS Independent", 132 | "Programming Language :: Python :: 3", 133 | "Programming Language :: Python :: 3.10", 134 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 135 | ], 136 | ) 137 | -------------------------------------------------------------------------------- /qwen-vl-utils/README.md: -------------------------------------------------------------------------------- 1 | # qwen-vl-utils 2 | 3 | Qwen-VL Utils contains a set of helper functions for processing and integrating visual language information with Qwen-VL Series Model. 4 | 5 | ## Install 6 | 7 | ```bash 8 | pip install qwen-vl-utils 9 | ``` 10 | 11 | ## Usage 12 | 13 | ### Qwen2VL 14 | 15 | ```python 16 | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor 17 | from qwen_vl_utils import process_vision_info 18 | 19 | 20 | # You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text. 21 | messages = [ 22 | # Image 23 | ## Local file path 24 | [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}], 25 | ## Image URL 26 | [{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}], 27 | ## Base64 encoded image 28 | [{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}], 29 | ## PIL.Image.Image 30 | [{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}], 31 | ## Model dynamically adjusts image size, specify dimensions if required. 32 | [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}], 33 | # Video 34 | ## Local video path 35 | [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}], 36 | ## Local video frames 37 | [{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}], 38 | ## Model dynamically adjusts video nframes, video height and width. specify args if required. 39 | [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}], 40 | ] 41 | 42 | processor = AutoProcessor.from_pretrained(model_path) 43 | model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto") 44 | text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 45 | images, videos = process_vision_info(messages) 46 | inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt") 47 | print(inputs) 48 | generated_ids = model.generate(**inputs) 49 | print(generated_ids) 50 | ``` 51 | 52 | ### Qwen2.5VL 53 | 54 | ```python 55 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor 56 | from qwen_vl_utils import process_vision_info 57 | 58 | 59 | # You can set the maximum tokens for a video through the environment variable VIDEO_MAX_PIXELS 60 | # based on the maximum tokens that the model can accept. 61 | # export VIDEO_MAX_PIXELS = 32000 * 28 * 28 * 0.9 62 | 63 | 64 | # You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text. 65 | messages = [ 66 | # Image 67 | ## Local file path 68 | [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}], 69 | ## Image URL 70 | [{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}], 71 | ## Base64 encoded image 72 | [{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}], 73 | ## PIL.Image.Image 74 | [{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}], 75 | ## Model dynamically adjusts image size, specify dimensions if required. 76 | [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}], 77 | # Video 78 | ## Local video path 79 | [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}], 80 | ## Local video frames 81 | [{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}], 82 | ## Model dynamically adjusts video nframes, video height and width. specify args if required. 83 | [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}], 84 | ] 85 | 86 | processor = AutoProcessor.from_pretrained(model_path) 87 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto") 88 | text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 89 | images, videos, video_kwargs = process_vision_info(messages, return_video_kwargs=True) 90 | inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt", **video_kwargs) 91 | print(inputs) 92 | generated_ids = model.generate(**inputs) 93 | print(generated_ids) 94 | ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |

4 | Exploring the Effect of Reinforcement Learning on Video Understanding: 5 | Insights from SEED-Bench-R1 6 | 7 | 8 | 9 |

10 | 11 | 12 |

13 | 14 |

15 | 16 |
17 | 18 | ## 🚀Introduction 19 | 20 | Recent advancements in Chain of Thought (COT) generation have significantly improved the reasoning capabilities of Large Language Models (LLMs), with reinforcement learning (RL) emerging as an effective post-training approach. Multimodal Large Language Models (MLLMs) inherit this reasoning potential but remain underexplored in tasks requiring both perception and logical reasoning. To address this, we introduce SEED-Bench-R1, a benchmark designed to systematically evaluate post-training methods for MLLMs in video understanding. It includes intricate real-world videos and complex everyday planning tasks in the format of multiple-choice questions, requiring sophisticated perception and reasoning. SEED-Bench-R1 assesses generalization through a three-level hierarchy: in-distribution, cross-environment, and cross-environment-task scenarios, equipped with a large-scale training dataset with easily verifiable ground-truth answers. Using Qwen2-VL-Instruct-7B as a base model, we compare RL with supervised fine-tuning (SFT), demonstrating RL's data efficiency and superior performance on both in-distribution and out-of-distribution tasks, even outperforming SFT on general video understanding benchmarks like LongVideoBench. Our detailed analysis reveals that RL enhances visual perception but often produces less logically coherent reasoning chains. We identify key limitations such as inconsistent reasoning and overlooked visual cues, and suggest future improvements in base model reasoning, reward modeling, and RL robustness against noisy signals. 21 | 22 | ## 🚩News 23 | - [2025/06/18]💥We release the training code of [GRPO-CARE](https://github.com/TencentARC/GRPO-CARE), a novel consistency-aware RL framework, and the corresponding model checkpoints! 24 | - [2025/03/31] We release the datasets of SEED-Bench-R1 and the training / evaluation codes. 25 | 26 | 27 | 28 | ## 📝Data 29 | 30 | SEED-Bench-R1 consists of a large-scale training set and a hierarchical three-level validation set for in-distribution, cross-environment, and cross-environment-task evaluations. The datasets can be downloaded from [HuggingFace](https://huggingface.co/datasets/TencentARC/SEED-Bench-R1). 31 | 32 | Specifically, SEED-Bench-R1 is built on our prior works, reusing the training and validation data from our [EgoPlan-Bench](https://github.com/ChenYi99/EgoPlan), as well as the test data from our [EgoPlan-Bench2](https://github.com/qiulu66/EgoPlan-Bench2). The validation data from EgoPlan-Bench are used for Level-1 (in-distribution) and Level-2 (OOD, cross-environment) evaluation, while the test data from EgoPlan-Bench2 cover more general domains and are used for Level-3 (OOD, cross-environment-task) evaluation. 33 | 34 |

35 | 36 |

37 | 38 | Questions from the human-verified validation data are formatted as multiple-choice problems. MLLMs need to select the most reasonable answer from four candidate choices. The primary metric is Accuracy. 39 | 40 |

41 | 42 |

43 | 44 | 45 | 46 | ## 🔥Training Models 47 | 48 | > [!NOTE] The training code is modified from [Open-R1-Video](https://github.com/Wang-Xiaodong1899/Open-R1-Video). 49 | > The training commands below are configured for a node of 8 x A100 (40GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps. 50 | 51 | ### Set up 52 | ``` 53 | git clone https://github.com/TencentARC/SEED-Bench-R1.git 54 | cd SEED-Bench-R1 55 | conda create -n r1 python=3.10 56 | conda activate r1 57 | pip3 install -e ".[dev]" 58 | pip3 install flash_attn --no-build-isolation 59 | cd qwen-vl-utils 60 | pip install -e . 61 | cd .. 62 | 63 | # download data and put in data/ 64 | git lfs install 65 | git clone https://huggingface.co/datasets/TencentARC/SEED-Bench-R1 66 | 67 | # download the model checkpoint of Qwen2-VL-7B-Instruct 68 | git clone https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct 69 | ``` 70 | 71 | 72 | ### Post-training Qwen2-VL-7B-Instruct 73 | 74 | - To run GRPO on Qwen2-VL-7B with [grpo.sh](scripts/grpo.sh): 75 | 76 | ``` 77 | bash scripts/grpo.sh 78 | ``` 79 | 80 | 81 | - To run SFT on Qwen2-VL-7B with [sft.sh](scripts/sft.sh): 82 | 83 | ``` 84 | bash scripts/sft.sh 85 | ``` 86 | 87 | 88 | ## 🤖Evaluating Models 89 | 90 | ### Inference 91 | 92 | Inference with the post-trained models 93 | ``` 94 | bash scripts/eval.sh 95 | ``` 96 | 97 | 98 | Evaluation results: 99 | 100 | ![image](assets/evaluation_results.png) 101 | 102 | 103 | ## 🙌References & Acknowledgements 104 | We sincerely thank the contributions from the open source community. The related projects are as follows: 105 | - [Open-R1-Video](https://github.com/Wang-Xiaodong1899/Open-R1-Video) 106 | - [EgoPlan](https://github.com/ChenYi99/EgoPlan) 107 | - [EgoPlan-Bench2](https://github.com/qiulu66/EgoPlan-Bench2) 108 | - [open-r1-multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal) 109 | - [lmm-r1](https://github.com/TideDra/lmm-r1) 110 | - [DeepSeek](https://github.com/deepseek-ai/DeepSeek-R1) 111 | - [Open-R1](https://github.com/huggingface/open-r1) 112 | - [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF) 113 | 114 | ## ⭐License 115 | The video samples in SEED-Bench-R1 are collected from [Epic-Kitchens](https://epic-kitchens.github.io/2025) and [Ego4D](https://ego4d-data.org/). Users must follow the related licenses ([Epic-Kitchens](https://creativecommons.org/licenses/by-nc/4.0/) and [Ego4D](https://ego4ddataset.com/ego4d-license/)) to use these video samples for training and validation. SEED-Bench-R1 does not hold the copyright for these videos and the copyright belongs to the original owner of these datasets. 116 | 117 | 118 | ## 📚Citation 119 | If you find our project helpful, hope you can star our repository and cite our paper as follows: 120 | 121 | ```bibtex 122 | @article{chen2025exploring, 123 | title={Exploring the Effect of Reinforcement Learning on Video Understanding: Insights from SEED-Bench-R1}, 124 | author={Chen, Yi and Ge, Yuying and Wang, Rui and Ge, Yixiao and Qiu, Lu and Shan, Ying and Liu, Xihui}, 125 | journal={arXiv preprint arXiv:2503.24376}, 126 | year={2025} 127 | } 128 | 129 | @article{chen2023egoplan, 130 | title={Egoplan-bench: Benchmarking multimodal large language models for human-level planning}, 131 | author={Chen, Yi and Ge, Yuying and Ge, Yixiao and Ding, Mingyu and Li, Bohao and Wang, Rui and Xu, Ruifeng and Shan, Ying and Liu, Xihui}, 132 | journal={arXiv preprint arXiv:2312.06722}, 133 | year={2023} 134 | } 135 | 136 | @article{qiu2024egoplan, 137 | title={Egoplan-bench2: A benchmark for multimodal large language model planning in real-world scenarios}, 138 | author={Qiu, Lu and Ge, Yuying and Chen, Yi and Ge, Yixiao and Shan, Ying and Liu, Xihui}, 139 | journal={arXiv preprint arXiv:2412.04447}, 140 | year={2024} 141 | } 142 | ``` 143 | -------------------------------------------------------------------------------- /src/open_r1.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.2 2 | Name: open-r1 3 | Version: 0.1.0.dev0 4 | Summary: Open R1 5 | Home-page: https://github.com/huggingface/open-r1 6 | Author: The Hugging Face team (past and future) 7 | Author-email: lewis@huggingface.co 8 | License: Apache 9 | Keywords: llm inference-time compute reasoning 10 | Classifier: Development Status :: 3 - Alpha 11 | Classifier: Intended Audience :: Developers 12 | Classifier: Intended Audience :: Education 13 | Classifier: Intended Audience :: Science/Research 14 | Classifier: License :: OSI Approved :: Apache Software License 15 | Classifier: Operating System :: OS Independent 16 | Classifier: Programming Language :: Python :: 3 17 | Classifier: Programming Language :: Python :: 3.10 18 | Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence 19 | Requires-Python: >=3.10.9 20 | Description-Content-Type: text/markdown 21 | License-File: LICENSE 22 | Requires-Dist: accelerate>=1.2.1 23 | Requires-Dist: bitsandbytes>=0.43.0 24 | Requires-Dist: einops>=0.8.0 25 | Requires-Dist: datasets>=3.2.0 26 | Requires-Dist: deepspeed==0.15.4 27 | Requires-Dist: hf_transfer>=0.1.4 28 | Requires-Dist: huggingface-hub[cli]<1.0,>=0.19.2 29 | Requires-Dist: liger_kernel==0.5.2 30 | Requires-Dist: packaging>=23.0 31 | Requires-Dist: safetensors>=0.3.3 32 | Requires-Dist: sentencepiece>=0.1.99 33 | Requires-Dist: transformers 34 | Requires-Dist: trl 35 | Provides-Extra: tests 36 | Requires-Dist: pytest; extra == "tests" 37 | Requires-Dist: parameterized>=0.9.0; extra == "tests" 38 | Provides-Extra: torch 39 | Requires-Dist: torch>=2.5.1; extra == "torch" 40 | Provides-Extra: quality 41 | Requires-Dist: black>=24.4.2; extra == "quality" 42 | Requires-Dist: isort>=5.12.0; extra == "quality" 43 | Requires-Dist: flake8>=6.0.0; extra == "quality" 44 | Provides-Extra: eval 45 | Requires-Dist: math-verify; extra == "eval" 46 | Provides-Extra: dev 47 | Requires-Dist: black>=24.4.2; extra == "dev" 48 | Requires-Dist: isort>=5.12.0; extra == "dev" 49 | Requires-Dist: flake8>=6.0.0; extra == "dev" 50 | Requires-Dist: pytest; extra == "dev" 51 | Requires-Dist: parameterized>=0.9.0; extra == "dev" 52 | Requires-Dist: math-verify; extra == "dev" 53 | Dynamic: author 54 | Dynamic: author-email 55 | Dynamic: classifier 56 | Dynamic: description 57 | Dynamic: description-content-type 58 | Dynamic: home-page 59 | Dynamic: keywords 60 | Dynamic: license 61 | Dynamic: provides-extra 62 | Dynamic: requires-dist 63 | Dynamic: requires-python 64 | Dynamic: summary 65 | 66 | # Open R1 Video 67 | 68 | We introduce R1's paradigm to video understanding tasks and open-sourced the training code and data. 69 | 70 | [🤗 Models](https://huggingface.co/Xiaodong/Open-R1-Video-7B) | [🤗 Datasets](https://huggingface.co/datasets/Xiaodong/open-r1-video-4k) | [Wandb Logs](https://wandb.ai/xiaodongwang/Qwen2-VL-7B-Video-GRPO/runs/mb6ued4m?nw=nwuserxiaodongwang) 71 | 72 | > [!NOTE] 73 | > Although our insights may not be guaranteed to be correct, we commit to sharing them truthfully and honestly. We welcome community feedback and discussions to improve our understanding on multimodal reasoning models. 74 | 75 | ## News 76 | - [2025/02/22] We release a provisional model [Open-R1-Video-7B](https://huggingface.co/Xiaodong/Open-R1-Video-7B), inference scripts, and evaluation results. 77 | - [2025/02/18] We release training code and data of Open-R1-Video! 78 | 79 | ## Our Findings 80 | ### GRPO training that forces thinking can improve video understanding 81 | 82 | ![image](assets/longvb.png) 83 | 84 | We train [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) on simple video dataset [open-r1-video-4k](https://huggingface.co/datasets/Xiaodong/open-r1-video-4k) using 4 x A100 (80G) GPUs, and the training only utilize video, query, and the ground truth answer (the letter of the correct answer). We only used GRPO (pure reinforcement learning without labeled reasoning trajectories) to train the model and achieved promising rewards during model training. We release our [wandb logs](https://wandb.ai/xiaodongwang/Qwen2-VL-7B-Video-GRPO/runs/mb6ued4m?nw=nwuserxiaodongwang) for reference. 85 | ![image](assets/log.png) 86 | 87 | **What We Did** 88 | - Introduce R1 to Video-LMM (e.g., Qwen2-VL) based on [huggingface/open-r1](https://github.com/huggingface/open-r1) and [deepseek-ai/DeepSeek-R1](https://github.com/deepseek-ai/DeepSeek-R1). 89 | - Open-sourced the simple training data [open-r1-video-4k](https://huggingface.co/datasets/Xiaodong/open-r1-video-4k). 90 | - The simple reformat data is available in [open-r1-video-4k](https://huggingface.co/datasets/Xiaodong/open-r1-video-4k). 91 | - The video data is available in [LLaVA-Video-large-swift](https://huggingface.co/datasets/malterei/LLaVA-Video-large-swift). 92 | 93 | 94 | 95 | ## Training Models 96 | 97 | > [!NOTE] 98 | > The training commands below are configured for a node of 4 x A100 (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps. 99 | 100 | ### Set up 101 | ``` 102 | git clone https://github.com/Wang-Xiaodong1899/Open-R1-Video.git 103 | cd Open-R1-Video 104 | conda create -n r1 python=3.10 105 | conda activate r1 106 | pip3 install -e ".[dev]" 107 | pip3 install flash_attn --no-build-isolation 108 | cd qwen-vl-utils 109 | pip install -e . 110 | cd .. 111 | 112 | # download data and put in data/ 113 | wget https://huggingface.co/datasets/Xiaodong/open-r1-video-4k/resolve/main/LLaVA-Video-large-swift-origin.jsonl 114 | # like: data/LLaVA-Video-large-swift-origin.jsonl 115 | 116 | # download videos 117 | git lfs install 118 | git clone https://huggingface.co/datasets/malterei/LLaVA-Video-large-swift 119 | 120 | ``` 121 | 122 | 123 | ### GRPO on Qwen2-VL/7B 124 | > [!NOTE] 125 | > Our training also support single A100 (80G) GPU training. Just modify the GPU and you’re good to go! 126 | 127 | > We removed format accuracy during 7B model training and slightly modified the final answer matching to calculate the accuracy reward. See this [commit](https://github.com/Wang-Xiaodong1899/Open-R1-Video/commit/2679e082aaf608fd167a0ad5e6f2afb4f548e25f#diff-d6985fa15a3c7864e723ebd4c04bfdc2f13c5e87af36f87d656e32666f8e0eeb). 128 | 129 | To run GRPO on Qwen2-VL-7B: 130 | 131 | ``` 132 | bash qwen-7b.sh 133 | ``` 134 | 135 | Please refer to [qwen-7b.sh](qwen-7b.sh) for more details. 136 | 137 | 138 | ## Evaluating models 139 | 140 | ### Inference 141 | 142 | Infer the video reasoning model! 143 | ``` 144 | python infer.py 145 | ``` 146 | 147 | 148 | ![video](assets/split_5.gif) 149 | 150 | [Video link](https://youtu.be/2evryGv-oZ4) 151 | 152 | Inference results: 153 | 154 | ![image](assets/sample.png) 155 | 156 | ### Evaluation 157 | 158 | > [!NOTE] 159 | > We use [Lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) to evaluate models. 160 | 161 | 162 | | Benchmarks | Qwen2-VL-7B-Instruct(w.o reasoning) | Qwen2-VL-7B-Instruct(w. reasoning) | Open-R1-Video-7B(w. reasoning) | 163 | |---------------------------|-------------------------------------|------------------------------------|---------------------------------| 164 | | [LongVideoBench](https://longvideobench.github.io/) (16 frames) | 53.33 | 41.89 | 48.99 (↑7.1) | 165 | 166 | 167 | ## RL Data Reformat 168 | 169 | We provide the easy reformat method to obtain the data for GRPO training, which only utilize video, query, and final answer. Please refer to [format_video_data.py](scripts/format_video_data.py) for more details. 170 | 171 | Users can view data in [open-r1-video-4k](https://huggingface.co/datasets/Xiaodong/open-r1-video-4k). The `original question`/`original answer` are from the original dataset. 172 | 173 | ## References & Acknowledgements 174 | We sincerely thank the contributions from the open source community, including the reproduction of [DeepSeek](https://github.com/deepseek-ai/DeepSeek-R1), [Open-R1](https://github.com/huggingface/open-r1), and [R1-multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal), etc. 175 | 176 | The related projects are as follows: 177 | - [open-r1-multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal) 178 | - [lmm-r1](https://github.com/TideDra/lmm-r1) 179 | - [DeepSeek](https://github.com/deepseek-ai/DeepSeek-R1) 180 | - [open-r1](https://github.com/huggingface/open-r1) 181 | - [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF) 182 | - [LLaVA-NeXT](https://github.com/LLaVA-VL/LLaVA-NeXT) 183 | - [LLaVA-Video-large-swift](https://huggingface.co/datasets/malterei/LLaVA-Video-large-swift) 184 | 185 | ## Citation 186 | If you find this useful, you can choose to cite us. 187 | 188 | ```bibtex 189 | @misc{wang-2025-open-r1-video, 190 | author = {Xiaodong Wang and Peixi Peng}, 191 | title = {Open-R1-Video}, 192 | year = {2025}, 193 | publisher = {GitHub}, 194 | journal = {GitHub repository}, 195 | howpublished = {\url{https://github.com/Wang-Xiaodong1899/Open-R1-Video}} 196 | } 197 | ``` 198 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | from qwen_vl_utils import process_vision_info 2 | import torch 3 | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor 4 | import os 5 | import argparse 6 | import json 7 | from tqdm import tqdm 8 | from math_verify import parse, verify 9 | import re 10 | import numpy as np 11 | 12 | QUESTION_TEMPLATE = "{Question} Output the thinking process in and final answer in tags, i.e., reasoning process here answer here . " 13 | 14 | def make_conversation_egoplan(example, data_root_dir): 15 | options = [ 16 | example['choice_a'], 17 | example['choice_b'], 18 | example['choice_c'], 19 | example['choice_d'], 20 | ] 21 | 22 | answer_index = example['golden_choice_idx'] 23 | 24 | problem = f"{example['question']}\n" + "\n".join([f"{chr(65 + i)}. {option}" for i, option in enumerate(options)]) + "\n" 25 | solution = f"{answer_index}." 26 | 27 | content = [] 28 | if len(example['task_progress_metadata']) > 0: 29 | video_path = os.path.join(data_root_dir, 'videos', example['video_source'], example['video_basename']) 30 | content.append({"type": "video", "video": video_path}) 31 | 32 | image_path = os.path.join(data_root_dir, 'images', example['video_source'], example['current_observation_basename']) 33 | content.extend([ 34 | {"type": "image", "image": image_path}, 35 | {"type": "text", "text": QUESTION_TEMPLATE.format(Question=problem)}, 36 | ]) 37 | 38 | feature = { 39 | "sample_id": example["sample_id"], 40 | "prompt": [ 41 | { 42 | "role": "user", 43 | "content": content, 44 | }, 45 | ], 46 | 'problem': problem, 47 | 'solution': solution, 48 | } 49 | return feature 50 | 51 | 52 | 53 | def make_conversation_longvideobench(example, data_root_dir): 54 | options = example['candidates'] 55 | answer_index = chr(65 + example['correct_choice']) 56 | 57 | problem = f"{example['question']}\n" + "\n".join([f"{chr(65 + i)}. {option}" for i, option in enumerate(options)]) + "\n" 58 | solution = f"{answer_index}." 59 | video_path = os.path.join(data_root_dir, "videos", example["video_path"]) 60 | content = [ 61 | {"type": "video", "video": video_path}, 62 | {"type": "text", "text": QUESTION_TEMPLATE.format(Question=problem)}, 63 | ] 64 | 65 | feature = { 66 | "sample_id": example["sample_id"], 67 | "prompt": [ 68 | { 69 | "role": "user", 70 | "content": content, 71 | }, 72 | ], 73 | 'problem': problem, 74 | 'solution': solution, 75 | } 76 | return feature 77 | 78 | 79 | def accuracy_reward(content, sol): 80 | reward = 0.0 81 | # Try symbolic verification first 82 | try: 83 | answer = parse(content) 84 | if float(verify(answer, parse(sol))) > 0: 85 | reward = 1.0 86 | except Exception: 87 | pass # Continue to next verification method if this fails 88 | 89 | # If symbolic verification failed, try string matching 90 | if reward == 0.0: 91 | try: 92 | # Extract answer from solution if it has think/answer tags 93 | sol_match = re.search(r"(.*?)", sol, re.DOTALL) 94 | ground_truth = sol_match.group(1).strip() if sol_match else sol.strip() 95 | 96 | # Extract answer from content if it has think/answer tags 97 | content_match = re.search(r"(.*?)", content, re.DOTALL) 98 | # student_answer = content_match.group(1).strip() if content_match else content.strip() 99 | if content_match: 100 | student_answer = content_match.group(1).strip() 101 | # HACK, if First letter is correct reward 1 102 | # Compare the extracted answers 103 | if student_answer[0] == ground_truth[0]: 104 | reward = 1.0 105 | else: 106 | reward = 0.0 107 | except Exception: 108 | pass # Keep reward as 0.0 if both methods fail 109 | return reward 110 | 111 | 112 | 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument("--model_path", default="/group/40101/milkcychen/Open-R1-Video/ckpt/Qwen2-VL-7B-EgoPlan-GRPO/egoplan-it-8k-remove-formatreward-matchletterreward-f16export/checkpoint-850") 118 | parser.add_argument("--test_file_path", default="/group/40121/public_datasets/SEED-Bench-R1/annotations/validation_L1.jsonl") 119 | parser.add_argument("--data_root_dir", default="/group/40121/public_datasets/SEED-Bench-R1") 120 | parser.add_argument("--test_batch_size", type=int, default=1) 121 | parser.add_argument("--do_sampling", type=int, default=1) 122 | # parser.add_argument("--") 123 | args = parser.parse_args() 124 | 125 | # default: Load the model on the available device(s) 126 | model = Qwen2VLForConditionalGeneration.from_pretrained( 127 | args.model_path, 128 | torch_dtype=torch.bfloat16, 129 | attn_implementation="flash_attention_2", 130 | device_map="auto", 131 | ) 132 | 133 | processor = AutoProcessor.from_pretrained(args.model_path) 134 | 135 | test_file_basename = os.path.basename(args.test_file_path) 136 | output_eval_results_path = os.path.join(args.model_path, f"eval_results_for_{test_file_basename}") 137 | 138 | rewards = [] 139 | evaluated_example_ids = set() 140 | if os.path.exists(output_eval_results_path): 141 | with open(output_eval_results_path) as fi: 142 | for line in tqdm(fi): 143 | example = json.loads(line) 144 | evaluated_example_ids.add(example['sample_id']) 145 | rewards.append(example['reward']) 146 | 147 | test_examples = [] 148 | if args.test_file_path.endswith('.jsonl'): 149 | with open(args.test_file_path) as f: 150 | examples = [] 151 | for line in f: 152 | example = json.loads(line) 153 | examples.append(example) 154 | else: 155 | with open(args.test_file_path) as f: 156 | examples = json.load(f) 157 | 158 | 159 | for example in examples: 160 | if 'sample_id' not in example: 161 | example['sample_id'] = example['id'] 162 | if example['sample_id'] not in evaluated_example_ids: 163 | test_examples.append(example) 164 | else: 165 | print(f"skip {example['sample_id']}") 166 | 167 | 168 | t = tqdm(total=len(test_examples)) 169 | 170 | with open(output_eval_results_path, 'a') as fo: 171 | for i in range(0, len(test_examples), args.test_batch_size): 172 | batch = test_examples[i:i+args.test_batch_size] 173 | if 'EgoPlan' in args.data_root_dir: 174 | features = [make_conversation_egoplan(example, args.data_root_dir) for example in batch] 175 | elif 'LongVideoBench' in args.data_root_dir: 176 | features = [make_conversation_longvideobench(example, args.data_root_dir) for example in batch] 177 | else: 178 | raise NotImplementedError 179 | prompts = [feature['prompt'] for feature in features] 180 | prompt_texts = [ 181 | processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) 182 | for msg in prompts 183 | ] 184 | image_inputs, video_inputs, video_kwargs = process_vision_info(prompts, return_video_kwargs=True) 185 | prompt_inputs = processor( 186 | text=prompt_texts, 187 | images=image_inputs, 188 | videos=video_inputs, 189 | return_tensors="pt", 190 | padding=True, 191 | padding_side="left", 192 | add_special_tokens=False, 193 | ) 194 | prompt_inputs = prompt_inputs.to("cuda") 195 | 196 | # Inference 197 | generated_ids = model.generate(**prompt_inputs, 198 | max_new_tokens=256, 199 | do_sample=args.do_sampling, 200 | temperature=1 201 | ) 202 | generated_ids_trimmed = [ 203 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(prompt_inputs.input_ids, generated_ids) 204 | ] 205 | output_texts = processor.batch_decode( 206 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 207 | ) 208 | 209 | for prompt, output_text, feature in zip(prompts, output_texts, features): 210 | reward = accuracy_reward(output_text, feature['solution']) 211 | rewards.append(reward) 212 | print(f"------------- Sample {feature['sample_id']} -------------\n") 213 | print(f"Question: {prompt[0]['content'][-1]['text']}\n") 214 | print(f"Content: {output_text}\n") 215 | print(f"Solution: {feature['solution']}\n") 216 | print(f"Reward: {reward}\n") 217 | 218 | feature['response'] = output_text 219 | feature['reward'] = reward 220 | fo.write(json.dumps(feature)+"\n") 221 | fo.flush() 222 | 223 | t.update(1) 224 | t.set_postfix(reward_mean=np.mean(rewards)) 225 | 226 | -------------------------------------------------------------------------------- /src/open_r1_egoplan/grpo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import re 17 | from dataclasses import dataclass, field 18 | from datetime import datetime 19 | from typing import Optional, List 20 | 21 | from datasets import load_dataset, Dataset, DatasetDict 22 | 23 | from math_verify import parse, verify 24 | from open_r1_egoplan.trainer import Qwen2VLGRPOTrainer 25 | from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config 26 | import random 27 | from functools import partial 28 | 29 | @dataclass 30 | class GRPOScriptArguments(ScriptArguments): 31 | """ 32 | Script arguments for the GRPO training script. 33 | 34 | Args: 35 | reward_funcs (`List[str]`): 36 | List of reward functions. Possible values: 'accuracy', 'format'. 37 | """ 38 | 39 | reward_funcs: List[str] = field( 40 | default_factory=lambda: ["accuracy",], 41 | metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"}, 42 | ) 43 | max_pixels: Optional[int] = field( 44 | default=12845056, 45 | metadata={"help": "Maximum number of pixels for the image"}, 46 | ) 47 | min_pixels: Optional[int] = field( 48 | default=3136, 49 | metadata={"help": "Minimum number of pixels for the image"}, 50 | ) 51 | jsonl_path: Optional[str] = field( 52 | default=None, 53 | metadata={"help": "json file path"}, 54 | ) 55 | data_root_dir: Optional[str] = field( 56 | default=None, 57 | metadata={"help": "Root directory to datasets"}, 58 | ) 59 | freeze_visual_encoder: Optional[int] = field( 60 | default=0, 61 | metadata={"help": "Whether to freeze visual encoder"}, 62 | ) 63 | 64 | def accuracy_reward(completions, solution, prompts, **kwargs): 65 | """Reward function that checks if the completion is correct using either symbolic verification or exact string matching.""" 66 | contents = [completion[0]["content"] for completion in completions] 67 | print(f"------------- Sample {kwargs['sample_id'][0]} -------------") 68 | print(prompts[0][0]['content'][-1]['text']+'\n') 69 | print(solution[0]+'\n') 70 | print(contents[:2]) # print online completion 71 | rewards = [] 72 | current_time = datetime.now().strftime("%d-%H-%M-%S-%f") 73 | 74 | for content, sol, prompt, sample_id in zip(contents, solution, prompts, kwargs['sample_id']): 75 | reward = 0.0 76 | # Try symbolic verification first 77 | try: 78 | answer = parse(content) 79 | if float(verify(answer, parse(sol))) > 0: 80 | reward = 1.0 81 | except Exception: 82 | pass # Continue to next verification method if this fails 83 | 84 | # If symbolic verification failed, try string matching 85 | if reward == 0.0: 86 | try: 87 | # Extract answer from solution if it has think/answer tags 88 | sol_match = re.search(r"(.*?)", sol, re.DOTALL) 89 | ground_truth = sol_match.group(1).strip() if sol_match else sol.strip() 90 | 91 | # Extract answer from content if it has think/answer tags 92 | content_match = re.search(r"(.*?)", content, re.DOTALL) 93 | # student_answer = content_match.group(1).strip() if content_match else content.strip() 94 | if content_match: 95 | student_answer = content_match.group(1).strip() 96 | # HACK, if First letter is correct reward 1 97 | # Compare the extracted answers 98 | if student_answer[0] == ground_truth[0]: 99 | reward = 1.0 100 | else: 101 | reward = 0.0 102 | except Exception: 103 | pass # Keep reward as 0.0 if both methods fail 104 | 105 | rewards.append(reward) 106 | if os.getenv("DEBUG_MODE") == "true": 107 | log_path = os.getenv("LOG_PATH") 108 | with open(log_path, "a") as f: 109 | f.write(f"------------- {current_time} Accuracy reward for Sample {sample_id}: {reward} -------------\n") 110 | f.write(f"Question: {prompt[0]['content'][-1]['text']}\n") 111 | f.write(f"Content: {content}\n") 112 | f.write(f"Solution: {sol}\n") 113 | return rewards 114 | 115 | 116 | def format_reward(completions, **kwargs): 117 | """Reward function that checks if the completion has a specific format.""" 118 | pattern = r".*?\s*.*?" 119 | completion_contents = [completion[0]["content"] for completion in completions] 120 | matches = [re.match(pattern, content, re.DOTALL) for content in completion_contents] 121 | return [1.0 if match else 0.0 for match in matches] 122 | 123 | 124 | reward_funcs_registry = { 125 | "accuracy": accuracy_reward, 126 | "format": format_reward, 127 | } 128 | 129 | 130 | def create_dataset_from_jsonl_simple(jsonl_path): 131 | base_dataset = Dataset.from_json(jsonl_path) 132 | return DatasetDict({ 133 | "train": base_dataset 134 | }) 135 | 136 | 137 | def main(script_args, training_args, model_args): 138 | # Get reward functions 139 | reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] 140 | 141 | if script_args.jsonl_path: 142 | # # load dataset from jsonl 143 | dataset = create_dataset_from_jsonl_simple(script_args.jsonl_path) 144 | else: 145 | # Load the dataset 146 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) 147 | 148 | # Format into conversation 149 | QUESTION_TEMPLATE = "{Question} Output the thinking process in and final answer in tags, i.e., reasoning process here answer here . " 150 | 151 | def make_conversation_egoplan(example, data_root_dir): 152 | if 'golden_choice_idx' not in example: 153 | negative_answers = random.sample(example["negative_answers"], 3) 154 | options = negative_answers + [example["answer"]] 155 | else: 156 | options = [example['choice_a'], example['choice_b'], example['choice_c'], example['choice_d']] 157 | 158 | random.shuffle(options) 159 | answer_index = options.index(example["answer"]) 160 | problem = f"{example['question']}\n" + "\n".join([f"{chr(65 + i)}. {option}" for i, option in enumerate(options)]) + "\n" 161 | solution = f"{chr(65 + answer_index)}." 162 | 163 | 164 | content = [] 165 | if len(example['task_progress_metadata']) > 0: 166 | video_path = os.path.join(data_root_dir, 'videos', example['video_source'], example['video_basename']) 167 | content.append({"type": "video", "video": video_path}) 168 | 169 | image_path = os.path.join(data_root_dir, 'images', example['video_source'], example['current_observation_basename']) 170 | content.extend([ 171 | {"type": "image", "image": image_path}, 172 | {"type": "text", "text": QUESTION_TEMPLATE.format(Question=problem)}, 173 | ]) 174 | 175 | feature = { 176 | "prompt": [ 177 | { 178 | "role": "user", 179 | "content": content, 180 | }, 181 | ], 182 | 'problem': problem, 183 | 'solution': solution, 184 | } 185 | 186 | return feature 187 | 188 | dataset = dataset.map(partial(make_conversation_egoplan, data_root_dir=script_args.data_root_dir)) 189 | trainer_cls = Qwen2VLGRPOTrainer 190 | 191 | # Initialize the GRPO trainer 192 | trainer = trainer_cls( 193 | model=model_args.model_name_or_path, 194 | reward_funcs=reward_funcs, 195 | args=training_args, 196 | train_dataset=dataset[script_args.dataset_train_split], 197 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, 198 | peft_config=get_peft_config(model_args), 199 | attn_implementation=model_args.attn_implementation, 200 | max_pixels=script_args.max_pixels, 201 | min_pixels=script_args.min_pixels, 202 | ) 203 | 204 | # Train and push the model to the Hub 205 | trainer.train() 206 | 207 | # Save and push to hub 208 | trainer.save_model(training_args.output_dir) 209 | if training_args.push_to_hub: 210 | trainer.push_to_hub(dataset_name=script_args.dataset_name) 211 | 212 | 213 | if __name__ == "__main__": 214 | parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig)) 215 | script_args, training_args, model_args = parser.parse_args_and_config() 216 | training_args.freeze_visual_encoder = script_args.freeze_visual_encoder 217 | main(script_args, training_args, model_args) 218 | -------------------------------------------------------------------------------- /src/open_r1_egoplan/sft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import logging 17 | import os 18 | import sys 19 | 20 | import datasets 21 | from dataclasses import dataclass, field 22 | from typing import Optional 23 | import torch 24 | import transformers 25 | from datasets import load_dataset 26 | from transformers import AutoTokenizer, set_seed, AutoProcessor 27 | from transformers.trainer_utils import get_last_checkpoint 28 | import trl 29 | from trl import ( 30 | ModelConfig, 31 | ScriptArguments, 32 | SFTTrainer, 33 | TrlParser, 34 | get_kbit_device_map, 35 | get_peft_config, 36 | get_quantization_config, 37 | ) 38 | from datasets import Dataset, DatasetDict 39 | from qwen_vl_utils import process_vision_info 40 | logger = logging.getLogger(__name__) 41 | 42 | 43 | @dataclass 44 | class SFTConfig(trl.SFTConfig): 45 | """ 46 | args for callbacks, benchmarks etc 47 | """ 48 | 49 | benchmarks: list[str] = field( 50 | default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."} 51 | ) 52 | callbacks: list[str] = field( 53 | default_factory=lambda: [], metadata={"help": "The callbacks to run during training."} 54 | ) 55 | system_prompt: Optional[str] = field( 56 | default=None, 57 | metadata={"help": "The optional system prompt to use for benchmarking."}, 58 | ) 59 | hub_model_revision: Optional[str] = field( 60 | default="main", 61 | metadata={"help": "The Hub model branch to push the model to."}, 62 | ) 63 | jsonl_path: Optional[str] = field( 64 | default=None, 65 | metadata={"help": "json file path"}, 66 | ) 67 | data_root_dir: Optional[str] = field( 68 | default=None, 69 | metadata={"help": "Root directory to datasets"}, 70 | ) 71 | overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) 72 | push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) 73 | 74 | 75 | 76 | processor = None 77 | 78 | 79 | def convert_example(example): 80 | messages = example['prompt'] 81 | messages.append({ 82 | "role": "assistant", 83 | "content": [{"type": "text", "text": example['response']}], 84 | }) 85 | example["messages"] = messages 86 | return example 87 | 88 | 89 | def collate_fn(examples): 90 | for example in examples: 91 | for message in example["prompt"]: 92 | if isinstance(message["content"], list): 93 | new_content = [] 94 | for ele in message["content"]: 95 | new_ele = {k: v for k, v in ele.items() if v is not None} 96 | if 'video' in new_ele: 97 | new_ele['video'] = os.path.join(training_args.data_root_dir, new_ele['video']) 98 | if 'image' in new_ele: 99 | new_ele['image'] = os.path.join(training_args.data_root_dir, new_ele['image']) 100 | new_content.append(new_ele) 101 | message["content"] = new_content 102 | convert_example(example) 103 | 104 | texts = [ 105 | processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False) 106 | for example in examples 107 | ] 108 | image_inputs = [] 109 | video_inputs = [] 110 | for example in examples: 111 | imgs, vids = process_vision_info(example["messages"]) 112 | image_inputs.append(imgs) 113 | video_inputs.append(vids) 114 | batch = processor( 115 | text=texts, 116 | images=image_inputs, 117 | videos=video_inputs, 118 | return_tensors="pt", 119 | padding=True, 120 | ) 121 | 122 | labels = batch["input_ids"].clone() 123 | # labels[labels == processor.tokenizer.pad_token_id] = -100 124 | # image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token) 125 | # labels[labels == image_token_id] = -100 126 | 127 | im_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>") 128 | for i in range(labels.size(0)): 129 | last_index = (labels[i] == im_start_token_id).nonzero(as_tuple=True)[0][-1].item() 130 | labels[i, :last_index+3] = -100 131 | 132 | batch["labels"] = labels 133 | return batch 134 | 135 | 136 | def create_dataset_from_jsonl_simple(jsonl_path): 137 | base_dataset = Dataset.from_json(jsonl_path) 138 | return DatasetDict({ 139 | "train": base_dataset 140 | }) 141 | 142 | def main(script_args, training_args, model_args): 143 | # Set seed for reproducibility 144 | set_seed(training_args.seed) 145 | 146 | ############### 147 | # Setup logging 148 | ############### 149 | logging.basicConfig( 150 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 151 | datefmt="%Y-%m-%d %H:%M:%S", 152 | handlers=[logging.StreamHandler(sys.stdout)], 153 | ) 154 | log_level = training_args.get_process_log_level() 155 | logger.setLevel(log_level) 156 | datasets.utils.logging.set_verbosity(log_level) 157 | transformers.utils.logging.set_verbosity(log_level) 158 | transformers.utils.logging.enable_default_handler() 159 | transformers.utils.logging.enable_explicit_format() 160 | 161 | # Log on each process a small summary 162 | logger.warning( 163 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 164 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 165 | ) 166 | logger.info(f"Model parameters {model_args}") 167 | logger.info(f"Script parameters {script_args}") 168 | logger.info(f"Data parameters {training_args}") 169 | 170 | # Check for last checkpoint 171 | last_checkpoint = None 172 | if os.path.isdir(training_args.output_dir): 173 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 174 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None: 175 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") 176 | 177 | ################ 178 | # Load datasets 179 | ################ 180 | 181 | if training_args.jsonl_path: 182 | # # load dataset from jsonl 183 | dataset = create_dataset_from_jsonl_simple(training_args.jsonl_path) 184 | else: 185 | # Load the dataset 186 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) 187 | 188 | 189 | ################ 190 | # Load tokenizer 191 | ################ 192 | global processor 193 | if "vl" in model_args.model_name_or_path.lower(): 194 | processor = AutoProcessor.from_pretrained( 195 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 196 | ) 197 | logger.info("Using AutoProcessor for vision-language model.") 198 | else: 199 | processor = AutoTokenizer.from_pretrained( 200 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True 201 | ) 202 | logger.info("Using AutoTokenizer for text-only model.") 203 | if hasattr(processor, "pad_token") and processor.pad_token is None: 204 | processor.pad_token = processor.eos_token 205 | elif hasattr(processor.tokenizer, "pad_token") and processor.tokenizer.pad_token is None: 206 | processor.tokenizer.pad_token = processor.tokenizer.eos_token 207 | 208 | ################### 209 | # Model init kwargs 210 | ################### 211 | logger.info("*** Initializing model kwargs ***") 212 | torch_dtype = ( 213 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) 214 | ) 215 | quantization_config = get_quantization_config(model_args) 216 | model_kwargs = dict( 217 | revision=model_args.model_revision, 218 | trust_remote_code=model_args.trust_remote_code, 219 | attn_implementation=model_args.attn_implementation, 220 | torch_dtype=torch_dtype, 221 | use_cache=False if training_args.gradient_checkpointing else True, 222 | device_map=get_kbit_device_map() if quantization_config is not None else None, 223 | quantization_config=quantization_config, 224 | ) 225 | # training_args.model_init_kwargs = model_kwargs 226 | from transformers import Qwen2VLForConditionalGeneration 227 | model = Qwen2VLForConditionalGeneration.from_pretrained( 228 | model_args.model_name_or_path, **model_kwargs 229 | ) 230 | ############################ 231 | # Initialize the SFT Trainer 232 | ############################ 233 | training_args.dataset_kwargs = { 234 | "skip_prepare_dataset": True, 235 | } 236 | training_args.remove_unused_columns = False 237 | trainer = SFTTrainer( 238 | model=model, 239 | args=training_args, 240 | train_dataset=dataset[script_args.dataset_train_split], 241 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, 242 | processing_class=processor.tokenizer, 243 | data_collator=collate_fn, 244 | peft_config=get_peft_config(model_args) 245 | ) 246 | 247 | ############### 248 | # Training loop 249 | ############### 250 | logger.info("*** Train ***") 251 | checkpoint = None 252 | if training_args.resume_from_checkpoint is not None: 253 | checkpoint = training_args.resume_from_checkpoint 254 | elif last_checkpoint is not None: 255 | checkpoint = last_checkpoint 256 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 257 | metrics = train_result.metrics 258 | metrics["train_samples"] = len(dataset[script_args.dataset_train_split]) 259 | trainer.log_metrics("train", metrics) 260 | trainer.save_metrics("train", metrics) 261 | trainer.save_state() 262 | 263 | ################################## 264 | # Save model and create model card 265 | ################################## 266 | logger.info("*** Save model ***") 267 | trainer.save_model(training_args.output_dir) 268 | processor.save_pretrained(training_args.output_dir) 269 | logger.info(f"Model saved to {training_args.output_dir}") 270 | 271 | # Save everything else on main process 272 | kwargs = { 273 | "dataset_name": script_args.dataset_name, 274 | "tags": ["R1-V"], 275 | } 276 | if trainer.accelerator.is_main_process: 277 | trainer.create_model_card(**kwargs) 278 | # Restore k,v cache for fast inference 279 | trainer.model.config.use_cache = True 280 | trainer.model.config.save_pretrained(training_args.output_dir) 281 | ############# 282 | # push to hub 283 | ############# 284 | 285 | if training_args.push_to_hub: 286 | logger.info("Pushing to hub...") 287 | trainer.push_to_hub(**kwargs) 288 | processor.push_to_hub(training_args.hub_model_id) 289 | 290 | 291 | 292 | 293 | if __name__ == "__main__": 294 | parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) 295 | script_args, training_args, model_args = parser.parse_args_and_config() 296 | main(script_args, training_args, model_args) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /qwen-vl-utils/src/qwen_vl_utils/vision_process_egoplan.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import base64 4 | import logging 5 | import math 6 | import os 7 | import sys 8 | import time 9 | import warnings 10 | from functools import lru_cache 11 | from io import BytesIO 12 | 13 | import requests 14 | import torch 15 | import torchvision 16 | from packaging import version 17 | from PIL import Image 18 | from torchvision import io, transforms 19 | from torchvision.transforms import InterpolationMode 20 | from typing import Optional 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | IMAGE_FACTOR = 28 26 | MIN_PIXELS = 4 * 28 * 28 27 | MAX_PIXELS = 16384 * 28 * 28 28 | MAX_RATIO = 200 29 | 30 | VIDEO_MIN_PIXELS = 128 * 28 * 28 31 | VIDEO_MAX_PIXELS = 196 * 28 * 28 32 | FRAME_FACTOR = 2 33 | FPS = 2.0 34 | FPS_MIN_FRAMES = 4 35 | FPS_MAX_FRAMES = 16 36 | 37 | # Set the maximum number of video token inputs. 38 | # Here, 128K represents the maximum number of input tokens for the VLLM model. 39 | # Remember to adjust it according to your own configuration. 40 | VIDEO_TOTAL_PIXELS = int(float(os.environ.get('VIDEO_MAX_PIXELS', 128000 * 28 * 28 * 0.9))) 41 | logger.info(f"set VIDEO_TOTAL_PIXELS: {VIDEO_TOTAL_PIXELS}") 42 | 43 | 44 | def round_by_factor(number: int, factor: int) -> int: 45 | """Returns the closest integer to 'number' that is divisible by 'factor'.""" 46 | return round(number / factor) * factor 47 | 48 | 49 | def ceil_by_factor(number: int, factor: int) -> int: 50 | """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" 51 | return math.ceil(number / factor) * factor 52 | 53 | 54 | def floor_by_factor(number: int, factor: int) -> int: 55 | """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" 56 | return math.floor(number / factor) * factor 57 | 58 | 59 | def smart_resize( 60 | height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS 61 | ) -> tuple[int, int]: 62 | """ 63 | Rescales the image so that the following conditions are met: 64 | 65 | 1. Both dimensions (height and width) are divisible by 'factor'. 66 | 67 | 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. 68 | 69 | 3. The aspect ratio of the image is maintained as closely as possible. 70 | """ 71 | if max(height, width) / min(height, width) > MAX_RATIO: 72 | raise ValueError( 73 | f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" 74 | ) 75 | h_bar = max(factor, round_by_factor(height, factor)) 76 | w_bar = max(factor, round_by_factor(width, factor)) 77 | if h_bar * w_bar > max_pixels: 78 | beta = math.sqrt((height * width) / max_pixels) 79 | h_bar = floor_by_factor(height / beta, factor) 80 | w_bar = floor_by_factor(width / beta, factor) 81 | elif h_bar * w_bar < min_pixels: 82 | beta = math.sqrt(min_pixels / (height * width)) 83 | h_bar = ceil_by_factor(height * beta, factor) 84 | w_bar = ceil_by_factor(width * beta, factor) 85 | return h_bar, w_bar 86 | 87 | 88 | def to_rgb(pil_image: Image.Image) -> Image.Image: 89 | if pil_image.mode == 'RGBA': 90 | white_background = Image.new("RGB", pil_image.size, (255, 255, 255)) 91 | white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask 92 | return white_background 93 | else: 94 | return pil_image.convert("RGB") 95 | 96 | 97 | def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image: 98 | if "image" in ele: 99 | image = ele["image"] 100 | else: 101 | image = ele["image_url"] 102 | image_obj = None 103 | if isinstance(image, Image.Image): 104 | image_obj = image 105 | elif image.startswith("http://") or image.startswith("https://"): 106 | response = requests.get(image, stream=True) 107 | image_obj = Image.open(BytesIO(response.content)) 108 | elif image.startswith("file://"): 109 | image_obj = Image.open(image[7:]) 110 | elif image.startswith("data:image"): 111 | if "base64," in image: 112 | _, base64_data = image.split("base64,", 1) 113 | data = base64.b64decode(base64_data) 114 | image_obj = Image.open(BytesIO(data)) 115 | else: 116 | image_obj = Image.open(image) 117 | if image_obj is None: 118 | raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") 119 | image = to_rgb(image_obj) 120 | ## resize 121 | if "resized_height" in ele and "resized_width" in ele: 122 | resized_height, resized_width = smart_resize( 123 | ele["resized_height"], 124 | ele["resized_width"], 125 | factor=size_factor, 126 | ) 127 | else: 128 | width, height = image.size 129 | min_pixels = ele.get("min_pixels", MIN_PIXELS) 130 | max_pixels = ele.get("max_pixels", MAX_PIXELS) 131 | resized_height, resized_width = smart_resize( 132 | height, 133 | width, 134 | factor=size_factor, 135 | min_pixels=min_pixels, 136 | max_pixels=max_pixels, 137 | ) 138 | image = image.resize((resized_width, resized_height)) 139 | 140 | return image 141 | 142 | 143 | def smart_nframes( 144 | ele: dict, 145 | total_frames: int, 146 | video_fps: int | float, 147 | ) -> int: 148 | """calculate the number of frames for video used for model inputs. 149 | 150 | Args: 151 | ele (dict): a dict contains the configuration of video. 152 | support either `fps` or `nframes`: 153 | - nframes: the number of frames to extract for model inputs. 154 | - fps: the fps to extract frames for model inputs. 155 | - min_frames: the minimum number of frames of the video, only used when fps is provided. 156 | - max_frames: the maximum number of frames of the video, only used when fps is provided. 157 | total_frames (int): the original total number of frames of the video. 158 | video_fps (int | float): the original fps of the video. 159 | 160 | Raises: 161 | ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. 162 | 163 | Returns: 164 | int: the number of frames for video used for model inputs. 165 | """ 166 | assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" 167 | if "nframes" in ele: 168 | nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) 169 | else: 170 | fps = ele.get("fps", FPS) 171 | min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) 172 | max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR) 173 | nframes = total_frames / video_fps * fps 174 | if nframes > total_frames: 175 | logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]") 176 | nframes = min(min(max(nframes, min_frames), max_frames), total_frames) 177 | nframes = floor_by_factor(nframes, FRAME_FACTOR) 178 | if not (FRAME_FACTOR <= nframes and nframes <= total_frames): 179 | raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") 180 | return nframes 181 | 182 | 183 | def _read_video_torchvision( 184 | ele: dict, 185 | ) -> (torch.Tensor, float): 186 | """read video using torchvision.io.read_video 187 | 188 | Args: 189 | ele (dict): a dict contains the configuration of video. 190 | support keys: 191 | - video: the path of video. support "file://", "http://", "https://" and local path. 192 | - video_start: the start time of video. 193 | - video_end: the end time of video. 194 | Returns: 195 | torch.Tensor: the video tensor with shape (T, C, H, W). 196 | """ 197 | video_path = ele["video"] 198 | if version.parse(torchvision.__version__) < version.parse("0.19.0"): 199 | if "http://" in video_path or "https://" in video_path: 200 | warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.") 201 | if "file://" in video_path: 202 | video_path = video_path[7:] 203 | st = time.time() 204 | video, audio, info = io.read_video( 205 | video_path, 206 | start_pts=ele.get("video_start", 0.0), 207 | end_pts=ele.get("video_end", None), 208 | pts_unit="sec", 209 | output_format="TCHW", 210 | ) 211 | total_frames, video_fps = video.size(0), info["video_fps"] 212 | logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") 213 | nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) 214 | idx = torch.linspace(0, total_frames - 1, nframes).round().long() 215 | sample_fps = nframes / max(total_frames, 1e-6) * video_fps 216 | video = video[idx] 217 | return video, sample_fps 218 | 219 | 220 | def is_decord_available() -> bool: 221 | import importlib.util 222 | 223 | return importlib.util.find_spec("decord") is not None 224 | 225 | 226 | def _read_video_decord( 227 | ele: dict, 228 | ) -> (torch.Tensor, float): 229 | """read video using decord.VideoReader 230 | 231 | Args: 232 | ele (dict): a dict contains the configuration of video. 233 | support keys: 234 | - video: the path of video. support "file://", "http://", "https://" and local path. 235 | - video_start: the start time of video. 236 | - video_end: the end time of video. 237 | Returns: 238 | torch.Tensor: the video tensor with shape (T, C, H, W). 239 | """ 240 | import decord 241 | video_path = ele["video"] 242 | st = time.time() 243 | vr = decord.VideoReader(video_path) 244 | # TODO: support start_pts and end_pts 245 | if 'video_start' in ele or 'video_end' in ele: 246 | raise NotImplementedError("not support start_pts and end_pts in decord for now.") 247 | total_frames, video_fps = len(vr), vr.get_avg_fps() 248 | logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") 249 | nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) 250 | idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() 251 | video = vr.get_batch(idx).asnumpy() 252 | video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format 253 | sample_fps = nframes / max(total_frames, 1e-6) * video_fps 254 | return video, sample_fps 255 | 256 | 257 | VIDEO_READER_BACKENDS = { 258 | "decord": _read_video_decord, 259 | "torchvision": _read_video_torchvision, 260 | } 261 | 262 | FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) 263 | 264 | 265 | @lru_cache(maxsize=1) 266 | def get_video_reader_backend() -> str: 267 | if FORCE_QWENVL_VIDEO_READER is not None: 268 | video_reader_backend = FORCE_QWENVL_VIDEO_READER 269 | elif is_decord_available(): 270 | video_reader_backend = "decord" 271 | else: 272 | video_reader_backend = "torchvision" 273 | print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr) 274 | return video_reader_backend 275 | 276 | 277 | def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]: 278 | if isinstance(ele["video"], str): 279 | video_reader_backend = get_video_reader_backend() 280 | try: 281 | video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele) 282 | except Exception as e: 283 | logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}") 284 | video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele) 285 | 286 | nframes, _, height, width = video.shape 287 | 288 | min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) 289 | total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) 290 | max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05)) 291 | max_pixels_supposed = ele.get("max_pixels", max_pixels) 292 | if max_pixels_supposed > max_pixels: 293 | logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].") 294 | max_pixels = min(max_pixels_supposed, max_pixels) 295 | if "resized_height" in ele and "resized_width" in ele: 296 | resized_height, resized_width = smart_resize( 297 | ele["resized_height"], 298 | ele["resized_width"], 299 | factor=image_factor, 300 | ) 301 | else: 302 | resized_height, resized_width = smart_resize( 303 | height, 304 | width, 305 | factor=image_factor, 306 | min_pixels=min_pixels, 307 | max_pixels=max_pixels, 308 | ) 309 | video = transforms.functional.resize( 310 | video, 311 | [resized_height, resized_width], 312 | interpolation=InterpolationMode.BICUBIC, 313 | antialias=True, 314 | ).float() 315 | 316 | # print(f'video shape {video.shape}') 317 | if return_video_sample_fps: 318 | return video, sample_fps 319 | return video 320 | else: 321 | assert isinstance(ele["video"], (list, tuple)) 322 | process_info = ele.copy() 323 | process_info.pop("type", None) 324 | process_info.pop("video", None) 325 | images = [ 326 | fetch_image({"image": video_element, **process_info}, size_factor=image_factor) 327 | for video_element in ele["video"] 328 | ] 329 | nframes = ceil_by_factor(len(images), FRAME_FACTOR) 330 | if len(images) < nframes: 331 | images.extend([images[-1]] * (nframes - len(images))) 332 | if return_video_sample_fps: 333 | return images, process_info.pop("fps", 2.0) 334 | return images 335 | 336 | 337 | def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]: 338 | vision_infos = [] 339 | if isinstance(conversations[0], dict): 340 | conversations = [conversations] 341 | for conversation in conversations: 342 | for message in conversation: 343 | if isinstance(message["content"], list): 344 | for ele in message["content"]: 345 | if ( 346 | "image" in ele 347 | or "image_url" in ele 348 | or "video" in ele 349 | or ele["type"] in ("image", "image_url", "video") 350 | ): 351 | ele['resized_height'] = 256 352 | ele['resized_width'] = 256 353 | vision_infos.append(ele) 354 | return vision_infos 355 | 356 | 357 | def process_vision_info( 358 | conversations: list[dict] | list[list[dict]], 359 | return_video_kwargs: bool = False, 360 | ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]: 361 | 362 | vision_infos = extract_vision_info(conversations) 363 | ## Read images or videos 364 | image_inputs = [] 365 | video_inputs = [] 366 | video_sample_fps_list = [] 367 | for vision_info in vision_infos: 368 | if "image" in vision_info or "image_url" in vision_info: 369 | image_inputs.append(fetch_image(vision_info)) 370 | elif "video" in vision_info: 371 | video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True) 372 | video_sample_fps_list.append(video_sample_fps) 373 | video_inputs.append(video_input) 374 | else: 375 | raise ValueError("image, image_url or video should in content.") 376 | if len(image_inputs) == 0: 377 | image_inputs = None 378 | if len(video_inputs) == 0: 379 | video_inputs = None 380 | if return_video_kwargs: 381 | return image_inputs, video_inputs, {'fps': video_sample_fps_list} 382 | return image_inputs, video_inputs -------------------------------------------------------------------------------- /src/open_r1_egoplan/trainer/grpo_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | import os 17 | import textwrap 18 | from collections import defaultdict 19 | from typing import Any, Callable, Optional, Union, List, Tuple, Dict 20 | 21 | import torch 22 | import torch.utils.data 23 | import transformers 24 | from datasets import Dataset, IterableDataset 25 | from packaging import version 26 | from transformers import ( 27 | AriaForConditionalGeneration, 28 | AriaProcessor, 29 | AutoModelForCausalLM, 30 | AutoModelForSequenceClassification, 31 | AutoProcessor, 32 | AutoTokenizer, 33 | GenerationConfig, 34 | PreTrainedModel, 35 | PreTrainedTokenizerBase, 36 | Qwen2VLForConditionalGeneration, 37 | Trainer, 38 | TrainerCallback, 39 | is_wandb_available, 40 | ) 41 | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled 42 | from transformers.utils import is_peft_available 43 | 44 | from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template 45 | from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation 46 | from trl.trainer.grpo_config import GRPOConfig 47 | from trl.trainer.utils import generate_model_card, get_comet_experiment_url 48 | 49 | 50 | from qwen_vl_utils import process_vision_info 51 | 52 | if is_peft_available(): 53 | from peft import PeftConfig, get_peft_model 54 | 55 | if is_wandb_available(): 56 | import wandb 57 | 58 | # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of 59 | # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. 60 | RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], List[float]]] 61 | 62 | 63 | class Qwen2VLGRPOTrainer(Trainer): 64 | """ 65 | Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the 66 | paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300). 67 | 68 | Example: 69 | 70 | ```python 71 | from datasets import load_dataset 72 | from trl import GRPOTrainer 73 | 74 | dataset = load_dataset("trl-lib/tldr", split="train") 75 | 76 | trainer = GRPOTrainer( 77 | model="Qwen/Qwen2-0.5B-Instruct", 78 | reward_funcs="weqweasdas/RM-Gemma-2B", 79 | train_dataset=dataset, 80 | ) 81 | 82 | trainer.train() 83 | ``` 84 | 85 | Args: 86 | model (`Union[str, PreTrainedModel]`): 87 | Model to be trained. Can be either: 88 | 89 | - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or 90 | a path to a *directory* containing model weights saved using 91 | [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is 92 | loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments 93 | in `args.model_init_kwargs`. 94 | - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. 95 | reward_funcs (`Union[RewardFunc, list[RewardFunc]]`): 96 | Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward 97 | functions with the prompts and completions and sum the rewards. Can be either: 98 | 99 | - A single reward function, such as: 100 | - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a 101 | path to a *directory* containing model weights saved using 102 | [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded 103 | using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the 104 | keyword arguments in `args.model_init_kwargs`. 105 | - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. 106 | - A custom reward function: The function is provided with the prompts and the generated completions, 107 | plus any additional columns in the dataset. It should return a list of rewards. For more details, see 108 | [Using a custom reward function](#using-a-custom-reward-function). 109 | - A list of reward functions, where each item can independently be any of the above types. Mixing different 110 | types within the list (e.g., a string model ID and a custom reward function) is allowed. 111 | args ([`GRPOConfig`], *optional*, defaults to `None`): 112 | Configuration for this trainer. If `None`, a default configuration is used. 113 | train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): 114 | Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is 115 | ignored. The format of the samples can be either: 116 | 117 | - [Standard](dataset_formats#standard): Each sample contains plain text. 118 | - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role 119 | and content). 120 | eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): 121 | Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. 122 | processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): 123 | Processing class used to process the data. The padding side must be set to "left". If `None`, the 124 | processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`]. 125 | reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`): 126 | Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: 127 | 128 | - A single processing class: Used when `reward_funcs` contains only one reward function. 129 | - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. 130 | If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is 131 | `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`]. 132 | For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]), 133 | the corresponding entries in `reward_processing_classes` are ignored. 134 | callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): 135 | List of callbacks to customize the training loop. Will add those to the list of default callbacks 136 | detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). 137 | 138 | If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] 139 | method. 140 | optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): 141 | A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your 142 | model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. 143 | peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): 144 | PEFT configuration used to wrap the model. If `None`, the model is not wrapped. 145 | """ 146 | 147 | def __init__( 148 | self, 149 | model: Union[str, PreTrainedModel], 150 | reward_funcs: Union[RewardFunc, List[RewardFunc]], 151 | args: GRPOConfig = None, 152 | train_dataset: Optional[Union[Dataset, IterableDataset]] = None, 153 | eval_dataset: Optional[Union[Dataset, IterableDataset, Dict[str, Union[Dataset, IterableDataset]]]] = None, 154 | processing_class: Optional[PreTrainedTokenizerBase] = None, 155 | reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, List[PreTrainedTokenizerBase]]] = None, 156 | callbacks: Optional[List[TrainerCallback]] = None, 157 | optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), 158 | peft_config: Optional["PeftConfig"] = None, 159 | max_pixels: Optional[int] = 12845056, 160 | min_pixels: Optional[int] = 3136, 161 | attn_implementation: str = "flash_attention_2", 162 | ): 163 | # Args 164 | if args is None: 165 | model_name = model if isinstance(model, str) else model.config._name_or_path 166 | model_name = model_name.split("/")[-1] 167 | args = GRPOConfig(f"{model_name}-GRPO") 168 | 169 | # Models 170 | # Trained model 171 | model_init_kwargs = args.model_init_kwargs or {} 172 | model_init_kwargs["attn_implementation"] = attn_implementation 173 | if isinstance(model, str): 174 | model_id = model 175 | torch_dtype = model_init_kwargs.get("torch_dtype") 176 | if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: 177 | pass # torch_dtype is already a torch.dtype or "auto" or None 178 | elif isinstance(torch_dtype, str): # it's a str, but not "auto" 179 | torch_dtype = getattr(torch, torch_dtype) 180 | model_init_kwargs["torch_dtype"] = torch_dtype 181 | else: 182 | raise ValueError( 183 | "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " 184 | f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." 185 | ) 186 | # Disable caching if gradient checkpointing is enabled (not supported) 187 | model_init_kwargs["use_cache"] = ( 188 | False if args.gradient_checkpointing else model_init_kwargs.get("use_cache") 189 | ) 190 | if "Qwen2-VL" in model_id: 191 | model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs) 192 | elif "Aria" in model_id: 193 | model_init_kwargs.pop("use_cache") 194 | model = AriaForConditionalGeneration.from_pretrained(model, **model_init_kwargs) 195 | else: 196 | model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) 197 | else: 198 | model_id = model.config._name_or_path 199 | if args.model_init_kwargs is not None: 200 | raise ValueError( 201 | "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " 202 | "This argument can only be used when the `model` argument is a string." 203 | ) 204 | 205 | if args.freeze_visual_encoder: 206 | print("Freeze parameters of visual encoder!") 207 | for n, p in model.visual.named_parameters(): 208 | p.requires_grad_(False) 209 | 210 | if peft_config is not None: 211 | model = get_peft_model(model, peft_config) 212 | 213 | # Reference model 214 | if is_deepspeed_zero3_enabled(): 215 | if "Qwen2-VL" in model_id: 216 | self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs) 217 | elif "Aria" in model_id: 218 | self.ref_model = AriaForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs) 219 | else: 220 | self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) 221 | elif peft_config is None: 222 | # If PEFT configuration is not provided, create a reference model based on the initial model. 223 | self.ref_model = create_reference_model(model) 224 | else: 225 | # If PEFT is used, the reference model is not needed since the adapter can be disabled 226 | # to revert to the initial model. 227 | self.ref_model = None 228 | 229 | # Processing class 230 | if processing_class is None: 231 | if "Qwen2-VL" in model_id or "Aria" in model_id: 232 | processing_class = AutoProcessor.from_pretrained(model_id) 233 | pad_token_id = processing_class.tokenizer.pad_token_id 234 | processing_class.pad_token_id = pad_token_id 235 | processing_class.eos_token_id = processing_class.tokenizer.eos_token_id 236 | if "Qwen2-VL" in model_id: 237 | processing_class.image_processor.max_pixels = max_pixels 238 | processing_class.image_processor.min_pixels = min_pixels 239 | else: 240 | processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left") 241 | pad_token_id = processing_class.pad_token_id 242 | 243 | # Reward functions 244 | if not isinstance(reward_funcs, list): 245 | reward_funcs = [reward_funcs] 246 | for i, reward_func in enumerate(reward_funcs): 247 | if isinstance(reward_func, str): 248 | reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( 249 | reward_func, num_labels=1, **model_init_kwargs 250 | ) 251 | self.reward_funcs = reward_funcs 252 | 253 | # Reward processing class 254 | if reward_processing_classes is None: 255 | reward_processing_classes = [None] * len(reward_funcs) 256 | elif not isinstance(reward_processing_classes, list): 257 | reward_processing_classes = [reward_processing_classes] 258 | else: 259 | if len(reward_processing_classes) != len(reward_funcs): 260 | raise ValueError("The number of reward processing classes must match the number of reward functions.") 261 | 262 | for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)): 263 | if isinstance(reward_func, PreTrainedModel): 264 | if reward_processing_class is None: 265 | reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path) 266 | if reward_processing_class.pad_token_id is None: 267 | reward_processing_class.pad_token = reward_processing_class.eos_token 268 | # The reward model computes the reward for the latest non-padded token in the input sequence. 269 | # So it's important to set the pad token ID to the padding token ID of the processing class. 270 | reward_func.config.pad_token_id = reward_processing_class.pad_token_id 271 | reward_processing_classes[i] = reward_processing_class 272 | self.reward_processing_classes = reward_processing_classes 273 | 274 | # Data collator 275 | def data_collator(features): # No data collation is needed in GRPO 276 | return features 277 | 278 | # Training arguments 279 | self.max_prompt_length = args.max_prompt_length 280 | self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper 281 | self.num_generations = args.num_generations # = G in the GRPO paper 282 | self.generation_config = GenerationConfig( 283 | max_new_tokens=self.max_completion_length, 284 | do_sample=True, 285 | temperature=1, # HACK 286 | num_return_sequences=self.num_generations, 287 | pad_token_id=pad_token_id, 288 | ) 289 | self.beta = args.beta 290 | 291 | # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the 292 | # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the 293 | # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning: 294 | # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To 295 | # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True. 296 | # This acts as a flag to indicate that the warning has already been issued. 297 | model.warnings_issued["estimate_tokens"] = True 298 | 299 | # Initialize the metrics 300 | self._metrics = defaultdict(list) 301 | 302 | super().__init__( 303 | model=model, 304 | args=args, 305 | data_collator=data_collator, 306 | train_dataset=train_dataset, 307 | eval_dataset=eval_dataset, 308 | processing_class=processing_class, 309 | callbacks=callbacks, 310 | optimizers=optimizers, 311 | ) 312 | 313 | # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the 314 | # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set 315 | # self.model_accepts_loss_kwargs to False to enable scaling. 316 | self.model_accepts_loss_kwargs = False 317 | 318 | if self.ref_model is not None: 319 | if self.is_deepspeed_enabled: 320 | self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) 321 | else: 322 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) 323 | 324 | for i, reward_func in enumerate(self.reward_funcs): 325 | if isinstance(reward_func, PreTrainedModel): 326 | self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) 327 | 328 | def _set_signature_columns_if_needed(self): 329 | # If `self.args.remove_unused_columns` is True, non-signature columns are removed. 330 | # By default, this method sets `self._signature_columns` to the model's expected inputs. 331 | # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work. 332 | # Instead, we set them to the columns expected by the `training_step` method, hence the override. 333 | if self._signature_columns is None: 334 | self._signature_columns = ["prompt"] 335 | 336 | # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device. 337 | # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step. 338 | def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: 339 | return inputs 340 | 341 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): 342 | if return_outputs: 343 | raise ValueError("The GRPOTrainer does not support returning outputs") 344 | 345 | # for x in inputs: 346 | # msg = x["prompt"] 347 | # new_content = [] 348 | # for ele in msg["content"]: 349 | # new_ele = {} 350 | # for k, v in ele: 351 | # if v is not None: 352 | # new_ele[k] = v 353 | # new_content.append(new_ele) 354 | # msg["content"] = new_content 355 | 356 | for x in inputs: 357 | for message in x["prompt"]: 358 | if isinstance(message["content"], list): 359 | new_content = [] 360 | for ele in message["content"]: 361 | new_ele = {k: v for k, v in ele.items() if v is not None} 362 | new_content.append(new_ele) 363 | message["content"] = new_content 364 | 365 | prompts = [x["prompt"] for x in inputs] 366 | 367 | # Preparation for batch inference 368 | prompt_texts = [ 369 | self.processing_class.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) 370 | for msg in prompts 371 | ] 372 | image_inputs, video_inputs = process_vision_info(prompts) 373 | prompt_inputs = self.processing_class( 374 | text=prompt_texts, 375 | images=image_inputs, 376 | videos=video_inputs, 377 | return_tensors="pt", 378 | padding=True, 379 | padding_side="left", 380 | add_special_tokens=False, 381 | ) 382 | prompt_inputs = super()._prepare_inputs(prompt_inputs) 383 | 384 | if self.max_prompt_length is not None: 385 | prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :] 386 | prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :] 387 | 388 | # Generate completions 389 | with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: 390 | # prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config) 391 | 392 | # Generate N times, each generate one with the temp_generation_config , stack the output_ids to prompt_completion_ids, pad the empty places with number 151613 393 | num_generations = self.generation_config.num_return_sequences 394 | temp_generation_config = copy.deepcopy(self.generation_config) 395 | temp_generation_config.num_return_sequences = 1 396 | 397 | all_completions = [] 398 | 399 | for i in range(num_generations): # -1 because we already have one generation 400 | completion = unwrapped_model.generate(**prompt_inputs, generation_config=temp_generation_config) 401 | all_completions.append(completion) 402 | 403 | # Stack all completions and pad if needed 404 | max_length = max(completion.size(1) for completion in all_completions) 405 | padded_completions = [] 406 | 407 | for completion in all_completions: 408 | if completion.size(1) < max_length: 409 | padding = torch.full( 410 | (completion.size(0), max_length - completion.size(1)), 411 | self.processing_class.tokenizer.pad_token_id, 412 | dtype=completion.dtype, 413 | device=completion.device, 414 | ) 415 | padded_completion = torch.cat([completion, padding], dim=1) 416 | else: 417 | padded_completion = completion 418 | padded_completions.append(padded_completion) 419 | 420 | # Stack all padded completions 421 | prompt_completion_ids = torch.cat(padded_completions, dim=0) 422 | 423 | prompt_length = prompt_inputs["input_ids"].size(1) 424 | completion_ids = prompt_completion_ids[:, prompt_length:] 425 | 426 | # import pdb; pdb.set_trace() 427 | 428 | # Get the per-token log probabilities for the completions for the model and the reference model 429 | def get_per_token_logps(model, input_ids, **kwargs): 430 | logits = model(input_ids, **kwargs).logits # (B, L, V) 431 | logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred 432 | input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it 433 | # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak. 434 | per_token_logps = [] 435 | for logits_row, input_ids_row in zip(logits, input_ids): 436 | log_probs = logits_row.log_softmax(dim=-1) 437 | token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) 438 | per_token_logps.append(token_log_prob) 439 | return torch.stack(per_token_logps) 440 | 441 | prompt_inputs.pop("input_ids") 442 | prompt_inputs.pop("attention_mask") 443 | # # Okay I am assuming that the inputs are Qwen2VL processor 444 | # # and no video for now, repeat the image for each completion 445 | # if "image" in inputs[0]: 446 | # prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].repeat(len(prompt_completion_ids), 1) 447 | # prompt_inputs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(len(prompt_completion_ids), 1) 448 | # # import pdb; pdb.set_trace() 449 | 450 | # # XXX if input video 451 | # # image_grid_thw is from image_process_qwen2_vl 452 | # # https://github.com/huggingface/transformers/blob/dd16acb8a3e93b643aa374c9fb80749f5235c1a6/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L414 453 | # # automatic process 454 | # if "video" in inputs[0]: 455 | # prompt_inputs["pixel_values_videos"] = prompt_inputs["pixel_values_videos"].repeat(len(prompt_completion_ids), 1) 456 | # prompt_inputs["video_grid_thw"] = prompt_inputs["video_grid_thw"].repeat(len(prompt_completion_ids), 1) 457 | 458 | for k in ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]: 459 | if k in prompt_inputs: 460 | prompt_inputs[k] = prompt_inputs[k].repeat(len(prompt_completion_ids), 1) 461 | 462 | 463 | per_token_logps = get_per_token_logps(model, prompt_completion_ids, **prompt_inputs) 464 | # Get rid of the prompt (-1 because of the shift done in get_per_token_logps) 465 | per_token_logps = per_token_logps[:, prompt_length - 1 :] 466 | 467 | with torch.inference_mode(): 468 | if self.ref_model is not None: 469 | ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs) # Fix Bug 470 | else: 471 | with self.accelerator.unwrap_model(model).disable_adapter(): 472 | ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, **prompt_inputs) # Fix Bug 473 | ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :] 474 | 475 | # Compute the KL divergence between the model and the reference model 476 | diff = ref_per_token_logps - per_token_logps 477 | diff = torch.clamp(diff, min=-11.0, max=11.0) 478 | 479 | per_token_kl = torch.exp(diff) - (diff) - 1 480 | 481 | # Mask everything after the first EOS token 482 | is_eos = completion_ids == self.processing_class.eos_token_id 483 | device = self.accelerator.device 484 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) 485 | eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] 486 | sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) 487 | completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() 488 | 489 | # Decode the generated completions 490 | completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) 491 | if is_conversational(inputs[0]): 492 | completions = [[{"role": "assistant", "content": completion}] for completion in completions] 493 | 494 | # Compute the rewards 495 | prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] 496 | 497 | rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) 498 | for i, (reward_func, reward_processing_class) in enumerate( 499 | zip(self.reward_funcs, self.reward_processing_classes) 500 | ): 501 | # import pdb; pdb.set_trace() 502 | if isinstance(reward_func, PreTrainedModel): 503 | if is_conversational(inputs[0]): # true 504 | messages = [{"messages": p + c} for p, c in zip(prompts, completions)] 505 | texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] 506 | else: 507 | texts = [p + c for p, c in zip(prompts, completions)] 508 | reward_inputs = reward_processing_class( 509 | texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False 510 | ) 511 | reward_inputs = super()._prepare_inputs(reward_inputs) 512 | with torch.inference_mode(): 513 | rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) 514 | else: 515 | # Repeat all input columns (but "prompt" and "completion") to match the number of generations 516 | reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]} 517 | for key in reward_kwargs: 518 | for example in inputs: 519 | # Repeat each value in the column for `num_generations` times 520 | reward_kwargs[key].extend([example[key]] * self.num_generations) 521 | output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) 522 | rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) 523 | 524 | # Sum the rewards from all reward functions 525 | rewards = rewards_per_func.sum(dim=1) # (bs, num_generations) 526 | 527 | # Compute grouped-wise rewards 528 | mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) # (bs,) 529 | std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) # (bs,) 530 | 531 | # Normalize the rewards to compute the advantages 532 | mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) # (bs, num_generations) 533 | std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) # (bs, num_generations) 534 | advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) 535 | 536 | # x - x.detach() allows for preserving gradients from x 537 | per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) 538 | per_token_loss = -(per_token_loss - self.beta * per_token_kl) # default 0.04 539 | loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() 540 | 541 | # import pdb; pdb.set_trace() 542 | 543 | # Log the metrics 544 | completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() 545 | self._metrics["completion_length"].append(completion_length) 546 | 547 | reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0) 548 | for i, reward_func in enumerate(self.reward_funcs): 549 | if isinstance(reward_func, PreTrainedModel): 550 | reward_func_name = reward_func.config._name_or_path.split("/")[-1] 551 | else: 552 | reward_func_name = reward_func.__name__ 553 | self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) 554 | 555 | self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item()) 556 | 557 | self._metrics["advantages"].append(self.accelerator.gather_for_metrics(advantages).mean().item()) 558 | 559 | self._metrics["reward_mean"].append(self.accelerator.gather_for_metrics(mean_grouped_rewards).mean().item()) 560 | 561 | self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item()) 562 | 563 | mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() 564 | self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) 565 | 566 | # import pdb; pdb.set_trace() 567 | 568 | return loss 569 | 570 | def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: 571 | metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics 572 | logs = {**logs, **metrics} 573 | if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): 574 | super().log(logs, start_time) 575 | else: # transformers<=4.46 576 | super().log(logs) 577 | self._metrics.clear() 578 | 579 | def create_model_card( 580 | self, 581 | model_name: Optional[str] = None, 582 | dataset_name: Optional[str] = None, 583 | tags: Union[str, List[str], None] = None, 584 | ): 585 | """ 586 | Creates a draft of a model card using the information available to the `Trainer`. 587 | 588 | Args: 589 | model_name (`str` or `None`, *optional*, defaults to `None`): 590 | Name of the model. 591 | dataset_name (`str` or `None`, *optional*, defaults to `None`): 592 | Name of the dataset used for training. 593 | tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): 594 | Tags to be associated with the model card. 595 | """ 596 | if not self.is_world_process_zero(): 597 | return 598 | 599 | if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): 600 | base_model = self.model.config._name_or_path 601 | else: 602 | base_model = None 603 | 604 | tags = tags or [] 605 | if isinstance(tags, str): 606 | tags = [tags] 607 | 608 | if hasattr(self.model.config, "unsloth_version"): 609 | tags.append("unsloth") 610 | 611 | citation = textwrap.dedent( 612 | """\ 613 | @article{zhihong2024deepseekmath, 614 | title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}}, 615 | author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, 616 | year = 2024, 617 | eprint = {arXiv:2402.03300}, 618 | """ 619 | ) 620 | 621 | model_card = generate_model_card( 622 | base_model=base_model, 623 | model_name=model_name, 624 | hub_model_id=self.hub_model_id, 625 | dataset_name=dataset_name, 626 | tags=tags, 627 | wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, 628 | comet_url=get_comet_experiment_url(), 629 | trainer_name="GRPO", 630 | trainer_citation=citation, 631 | paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models", 632 | paper_id="2402.03300", 633 | ) 634 | 635 | model_card.save(os.path.join(self.args.output_dir, "README.md")) --------------------------------------------------------------------------------