├── src
├── open_r1_egoplan
│ ├── __init__.py
│ ├── trainer
│ │ ├── __init__.py
│ │ └── grpo_trainer.py
│ ├── grpo.py
│ └── sft.py
└── open_r1.egg-info
│ ├── not-zip-safe
│ ├── dependency_links.txt
│ ├── top_level.txt
│ ├── requires.txt
│ ├── SOURCES.txt
│ └── PKG-INFO
├── qwen-vl-utils
├── .python-version
├── src
│ └── qwen_vl_utils
│ │ ├── __init__.py
│ │ └── vision_process_egoplan.py
├── requirements.lock
├── pyproject.toml
├── requirements-dev.lock
└── README.md
├── assets
├── teaser.png
├── data_statistics.png
├── evaluation_results.png
└── question_examples.png
├── scripts
├── sft.sh
├── zero3.yaml
├── eval.sh
├── qwen2vl_sft_config.yaml
├── zero2.json
├── zero3.json
├── grpo.sh
└── zero3_offload.json
├── Makefile
├── setup.cfg
├── setup.py
├── README.md
├── infer.py
└── LICENSE
/src/open_r1_egoplan/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/open_r1.egg-info/not-zip-safe:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/qwen-vl-utils/.python-version:
--------------------------------------------------------------------------------
1 | 3.8.19
2 |
--------------------------------------------------------------------------------
/src/open_r1.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/open_r1.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | open_r1_egoplan
2 | open_r1_video
3 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/SEED-Bench-R1/HEAD/assets/teaser.png
--------------------------------------------------------------------------------
/assets/data_statistics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/SEED-Bench-R1/HEAD/assets/data_statistics.png
--------------------------------------------------------------------------------
/assets/evaluation_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/SEED-Bench-R1/HEAD/assets/evaluation_results.png
--------------------------------------------------------------------------------
/assets/question_examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TencentARC/SEED-Bench-R1/HEAD/assets/question_examples.png
--------------------------------------------------------------------------------
/src/open_r1_egoplan/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .grpo_trainer import Qwen2VLGRPOTrainer
2 |
3 |
4 | __all__ = ["Qwen2VLGRPOTrainer"]
5 |
--------------------------------------------------------------------------------
/qwen-vl-utils/src/qwen_vl_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .vision_process_egoplan import (
2 | extract_vision_info,
3 | fetch_image,
4 | fetch_video,
5 | process_vision_info,
6 | smart_resize,
7 | )
8 |
--------------------------------------------------------------------------------
/scripts/sft.sh:
--------------------------------------------------------------------------------
1 | export PROJECT_ROOT="/group/40101/milkcychen/SEED-Bench-R1"
2 |
3 | accelerate launch --config_file ${PROJECT_ROOT}/scripts/zero3.yaml \
4 | src/open_r1_egoplan/sft.py \
5 | --config ${PROJECT_ROOT}/scripts/qwen2vl_sft_config.yaml
6 |
7 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: style quality
2 |
3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
4 | export PYTHONPATH = src
5 |
6 | check_dirs := src
7 |
8 | style:
9 | black --line-length 119 --target-version py310 $(check_dirs) setup.py
10 | isort $(check_dirs) setup.py
11 |
12 | quality:
13 | black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
14 | isort --check-only $(check_dirs) setup.py
15 | flake8 --max-line-length 119 $(check_dirs) setup.py
16 |
17 |
18 | # Evaluation
19 |
20 | evaluate:
21 |
--------------------------------------------------------------------------------
/scripts/zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/src/open_r1.egg-info/requires.txt:
--------------------------------------------------------------------------------
1 | accelerate>=1.2.1
2 | bitsandbytes>=0.43.0
3 | einops>=0.8.0
4 | datasets>=3.2.0
5 | deepspeed==0.15.4
6 | hf_transfer>=0.1.4
7 | huggingface-hub[cli]<1.0,>=0.19.2
8 | liger_kernel==0.5.2
9 | packaging>=23.0
10 | safetensors>=0.3.3
11 | sentencepiece>=0.1.99
12 | transformers
13 | trl
14 |
15 | [dev]
16 | black>=24.4.2
17 | isort>=5.12.0
18 | flake8>=6.0.0
19 | pytest
20 | parameterized>=0.9.0
21 | math-verify
22 |
23 | [eval]
24 | math-verify
25 |
26 | [quality]
27 | black>=24.4.2
28 | isort>=5.12.0
29 | flake8>=6.0.0
30 |
31 | [tests]
32 | pytest
33 | parameterized>=0.9.0
34 |
35 | [torch]
36 | torch>=2.5.1
37 |
--------------------------------------------------------------------------------
/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | python -u infer.py \
2 | --model_path "checkpoint_path" \
3 | --test_file_path "/group/40101/milkcychen/SEED-Bench-R1/data/annotations/validation_L1.jsonl" \
4 | --data_root_dir "/group/40101/milkcychen/SEED-Bench-R1/data" \
5 | --test_batch_size 1
6 |
7 | python -u infer.py \
8 | --model_path "checkpoint_path" \
9 | --test_file_path "/group/40101/milkcychen/SEED-Bench-R1/data/annotations/validation_L2.jsonl" \
10 | --data_root_dir "/group/40101/milkcychen/SEED-Bench-R1/data" \
11 | --test_batch_size 1
12 |
13 |
14 | python -u infer.py \
15 | --model_path "checkpoint_path" \
16 | --test_file_path "/group/40101/milkcychen/SEED-Bench-R1/data/annotations/validation_L3.jsonl" \
17 | --data_root_dir "/group/40101/milkcychen/SEED-Bench-R1/data" \
18 | --test_batch_size 1
--------------------------------------------------------------------------------
/qwen-vl-utils/requirements.lock:
--------------------------------------------------------------------------------
1 | # generated by rye
2 | # use `rye lock` or `rye sync` to update this lockfile
3 | #
4 | # last locked with the following flags:
5 | # pre: false
6 | # features: ["decord"]
7 | # all-features: false
8 | # with-sources: false
9 | # generate-hashes: false
10 | # universal: false
11 |
12 | -e file:.
13 | av==12.3.0
14 | # via qwen-vl-utils
15 | certifi==2022.12.7
16 | # via requests
17 | charset-normalizer==2.1.1
18 | # via requests
19 | decord==0.6.0
20 | # via qwen-vl-utils
21 | idna==3.4
22 | # via requests
23 | numpy==1.24.4
24 | # via decord
25 | packaging==24.1
26 | # via qwen-vl-utils
27 | pillow==10.2.0
28 | # via qwen-vl-utils
29 | requests==2.28.1
30 | # via qwen-vl-utils
31 | urllib3==1.26.13
32 | # via requests
33 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | default_section = FIRSTPARTY
3 | ensure_newline_before_comments = True
4 | force_grid_wrap = 0
5 | include_trailing_comma = True
6 | known_first_party = open_r1
7 | known_third_party =
8 | transformers
9 | datasets
10 | fugashi
11 | git
12 | h5py
13 | matplotlib
14 | nltk
15 | numpy
16 | packaging
17 | pandas
18 | psutil
19 | pytest
20 | rouge_score
21 | sacrebleu
22 | seqeval
23 | sklearn
24 | streamlit
25 | torch
26 | tqdm
27 |
28 | line_length = 119
29 | lines_after_imports = 2
30 | multi_line_output = 3
31 | use_parentheses = True
32 |
33 | [flake8]
34 | ignore = E203, E501, E741, W503, W605
35 | max-line-length = 119
36 | per-file-ignores =
37 | # imported but unused
38 | __init__.py: F401
39 |
40 | [tool:pytest]
41 | doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
--------------------------------------------------------------------------------
/src/open_r1.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | LICENSE
2 | README.md
3 | setup.cfg
4 | setup.py
5 | src/open_r1.egg-info/PKG-INFO
6 | src/open_r1.egg-info/SOURCES.txt
7 | src/open_r1.egg-info/dependency_links.txt
8 | src/open_r1.egg-info/not-zip-safe
9 | src/open_r1.egg-info/requires.txt
10 | src/open_r1.egg-info/top_level.txt
11 | src/open_r1_egoplan/__init__.py
12 | src/open_r1_egoplan/evaluate.py
13 | src/open_r1_egoplan/generate.py
14 | src/open_r1_egoplan/grpo.py
15 | src/open_r1_egoplan/sft.py
16 | src/open_r1_egoplan/sft_orig.py
17 | src/open_r1_egoplan/trainer/__init__.py
18 | src/open_r1_egoplan/trainer/grpo_trainer.py
19 | src/open_r1_video/__init__.py
20 | src/open_r1_video/evaluate.py
21 | src/open_r1_video/generate.py
22 | src/open_r1_video/grpo.py
23 | src/open_r1_video/sft.py
24 | src/open_r1_video/trainer/__init__.py
25 | src/open_r1_video/trainer/grpo_trainer.py
26 | src/open_r1_video/trainer/grpo_trainer_w2s.py
--------------------------------------------------------------------------------
/scripts/qwen2vl_sft_config.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: /group/40101/milkcychen/SEED-Bench-R1/pretrained_ckpt/Qwen2-VL-7B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 |
6 | # Data training arguments
7 | dataset_name: xxx
8 | jsonl_path: /group/40101/milkcychen/SEED-Bench-R1/data/annotations/training_with_cot_6k.jsonl
9 | data_root_dir: /group/40101/milkcychen/SEED-Bench-R1/data
10 |
11 | # SFT trainer config
12 | bf16: true
13 | do_eval: true
14 | eval_strategy: "no"
15 | gradient_accumulation_steps: 16
16 | gradient_checkpointing: true
17 | gradient_checkpointing_kwargs:
18 | use_reentrant: false
19 | learning_rate: 2.0e-05
20 | log_level: info
21 | logging_steps: 1
22 | logging_strategy: steps
23 | lr_scheduler_type: cosine
24 | packing: true
25 | max_seq_length: 8192
26 | max_steps: -1
27 | num_train_epochs: 1
28 | output_dir: /group/40101/milkcychen/SEED-Bench-R1/ckpt/Qwen2-VL-7B-Instruct-SFT/training_with_cot_6k
29 | overwrite_output_dir: true
30 | per_device_eval_batch_size: 1
31 | per_device_train_batch_size: 1
32 | report_to:
33 | - wandb
34 | save_strategy: steps
35 | seed: 42
36 | warmup_ratio: 0.1
37 | save_steps: 1
38 | save_only_model: true
--------------------------------------------------------------------------------
/scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 2,
24 | "offload_optimizer": {
25 | "device": "none",
26 | "pin_memory": true
27 | },
28 | "allgather_partitions": true,
29 | "allgather_bucket_size": 2e8,
30 | "overlap_comm": false,
31 | "reduce_scatter": true,
32 | "reduce_bucket_size": 2e8,
33 | "contiguous_gradients": true
34 | },
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 |
14 | "zero_optimization": {
15 | "stage": 3,
16 | "offload_optimizer": {
17 | "device": "none",
18 | "pin_memory": true
19 | },
20 | "offload_param": {
21 | "device": "none",
22 | "pin_memory": true
23 | },
24 | "overlap_comm": true,
25 | "contiguous_gradients": true,
26 | "sub_group_size": 1e9,
27 | "reduce_bucket_size": "auto",
28 | "stage3_prefetch_bucket_size": "auto",
29 | "stage3_param_persistence_threshold": "auto",
30 | "stage3_max_live_parameters": 1e9,
31 | "stage3_max_reuse_distance": 1e9,
32 | "stage3_gather_16bit_weights_on_model_save": true
33 | },
34 |
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/scripts/grpo.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT=Qwen2-VL-7B-GRPO
2 | export WANDB_NAME=training_6k-remove-formatreward-matchletterreward-f16
3 | export DEBUG_MODE=true
4 | export PROJECT_ROOT=/group/40101/milkcychen/SEED-Bench-R1
5 | export LOG_PATH=${PROJECT_ROOT}/output_ckpt/$WANDB_PROJECT/$WANDB_NAME/completions_log.txt
6 |
7 | mkdir -p ${PROJECT_ROOT}/output_ckpt/$WANDB_PROJECT/$WANDB_NAME
8 |
9 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node="8" \
10 | --nnodes="1" \
11 | --node_rank="0" \
12 | --master_addr="127.0.0.1" \
13 | --master_port="12352" \
14 | src/open_r1_egoplan/grpo.py \
15 | --deepspeed scripts/zero3_offload.json \
16 | --output_dir ${PROJECT_ROOT}/output_ckpt/$WANDB_PROJECT/$WANDB_NAME \
17 | --model_name_or_path ${PROJECT_ROOT}/pretrained_ckpt/Qwen2-VL-7B-Instruct \
18 | --dataset_name xxx \
19 | --jsonl_path ${PROJECT_ROOT}/data/annotations/training_6k.jsonl \
20 | --data_root_dir ${PROJECT_ROOT}/data \
21 | --max_prompt_length 8192 \
22 | --learning_rate 1e-6 \
23 | --beta 0.1 \
24 | --per_device_train_batch_size 1 \
25 | --gradient_accumulation_steps 1 \
26 | --logging_steps 1 \
27 | --bf16 \
28 | --torch_dtype bfloat16 \
29 | --data_seed 42 \
30 | --report_to wandb \
31 | --gradient_checkpointing true \
32 | --attn_implementation flash_attention_2 \
33 | --num_train_epochs 1 \
34 | --run_name $WANDB_NAME \
35 | --save_steps 10 \
36 | --save_only_model true
37 |
38 |
--------------------------------------------------------------------------------
/scripts/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 3,
24 | "offload_optimizer": {
25 | "device": "cpu",
26 | "pin_memory": true
27 | },
28 | "offload_param": {
29 | "device": "cpu",
30 | "pin_memory": true
31 | },
32 | "overlap_comm": true,
33 | "contiguous_gradients": true,
34 | "sub_group_size": 1e9,
35 | "reduce_bucket_size": "auto",
36 | "stage3_prefetch_bucket_size": "auto",
37 | "stage3_param_persistence_threshold": "auto",
38 | "stage3_max_live_parameters": 1e9,
39 | "stage3_max_reuse_distance": 1e9,
40 | "gather_16bit_weights_on_model_save": true
41 | },
42 | "gradient_accumulation_steps": "auto",
43 | "gradient_clipping": "auto",
44 | "train_batch_size": "auto",
45 | "train_micro_batch_size_per_gpu": "auto",
46 | "steps_per_print": 1e5,
47 | "wall_clock_breakdown": false
48 | }
--------------------------------------------------------------------------------
/qwen-vl-utils/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "qwen-vl-utils"
3 | version = "0.0.10"
4 | description = "Qwen Vision Language Model Utils - PyTorch"
5 | authors = [
6 | { name = "Qwen Team", email = "chenkeqin.ckq@alibaba-inc.com" },
7 | ]
8 | dependencies = [
9 | "requests",
10 | "pillow",
11 | "av",
12 | "packaging",
13 | ]
14 | readme = "README.md"
15 | requires-python = ">= 3.8"
16 | license = {text = "Apache-2.0"}
17 | keywords = [
18 | 'large language model',
19 | 'vision language model',
20 | 'qwen-vl',
21 | 'pytorch',
22 | ]
23 | classifiers = [
24 | 'Development Status :: 4 - Beta',
25 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
26 | 'Programming Language :: Python :: 3',
27 | 'License :: OSI Approved :: Apache Software License',
28 | ]
29 |
30 | [project.urls]
31 | Homepage = "https://github.com/QwenLM/Qwen2-VL/tree/main/qwen-vl-utils"
32 | Repository = "https://github.com/QwenLM/Qwen2-VL.git"
33 | Issues = "https://github.com/QwenLM/Qwen2-VL/issues"
34 |
35 | [project.optional-dependencies]
36 | decord = [
37 | "decord",
38 | ]
39 |
40 | [build-system]
41 | requires = ["hatchling"]
42 | build-backend = "hatchling.build"
43 |
44 | [tool.rye]
45 | managed = true
46 | dev-dependencies = [
47 | "torch",
48 | "torchvision",
49 | ]
50 |
51 | [tool.hatch.metadata]
52 | allow-direct-references = true
53 |
54 | [tool.hatch.build.targets.wheel]
55 | packages = ["src/qwen_vl_utils"]
56 |
57 | [tool.ruff]
58 | line-length = 119
59 |
60 | [tool.ruff.lint]
61 | ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
62 | select = ["C", "E", "F", "I", "W"]
63 |
64 | [tool.ruff.lint.per-file-ignores]
65 | "__init__.py" = ["E402", "F401", "F403", "F811"]
66 |
67 | [tool.ruff.lint.isort]
68 | lines-after-imports = 2
69 | known-first-party = ["qwen_vl_utils"]
70 |
71 | [tool.ruff.format]
72 | quote-style = "double"
73 | indent-style = "space"
74 | skip-magic-trailing-comma = false
75 | line-ending = "auto"
76 |
--------------------------------------------------------------------------------
/qwen-vl-utils/requirements-dev.lock:
--------------------------------------------------------------------------------
1 | # generated by rye
2 | # use `rye lock` or `rye sync` to update this lockfile
3 | #
4 | # last locked with the following flags:
5 | # pre: false
6 | # features: ["decord"]
7 | # all-features: false
8 | # with-sources: false
9 | # generate-hashes: false
10 | # universal: false
11 |
12 | -e file:.
13 | av==12.3.0
14 | # via qwen-vl-utils
15 | certifi==2022.12.7
16 | # via requests
17 | charset-normalizer==2.1.1
18 | # via requests
19 | decord==0.6.0
20 | # via qwen-vl-utils
21 | filelock==3.13.1
22 | # via torch
23 | # via triton
24 | fsspec==2024.2.0
25 | # via torch
26 | idna==3.4
27 | # via requests
28 | jinja2==3.1.3
29 | # via torch
30 | markupsafe==2.1.5
31 | # via jinja2
32 | mpmath==1.3.0
33 | # via sympy
34 | networkx==3.1
35 | # via torch
36 | numpy==1.24.1
37 | # via decord
38 | # via torchvision
39 | nvidia-cublas-cu12==12.1.3.1
40 | # via nvidia-cudnn-cu12
41 | # via nvidia-cusolver-cu12
42 | # via torch
43 | nvidia-cuda-cupti-cu12==12.1.105
44 | # via torch
45 | nvidia-cuda-nvrtc-cu12==12.1.105
46 | # via torch
47 | nvidia-cuda-runtime-cu12==12.1.105
48 | # via torch
49 | nvidia-cudnn-cu12==9.1.0.70
50 | # via torch
51 | nvidia-cufft-cu12==11.0.2.54
52 | # via torch
53 | nvidia-curand-cu12==10.3.2.106
54 | # via torch
55 | nvidia-cusolver-cu12==11.4.5.107
56 | # via torch
57 | nvidia-cusparse-cu12==12.1.0.106
58 | # via nvidia-cusolver-cu12
59 | # via torch
60 | nvidia-nccl-cu12==2.20.5
61 | # via torch
62 | nvidia-nvjitlink-cu12==12.6.68
63 | # via nvidia-cusolver-cu12
64 | # via nvidia-cusparse-cu12
65 | nvidia-nvtx-cu12==12.1.105
66 | # via torch
67 | packaging==24.1
68 | # via qwen-vl-utils
69 | pillow==10.2.0
70 | # via qwen-vl-utils
71 | # via torchvision
72 | requests==2.28.1
73 | # via qwen-vl-utils
74 | sympy==1.12
75 | # via torch
76 | torch==2.4.0
77 | # via torchvision
78 | torchvision==0.19.0
79 | triton==3.0.0
80 | # via torch
81 | typing-extensions==4.9.0
82 | # via torch
83 | urllib3==1.26.13
84 | # via requests
85 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py
16 |
17 |
18 | import re
19 | import shutil
20 | from pathlib import Path
21 |
22 | from setuptools import find_packages, setup
23 |
24 |
25 | # Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
26 | stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
27 | if stale_egg_info.exists():
28 | print(
29 | (
30 | "Warning: {} exists.\n\n"
31 | "If you recently updated open_r1, this is expected,\n"
32 | "but it may prevent open_r1 from installing in editable mode.\n\n"
33 | "This directory is automatically generated by Python's packaging tools.\n"
34 | "I will remove it now.\n\n"
35 | "See https://github.com/pypa/pip/issues/5466 for details.\n"
36 | ).format(stale_egg_info)
37 | )
38 | shutil.rmtree(stale_egg_info)
39 |
40 |
41 | # IMPORTANT: all dependencies should be listed here with their version requirements, if any.
42 | # * If a dependency is fast-moving (e.g. transformers), pin to the exact version
43 | _deps = [
44 | "accelerate>=1.2.1",
45 | "bitsandbytes>=0.43.0",
46 | "black>=24.4.2",
47 | "datasets>=3.2.0",
48 | "deepspeed==0.15.4",
49 | "distilabel[vllm,ray,openai]>=1.5.2",
50 | "einops>=0.8.0",
51 | "flake8>=6.0.0",
52 | "hf_transfer>=0.1.4",
53 | "huggingface-hub[cli]>=0.19.2,<1.0",
54 | "isort>=5.12.0",
55 | "liger_kernel==0.5.2",
56 | "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]",
57 | "math-verify", # Used for math verification in grpo
58 | "packaging>=23.0",
59 | "parameterized>=0.9.0",
60 | "pytest",
61 | "safetensors>=0.3.3",
62 | "sentencepiece>=0.1.99",
63 | "torch>=2.5.1",
64 | "transformers",
65 | "trl",
66 | "vllm==0.6.6.post1",
67 | "wandb>=0.19.1",
68 | "pillow",
69 | ]
70 |
71 | # this is a lookup table with items like:
72 | #
73 | # tokenizers: "tokenizers==0.9.4"
74 | # packaging: "packaging"
75 | #
76 | # some of the values are versioned whereas others aren't.
77 | deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
78 |
79 |
80 | def deps_list(*pkgs):
81 | return [deps[pkg] for pkg in pkgs]
82 |
83 |
84 | extras = {}
85 | extras["tests"] = deps_list("pytest", "parameterized")
86 | extras["torch"] = deps_list("torch")
87 | extras["quality"] = deps_list("black", "isort", "flake8")
88 | extras["eval"] = deps_list("math-verify")
89 | extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]
90 |
91 | # core dependencies shared across the whole project - keep this to a bare minimum :)
92 | install_requires = [
93 | deps["accelerate"],
94 | deps["bitsandbytes"],
95 | deps["einops"],
96 | deps["datasets"],
97 | deps["deepspeed"],
98 | deps["hf_transfer"],
99 | deps["huggingface-hub"],
100 | deps["liger_kernel"],
101 | deps["packaging"], # utilities from PyPA to e.g., compare versions
102 | deps["safetensors"],
103 | deps["sentencepiece"],
104 | deps["transformers"],
105 | deps["trl"],
106 | ]
107 |
108 | setup(
109 | name="open-r1",
110 | version="0.1.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
111 | author="The Hugging Face team (past and future)",
112 | author_email="lewis@huggingface.co",
113 | description="Open R1",
114 | long_description=open("README.md", "r", encoding="utf-8").read(),
115 | long_description_content_type="text/markdown",
116 | keywords="llm inference-time compute reasoning",
117 | license="Apache",
118 | url="https://github.com/huggingface/open-r1",
119 | package_dir={"": "src"},
120 | packages=find_packages("src"),
121 | zip_safe=False,
122 | extras_require=extras,
123 | python_requires=">=3.10.9",
124 | install_requires=install_requires,
125 | classifiers=[
126 | "Development Status :: 3 - Alpha",
127 | "Intended Audience :: Developers",
128 | "Intended Audience :: Education",
129 | "Intended Audience :: Science/Research",
130 | "License :: OSI Approved :: Apache Software License",
131 | "Operating System :: OS Independent",
132 | "Programming Language :: Python :: 3",
133 | "Programming Language :: Python :: 3.10",
134 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
135 | ],
136 | )
137 |
--------------------------------------------------------------------------------
/qwen-vl-utils/README.md:
--------------------------------------------------------------------------------
1 | # qwen-vl-utils
2 |
3 | Qwen-VL Utils contains a set of helper functions for processing and integrating visual language information with Qwen-VL Series Model.
4 |
5 | ## Install
6 |
7 | ```bash
8 | pip install qwen-vl-utils
9 | ```
10 |
11 | ## Usage
12 |
13 | ### Qwen2VL
14 |
15 | ```python
16 | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
17 | from qwen_vl_utils import process_vision_info
18 |
19 |
20 | # You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
21 | messages = [
22 | # Image
23 | ## Local file path
24 | [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
25 | ## Image URL
26 | [{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
27 | ## Base64 encoded image
28 | [{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
29 | ## PIL.Image.Image
30 | [{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
31 | ## Model dynamically adjusts image size, specify dimensions if required.
32 | [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
33 | # Video
34 | ## Local video path
35 | [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
36 | ## Local video frames
37 | [{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
38 | ## Model dynamically adjusts video nframes, video height and width. specify args if required.
39 | [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
40 | ]
41 |
42 | processor = AutoProcessor.from_pretrained(model_path)
43 | model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
44 | text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
45 | images, videos = process_vision_info(messages)
46 | inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt")
47 | print(inputs)
48 | generated_ids = model.generate(**inputs)
49 | print(generated_ids)
50 | ```
51 |
52 | ### Qwen2.5VL
53 |
54 | ```python
55 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
56 | from qwen_vl_utils import process_vision_info
57 |
58 |
59 | # You can set the maximum tokens for a video through the environment variable VIDEO_MAX_PIXELS
60 | # based on the maximum tokens that the model can accept.
61 | # export VIDEO_MAX_PIXELS = 32000 * 28 * 28 * 0.9
62 |
63 |
64 | # You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
65 | messages = [
66 | # Image
67 | ## Local file path
68 | [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
69 | ## Image URL
70 | [{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
71 | ## Base64 encoded image
72 | [{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
73 | ## PIL.Image.Image
74 | [{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
75 | ## Model dynamically adjusts image size, specify dimensions if required.
76 | [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
77 | # Video
78 | ## Local video path
79 | [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
80 | ## Local video frames
81 | [{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
82 | ## Model dynamically adjusts video nframes, video height and width. specify args if required.
83 | [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
84 | ]
85 |
86 | processor = AutoProcessor.from_pretrained(model_path)
87 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
88 | text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
89 | images, videos, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
90 | inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt", **video_kwargs)
91 | print(inputs)
92 | generated_ids = model.generate(**inputs)
93 | print(generated_ids)
94 | ```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Exploring the Effect of Reinforcement Learning on Video Understanding:
5 | Insights from SEED-Bench-R1
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | ## 🚀Introduction
19 |
20 | Recent advancements in Chain of Thought (COT) generation have significantly improved the reasoning capabilities of Large Language Models (LLMs), with reinforcement learning (RL) emerging as an effective post-training approach. Multimodal Large Language Models (MLLMs) inherit this reasoning potential but remain underexplored in tasks requiring both perception and logical reasoning. To address this, we introduce SEED-Bench-R1, a benchmark designed to systematically evaluate post-training methods for MLLMs in video understanding. It includes intricate real-world videos and complex everyday planning tasks in the format of multiple-choice questions, requiring sophisticated perception and reasoning. SEED-Bench-R1 assesses generalization through a three-level hierarchy: in-distribution, cross-environment, and cross-environment-task scenarios, equipped with a large-scale training dataset with easily verifiable ground-truth answers. Using Qwen2-VL-Instruct-7B as a base model, we compare RL with supervised fine-tuning (SFT), demonstrating RL's data efficiency and superior performance on both in-distribution and out-of-distribution tasks, even outperforming SFT on general video understanding benchmarks like LongVideoBench. Our detailed analysis reveals that RL enhances visual perception but often produces less logically coherent reasoning chains. We identify key limitations such as inconsistent reasoning and overlooked visual cues, and suggest future improvements in base model reasoning, reward modeling, and RL robustness against noisy signals.
21 |
22 | ## 🚩News
23 | - [2025/06/18]💥We release the training code of [GRPO-CARE](https://github.com/TencentARC/GRPO-CARE), a novel consistency-aware RL framework, and the corresponding model checkpoints!
24 | - [2025/03/31] We release the datasets of SEED-Bench-R1 and the training / evaluation codes.
25 |
26 |
27 |
28 | ## 📝Data
29 |
30 | SEED-Bench-R1 consists of a large-scale training set and a hierarchical three-level validation set for in-distribution, cross-environment, and cross-environment-task evaluations. The datasets can be downloaded from [HuggingFace](https://huggingface.co/datasets/TencentARC/SEED-Bench-R1).
31 |
32 | Specifically, SEED-Bench-R1 is built on our prior works, reusing the training and validation data from our [EgoPlan-Bench](https://github.com/ChenYi99/EgoPlan), as well as the test data from our [EgoPlan-Bench2](https://github.com/qiulu66/EgoPlan-Bench2). The validation data from EgoPlan-Bench are used for Level-1 (in-distribution) and Level-2 (OOD, cross-environment) evaluation, while the test data from EgoPlan-Bench2 cover more general domains and are used for Level-3 (OOD, cross-environment-task) evaluation.
33 |
34 |
35 |
36 |
37 |
38 | Questions from the human-verified validation data are formatted as multiple-choice problems. MLLMs need to select the most reasonable answer from four candidate choices. The primary metric is Accuracy.
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 | ## 🔥Training Models
47 |
48 | > [!NOTE] The training code is modified from [Open-R1-Video](https://github.com/Wang-Xiaodong1899/Open-R1-Video).
49 | > The training commands below are configured for a node of 8 x A100 (40GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps.
50 |
51 | ### Set up
52 | ```
53 | git clone https://github.com/TencentARC/SEED-Bench-R1.git
54 | cd SEED-Bench-R1
55 | conda create -n r1 python=3.10
56 | conda activate r1
57 | pip3 install -e ".[dev]"
58 | pip3 install flash_attn --no-build-isolation
59 | cd qwen-vl-utils
60 | pip install -e .
61 | cd ..
62 |
63 | # download data and put in data/
64 | git lfs install
65 | git clone https://huggingface.co/datasets/TencentARC/SEED-Bench-R1
66 |
67 | # download the model checkpoint of Qwen2-VL-7B-Instruct
68 | git clone https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
69 | ```
70 |
71 |
72 | ### Post-training Qwen2-VL-7B-Instruct
73 |
74 | - To run GRPO on Qwen2-VL-7B with [grpo.sh](scripts/grpo.sh):
75 |
76 | ```
77 | bash scripts/grpo.sh
78 | ```
79 |
80 |
81 | - To run SFT on Qwen2-VL-7B with [sft.sh](scripts/sft.sh):
82 |
83 | ```
84 | bash scripts/sft.sh
85 | ```
86 |
87 |
88 | ## 🤖Evaluating Models
89 |
90 | ### Inference
91 |
92 | Inference with the post-trained models
93 | ```
94 | bash scripts/eval.sh
95 | ```
96 |
97 |
98 | Evaluation results:
99 |
100 | 
101 |
102 |
103 | ## 🙌References & Acknowledgements
104 | We sincerely thank the contributions from the open source community. The related projects are as follows:
105 | - [Open-R1-Video](https://github.com/Wang-Xiaodong1899/Open-R1-Video)
106 | - [EgoPlan](https://github.com/ChenYi99/EgoPlan)
107 | - [EgoPlan-Bench2](https://github.com/qiulu66/EgoPlan-Bench2)
108 | - [open-r1-multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal)
109 | - [lmm-r1](https://github.com/TideDra/lmm-r1)
110 | - [DeepSeek](https://github.com/deepseek-ai/DeepSeek-R1)
111 | - [Open-R1](https://github.com/huggingface/open-r1)
112 | - [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF)
113 |
114 | ## ⭐License
115 | The video samples in SEED-Bench-R1 are collected from [Epic-Kitchens](https://epic-kitchens.github.io/2025) and [Ego4D](https://ego4d-data.org/). Users must follow the related licenses ([Epic-Kitchens](https://creativecommons.org/licenses/by-nc/4.0/) and [Ego4D](https://ego4ddataset.com/ego4d-license/)) to use these video samples for training and validation. SEED-Bench-R1 does not hold the copyright for these videos and the copyright belongs to the original owner of these datasets.
116 |
117 |
118 | ## 📚Citation
119 | If you find our project helpful, hope you can star our repository and cite our paper as follows:
120 |
121 | ```bibtex
122 | @article{chen2025exploring,
123 | title={Exploring the Effect of Reinforcement Learning on Video Understanding: Insights from SEED-Bench-R1},
124 | author={Chen, Yi and Ge, Yuying and Wang, Rui and Ge, Yixiao and Qiu, Lu and Shan, Ying and Liu, Xihui},
125 | journal={arXiv preprint arXiv:2503.24376},
126 | year={2025}
127 | }
128 |
129 | @article{chen2023egoplan,
130 | title={Egoplan-bench: Benchmarking multimodal large language models for human-level planning},
131 | author={Chen, Yi and Ge, Yuying and Ge, Yixiao and Ding, Mingyu and Li, Bohao and Wang, Rui and Xu, Ruifeng and Shan, Ying and Liu, Xihui},
132 | journal={arXiv preprint arXiv:2312.06722},
133 | year={2023}
134 | }
135 |
136 | @article{qiu2024egoplan,
137 | title={Egoplan-bench2: A benchmark for multimodal large language model planning in real-world scenarios},
138 | author={Qiu, Lu and Ge, Yuying and Chen, Yi and Ge, Yixiao and Shan, Ying and Liu, Xihui},
139 | journal={arXiv preprint arXiv:2412.04447},
140 | year={2024}
141 | }
142 | ```
143 |
--------------------------------------------------------------------------------
/src/open_r1.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.2
2 | Name: open-r1
3 | Version: 0.1.0.dev0
4 | Summary: Open R1
5 | Home-page: https://github.com/huggingface/open-r1
6 | Author: The Hugging Face team (past and future)
7 | Author-email: lewis@huggingface.co
8 | License: Apache
9 | Keywords: llm inference-time compute reasoning
10 | Classifier: Development Status :: 3 - Alpha
11 | Classifier: Intended Audience :: Developers
12 | Classifier: Intended Audience :: Education
13 | Classifier: Intended Audience :: Science/Research
14 | Classifier: License :: OSI Approved :: Apache Software License
15 | Classifier: Operating System :: OS Independent
16 | Classifier: Programming Language :: Python :: 3
17 | Classifier: Programming Language :: Python :: 3.10
18 | Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19 | Requires-Python: >=3.10.9
20 | Description-Content-Type: text/markdown
21 | License-File: LICENSE
22 | Requires-Dist: accelerate>=1.2.1
23 | Requires-Dist: bitsandbytes>=0.43.0
24 | Requires-Dist: einops>=0.8.0
25 | Requires-Dist: datasets>=3.2.0
26 | Requires-Dist: deepspeed==0.15.4
27 | Requires-Dist: hf_transfer>=0.1.4
28 | Requires-Dist: huggingface-hub[cli]<1.0,>=0.19.2
29 | Requires-Dist: liger_kernel==0.5.2
30 | Requires-Dist: packaging>=23.0
31 | Requires-Dist: safetensors>=0.3.3
32 | Requires-Dist: sentencepiece>=0.1.99
33 | Requires-Dist: transformers
34 | Requires-Dist: trl
35 | Provides-Extra: tests
36 | Requires-Dist: pytest; extra == "tests"
37 | Requires-Dist: parameterized>=0.9.0; extra == "tests"
38 | Provides-Extra: torch
39 | Requires-Dist: torch>=2.5.1; extra == "torch"
40 | Provides-Extra: quality
41 | Requires-Dist: black>=24.4.2; extra == "quality"
42 | Requires-Dist: isort>=5.12.0; extra == "quality"
43 | Requires-Dist: flake8>=6.0.0; extra == "quality"
44 | Provides-Extra: eval
45 | Requires-Dist: math-verify; extra == "eval"
46 | Provides-Extra: dev
47 | Requires-Dist: black>=24.4.2; extra == "dev"
48 | Requires-Dist: isort>=5.12.0; extra == "dev"
49 | Requires-Dist: flake8>=6.0.0; extra == "dev"
50 | Requires-Dist: pytest; extra == "dev"
51 | Requires-Dist: parameterized>=0.9.0; extra == "dev"
52 | Requires-Dist: math-verify; extra == "dev"
53 | Dynamic: author
54 | Dynamic: author-email
55 | Dynamic: classifier
56 | Dynamic: description
57 | Dynamic: description-content-type
58 | Dynamic: home-page
59 | Dynamic: keywords
60 | Dynamic: license
61 | Dynamic: provides-extra
62 | Dynamic: requires-dist
63 | Dynamic: requires-python
64 | Dynamic: summary
65 |
66 | # Open R1 Video
67 |
68 | We introduce R1's paradigm to video understanding tasks and open-sourced the training code and data.
69 |
70 | [🤗 Models](https://huggingface.co/Xiaodong/Open-R1-Video-7B) | [🤗 Datasets](https://huggingface.co/datasets/Xiaodong/open-r1-video-4k) | [Wandb Logs](https://wandb.ai/xiaodongwang/Qwen2-VL-7B-Video-GRPO/runs/mb6ued4m?nw=nwuserxiaodongwang)
71 |
72 | > [!NOTE]
73 | > Although our insights may not be guaranteed to be correct, we commit to sharing them truthfully and honestly. We welcome community feedback and discussions to improve our understanding on multimodal reasoning models.
74 |
75 | ## News
76 | - [2025/02/22] We release a provisional model [Open-R1-Video-7B](https://huggingface.co/Xiaodong/Open-R1-Video-7B), inference scripts, and evaluation results.
77 | - [2025/02/18] We release training code and data of Open-R1-Video!
78 |
79 | ## Our Findings
80 | ### GRPO training that forces thinking can improve video understanding
81 |
82 | 
83 |
84 | We train [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) on simple video dataset [open-r1-video-4k](https://huggingface.co/datasets/Xiaodong/open-r1-video-4k) using 4 x A100 (80G) GPUs, and the training only utilize video, query, and the ground truth answer (the letter of the correct answer). We only used GRPO (pure reinforcement learning without labeled reasoning trajectories) to train the model and achieved promising rewards during model training. We release our [wandb logs](https://wandb.ai/xiaodongwang/Qwen2-VL-7B-Video-GRPO/runs/mb6ued4m?nw=nwuserxiaodongwang) for reference.
85 | 
86 |
87 | **What We Did**
88 | - Introduce R1 to Video-LMM (e.g., Qwen2-VL) based on [huggingface/open-r1](https://github.com/huggingface/open-r1) and [deepseek-ai/DeepSeek-R1](https://github.com/deepseek-ai/DeepSeek-R1).
89 | - Open-sourced the simple training data [open-r1-video-4k](https://huggingface.co/datasets/Xiaodong/open-r1-video-4k).
90 | - The simple reformat data is available in [open-r1-video-4k](https://huggingface.co/datasets/Xiaodong/open-r1-video-4k).
91 | - The video data is available in [LLaVA-Video-large-swift](https://huggingface.co/datasets/malterei/LLaVA-Video-large-swift).
92 |
93 |
94 |
95 | ## Training Models
96 |
97 | > [!NOTE]
98 | > The training commands below are configured for a node of 4 x A100 (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps.
99 |
100 | ### Set up
101 | ```
102 | git clone https://github.com/Wang-Xiaodong1899/Open-R1-Video.git
103 | cd Open-R1-Video
104 | conda create -n r1 python=3.10
105 | conda activate r1
106 | pip3 install -e ".[dev]"
107 | pip3 install flash_attn --no-build-isolation
108 | cd qwen-vl-utils
109 | pip install -e .
110 | cd ..
111 |
112 | # download data and put in data/
113 | wget https://huggingface.co/datasets/Xiaodong/open-r1-video-4k/resolve/main/LLaVA-Video-large-swift-origin.jsonl
114 | # like: data/LLaVA-Video-large-swift-origin.jsonl
115 |
116 | # download videos
117 | git lfs install
118 | git clone https://huggingface.co/datasets/malterei/LLaVA-Video-large-swift
119 |
120 | ```
121 |
122 |
123 | ### GRPO on Qwen2-VL/7B
124 | > [!NOTE]
125 | > Our training also support single A100 (80G) GPU training. Just modify the GPU and you’re good to go!
126 |
127 | > We removed format accuracy during 7B model training and slightly modified the final answer matching to calculate the accuracy reward. See this [commit](https://github.com/Wang-Xiaodong1899/Open-R1-Video/commit/2679e082aaf608fd167a0ad5e6f2afb4f548e25f#diff-d6985fa15a3c7864e723ebd4c04bfdc2f13c5e87af36f87d656e32666f8e0eeb).
128 |
129 | To run GRPO on Qwen2-VL-7B:
130 |
131 | ```
132 | bash qwen-7b.sh
133 | ```
134 |
135 | Please refer to [qwen-7b.sh](qwen-7b.sh) for more details.
136 |
137 |
138 | ## Evaluating models
139 |
140 | ### Inference
141 |
142 | Infer the video reasoning model!
143 | ```
144 | python infer.py
145 | ```
146 |
147 |
148 | 
149 |
150 | [Video link](https://youtu.be/2evryGv-oZ4)
151 |
152 | Inference results:
153 |
154 | 
155 |
156 | ### Evaluation
157 |
158 | > [!NOTE]
159 | > We use [Lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) to evaluate models.
160 |
161 |
162 | | Benchmarks | Qwen2-VL-7B-Instruct(w.o reasoning) | Qwen2-VL-7B-Instruct(w. reasoning) | Open-R1-Video-7B(w. reasoning) |
163 | |---------------------------|-------------------------------------|------------------------------------|---------------------------------|
164 | | [LongVideoBench](https://longvideobench.github.io/) (16 frames) | 53.33 | 41.89 | 48.99 (↑7.1) |
165 |
166 |
167 | ## RL Data Reformat
168 |
169 | We provide the easy reformat method to obtain the data for GRPO training, which only utilize video, query, and final answer. Please refer to [format_video_data.py](scripts/format_video_data.py) for more details.
170 |
171 | Users can view data in [open-r1-video-4k](https://huggingface.co/datasets/Xiaodong/open-r1-video-4k). The `original question`/`original answer` are from the original dataset.
172 |
173 | ## References & Acknowledgements
174 | We sincerely thank the contributions from the open source community, including the reproduction of [DeepSeek](https://github.com/deepseek-ai/DeepSeek-R1), [Open-R1](https://github.com/huggingface/open-r1), and [R1-multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal), etc.
175 |
176 | The related projects are as follows:
177 | - [open-r1-multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal)
178 | - [lmm-r1](https://github.com/TideDra/lmm-r1)
179 | - [DeepSeek](https://github.com/deepseek-ai/DeepSeek-R1)
180 | - [open-r1](https://github.com/huggingface/open-r1)
181 | - [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF)
182 | - [LLaVA-NeXT](https://github.com/LLaVA-VL/LLaVA-NeXT)
183 | - [LLaVA-Video-large-swift](https://huggingface.co/datasets/malterei/LLaVA-Video-large-swift)
184 |
185 | ## Citation
186 | If you find this useful, you can choose to cite us.
187 |
188 | ```bibtex
189 | @misc{wang-2025-open-r1-video,
190 | author = {Xiaodong Wang and Peixi Peng},
191 | title = {Open-R1-Video},
192 | year = {2025},
193 | publisher = {GitHub},
194 | journal = {GitHub repository},
195 | howpublished = {\url{https://github.com/Wang-Xiaodong1899/Open-R1-Video}}
196 | }
197 | ```
198 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | from qwen_vl_utils import process_vision_info
2 | import torch
3 | from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4 | import os
5 | import argparse
6 | import json
7 | from tqdm import tqdm
8 | from math_verify import parse, verify
9 | import re
10 | import numpy as np
11 |
12 | QUESTION_TEMPLATE = "{Question} Output the thinking process in and final answer in tags, i.e., reasoning process here answer here . "
13 |
14 | def make_conversation_egoplan(example, data_root_dir):
15 | options = [
16 | example['choice_a'],
17 | example['choice_b'],
18 | example['choice_c'],
19 | example['choice_d'],
20 | ]
21 |
22 | answer_index = example['golden_choice_idx']
23 |
24 | problem = f"{example['question']}\n" + "\n".join([f"{chr(65 + i)}. {option}" for i, option in enumerate(options)]) + "\n"
25 | solution = f"{answer_index}."
26 |
27 | content = []
28 | if len(example['task_progress_metadata']) > 0:
29 | video_path = os.path.join(data_root_dir, 'videos', example['video_source'], example['video_basename'])
30 | content.append({"type": "video", "video": video_path})
31 |
32 | image_path = os.path.join(data_root_dir, 'images', example['video_source'], example['current_observation_basename'])
33 | content.extend([
34 | {"type": "image", "image": image_path},
35 | {"type": "text", "text": QUESTION_TEMPLATE.format(Question=problem)},
36 | ])
37 |
38 | feature = {
39 | "sample_id": example["sample_id"],
40 | "prompt": [
41 | {
42 | "role": "user",
43 | "content": content,
44 | },
45 | ],
46 | 'problem': problem,
47 | 'solution': solution,
48 | }
49 | return feature
50 |
51 |
52 |
53 | def make_conversation_longvideobench(example, data_root_dir):
54 | options = example['candidates']
55 | answer_index = chr(65 + example['correct_choice'])
56 |
57 | problem = f"{example['question']}\n" + "\n".join([f"{chr(65 + i)}. {option}" for i, option in enumerate(options)]) + "\n"
58 | solution = f"{answer_index}."
59 | video_path = os.path.join(data_root_dir, "videos", example["video_path"])
60 | content = [
61 | {"type": "video", "video": video_path},
62 | {"type": "text", "text": QUESTION_TEMPLATE.format(Question=problem)},
63 | ]
64 |
65 | feature = {
66 | "sample_id": example["sample_id"],
67 | "prompt": [
68 | {
69 | "role": "user",
70 | "content": content,
71 | },
72 | ],
73 | 'problem': problem,
74 | 'solution': solution,
75 | }
76 | return feature
77 |
78 |
79 | def accuracy_reward(content, sol):
80 | reward = 0.0
81 | # Try symbolic verification first
82 | try:
83 | answer = parse(content)
84 | if float(verify(answer, parse(sol))) > 0:
85 | reward = 1.0
86 | except Exception:
87 | pass # Continue to next verification method if this fails
88 |
89 | # If symbolic verification failed, try string matching
90 | if reward == 0.0:
91 | try:
92 | # Extract answer from solution if it has think/answer tags
93 | sol_match = re.search(r"(.*?)", sol, re.DOTALL)
94 | ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
95 |
96 | # Extract answer from content if it has think/answer tags
97 | content_match = re.search(r"(.*?)", content, re.DOTALL)
98 | # student_answer = content_match.group(1).strip() if content_match else content.strip()
99 | if content_match:
100 | student_answer = content_match.group(1).strip()
101 | # HACK, if First letter is correct reward 1
102 | # Compare the extracted answers
103 | if student_answer[0] == ground_truth[0]:
104 | reward = 1.0
105 | else:
106 | reward = 0.0
107 | except Exception:
108 | pass # Keep reward as 0.0 if both methods fail
109 | return reward
110 |
111 |
112 |
113 |
114 |
115 | if __name__ == "__main__":
116 | parser = argparse.ArgumentParser()
117 | parser.add_argument("--model_path", default="/group/40101/milkcychen/Open-R1-Video/ckpt/Qwen2-VL-7B-EgoPlan-GRPO/egoplan-it-8k-remove-formatreward-matchletterreward-f16export/checkpoint-850")
118 | parser.add_argument("--test_file_path", default="/group/40121/public_datasets/SEED-Bench-R1/annotations/validation_L1.jsonl")
119 | parser.add_argument("--data_root_dir", default="/group/40121/public_datasets/SEED-Bench-R1")
120 | parser.add_argument("--test_batch_size", type=int, default=1)
121 | parser.add_argument("--do_sampling", type=int, default=1)
122 | # parser.add_argument("--")
123 | args = parser.parse_args()
124 |
125 | # default: Load the model on the available device(s)
126 | model = Qwen2VLForConditionalGeneration.from_pretrained(
127 | args.model_path,
128 | torch_dtype=torch.bfloat16,
129 | attn_implementation="flash_attention_2",
130 | device_map="auto",
131 | )
132 |
133 | processor = AutoProcessor.from_pretrained(args.model_path)
134 |
135 | test_file_basename = os.path.basename(args.test_file_path)
136 | output_eval_results_path = os.path.join(args.model_path, f"eval_results_for_{test_file_basename}")
137 |
138 | rewards = []
139 | evaluated_example_ids = set()
140 | if os.path.exists(output_eval_results_path):
141 | with open(output_eval_results_path) as fi:
142 | for line in tqdm(fi):
143 | example = json.loads(line)
144 | evaluated_example_ids.add(example['sample_id'])
145 | rewards.append(example['reward'])
146 |
147 | test_examples = []
148 | if args.test_file_path.endswith('.jsonl'):
149 | with open(args.test_file_path) as f:
150 | examples = []
151 | for line in f:
152 | example = json.loads(line)
153 | examples.append(example)
154 | else:
155 | with open(args.test_file_path) as f:
156 | examples = json.load(f)
157 |
158 |
159 | for example in examples:
160 | if 'sample_id' not in example:
161 | example['sample_id'] = example['id']
162 | if example['sample_id'] not in evaluated_example_ids:
163 | test_examples.append(example)
164 | else:
165 | print(f"skip {example['sample_id']}")
166 |
167 |
168 | t = tqdm(total=len(test_examples))
169 |
170 | with open(output_eval_results_path, 'a') as fo:
171 | for i in range(0, len(test_examples), args.test_batch_size):
172 | batch = test_examples[i:i+args.test_batch_size]
173 | if 'EgoPlan' in args.data_root_dir:
174 | features = [make_conversation_egoplan(example, args.data_root_dir) for example in batch]
175 | elif 'LongVideoBench' in args.data_root_dir:
176 | features = [make_conversation_longvideobench(example, args.data_root_dir) for example in batch]
177 | else:
178 | raise NotImplementedError
179 | prompts = [feature['prompt'] for feature in features]
180 | prompt_texts = [
181 | processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
182 | for msg in prompts
183 | ]
184 | image_inputs, video_inputs, video_kwargs = process_vision_info(prompts, return_video_kwargs=True)
185 | prompt_inputs = processor(
186 | text=prompt_texts,
187 | images=image_inputs,
188 | videos=video_inputs,
189 | return_tensors="pt",
190 | padding=True,
191 | padding_side="left",
192 | add_special_tokens=False,
193 | )
194 | prompt_inputs = prompt_inputs.to("cuda")
195 |
196 | # Inference
197 | generated_ids = model.generate(**prompt_inputs,
198 | max_new_tokens=256,
199 | do_sample=args.do_sampling,
200 | temperature=1
201 | )
202 | generated_ids_trimmed = [
203 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(prompt_inputs.input_ids, generated_ids)
204 | ]
205 | output_texts = processor.batch_decode(
206 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
207 | )
208 |
209 | for prompt, output_text, feature in zip(prompts, output_texts, features):
210 | reward = accuracy_reward(output_text, feature['solution'])
211 | rewards.append(reward)
212 | print(f"------------- Sample {feature['sample_id']} -------------\n")
213 | print(f"Question: {prompt[0]['content'][-1]['text']}\n")
214 | print(f"Content: {output_text}\n")
215 | print(f"Solution: {feature['solution']}\n")
216 | print(f"Reward: {reward}\n")
217 |
218 | feature['response'] = output_text
219 | feature['reward'] = reward
220 | fo.write(json.dumps(feature)+"\n")
221 | fo.flush()
222 |
223 | t.update(1)
224 | t.set_postfix(reward_mean=np.mean(rewards))
225 |
226 |
--------------------------------------------------------------------------------
/src/open_r1_egoplan/grpo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import re
17 | from dataclasses import dataclass, field
18 | from datetime import datetime
19 | from typing import Optional, List
20 |
21 | from datasets import load_dataset, Dataset, DatasetDict
22 |
23 | from math_verify import parse, verify
24 | from open_r1_egoplan.trainer import Qwen2VLGRPOTrainer
25 | from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
26 | import random
27 | from functools import partial
28 |
29 | @dataclass
30 | class GRPOScriptArguments(ScriptArguments):
31 | """
32 | Script arguments for the GRPO training script.
33 |
34 | Args:
35 | reward_funcs (`List[str]`):
36 | List of reward functions. Possible values: 'accuracy', 'format'.
37 | """
38 |
39 | reward_funcs: List[str] = field(
40 | default_factory=lambda: ["accuracy",],
41 | metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
42 | )
43 | max_pixels: Optional[int] = field(
44 | default=12845056,
45 | metadata={"help": "Maximum number of pixels for the image"},
46 | )
47 | min_pixels: Optional[int] = field(
48 | default=3136,
49 | metadata={"help": "Minimum number of pixels for the image"},
50 | )
51 | jsonl_path: Optional[str] = field(
52 | default=None,
53 | metadata={"help": "json file path"},
54 | )
55 | data_root_dir: Optional[str] = field(
56 | default=None,
57 | metadata={"help": "Root directory to datasets"},
58 | )
59 | freeze_visual_encoder: Optional[int] = field(
60 | default=0,
61 | metadata={"help": "Whether to freeze visual encoder"},
62 | )
63 |
64 | def accuracy_reward(completions, solution, prompts, **kwargs):
65 | """Reward function that checks if the completion is correct using either symbolic verification or exact string matching."""
66 | contents = [completion[0]["content"] for completion in completions]
67 | print(f"------------- Sample {kwargs['sample_id'][0]} -------------")
68 | print(prompts[0][0]['content'][-1]['text']+'\n')
69 | print(solution[0]+'\n')
70 | print(contents[:2]) # print online completion
71 | rewards = []
72 | current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
73 |
74 | for content, sol, prompt, sample_id in zip(contents, solution, prompts, kwargs['sample_id']):
75 | reward = 0.0
76 | # Try symbolic verification first
77 | try:
78 | answer = parse(content)
79 | if float(verify(answer, parse(sol))) > 0:
80 | reward = 1.0
81 | except Exception:
82 | pass # Continue to next verification method if this fails
83 |
84 | # If symbolic verification failed, try string matching
85 | if reward == 0.0:
86 | try:
87 | # Extract answer from solution if it has think/answer tags
88 | sol_match = re.search(r"(.*?)", sol, re.DOTALL)
89 | ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
90 |
91 | # Extract answer from content if it has think/answer tags
92 | content_match = re.search(r"(.*?)", content, re.DOTALL)
93 | # student_answer = content_match.group(1).strip() if content_match else content.strip()
94 | if content_match:
95 | student_answer = content_match.group(1).strip()
96 | # HACK, if First letter is correct reward 1
97 | # Compare the extracted answers
98 | if student_answer[0] == ground_truth[0]:
99 | reward = 1.0
100 | else:
101 | reward = 0.0
102 | except Exception:
103 | pass # Keep reward as 0.0 if both methods fail
104 |
105 | rewards.append(reward)
106 | if os.getenv("DEBUG_MODE") == "true":
107 | log_path = os.getenv("LOG_PATH")
108 | with open(log_path, "a") as f:
109 | f.write(f"------------- {current_time} Accuracy reward for Sample {sample_id}: {reward} -------------\n")
110 | f.write(f"Question: {prompt[0]['content'][-1]['text']}\n")
111 | f.write(f"Content: {content}\n")
112 | f.write(f"Solution: {sol}\n")
113 | return rewards
114 |
115 |
116 | def format_reward(completions, **kwargs):
117 | """Reward function that checks if the completion has a specific format."""
118 | pattern = r".*?\s*.*?"
119 | completion_contents = [completion[0]["content"] for completion in completions]
120 | matches = [re.match(pattern, content, re.DOTALL) for content in completion_contents]
121 | return [1.0 if match else 0.0 for match in matches]
122 |
123 |
124 | reward_funcs_registry = {
125 | "accuracy": accuracy_reward,
126 | "format": format_reward,
127 | }
128 |
129 |
130 | def create_dataset_from_jsonl_simple(jsonl_path):
131 | base_dataset = Dataset.from_json(jsonl_path)
132 | return DatasetDict({
133 | "train": base_dataset
134 | })
135 |
136 |
137 | def main(script_args, training_args, model_args):
138 | # Get reward functions
139 | reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
140 |
141 | if script_args.jsonl_path:
142 | # # load dataset from jsonl
143 | dataset = create_dataset_from_jsonl_simple(script_args.jsonl_path)
144 | else:
145 | # Load the dataset
146 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
147 |
148 | # Format into conversation
149 | QUESTION_TEMPLATE = "{Question} Output the thinking process in and final answer in tags, i.e., reasoning process here answer here . "
150 |
151 | def make_conversation_egoplan(example, data_root_dir):
152 | if 'golden_choice_idx' not in example:
153 | negative_answers = random.sample(example["negative_answers"], 3)
154 | options = negative_answers + [example["answer"]]
155 | else:
156 | options = [example['choice_a'], example['choice_b'], example['choice_c'], example['choice_d']]
157 |
158 | random.shuffle(options)
159 | answer_index = options.index(example["answer"])
160 | problem = f"{example['question']}\n" + "\n".join([f"{chr(65 + i)}. {option}" for i, option in enumerate(options)]) + "\n"
161 | solution = f"{chr(65 + answer_index)}."
162 |
163 |
164 | content = []
165 | if len(example['task_progress_metadata']) > 0:
166 | video_path = os.path.join(data_root_dir, 'videos', example['video_source'], example['video_basename'])
167 | content.append({"type": "video", "video": video_path})
168 |
169 | image_path = os.path.join(data_root_dir, 'images', example['video_source'], example['current_observation_basename'])
170 | content.extend([
171 | {"type": "image", "image": image_path},
172 | {"type": "text", "text": QUESTION_TEMPLATE.format(Question=problem)},
173 | ])
174 |
175 | feature = {
176 | "prompt": [
177 | {
178 | "role": "user",
179 | "content": content,
180 | },
181 | ],
182 | 'problem': problem,
183 | 'solution': solution,
184 | }
185 |
186 | return feature
187 |
188 | dataset = dataset.map(partial(make_conversation_egoplan, data_root_dir=script_args.data_root_dir))
189 | trainer_cls = Qwen2VLGRPOTrainer
190 |
191 | # Initialize the GRPO trainer
192 | trainer = trainer_cls(
193 | model=model_args.model_name_or_path,
194 | reward_funcs=reward_funcs,
195 | args=training_args,
196 | train_dataset=dataset[script_args.dataset_train_split],
197 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
198 | peft_config=get_peft_config(model_args),
199 | attn_implementation=model_args.attn_implementation,
200 | max_pixels=script_args.max_pixels,
201 | min_pixels=script_args.min_pixels,
202 | )
203 |
204 | # Train and push the model to the Hub
205 | trainer.train()
206 |
207 | # Save and push to hub
208 | trainer.save_model(training_args.output_dir)
209 | if training_args.push_to_hub:
210 | trainer.push_to_hub(dataset_name=script_args.dataset_name)
211 |
212 |
213 | if __name__ == "__main__":
214 | parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
215 | script_args, training_args, model_args = parser.parse_args_and_config()
216 | training_args.freeze_visual_encoder = script_args.freeze_visual_encoder
217 | main(script_args, training_args, model_args)
218 |
--------------------------------------------------------------------------------
/src/open_r1_egoplan/sft.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import logging
17 | import os
18 | import sys
19 |
20 | import datasets
21 | from dataclasses import dataclass, field
22 | from typing import Optional
23 | import torch
24 | import transformers
25 | from datasets import load_dataset
26 | from transformers import AutoTokenizer, set_seed, AutoProcessor
27 | from transformers.trainer_utils import get_last_checkpoint
28 | import trl
29 | from trl import (
30 | ModelConfig,
31 | ScriptArguments,
32 | SFTTrainer,
33 | TrlParser,
34 | get_kbit_device_map,
35 | get_peft_config,
36 | get_quantization_config,
37 | )
38 | from datasets import Dataset, DatasetDict
39 | from qwen_vl_utils import process_vision_info
40 | logger = logging.getLogger(__name__)
41 |
42 |
43 | @dataclass
44 | class SFTConfig(trl.SFTConfig):
45 | """
46 | args for callbacks, benchmarks etc
47 | """
48 |
49 | benchmarks: list[str] = field(
50 | default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
51 | )
52 | callbacks: list[str] = field(
53 | default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
54 | )
55 | system_prompt: Optional[str] = field(
56 | default=None,
57 | metadata={"help": "The optional system prompt to use for benchmarking."},
58 | )
59 | hub_model_revision: Optional[str] = field(
60 | default="main",
61 | metadata={"help": "The Hub model branch to push the model to."},
62 | )
63 | jsonl_path: Optional[str] = field(
64 | default=None,
65 | metadata={"help": "json file path"},
66 | )
67 | data_root_dir: Optional[str] = field(
68 | default=None,
69 | metadata={"help": "Root directory to datasets"},
70 | )
71 | overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
72 | push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
73 |
74 |
75 |
76 | processor = None
77 |
78 |
79 | def convert_example(example):
80 | messages = example['prompt']
81 | messages.append({
82 | "role": "assistant",
83 | "content": [{"type": "text", "text": example['response']}],
84 | })
85 | example["messages"] = messages
86 | return example
87 |
88 |
89 | def collate_fn(examples):
90 | for example in examples:
91 | for message in example["prompt"]:
92 | if isinstance(message["content"], list):
93 | new_content = []
94 | for ele in message["content"]:
95 | new_ele = {k: v for k, v in ele.items() if v is not None}
96 | if 'video' in new_ele:
97 | new_ele['video'] = os.path.join(training_args.data_root_dir, new_ele['video'])
98 | if 'image' in new_ele:
99 | new_ele['image'] = os.path.join(training_args.data_root_dir, new_ele['image'])
100 | new_content.append(new_ele)
101 | message["content"] = new_content
102 | convert_example(example)
103 |
104 | texts = [
105 | processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False)
106 | for example in examples
107 | ]
108 | image_inputs = []
109 | video_inputs = []
110 | for example in examples:
111 | imgs, vids = process_vision_info(example["messages"])
112 | image_inputs.append(imgs)
113 | video_inputs.append(vids)
114 | batch = processor(
115 | text=texts,
116 | images=image_inputs,
117 | videos=video_inputs,
118 | return_tensors="pt",
119 | padding=True,
120 | )
121 |
122 | labels = batch["input_ids"].clone()
123 | # labels[labels == processor.tokenizer.pad_token_id] = -100
124 | # image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
125 | # labels[labels == image_token_id] = -100
126 |
127 | im_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>")
128 | for i in range(labels.size(0)):
129 | last_index = (labels[i] == im_start_token_id).nonzero(as_tuple=True)[0][-1].item()
130 | labels[i, :last_index+3] = -100
131 |
132 | batch["labels"] = labels
133 | return batch
134 |
135 |
136 | def create_dataset_from_jsonl_simple(jsonl_path):
137 | base_dataset = Dataset.from_json(jsonl_path)
138 | return DatasetDict({
139 | "train": base_dataset
140 | })
141 |
142 | def main(script_args, training_args, model_args):
143 | # Set seed for reproducibility
144 | set_seed(training_args.seed)
145 |
146 | ###############
147 | # Setup logging
148 | ###############
149 | logging.basicConfig(
150 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
151 | datefmt="%Y-%m-%d %H:%M:%S",
152 | handlers=[logging.StreamHandler(sys.stdout)],
153 | )
154 | log_level = training_args.get_process_log_level()
155 | logger.setLevel(log_level)
156 | datasets.utils.logging.set_verbosity(log_level)
157 | transformers.utils.logging.set_verbosity(log_level)
158 | transformers.utils.logging.enable_default_handler()
159 | transformers.utils.logging.enable_explicit_format()
160 |
161 | # Log on each process a small summary
162 | logger.warning(
163 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
164 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
165 | )
166 | logger.info(f"Model parameters {model_args}")
167 | logger.info(f"Script parameters {script_args}")
168 | logger.info(f"Data parameters {training_args}")
169 |
170 | # Check for last checkpoint
171 | last_checkpoint = None
172 | if os.path.isdir(training_args.output_dir):
173 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
174 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
175 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
176 |
177 | ################
178 | # Load datasets
179 | ################
180 |
181 | if training_args.jsonl_path:
182 | # # load dataset from jsonl
183 | dataset = create_dataset_from_jsonl_simple(training_args.jsonl_path)
184 | else:
185 | # Load the dataset
186 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
187 |
188 |
189 | ################
190 | # Load tokenizer
191 | ################
192 | global processor
193 | if "vl" in model_args.model_name_or_path.lower():
194 | processor = AutoProcessor.from_pretrained(
195 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
196 | )
197 | logger.info("Using AutoProcessor for vision-language model.")
198 | else:
199 | processor = AutoTokenizer.from_pretrained(
200 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
201 | )
202 | logger.info("Using AutoTokenizer for text-only model.")
203 | if hasattr(processor, "pad_token") and processor.pad_token is None:
204 | processor.pad_token = processor.eos_token
205 | elif hasattr(processor.tokenizer, "pad_token") and processor.tokenizer.pad_token is None:
206 | processor.tokenizer.pad_token = processor.tokenizer.eos_token
207 |
208 | ###################
209 | # Model init kwargs
210 | ###################
211 | logger.info("*** Initializing model kwargs ***")
212 | torch_dtype = (
213 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
214 | )
215 | quantization_config = get_quantization_config(model_args)
216 | model_kwargs = dict(
217 | revision=model_args.model_revision,
218 | trust_remote_code=model_args.trust_remote_code,
219 | attn_implementation=model_args.attn_implementation,
220 | torch_dtype=torch_dtype,
221 | use_cache=False if training_args.gradient_checkpointing else True,
222 | device_map=get_kbit_device_map() if quantization_config is not None else None,
223 | quantization_config=quantization_config,
224 | )
225 | # training_args.model_init_kwargs = model_kwargs
226 | from transformers import Qwen2VLForConditionalGeneration
227 | model = Qwen2VLForConditionalGeneration.from_pretrained(
228 | model_args.model_name_or_path, **model_kwargs
229 | )
230 | ############################
231 | # Initialize the SFT Trainer
232 | ############################
233 | training_args.dataset_kwargs = {
234 | "skip_prepare_dataset": True,
235 | }
236 | training_args.remove_unused_columns = False
237 | trainer = SFTTrainer(
238 | model=model,
239 | args=training_args,
240 | train_dataset=dataset[script_args.dataset_train_split],
241 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
242 | processing_class=processor.tokenizer,
243 | data_collator=collate_fn,
244 | peft_config=get_peft_config(model_args)
245 | )
246 |
247 | ###############
248 | # Training loop
249 | ###############
250 | logger.info("*** Train ***")
251 | checkpoint = None
252 | if training_args.resume_from_checkpoint is not None:
253 | checkpoint = training_args.resume_from_checkpoint
254 | elif last_checkpoint is not None:
255 | checkpoint = last_checkpoint
256 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
257 | metrics = train_result.metrics
258 | metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
259 | trainer.log_metrics("train", metrics)
260 | trainer.save_metrics("train", metrics)
261 | trainer.save_state()
262 |
263 | ##################################
264 | # Save model and create model card
265 | ##################################
266 | logger.info("*** Save model ***")
267 | trainer.save_model(training_args.output_dir)
268 | processor.save_pretrained(training_args.output_dir)
269 | logger.info(f"Model saved to {training_args.output_dir}")
270 |
271 | # Save everything else on main process
272 | kwargs = {
273 | "dataset_name": script_args.dataset_name,
274 | "tags": ["R1-V"],
275 | }
276 | if trainer.accelerator.is_main_process:
277 | trainer.create_model_card(**kwargs)
278 | # Restore k,v cache for fast inference
279 | trainer.model.config.use_cache = True
280 | trainer.model.config.save_pretrained(training_args.output_dir)
281 | #############
282 | # push to hub
283 | #############
284 |
285 | if training_args.push_to_hub:
286 | logger.info("Pushing to hub...")
287 | trainer.push_to_hub(**kwargs)
288 | processor.push_to_hub(training_args.hub_model_id)
289 |
290 |
291 |
292 |
293 | if __name__ == "__main__":
294 | parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
295 | script_args, training_args, model_args = parser.parse_args_and_config()
296 | main(script_args, training_args, model_args)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/qwen-vl-utils/src/qwen_vl_utils/vision_process_egoplan.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import base64
4 | import logging
5 | import math
6 | import os
7 | import sys
8 | import time
9 | import warnings
10 | from functools import lru_cache
11 | from io import BytesIO
12 |
13 | import requests
14 | import torch
15 | import torchvision
16 | from packaging import version
17 | from PIL import Image
18 | from torchvision import io, transforms
19 | from torchvision.transforms import InterpolationMode
20 | from typing import Optional
21 |
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | IMAGE_FACTOR = 28
26 | MIN_PIXELS = 4 * 28 * 28
27 | MAX_PIXELS = 16384 * 28 * 28
28 | MAX_RATIO = 200
29 |
30 | VIDEO_MIN_PIXELS = 128 * 28 * 28
31 | VIDEO_MAX_PIXELS = 196 * 28 * 28
32 | FRAME_FACTOR = 2
33 | FPS = 2.0
34 | FPS_MIN_FRAMES = 4
35 | FPS_MAX_FRAMES = 16
36 |
37 | # Set the maximum number of video token inputs.
38 | # Here, 128K represents the maximum number of input tokens for the VLLM model.
39 | # Remember to adjust it according to your own configuration.
40 | VIDEO_TOTAL_PIXELS = int(float(os.environ.get('VIDEO_MAX_PIXELS', 128000 * 28 * 28 * 0.9)))
41 | logger.info(f"set VIDEO_TOTAL_PIXELS: {VIDEO_TOTAL_PIXELS}")
42 |
43 |
44 | def round_by_factor(number: int, factor: int) -> int:
45 | """Returns the closest integer to 'number' that is divisible by 'factor'."""
46 | return round(number / factor) * factor
47 |
48 |
49 | def ceil_by_factor(number: int, factor: int) -> int:
50 | """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
51 | return math.ceil(number / factor) * factor
52 |
53 |
54 | def floor_by_factor(number: int, factor: int) -> int:
55 | """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
56 | return math.floor(number / factor) * factor
57 |
58 |
59 | def smart_resize(
60 | height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
61 | ) -> tuple[int, int]:
62 | """
63 | Rescales the image so that the following conditions are met:
64 |
65 | 1. Both dimensions (height and width) are divisible by 'factor'.
66 |
67 | 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
68 |
69 | 3. The aspect ratio of the image is maintained as closely as possible.
70 | """
71 | if max(height, width) / min(height, width) > MAX_RATIO:
72 | raise ValueError(
73 | f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
74 | )
75 | h_bar = max(factor, round_by_factor(height, factor))
76 | w_bar = max(factor, round_by_factor(width, factor))
77 | if h_bar * w_bar > max_pixels:
78 | beta = math.sqrt((height * width) / max_pixels)
79 | h_bar = floor_by_factor(height / beta, factor)
80 | w_bar = floor_by_factor(width / beta, factor)
81 | elif h_bar * w_bar < min_pixels:
82 | beta = math.sqrt(min_pixels / (height * width))
83 | h_bar = ceil_by_factor(height * beta, factor)
84 | w_bar = ceil_by_factor(width * beta, factor)
85 | return h_bar, w_bar
86 |
87 |
88 | def to_rgb(pil_image: Image.Image) -> Image.Image:
89 | if pil_image.mode == 'RGBA':
90 | white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
91 | white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
92 | return white_background
93 | else:
94 | return pil_image.convert("RGB")
95 |
96 |
97 | def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
98 | if "image" in ele:
99 | image = ele["image"]
100 | else:
101 | image = ele["image_url"]
102 | image_obj = None
103 | if isinstance(image, Image.Image):
104 | image_obj = image
105 | elif image.startswith("http://") or image.startswith("https://"):
106 | response = requests.get(image, stream=True)
107 | image_obj = Image.open(BytesIO(response.content))
108 | elif image.startswith("file://"):
109 | image_obj = Image.open(image[7:])
110 | elif image.startswith("data:image"):
111 | if "base64," in image:
112 | _, base64_data = image.split("base64,", 1)
113 | data = base64.b64decode(base64_data)
114 | image_obj = Image.open(BytesIO(data))
115 | else:
116 | image_obj = Image.open(image)
117 | if image_obj is None:
118 | raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
119 | image = to_rgb(image_obj)
120 | ## resize
121 | if "resized_height" in ele and "resized_width" in ele:
122 | resized_height, resized_width = smart_resize(
123 | ele["resized_height"],
124 | ele["resized_width"],
125 | factor=size_factor,
126 | )
127 | else:
128 | width, height = image.size
129 | min_pixels = ele.get("min_pixels", MIN_PIXELS)
130 | max_pixels = ele.get("max_pixels", MAX_PIXELS)
131 | resized_height, resized_width = smart_resize(
132 | height,
133 | width,
134 | factor=size_factor,
135 | min_pixels=min_pixels,
136 | max_pixels=max_pixels,
137 | )
138 | image = image.resize((resized_width, resized_height))
139 |
140 | return image
141 |
142 |
143 | def smart_nframes(
144 | ele: dict,
145 | total_frames: int,
146 | video_fps: int | float,
147 | ) -> int:
148 | """calculate the number of frames for video used for model inputs.
149 |
150 | Args:
151 | ele (dict): a dict contains the configuration of video.
152 | support either `fps` or `nframes`:
153 | - nframes: the number of frames to extract for model inputs.
154 | - fps: the fps to extract frames for model inputs.
155 | - min_frames: the minimum number of frames of the video, only used when fps is provided.
156 | - max_frames: the maximum number of frames of the video, only used when fps is provided.
157 | total_frames (int): the original total number of frames of the video.
158 | video_fps (int | float): the original fps of the video.
159 |
160 | Raises:
161 | ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
162 |
163 | Returns:
164 | int: the number of frames for video used for model inputs.
165 | """
166 | assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
167 | if "nframes" in ele:
168 | nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
169 | else:
170 | fps = ele.get("fps", FPS)
171 | min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
172 | max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
173 | nframes = total_frames / video_fps * fps
174 | if nframes > total_frames:
175 | logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
176 | nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
177 | nframes = floor_by_factor(nframes, FRAME_FACTOR)
178 | if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
179 | raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
180 | return nframes
181 |
182 |
183 | def _read_video_torchvision(
184 | ele: dict,
185 | ) -> (torch.Tensor, float):
186 | """read video using torchvision.io.read_video
187 |
188 | Args:
189 | ele (dict): a dict contains the configuration of video.
190 | support keys:
191 | - video: the path of video. support "file://", "http://", "https://" and local path.
192 | - video_start: the start time of video.
193 | - video_end: the end time of video.
194 | Returns:
195 | torch.Tensor: the video tensor with shape (T, C, H, W).
196 | """
197 | video_path = ele["video"]
198 | if version.parse(torchvision.__version__) < version.parse("0.19.0"):
199 | if "http://" in video_path or "https://" in video_path:
200 | warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
201 | if "file://" in video_path:
202 | video_path = video_path[7:]
203 | st = time.time()
204 | video, audio, info = io.read_video(
205 | video_path,
206 | start_pts=ele.get("video_start", 0.0),
207 | end_pts=ele.get("video_end", None),
208 | pts_unit="sec",
209 | output_format="TCHW",
210 | )
211 | total_frames, video_fps = video.size(0), info["video_fps"]
212 | logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
213 | nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
214 | idx = torch.linspace(0, total_frames - 1, nframes).round().long()
215 | sample_fps = nframes / max(total_frames, 1e-6) * video_fps
216 | video = video[idx]
217 | return video, sample_fps
218 |
219 |
220 | def is_decord_available() -> bool:
221 | import importlib.util
222 |
223 | return importlib.util.find_spec("decord") is not None
224 |
225 |
226 | def _read_video_decord(
227 | ele: dict,
228 | ) -> (torch.Tensor, float):
229 | """read video using decord.VideoReader
230 |
231 | Args:
232 | ele (dict): a dict contains the configuration of video.
233 | support keys:
234 | - video: the path of video. support "file://", "http://", "https://" and local path.
235 | - video_start: the start time of video.
236 | - video_end: the end time of video.
237 | Returns:
238 | torch.Tensor: the video tensor with shape (T, C, H, W).
239 | """
240 | import decord
241 | video_path = ele["video"]
242 | st = time.time()
243 | vr = decord.VideoReader(video_path)
244 | # TODO: support start_pts and end_pts
245 | if 'video_start' in ele or 'video_end' in ele:
246 | raise NotImplementedError("not support start_pts and end_pts in decord for now.")
247 | total_frames, video_fps = len(vr), vr.get_avg_fps()
248 | logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
249 | nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
250 | idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
251 | video = vr.get_batch(idx).asnumpy()
252 | video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
253 | sample_fps = nframes / max(total_frames, 1e-6) * video_fps
254 | return video, sample_fps
255 |
256 |
257 | VIDEO_READER_BACKENDS = {
258 | "decord": _read_video_decord,
259 | "torchvision": _read_video_torchvision,
260 | }
261 |
262 | FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
263 |
264 |
265 | @lru_cache(maxsize=1)
266 | def get_video_reader_backend() -> str:
267 | if FORCE_QWENVL_VIDEO_READER is not None:
268 | video_reader_backend = FORCE_QWENVL_VIDEO_READER
269 | elif is_decord_available():
270 | video_reader_backend = "decord"
271 | else:
272 | video_reader_backend = "torchvision"
273 | print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
274 | return video_reader_backend
275 |
276 |
277 | def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]:
278 | if isinstance(ele["video"], str):
279 | video_reader_backend = get_video_reader_backend()
280 | try:
281 | video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
282 | except Exception as e:
283 | logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
284 | video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele)
285 |
286 | nframes, _, height, width = video.shape
287 |
288 | min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
289 | total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
290 | max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
291 | max_pixels_supposed = ele.get("max_pixels", max_pixels)
292 | if max_pixels_supposed > max_pixels:
293 | logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
294 | max_pixels = min(max_pixels_supposed, max_pixels)
295 | if "resized_height" in ele and "resized_width" in ele:
296 | resized_height, resized_width = smart_resize(
297 | ele["resized_height"],
298 | ele["resized_width"],
299 | factor=image_factor,
300 | )
301 | else:
302 | resized_height, resized_width = smart_resize(
303 | height,
304 | width,
305 | factor=image_factor,
306 | min_pixels=min_pixels,
307 | max_pixels=max_pixels,
308 | )
309 | video = transforms.functional.resize(
310 | video,
311 | [resized_height, resized_width],
312 | interpolation=InterpolationMode.BICUBIC,
313 | antialias=True,
314 | ).float()
315 |
316 | # print(f'video shape {video.shape}')
317 | if return_video_sample_fps:
318 | return video, sample_fps
319 | return video
320 | else:
321 | assert isinstance(ele["video"], (list, tuple))
322 | process_info = ele.copy()
323 | process_info.pop("type", None)
324 | process_info.pop("video", None)
325 | images = [
326 | fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
327 | for video_element in ele["video"]
328 | ]
329 | nframes = ceil_by_factor(len(images), FRAME_FACTOR)
330 | if len(images) < nframes:
331 | images.extend([images[-1]] * (nframes - len(images)))
332 | if return_video_sample_fps:
333 | return images, process_info.pop("fps", 2.0)
334 | return images
335 |
336 |
337 | def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
338 | vision_infos = []
339 | if isinstance(conversations[0], dict):
340 | conversations = [conversations]
341 | for conversation in conversations:
342 | for message in conversation:
343 | if isinstance(message["content"], list):
344 | for ele in message["content"]:
345 | if (
346 | "image" in ele
347 | or "image_url" in ele
348 | or "video" in ele
349 | or ele["type"] in ("image", "image_url", "video")
350 | ):
351 | ele['resized_height'] = 256
352 | ele['resized_width'] = 256
353 | vision_infos.append(ele)
354 | return vision_infos
355 |
356 |
357 | def process_vision_info(
358 | conversations: list[dict] | list[list[dict]],
359 | return_video_kwargs: bool = False,
360 | ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]:
361 |
362 | vision_infos = extract_vision_info(conversations)
363 | ## Read images or videos
364 | image_inputs = []
365 | video_inputs = []
366 | video_sample_fps_list = []
367 | for vision_info in vision_infos:
368 | if "image" in vision_info or "image_url" in vision_info:
369 | image_inputs.append(fetch_image(vision_info))
370 | elif "video" in vision_info:
371 | video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
372 | video_sample_fps_list.append(video_sample_fps)
373 | video_inputs.append(video_input)
374 | else:
375 | raise ValueError("image, image_url or video should in content.")
376 | if len(image_inputs) == 0:
377 | image_inputs = None
378 | if len(video_inputs) == 0:
379 | video_inputs = None
380 | if return_video_kwargs:
381 | return image_inputs, video_inputs, {'fps': video_sample_fps_list}
382 | return image_inputs, video_inputs
--------------------------------------------------------------------------------
/src/open_r1_egoplan/trainer/grpo_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import copy
16 | import os
17 | import textwrap
18 | from collections import defaultdict
19 | from typing import Any, Callable, Optional, Union, List, Tuple, Dict
20 |
21 | import torch
22 | import torch.utils.data
23 | import transformers
24 | from datasets import Dataset, IterableDataset
25 | from packaging import version
26 | from transformers import (
27 | AriaForConditionalGeneration,
28 | AriaProcessor,
29 | AutoModelForCausalLM,
30 | AutoModelForSequenceClassification,
31 | AutoProcessor,
32 | AutoTokenizer,
33 | GenerationConfig,
34 | PreTrainedModel,
35 | PreTrainedTokenizerBase,
36 | Qwen2VLForConditionalGeneration,
37 | Trainer,
38 | TrainerCallback,
39 | is_wandb_available,
40 | )
41 | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
42 | from transformers.utils import is_peft_available
43 |
44 | from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
45 | from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
46 | from trl.trainer.grpo_config import GRPOConfig
47 | from trl.trainer.utils import generate_model_card, get_comet_experiment_url
48 |
49 |
50 | from qwen_vl_utils import process_vision_info
51 |
52 | if is_peft_available():
53 | from peft import PeftConfig, get_peft_model
54 |
55 | if is_wandb_available():
56 | import wandb
57 |
58 | # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
59 | # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
60 | RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], List[float]]]
61 |
62 |
63 | class Qwen2VLGRPOTrainer(Trainer):
64 | """
65 | Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
66 | paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
67 |
68 | Example:
69 |
70 | ```python
71 | from datasets import load_dataset
72 | from trl import GRPOTrainer
73 |
74 | dataset = load_dataset("trl-lib/tldr", split="train")
75 |
76 | trainer = GRPOTrainer(
77 | model="Qwen/Qwen2-0.5B-Instruct",
78 | reward_funcs="weqweasdas/RM-Gemma-2B",
79 | train_dataset=dataset,
80 | )
81 |
82 | trainer.train()
83 | ```
84 |
85 | Args:
86 | model (`Union[str, PreTrainedModel]`):
87 | Model to be trained. Can be either:
88 |
89 | - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
90 | a path to a *directory* containing model weights saved using
91 | [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
92 | loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
93 | in `args.model_init_kwargs`.
94 | - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
95 | reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
96 | Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
97 | functions with the prompts and completions and sum the rewards. Can be either:
98 |
99 | - A single reward function, such as:
100 | - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
101 | path to a *directory* containing model weights saved using
102 | [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
103 | using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
104 | keyword arguments in `args.model_init_kwargs`.
105 | - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
106 | - A custom reward function: The function is provided with the prompts and the generated completions,
107 | plus any additional columns in the dataset. It should return a list of rewards. For more details, see
108 | [Using a custom reward function](#using-a-custom-reward-function).
109 | - A list of reward functions, where each item can independently be any of the above types. Mixing different
110 | types within the list (e.g., a string model ID and a custom reward function) is allowed.
111 | args ([`GRPOConfig`], *optional*, defaults to `None`):
112 | Configuration for this trainer. If `None`, a default configuration is used.
113 | train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
114 | Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
115 | ignored. The format of the samples can be either:
116 |
117 | - [Standard](dataset_formats#standard): Each sample contains plain text.
118 | - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
119 | and content).
120 | eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
121 | Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
122 | processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
123 | Processing class used to process the data. The padding side must be set to "left". If `None`, the
124 | processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
125 | reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
126 | Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
127 |
128 | - A single processing class: Used when `reward_funcs` contains only one reward function.
129 | - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
130 | If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
131 | `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
132 | For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
133 | the corresponding entries in `reward_processing_classes` are ignored.
134 | callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
135 | List of callbacks to customize the training loop. Will add those to the list of default callbacks
136 | detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
137 |
138 | If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
139 | method.
140 | optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
141 | A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
142 | model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
143 | peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
144 | PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
145 | """
146 |
147 | def __init__(
148 | self,
149 | model: Union[str, PreTrainedModel],
150 | reward_funcs: Union[RewardFunc, List[RewardFunc]],
151 | args: GRPOConfig = None,
152 | train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
153 | eval_dataset: Optional[Union[Dataset, IterableDataset, Dict[str, Union[Dataset, IterableDataset]]]] = None,
154 | processing_class: Optional[PreTrainedTokenizerBase] = None,
155 | reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, List[PreTrainedTokenizerBase]]] = None,
156 | callbacks: Optional[List[TrainerCallback]] = None,
157 | optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
158 | peft_config: Optional["PeftConfig"] = None,
159 | max_pixels: Optional[int] = 12845056,
160 | min_pixels: Optional[int] = 3136,
161 | attn_implementation: str = "flash_attention_2",
162 | ):
163 | # Args
164 | if args is None:
165 | model_name = model if isinstance(model, str) else model.config._name_or_path
166 | model_name = model_name.split("/")[-1]
167 | args = GRPOConfig(f"{model_name}-GRPO")
168 |
169 | # Models
170 | # Trained model
171 | model_init_kwargs = args.model_init_kwargs or {}
172 | model_init_kwargs["attn_implementation"] = attn_implementation
173 | if isinstance(model, str):
174 | model_id = model
175 | torch_dtype = model_init_kwargs.get("torch_dtype")
176 | if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
177 | pass # torch_dtype is already a torch.dtype or "auto" or None
178 | elif isinstance(torch_dtype, str): # it's a str, but not "auto"
179 | torch_dtype = getattr(torch, torch_dtype)
180 | model_init_kwargs["torch_dtype"] = torch_dtype
181 | else:
182 | raise ValueError(
183 | "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
184 | f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
185 | )
186 | # Disable caching if gradient checkpointing is enabled (not supported)
187 | model_init_kwargs["use_cache"] = (
188 | False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
189 | )
190 | if "Qwen2-VL" in model_id:
191 | model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
192 | elif "Aria" in model_id:
193 | model_init_kwargs.pop("use_cache")
194 | model = AriaForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
195 | else:
196 | model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
197 | else:
198 | model_id = model.config._name_or_path
199 | if args.model_init_kwargs is not None:
200 | raise ValueError(
201 | "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
202 | "This argument can only be used when the `model` argument is a string."
203 | )
204 |
205 | if args.freeze_visual_encoder:
206 | print("Freeze parameters of visual encoder!")
207 | for n, p in model.visual.named_parameters():
208 | p.requires_grad_(False)
209 |
210 | if peft_config is not None:
211 | model = get_peft_model(model, peft_config)
212 |
213 | # Reference model
214 | if is_deepspeed_zero3_enabled():
215 | if "Qwen2-VL" in model_id:
216 | self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
217 | elif "Aria" in model_id:
218 | self.ref_model = AriaForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
219 | else:
220 | self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
221 | elif peft_config is None:
222 | # If PEFT configuration is not provided, create a reference model based on the initial model.
223 | self.ref_model = create_reference_model(model)
224 | else:
225 | # If PEFT is used, the reference model is not needed since the adapter can be disabled
226 | # to revert to the initial model.
227 | self.ref_model = None
228 |
229 | # Processing class
230 | if processing_class is None:
231 | if "Qwen2-VL" in model_id or "Aria" in model_id:
232 | processing_class = AutoProcessor.from_pretrained(model_id)
233 | pad_token_id = processing_class.tokenizer.pad_token_id
234 | processing_class.pad_token_id = pad_token_id
235 | processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
236 | if "Qwen2-VL" in model_id:
237 | processing_class.image_processor.max_pixels = max_pixels
238 | processing_class.image_processor.min_pixels = min_pixels
239 | else:
240 | processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
241 | pad_token_id = processing_class.pad_token_id
242 |
243 | # Reward functions
244 | if not isinstance(reward_funcs, list):
245 | reward_funcs = [reward_funcs]
246 | for i, reward_func in enumerate(reward_funcs):
247 | if isinstance(reward_func, str):
248 | reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
249 | reward_func, num_labels=1, **model_init_kwargs
250 | )
251 | self.reward_funcs = reward_funcs
252 |
253 | # Reward processing class
254 | if reward_processing_classes is None:
255 | reward_processing_classes = [None] * len(reward_funcs)
256 | elif not isinstance(reward_processing_classes, list):
257 | reward_processing_classes = [reward_processing_classes]
258 | else:
259 | if len(reward_processing_classes) != len(reward_funcs):
260 | raise ValueError("The number of reward processing classes must match the number of reward functions.")
261 |
262 | for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
263 | if isinstance(reward_func, PreTrainedModel):
264 | if reward_processing_class is None:
265 | reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
266 | if reward_processing_class.pad_token_id is None:
267 | reward_processing_class.pad_token = reward_processing_class.eos_token
268 | # The reward model computes the reward for the latest non-padded token in the input sequence.
269 | # So it's important to set the pad token ID to the padding token ID of the processing class.
270 | reward_func.config.pad_token_id = reward_processing_class.pad_token_id
271 | reward_processing_classes[i] = reward_processing_class
272 | self.reward_processing_classes = reward_processing_classes
273 |
274 | # Data collator
275 | def data_collator(features): # No data collation is needed in GRPO
276 | return features
277 |
278 | # Training arguments
279 | self.max_prompt_length = args.max_prompt_length
280 | self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
281 | self.num_generations = args.num_generations # = G in the GRPO paper
282 | self.generation_config = GenerationConfig(
283 | max_new_tokens=self.max_completion_length,
284 | do_sample=True,
285 | temperature=1, # HACK
286 | num_return_sequences=self.num_generations,
287 | pad_token_id=pad_token_id,
288 | )
289 | self.beta = args.beta
290 |
291 | # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
292 | # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
293 | # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
294 | # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
295 | # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
296 | # This acts as a flag to indicate that the warning has already been issued.
297 | model.warnings_issued["estimate_tokens"] = True
298 |
299 | # Initialize the metrics
300 | self._metrics = defaultdict(list)
301 |
302 | super().__init__(
303 | model=model,
304 | args=args,
305 | data_collator=data_collator,
306 | train_dataset=train_dataset,
307 | eval_dataset=eval_dataset,
308 | processing_class=processing_class,
309 | callbacks=callbacks,
310 | optimizers=optimizers,
311 | )
312 |
313 | # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
314 | # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
315 | # self.model_accepts_loss_kwargs to False to enable scaling.
316 | self.model_accepts_loss_kwargs = False
317 |
318 | if self.ref_model is not None:
319 | if self.is_deepspeed_enabled:
320 | self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
321 | else:
322 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
323 |
324 | for i, reward_func in enumerate(self.reward_funcs):
325 | if isinstance(reward_func, PreTrainedModel):
326 | self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
327 |
328 | def _set_signature_columns_if_needed(self):
329 | # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
330 | # By default, this method sets `self._signature_columns` to the model's expected inputs.
331 | # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
332 | # Instead, we set them to the columns expected by the `training_step` method, hence the override.
333 | if self._signature_columns is None:
334 | self._signature_columns = ["prompt"]
335 |
336 | # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
337 | # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
338 | def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
339 | return inputs
340 |
341 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
342 | if return_outputs:
343 | raise ValueError("The GRPOTrainer does not support returning outputs")
344 |
345 | # for x in inputs:
346 | # msg = x["prompt"]
347 | # new_content = []
348 | # for ele in msg["content"]:
349 | # new_ele = {}
350 | # for k, v in ele:
351 | # if v is not None:
352 | # new_ele[k] = v
353 | # new_content.append(new_ele)
354 | # msg["content"] = new_content
355 |
356 | for x in inputs:
357 | for message in x["prompt"]:
358 | if isinstance(message["content"], list):
359 | new_content = []
360 | for ele in message["content"]:
361 | new_ele = {k: v for k, v in ele.items() if v is not None}
362 | new_content.append(new_ele)
363 | message["content"] = new_content
364 |
365 | prompts = [x["prompt"] for x in inputs]
366 |
367 | # Preparation for batch inference
368 | prompt_texts = [
369 | self.processing_class.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
370 | for msg in prompts
371 | ]
372 | image_inputs, video_inputs = process_vision_info(prompts)
373 | prompt_inputs = self.processing_class(
374 | text=prompt_texts,
375 | images=image_inputs,
376 | videos=video_inputs,
377 | return_tensors="pt",
378 | padding=True,
379 | padding_side="left",
380 | add_special_tokens=False,
381 | )
382 | prompt_inputs = super()._prepare_inputs(prompt_inputs)
383 |
384 | if self.max_prompt_length is not None:
385 | prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
386 | prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
387 |
388 | # Generate completions
389 | with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
390 | # prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config)
391 |
392 | # Generate N times, each generate one with the temp_generation_config , stack the output_ids to prompt_completion_ids, pad the empty places with number 151613
393 | num_generations = self.generation_config.num_return_sequences
394 | temp_generation_config = copy.deepcopy(self.generation_config)
395 | temp_generation_config.num_return_sequences = 1
396 |
397 | all_completions = []
398 |
399 | for i in range(num_generations): # -1 because we already have one generation
400 | completion = unwrapped_model.generate(**prompt_inputs, generation_config=temp_generation_config)
401 | all_completions.append(completion)
402 |
403 | # Stack all completions and pad if needed
404 | max_length = max(completion.size(1) for completion in all_completions)
405 | padded_completions = []
406 |
407 | for completion in all_completions:
408 | if completion.size(1) < max_length:
409 | padding = torch.full(
410 | (completion.size(0), max_length - completion.size(1)),
411 | self.processing_class.tokenizer.pad_token_id,
412 | dtype=completion.dtype,
413 | device=completion.device,
414 | )
415 | padded_completion = torch.cat([completion, padding], dim=1)
416 | else:
417 | padded_completion = completion
418 | padded_completions.append(padded_completion)
419 |
420 | # Stack all padded completions
421 | prompt_completion_ids = torch.cat(padded_completions, dim=0)
422 |
423 | prompt_length = prompt_inputs["input_ids"].size(1)
424 | completion_ids = prompt_completion_ids[:, prompt_length:]
425 |
426 | # import pdb; pdb.set_trace()
427 |
428 | # Get the per-token log probabilities for the completions for the model and the reference model
429 | def get_per_token_logps(model, input_ids, **kwargs):
430 | logits = model(input_ids, **kwargs).logits # (B, L, V)
431 | logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
432 | input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
433 | # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
434 | per_token_logps = []
435 | for logits_row, input_ids_row in zip(logits, input_ids):
436 | log_probs = logits_row.log_softmax(dim=-1)
437 | token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
438 | per_token_logps.append(token_log_prob)
439 | return torch.stack(per_token_logps)
440 |
441 | prompt_inputs.pop("input_ids")
442 | prompt_inputs.pop("attention_mask")
443 | # # Okay I am assuming that the inputs are Qwen2VL processor
444 | # # and no video for now, repeat the image for each completion
445 | # if "image" in inputs[0]:
446 | # prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].repeat(len(prompt_completion_ids), 1)
447 | # prompt_inputs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(len(prompt_completion_ids), 1)
448 | # # import pdb; pdb.set_trace()
449 |
450 | # # XXX if input video
451 | # # image_grid_thw is from image_process_qwen2_vl
452 | # # https://github.com/huggingface/transformers/blob/dd16acb8a3e93b643aa374c9fb80749f5235c1a6/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L414
453 | # # automatic process
454 | # if "video" in inputs[0]:
455 | # prompt_inputs["pixel_values_videos"] = prompt_inputs["pixel_values_videos"].repeat(len(prompt_completion_ids), 1)
456 | # prompt_inputs["video_grid_thw"] = prompt_inputs["video_grid_thw"].repeat(len(prompt_completion_ids), 1)
457 |
458 | for k in ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]:
459 | if k in prompt_inputs:
460 | prompt_inputs[k] = prompt_inputs[k].repeat(len(prompt_completion_ids), 1)
461 |
462 |
463 | per_token_logps = get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
464 | # Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
465 | per_token_logps = per_token_logps[:, prompt_length - 1 :]
466 |
467 | with torch.inference_mode():
468 | if self.ref_model is not None:
469 | ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs) # Fix Bug
470 | else:
471 | with self.accelerator.unwrap_model(model).disable_adapter():
472 | ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, **prompt_inputs) # Fix Bug
473 | ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
474 |
475 | # Compute the KL divergence between the model and the reference model
476 | diff = ref_per_token_logps - per_token_logps
477 | diff = torch.clamp(diff, min=-11.0, max=11.0)
478 |
479 | per_token_kl = torch.exp(diff) - (diff) - 1
480 |
481 | # Mask everything after the first EOS token
482 | is_eos = completion_ids == self.processing_class.eos_token_id
483 | device = self.accelerator.device
484 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
485 | eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
486 | sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
487 | completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
488 |
489 | # Decode the generated completions
490 | completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
491 | if is_conversational(inputs[0]):
492 | completions = [[{"role": "assistant", "content": completion}] for completion in completions]
493 |
494 | # Compute the rewards
495 | prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
496 |
497 | rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
498 | for i, (reward_func, reward_processing_class) in enumerate(
499 | zip(self.reward_funcs, self.reward_processing_classes)
500 | ):
501 | # import pdb; pdb.set_trace()
502 | if isinstance(reward_func, PreTrainedModel):
503 | if is_conversational(inputs[0]): # true
504 | messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
505 | texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
506 | else:
507 | texts = [p + c for p, c in zip(prompts, completions)]
508 | reward_inputs = reward_processing_class(
509 | texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
510 | )
511 | reward_inputs = super()._prepare_inputs(reward_inputs)
512 | with torch.inference_mode():
513 | rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
514 | else:
515 | # Repeat all input columns (but "prompt" and "completion") to match the number of generations
516 | reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
517 | for key in reward_kwargs:
518 | for example in inputs:
519 | # Repeat each value in the column for `num_generations` times
520 | reward_kwargs[key].extend([example[key]] * self.num_generations)
521 | output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
522 | rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
523 |
524 | # Sum the rewards from all reward functions
525 | rewards = rewards_per_func.sum(dim=1) # (bs, num_generations)
526 |
527 | # Compute grouped-wise rewards
528 | mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) # (bs,)
529 | std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) # (bs,)
530 |
531 | # Normalize the rewards to compute the advantages
532 | mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) # (bs, num_generations)
533 | std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) # (bs, num_generations)
534 | advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
535 |
536 | # x - x.detach() allows for preserving gradients from x
537 | per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
538 | per_token_loss = -(per_token_loss - self.beta * per_token_kl) # default 0.04
539 | loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
540 |
541 | # import pdb; pdb.set_trace()
542 |
543 | # Log the metrics
544 | completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
545 | self._metrics["completion_length"].append(completion_length)
546 |
547 | reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
548 | for i, reward_func in enumerate(self.reward_funcs):
549 | if isinstance(reward_func, PreTrainedModel):
550 | reward_func_name = reward_func.config._name_or_path.split("/")[-1]
551 | else:
552 | reward_func_name = reward_func.__name__
553 | self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
554 |
555 | self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
556 |
557 | self._metrics["advantages"].append(self.accelerator.gather_for_metrics(advantages).mean().item())
558 |
559 | self._metrics["reward_mean"].append(self.accelerator.gather_for_metrics(mean_grouped_rewards).mean().item())
560 |
561 | self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
562 |
563 | mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
564 | self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
565 |
566 | # import pdb; pdb.set_trace()
567 |
568 | return loss
569 |
570 | def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
571 | metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
572 | logs = {**logs, **metrics}
573 | if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
574 | super().log(logs, start_time)
575 | else: # transformers<=4.46
576 | super().log(logs)
577 | self._metrics.clear()
578 |
579 | def create_model_card(
580 | self,
581 | model_name: Optional[str] = None,
582 | dataset_name: Optional[str] = None,
583 | tags: Union[str, List[str], None] = None,
584 | ):
585 | """
586 | Creates a draft of a model card using the information available to the `Trainer`.
587 |
588 | Args:
589 | model_name (`str` or `None`, *optional*, defaults to `None`):
590 | Name of the model.
591 | dataset_name (`str` or `None`, *optional*, defaults to `None`):
592 | Name of the dataset used for training.
593 | tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
594 | Tags to be associated with the model card.
595 | """
596 | if not self.is_world_process_zero():
597 | return
598 |
599 | if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
600 | base_model = self.model.config._name_or_path
601 | else:
602 | base_model = None
603 |
604 | tags = tags or []
605 | if isinstance(tags, str):
606 | tags = [tags]
607 |
608 | if hasattr(self.model.config, "unsloth_version"):
609 | tags.append("unsloth")
610 |
611 | citation = textwrap.dedent(
612 | """\
613 | @article{zhihong2024deepseekmath,
614 | title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
615 | author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
616 | year = 2024,
617 | eprint = {arXiv:2402.03300},
618 | """
619 | )
620 |
621 | model_card = generate_model_card(
622 | base_model=base_model,
623 | model_name=model_name,
624 | hub_model_id=self.hub_model_id,
625 | dataset_name=dataset_name,
626 | tags=tags,
627 | wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
628 | comet_url=get_comet_experiment_url(),
629 | trainer_name="GRPO",
630 | trainer_citation=citation,
631 | paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
632 | paper_id="2402.03300",
633 | )
634 |
635 | model_card.save(os.path.join(self.args.output_dir, "README.md"))
--------------------------------------------------------------------------------