├── src
└── active_prm
│ ├── __init__.py
│ ├── eval
│ ├── PRMBench
│ │ ├── mr_annotate
│ │ │ ├── __init__.py
│ │ │ └── build_data
│ │ │ │ ├── __init__.py
│ │ │ │ ├── generate_data
│ │ │ │ ├── __init__.py
│ │ │ │ └── dataset
│ │ │ │ │ └── generate_dataset.py
│ │ │ │ ├── model_inference
│ │ │ │ ├── gemini_api
│ │ │ │ │ └── gemini_inference.py
│ │ │ │ └── qwq
│ │ │ │ │ ├── inferencer
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── qwq_inferencer.py
│ │ │ │ │ └── run_inference.py
│ │ │ │ └── prompts
│ │ │ │ ├── __init__.py
│ │ │ │ ├── classifications
│ │ │ │ ├── __init__.py
│ │ │ │ ├── confidence.py
│ │ │ │ ├── step_contradiction.py
│ │ │ │ ├── counterfactual.py
│ │ │ │ ├── redundency.py
│ │ │ │ └── domain_inconsistency.py
│ │ │ │ ├── prompts_test.txt
│ │ │ │ └── prompt_new.py
│ │ ├── mr_eval
│ │ │ ├── __init__.py
│ │ │ ├── utils
│ │ │ │ ├── __init__.py
│ │ │ │ ├── log_utils.py
│ │ │ │ ├── task_utils.py
│ │ │ │ ├── prompts.py
│ │ │ │ ├── arguments.py
│ │ │ │ ├── model_utils.py
│ │ │ │ └── utils.py
│ │ │ ├── bon_eval
│ │ │ │ └── policy_gen
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── dataset
│ │ │ │ │ └── base_dataset.py
│ │ │ ├── tasks
│ │ │ │ ├── base_dataset
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── prmbench_bon
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── prmbench_stem
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── prmtest_classified
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── prmbench_bon_subset50
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── config.yaml
│ │ │ │ ├── prmtest_classified_subset400
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── config.yaml
│ │ │ │ └── __init__.py
│ │ │ ├── scripts
│ │ │ │ └── examples
│ │ │ │ │ ├── accelerate_configs
│ │ │ │ │ ├── cpu.yaml
│ │ │ │ │ ├── 1gpu.yaml
│ │ │ │ │ ├── 4gpus.yaml
│ │ │ │ │ ├── zero3_inference.json
│ │ │ │ │ ├── zero3_offload_inference.json
│ │ │ │ │ └── 4gpus_deepspeed.yaml
│ │ │ │ │ ├── direct_run.sh
│ │ │ │ │ ├── example_configs
│ │ │ │ │ └── initial_test_gpt4o.yaml
│ │ │ │ │ ├── api_eval.sh
│ │ │ │ │ └── local_multi_gpu_eval.sh
│ │ │ ├── models
│ │ │ │ ├── __init__.py
│ │ │ │ ├── abstract_model.py
│ │ │ │ ├── qwen_math_rm_fsdp.py
│ │ │ │ ├── math_shepherd.py
│ │ │ │ ├── mathminos_mistral.py
│ │ │ │ ├── llama3_1_8b_prm.py
│ │ │ │ ├── pure_prm.py
│ │ │ │ ├── skywork_prm.py
│ │ │ │ ├── llemma7b_prm.py
│ │ │ │ ├── ensemble_prm.py
│ │ │ │ ├── qwen_prm.py
│ │ │ │ ├── qwen_qwq.py
│ │ │ │ └── vllm_models.py
│ │ │ ├── __main__.py
│ │ │ └── evaluator.py
│ │ ├── __main__.py
│ │ ├── docs
│ │ │ ├── assets
│ │ │ │ ├── main_fig.pdf
│ │ │ │ └── main_logo.png
│ │ │ ├── data_format.md
│ │ │ └── document.md
│ │ ├── mr_visualize
│ │ │ └── eval_res_view
│ │ │ │ ├── draw_figure
│ │ │ │ └── res
│ │ │ │ │ ├── step_acc.pdf
│ │ │ │ │ ├── correlation.pdf
│ │ │ │ │ ├── llava_radar.pdf
│ │ │ │ │ └── error_position_distribution.pdf
│ │ │ │ └── draw_tabs
│ │ │ │ └── bias_table.ipynb
│ │ ├── setup.py
│ │ ├── requirements.txt
│ │ ├── vis_res.py
│ │ └── .gitignore
│ └── processbench.py
│ ├── trainer
│ ├── __init__.py
│ ├── active_sft_config.py
│ └── active_sft_trainer.py
│ └── models
│ ├── __init__.py
│ └── nets.py
├── .gitignore
├── assets
├── algorithm.png
├── figure1.png
├── prmbench.png
└── processbench.png
├── pyproject.toml
├── requirements.txt
├── examples
├── ds_config.json
├── scripts
│ └── pool_based_active_learning.sh
└── py_scripts
│ └── code_snippets.py
├── setup.py
└── readme.md
/src/active_prm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .git
2 | .nvim
3 | .argo
4 | .DS_Store
5 |
6 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/bon_eval/policy_gen/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/tasks/base_dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/tasks/prmbench_bon/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/tasks/prmbench_stem/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/tasks/prmtest_classified/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/generate_data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/tasks/prmbench_bon_subset50/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/tasks/prmtest_classified_subset400/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sail-sg/ActivePRM/HEAD/assets/algorithm.png
--------------------------------------------------------------------------------
/assets/figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sail-sg/ActivePRM/HEAD/assets/figure1.png
--------------------------------------------------------------------------------
/assets/prmbench.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sail-sg/ActivePRM/HEAD/assets/prmbench.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length=119
3 |
4 | [tool.ruff]
5 | line-length=119
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | deepspeed
3 | transformers
4 | datasets
5 | trl>=0.15.2
6 | vllm
7 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/model_inference/gemini_api/gemini_inference.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/processbench.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sail-sg/ActivePRM/HEAD/assets/processbench.png
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/prompts/__init__.py:
--------------------------------------------------------------------------------
1 | from .prompt_new import prompt_dict
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/__main__.py:
--------------------------------------------------------------------------------
1 | from .mr_eval.__main__ import main
2 |
3 | if __name__ == "__main__":
4 | main()
5 |
--------------------------------------------------------------------------------
/src/active_prm/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .active_sft_config import ActiveSFTConfig
2 | from .active_sft_trainer import ActiveSFTTrainer
3 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/docs/assets/main_fig.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sail-sg/ActivePRM/HEAD/src/active_prm/eval/PRMBench/docs/assets/main_fig.pdf
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/docs/assets/main_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sail-sg/ActivePRM/HEAD/src/active_prm/eval/PRMBench/docs/assets/main_logo.png
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/bon_eval/policy_gen/dataset/base_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.data.utils import Dataset
3 |
4 | class BonDataset():
5 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_visualize/eval_res_view/draw_figure/res/step_acc.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sail-sg/ActivePRM/HEAD/src/active_prm/eval/PRMBench/mr_visualize/eval_res_view/draw_figure/res/step_acc.pdf
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_visualize/eval_res_view/draw_figure/res/correlation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sail-sg/ActivePRM/HEAD/src/active_prm/eval/PRMBench/mr_visualize/eval_res_view/draw_figure/res/correlation.pdf
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_visualize/eval_res_view/draw_figure/res/llava_radar.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sail-sg/ActivePRM/HEAD/src/active_prm/eval/PRMBench/mr_visualize/eval_res_view/draw_figure/res/llava_radar.pdf
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/tasks/prmbench_stem/config.yaml:
--------------------------------------------------------------------------------
1 | task_name: "prmtest_classified"
2 | dataset_path: data
3 | dataset_type: "dir_of_jsonl"
4 | generation_config:
5 | max_length: 2048
6 | temperature: 0
7 |
8 |
9 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/tasks/prmtest_classified/config.yaml:
--------------------------------------------------------------------------------
1 | task_name: "prmtest_classified"
2 | dataset_path: data
3 | dataset_type: "dir_of_jsonl"
4 | generation_config:
5 | max_length: 2048
6 | temperature: 0
7 |
8 |
9 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_visualize/eval_res_view/draw_figure/res/error_position_distribution.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sail-sg/ActivePRM/HEAD/src/active_prm/eval/PRMBench/mr_visualize/eval_res_view/draw_figure/res/error_position_distribution.pdf
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/scripts/examples/accelerate_configs/cpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | distributed_type: "NO"
3 | mixed_precision: "no"
4 | num_processes: 1
5 | machine_rank: 0
6 | main_training_function: main
7 | dynamo_backend: "no"
8 | num_machines: 1
9 | use_cpu: true
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/tasks/prmbench_bon/config.yaml:
--------------------------------------------------------------------------------
1 | task_name: "prmbench_bon"
2 | dataset_path: /mnt/petrelfs/songmingyang/code/reasoning/MR_Hallucination/mr_eval/bon_eval/scripts/logs/fix_inference_full/eval_res
3 | dataset_type: "dir_of_jsonl"
4 | generation_config:
5 | max_length: 2048
6 | temperature: 0
7 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/tasks/prmbench_bon_subset50/config.yaml:
--------------------------------------------------------------------------------
1 | task_name: "prmbench_bon"
2 | dataset_path: /mnt/petrelfs/songmingyang/code/reasoning/MR_Hallucination/mr_eval/bon_eval/scripts/logs/inference_res_subset50
3 | dataset_type: "dir_of_jsonl"
4 | generation_config:
5 | max_length: 2048
6 | temperature: 0
7 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/model_inference/qwq/inferencer/__init__.py:
--------------------------------------------------------------------------------
1 | from .qwq_inferencer_dataset import *
2 | from .qwq_inferencer import *
3 |
4 | inferencer_type_dict = dict(
5 | generate_prm=dict(model=QwQGeneratePRMInferencer, dataset=QwQGeneratePRMDataset),
6 | parallel_generate_prm=dict(model=QwQParallelGeneratePRMInferencer, dataset=QwQGeneratePRMDataset),
7 | )
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/tasks/prmtest_classified_subset400/config.yaml:
--------------------------------------------------------------------------------
1 | task_name: "prmtest_classified_subset400"
2 | dataset_path: data
3 | dataset_type: "dir_of_jsonl"
4 | generation_config:
5 | max_length: 2048
6 |
7 | # test_split: validation
8 | # metric_list:
9 | # - metric: hard_acc
10 | # aggregation: mean
11 | # higher_is_better: true
12 | # include: _default_template_docvqa_yaml
13 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/scripts/examples/direct_run.sh:
--------------------------------------------------------------------------------
1 | accelerate_config=./mr_eval/scripts/examples/accelerate_configs/1gpu.yaml
2 |
3 | accelerate launch --config_file ${accelerate_config} \
4 | -m mr_eval \
5 | --model pure_prm \
6 | --batch_size 2 \
7 | --num_workers 2 \
8 | --task_name prmtest_classified \
9 | --verbosity INFO \
10 | --output_path ./mr_eval/scripts/logs/prmtest_classified/pure_prm_7b.jsonl
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/scripts/examples/accelerate_configs/1gpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | distributed_type: NO
3 | downcast_bf16: 'no'
4 | gpu_ids: all
5 | machine_rank: 0
6 | main_training_function: main
7 | mixed_precision: bf16
8 | num_machines: 1
9 | num_processes: 1
10 | rdzv_backend: static
11 | same_network: true
12 | tpu_env: []
13 | tpu_use_cluster: false
14 | tpu_use_sudo: false
15 | use_cpu: false
16 | main_process_port: 39678
17 |
18 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/scripts/examples/accelerate_configs/4gpus.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | distributed_type: MULTI_GPU
3 | downcast_bf16: 'no'
4 | gpu_ids: all
5 | machine_rank: 0
6 | main_training_function: main
7 | mixed_precision: bf16
8 | num_machines: 1
9 | num_processes: 4
10 | rdzv_backend: static
11 | same_network: true
12 | tpu_env: []
13 | tpu_use_cluster: false
14 | tpu_use_sudo: false
15 | use_cpu: false
16 | main_process_port: 39678
17 |
18 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/scripts/examples/example_configs/initial_test_gpt4o.yaml:
--------------------------------------------------------------------------------
1 | model_args:
2 | model: openai_models
3 | model_args: model_name=gpt-4o,endpoint=https://api.openai.com/v1,api_key=your-api-key,log_save_dir=./mr_eval/scripts/logs/generated/model_generate_logs/gpt4o.jsonl
4 | batch_size: 1 # Must be 1 when calling APIs
5 | task_args:
6 | task_name: prmtest_correct
7 | script_args:
8 | verbosity: INFO
9 | output_path: ./mr_eval/scripts/logs/prmtest_classified/gpt4o.jsonl
--------------------------------------------------------------------------------
/examples/ds_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 2,
4 | "offload_optimizer": {
5 | "device": "cpu",
6 | "pin_memory": true
7 | },
8 | "offload_param": {
9 | "device": "cpu",
10 | "pin_memory": true
11 | },
12 | "allgather_partitions": true,
13 | "allgather_bucket_size": 2e8,
14 | "overlap_comm": true,
15 | "reduce_scatter": true,
16 | "reduce_bucket_size": 2e8,
17 | "contiguous_gradients": true
18 | },
19 | "train_micro_batch_size_per_gpu": "auto",
20 | "train_batch_size": "auto",
21 | "gradient_accumulation_steps": "auto",
22 | "gradient_clipping": "auto",
23 | "wall_clock_breakdown": false,
24 | "bf16": {
25 | "enabled": "auto"
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/scripts/examples/accelerate_configs/zero3_inference.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": false
4 | },
5 | "bf16": {
6 | "enabled": true
7 | },
8 | "zero_optimization": {
9 | "stage": 3,
10 | "offload_param": {
11 | "device": "none"
12 | },
13 | "overlap_comm": true,
14 | "contiguous_gradients": false,
15 | "reduce_scatter": true
16 | },
17 | "activation_checkpointing": {
18 | "partition_activations": false
19 | },
20 | "aio": {
21 | "block_size": 1048576,
22 | "queue_depth": 8,
23 | "single_submit": false,
24 | "overlap_events": true,
25 | "thread_count": 2
26 | },
27 | "wall_clock_breakdown": false
28 | }
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/scripts/examples/api_eval.sh:
--------------------------------------------------------------------------------
1 | source ~/.bashrc
2 | source ~/anaconda3/bin/activate smoe
3 |
4 | ## Change Dir to PRMBench
5 |
6 | # environment variables
7 | new_proxy_address="your_http_proxy_address"
8 | export http_proxy=$new_proxy_address
9 | export https_proxy=$new_proxy_address
10 | export HTTP_PROXY=$new_proxy_address
11 | export HTTPS_PROXY=$new_proxy_address
12 |
13 | accelerate_config=./mr_eval/scripts/examples/accelerate_configs/cpu.yaml
14 | config_file=$1
15 |
16 |
17 | gpus=0
18 | cpus=16
19 | quotatype="reserved"
20 | OMP_NUM_THREADS=4 srun --partition=MoE --job-name="eval" --mpi=pmi2 --gres=gpu:${gpus} -n1 --ntasks-per-node=1 -c ${cpus} --kill-on-bad-exit=1 --quotatype=${quotatype} \
21 | accelerate launch --config_file ${accelerate_config} \
22 | -m mr_eval --config ${config_file}
23 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from setuptools import find_packages, setup
4 |
5 | with open("requirements.txt", "r") as requirements:
6 | setup(
7 | name="active_prm",
8 | version="1.0.0",
9 | install_requires=list(requirements.read().splitlines()),
10 | # packages=find_packages(),
11 | packages=find_packages(where="src"),
12 | package_dir={"": "src"},
13 | description="library for o1 redundancy",
14 | python_requires=">=3.11",
15 | author="Keyu Duan",
16 | author_email="k.duan@sea.com",
17 | classifiers=[
18 | "Programming Language :: Python :: 3",
19 | "License :: OSI Approved :: MIT License",
20 | "Operating System :: OS Independent",
21 | ],
22 | long_description=None,
23 | long_description_content_type="text/markdown",
24 | )
25 |
--------------------------------------------------------------------------------
/src/active_prm/models/__init__.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 | from .qwen2_ensemble_prm.configuration_qwen2 import QwenEnPRMConfig
4 | from .qwen2_ensemble_prm.modeling_qwen2 import Qwen2ForEnsemblePRM
5 |
6 |
7 | class AutoModelForEnsemblePRM:
8 | @classmethod
9 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
10 | # if "qwen" in pretrained_model_name_or_path.lower():
11 | # return Qwen2ForEnsemblePRM.from_pretrained(pretrained_model_name_or_path, **kwargs)
12 | # else:
13 | # raise ValueError(f"Model {pretrained_model_name_or_path} not supported")
14 |
15 | # config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
16 | # return Qwen2ForEnsemblePRM.from_pretrained(pretrained_model_name_or_path, config=config, **kwargs)
17 | return Qwen2ForEnsemblePRM.from_pretrained(pretrained_model_name_or_path, **kwargs)
18 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='mr_toolkit', # 项目名称
5 | version='0.1.0', # 项目版本
6 | description='A short description of your project', # 项目简短描述
7 | author='Mingyang Song', # 作者名
8 | author_email='mysong23@m.fudan.edu.cn', # 作者邮箱
9 | url='https://github.com/yourusername/your_project', # 项目主页 URL
10 | packages=find_packages(), # 自动发现并包含项目中的所有包
11 | install_requires=[ # 项目的依赖包列表
12 | 'numpy',
13 | 'requests',
14 | ],
15 | classifiers=[ # 项目分类标签(提高可见度)
16 | 'Programming Language :: Python :: 3',
17 | 'Programming Language :: Python :: 3.10',
18 | 'License :: OSI Approved :: MIT License', # 选择合适的许可证
19 | 'Operating System :: OS Independent',
20 | ],
21 | python_requires='>=3.10', # 项目支持的 Python 版本
22 | include_package_data=True, # 包含非 Python 文件
23 | long_description=open('README.md').read(), # 从 README 文件中读取长描述
24 | long_description_content_type='text/markdown', # 长描述的格式
25 | )
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/prompts/classifications/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | classifications = ["circular", "confidence", "counterfactual", "step_contradiction", "domain_inconsistency", "redundency", "missing_condition", "deception"]
4 | few_shots_names = ["fewshot_q1", "fewshot_a1", "fewshot_q2", "fewshot_a2"]
5 | __all__ = []
6 |
7 | for classification in classifications:
8 | for few_shot_name in few_shots_names:
9 | object_name = f"{classification}_{few_shot_name}"
10 | module = importlib.import_module(f".{classification}", package=__name__) # 动态导入模块
11 | fs_obj = getattr(module, object_name)
12 | globals()[object_name] = fs_obj
13 | __all__.append(object_name)
14 |
15 |
16 | fewshot_dicts = {classification: [(globals()[f"{classification}_fewshot_q1"], globals()[f"{classification}_fewshot_a1"]),
17 | (globals()[f"{classification}_fewshot_q2"], globals()[f"{classification}_fewshot_a2"])] for classification in classifications}
18 | __all__.append("fewshot_dicts")
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/scripts/examples/accelerate_configs/zero3_offload_inference.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": 1,
3 | "gradient_accumulation_steps": 1,
4 | "fp16": {
5 | "enabled": false
6 | },
7 | "bf16": {
8 | "enabled": true
9 | },
10 | "zero_optimization": {
11 | "stage": 3,
12 | "offload_param": {
13 | "device": "cpu",
14 | "pin_memory": true
15 | },
16 | "overlap_comm": true,
17 | "contiguous_gradients": false,
18 | "reduce_scatter": true,
19 | "reduce_bucket_size": 500000000,
20 | "stage3_param_persistence_threshold": 100000
21 | },
22 | "activation_checkpointing": {
23 | "partition_activations": false
24 | },
25 | "aio": {
26 | "block_size": 1048576,
27 | "queue_depth": 8,
28 | "single_submit": false,
29 | "overlap_events": true,
30 | "thread_count": 2
31 | },
32 | "gradient_clipping": 1.0,
33 | "steps_per_print": 2000,
34 | "wall_clock_breakdown": false
35 | }
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/scripts/examples/local_multi_gpu_eval.sh:
--------------------------------------------------------------------------------
1 | source ~/.bashrc
2 | source ~/anaconda3/bin/activate reasoneval
3 |
4 | new_proxy_address=your_http_proxy_address
5 | export http_proxy=$new_proxy_address
6 | export https_proxy=$new_proxy_address
7 | export HTTP_PROXY=$new_proxy_address
8 | export HTTPS_PROXY=$new_proxy_address
9 |
10 |
11 | job_id=3962886
12 | export SLURM_JOB_ID=${job_id}
13 |
14 |
15 | accelerate_config=./mr_eval/scripts/examples/accelerate_configs/4gpus.yaml
16 | config_file=$1
17 |
18 | gpus=4
19 | cpus=32
20 | quotatype="reserved"
21 | OMP_NUM_THREADS=4 srun --partition=MoE --jobid=${job_id} --job-name="eval" --mpi=pmi2 --gres=gpu:${gpus} -n1 --ntasks-per-node=1 -c ${cpus} --kill-on-bad-exit=1 --quotatype=${quotatype} \
22 | accelerate launch --config_file ${accelerate_config} \
23 | -m mr_eval --config ${config_file}
24 |
25 |
26 | # salloc --partition=MoE --job-name="eval" --gres=gpu:8 -n1 --ntasks-per-node=1 -c 64 --quotatype="reserved"
27 | # salloc --partition=MoE --job-name="interact" --gres=gpu:1 -n1 --ntasks-per-node=1 -c 16 --quotatype="reserved"
28 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/scripts/examples/accelerate_configs/4gpus_deepspeed.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | deepspeed_config_file: /mnt/petrelfs/songmingyang/.config/accelerate/zero3_inference.json
4 | # deepspeed_config:
5 | # fp16:
6 | # enabled: false
7 | # bf16:
8 | # enabled: true
9 | # zero_optimization:
10 | # stage: 2
11 | # offload_param:
12 | # device: cpu
13 | # pin_memory: true
14 | # overlap_comm: true
15 | # contiguous_gradients: false,
16 | # reduce_scatter: true
17 | # activation_checkpointing:
18 | # partition_activations: false
19 | # aio:
20 | # block_size: 1048576
21 | # queue_depth: 8
22 | # overlap_events: true
23 | # single_submit: false
24 | # thread_count: 2
25 | # wall_clock_breakdown: false
26 | # train_micro_batch_size_per_gpu: 4
27 | distributed_type: DEEPSPEED
28 | downcast_bf16: 'no'
29 | machine_rank: 0
30 | main_training_function: main
31 | main_process_port: 29501
32 | num_machines: 1
33 | num_processes: 4
34 | rdzv_backend: static
35 | same_network: true
36 | tpu_env: []
37 | tpu_use_cluster: false
38 | tpu_use_sudo: false
39 | use_cpu: false
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import sys
4 |
5 | from loguru import logger
6 |
7 | logger.remove()
8 | logger.add(sys.stdout, level="WARNING")
9 |
10 | AVAILABLE_MODELS = {
11 | "reasoneval": "ReasonEval",
12 | "math_shepherd": "MathShepherd",
13 | "llemma7b_prm": "Llemma7bPRM",
14 | "mathminos_mistral": "MathMinos_Mistral",
15 | "openai_models": "OpenaiModels",
16 | "llama3_1_8b_prm": "LLaMA318BPRM",
17 | "skywork_prm": "SkyworkPRM",
18 | "gemini_models": "GeminiModels",
19 | "qwen_qwq": "QwenQwQ",
20 | "qwen_prm": "QwenPRM",
21 | "vllm_models": "VllmModels",
22 | "pure_prm": "PUREPRM",
23 | "ensemble_prm": "EnPRM",
24 | }
25 |
26 |
27 | def get_model(model_name):
28 | if model_name not in AVAILABLE_MODELS:
29 | raise ValueError(f"Model {model_name} not found in available models.")
30 |
31 | model_class = AVAILABLE_MODELS[model_name]
32 | try:
33 | module = __import__(f"active_prm.eval.PRMBench.mr_eval.models.{model_name}", fromlist=[model_class])
34 | return getattr(module, model_class)
35 | except Exception as e:
36 | logger.error(f"Failed to import {model_class} from {model_name}: {e}")
37 | raise
38 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/utils/log_utils.py:
--------------------------------------------------------------------------------
1 | import logging, os
2 | from accelerate import Accelerator
3 | from accelerate.logging import get_logger as get_accelerator_logger
4 |
5 | current_file_path = os.path.dirname(os.path.abspath(__file__))
6 | default_log_dir = f"{current_file_path}/../scripts/logs/generated"
7 |
8 | def get_logger(name, level=logging.INFO, log_dir=default_log_dir):
9 | try:
10 | logger = get_accelerator_logger(name)
11 | except:
12 | print("Accelerator is not available, using the default logger.")
13 | logger = logging.getLogger(name)
14 | logger.setLevel(level)
15 | # 创建控制台处理器
16 | console_handler = logging.StreamHandler()
17 | console_handler.setLevel(level)
18 | # 创建文件处理器
19 | file_handler = logging.FileHandler(f"{log_dir}/{name}.log", mode='w')
20 | file_handler.setLevel(logging.DEBUG)
21 |
22 |
23 |
24 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s',
25 | datefmt="%Y-%m-%d %H:%M:%S")
26 | console_handler.setFormatter(formatter)
27 | file_handler.setFormatter(formatter)
28 | try:
29 | logger.logger.addHandler(console_handler)
30 | logger.logger.addHandler(file_handler)
31 | except:
32 | logger.addHandler(console_handler)
33 | logger.addHandler(file_handler)
34 |
35 | return logger
36 |
37 |
38 |
--------------------------------------------------------------------------------
/src/active_prm/trainer/active_sft_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 |
17 | from trl import SFTConfig
18 |
19 |
20 | @dataclass
21 | class ActiveSFTConfig(SFTConfig):
22 | # parameters for active learning
23 | active_learning_pred_threshold: float = field(
24 | default=0.9, metadata={"help": "Prediction threshold for active learning"}
25 | )
26 | active_learning_std_threshold: float = field(
27 | default=0.01,
28 | metadata={"help": "Std threshold for active learning"},
29 | )
30 | active_learning_warmup_steps: int = field(default=-1, metadata={"help": "Warmup steps for active learning"})
31 | random_selection_threshold: float = field(
32 | default=0.5, metadata={"help": "Random selection threshold as a baseline"}
33 | )
34 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/utils/task_utils.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from box import Box
3 |
4 | def load_task_config(file_path):
5 | task_config = load_yaml_file(file_path)
6 | task_config = Box(task_config)
7 | return task_config
8 |
9 |
10 |
11 | def load_data_function_default(task_config):
12 | return load_jsonl_data_function_default(task_config)
13 |
14 | def load_jsonl_data_function_default(task_config):
15 | task_name = task_config["task_name"]
16 | dataset_type = task_config["dataset_type"]
17 | if dataset_type == "jsonl":
18 | dataset_path = task_config["dataset_path"]
19 | meta_data = process_jsonl(dataset_path)
20 | elif dataset_type == "json":
21 | dataset_path = task_config["dataset_path"]
22 | meta_data = load_json_file(dataset_path)
23 | else:
24 | raise ValueError(f"dataset_type {dataset_type} not supported")
25 | return meta_data
26 |
27 | def load_dir_of_jsonl_data_function_default(task_config):
28 | task_name = task_config["task_name"]
29 | dataset_type = task_config["dataset_type"]
30 | dataset_path = task_config["dataset_path"]
31 | assert dataset_type == "dir_of_jsonl"
32 | assert os.path.isdir(dataset_path)
33 | files = [os.path.join(dataset_path, f) for f in os.listdir(dataset_path) if f.endswith(".jsonl")]
34 | meta_data = []
35 | for file in files:
36 | meta_data.extend(process_jsonl(file))
37 | return meta_data
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import inspect
3 | import logging
4 | import os
5 | import sys
6 | from functools import partial
7 | from typing import Dict, List, Mapping, Optional, Union
8 |
9 | from loguru import logger
10 |
11 | logger.remove()
12 | logger.add(sys.stdout, level="WARNING")
13 |
14 | # AVAILABLE_MODELS = {
15 | # "reasoneval": "ReasonEval",
16 | # "math_shepherd": "MathShepherd",
17 | # }
18 |
19 |
20 | def get_task_object(task_name, object_name):
21 | try:
22 | module = __import__(f"mr_eval.tasks.{task_name}.task", fromlist=[object_name])
23 | return getattr(module, object_name)
24 | except Exception as e:
25 | logger.error(f"Failed to import {object_name} from {task_name}: {e}")
26 | raise
27 |
28 |
29 | def get_task_functions(task_name):
30 | """
31 | return a dictionary of functions from the task module
32 | {
33 | "load_data_function": load_data_function,
34 | "evaluate_function": evaluate_function,
35 | "task_config": task_config
36 | }
37 | """
38 | function_list = ["load_data_function", "evaluate_function", "task_config"]
39 | try:
40 | module = __import__(f"active_prm.eval.PRMBench.mr_eval.tasks.{task_name}.task", fromlist=["*"])
41 | res_dict = {func: getattr(module, func) for func in function_list}
42 | return res_dict
43 |
44 | except Exception as e:
45 | logger.error(f"Failed to import all functions from {task_name}: {e}")
46 | raise
47 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/models/abstract_model.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import hashlib
3 | import json
4 | import os
5 | from typing import List, Optional, Tuple, Type, TypeVar, Union
6 |
7 | from loguru import logger as eval_logger
8 | from tqdm import tqdm
9 |
10 |
11 | T = TypeVar("T", bound="prm")
12 |
13 |
14 | class prm(abc.ABC):
15 | def __init__(
16 | self,
17 | redundancy_threshold = 0.15,
18 | validity_threshold = 0.5,
19 | generation_config = {},
20 | ) -> None:
21 | self.redundancy_threshold = float(redundancy_threshold)
22 | self.validity_threshold = float(validity_threshold)
23 | self.set_generation_config(generation_config)
24 |
25 | def to(self, device: str) -> T:
26 | self.model = self.model.to(device)
27 | return self
28 |
29 | @abc.abstractmethod
30 | def respond(self, dataloader) -> None:
31 | pass
32 |
33 | def set_generation_config(self, generation_configs: dict) -> None:
34 | self.generation_config = generation_configs
35 | self.generation_config["max_length"] = generation_configs.get("max_length", 512)
36 | self.generation_config["temperature"] = generation_configs.get("temperature", 0.0)
37 | self.generation_config["top_k"] = generation_configs.get("top_k", 1)
38 | self.generation_config["top_p"] = generation_configs.get("top_p", 1.0)
39 |
40 |
41 |
42 | def get_generation_config(self) -> dict:
43 | try:
44 | return self.generation_config
45 | except:
46 | return {}
47 |
48 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/model_inference/qwq/run_inference.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Dict, List, Optional, Sequence
3 |
4 | import requests
5 | import torch
6 | import tqdm
7 | from accelerate import Accelerator, PartialState
8 | from inferencer import inferencer_type_dict
9 | from PIL import Image
10 | from torch.utils.data import DataLoader, Dataset
11 | from tqdm import tqdm
12 | from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
13 |
14 | if __name__ == "__main__":
15 |
16 | @dataclass
17 | class DataArguments:
18 | input_path: List[str] = field(default_factory=list)
19 | output_path: str = field(default=None)
20 | batch_size: int = field(default=8)
21 | num_workers: int = field(default=2)
22 |
23 | @dataclass
24 | class ModelArguments:
25 | model_path: str = field(
26 | default="/mnt/petrelfs/songmingyang/songmingyang/model/reasoning/policy_models/QwQ-32B-Preview"
27 | )
28 |
29 | @dataclass
30 | class InferenceArguments:
31 | function: str = field(default="generate_prm")
32 | parallel_mode: str = field(default="1gpu")
33 | float_type: str = field(default="float16")
34 |
35 | def __post_init__(self):
36 | self.float_type = getattr(torch, self.float_type)
37 |
38 | parser = HfArgumentParser((InferenceArguments, ModelArguments, DataArguments))
39 |
40 | inference_args, model_args, data_args = parser.parse_args_into_dataclasses()
41 | data_args.input_path = data_args.input_path[0]
42 | function_name = inference_args.function
43 | inference_module = inferencer_type_dict[function_name]["model"](
44 | inference_args=inference_args, model_args=model_args, data_args=data_args
45 | )
46 |
47 | inference_module.inference()
48 |
49 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/generate_data/dataset/generate_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils.data import Dataset
3 | from mr_eval.utils.utils import *
4 | # Question:
5 |
6 | # If $x^2 - x - 1 = 0$, what is the value of $x^3 - 2x + 1$?
7 |
8 | # Answer:
9 |
10 | # Step 1. Let's start with the first equation and see if we can solve for x.
11 |
12 | # Step 2. We can use the quadratic formula to find x.
13 |
14 | # Step 3. The quadratic formula is $x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$.
15 |
16 | # Step 4. Substituting $a = 1$, $b = -1$, and $c = -1$, we get $x = \\frac{1 \\pm \\sqrt{5}}{2}$.
17 |
18 | # Step 5. Since $x^2 - x - 1 = 0$, we replace $x^2$ with $x + 1$ in $x^3 - 2x + 1$.
19 |
20 | # Step 6. Substituting, $x^3 - 2x + 1 = x(x + 1) - 2x + 1$.
21 |
22 | # Step 7. Simplify: $x^3 - 2x + 1 = (x^2 + x) - 2x + 1$.
23 |
24 | # Step 8. Substitute $x^2 = x + 1$: $x^3 - 2x + 1 = (x + 1 + x) - 2x + 1$.
25 |
26 | # Step 9. Simplify: $x^3 - 2x + 1 = 2$."
27 |
28 | # # Answer
29 |
30 | # 2
31 |
32 | class GenerateDataset(Dataset):
33 | def __init__(self, data_args):
34 | self.data_args = data_args
35 | self.input_path = data_args.input_path
36 | self.output_path = data_args.output_path
37 | self.subject = data_args.subject
38 |
39 | self.load_data()
40 | self.resume_from_ckpt()
41 |
42 | def load_data(self):
43 | raw_data = process_jsonl(self.input_path)
44 | self.meta_data = []
45 | for idx, item in enumerate(raw_data):
46 | item_idx = item["idx"]
47 | question = item["Question"]
48 | options = item["Options"]
49 | question = f"{question} {options}"
50 | steps = item["Model_Solution_Steps"]
51 | step_str = ""
52 | for step_idx,step in steps:
53 | step_text = step["text"]
54 | question = f"{question} {step_text}"
55 |
56 |
57 |
58 | def resume_from_ckpt(self):
59 | pass
60 |
61 | def __len__(self):
62 | return len(self.meta_data)
63 |
64 | def __getitem__(self, idx):
65 | pass
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Efficient Process Reward Model Training via Active Learning
4 |
5 | The official Implementation for Paper "Efficient Process Reward Model Training via Active Learning".
6 |
7 | [](http://arxiv.org/abs/2504.10559) [](https://huggingface.co/collections/sail/active-prm-67fccbc10c94c41a6d6b26d9)
8 |
9 |
10 |
11 | ## 🔥 Updates
12 |
13 | - 16/04/2025: Our paper is available on [arxiv](http://arxiv.org/abs/2504.10559) now!
14 | - 14/04/2025: We release our code, models and data. Paper will be available soon.
15 | - 14/04/2025: Within 7B PRMs, our model `sail/ActPRM-X` (based on `Qwen/Qwen2.5-Math-PRM-7B`) achieved new SOTA performance on ProcessBench (76.0%) and PRMBench (66.7%).
16 |
17 | ## 🏴 Overview
18 |
19 | **TL;DR: We achieved SOTA performance on [ProcessBench](https://github.com/QwenLM/ProcessBench) (75.0%) and [PRMBench](https://github.com/ssmisya/PRMBench?tab=readme-ov-file) (65.5%) with merely 5% labeling cost compared with `Qwen/Qwen2.5-Math-PRM-7B`**.
20 |
21 |
22 |
23 |
24 |
25 | ## 📊 Results
26 |
27 |
28 | ProcessBench
29 |
30 |
31 |
32 | PRMBench
33 |
34 |
35 |
36 | ## ⚡️ Quickstart
37 |
38 | ### Installation
39 | ```shell
40 | git clone https://github.com/sail-sg/ActivePRM.git
41 | cd ActivePRM
42 | pip install -e . # tested in conda env where python==3.11
43 | ```
44 |
45 | ### Replication
46 |
47 | - Evaluate our `sail/ActPRM-X` and `sail/ActPRM` on ProcessBench simply by running
48 | ```shell
49 | cd examples
50 | python py_scripts/test_actprm_on_processbench.py
51 | ```
52 |
53 | - Training PRM with Active Learning
54 | ```shell
55 | cd examples
56 | bash scripts/pool_based_active_learning.sh sail/ActPRMData
57 | ```
58 |
59 | ## Citation
60 | If you find our repo or paper helpful, please cite
61 | ```
62 | @misc{duan2025actprm,
63 | title={Efficient Process Reward Model Training via Active Learning},
64 | author={Keyu Duan and Zichen Liu and Xin Mao and Tianyu Pang and Changyu Chen and Qiguang Chen and Michael Qizhe Shieh and Longxu Dou},
65 | year={2025},
66 | eprint={2504.10559},
67 | archivePrefix={arXiv},
68 | primaryClass={cs.LG},
69 | url={https://arxiv.org/abs/2504.10559},
70 | }
71 | ```
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import json
4 | import logging
5 | import os
6 | import sys
7 | from functools import partial
8 | from typing import Union
9 |
10 | from accelerate import Accelerator
11 |
12 | from .evaluator import MREvaluator
13 | from .utils.arguments import *
14 | from .utils.log_utils import get_logger
15 | from .utils.utils import *
16 |
17 | logger = get_logger(__name__)
18 | import os
19 |
20 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
21 |
22 |
23 | def main():
24 | accelerator = Accelerator() # Necceary for logging scripts
25 | args_dict = parse_args()
26 | model_args, task_args, script_args = args_dict["model_args"], args_dict["task_args"], args_dict["script_args"]
27 |
28 | config = script_args.config
29 | if config:
30 | if not isinstance(config, list):
31 | config = [config]
32 |
33 | for idx, config_item in enumerate(config):
34 | if model_args.model != "vllm_models":
35 | dist_wait_for_everyone()
36 | gc.collect()
37 | torch.cuda.empty_cache()
38 | logger.info(f"After Cleaning: Memory Allocated: {torch.cuda.memory_allocated()/(1024 ** 3) :.2f} GB")
39 | logger.info(f"Begin evaluating on the No. {idx+1} config, toal {len(config)} configs.")
40 | if isinstance(config_item, dict):
41 | model_args = ModelArguments(**config_item["model_args"])
42 | task_args = TaskArguments(**config_item["task_args"])
43 | script_args = ScriptArguments(**config_item["script_args"])
44 |
45 | task_args.task_name = parse_str_into_list(task_args.task_name)
46 | if isinstance(model_args.model_args, str):
47 | model_args.model_args = parse_str_into_dict(model_args.model_args)
48 | if isinstance(script_args.wandb_args, str):
49 | script_args.wandb_args = parse_str_into_dict(script_args.wandb_args)
50 | else:
51 | assert len(config) == 1, "If config is not a list, it should be a dictionary or NoneType"
52 | raise ValueError("Config should be a list of dictionaries.")
53 |
54 | evaluator = MREvaluator(model_args, task_args, script_args)
55 | evaluator.evaluate()
56 | del evaluator
57 | logger.info(f"Finished evaluating on the No. {idx+1} config, toal {len(config)} configs.")
58 | else:
59 | evaluator = MREvaluator(model_args, task_args, script_args)
60 | evaluator.evaluate()
61 | del evaluator
62 | logger.info("Finished evaluating on the single config.")
63 |
64 |
65 | if __name__ == "__main__":
66 | main()
67 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.1.1
2 | aiofiles==23.2.1
3 | altair==5.4.1
4 | annotated-types==0.7.0
5 | anyio==4.6.2.post1
6 | asttokens==2.4.1
7 | attrs==24.2.0
8 | bitsandbytes==0.42.0
9 | Brotli
10 | certifi
11 | charset-normalizer
12 | click==8.1.7
13 | comm==0.2.2
14 | contourpy==1.3.0
15 | cycler==0.12.1
16 | debugpy==1.8.7
17 | decorator==5.1.1
18 | distro==1.9.0
19 | einops==0.6.1
20 | einops-exts==0.0.4
21 | exceptiongroup==1.2.2
22 | executing==2.1.0
23 | fastapi==0.115.4
24 | ffmpy==0.4.0
25 | filelock
26 | fonttools==4.54.1
27 | fsspec==2024.10.0
28 | gmpy2
29 | gradio==4.16.0
30 | gradio_client==0.8.1
31 | h11==0.14.0
32 | httpcore==0.17.3
33 | httpx==0.24.0
34 | huggingface-hub==0.26.2
35 | idna
36 | importlib_resources==6.4.5
37 | ipykernel==6.29.5
38 | ipython==8.29.0
39 | jedi==0.19.1
40 | Jinja2
41 | jiter==0.8.0
42 | joblib==1.4.2
43 | jsonschema==4.23.0
44 | jsonschema-specifications==2024.10.1
45 | jupyter_client==8.6.3
46 | jupyter_core==5.7.2
47 | kiwisolver==1.4.7
48 | latex2mathml==3.77.0
49 | loguru==0.7.2
50 | markdown-it-py==3.0.0
51 | markdown2==2.5.1
52 | MarkupSafe
53 | matplotlib==3.9.2
54 | matplotlib-inline==0.1.7
55 | mdurl==0.1.2
56 | mkl-service==2.4.0
57 | mkl_fft
58 | mkl_random
59 | mpmath
60 | narwhals==1.13.1
61 | nest-asyncio==1.6.0
62 | networkx
63 | numpy==1.26.4
64 | openai==1.55.3
65 | orjson==3.10.11
66 | packaging==24.1
67 | pandas==2.2.3
68 | parso==0.8.4
69 | pexpect==4.9.0
70 | pillow
71 | platformdirs==4.3.6
72 | prompt_toolkit==3.0.48
73 | protobuf==5.28.3
74 | psutil==6.1.0
75 | ptyprocess==0.7.0
76 | pure_eval==0.2.3
77 | pydantic==2.9.2
78 | pydantic_core==2.23.4
79 | pydub==0.25.1
80 | Pygments==2.18.0
81 | pyparsing==3.2.0
82 | PySocks
83 | python-box==7.2.0
84 | python-dateutil==2.9.0.post0
85 | python-multipart==0.0.17
86 | pytz==2024.2
87 | PyYAML
88 | pyzmq==26.2.0
89 | referencing==0.35.1
90 | regex==2024.9.11
91 | requests
92 | rich==13.9.4
93 | rpds-py==0.20.1
94 | ruff==0.7.2
95 | safetensors==0.4.5
96 | scikit-learn==1.2.2
97 | scipy==1.14.1
98 | semantic-version==2.10.0
99 | sentencepiece==0.2.0
100 | shellingham==1.5.4
101 | six==1.16.0
102 | sniffio==1.3.1
103 | stack-data==0.6.3
104 | starlette==0.41.2
105 | svgwrite==1.4.3
106 | sympy
107 | threadpoolctl==3.5.0
108 | timm==0.6.13
109 | tokenizers==0.20.1
110 | tomlkit==0.12.0
111 | torch==2.3.0
112 | torchaudio==2.3.0
113 | torchvision==0.18.0
114 | tornado==6.4.1
115 | tqdm==4.66.6
116 | traitlets==5.14.3
117 | transformers==4.46.0
118 | triton==2.3.0
119 | typer==0.12.5
120 | typing_extensions
121 | tzdata==2024.2
122 | urllib3
123 | uvicorn==0.32.0
124 | wavedrom==2.0.3.post3
125 | wcwidth==0.2.13
126 | websockets==11.0.3
127 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/docs/data_format.md:
--------------------------------------------------------------------------------
1 | # Data Format for $PRMBench$
2 |
3 | In $PRMBench$, our data format can be formulated as follows:
4 |
5 | ```json
6 | {
7 | // Original question
8 | "original_question": "Three pencils and a jumbo eraser cost $1.24. Five pencils and a jumbo eraser cost $1.82. No prices include tax. In cents, what is the cost of a pencil?",
9 | // Modified question --used for the evaluation
10 | "modified_question": "Three pencils and a jumbo eraser cost $1.24. Five pencils and a jumbo eraser cost $1.82. No prices include tax. In cents, what is the cost of a pencil?",
11 | // Original process -- the original solution steps
12 | "original_process": [
13 | "1. Let's call the price of a pencil p and the price of a jumbo eraser e. Then we can write two equations.",
14 | "2. We have $3p+e=1.24$ and $5p+e=1.82$.",
15 | "3. To solve this system, let's subtract the first equation from the second equation. This will eliminate e.",
16 | "4. $5p+e-3p-e=1.82-1.24$.",
17 | "5. This simplifies to $2p=0.58$. So $p=0.29$.",
18 | "6. That means a pencil costs 29 cents."
19 | ],
20 | // Modified process -- used for the evaluation
21 | "modified_process": [
22 | "1. Let's call the price of a pencil p and the price of a jumbo eraser e. Then we can write two equations.",
23 | "2. We have $3p+e=1.24$ and $5p+e=1.82$.",
24 | "3. Assume a pencil costs 29 cents. Then verify this against the original equations.",
25 | "4. If $p=0.29$, substitute into one of the equations, for example, $3p+e=1.24$. It gives $3(0.29)+e=1.24$.",
26 | "5. Solving $0.87+e=1.24$ gives $e=0.37$. Now check $5p+e=1.82$ to confirm.",
27 | "6. Plug $p=0.29$ and $e=0.37$ into $5p+e=1.82$. We get $5(0.29)+0.37=1.82$.",
28 | "7. The check confirms that the assumed price of 29 cents per pencil is correct."
29 | ],
30 | // Modified steps -- the steps that are modified
31 | "modified_steps": [3, 4, 5, 6],
32 | // Error steps -- the steps that contain errors
33 | "error_steps": [3, 4, 5],
34 | // Reason for the error
35 | "reason": "Steps 3, 4, and 5 introduce circular reasoning by assuming the result ($p = 0.29$) and verifying it rather than deriving it through independent calculations. This creates a fallacious reasoning process since the solution is assumed as a premise and then used to confirm itself.",
36 | // idx -- unique identifier for the data instance
37 | "idx": "circular_prm_test_p1_0",
38 | // question -- the original question
39 | "question": "Three pencils and a jumbo eraser cost $\\$1.24$. Five pencils and a jumbo eraser cost $\\$1.82$. No prices include tax. In cents, what is the cost of a pencil?",
40 | // classification -- the classification of the error
41 | "classification": "circular"
42 | }
43 | ```
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/vis_res.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .mr_eval.utils.utils import *
4 |
5 | classification_name_dict = dict(
6 | domain_inconsistency="DC.",
7 | redundency="NR.",
8 | multi_solutions="MS.",
9 | deception="DR.",
10 | confidence="CI.",
11 | step_contradiction="SC.",
12 | circular="NCL.",
13 | missing_condition="PS.",
14 | counterfactual="ES.",
15 | )
16 | classification_parallel_dict = dict(
17 | simplicity=dict(
18 | redundency="NR.",
19 | circular="NCL.",
20 | ),
21 | soundness=dict(
22 | counterfactual="ES.",
23 | step_contradiction="SC.",
24 | domain_inconsistency="DC.",
25 | confidence="CI.",
26 | ),
27 | sensitivity=dict(
28 | missing_condition="PS.",
29 | deception="DR.",
30 | multi_solutions="MS.",
31 | ),
32 | )
33 | classifications = [
34 | "redundency",
35 | "circular",
36 | "counterfactual",
37 | "step_contradiction",
38 | "domain_inconsistency",
39 | "confidence",
40 | "missing_condition",
41 | "deception",
42 | "multi_solutions",
43 | ]
44 | metrics = [
45 | "f1",
46 | "negative_f1",
47 | "total_step_acc",
48 | "correct_step_acc",
49 | "wrong_step_acc",
50 | "first_error_acc",
51 | "similarity",
52 | ]
53 |
54 |
55 | # def main(data_file):
56 | # res_dict = {}
57 | # for model_name, file_path in file_dict.items
58 | def get_prmscore_from_current_res_dict(res_dict, classification=None):
59 | """
60 | Get PRM score from model level dict
61 | """
62 | if not classification:
63 | prm_score = (
64 | res_dict["total_hallucination_results"]["f1"] * 0.5
65 | + res_dict["total_hallucination_results"]["negative_f1"] * 0.5
66 | )
67 | else:
68 | if classification in ["multi_solutions"]:
69 | prm_score = res_dict["hallucination_type_results"]["f1"][classification]
70 | else:
71 | prm_score = (
72 | res_dict["hallucination_type_results"]["f1"][classification] * 0.5
73 | + res_dict["hallucination_type_results"]["negative_f1"][classification] * 0.5
74 | )
75 | return prm_score
76 |
77 |
78 | def main(file_path):
79 | # file_path = "./out/bench/prmbench/enprm_numinamath_ne_8_lr_2e-6/results.jsonl"
80 | res = process_jsonl(file_path)[-1]
81 | prm_score = get_prmscore_from_current_res_dict(res)
82 | print(f"Overall: {prm_score}:.3f")
83 |
84 | for big_classification, current_classifcation_dict in classification_parallel_dict.items():
85 | print(f"Big Classification: {big_classification}")
86 | avg = []
87 | for classification, prefix in current_classifcation_dict.items():
88 | prm_score = get_prmscore_from_current_res_dict(res, classification)
89 | print(f"{prefix}: {prm_score:.3f}")
90 | avg += [prm_score]
91 | print(f"Average: {sum(avg) / len(avg):.3f}")
92 |
93 |
94 | if __name__ == "__main__":
95 | import fire
96 |
97 | fire.Fire(main)
98 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_visualize/eval_res_view/draw_tabs/bias_table.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from mr_eval.utils.utils import *\n",
10 | "import os\n",
11 | "from copy import deepcopy\n",
12 | "import numpy as np\n",
13 | "\n",
14 | "data_dir = \"/mnt/petrelfs/songmingyang/code/reasoning/MR_Hallucination/mr_eval/tasks/prmtest_classified/data\"\n",
15 | "dataset_type = \"dir_of_jsonl\"\n",
16 | "\n",
17 | "#domain_inconsistency\tredundency\tmulti_solutions\tdeception\tconfidence\tstep_contradiction\tcircular\tmissing_condition\tcounterfactual\n",
18 | "classification_name_dict = dict(\n",
19 | " domain_inconsistency=\"DC.\",\n",
20 | " redundency=\"NR.\",\n",
21 | " multi_solutions=\"MS.\",\n",
22 | " deception=\"DR.\",\n",
23 | " confidence=\"CI.\",\n",
24 | " step_contradiction=\"SC.\",\n",
25 | " circular=\"NCL.\",\n",
26 | " missing_condition=\"PS.\",\n",
27 | " counterfactual=\"ES.\"\n",
28 | ")\n",
29 | "classifications = [\"redundency\", \"circular\", \"counterfactual\", \"step_contradiction\", \"domain_inconsistency\", \"confidence\", \"missing_condition\", \"deception\", \"multi_solutions\", ]\n"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 5,
35 | "metadata": {},
36 | "outputs": [
37 | {
38 | "name": "stdout",
39 | "output_type": "stream",
40 | "text": [
41 | "bias: 0.20592848872394862\n"
42 | ]
43 | }
44 | ],
45 | "source": [
46 | "data_files = os.listdir(data_dir)\n",
47 | "data_files = [f for f in data_files if f.endswith(\".jsonl\")]\n",
48 | "raw_data = []\n",
49 | "for data_file in data_files:\n",
50 | " raw_data.extend(process_jsonl(os.path.join(data_dir, data_file)))\n",
51 | " \n",
52 | "bias = []\n",
53 | "for item in raw_data:\n",
54 | " error_length = len(item[\"error_steps\"])\n",
55 | " total_length = len(item[\"modified_process\"])\n",
56 | " if error_length <= total_length:\n",
57 | " accuracy = 1 - error_length / total_length\n",
58 | " accuracy_bias = error_length / total_length\n",
59 | " bias.append(accuracy_bias)\n",
60 | "\n",
61 | "print(\"bias: \", np.mean(bias))"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "metadata": {},
68 | "outputs": [],
69 | "source": []
70 | }
71 | ],
72 | "metadata": {
73 | "kernelspec": {
74 | "display_name": "smoe",
75 | "language": "python",
76 | "name": "python3"
77 | },
78 | "language_info": {
79 | "codemirror_mode": {
80 | "name": "ipython",
81 | "version": 3
82 | },
83 | "file_extension": ".py",
84 | "mimetype": "text/x-python",
85 | "name": "python",
86 | "nbconvert_exporter": "python",
87 | "pygments_lexer": "ipython3",
88 | "version": "3.11.8"
89 | }
90 | },
91 | "nbformat": 4,
92 | "nbformat_minor": 2
93 | }
94 |
--------------------------------------------------------------------------------
/examples/scripts/pool_based_active_learning.sh:
--------------------------------------------------------------------------------
1 | conda activate o1
2 |
3 | base_dir=./out/models
4 | model=Qwen/Qwen2.5-Math-7B-Instruct
5 | model_id=$(echo "$model" | cut -d'/' -f2)
6 |
7 | problem_type=single_label_classification
8 | label_type=hard_labels
9 |
10 | num_ensemble=32 # 8, 1
11 | p_threshold=5
12 | std_threshold=-1
13 | lr=1e-5
14 | batch_size=64
15 | dataset=$1
16 |
17 | if [[ $num_ensemble == 1 ]]; then
18 | learning_probability=1.0
19 | regularization_lambda=0.0
20 | else
21 | learning_probability=1.0
22 | regularization_lambda=0.5
23 | fi
24 | freeze_backbone=False
25 | # echo $annotate_model $num_ensemble $learning_probability $freeze_backbone
26 |
27 | echo num_ensemble $num_ensemble
28 | echo learning_probability $learning_probability
29 | echo pred_threshold $p_threshold
30 | echo std_threshold $std_threshold
31 | echo lr $lr
32 | echo batch_size $batch_size
33 | echo dataset $dataset
34 |
35 | exp_name=pool_based_active_learning
36 | model_id=$exp_name
37 | output_dir=$base_dir/${exp_name}
38 | mkdir -p $output_dir
39 |
40 | enable_wandb=True
41 | if [[ $enable_wandb == 'True' ]]; then
42 | wandb online
43 | export WANDB_PROJECT=active_prm
44 | export WANDB_RESUME='allow'
45 | export WANDB_RUN_ID=$exp_name
46 | report_to='wandb'
47 | else
48 | report_to='none'
49 | fi
50 |
51 | # training hps
52 | num_train_epochs=1
53 | per_device_train_batch_size=8
54 | gradient_accumulation_steps=$((batch_size / (per_device_train_batch_size * num_gpus)))
55 |
56 | accelerate launch py_scripts/pool_based_active_learning.py \
57 | --deepspeed ds_config.json \
58 | --model_name_or_path $model \
59 | --dataset_name ${dataset} \
60 | --learning_rate $lr \
61 | --num_train_epochs $num_train_epochs \
62 | --per_device_train_batch_size $per_device_train_batch_size \
63 | --gradient_accumulation_steps $gradient_accumulation_steps \
64 | --gradient_checkpointing \
65 | --logging_steps 5 \
66 | --logging_dir $output_dir/logs \
67 | --save_total_limit 1 \
68 | --save_strategy 'steps' \
69 | --save_steps 100 \
70 | --eval_strategy 'steps' \
71 | --eval_steps 500 \
72 | --output_dir $output_dir \
73 | --report_to=$report_to \
74 | --run_name $exp_name \
75 | --max_seq_length 2048 \
76 | --bf16=True \
77 | --torch_dtype auto \
78 | --num_ensemble $num_ensemble \
79 | --label_type $label_type \
80 | --problem_type $problem_type \
81 | --learning_probability $learning_probability \
82 | --regularization_lambda $regularization_lambda \
83 | --rr_token '' \
84 | --active_learning_pred_threshold $p_threshold \
85 | --active_learning_std_threshold $std_threshold \
86 | --lr_scheduler_type 'linear' \
87 | --warmup_steps 500 \
88 | 2>&1 | tee $output_dir/log.txt
89 |
90 | python -m online_prm.eval.processbench prm ${output_dir}
91 |
92 | accelerate launch -m active_prm.eval.PRMBench \
93 | --model ensemble_prm \
94 | --model_args pretrained=${output_dir} \
95 | --task_name prmtest_classified \
96 | --verbosity INFO \
97 | --output_path ./out/bench/prmbench/${model_id}/results.jsonl
98 |
99 | python -m online_prm.eval.PRMBench.vis_res ./out/bench/prmbench/${exp_name}/results.jsonl
100 |
--------------------------------------------------------------------------------
/examples/py_scripts/code_snippets.py:
--------------------------------------------------------------------------------
1 | # modified from https://huggingface.co/Qwen/Qwen2.5-Math-PRM-7B/blob/main/README.md
2 | from transformers import AutoModel, AutoTokenizer
3 |
4 |
5 | def make_step_rewards(logits, token_masks):
6 | res = []
7 | for j in range(outputs.logits.size(1)):
8 | logits = outputs.logits[:, j, token_masks[j]]
9 | logits = logits.mean(dim=0)
10 | res.append(logits.tolist())
11 | return res
12 |
13 |
14 | model_name = "ActPRM/ActPRM"
15 | device = "auto"
16 |
17 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
18 | model = (
19 | AutoModel.from_pretrained(
20 | model_name,
21 | torch_dtype="auto",
22 | trust_remote_code=True,
23 | )
24 | .eval()
25 | .to("cuda")
26 | )
27 |
28 |
29 | data = {
30 | "system": "Please reason step by step, and put your final answer within \\boxed{}.",
31 | "query": "Sue lives in a fun neighborhood. One weekend, the neighbors decided to play a prank on Sue. On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard. On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard. Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos?",
32 | "response": [
33 | "To find out how many more pink plastic flamingos were out than white plastic flamingos at noon on Sunday, we can break down the problem into steps. First, on Friday, the neighbors start with 18 pink plastic flamingos.",
34 | "On Saturday, they take back one third of the flamingos. Since there were 18 flamingos, (1/3 \\times 18 = 6) flamingos are taken back. So, they have (18 - 6 = 12) flamingos left in their possession. Then, they paint these 6 flamingos white and put them back out on Sue's front yard. Now, Sue has the original 12 pink flamingos plus the 6 new white ones. Thus, by the end of Saturday, Sue has (12 + 6 = 18) pink flamingos and 6 white flamingos.",
35 | "On Sunday, the neighbors add another 18 pink plastic flamingos to Sue's front yard. By the end of Sunday morning, Sue has (18 + 18 = 36) pink flamingos and still 6 white flamingos.",
36 | "To find the difference, subtract the number of white flamingos from the number of pink flamingos: (36 - 6 = 30). Therefore, at noon on Sunday, there were 30 more pink plastic flamingos out than white plastic flamingos. The answer is (\\boxed{30}).",
37 | ],
38 | }
39 |
40 | messages = [
41 | {"role": "system", "content": data["system"]},
42 | {"role": "user", "content": data["query"]},
43 | {"role": "assistant", "content": "".join(data["response"]) + ""},
44 | ]
45 | conversation_str = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
46 |
47 | input_ids = tokenizer.encode(
48 | conversation_str,
49 | return_tensors="pt",
50 | ).to(model.device)
51 |
52 | outputs = model.inference(input_ids=input_ids)
53 |
54 | step_sep_id = tokenizer.encode("")[0]
55 | token_masks = input_ids == step_sep_id
56 | step_reward = make_step_rewards(outputs[0], token_masks)
57 | print(step_reward) # [[0.9993686676025391, 0.2316841036081314, 0.7311716079711914, 0.8314468264579773]]
58 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/evaluator.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import itertools
4 | import json
5 | import logging
6 | import random
7 | import time
8 | from collections import defaultdict
9 | from dataclasses import asdict
10 | from typing import TYPE_CHECKING, List, Optional, Union
11 |
12 | import numpy as np
13 | import torch
14 | import torch.distributed as dist
15 | import yaml
16 | from accelerate.state import AcceleratorState
17 | from torch.utils.data import DataLoader
18 |
19 | from .models import get_model
20 | from .tasks import get_task_functions, get_task_object
21 | from .tasks.base_dataset.base_evaluation_dataset import (
22 | BaseEvalDataset, DataCollatorForSupervisedDataset)
23 | from .utils.arguments import *
24 | from .utils.log_utils import get_logger
25 | from .utils.utils import *
26 |
27 | logger = get_logger(__name__)
28 |
29 |
30 | class MREvaluator:
31 | def __init__(self, model_args, task_args, script_args):
32 | self.config = script_args.config
33 | self.model_args = model_args
34 | self.task_args = task_args
35 | self.script_args = script_args
36 | self.tasks = self.task_args.task_name
37 | self.model = get_model(self.model_args.model)(**self.model_args.model_args)
38 | try:
39 | self.tokenizer = self.model.tokenizer
40 | except AttributeError:
41 | self.tokenizer = None
42 |
43 | self.state = AcceleratorState()
44 | self.batch_size = asdict(self.model_args).get("batch_size", 1)
45 | if self.state.deepspeed_plugin:
46 | deepspeed_config = self.state.deepspeed_plugin.deepspeed_config
47 | # 修改配置
48 | deepspeed_config["train_micro_batch_size_per_gpu"] = self.batch_size
49 | # 应用修改
50 | self.state.deepspeed_plugin.deepspeed_config = deepspeed_config
51 | else:
52 | logger.info("DeepSpeed is not initialized. Skipping DeepSpeed-specific configuration.")
53 |
54 | def evaluate(self):
55 | for task_name in self.tasks:
56 | logger.info(f"evaluating {task_name}")
57 | task_dict = get_task_functions(task_name)
58 | load_data_function, evaluate_function, task_config = (
59 | task_dict["load_data_function"],
60 | task_dict["evaluate_function"],
61 | task_dict["task_config"],
62 | )
63 | self.model.set_generation_config(task_config.generation_config)
64 |
65 | dataset = BaseEvalDataset(
66 | load_data_function=load_data_function,
67 | getitem_function=self.model.getitem_function,
68 | evaluate_function=evaluate_function,
69 | task_config=task_config,
70 | task_args=self.task_args,
71 | model_args=self.model_args,
72 | )
73 | num_workers = self.model_args.num_workers
74 | data_collator = DataCollatorForSupervisedDataset(
75 | tokenizer=self.tokenizer,
76 | max_length=task_config.generation_config.max_length,
77 | padding_side=dataset.padding_side,
78 | )
79 | dataloader = DataLoader(
80 | dataset, batch_size=self.model_args.batch_size, num_workers=num_workers, collate_fn=data_collator
81 | )
82 | self.model.respond(dataloader)
83 | res_log = dataset.evaluate()
84 | if is_main_process():
85 | logger.info(f"evaluation of {task_name} completed")
86 | append_jsonl(res_log, self.script_args.output_path)
87 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/models/qwen_math_rm_fsdp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | from torch.distributed.fsdp import (
5 | FullyShardedDataParallel as FSDP,
6 | MixedPrecision,
7 | StateDictType,
8 | FullStateDictConfig,
9 | ShardingStrategy,
10 | )
11 | from transformers import AutoModelForCausalLM, AutoTokenizer
12 | from tqdm import tqdm
13 |
14 |
15 |
16 |
17 | class InferenceManager:
18 | def __init__(self, model_name, checkpoint_path=None):
19 | self.setup_dist()
20 | self.model_name = model_name
21 | self.checkpoint_path = checkpoint_path
22 | self.local_rank = int(os.environ["LOCAL_RANK"])
23 |
24 | # 初始化模型和分词器
25 | self.init_model_and_tokenizer()
26 |
27 | def setup_dist(self):
28 | dist.init_process_group("nccl")
29 | torch.cuda.set_device(dist.get_rank())
30 |
31 | def init_model_and_tokenizer(self):
32 | self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
33 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
34 |
35 | # FSDP包装
36 | self.model = FSDP(
37 | self.model,
38 | sharding_strategy=ShardingStrategy.FULL_SHARD,
39 | device_id=torch.cuda.current_device(),
40 | mixed_precision=MixedPrecision(
41 | param_dtype=torch.float16,
42 | reduce_dtype=torch.float16,
43 | buffer_dtype=torch.float16,
44 | ),
45 | )
46 |
47 | # 加载检查点
48 | if self.checkpoint_path and os.path.exists(self.checkpoint_path):
49 | state_dict = torch.load(self.checkpoint_path)
50 | self.model.load_state_dict(state_dict)
51 |
52 | def inference(self, texts, batch_size=4, max_length=100):
53 | """批量推理"""
54 | self.model.eval()
55 | results = []
56 |
57 | # 创建进度条(只在主进程)
58 | if self.local_rank == 0:
59 | pbar = tqdm(range(0, len(texts), batch_size), desc="Inferencing")
60 | else:
61 | pbar = range(0, len(texts), batch_size)
62 |
63 | for i in pbar:
64 | batch_texts = texts[i:i + batch_size]
65 |
66 | with torch.no_grad():
67 | inputs = self.tokenizer(
68 | batch_texts,
69 | return_tensors="pt",
70 | padding=True,
71 | truncation=True
72 | ).to(self.model.device)
73 |
74 | outputs = self.model.generate(
75 | **inputs,
76 | max_length=max_length,
77 | num_return_sequences=1,
78 | pad_token_id=self.tokenizer.pad_token_id,
79 | )
80 |
81 | if self.local_rank == 0:
82 | decoded_outputs = [
83 | self.tokenizer.decode(output, skip_special_tokens=True)
84 | for output in outputs
85 | ]
86 | results.extend(decoded_outputs)
87 |
88 | return results if self.local_rank == 0 else None
89 |
90 | def cleanup(self):
91 | dist.destroy_process_group()
92 |
93 | def main():
94 | # 配置
95 | model_name = "your_model_name"
96 | checkpoint_path = "path/to/checkpoint.pt"
97 |
98 | # 创建推理管理器
99 | inference_manager = InferenceManager(model_name, checkpoint_path)
100 |
101 | # 准备输入数据
102 | texts = [
103 | "First input text",
104 | "Second input text",
105 | # ... more texts
106 | ]
107 |
108 | # 执行推理
109 | results = inference_manager.inference(texts, batch_size=4)
110 |
111 | # 打印结果(只在主进程)
112 | if dist.get_rank() == 0:
113 | for i, result in enumerate(results):
114 | print(f"Input {i}:")
115 | print(f"Generated: {result}\n")
116 |
117 | # 清理
118 | inference_manager.cleanup()
119 |
120 | if __name__ == "__main__":
121 | main()
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/utils/prompts.py:
--------------------------------------------------------------------------------
1 | policy_model_system_prompt = """
2 | You are a mathematical reasoning evaluator. Your task is to analyze mathematical problem-solving steps and provide structured assessments in JSON format.
3 |
4 | For each solution step, you need to evaluate two aspects:
5 | 1. Validity Score (-1 to +1):
6 | * +1: Completely correct mathematical reasoning
7 | * 0: Partially correct with some mistakes
8 | * -1: Completely incorrect
9 | * Use any value in between to indicate varying degrees of correctness
10 |
11 | 2. Redundancy Score (-1 to +1):
12 | * -1: Critical step, absolutely necessary for the solution
13 | * 0: Moderately important step
14 | * +1: Completely redundant, can be omitted
15 | * Use any value in between to indicate varying degrees of redundancy
16 |
17 | Requirements:
18 | - Evaluate each step independently
19 | - Provide scores as floating-point numbers
20 | - Return results in strict JSON format: {"validity": [scores], "redundancy": [scores]}
21 | - Ensure both arrays have the same length as the number of steps
22 | - Maintain mathematical rigor in your evaluation
23 | - Consider mathematical accuracy, logical coherence, and solution efficiency
24 |
25 | Example output format:
26 | {"validity": [0.8, -0.5, 1.0], "redundancy": [-1.0, 0.3, 0.7]}
27 |
28 | You will be presented with a mathematical problem and its step-by-step solution. Please analyze each step and provide your evaluation in the specified JSON format.
29 |
30 | """
31 |
32 | policy_model_fewshot_q1 = """
33 | Question:
34 |
35 | In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $\$1.50$ in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?
36 |
37 | Solution:
38 |
39 | Step 1. Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*$\\frac{1.50}{1200}$.
40 |
41 | Step 2. There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.
42 |
43 | Step 3. I think the second way will be easier. $\\frac{1.50}{1200}=0.00125$.
44 |
45 | Step 4. That's true. Now we just have to solve the equation x=1000000*0.00125.
46 |
47 | Step 5. So x=$1250$.
48 |
49 | Step 6. So x=$1250$.
50 |
51 | Step 7. That's the final answer.
52 |
53 | Step 8. Right. So 1,000,000 lire is equivalent to $\$1250$.
54 |
55 | # Answer
56 |
57 | 1250
58 | """
59 | policy_model_fewshot_a1="{\"validity\": [1.0, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\"redundancy\": [-1.0, 0.5, -0.5, -0.5, -1.0, 1.0, 0.5, 1.0]}"
60 |
61 | policy_model_fewshot_q2 = """
62 | Question:
63 |
64 | Four points, $A$, $B$, $C$, and $D$, are chosen randomly and independently on the circumference of a circle. What is the probability that segments $AB$ and $CD$ intersect?
65 |
66 | Solution:
67 |
68 | Step 1. Let's think about what it means for the segments to intersect.
69 |
70 | Step 2. Whether they intersect is entirely a function of the order of $B,C,D$ on the circle, when we look at it counterclockwise and consider $A$ to come first.
71 |
72 | Step 3. Exactly. So let's consider the different possible orders of $B,C,D$.
73 |
74 | Step 4. The number of possible orders is $3!=6$.
75 |
76 | Step 5. Yes, that's right. Now, we need to figure out how many of these result in the segments intersecting.
77 |
78 | Step 6. If $C$ is in between $B$ and $D$, then the segments will intersect.
79 |
80 | Step 7. If $B$ is in between $C$ and $D$, then the segments will intersect.
81 |
82 | Step 8. Right. That's two of the possible orders.
83 |
84 | Step 9. So, the probability that the segments intersect is $\dfrac{2}{6}=\dfrac{1}{3}$.
85 |
86 | Step 10. And that's our answer.
87 |
88 | # Answer
89 |
90 | 1/3
91 | """
92 |
93 | policy_model_fewshot_a2="{\"validity\": [1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 0.8, 1.0, 1.0, 1.0],\"redundancy\": [-1.0, -0.5, -0.5, -0.5, -0.5, -0.5, 0.5, -0.5, -1.0, 1.0]}"
94 |
95 |
96 | PROMPT_DICT=dict(
97 | policy_model_as_an_evaluator=dict(
98 | system_prompt=policy_model_system_prompt,
99 | fewshots=[
100 | (policy_model_fewshot_q1, policy_model_fewshot_a1),
101 | (policy_model_fewshot_q2, policy_model_fewshot_a2)
102 | ]
103 | )
104 | )
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
164 | # scripts
165 | mr_eval/scripts/bash_scripts/
166 | mr_eval/scripts/configs/
167 | mr_eval/scripts/logs/
168 |
169 |
170 | mr_eval/bon_eval/scripts/
171 |
172 | mr_annotate/build_data/selection_of_data/
173 | mr_annotate/build_data/scripts/
174 | mr_annotate/build_data/model_inference/qwq/scripts
175 | # data
176 | mr_annotate/annotation/data/
177 |
178 | # notebooks
179 | **/notebooks/
180 | **/notebook/
181 | test/
182 |
183 | # exclude
184 |
185 | ! mr_eval/scripts/logs/generated/loc
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/models/math_shepherd.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import json
3 | import os
4 | from typing import List, Optional, Tuple, Type, TypeVar, Union
5 |
6 | from loguru import logger as eval_logger
7 | from tqdm import tqdm
8 | from transformers import AutoTokenizer, AutoTokenizer, AutoModelForCausalLM
9 | import torch
10 | import torch.nn as nn
11 | from transformers.configuration_utils import PretrainedConfig
12 | from accelerate import Accelerator
13 |
14 |
15 | from .abstract_model import prm
16 | from ..utils.utils import *
17 | class MathShepherd(prm):
18 | def __init__(
19 | self,
20 | pretrained = "/mnt/petrelfs/songmingyang/songmingyang/model/reasoning/math-shepherd-mistral-7b-prm",
21 | good_token = '+',
22 | bad_token = '-',
23 | step_tag = 'ки',
24 | validity_threshold = 0.5,
25 | ) -> None:
26 |
27 | super().__init__(validity_threshold=validity_threshold)
28 | self.good_token = good_token
29 | self.bad_token = bad_token
30 | self.step_tag = step_tag
31 |
32 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained)
33 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
34 | self.candidate_tokens = self.tokenizer.encode(f"{good_token} {bad_token}")[1:] # [648, 387]
35 | self.step_tag_id = self.tokenizer.encode(f"{step_tag}")[-1] # 12902
36 | self.model = AutoModelForCausalLM.from_pretrained(pretrained,).eval()
37 |
38 | self.accelerator = Accelerator()
39 |
40 |
41 |
42 |
43 | def getitem_function(self,meta_data,index):
44 | data_idx = meta_data[index]["idx"]
45 | steps = meta_data[index]["steps"]
46 | question = meta_data[index]["question"]
47 |
48 | res = []
49 | for idx,step in enumerate(steps):
50 | if step.strip().startswith("Step") or step.strip().startswith("step"):
51 | res.append(f"{step.strip()} {self.step_tag}\n")
52 | else:
53 | res.append(f"Step {idx+1}: {step.strip()} {self.step_tag}\n")
54 | steps_str = "".join(res)
55 | original_input_for_prm = f"{question} {steps_str}"
56 | input_ids=self.tokenizer.encode(original_input_for_prm, return_tensors='pt')
57 | while input_ids.ndim > 1:
58 | input_ids = input_ids[0]
59 | res = dict(
60 | idx = data_idx,
61 | input_ids = input_ids,
62 | )
63 | return res
64 |
65 | def respond(self, dataloader) -> List[Tuple[float, bool]]:
66 | self.model, dataloader = self.accelerator.prepare(self.model, dataloader)
67 | self.accelerator.wait_for_everyone()
68 | self.model.eval()
69 | gen_kwargs = dataloader.dataset.gen_kwargs
70 | progress_bar = tqdm_rank0(len(dataloader), desc="Model Responding")
71 | if len(dataloader) == 0:
72 | self.accelerator.wait_for_everyone()
73 | return
74 | with torch.no_grad():
75 | for batch in dataloader:
76 | idx = batch['idx']
77 | input_ids = batch['input_ids']
78 | attention_mask = batch['attention_mask']
79 |
80 | original_logits = self.model(
81 | input_ids=input_ids,
82 | attention_mask=attention_mask,
83 | ).logits
84 |
85 | for i in range(len(idx)):
86 | current_logits = original_logits[i][:,self.candidate_tokens]
87 | original_scores = current_logits.softmax(dim=-1)[:,0]
88 | original_step_scores = original_scores[input_ids[i] == self.step_tag_id].tolist()
89 | step_level_validity_labels = [item > self.validity_threshold for item in original_step_scores]
90 | idx_item = idx[i]
91 | score_dict = dict(
92 | step_level_validity_labels = step_level_validity_labels,
93 | step_level_validity_scores = original_step_scores,
94 | )
95 | res = dict(scores=score_dict, idx=idx_item)
96 | dataloader.dataset.store_results(res)
97 | if progress_bar is not None:
98 | progress_bar.update(1)
99 |
100 | self.accelerator.wait_for_everyone()
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/utils/arguments.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import json
4 | import copy
5 | import random
6 | import logging
7 | import argparse
8 | import numpy as np
9 | from PIL import Image
10 | from argparse import Namespace
11 | from dataclasses import dataclass, field
12 | from typing import Dict, Optional, List, Sequence
13 |
14 | import torch
15 | from torch.utils.data import Dataset
16 |
17 | import transformers
18 | from transformers import TrainerCallback
19 | from transformers import HfArgumentParser, TrainingArguments
20 | from box import Box
21 |
22 | from .utils import *
23 |
24 | @dataclass
25 | class ModelArguments:
26 | model: Optional[str] = field(default="reasoneval")
27 | model_args: Optional[str] = field(default="pretrained=EleutherAI/pythia-160m,dtype=float32")
28 | batch_size: Optional[int] = field(default=1)
29 | num_workers: Optional[int] = field(default=4)
30 |
31 | @dataclass
32 | class TaskArguments:
33 | task_name: Optional[str] = field(default="reasoneval")
34 | resume_from_ckpt: Optional[Dict[str, str]] = field(default=None,)
35 | def __post_init__(self):
36 | # 如果传入的是一个字典,将其包装成 Box;否则默认生成空 Box
37 | if self.resume_from_ckpt is None:
38 | self.resume_from_ckpt = Box()
39 | elif isinstance(self.resume_from_ckpt, dict):
40 | self.resume_from_ckpt = Box(self.resume_from_ckpt)
41 | else:
42 | raise ValueError("resume_from_ckpt should be a dictionary.")
43 |
44 | save_to_ckpt: Optional[Dict[str, str]] = field(default=None,)
45 | def __post_init__(self):
46 | # 如果传入的是一个字典,将其包装成 Box;否则默认生成空 Box
47 | if self.save_to_ckpt is None:
48 | self.save_to_ckpt = Box()
49 | elif isinstance(self.save_to_ckpt, dict):
50 | self.save_to_ckpt = Box(self.save_to_ckpt)
51 | else:
52 | raise ValueError("save_to_ckpt should be a dictionary.")
53 |
54 |
55 | # Define and parse arguments.
56 | @dataclass
57 | class ScriptArguments:
58 | """
59 | The arguments for the Evaluation script.
60 | """
61 | config: Optional[str] = field(default=None)
62 | verbosity: Optional[str] = field(default="INFO")
63 | wandb_args: Optional[str] = field(default="project=mr_eval,entity=mr_eval")
64 | output_path: Optional[str] = field(default="output")
65 |
66 |
67 | def parse_str_into_dict(args_str: str) -> Dict:
68 | """
69 | Parse a string of comma-separated key-value pairs into a dictionary.
70 | """
71 | args_dict = {}
72 | for arg in args_str.split(","):
73 | key, value = arg.split("=")
74 | args_dict[key] = value
75 | return args_dict
76 |
77 | def parse_str_into_list(args_str: str) -> List:
78 | """
79 | Parse a string of comma-separated values into a list.
80 | """
81 | return args_str.split(",")
82 |
83 | def parse_args():
84 | parser = transformers.HfArgumentParser(
85 | (ModelArguments, TaskArguments, ScriptArguments))
86 | model_args, task_args, script_args = parser.parse_args_into_dataclasses()
87 |
88 | if script_args.config:
89 | if script_args.config.endswith(".json"):
90 | config = load_json_file(script_args.config)
91 | elif script_args.config.endswith(".yaml"):
92 | config = load_yaml_file(script_args.config)
93 | else:
94 | raise ValueError("Config file should be either a json or yaml file.")
95 |
96 | if isinstance(config, dict):
97 | model_args = ModelArguments(**config["model_args"])
98 | task_args = TaskArguments(**config["task_args"])
99 | script_args = ScriptArguments(**config["script_args"])
100 | elif isinstance(config, list):
101 | model_args = ModelArguments(**config[0]["model_args"])
102 | task_args = TaskArguments(**config[0]["task_args"])
103 | script_args = ScriptArguments(**config[0]["script_args"])
104 | else:
105 | raise ValueError("Config file should be either a dict or list of dicts.")
106 | else:
107 | config = None
108 |
109 | script_args.config = config
110 | task_args.task_name = parse_str_into_list(task_args.task_name)
111 | if isinstance(model_args.model_args, str):
112 | model_args.model_args = parse_str_into_dict(model_args.model_args)
113 | if isinstance(script_args.wandb_args, str):
114 | script_args.wandb_args = parse_str_into_dict(script_args.wandb_args)
115 |
116 | return dict(model_args=model_args, task_args=task_args, script_args=script_args)
117 |
118 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/models/mathminos_mistral.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import json
3 | import os
4 | from typing import List, Optional, Tuple, Type, TypeVar, Union
5 |
6 | from loguru import logger as eval_logger
7 | from tqdm import tqdm
8 | from transformers import AutoTokenizer, AutoTokenizer,MistralForTokenClassification
9 | import torch
10 | import torch.nn as nn
11 | from transformers.configuration_utils import PretrainedConfig
12 | from accelerate import Accelerator
13 | import torch.nn.functional as F
14 |
15 | from .abstract_model import prm
16 | from ..utils.utils import *
17 | from ..utils.model_utils import remove_step_prefix
18 |
19 | class MathMinos_Mistral(prm):
20 | def __init__(
21 | self,
22 | pretrained = "/mnt/petrelfs/songmingyang/code/reasoning/MR_Hallucination/ref/MATH-Minos/RM/ckpts/minos_mistral",
23 | step_tag = "и",
24 | validity_threshold = 0,
25 | ) -> None:
26 |
27 | super().__init__(validity_threshold=validity_threshold)
28 | self.step_tag = step_tag
29 |
30 | self.model = MistralForTokenClassification.from_pretrained(
31 | pretrained,
32 | )
33 |
34 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained)
35 | self.step_tag_id = self.tokenizer.encode(f"{step_tag}")[-1]
36 | if self.tokenizer.pad_token_id is None:
37 | print_rank0("Setting pad_token_id to eos_token_id")
38 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
39 | self.accelerator = Accelerator()
40 |
41 |
42 |
43 |
44 | def getitem_function(self,meta_data,index):
45 | data_idx = meta_data[index]["idx"]
46 | steps = meta_data[index]["steps"]
47 | question = meta_data[index]["question"]
48 |
49 | res = []
50 | for idx,step in enumerate(steps):
51 | clean_step = remove_step_prefix(step)
52 | res.append(f"Step {idx+1}: {clean_step} {self.step_tag}\n")
53 |
54 | steps_str = "".join(res)
55 | original_input_for_prm = f"Human: {question}\n\nAssistant:{steps_str}"
56 |
57 | input_ids = self.tokenizer.encode(original_input_for_prm, return_tensors='pt', max_length = self.generation_config.max_length, truncation=True)
58 | while input_ids.ndim > 1:
59 | input_ids = input_ids[0]
60 |
61 | res = dict(
62 | idx = data_idx,
63 | input_ids = input_ids,
64 | )
65 | return res
66 |
67 | def respond(self, dataloader) -> List[Tuple[float, bool]]:
68 | self.model, dataloader = self.accelerator.prepare(self.model, dataloader)
69 | self.accelerator.wait_for_everyone()
70 | self.model.eval()
71 | gen_kwargs = dataloader.dataset.gen_kwargs
72 | progress_bar = tqdm_rank0(len(dataloader), desc="Model Responding")
73 | if len(dataloader) == 0:
74 | self.accelerator.wait_for_everyone()
75 | return
76 | with torch.no_grad():
77 | for batch in dataloader:
78 | idx = batch['idx']
79 | input_ids = batch['input_ids']
80 | attention_mask = batch['attention_mask']
81 |
82 | original_logits = self.model(
83 | input_ids=input_ids,
84 | attention_mask=attention_mask,
85 | ).logits.squeeze(-1)
86 |
87 | for i in range(len(idx)):
88 | current_input_ids = input_ids[i]
89 | current_logits = original_logits[i][current_input_ids == self.step_tag_id]
90 | original_labels = current_logits > self.validity_threshold
91 | if torch.is_tensor(original_labels):
92 | original_labels = original_labels.tolist()
93 | # original_scores = F.normalize(original_scores, p=2, dim=-1)
94 | original_scores = current_logits
95 | idx_item = idx[i]
96 | score_dict = dict(
97 | step_level_validity_labels = original_labels,
98 | step_level_validity_scores = original_scores.tolist(),
99 | )
100 | res = dict(scores=score_dict, idx=idx_item)
101 | dataloader.dataset.store_results(res)
102 |
103 | if progress_bar is not None:
104 | progress_bar.update(1)
105 |
106 | self.accelerator.wait_for_everyone()
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/models/llama3_1_8b_prm.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import json
3 | import os
4 | import sys
5 |
6 | from typing import List, Optional, Tuple, Type, TypeVar, Union
7 |
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | import torch
11 | import torch.nn as nn
12 | from transformers.configuration_utils import PretrainedConfig
13 | from accelerate import Accelerator
14 |
15 |
16 | from .abstract_model import prm
17 | from ..utils.utils import *
18 | from ..utils.log_utils import *
19 | from ..utils.model_utils import remove_step_prefix, find_subsequence
20 |
21 | logger = get_logger(__name__)
22 |
23 |
24 |
25 | class LLaMA318BPRM(prm):
26 | def __init__(
27 | self,
28 | pretrained = "/mnt/petrelfs/songmingyang/songmingyang/model/reasoning/Llama3.1-8B-PRM-Mistral-Data",
29 | positive_token = 10,
30 | pattern = torch.tensor([128006, 78191, 128007, 271, 10, 128009]),
31 | validity_threshold = 0.5,
32 | ) -> None:
33 | super(LLaMA318BPRM, self).__init__(validity_threshold=validity_threshold)
34 |
35 | # pattern
36 | if isinstance(pattern, torch.Tensor):
37 | self.pattern = pattern
38 | elif isinstance(pattern, list):
39 | self.pattern = torch.tensor(pattern)
40 | else:
41 | raise ValueError("pattern should be a list or a torch.Tensor")
42 | self.positive_token = positive_token
43 |
44 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
45 | self.model = AutoModelForCausalLM.from_pretrained(pretrained, device_map="cpu").eval()
46 |
47 | self.accelerator = Accelerator()
48 |
49 |
50 |
51 |
52 | def getitem_function(self,meta_data,index):
53 | data_idx = meta_data[index]["idx"]
54 | steps = meta_data[index]["steps"]
55 | question = meta_data[index]["question"]
56 |
57 | conversation = []
58 | for idx,step in enumerate(steps):
59 | clean_step = remove_step_prefix(step)
60 | if idx == 0:
61 | text = question + " " + clean_step
62 | else:
63 | text = clean_step
64 | conversation.append({"content":text,"role":"user"})
65 | conversation.append({"content":"+","role":"assistant"})
66 | input_ids = self.tokenizer.apply_chat_template(conversation,return_tensors="pt")
67 |
68 | while input_ids.ndim > 1:
69 | input_ids = input_ids[0]
70 | res = dict(
71 | idx = data_idx,
72 | input_ids = input_ids,
73 | )
74 | return res
75 |
76 | def respond(self, dataloader) -> List[Tuple[float, bool]]:
77 | # self.model, dataloader = self.accelerator.prepare(self.model, dataloader)
78 | dataloader = self.accelerator.prepare(dataloader)
79 | self.model = self.model.to(self.accelerator.device)
80 | self.accelerator.wait_for_everyone()
81 | self.model.eval()
82 | gen_kwargs = dataloader.dataset.gen_kwargs
83 | progress_bar = tqdm_rank0(len(dataloader), desc="Model Responding")
84 | if len(dataloader) == 0:
85 | self.accelerator.wait_for_everyone()
86 | return
87 | with torch.no_grad():
88 | for batch in dataloader:
89 | idx = batch['idx']
90 | input_ids = batch['input_ids']
91 | logits = self.model(input_ids).logits
92 |
93 | for i in range(len(idx)):
94 | current_input_id = input_ids[i].cpu()
95 | current_logits = logits[i].cpu()
96 | score_locations = find_subsequence(current_input_id, self.pattern)
97 | score_locations = [i+3 for i in score_locations]
98 | reward_logits = current_logits[score_locations]
99 | step_level_validity_scores = reward_logits.softmax(dim=-1)[:,self.positive_token].tolist()
100 | step_level_validity_labels = [item > self.validity_threshold for item in step_level_validity_scores]
101 |
102 | idx_item = idx[i]
103 | score_dict = dict(
104 | step_level_validity_labels = step_level_validity_labels,
105 | step_level_validity_scores = step_level_validity_scores,
106 | )
107 | res = dict(scores=score_dict, idx=idx_item)
108 | dataloader.dataset.store_results(res)
109 | if progress_bar is not None:
110 | progress_bar.update(1)
111 |
112 | self.accelerator.wait_for_everyone()
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import re
4 | import sys
5 | from copy import deepcopy
6 |
7 | import torch
8 |
9 | from ..utils import *
10 |
11 |
12 | def answer_sequence_to_str(answer_sequence):
13 | res = []
14 | for idx, step in enumerate(answer_sequence):
15 | res.append(f"Step {idx+1}. {step['text']}\n\n")
16 | res_str = "".join(res)
17 | return res_str
18 |
19 |
20 | def answer_sequence_to_shepherd_str(answer_sequence, step_tag="ки"):
21 | res = []
22 | for idx, step in enumerate(answer_sequence):
23 | res.append(f"Step {idx+1}: {step['text']} {step_tag}\n")
24 | res_str = "".join(res)
25 | return res_str
26 |
27 |
28 | def answer_sequence_to_reasoneval_list(answer_sequence):
29 | res = []
30 | for idx, step in enumerate(answer_sequence):
31 | res.append(f"{idx+1}. {step['text']}")
32 | return res
33 |
34 |
35 | def get_best_answer_by_item(item, return_type="shepherd"):
36 | steps = item["label"]["steps"]
37 | best_answers = []
38 | for step in steps:
39 | if step["human_completion"] is not None and step["chosen_completion"] is None:
40 | best_answers.append(step["human_completion"])
41 | elif step["chosen_completion"] is not None:
42 | best_answers.append(step["completions"][step["chosen_completion"]])
43 | else:
44 | print(f"skipped one step")
45 | if return_type == "shepherd":
46 | answer_str = answer_sequence_to_shepherd_str(best_answers)
47 | elif return_type == "str":
48 | answer_str = answer_sequence_to_str(best_answers)
49 | elif return_type == "reasoneval":
50 | answer_str = answer_sequence_to_reasoneval_list(best_answers)
51 | else:
52 | answer_str = best_answers
53 | return answer_str
54 |
55 |
56 | def get_latex_str(question, answer):
57 | res = f"Question:\n\n{question}\n\nAnswer:\n\n{answer}"
58 | return res
59 |
60 |
61 | def score_list_to_str(score_list):
62 | valid2_list = [str(round(i, 2)) for i in score_list]
63 | res = ", ".join(valid2_list)
64 | return res
65 |
66 |
67 | def clean_str(input_str):
68 | res_str = deepcopy(input_str)
69 | res_str = re.sub(r"\\+([^\\\s])", r"\\\\\1", res_str)
70 | res_str = re.sub(r"\\+([\s])", r"\\\\\\\\\1", res_str)
71 | return res_str
72 |
73 |
74 | def remove_comments_from_json(json_string):
75 | """
76 | 移除 JSON 字符串中的单行和多行注释。
77 | """
78 |
79 | # 匹配 // 和 # 开头的注释,并移除
80 | return re.sub(r"//.*?$|#.*?$", "", json_string, flags=re.MULTILINE)
81 |
82 |
83 | def extract_nested_json(text):
84 | """
85 | 提取嵌套大括号内的 JSON 数据,移除注释后解析。
86 | Args:
87 | text (str): 包含 JSON 的文本。
88 | Returns:
89 | dict or list or None: 解析成功返回 JSON 数据,失败返回 None。
90 | """
91 | stack = [] # 用来记录大括号的匹配
92 | start = -1
93 | for i, char in enumerate(text):
94 | if char == "{":
95 | if not stack: # 当栈为空时,记录第一个大括号的位置
96 | start = i
97 | stack.append("{") # 压栈
98 | elif char == "}":
99 | stack.pop() # 出栈
100 | if not stack: # 当栈为空时,表示找到完整 JSON
101 | try:
102 | # 提取完整 JSON 字符串
103 | json_str = text[start : i + 1]
104 | # 移除注释
105 | json_cleaned = remove_comments_from_json(json_str)
106 | # 尝试解析为 JSON 对象
107 | return json.loads(json_cleaned)
108 | except json.JSONDecodeError as e:
109 | continue # 如果解析失败,跳过并继续查找
110 | return None # 如果未找到完整 JSON,则返回 None
111 |
112 |
113 | def process_policy_lm_evaluation_response(response):
114 | """process the response STRING from the language model"""
115 | try:
116 | json_object = extract_nested_json(response)
117 | assert json_object is not None
118 | assert "validity" in json_object and "redundancy" in json_object
119 | return json_object
120 | except:
121 | print(f"Invalid JSON Str, response: {response}")
122 | return None
123 |
124 |
125 | def remove_step_prefix(text):
126 | """
127 | 去掉以 'Step x. ' 或 'step x. ' 或 'x. ' 开头的部分,其中 x 是数字
128 | """
129 | text = text.strip()
130 | return re.sub(r"^(Step\s*\d+\.\s*|\d+\.\s*)", "", text, flags=re.IGNORECASE)
131 |
132 |
133 | def find_subsequence(tensor, subsequence):
134 | """
135 | 在张量中定位子串的位置。
136 |
137 | Args:
138 | tensor (torch.Tensor): 主张量。
139 | subsequence (torch.Tensor): 子串张量。
140 |
141 | Returns:
142 | List[int]: 子串在主张量中的起始位置索引列表。
143 | """
144 | main_len = tensor.size(0) # 主张量的长度 (假设是二维张量,取列数)
145 | sub_len = subsequence.size(0) # 子串的长度
146 |
147 | positions = [] # 存储匹配的起始位置
148 | for i in range(main_len - sub_len + 1): # 滑动窗口遍历
149 | # 比较切片是否与子串相等
150 | if torch.equal(tensor[i : i + sub_len], subsequence):
151 | positions.append(i)
152 | return positions
153 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/models/pure_prm.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import torch
4 | from accelerate import Accelerator
5 | from transformers import AutoModelForTokenClassification, AutoTokenizer
6 |
7 | from ..utils.log_utils import get_logger
8 | from ..utils.utils import *
9 | from .abstract_model import prm
10 |
11 | logger = get_logger(__name__)
12 |
13 |
14 | class PUREPRM(prm):
15 | def __init__(
16 | self,
17 | pretrained = "jinachris/Qwen2.5-Math-7B-PRM800K",
18 | redundancy_threshold = 0.0, # not used?
19 | validity_threshold = 0.0,
20 | ) -> None:
21 | super().__init__(
22 | validity_threshold=validity_threshold,
23 | redundancy_threshold=redundancy_threshold,
24 | )
25 | self.tokenizer = AutoTokenizer.from_pretrained(
26 | pretrained,
27 | trust_remote_code=True,
28 | )
29 | self.model = AutoModelForTokenClassification.from_pretrained(
30 | pretrained,
31 | torch_dtype=torch.bfloat16,
32 | trust_remote_code=True,
33 | ).eval()
34 |
35 | self.accelerator = Accelerator()
36 |
37 | self.step_separator = "\n\n"
38 | self.step_separator_token_id = self.tokenizer(
39 | self.step_separator, add_special_tokens=False, return_tensors='pt')['input_ids']
40 |
41 | def getitem_function(self,meta_data,index):
42 | data_idx = meta_data[index]["idx"]
43 | steps = meta_data[index]["steps"]
44 | question = meta_data[index]["question"]
45 |
46 | ## build model-specialized input
47 | input_ids = self.tokenizer(
48 | question, add_special_tokens=False, return_tensors='pt')['input_ids']
49 | score_ids = []
50 | for step in steps:
51 | step_ids = self.tokenizer(
52 | step, add_special_tokens=False, return_tensors='pt')['input_ids']
53 | input_ids = torch.cat(
54 | [input_ids, step_ids, self.step_separator_token_id], dim=-1)
55 | score_ids.append(input_ids.size(-1) - 1)
56 |
57 | input_ids = input_ids.squeeze()
58 | token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
59 | token_mask[score_ids] = True
60 |
61 | res = dict(
62 | idx = data_idx,
63 | input_ids = input_ids,
64 | token_mask = token_mask,
65 | )
66 | return res
67 |
68 | def respond(self, dataloader) -> List[Tuple[float, bool]]:
69 | self.model, dataloader = self.accelerator.prepare(self.model, dataloader)
70 | self.accelerator.wait_for_everyone()
71 | self.model.eval()
72 | progress_bar = tqdm_rank0(len(dataloader), desc="Model Responding")
73 | if len(dataloader) == 0:
74 | self.accelerator.wait_for_everyone()
75 | return
76 | with torch.no_grad():
77 | for batch_idx, batch in enumerate(dataloader):
78 | idx = batch['idx']
79 | input_ids = batch['input_ids']
80 | attention_mask = batch['attention_mask']
81 | # right pad token mask
82 | token_mask_ = batch['token_mask']
83 | token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
84 | bs = input_ids.size(0)
85 | for i in range(bs):
86 | token_mask[i, attention_mask[i].to(bool)] = token_mask_[i][:attention_mask.size(1)]
87 | assert torch.all(input_ids[token_mask] == self.step_separator_token_id.item())
88 |
89 | scores = self.model(input_ids, attention_mask).logits
90 | step_reward = make_step_rewards(scores, token_mask)
91 |
92 | for i in range(len(idx)):
93 | idx_item = idx[i]
94 | try:
95 | step_level_validity_scores = step_reward[i]
96 | score_dict = dict(
97 | step_level_validity_scores=step_level_validity_scores,
98 | step_level_validity_labels=[item > self.validity_threshold for item in step_level_validity_scores],
99 | )
100 | res = dict(scores=score_dict, idx=idx_item)
101 | except:
102 | logger.error(f"Error in processing idx: {idx[i]}")
103 | res = dict(scores=dict(), idx=idx_item,validity=False)
104 |
105 | dataloader.dataset.store_results(res)
106 | if progress_bar is not None:
107 | progress_bar.update(1)
108 |
109 | self.accelerator.wait_for_everyone()
110 |
111 |
112 | def make_step_rewards(logits, token_masks):
113 | all_scores_res = []
114 | for sample, token_mask in zip(logits, token_masks):
115 | probs = sample[token_mask].softmax(dim=-1)
116 | process_reward = probs[:, 1] - probs[:, 0]
117 | all_scores_res.append(process_reward.cpu().tolist())
118 | return all_scores_res
--------------------------------------------------------------------------------
/src/active_prm/trainer/active_sft_trainer.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from typing import Optional
3 |
4 | import torch
5 | from transformers import Trainer
6 |
7 |
8 | class ActiveSFTTrainer(Trainer):
9 | def __init__(self, *args, **kwargs):
10 | super().__init__(*args, **kwargs)
11 | self._metrics = defaultdict(list)
12 | self._pseudo_labels = defaultdict(list)
13 |
14 | def _get_pseudo_labels(self, preds, labels):
15 | # NOTE: compute pseudo_labels; for labels after first error, set to -100
16 | # labels is only used for mask -100
17 | pseudo_labels = torch.zeros_like(labels, dtype=labels.dtype)
18 | pseudo_labels[preds >= 0.5] = 1
19 | pseudo_labels[labels == -100] = -100
20 | # compute first error step
21 | errors = pseudo_labels == 0
22 | first_error_idx = torch.where(errors.any(dim=1), errors.int().argmax(dim=1), pseudo_labels.size(1))
23 | positions = torch.arange(pseudo_labels.size(1), device=labels.device).unsqueeze(0)
24 | mask = positions >= (first_error_idx.unsqueeze(-1) + 1)
25 | pseudo_labels[mask] = -100
26 | del errors, first_error_idx, positions, mask
27 | return pseudo_labels
28 |
29 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
30 | """
31 | Compute training loss and additionally compute token accuracies
32 | """
33 | if not self.model.training:
34 | outputs = model(**inputs)
35 | outputs.logits = torch.nn.functional.sigmoid(outputs.logits).mean(dim=0)
36 | return (outputs.loss, outputs) if return_outputs else outputs.loss
37 |
38 | else:
39 | labels = inputs.pop("labels", None)
40 | outputs = model(**inputs)
41 | # compute pesudo labels using active learning
42 | p_threshold = self.args.active_learning_pred_threshold
43 | std_threshold = self.args.active_learning_std_threshold
44 |
45 | prm_score = torch.nn.functional.sigmoid(outputs.logits)
46 | # compute stds
47 | if prm_score.size(0) != 1:
48 | stds = prm_score.std(dim=0) # (batch_size, seq_len)
49 | else:
50 | stds = torch.zeros_like(prm_score[0])
51 | preds = prm_score.mean(dim=0)
52 |
53 | pseudo_labels = self._get_pseudo_labels(preds, labels)
54 |
55 | # compute trust_masks on instance level;
56 | # for all pseudo_labels !=-100, std should <= std_threshold
57 | _trust_condition = (stds <= std_threshold) & ((preds >= p_threshold) | (preds <= 1 - p_threshold))
58 | trust_masks = torch.all(_trust_condition | (pseudo_labels == -100), dim=1)
59 |
60 | # compute the correctness between labels and pseudo_labels before assign
61 | num_correct_pseudo_label = (torch.all(pseudo_labels[trust_masks] == labels[trust_masks], dim=1)).sum()
62 | num_correct_pseudo_label = self.accelerator.gather_for_metrics(num_correct_pseudo_label)
63 | self._metrics["al_num_correct_pseudo_labels"].append(num_correct_pseudo_label.sum().item())
64 |
65 | num_trust_instances = trust_masks.sum()
66 | num_trust_instances = self.accelerator.gather_for_metrics(num_trust_instances)
67 | self._metrics["al_num_pseudo_labels"].append(num_trust_instances.sum().item())
68 |
69 | # # assign pseudo_labels to labels
70 | # labels[trust_masks] = pseudo_labels[trust_masks]
71 | # labels = labels.contiguous()
72 | # do not use confident data to train the model
73 |
74 | loss, reg_loss = model._compute_loss(
75 | outputs.logits[:, ~trust_masks, ...], labels[~trust_masks], return_reg_loss=True
76 | )
77 |
78 | trust_percentage = trust_masks.float().mean()
79 | trust_percentage = self.accelerator.gather_for_metrics(trust_percentage)
80 | self._metrics["trust_percentage"].append(trust_percentage.mean().item())
81 |
82 | # log loss
83 | self._metrics["reg_loss"].append(reg_loss.item())
84 | self._metrics["cls_loss"].append(loss.item() - reg_loss.item())
85 |
86 | return (loss, outputs) if return_outputs else loss
87 |
88 | def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
89 | metrics = {
90 | key: (sum(val) / len(val) if not key.startswith("al_") else sum(val)) for key, val in self._metrics.items()
91 | } # average the metrics
92 | if "al_num_correct_pseudo_labels" in metrics.keys():
93 | metrics["pseudo_label_acc"] = metrics["al_num_correct_pseudo_labels"] / (
94 | metrics["al_num_pseudo_labels"] + 1e-6
95 | )
96 |
97 | # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
98 | # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
99 | if next(iter(logs.keys())).startswith("eval_"):
100 | metrics = {f"eval_{key}": val for key, val in metrics.items()}
101 |
102 | logs = {**logs, **metrics}
103 | super().log(logs, start_time)
104 | self._metrics.clear()
105 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/prompts/prompts_test.txt:
--------------------------------------------------------------------------------
1 | Improved Prompt
2 | I want to test the ability of process-level reward models to judge whether a reasoning process or step is correct. To do this, please help me build flawed cases by introducing specific types of errors into a given reasoning process.
3 |
4 | You will be provided with:
5 |
6 | 1. A mathematical problem.
7 | 2. Its standard correct answer.
8 | 3. A step-by-step reasoning process used to solve it.
9 | Your task is to modify specific steps or introduce new ones to create reasoning that appears plausible but is incorrect. The goal is to simulate flawed solutions by introducing hallucinations or errors based on the types described below.
10 |
11 | Error Types to Introduce
12 | 1. Redundancy:
13 | Add unnecessary steps that do not affect the correctness of the solution but make the process less compact. For example, if the correct chain is $A \to B$, modify it to $A \to C \to B$, where $C$ is redundant.
14 | Circular Logic:
15 | Introduce a reasoning loop where a step depends on itself indirectly. For example, $A \to B \to C \to A$, forming a circular chain of reasoning.
16 | Counterfactual:
17 | Add a statement that contradicts known ground truth. This could involve using outdated theories, omitting restrictions in a theory, or introducing an incorrect assumption.
18 | Step Contradiction:
19 | Introduce a conflict between two or more steps in the reasoning process. For example, if the reasoning chain is $P = {S_1, S_2, ..., S_n}$, a contradiction exists if $S_i \perp S_j$ (where $i \neq j$).
20 | Domain Inconsistency:
21 | Use a statement or theory that is valid in a different domain or context but is incorrect in the current reasoning chain.
22 | Confident Hallucination:
23 | Introduce a false statement expressed in a confident tone, making it appear as ground truth while it contradicts known facts.
24 | Missing Conditions or Prerequisites:
25 | Omit critical premises, assumptions, or conditions necessary for logical reasoning. This creates a logical gap or incomplete conclusion.
26 | Deception or Traps:
27 | Modify a ground-truth statement slightly to make it appear correct while introducing an error. For example, subtly altering a formula or definition.
28 | Output Requirements
29 | After making the modifications, provide the following structured output:
30 |
31 | json
32 |
33 | 复制
34 | {
35 | "origin_process": ["origin_step 1", "origin_step 2", ...],
36 | "modified_process": ["modified_step 1", "modified_step 2", ...],
37 | "modified_steps": [1, 5, 7, ...],
38 | "hallucination_steps": [5, 6, ...],
39 | "hallucination_types": [1, 2, ...],
40 | "reason": "Explanation for the changes."
41 | }
42 | Detailed Requirements:
43 | origin_process: A non-empty list of strings representing the original reasoning steps provided as input.
44 | modified_process: A non-empty list of strings representing the reasoning process after your modifications. Retain the original steps except for those you have changed.
45 | modified_steps: A non-empty list of integers indicating the indexes of all modified steps. Indexing starts at 1.
46 | hallucination_steps: A non-empty list of integers representing the steps that contain hallucinations or errors. These should also be part of modified_steps.
47 | hallucination_types: A list of integers corresponding to the types of hallucinations introduced (1 for redundancy, 2 for circular logic, etc.). Use the numbering provided in the error types section.
48 | reason: A clear explanation of the modifications made, why they were introduced, and how they align with the specified error types.
49 | Formatting Notes:
50 | Ensure all lists are non-empty.
51 | The numbering of steps in origin_process and modified_process must be consistent.
52 | Use LaTeX format for all mathematical symbols (e.g., $x^2$ for $x$ squared). Do not use Unicode symbols such as \u2248 or \u00f7.
53 | Ensure the JSON object is well-formed, with proper escaping for special characters like \n (e.g., use \\n for newlines).
54 | Example Task
55 | Input:
56 | Problem: Solve $x^2 + 2x + 1 = 0$.
57 | Correct Answer: $x = -1$.
58 | Correct Reasoning:
59 | Step 1: Recognize the equation as a perfect square trinomial.
60 | Step 2: Factorize it as $(x+1)^2 = 0$.
61 | Step 3: Solve for $x$, giving $x = -1$.
62 | Output Example:
63 | json
64 |
65 | 复制
66 | {
67 | "origin_process": [
68 | "Step 1: Recognize the equation as a perfect square trinomial.",
69 | "Step 2: Factorize it as $(x+1)^2 = 0$.",
70 | "Step 3: Solve for $x$, giving $x = -1$."
71 | ],
72 | "modified_process": [
73 | "Step 1: Recognize the equation as a perfect square trinomial.",
74 | "Step 2: Factorize it as $(x+1)^2 = 0$.",
75 | "Step 3: Introduce a redundant step: Assume $x+1 = 1$, leading to $x = 0$.",
76 | "Step 4: Solve for $x$, giving $x = -1$."
77 | ],
78 | "modified_steps": [3],
79 | "hallucination_steps": [3],
80 | "hallucination_types": [1],
81 | "reason": "Step 3 introduces a redundant assumption ($x+1 = 1$) that is not necessary for solving the equation. This creates a hallucination of type 1 (redundancy)."
82 | }
83 | Notes
84 | Ensure the hallucinations you introduce are realistic and align with the error types described.
85 | The reasoning process should remain plausible and easy to follow, even with the introduced errors.
86 | Focus on creating subtle errors that test a model's ability to detect flaws effectively.
87 | By following this structure, you can simulate realistic flawed reasoning processes for testing purposes.
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/models/skywork_prm.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import json
3 | import os
4 | import sys
5 |
6 | from typing import List, Optional, Tuple, Type, TypeVar, Union
7 |
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoTokenizer, AutoModelForCausalLM
10 | import torch
11 | import torch.nn as nn
12 | from transformers.configuration_utils import PretrainedConfig
13 | from accelerate import Accelerator
14 |
15 |
16 | from .abstract_model import prm
17 | from ..utils.utils import *
18 | from ..utils.log_utils import *
19 | from ..utils.model_utils import remove_step_prefix
20 |
21 | logger = get_logger(__name__)
22 |
23 | try:
24 | # Some skywork specific functions.
25 | # please modify your code if you are evaluating skyworkPRM
26 | sys.path.append('/mnt/petrelfs/songmingyang/code/reasoning/MR_Hallucination/ref/skywork-o1-prm-inference')
27 | from model_utils.prm_model import PRM_MODEL
28 | from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
29 | except:
30 | logger.error("Failed to import Skywork PRM model utils, please specify path to Skywork PRM.")
31 |
32 |
33 |
34 | class SkyworkPRM(prm):
35 | def __init__(
36 | self,
37 | pretrained = "/mnt/petrelfs/songmingyang/songmingyang/model/reasoning/Skywork-o1-Open-PRM-Qwen-2.5-1.5B",
38 | step_tag = 'ки',
39 | validity_threshold = -0.05,
40 | ) -> None:
41 |
42 | super().__init__(validity_threshold=validity_threshold)
43 | self.step_tag = step_tag
44 |
45 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
46 | # self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
47 | self.model = PRM_MODEL.from_pretrained(pretrained, device_map="cpu").eval()
48 | self.accelerator = Accelerator()
49 | assert not self.accelerator.state.distributed_type == 'DEEPSPEED', "DeepSpeed is not supported for Skywork PRM."
50 |
51 | def getitem_function(self,meta_data,index):
52 | data_idx = meta_data[index]["idx"]
53 | steps = meta_data[index]["steps"]
54 | question = meta_data[index]["question"]
55 |
56 | res = []
57 | for idx,step in enumerate(steps):
58 | clean_step = remove_step_prefix(step)
59 | res.append(f"Step {idx+1}: {clean_step.strip()} {self.step_tag}\n")
60 |
61 | steps_str = "".join(res)
62 | prm_input_data = dict(
63 | problem = question,
64 | response = steps_str,
65 | )
66 | input_ids, attention_mask, reward_flags = prepare_input(**prm_input_data,tokenizer=self.tokenizer,step_token=self.step_tag)
67 | input_ids, reward_flags = torch.LongTensor(input_ids), torch.LongTensor(reward_flags)
68 | while input_ids.ndim > 1:
69 | input_ids = input_ids[0]
70 | if input_ids.shape[0] > self.generation_config.max_length:
71 | input_ids = input_ids[:self.generation_config.max_length]
72 | if reward_flags.shape[0] > self.generation_config.max_length:
73 | reward_flags = reward_flags[:self.generation_config.max_length]
74 | res = dict(
75 | idx = data_idx,
76 | input_ids = input_ids,
77 | reward_flags = reward_flags,
78 | )
79 | return res
80 |
81 | def respond(self, dataloader) -> List[Tuple[float, bool]]:
82 | dataloader = self.accelerator.prepare(dataloader)
83 | self.model = self.model.to(self.accelerator.device)
84 | self.accelerator.wait_for_everyone()
85 | self.model.eval()
86 | gen_kwargs = dataloader.dataset.gen_kwargs
87 | progress_bar = tqdm_rank0(len(dataloader), desc="Model Responding")
88 | if len(dataloader) == 0:
89 | self.accelerator.wait_for_everyone()
90 | return
91 | with torch.no_grad():
92 | for batch in dataloader:
93 | idx = batch['idx']
94 | input_ids = batch['input_ids']
95 | attention_mask = batch['attention_mask']
96 | reward_flags = batch['reward_flags']
97 |
98 | _, _, rewards = self.model(input_ids=input_ids, attention_mask=attention_mask, return_probs=True)
99 | step_rewards = derive_step_rewards(rewards, reward_flags)
100 |
101 |
102 | for i in range(len(idx)):
103 | step_level_validity_scores = step_rewards[i]
104 | judge_label_scores = [step_level_validity_scores[0]] + [step_level_validity_scores[i] - step_level_validity_scores[i-1] for i in range(1, len(step_level_validity_scores))]
105 | step_level_validity_labels = [item > self.validity_threshold for item in judge_label_scores]
106 |
107 | idx_item = idx[i]
108 | score_dict = dict(
109 | step_level_validity_labels = step_level_validity_labels,
110 | step_level_validity_scores = step_level_validity_scores,
111 | )
112 | res = dict(scores=score_dict, idx=idx_item)
113 | dataloader.dataset.store_results(res)
114 | if progress_bar is not None:
115 | progress_bar.update(1)
116 |
117 | self.accelerator.wait_for_everyone()
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/models/llemma7b_prm.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import json
3 | import os
4 | from typing import List, Optional, Tuple, Type, TypeVar, Union
5 |
6 | from loguru import logger as eval_logger
7 | from tqdm import tqdm
8 | from transformers import AutoTokenizer, AutoTokenizer, AutoModelForCausalLM
9 | import torch
10 | import torch.nn as nn
11 | from transformers.configuration_utils import PretrainedConfig
12 | from accelerate import Accelerator
13 |
14 |
15 | from .abstract_model import prm
16 | from ..utils.utils import *
17 |
18 | class Llemma7bPRM(prm):
19 | def __init__(
20 | self,
21 | pretrained = "/mnt/petrelfs/songmingyang/songmingyang/model/reasoning/llemma-7b-prm-prm800k-level-1to3-hf",
22 | tokenizer_pretrained = "EleutherAI/llemma_7b",
23 | step_tag = "\n\n",
24 | validity_threshold = 0.5,
25 | ) -> None:
26 |
27 | super().__init__(validity_threshold=validity_threshold)
28 |
29 | self.step_tag = step_tag
30 |
31 | self.model = AutoModelForCausalLM.from_pretrained(
32 | pretrained,
33 | )
34 |
35 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_pretrained)
36 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
37 | self.accelerator = Accelerator()
38 |
39 |
40 |
41 |
42 | def getitem_function(self,meta_data,index):
43 | data_idx = meta_data[index]["idx"]
44 | steps = meta_data[index]["steps"]
45 | question = meta_data[index]["question"]
46 |
47 | res = []
48 | for idx,step in enumerate(steps):
49 | if "\n\n" in step:
50 | step = step.replace("\n\n", "")
51 | res.append(f"{step.strip()}{self.step_tag}")
52 | steps_str = "".join(res)
53 | original_input_for_prm = f"# Question\n\n{question}\n\n# Solution\n\n{steps_str}"
54 | begin_solution_tokens = self.tokenizer.encode("\n\n# Solution", add_special_tokens=False)[1:]
55 | scoring_tokens = self.tokenizer.encode("\n\n", add_special_tokens=False)[1:]
56 | eos_token = self.tokenizer.eos_token_id
57 |
58 | input_ids = self.tokenizer.encode(original_input_for_prm)
59 |
60 | begin_solution_flag = False
61 | candidate_positions = []
62 | for start_idx in range(len(input_ids)):
63 | if tuple(input_ids[start_idx:start_idx+len(begin_solution_tokens)]) == tuple(begin_solution_tokens):
64 | begin_solution_flag = True
65 |
66 | if begin_solution_flag and tuple(input_ids[start_idx:start_idx+len(scoring_tokens)]) == tuple(scoring_tokens):
67 | candidate_positions.append(start_idx)
68 |
69 | if input_ids[start_idx] == eos_token:
70 | candidate_positions.append(start_idx)
71 | break
72 |
73 | # maybe delete the first and the second to last candidate_positions
74 | # because they are "\n\n" after "# Solution" and after "# Answer"
75 | del candidate_positions[0]
76 |
77 | input_ids = torch.tensor(input_ids)
78 | candidate_positions = [i for i in candidate_positions if i < self.generation_config.max_length]
79 | candidate_positions = torch.tensor(candidate_positions)
80 | res = dict(
81 | idx = data_idx,
82 | input_ids = input_ids,
83 | candidate_positions = candidate_positions,
84 | )
85 | return res
86 |
87 | def respond(self, dataloader) -> List[Tuple[float, bool]]:
88 |
89 | self.model, dataloader = self.accelerator.prepare(self.model, dataloader)
90 | self.accelerator.wait_for_everyone()
91 | self.model.eval()
92 | gen_kwargs = dataloader.dataset.gen_kwargs
93 | progress_bar = tqdm_rank0(len(dataloader), desc="Model Responding")
94 | if len(dataloader) == 0:
95 | self.accelerator.wait_for_everyone()
96 | return
97 | with torch.no_grad():
98 | for batch in dataloader:
99 | idx = batch['idx']
100 | input_ids = batch['input_ids']
101 | attention_mask = batch['attention_mask']
102 | candidate_positions = batch['candidate_positions']
103 |
104 | original_logits = self.model(
105 | input_ids=input_ids,
106 | attention_mask=attention_mask,
107 | ).logits
108 | scores = original_logits.mean(dim=-1)
109 |
110 | for i in range(len(idx)):
111 | current_candidate_positions = candidate_positions[i]
112 | current_score = scores[i][current_candidate_positions]
113 |
114 | original_step_scores = torch.sigmoid(current_score).tolist()
115 | step_level_validity_labels = [item > self.validity_threshold for item in original_step_scores]
116 | idx_item = idx[i]
117 | score_dict = dict(
118 | step_level_validity_labels = step_level_validity_labels,
119 | step_level_validity_scores = original_step_scores,
120 | )
121 | res = dict(scores=score_dict, idx=idx_item)
122 | dataloader.dataset.store_results(res)
123 | if progress_bar is not None:
124 | progress_bar.update(1)
125 |
126 | self.accelerator.wait_for_everyone()
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/models/ensemble_prm.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import hashlib
3 | import json
4 | import logging
5 | import os
6 | from typing import List, Optional, Tuple, Type, TypeVar, Union
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from accelerate import Accelerator
12 | from accelerate.logging import get_logger as get_accelerator_logger
13 | from active_prm.models import AutoModelForEnsemblePRM
14 | from loguru import logger as eval_logger
15 | from tqdm import tqdm
16 | from transformers import (AutoModel, AutoTokenizer, LlamaModel,
17 | LlamaPreTrainedModel, MistralModel,
18 | MistralPreTrainedModel)
19 | from transformers.configuration_utils import PretrainedConfig
20 |
21 | from ..utils.log_utils import get_logger
22 | from ..utils.model_utils import remove_step_prefix
23 | from ..utils.utils import *
24 | from .abstract_model import prm
25 |
26 | # accelerate_logger = logging.getLogger("debug")
27 | # accelerate_logger.setLevel(logging.DEBUG)
28 | logger = get_logger(__name__)
29 |
30 |
31 | class EnPRM(prm):
32 | def __init__(
33 | self,
34 | pretrained="/mnt/petrelfs/songmingyang/songmingyang/model/reasoning/Qwen2.5-Math-PRM-7B",
35 | redundancy_threshold=0.15,
36 | validity_threshold=0.5,
37 | ) -> None:
38 | super().__init__(validity_threshold=validity_threshold, redundancy_threshold=redundancy_threshold)
39 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
40 | self.model = AutoModelForEnsemblePRM.from_pretrained(
41 | pretrained,
42 | torch_dtype=torch.bfloat16,
43 | trust_remote_code=True,
44 | ).eval()
45 |
46 | self.accelerator = Accelerator()
47 |
48 | self.step_separator = ""
49 | self.step_separator_token_id = self.tokenizer.encode(self.step_separator)[0]
50 |
51 | def getitem_function(self, meta_data, index):
52 | data_idx = meta_data[index]["idx"]
53 | steps = meta_data[index]["steps"]
54 | question = meta_data[index]["question"]
55 |
56 | ## build model-specialized input
57 | system_prompt = "Please reason step by step, and put your final answer within \boxed{}."
58 | combined_steps = ""
59 | for step in steps:
60 | cleaned_step = remove_step_prefix(step)
61 | combined_steps += cleaned_step + self.step_separator
62 | messages = [
63 | {"role": "system", "content": system_prompt},
64 | {"role": "user", "content": question},
65 | {"role": "assistant", "content": combined_steps},
66 | ]
67 |
68 | conversation_str = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
69 |
70 | input_ids = self.tokenizer.encode(
71 | conversation_str,
72 | return_tensors="pt",
73 | )
74 | while input_ids.ndim > 1:
75 | input_ids = input_ids[0]
76 | res = dict(
77 | idx=data_idx,
78 | input_ids=input_ids,
79 | )
80 | return res
81 |
82 | def respond(self, dataloader) -> List[Tuple[float, bool]]:
83 | self.model, dataloader = self.accelerator.prepare(self.model, dataloader)
84 | self.accelerator.wait_for_everyone()
85 | self.model.eval()
86 | gen_kwargs = dataloader.dataset.gen_kwargs
87 | progress_bar = tqdm_rank0(len(dataloader), desc="Model Responding")
88 | if len(dataloader) == 0:
89 | self.accelerator.wait_for_everyone()
90 | return
91 | with torch.no_grad():
92 | for batch_idx, batch in enumerate(dataloader):
93 | idx = batch["idx"]
94 | input_ids = batch["input_ids"]
95 | attention_mask = batch["attention_mask"]
96 | # print(f"data device: {input_ids.device}, current device: {self.accelerator.device}")
97 | scores = self.model.inference(
98 | input_ids,
99 | attention_mask,
100 | ).logits
101 | token_mask = input_ids == self.step_separator_token_id
102 | step_reward = make_step_rewards(scores, token_mask)
103 | # import debugpy
104 | # debugpy.listen(address = ('0.0.0.0', 7119))
105 | # debugpy.wait_for_client()
106 | # breakpoint() #在下一句代码处暂停
107 |
108 | for i in range(len(idx)):
109 | idx_item = idx[i]
110 | try:
111 | step_level_validity_scores = step_reward[i]
112 | score_dict = dict(
113 | step_level_validity_scores=step_level_validity_scores,
114 | step_level_validity_labels=[
115 | item > self.validity_threshold for item in step_level_validity_scores
116 | ],
117 | )
118 | res = dict(scores=score_dict, idx=idx_item)
119 | except:
120 | __import__("ipdb").set_trace()
121 | logger.error(f"Error in processing idx: {idx[i]}")
122 | res = dict(scores=dict(), idx=idx_item, validity=False)
123 |
124 | dataloader.dataset.store_results(res)
125 | if progress_bar is not None:
126 | progress_bar.update(1)
127 |
128 | self.accelerator.wait_for_everyone()
129 |
130 |
131 | def make_step_rewards(logits, token_masks):
132 | all_scores_res = []
133 | for i in range(logits.size(1)):
134 | _logits = logits[:, i, token_masks[i]]
135 | _logits = _logits.mean(dim=0)
136 | all_scores_res.append(_logits.cpu().tolist())
137 | return all_scores_res
138 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/utils/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import torch.distributed as dist
5 | import yaml
6 | from tqdm import tqdm
7 |
8 |
9 | def load_json_file(filepath):
10 | """
11 | 将json文件读取成为列表或词典
12 | """
13 | with open(filepath, "r", encoding="UTF-8") as file:
14 | data = json.load(file)
15 | return data
16 |
17 |
18 | def write_json_file(data, filepath):
19 | os.makedirs(os.path.dirname(filepath), exist_ok=True)
20 | with open(filepath, "w", encoding="UTF-8") as f:
21 | json.dump(data, f, ensure_ascii=False, indent=4)
22 |
23 |
24 | def process_jsonl(file_path):
25 | """
26 | 将jsonl文件转换为装有dict的列表
27 | """
28 | data = []
29 | with open(file_path, "r", encoding="utf-8") as file:
30 | for line in file:
31 | json_obj = json.loads(line)
32 | data.append(json_obj)
33 | return data
34 |
35 |
36 | def write_jsonl(data, file_path):
37 | """
38 | 将list[dict]写入jsonl文件
39 | """
40 | os.makedirs(os.path.dirname(file_path), exist_ok=True)
41 | with open(file_path, "w", encoding="utf-8") as file:
42 | for item in data:
43 | line = json.dumps(item, ensure_ascii=False)
44 | file.write(line + "\n")
45 |
46 |
47 | def merge_jsonl(input_file_dir, output_filepath):
48 | """
49 | 将源文件夹内的所有jsonl文件合并为一个jsonl文件,并保存在output_filepath中
50 | """
51 | filepaths = [os.path.join(input_file_dir, file) for file in os.listdir(input_file_dir)]
52 | merged_data = []
53 | for filepath in filepaths:
54 | with open(filepath, "r") as file:
55 | for line in file:
56 | data = json.loads(line)
57 | merged_data.append(data)
58 |
59 | with open(output_filepath, "w") as output_file:
60 | for data in merged_data:
61 | output_file.write(json.dumps(data) + "\n")
62 |
63 |
64 | def append_jsonl(data, filename):
65 | os.makedirs(os.path.dirname(filename), exist_ok=True)
66 | with open(filename, "a", encoding="utf-8") as f:
67 | json.dump(data, f)
68 | f.write("\n")
69 |
70 |
71 | def load_txt_file(filepath):
72 | with open(filepath, "r", encoding="utf-8") as f:
73 | data = f.readlines()
74 | data = [line.strip().replace("\n", "") for line in data]
75 | return data
76 |
77 |
78 | def write_txt_file(data, filepath):
79 | os.makedirs(os.path.dirname(filepath), exist_ok=True)
80 | for item in data:
81 | with open(filepath, "a", encoding="utf-8") as f:
82 | f.write(item + "\n")
83 |
84 |
85 | def print_rank0(msg):
86 | if dist.is_available() and dist.is_initialized():
87 | if dist.get_rank() == 0:
88 | print(msg)
89 | else:
90 | print(msg)
91 |
92 |
93 | def str2list(input_str):
94 | if isinstance(input_str, str):
95 | raw_list = input_str.strip().replace("\n", "").split(",")
96 | new_list = []
97 | for item in raw_list:
98 | new_list.append(item.strip())
99 | return new_list
100 | elif isinstance(input_str, list):
101 | return input_str
102 | else:
103 | raise TypeError("input_str should be str or list")
104 |
105 |
106 | def get_two_words(word1, word2):
107 | if word1 < word2:
108 | return f"{word1},{word2}"
109 | else:
110 | return f"{word2},{word1}"
111 |
112 |
113 | def load_yaml_file(filepath):
114 | with open(filepath, "r", encoding="UTF-8") as file:
115 | data = yaml.safe_load(file)
116 | return data
117 |
118 |
119 | def write_yaml_file(data, filepath):
120 | os.makedirs(os.path.dirname(filepath), exist_ok=True)
121 | with open(filepath, "w", encoding="UTF-8") as file:
122 | yaml.dump(data, file, indent=4)
123 |
124 |
125 | def tqdm_rank0(total, desc):
126 | if dist.is_available() and dist.is_initialized():
127 | if dist.get_rank() == 0:
128 | pbar = tqdm(total=total, desc=desc)
129 | return pbar
130 | else:
131 | return None
132 | else:
133 | pbar = tqdm(total=total, desc=desc)
134 | return pbar
135 |
136 |
137 | def is_main_process():
138 | if dist.is_available() and dist.is_initialized():
139 | return dist.get_rank() == 0
140 | else:
141 | return True
142 |
143 |
144 | def dist_wait_for_everyone():
145 | if dist.is_available() and dist.is_initialized():
146 | dist.barrier()
147 |
148 |
149 | def gather_dict_lists(local_dict_list):
150 | """
151 | 使用all_gather_object收集所有进程的数据
152 | """
153 | if dist.is_available() and dist.is_initialized():
154 | # 获取总进程数
155 | world_size = dist.get_world_size()
156 |
157 | # 准备接收对象的空列表,每个进程分配一个 None
158 | gathered_dict_lists = [None for _ in range(world_size)]
159 |
160 | # 收集所有进程的数据
161 | dist.all_gather_object(gathered_dict_lists, local_dict_list)
162 |
163 | # 合并所有进程的数据到一个完整的列表
164 | final_merged_list = [item for sublist in gathered_dict_lists for item in sublist]
165 | return final_merged_list
166 | else:
167 | return local_dict_list
168 |
169 |
170 | def setup_proxy():
171 | AD_NAME = "songmingyang"
172 | encrypted_password = "dSpydxsxxhKix63HfIFhjwnZLEInXEDawSoMD35G1IT2CygKnHsJqG9ZHbEP"
173 | new_proxy_address = f"http://{AD_NAME}:{encrypted_password}@10.1.20.50:23128/"
174 | # 设置环境变量
175 | os.environ["http_proxy"] = new_proxy_address
176 | os.environ["https_proxy"] = new_proxy_address
177 | os.environ["HTTP_PROXY"] = new_proxy_address
178 | os.environ["HTTPS_PROXY"] = new_proxy_address
179 |
180 |
181 | def setup_openai_proxy():
182 | new_proxy_address = "http://closeai-proxy.pjlab.org.cn:23128"
183 | # 设置环境变量
184 | os.environ["http_proxy"] = new_proxy_address
185 | os.environ["https_proxy"] = new_proxy_address
186 | os.environ["HTTP_PROXY"] = new_proxy_address
187 | os.environ["HTTPS_PROXY"] = new_proxy_address
188 |
189 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/models/qwen_prm.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import hashlib
3 | import json
4 | import os
5 | import logging
6 | from typing import List, Optional, Tuple, Type, TypeVar, Union
7 |
8 | from loguru import logger as eval_logger
9 | from tqdm import tqdm
10 | from transformers import AutoTokenizer,MistralModel, MistralPreTrainedModel, LlamaModel, LlamaPreTrainedModel, AutoTokenizer, AutoModel
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from transformers.configuration_utils import PretrainedConfig
15 | from accelerate import Accelerator
16 |
17 | from .abstract_model import prm
18 | from ..utils.utils import *
19 | from ..utils.log_utils import get_logger
20 | from ..utils.model_utils import remove_step_prefix
21 | from accelerate.logging import get_logger as get_accelerator_logger
22 |
23 | # accelerate_logger = logging.getLogger("debug")
24 | # accelerate_logger.setLevel(logging.DEBUG)
25 | logger = get_logger(__name__)
26 | class QwenPRM(prm):
27 | def __init__(
28 | self,
29 | pretrained = "/mnt/petrelfs/songmingyang/songmingyang/model/reasoning/Qwen2.5-Math-PRM-7B",
30 | redundancy_threshold = 0.15,
31 | validity_threshold = 0.5,
32 | ) -> None:
33 | super().__init__(validity_threshold=validity_threshold, redundancy_threshold=redundancy_threshold)
34 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
35 | self.model = AutoModel.from_pretrained(
36 | pretrained,
37 | torch_dtype=torch.bfloat16,
38 | trust_remote_code=True,
39 | ).eval()
40 |
41 | self.accelerator = Accelerator()
42 |
43 |
44 | self.step_separator = ""
45 | self.step_separator_token_id = self.tokenizer.encode(self.step_separator)[0]
46 |
47 | def getitem_function(self,meta_data,index):
48 | data_idx = meta_data[index]["idx"]
49 | steps = meta_data[index]["steps"]
50 | question = meta_data[index]["question"]
51 |
52 |
53 |
54 | ## build model-specialized input
55 | system_prompt = "Please reason step by step, and put your final answer within \boxed{}."
56 | combined_steps = ""
57 | for step in steps:
58 | cleaned_step = remove_step_prefix(step)
59 | combined_steps += cleaned_step + self.step_separator
60 | messages = [
61 | {"role": "system", "content": system_prompt},
62 | {"role": "user", "content": question},
63 | {"role": "assistant", "content": combined_steps},
64 | ]
65 |
66 | conversation_str = self.tokenizer.apply_chat_template(
67 | messages,
68 | tokenize=False,
69 | add_generation_prompt=False
70 | )
71 |
72 | input_ids = self.tokenizer.encode(
73 | conversation_str,
74 | return_tensors="pt",
75 | )
76 | while input_ids.ndim > 1:
77 | input_ids = input_ids[0]
78 | res = dict(
79 | idx = data_idx,
80 | input_ids = input_ids,
81 | )
82 | return res
83 |
84 | def respond(self, dataloader) -> List[Tuple[float, bool]]:
85 | self.model, dataloader = self.accelerator.prepare(self.model, dataloader)
86 | self.accelerator.wait_for_everyone()
87 | self.model.eval()
88 | gen_kwargs = dataloader.dataset.gen_kwargs
89 | progress_bar = tqdm_rank0(len(dataloader), desc="Model Responding")
90 | if len(dataloader) == 0:
91 | self.accelerator.wait_for_everyone()
92 | return
93 | with torch.no_grad():
94 | for batch_idx, batch in enumerate(dataloader):
95 |
96 | idx = batch['idx']
97 | input_ids = batch['input_ids']
98 | attention_mask = batch['attention_mask']
99 | # print(f"data device: {input_ids.device}, current device: {self.accelerator.device}")
100 | scores = self.model(input_ids,
101 | attention_mask,).logits
102 | token_mask = input_ids == self.step_separator_token_id
103 | step_reward = make_step_rewards(scores, token_mask)
104 | # import debugpy
105 | # debugpy.listen(address = ('0.0.0.0', 7119))
106 | # debugpy.wait_for_client()
107 | # breakpoint() #在下一句代码处暂停
108 |
109 |
110 | for i in range(len(idx)):
111 | idx_item = idx[i]
112 | try:
113 | step_level_validity_scores = step_reward[i]
114 | score_dict = dict(
115 | step_level_validity_scores=step_level_validity_scores,
116 | step_level_validity_labels=[item > self.validity_threshold for item in step_level_validity_scores],
117 | )
118 | res = dict(scores=score_dict, idx=idx_item)
119 | except:
120 | logger.error(f"Error in processing idx: {idx[i]}")
121 | res = dict(scores=dict(), idx=idx_item,validity=False)
122 |
123 | dataloader.dataset.store_results(res)
124 | if progress_bar is not None:
125 | progress_bar.update(1)
126 |
127 | self.accelerator.wait_for_everyone()
128 |
129 |
130 | def make_step_rewards(logits, token_masks):
131 | probabilities = F.softmax(logits, dim=-1)
132 | probabilities = probabilities * token_masks.unsqueeze(-1) # bs, seq_len, num_labels
133 |
134 | all_scores_res = []
135 | for i in range(probabilities.size(0)):
136 | sample = probabilities[i] # seq_len, num_labels
137 | positive_probs = sample[sample != 0].view(-1, 2)[:, 1] # valid_tokens, num_labels
138 | non_zero_elements_list = positive_probs.cpu().tolist()
139 | all_scores_res.append(non_zero_elements_list)
140 | return all_scores_res
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/prompts/classifications/confidence.py:
--------------------------------------------------------------------------------
1 | confidence_fewshot_q1="""
2 | Question:
3 |
4 | If $x^2 - x - 1 = 0$, what is the value of $x^3 - 2x + 1$?
5 |
6 | Answer:
7 |
8 | Step 1. Let's start with the first equation and see if we can solve for x.
9 |
10 | Step 2. We can use the quadratic formula to find x.
11 |
12 | Step 3. The quadratic formula is $x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$.
13 |
14 | Step 4. Substituting $a = 1$, $b = -1$, and $c = -1$, we get $x = \\frac{1 \\pm \\sqrt{5}}{2}$.
15 |
16 | Step 5. Since $x^2 - x - 1 = 0$, we replace $x^2$ with $x + 1$ in $x^3 - 2x + 1$.
17 |
18 | Step 6. Substituting, $x^3 - 2x + 1 = x(x + 1) - 2x + 1$.
19 |
20 | Step 7. Simplify: $x^3 - 2x + 1 = (x^2 + x) - 2x + 1$.
21 |
22 | Step 8. Substitute $x^2 = x + 1$: $x^3 - 2x + 1 = (x + 1 + x) - 2x + 1$.
23 |
24 | Step 9. Simplify: $x^3 - 2x + 1 = 2$."
25 |
26 | # Answer
27 |
28 | 2
29 | """
30 | confidence_fewshot_a1="""
31 | {
32 | "original_question": "If $x^2 - x - 1 = 0$, what is the value of $x^3 - 2x + 1$?",
33 | "modified_question": "If $x^2 - x - 1 = 0$, what is the value of $x^3 - 2x + 1$?",
34 | "original_process": [
35 | "Let's start with the first equation and see if we can solve for x.",
36 | "We can use the quadratic formula to find x.",
37 | "The quadratic formula is $x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$.",
38 | "Substituting $a = 1$, $b = -1$, and $c = -1$, we get $x = \\frac{1 \\pm \\sqrt{5}}{2}$.",
39 | "Since $x^2 - x - 1 = 0$, we replace $x^2$ with $x + 1$ in $x^3 - 2x + 1$.",
40 | "Substituting, $x^3 - 2x + 1 = x(x + 1) - 2x + 1$.",
41 | "Simplify: $x^3 - 2x + 1 = (x^2 + x) - 2x + 1$.",
42 | "Substitute $x^2 = x + 1$: $x^3 - 2x + 1 = (x + 1 + x) - 2x + 1$.",
43 | "Simplify: $x^3 - 2x + 1 = 2$."
44 | ],
45 | "modified_process": [
46 | "Let's start with the first equation and see if we can solve for x.",
47 | "We can use the quadratic formula to find x.",
48 | "The quadratic formula is $x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$.",
49 | "Substituting $a = 1$, $b = -1$, and $c = -1$, we get $x = \\frac{1 \\pm \\sqrt{5}}{2}$.",
50 | "Since $x^2 - x - 1 = 0$, we replace $x^2$ with $x + 1$ in $x^3 - 2x + 1$.",
51 | "Substituting, $x^3 - 2x + 1 = x(x + 1) - 2x + 1$.",
52 | "Simplify: $x^3 - 2x + 1 = (x^2 + x) - 2x + 1$.",
53 | "Substitute $x^2 = x + 1$: $x^3 - 2x + 1 = (x + 1 + x) - 2x + 1$.",
54 | "Simplify: $x^3 - 2x + 1 = 3$. Therefore, the final value is 3."
55 | ],
56 | "modified_steps": [9],
57 | "error_steps": [9],
58 | "reason": "Step 9 introduces a confident hallucination by asserting that $x^3 - 2x + 1 = 3$. This contradicts the correct simplification result of $2$ but is stated with undue certainty, simulating overconfidence in an incorrect step."
59 | }
60 | """
61 |
62 | confidence_fewshot_q2="""
63 | Question:
64 |
65 | In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $\$1.50$ in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?
66 |
67 | Answer:
68 |
69 | Step 1. Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*$\\frac{1.50}{1200}$.
70 |
71 | Step 2. There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.
72 |
73 | Step 3. I think the second way will be easier. $\\frac{1.50}{1200}=0.00125$.
74 |
75 | Step 4. That's true. Now we just have to solve the equation x=1000000*0.00125.
76 |
77 | Step 5. So x=$1250$.
78 |
79 | Step 6. That's the final answer.
80 |
81 | Step 7. Right. So 1,000,000 lire is equivalent to $\$1250$.
82 |
83 | # Answer
84 |
85 | 1250
86 | """
87 |
88 |
89 | confidence_fewshot_a2="""
90 | {
91 | "original_question": "In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $1.50 in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?",
92 | "modified_question": "In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $1.50 in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?",
93 | "original_process": [
94 | "Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*$\\frac{1.50}{1200}$.",
95 | "There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.",
96 | "I think the second way will be easier. $\\frac{1.50}{1200}=0.00125$.",
97 | "That's true. Now we just have to solve the equation x=1000000*0.00125.",
98 | "So x=$1250$.",
99 | "That's the final answer.",
100 | "Right. So 1,000,000 lire is equivalent to $1250."
101 | ],
102 | "modified_process": [
103 | "Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*$\\frac{1.50}{1200}$.",
104 | "There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.",
105 | "I think the second way will be easier. $\\frac{1.50}{1200}=0.00125$.",
106 | "That's true. Now we just have to solve the equation x=1000000*0.00125.",
107 | "So x=$1250$.",
108 | "That's the final answer.",
109 | "Actually, after a closer look, $x=0.00125 \\times 1000000$ should be $12500$. I must have miscalculated earlier.",
110 | "So 1,000,000 lire is equivalent to $12500."
111 | ],
112 | "modified_steps": [6, 7],
113 | "error_steps": [7],
114 | "reason": "Step 7 introduces a confident hallucination. It contradicts the correct value of $1250$ by recalculating and asserting that the result is $12500$, which is false. This error is presented with unwarranted certainty to simulate confident misjudgment."
115 | }
116 | """
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/docs/document.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # Documentations For PRM Eval ToolKit
5 |
6 | ## 1. Config File Formatting
7 |
8 | **We released some example scripts/configs to demonstrate how to use our toolkit. You can find them in the `mr_eval/scripts` directory.**
9 |
10 |
11 | You can organize your config as a list of dict or a single dict. It's recommend to use a yaml file.
12 | The
13 | ```yaml
14 | - model_args:
15 | # The model series you want to test, must be the same with the file name under mr_eval/models
16 | model: reasoneval
17 | # The arguments that you want to pass to the models, split by a comma.
18 | model_args: pretrained=GAIR/ReasonEval-7B,model_size=7B,redundancy_threshold=0.15
19 | # The batch size if you want to use batch inference.
20 | batch_size: 2
21 | task_args:
22 | # The task names you want to evaluate, split by a comma.
23 | task_name: prmtest_classified
24 | # checkpoint settings, organize them as a dict
25 | # taskname: ckpt_file path
26 | resume_from_ckpt:
27 | prmtest_classified: ./logs/generated/ckpt/reasoneval7b_prm800k_classified.jsonl
28 | save_to_ckpt:
29 | prmtest_classified: ./logs/generated/ckpt/reasoneval7b_prm800k_classified.jsonl
30 | script_args:
31 | verbosity: INFO
32 | # final result output path
33 | output_path: ./logs/prmtest_classified/reasoneval7b.jsonl
34 | ```
35 | After setting down the config, please run $PRMEval$ as:
36 |
37 | ```bash
38 | accelerate launch --config_file ${accelerate_config} \
39 | -m mr_eval \
40 | --config ${config_file}
41 | ```
42 |
43 | Our batch inference and multi-gpu parallel is inferenced based on huggingface accelerate, so please prepare a accelerate config and run based on it.
44 | An example accelerate config is:
45 |
46 | ```yaml
47 | compute_environment: LOCAL_MACHINE
48 | distributed_type: MULTI_GPU
49 | downcast_bf16: 'no'
50 | gpu_ids: all
51 | machine_rank: 0
52 | main_training_function: main
53 | mixed_precision: bf16
54 | num_machines: 1
55 | num_processes: 8
56 | rdzv_backend: static
57 | same_network: true
58 | tpu_env: []
59 | tpu_use_cluster: false
60 | tpu_use_sudo: false
61 | use_cpu: false
62 | ```
63 |
64 | But notice that when testing api modes (e.g. gemini and openai series models), the batch size must be set at 1 and do not use multi-process parallel.
65 |
66 |
67 | ## 2.Introduction to our basic framework
68 |
69 | Our `PRMEval` framework is consisted with two important concepts: `task` and `model`. You can add custom tasks or models to customize your own evaluation framework.
70 | The tasks and models are connected throuth a pytorch dataset, whose basic implementation can be found at `mr_eval/tasks/base_dataset`. The **data loading logic**(`load_data_function()`) and **evaluation logic**(`evaluate_function()`) is implemeted by `task` and the **get data instance logic**(`getitem_function(self,meta_data,index)`) is implemented by `model`.
71 |
72 | The results of the evaluation will be staged in `base_dataset`, so you can call `dataloader.dataset.store_results(res)` to stage the results temporarily. After the whole evaluation process, the evaluation process will call `evaluate_function()` to get the final results.
73 |
74 | ## 3.How to add a new model?
75 |
76 | 1. Implement your model inference script under `mr_eval/models`
77 |
78 | 2. Wrap your model inference code with base class `prm`
79 |
80 | 3. Implement the `getitem_function` and `respond` function in your model class
81 |
82 | 4. register your model in `AVAILABLE_MODELS` in `mr_eval/models/__init__.py`, the key should be the same with the `model` in your config file and your implement python file name. The value should be the class name of your model.
83 |
84 | Notes:
85 |
86 | 1. When implementing the `getitem_function`, you should return a dict whose keys and values are decided by yourself. You can design the data structure of the dict based on your model `respond` logic.
87 |
88 | 2. Generally speaking, the temprary results is formatted as:
89 | ```python
90 | score_dict = dict(step_level_redundancy_scores=step_level_redundancy_scores,
91 | step_level_validity_scores=step_level_validity_scores,
92 | step_level_redundancy_labels=[item > self.redundancy_threshold for item in step_level_redundancy_scores],
93 | step_level_validity_labels=[item > self.validity_threshold for item in step_level_validity_scores],
94 | solution_level_redundancy_scores= max(step_level_redundancy_scores),
95 | solution_level_validity_scores=min(step_level_validity_scores),
96 | solution_level_redundancy_labels=max(step_level_redundancy_scores)>self.redundancy_threshold,
97 | solution_level_validity_labels=min(step_level_validity_scores)>self.validity_threshold
98 | )
99 | res = dict(scores=score_dict, idx=idx_item)
100 | ```
101 | It is dealed between model and task, which means you can customize them. For example, PRMBench only uses `step_level_validity_scores`, `step_level_redundancy_scores` `step_level_redundancy_labels` and `step_level_validity_labels` to evaluate the model.
102 |
103 | 3. When finish one round of batch inference, you should call `dataloader.dataset.store_results(res)` to store the results in the dataset sequentially.
104 |
105 | 4. An example template is:
106 |
107 | ```python
108 | from .abstract_model import prm
109 | class GeminiModels(prm):
110 | def __init__(
111 | self,
112 | model_name = "your-model-name",
113 | ) -> None:
114 | super().__init__(validity_threshold=validity_threshold, redundancy_threshold=redundancy_threshold)
115 | # your initialize scripts
116 |
117 | def getitem_function(self,meta_data,index) -> Dict:
118 | pass
119 |
120 | def respond(self, dataloader) -> None:
121 | pass
122 |
123 |
124 | ```
125 | ## 4.How to add a new task?
126 |
127 | 1. Implement your task under `mr_eval/tasks/your-task-name`
128 | 2. Implement your `mr_eval/tasks/your-task-name/config.yaml` and `mr_eval/tasks/your-task-name/task.py`
129 | 3. In `task.py`, implement the `load_data_function` and `evaluate_function`
130 | 4. No need to register your task, but make sure the `task_name` in your config file is the same with the folder name of your task, that is `your-task-name` in this demo.
131 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/prompts/classifications/step_contradiction.py:
--------------------------------------------------------------------------------
1 | step_contradiction_fewshot_q1="""
2 | Question:
3 |
4 | If $x^2 - x - 1 = 0$, what is the value of $x^3 - 2x + 1$?
5 |
6 | Answer:
7 |
8 | Step 1. Let's start with the first equation and see if we can solve for x.
9 |
10 | Step 2. We can use the quadratic formula to find x.
11 |
12 | Step 3. The quadratic formula is $x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$.
13 |
14 | Step 4. Substituting $a = 1$, $b = -1$, and $c = -1$, we get $x = \\frac{1 \\pm \\sqrt{5}}{2}$.
15 |
16 | Step 5. Since $x^2 - x - 1 = 0$, we replace $x^2$ with $x + 1$ in $x^3 - 2x + 1$.
17 |
18 | Step 6. Substituting, $x^3 - 2x + 1 = x(x + 1) - 2x + 1$.
19 |
20 | Step 7. Simplify: $x^3 - 2x + 1 = (x^2 + x) - 2x + 1$.
21 |
22 | Step 8. Substitute $x^2 = x + 1$: $x^3 - 2x + 1 = (x + 1 + x) - 2x + 1$.
23 |
24 | Step 9. Simplify: $x^3 - 2x + 1 = 2$."
25 |
26 | # Answer
27 |
28 | 2
29 | """
30 | step_contradiction_fewshot_a1="""
31 | {
32 | "original_question": "If x^2 - x - 1 = 0, what is the value of x^3 - 2x + 1?",
33 | "modified_question": "If x^2 - x - 1 = 0, what is the value of x^3 - 2x + 1?",
34 | "original_process": [
35 | "Let's start with the first equation and see if we can solve for x.",
36 | "We can use the quadratic formula to find x.",
37 | "The quadratic formula is x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}.",
38 | "Substituting a = 1, b = -1, and c = -1, we get x = \\frac{1 \\pm \\sqrt{5}}{2}.",
39 | "Since x^2 - x - 1 = 0, we replace x^2 with x + 1 in x^3 - 2x + 1.",
40 | "Substituting, x^3 - 2x + 1 = x(x + 1) - 2x + 1.",
41 | "Simplify: x^3 - 2x + 1 = (x^2 + x) - 2x + 1.",
42 | "Substitute x^2 = x + 1: x^3 - 2x + 1 = (x + 1 + x) - 2x + 1.",
43 | "Simplify: x^3 - 2x + 1 = 2."
44 | ],
45 | "modified_process": [
46 | "Let's start with the first equation and see if we can solve for x.",
47 | "We can use the quadratic formula to find x.",
48 | "The quadratic formula is x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}.",
49 | "Substituting a = 1, b = -1, and c = -1, we get x = \\frac{1 \\pm \\sqrt{5}}{2}.",
50 | "Since x^2 - x - 1 = 0, we replace x^2 with x + 1 in x^3 - 2x + 1.",
51 | "Substituting, x^3 - 2x + 1 = x(x + 1) - 2x + 1.",
52 | "Simplify: x^3 - 2x + 1 = (x^2 + x) - 2x + 1.",
53 | "Incorrectly substitute x^2 = x - 1 instead of x + 1: x^3 - 2x + 1 = (x - 1 + x) - 2x + 1.",
54 | "Simplify: x^3 - 2x + 1 = 0."
55 | ],
56 | "modified_steps": [8, 9],
57 | "error_steps": [8],
58 | "reason": "Step 8 introduces a step contradiction by incorrectly substituting x^2 = x - 1 instead of the correct x^2 = x + 1. This conflicts with the earlier derivation from x^2 - x - 1 = 0, where x^2 = x + 1. As a result, the final answer becomes 0 instead of the correct value, 2."
59 | }
60 | """
61 |
62 | step_contradiction_fewshot_q2="""
63 | Question:
64 |
65 | In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $\$1.50$ in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?
66 |
67 | Answer:
68 |
69 | Step 1. Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*$\\frac{1.50}{1200}$.
70 |
71 | Step 2. There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.
72 |
73 | Step 3. I think the second way will be easier. $\\frac{1.50}{1200}=0.00125$.
74 |
75 | Step 4. That's true. Now we just have to solve the equation x=1000000*0.00125.
76 |
77 | Step 5. So x=$1250$.
78 |
79 | Step 6. That's the final answer.
80 |
81 | Step 7. Right. So 1,000,000 lire is equivalent to $\$1250$.
82 |
83 | # Answer
84 |
85 | 1250
86 | """
87 |
88 |
89 | step_contradiction_fewshot_a2="""
90 | {
91 | "original_question": "In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $1.50 in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?",
92 | "modified_question": "In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $1.50 in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?",
93 | "original_process": [
94 | "Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*\\frac{1.50}{1200}.",
95 | "There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.",
96 | "I think the second way will be easier. \\frac{1.50}{1200}=0.00125.",
97 | "That's true. Now we just have to solve the equation x=1000000*0.00125.",
98 | "So x=1250.",
99 | "That's the final answer.",
100 | "Right. So 1,000,000 lire is equivalent to $1250."
101 | ],
102 | "modified_process": [
103 | "Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*\\frac{1.50}{1200}.",
104 | "There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.",
105 | "I think the second way will be easier. \\frac{1.50}{1200}=0.00125.",
106 | "That's true. Now we just have to solve the equation x=1000000*0.00125.",
107 | "Instead of solving this directly, incorrectly calculate \\frac{1,000,000}{1200}=833.33.",
108 | "Then multiply the result by 1.50: 833.33*1.50=1250.",
109 | "That's the final answer.",
110 | "Right. So 1,000,000 lire is equivalent to $1250."
111 | ],
112 | "modified_steps": [5, 6],
113 | "error_steps": [5],
114 | "reason": "Step 5 introduces a contradiction by switching from the correct direct multiplication approach to an incorrect intermediate calculation (dividing by 1200 instead of multiplying by 0.00125). This causes logical inconsistency within the process, though the final result coincidentally matches the correct answer."
115 | }
116 | """
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/prompts/classifications/counterfactual.py:
--------------------------------------------------------------------------------
1 | counterfactual_fewshot_q1="""
2 | Question:
3 |
4 | If $x^2 - x - 1 = 0$, what is the value of $x^3 - 2x + 1$?
5 |
6 | Answer:
7 |
8 | Step 1. Let's start with the first equation and see if we can solve for x.
9 |
10 | Step 2. We can use the quadratic formula to find x.
11 |
12 | Step 3. The quadratic formula is $x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$.
13 |
14 | Step 4. Substituting $a = 1$, $b = -1$, and $c = -1$, we get $x = \\frac{1 \\pm \\sqrt{5}}{2}$.
15 |
16 | Step 5. Since $x^2 - x - 1 = 0$, we replace $x^2$ with $x + 1$ in $x^3 - 2x + 1$.
17 |
18 | Step 6. Substituting, $x^3 - 2x + 1 = x(x + 1) - 2x + 1$.
19 |
20 | Step 7. Simplify: $x^3 - 2x + 1 = (x^2 + x) - 2x + 1$.
21 |
22 | Step 8. Substitute $x^2 = x + 1$: $x^3 - 2x + 1 = (x + 1 + x) - 2x + 1$.
23 |
24 | Step 9. Simplify: $x^3 - 2x + 1 = 2$."
25 |
26 | # Answer
27 |
28 | 2
29 | """
30 | counterfactual_fewshot_a1="""
31 | {
32 | "original_question": "If $x^2 - x - 1 = 0$, what is the value of $x^3 - 2x + 1$?",
33 | "modified_question": "If $x^2 - x - 1 = 0$, what is the value of $x^3 - 2x + 1$?",
34 | "original_process": [
35 | "Let's start with the first equation and see if we can solve for x.",
36 | "We can use the quadratic formula to find x.",
37 | "The quadratic formula is $x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$.",
38 | "Substituting $a = 1$, $b = -1$, and $c = -1$, we get $x = \\frac{1 \\pm \\sqrt{5}}{2}$.",
39 | "Since $x^2 - x - 1 = 0$, we replace $x^2$ with $x + 1$ in $x^3 - 2x + 1$.",
40 | "Substituting, $x^3 - 2x + 1 = x(x + 1) - 2x + 1$.",
41 | "Simplify: $x^3 - 2x + 1 = (x^2 + x) - 2x + 1$.",
42 | "Substitute $x^2 = x + 1$: $x^3 - 2x + 1 = (x + 1 + x) - 2x + 1$.",
43 | "Simplify: $x^3 - 2x + 1 = 2$."
44 | ],
45 | "modified_process": [
46 | "Let's start with the first equation and see if we can solve for x.",
47 | "We can use the quadratic formula to find x.",
48 | "The quadratic formula is $x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$.",
49 | "Substituting $a = 1$, $b = -1$, and $c = -1$, we get $x = \\frac{1 \\pm \\sqrt{5}}{2}$.",
50 | "Since $x^2 - x - 1 = 0$, we replace $x^2$ with $x - 1$ in $x^3 - 2x + 1$.",
51 | "Substituting, $x^3 - 2x + 1 = x(x - 1) - 2x + 1$.",
52 | "Simplify: $x^3 - 2x + 1 = (x^2 - x) - 2x + 1$.",
53 | "Substitute $x^2 = x - 1$: $x^3 - 2x + 1 = (x - 1 - x) - 2x + 1$.",
54 | "Simplify: $x^3 - 2x + 1 = -2x + 1$."
55 | ],
56 | "modified_steps": [5, 6, 7, 8, 9],
57 | "error_steps": [5, 6, 8],
58 | "reason": "In step 5, a counterfactual error was introduced by incorrectly substituting $x^2$ with $x - 1$ instead of $x + 1$. This incorrect substitution was propagated in step 6, leading to further incorrect computations. In step 8, the final simplification used the flawed substitution, resulting in an incorrect answer."
59 | }
60 | """
61 |
62 | counterfactual_fewshot_q2="""
63 | Question:
64 |
65 | In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $\$1.50$ in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?
66 |
67 | Answer:
68 |
69 | Step 1. Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*$\\frac{1.50}{1200}$.
70 |
71 | Step 2. There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.
72 |
73 | Step 3. I think the second way will be easier. $\\frac{1.50}{1200}=0.00125$.
74 |
75 | Step 4. That's true. Now we just have to solve the equation x=1000000*0.00125.
76 |
77 | Step 5. So x=$1250$.
78 |
79 | Step 6. That's the final answer.
80 |
81 | Step 7. Right. So 1,000,000 lire is equivalent to $\$1250$.
82 |
83 | # Answer
84 |
85 | 1250
86 | """
87 |
88 |
89 | counterfactual_fewshot_a2="""
90 | {
91 | "original_question": "In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $1.50 in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?",
92 | "modified_question": "In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $1.50 in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?",
93 | "original_process": [
94 | "Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*$\\frac{1.50}{1200}$.",
95 | "There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.",
96 | "I think the second way will be easier. $\\frac{1.50}{1200}=0.00125$.",
97 | "That's true. Now we just have to solve the equation x=1000000*0.00125.",
98 | "So x=$1250$.",
99 | "That's the final answer.",
100 | "Right. So 1,000,000 lire is equivalent to $1250."
101 | ],
102 | "modified_process": [
103 | "Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*$\\frac{1.50}{1200}$.",
104 | "There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.",
105 | "I think the second way will be easier. $\\frac{1.50}{1200}=0.00125$.",
106 | "That's true. Now we just have to solve the equation x=1000000*0.00125.",
107 | "So x=$1250$.",
108 | "But wait, we need to account for transaction fees that add an additional 10%.",
109 | "So x=$1250 + 0.10 * 1250$.",
110 | "That gives x=$1375$.",
111 | "Right. So 1,000,000 lire is equivalent to $1375."
112 | ],
113 | "modified_steps": [6, 7, 8, 9],
114 | "error_steps": [6, 7, 8, 9],
115 | "reason": "A counterfactual step was introduced in step 6 by claiming the need to account for an additional 10% transaction fee, which is not mentioned in the problem. This erroneous assumption leads to the incorrect result of $1375 instead of $1250."
116 | }
117 | """
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/prompts/classifications/redundency.py:
--------------------------------------------------------------------------------
1 | redundency_fewshot_q1="""
2 | Question:
3 |
4 | If $x^2 - x - 1 = 0$, what is the value of $x^3 - 2x + 1$?
5 |
6 | Answer:
7 |
8 | Step 1. Let's start with the first equation and see if we can solve for x.
9 |
10 | Step 2. We can use the quadratic formula to find x.
11 |
12 | Step 3. The quadratic formula is $x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$.
13 |
14 | Step 4. Substituting $a = 1$, $b = -1$, and $c = -1$, we get $x = \\frac{1 \\pm \\sqrt{5}}{2}$.
15 |
16 | Step 5. Since $x^2 - x - 1 = 0$, we replace $x^2$ with $x + 1$ in $x^3 - 2x + 1$.
17 |
18 | Step 6. Substituting, $x^3 - 2x + 1 = x(x + 1) - 2x + 1$.
19 |
20 | Step 7. Simplify: $x^3 - 2x + 1 = (x^2 + x) - 2x + 1$.
21 |
22 | Step 8. Substitute $x^2 = x + 1$: $x^3 - 2x + 1 = (x + 1 + x) - 2x + 1$.
23 |
24 | Step 9. Simplify: $x^3 - 2x + 1 = 2$."
25 |
26 | # Answer
27 |
28 | 2
29 | """
30 | redundency_fewshot_a1="""
31 | {
32 | "original_question": "If $x^2 - x - 1 = 0$, what is the value of $x^3 - 2x + 1$?",
33 | "modified_question": "If $x^2 - x - 1 = 0$, what is the value of $x^3 - 2x + 1$?",
34 | "original_process": [
35 | "Let's start with the first equation and see if we can solve for x.",
36 | "We can use the quadratic formula to find x.",
37 | "The quadratic formula is $x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$.",
38 | "Substituting $a = 1$, $b = -1$, and $c = -1$, we get $x = \\frac{1 \\pm \\sqrt{5}}{2}$.",
39 | "Since $x^2 - x - 1 = 0$, we replace $x^2$ with $x + 1$ in $x^3 - 2x + 1$.",
40 | "Substituting, $x^3 - 2x + 1 = x(x + 1) - 2x + 1$.",
41 | "Simplify: $x^3 - 2x + 1 = (x^2 + x) - 2x + 1$.",
42 | "Substitute $x^2 = x + 1$: $x^3 - 2x + 1 = (x + 1 + x) - 2x + 1$.",
43 | "Simplify: $x^3 - 2x + 1 = 2$."
44 | ],
45 | "modified_process": [
46 | "Let's start with the first equation and see if we can solve for x.",
47 | "We can use the quadratic formula to find x.",
48 | "The quadratic formula is $x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$.",
49 | "Substituting $a = 1$, $b = -1$, and $c = -1$, we get $x = \\frac{1 \\pm \\sqrt{5}}{2}$.",
50 | "Since $x^2 - x - 1 = 0$, we replace $x^2$ with $x + 1$ in $x^3 - 2x + 1$.",
51 | "Substituting, $x^3 - 2x + 1 = x(x + 1) - 2x + 1$.",
52 | "Simplify: $x^3 - 2x + 1 = (x^2 + x) - 2x + 1$.",
53 | "Now let’s verify that $x^2 = x + 1$ by re-solving $x^2 - x - 1 = 0$ again.",
54 | "Substitute $x^2 = x + 1$: $x^3 - 2x + 1 = (x + 1 + x) - 2x + 1$.",
55 | "Simplify: $x^3 - 2x + 1 = 2$."
56 | ],
57 | "modified_steps": [8],
58 | "error_steps": [8],
59 | "reason": "Step 8 re-solves $x^2 - x - 1 = 0$ unnecessarily, introducing redundancy. This step does not provide new information, as $x^2 = x + 1$ was already established in The reasoning is still correct but less efficient."
60 | }
61 | """
62 |
63 | redundency_fewshot_q2="""
64 | Question:
65 |
66 | In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $\$1.50$ in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?
67 |
68 | Answer:
69 |
70 | Step 1. Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*$\\frac{1.50}{1200}$.
71 |
72 | Step 2. There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.
73 |
74 | Step 3. I think the second way will be easier. $\\frac{1.50}{1200}=0.00125$.
75 |
76 | Step 4. That's true. Now we just have to solve the equation x=1000000*0.00125.
77 |
78 | Step 5. So x=$1250$.
79 |
80 | Step 6. That's the final answer.
81 |
82 | Step 7. Right. So 1,000,000 lire is equivalent to $\$1250$.
83 |
84 | # Answer
85 |
86 | 1250
87 | """
88 |
89 |
90 | redundency_fewshot_a2="""
91 | {
92 | "original_question": "In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $\\$1.50$ in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?",
93 | "modified_question": "In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $\\$1.50$ in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?",
94 | "original_process": [
95 | "Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*$\\frac{1.50}{1200}$.",
96 | "There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.",
97 | "I think the second way will be easier. $\\frac{1.50}{1200}=0.00125$.",
98 | "That's true. Now we just have to solve the equation x=1000000*0.00125.",
99 | "So x=$1250$.",
100 | "That's the final answer.",
101 | "Right. So 1,000,000 lire is equivalent to $\\$1250$."
102 | ],
103 | "modified_process": [
104 | "Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*$\\frac{1.50}{1200}$.",
105 | "There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.",
106 | "I think the second way will be easier. $\\frac{1.50}{1200}=0.00125$.",
107 | "That's true. But let's confirm the computation again: $\\frac{1.50}{1200}=0.00125$.",
108 | "Now we just have to solve the equation x=1000000*0.00125.",
109 | "So x=$1250$.",
110 | "Let's double-check by performing the multiplication again: 1000000*0.00125=$1250$.",
111 | "That's the final answer.",
112 | "Right. So 1,000,000 lire is equivalent to $\\$1250$."
113 | ],
114 | "modified_steps": [4, 7, 8],
115 | "error_steps": [4, 7, 8],
116 | "reason": "Step 4 unnecessarily repeats the computation of $\\frac{1.50}{1200}$, which was already completed in Step 7 redundantly re-checks the multiplication, adding no new information. These redundant steps make the reasoning process less concise without affecting correctness."
117 | }
118 | """
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/models/qwen_qwq.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import json
3 | import os
4 | import sys
5 |
6 | from typing import List, Optional, Tuple, Type, TypeVar, Union
7 |
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | import torch
11 | import torch.nn as nn
12 | from transformers.configuration_utils import PretrainedConfig
13 | from accelerate import Accelerator
14 | from copy import deepcopy
15 |
16 |
17 |
18 | from .abstract_model import prm
19 | from ..utils.prompts import PROMPT_DICT
20 | from ..utils.utils import *
21 | from ..utils.log_utils import *
22 | from ..utils.model_utils import remove_step_prefix, process_policy_lm_evaluation_response
23 |
24 | logger = get_logger(__name__)
25 |
26 |
27 |
28 | class QwenQwQ(prm):
29 | def __init__(
30 | self,
31 | pretrained = "/mnt/petrelfs/songmingyang/songmingyang/model/reasoning/policy_models/QwQ-32B-Preview",
32 | validity_threshold = 0,
33 | redundancy_threshold = 0,
34 | log_save_dir = "mr_eval/scripts/logs/generated/qwq.jsonl",
35 | ) -> None:
36 | super().__init__(validity_threshold=validity_threshold, redundancy_threshold=redundancy_threshold)
37 |
38 |
39 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
40 | self.tokenizer.padding_side = "left"
41 | self.model = AutoModelForCausalLM.from_pretrained(pretrained,).eval()
42 | self.log_save_dir = log_save_dir
43 |
44 | self.accelerator = Accelerator()
45 | self.model = self.model.to(self.accelerator.device)
46 | self.prompt = PROMPT_DICT["policy_model_as_an_evaluator"]
47 | self.messages = [
48 | {"role": "system", "content": self.prompt["system_prompt"]},
49 | {"role": "user", "content": self.prompt["fewshots"][0][0]},
50 | {"role": "assistant", "content": self.prompt["fewshots"][0][1]},
51 | {"role": "user", "content": self.prompt["fewshots"][1][0]},
52 | {"role": "assistant", "content": self.prompt["fewshots"][1][1]},
53 | ]
54 |
55 |
56 | def getitem_function(self,meta_data,index):
57 | data_idx = meta_data[index]["idx"]
58 | steps = meta_data[index]["steps"]
59 | question = meta_data[index]["question"]
60 |
61 | res = []
62 | for idx,step in enumerate(steps):
63 | clean_step = remove_step_prefix(step)
64 | res.append(f"Step {idx+1}: {clean_step} \n\n")
65 |
66 | steps_str = "".join(res)
67 | original_input_for_prm = f"Question: {question}\n\n Solutions: {steps_str}"
68 | messages = deepcopy(self.messages)
69 | messages.append({"role": "user", "content": original_input_for_prm})
70 |
71 | input_ids = self.tokenizer.apply_chat_template(
72 | messages,
73 | tokenize=True,
74 | add_generation_prompt=True,
75 | return_tensors="pt"
76 | )
77 | while input_ids.ndim > 1:
78 | input_ids = input_ids[0]
79 |
80 | res = dict(
81 | idx = data_idx,
82 | input_ids = input_ids,
83 | )
84 | return res
85 |
86 | def respond(self, dataloader) -> List[Tuple[float, bool]]:
87 | dataloader = self.accelerator.prepare(dataloader)
88 | self.accelerator.wait_for_everyone()
89 | self.model.eval()
90 | gen_kwargs = dataloader.dataset.gen_kwargs
91 | progress_bar = tqdm_rank0(len(dataloader), desc="Model Responding")
92 | if len(dataloader) == 0:
93 | self.accelerator.wait_for_everyone()
94 | return
95 | with torch.no_grad():
96 | for batch in dataloader:
97 | idx = batch['idx']
98 | input_ids = batch['input_ids']
99 | attention_mask = batch['attention_mask']
100 | generated_ids = self.model.generate(
101 | input_ids=input_ids,
102 | attention_mask=attention_mask,
103 | max_new_tokens=1024,
104 | )
105 | generated_ids = [
106 | output_id[len(input_id):] for input_id, output_id in zip(input_ids, generated_ids)
107 | ]
108 | response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
109 |
110 | for i in range(len(idx)):
111 | try:
112 | current_idx = idx[i]
113 | current_response = response[i]
114 | scores = process_policy_lm_evaluation_response(current_response)
115 | if scores:
116 | score_dict = dict(
117 | step_level_validity_scores = scores["validity"],
118 | step_level_redundancy_scores = scores["redundancy"],
119 | step_level_validity_labels = [item > 0 for item in scores["validity"]],
120 | step_level_redundancy_labels = [item > 0 for item in scores["redundancy"]],
121 | )
122 | res = dict(scores=score_dict, idx=current_idx, validity=True)
123 | else:
124 | res = dict(validity=False, idx=current_idx)
125 | dataloader.dataset.store_results(res)
126 | log = dict(idx = current_idx, response = current_response, scores = scores, result = res)
127 | dataloader.dataset.save_result_item_into_log(log,self.log_save_dir)
128 | except:
129 | current_idx = idx[i]
130 | current_response = response[i]
131 | logger.error(f"Error in responding to idx {current_idx}")
132 | res = dict(validity=False, idx=current_idx)
133 | dataloader.dataset.store_results(res)
134 | log = dict(idx = current_idx, response = current_response, scores = None, result = res)
135 | dataloader.dataset.save_result_item_into_log(log,self.log_save_dir)
136 | if progress_bar is not None:
137 | progress_bar.update(1)
138 |
139 | self.accelerator.wait_for_everyone()
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/model_inference/qwq/inferencer/qwq_inferencer.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import torch
4 | import tqdm
5 | from accelerate import Accelerator
6 | from torch.utils.data import DataLoader
7 | from tqdm import tqdm
8 | from transformers import AutoModelForCausalLM, AutoTokenizer
9 |
10 | from .qwq_inferencer_dataset import DataCollatorQwQdDataset, dataset_dict
11 |
12 |
13 | class BaseQwenInferencer:
14 | def __init__(self, inference_args, data_args, model_args):
15 | self.inference_args, self.data_args, self.model_args = inference_args, data_args, model_args
16 |
17 | self.initialize_model()
18 | self.prepare_dataset()
19 |
20 | def prepare_dataset(self):
21 | raise NotImplementedError
22 |
23 | def initialize_model(self):
24 | raise NotImplementedError
25 |
26 | def inference(self):
27 | raise NotImplementedError
28 |
29 |
30 | class QwQGeneratePRMInferencer(BaseQwenInferencer):
31 | def extract_steps(self, text):
32 | """
33 | 从文本中提取每个 Step 的内容,并按顺序返回一个列表。
34 | """
35 | # 正则表达式:匹配 "Step X." 开头,捕获其后的内容
36 | pattern = r"(Step \d+\..*?)(?=Step \d+\.|\Z)" # 匹配 Step 开头到下一个 Step 或文本结束
37 | steps = re.findall(pattern, text, re.DOTALL) # 使用 re.DOTALL 允许匹配换行符
38 | return steps
39 |
40 | def prepare_dataset(self):
41 | self.function = self.inference_args.function
42 | self.dataset = dataset_dict[self.function](self.data_args)
43 |
44 | def initialize_model(self):
45 | self.model_name = self.model_args.model_path
46 |
47 | self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", device_map="auto")
48 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
49 | self.tokenizer.padding_side = "left"
50 |
51 | def inference(self):
52 | for idx, item in enumerate(tqdm(self.dataset)):
53 | messages = item["messages"]
54 | item_idx = item["idx"]
55 | question = item["question"]
56 | process_list = item["process_list"]
57 | ground_truth = item["ground_truth"]
58 |
59 | # print(messages)
60 | text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
61 | model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
62 |
63 | generated_ids = self.model.generate(**model_inputs, max_new_tokens=2048)
64 | generated_ids = [
65 | output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
66 | ]
67 |
68 | response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
69 |
70 | response_list = self.extract_steps(response)
71 | res_item = dict(
72 | original_question=question,
73 | modified_question=question,
74 | original_process=process_list,
75 | modified_process=response_list,
76 | modified_steps=[],
77 | error_steps=[],
78 | reason="",
79 | ground_truth=ground_truth,
80 | idx=item_idx,
81 | question=question,
82 | classification="one_question_multi_answers",
83 | original_response=response,
84 | )
85 | self.dataset.write_output_item(res_item)
86 |
87 |
88 | class QwQParallelGeneratePRMInferencer(BaseQwenInferencer):
89 | def extract_steps(self, text):
90 | """
91 | 从文本中提取每个 Step 的内容,并按顺序返回一个列表。
92 | """
93 | # 正则表达式:匹配 "Step X." 开头,捕获其后的内容
94 | pattern = r"(Step \d+\..*?)(?=Step \d+\.|\Z)" # 匹配 Step 开头到下一个 Step 或文本结束
95 | steps = re.findall(pattern, text, re.DOTALL) # 使用 re.DOTALL 允许匹配换行符
96 | return steps
97 |
98 | def prepare_dataset(self):
99 | self.function = self.inference_args.function
100 | self.dataset = dataset_dict[self.function](self.data_args)
101 | data_collator = DataCollatorQwQdDataset(self.tokenizer)
102 | self.dataloader = DataLoader(
103 | self.dataset,
104 | collate_fn=data_collator,
105 | batch_size=self.data_args.batch_size,
106 | num_workers=self.data_args.num_workers,
107 | )
108 |
109 | def initialize_model(self):
110 | self.model_name = self.model_args.model_path
111 |
112 | self.model = AutoModelForCausalLM.from_pretrained(
113 | self.model_name, torch_dtype=torch.bfloat16, device_map="cpu"
114 | )
115 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
116 | self.tokenizer.padding_side = "left"
117 | self.accelerator = Accelerator()
118 |
119 | def inference(self):
120 | self.model = self.model.to(self.accelerator.device)
121 | self.dataloader = self.accelerator.prepare(self.dataloader)
122 |
123 | for batch in tqdm(self.dataloader):
124 | with torch.no_grad():
125 | model_inputs = batch["model_inputs"].to(self.accelerator.device)
126 | generated_ids = self.model.generate(**model_inputs, max_new_tokens=2048)
127 |
128 | generated_ids = [
129 | output_ids[len(input_ids) :]
130 | for input_ids, output_ids in zip(model_inputs["input_ids"], generated_ids)
131 | ]
132 | response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
133 | for i in range(len(response)):
134 | response_list = self.extract_steps(response[i])
135 | res_item = dict(
136 | original_question=batch["question"][i],
137 | modified_question=batch["question"][i],
138 | original_process=batch["process_list"][i],
139 | modified_process=response_list,
140 | modified_steps=[],
141 | error_steps=[],
142 | reason="",
143 | ground_truth=batch["ground_truth"][i],
144 | idx=batch["idx"][i],
145 | question=batch["question"][i],
146 | classification="one_question_multi_answers",
147 | original_response=response[i],
148 | )
149 | self.dataset.write_output_item(res_item)
150 |
151 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/prompts/classifications/domain_inconsistency.py:
--------------------------------------------------------------------------------
1 | domain_inconsistency_fewshot_q1="""
2 | Question:
3 |
4 | If $x^2 - x - 1 = 0$, what is the value of $x^3 - 2x + 1$?
5 |
6 | Answer:
7 |
8 | Step 1. Let's start with the first equation and see if we can solve for x.
9 |
10 | Step 2. We can use the quadratic formula to find x.
11 |
12 | Step 3. The quadratic formula is $x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$.
13 |
14 | Step 4. Substituting $a = 1$, $b = -1$, and $c = -1$, we get $x = \\frac{1 \\pm \\sqrt{5}}{2}$.
15 |
16 | Step 5. Since $x^2 - x - 1 = 0$, we replace $x^2$ with $x + 1$ in $x^3 - 2x + 1$.
17 |
18 | Step 6. Substituting, $x^3 - 2x + 1 = x(x + 1) - 2x + 1$.
19 |
20 | Step 7. Simplify: $x^3 - 2x + 1 = (x^2 + x) - 2x + 1$.
21 |
22 | Step 8. Substitute $x^2 = x + 1$: $x^3 - 2x + 1 = (x + 1 + x) - 2x + 1$.
23 |
24 | Step 9. Simplify: $x^3 - 2x + 1 = 2$."
25 |
26 | # Answer
27 |
28 | 2
29 | """
30 | domain_inconsistency_fewshot_a1="""
31 | {
32 | "original_question": "If $x^2 - x - 1 = 0$, what is the value of $x^3 - 2x + 1$?",
33 | "modified_question": "If $x^2 - x - 1 = 0$, what is the value of $x^3 - 2x + 1$?",
34 | "original_process": [
35 | "Let's start with the first equation and see if we can solve for x.",
36 | "We can use the quadratic formula to find x.",
37 | "The quadratic formula is $x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$.",
38 | "Substituting $a = 1$, $b = -1$, and $c = -1$, we get $x = \\frac{1 \\pm \\sqrt{5}}{2}$.",
39 | "Since $x^2 - x - 1 = 0$, we replace $x^2$ with $x + 1$ in $x^3 - 2x + 1$.",
40 | "Substituting, $x^3 - 2x + 1 = x(x + 1) - 2x + 1$.",
41 | "Simplify: $x^3 - 2x + 1 = (x^2 + x) - 2x + 1$.",
42 | "Substitute $x^2 = x + 1$: $x^3 - 2x + 1 = (x + 1 + x) - 2x + 1$.",
43 | "Simplify: $x^3 - 2x + 1 = 2$."
44 | ],
45 | "modified_process": [
46 | "Let's start with the first equation and see if we can solve for x.",
47 | "We can use the quadratic formula to find x.",
48 | "The quadratic formula is $x = \\frac{-b \\pm \\sqrt{b^2 - 4ac}}{2a}$.",
49 | "Substituting $a = 1$, $b = -1$, and $c = -1$, we get $x = \\frac{1 \\pm \\sqrt{5}}{2}$.",
50 | "Since $x^2 - x - 1 = 0$, we replace $x^2$ with $x + 1$ in $x^3 - 2x + 1$.",
51 | "Substituting, $x^3 - 2x + 1 = x(x + 1) - 2x + 1$.",
52 | "To simplify $x^3$, use the property of cubic roots: $x^3 = \\sqrt[3]{x}$.",
53 | "Substitute $x^3 = \\sqrt[3]{x}$ and $x^2 = x + 1$ into $x^3 - 2x + 1$: $x^3 - 2x + 1 = \\sqrt[3]{x} - 2x + 1$.",
54 | "Simplify further: $x^3 - 2x + 1 = \\sqrt[3]{x} - x + 1$."
55 | ],
56 | "modified_steps": [7, 8, 9],
57 | "error_steps": [7],
58 | "reason": "A domain inconsistency was introduced in Step 7 by incorrectly assuming that $x^3 = \\sqrt[3]{x}$. This assumption is valid for root-based equations but not in the context of solving a polynomial equation. This error propagates through the rest of the reasoning, leading to a flawed solution."
59 | }
60 | """
61 |
62 | domain_inconsistency_fewshot_q2="""
63 | Question:
64 |
65 | In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $\$1.50$ in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?
66 |
67 | Answer:
68 |
69 | Step 1. Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*$\\frac{1.50}{1200}$.
70 |
71 | Step 2. There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.
72 |
73 | Step 3. I think the second way will be easier. $\\frac{1.50}{1200}=0.00125$.
74 |
75 | Step 4. That's true. Now we just have to solve the equation x=1000000*0.00125.
76 |
77 | Step 5. So x=$1250$.
78 |
79 | Step 6. That's the final answer.
80 |
81 | Step 7. Right. So 1,000,000 lire is equivalent to $\$1250$.
82 |
83 | # Answer
84 |
85 | 1250
86 | """
87 |
88 |
89 | domain_inconsistency_fewshot_a2="""
90 | {
91 | "original_question": "In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $1.50 in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?",
92 | "modified_question": "In 1992, a scoop of gelato could be purchased in Italy for 1200 lire. The same gelato would have cost $1.50 in the U.S. At the equivalent exchange rate between the lire and the dollar, how many dollars would be equivalent to 1,000,000 lire?",
93 | "original_process": [
94 | "Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*$\\frac{1.50}{1200}$.",
95 | "There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.",
96 | "I think the second way will be easier. $\\frac{1.50}{1200}=0.00125$.",
97 | "That's true. Now we just have to solve the equation x=1000000*0.00125.",
98 | "So x=$1250$.",
99 | "That's the final answer.",
100 | "Right. So 1,000,000 lire is equivalent to $1250$."
101 | ],
102 | "modified_process": [
103 | "Let's call the number of dollars x. So, the problem is asking us to solve the equation x=1000000*$\\frac{1.50}{1200}$.",
104 | "There are two ways to solve this equation. We could either divide 1000000 by 1200 and then multiply the result by 1.50 or we could divide 1.50 by 1200 and then multiply the result by 1000000.",
105 | "I think the second way will be easier. $\\frac{1.50}{1200}=0.00125$.",
106 | "Let's consider the equivalent in euros by assuming that $1$ euro is equivalent to $1.20$ at the time.",
107 | "Using this, $x=1000000 * 0.00125 * 1.20$.",
108 | "Now solve the equation to get $x=1500$.",
109 | "So 1,000,000 lire is equivalent to $1500$."
110 | ],
111 | "modified_steps": [4, 5, 6, 7],
112 | "error_steps": [4, 5],
113 | "reason": "A domain inconsistency error was introduced in Step 4 by incorrectly introducing euros into the problem. The original problem strictly deals with the lire-to-dollar conversion, but the modified process adds an exchange rate for euros, which is outside the scope of the problem. This inconsistency changes the result and creates a plausible but incorrect answer."
114 | }
115 | """
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_annotate/build_data/prompts/prompt_new.py:
--------------------------------------------------------------------------------
1 | from .classifications import fewshot_dicts
2 | basic_prompt = """
3 | You are a helpful AI assistant that is very good at reasoning and data construction. Now I want to test the ability of process-level reward models to judge whether a step within reasoning process is correct. To do this, please help me build flawed cases by introducing specific types of errors into a given reasoning process.
4 |
5 | You will be provided with:
6 |
7 | 1. A mathematical problem.
8 | 2. Its standard correct answer.
9 | 3. A correct step-by-step reasoning process used to solve it.
10 |
11 | Your task is to modify the question, adjust one or more steps, or introduce additional steps into the original process chain to create a reasoning process that appears plausible but is incorrect. The objective is to simulate flawed solutions by incorporating the specified error detailed after '### Error Type to Introduce'.
12 |
13 | ### Error Type to Introduce
14 | """
15 |
16 | redundency = """
17 | Redundancy refers to a process that is not the most concise or efficient, as it includes one or more redundant steps that can be removed without affecting the correctness of the overall solution path. For example, if $ A \\to B $ represents a correct inference chain, your task is to introduce one or more redundant steps $ C = {c | c is redundent} $ and reformulate the solution chain as $ A \\to C \\to B $.
18 | """
19 |
20 | circular = """
21 | Circular logic is a specific form of redundancy, characterized by a reasoning chain that starts at a step $ S $, progresses through a sequence of steps, and ultimately loops back to $ S $. Symbolically, this can be expressed as $ S \\to A \\to B \\to S $, where $ S $, $ A $, and $ B $ represent individual reasoning steps. Your task is to modify the reasoning process to introduce such circular logic.
22 | """
23 |
24 | counterfactual="""
25 | A counterfactual step refers to a statement within a reasoning chain that contradicts established ground truth. Such contradictions can arise from relying on outdated theories, omitting critical constraints in a theory, or incorporating erroneous assumptions. Your task is to modify the reasoning process to introduce such counterfactual steps.
26 | """
27 |
28 | step_contradiction="""
29 | Step contradiction refers to a conflict between a specific step and other steps within a reasoning path. Given a reasoning path $ P = {S_1, S_2, \dots, S_n} $, a step contradiction exists if $ S_i \perp S_j $, where $ i, j \in [1, n] $ and $ i \\neq j $. Your task is to modify the reasoning process to introduce such step contradiction steps.
30 | """
31 |
32 | domain_inconsistency="""
33 | Domain inconsistency is a special type of counterfactual. It refers to a step within the reasoning chain that uses a statement or theory valid in other domains or cases but is not valid within the current reasoning chain. Your task is to modify the reasoning process to introduce such domain inconsistency steps.
34 | """
35 |
36 | confidence="""
37 | Confident hallucination is a special type of counterfactual. It refers to a statement within the reasoning chain that contradicts established ground truth and is presented with an overly confident tone. In other words, it involves stating an incorrect statement with unwarranted certainty. Your task is to modify the reasoning process to introduce such confident hallucination steps.
38 | """
39 |
40 | missing_condition="""
41 | Missing condition or prerequisite refers to a flaw in the reasoning chain where critical premises, assumptions, or necessary conditions are absent. This omission results in logical gaps, incomplete reasoning, or biased conclusions. For example, when a missing condition occurs, the model is required to solve the problem through case analysis or further investigation. However, the answer becomes incorrect if the model overlooks the missing condition and proceeds with standard reasoning methods. Your task is to modify the reasoning process to introduce such missing condition error steps.
42 | """
43 |
44 | deception="""
45 | Deception or traps refer to statements that appear to be correct or align with ground truth but are subtly altered to introduce inaccuracies while maintaining the illusion of correctness. Your task is to modify the reasoning process to introduce such deception or trap error steps.
46 | """
47 |
48 |
49 | post_prompt = """
50 | ### Formatting Instructions
51 |
52 | After making the modifications, provide the following structured output:
53 | {
54 | "original_question": "The original mathematical problem.",
55 | "modified_question": "The modified problem or original problem
56 | "original_process": ["original_step 1", "original_step 2", ...],
57 | "modified_process": ["modified_step 1", "modified_step 2", ...],
58 | "modified_steps": [1, 5, 7, ...],
59 | "error_steps": [5, 6, ...],
60 | "reason": "Explanation for the changes."
61 | }
62 |
63 | Detailed Requirements:
64 | 1. original_question: A string representing the original mathematical problem as provided.
65 | 2. modified_question: A string representing the modified problem after your changes. If the problem remains the same, you can copy the original question.
66 | 3. original_process: A non-empty list of strings representing the original reasoning steps provided as input.
67 | 4. modified_process: A non-empty list of strings representing the reasoning process after your modifications.
68 | 5. modified_steps: A non-empty list of integers indicating the indexes of all modified steps. Indexing starts at 1.
69 | 6. error_steps: A non-empty list of integers representing the steps that contain hallucinations or errors. These should also be part of modified_steps.
70 | 7. reason: A clear explanation of the modifications made, why they were introduced, and how they align with the specified error types.
71 |
72 | ### Notes:
73 |
74 | 1. Ensure all lists are non-empty.
75 | 2. Use LaTeX format for all mathematical symbols (e.g., $x^2$ for $x$ squared). Do not use Unicode symbols such as \u2248 or \u00f7.
76 | 3. Ensure the JSON object is well-formed, with proper escaping for special characters like backslash n (e.g., use backslash backslash n for newlines).
77 | 4. All indexes start from 1, that is, the first step's index is 1, not 0.
78 | 5. You can choose to modify the question or not, if the question remains the same, you can copy the original question. But if the question is modified, ensure that the steps is judged based on the modified question.
79 | 6. Please give original process as provided by the prompt, do not modify it.
80 | """
81 |
82 | classifications = ["circular", "confidence", "counterfactual", "step_contradiction", "domain_inconsistency", "redundency", "missing_condition", "deception"]
83 |
84 | prompt_dict = {}
85 |
86 | for classification in classifications:
87 | prompt_dict[classification] = dict(system = basic_prompt + eval(classification) + post_prompt, few_shot = fewshot_dicts[classification])
88 |
--------------------------------------------------------------------------------
/src/active_prm/eval/PRMBench/mr_eval/models/vllm_models.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sys
4 |
5 | from typing import List, Optional, Tuple, Type, TypeVar, Union
6 |
7 | from tqdm import tqdm
8 | from transformers import AutoTokenizer, AutoModelForCausalLM
9 | import torch
10 | import torch.nn as nn
11 | from transformers.configuration_utils import PretrainedConfig
12 | from accelerate import Accelerator
13 | from copy import deepcopy
14 | from vllm import LLM, SamplingParams
15 |
16 |
17 | from .abstract_model import prm
18 | from ..utils.prompts import PROMPT_DICT
19 | from ..utils.utils import *
20 | from ..utils.log_utils import *
21 | from ..utils.model_utils import remove_step_prefix, process_policy_lm_evaluation_response
22 | import torch.multiprocessing as mp
23 |
24 | # 设置启动方法为 spawn
25 | mp.set_start_method('spawn', force=True)
26 | logger = get_logger(__name__)
27 |
28 |
29 |
30 | class VllmModels(prm):
31 | def __init__(
32 | self,
33 | pretrained = "/mnt/petrelfs/songmingyang/songmingyang/model/reasoning/policy_models/QwQ-32B-Preview",
34 | tensor_parallel: str = "1",
35 | validity_threshold = 0,
36 | redundancy_threshold = 0,
37 | first_round_role = "user",
38 | save_to_ckpt_interval = 1000,
39 | ) -> None:
40 | super().__init__(validity_threshold=validity_threshold, redundancy_threshold=redundancy_threshold)
41 |
42 | self.tensor_parallel = int(tensor_parallel)
43 | self.first_round_role = first_round_role
44 | self.save_to_ckpt_interval = save_to_ckpt_interval
45 | pid = os.getpid()
46 | # print(f"当前进程 ID: {pid}")
47 |
48 | self.model = LLM(model=pretrained, tensor_parallel_size = self.tensor_parallel, trust_remote_code=True)
49 | self.prompt = PROMPT_DICT["policy_model_as_an_evaluator"]
50 | self.messages = [
51 | {"role": self.first_round_role, "content": self.prompt["system_prompt"]},
52 | {"role": "user", "content": self.prompt["fewshots"][0][0]},
53 | {"role": "assistant", "content": self.prompt["fewshots"][0][1]},
54 | {"role": "user", "content": self.prompt["fewshots"][1][0]},
55 | {"role": "assistant", "content": self.prompt["fewshots"][1][1]},
56 | ]
57 |
58 |
59 | def getitem_function(self,meta_data,index):
60 | data_idx = meta_data[index]["idx"]
61 | steps = meta_data[index]["steps"]
62 | question = meta_data[index]["question"]
63 |
64 | res = []
65 | for idx,step in enumerate(steps):
66 | clean_step = remove_step_prefix(step)
67 | res.append(f"Step {idx+1}: {clean_step} \n\n")
68 |
69 | steps_str = "".join(res)
70 | original_input_for_prm = f"Question: {question}\n\n Solutions: {steps_str}"
71 | messages = deepcopy(self.messages)
72 | messages.append({"role": "user", "content": original_input_for_prm})
73 |
74 |
75 | res = dict(
76 | idx = data_idx,
77 | inputs = messages,
78 | model_type = "vllm",
79 | )
80 | return res
81 |
82 | def respond(self, dataloader) -> List[Tuple[float, bool]]:
83 |
84 | # gen_kwargs = dataloader.dataset.gen_kwargs
85 | sampling_params = SamplingParams(
86 | temperature = self.generation_config.get("temperature", 0.0),
87 | top_k = self.generation_config.get("top_k", -1),
88 | top_p = self.generation_config.get("top_p", 1.0),
89 | max_tokens = self.generation_config.get("max_length", 2048),
90 | )
91 |
92 |
93 | progress_bar = tqdm_rank0(len(dataloader), desc="Model Responding")
94 |
95 | dataloader_iter = iter(dataloader)
96 | # import debugpy
97 | # debugpy.listen(address = ('0.0.0.0', 7119))
98 | # debugpy.wait_for_client()
99 | # breakpoint() # 在下一句代码处暂停
100 | # dist.barrier()
101 | with torch.no_grad():
102 | stop_flag = False
103 | while not stop_flag:
104 | data_batch = []
105 | messages = []
106 | idxs = []
107 | for iter_num in range(self.save_to_ckpt_interval):
108 | try:
109 | current_batch = next(dataloader_iter)
110 | message = current_batch['inputs']
111 | idx = current_batch['idx']
112 | messages.append(message)
113 | idxs.append(idx)
114 | except StopIteration:
115 | stop_flag = True
116 | break
117 |
118 | outputs = self.model.chat(messages, sampling_params = sampling_params)
119 |
120 | for idx, output in zip(idxs, outputs):
121 | response = output.outputs[0].text
122 | try:
123 | scores = process_policy_lm_evaluation_response(response)
124 | if scores:
125 | score_dict = dict(
126 | step_level_validity_scores = scores["validity"],
127 | step_level_redundancy_scores = scores["redundancy"],
128 | step_level_validity_labels = [item > self.validity_threshold for item in scores["validity"]],
129 | step_level_redundancy_labels = [item > self.redundancy_threshold for item in scores["redundancy"]],
130 | )
131 | res = dict(scores=score_dict, idx=idx, validity=True)
132 | else:
133 | res = dict(validity=False, idx=idx, original_response=response)
134 | dataloader.dataset.store_results(res)
135 | # log = dict(idx = current_idx, response = current_response, scores = scores, result = res)
136 | # dataloader.dataset.save_result_item_into_log(log,self.log_save_dir)
137 | except:
138 | current_response = response
139 | logger.error(f"Error in responding to idx {idx}")
140 | res = dict(validity=False, idx=idx, original_response=current_response)
141 | dataloader.dataset.store_results(res)
142 | # log = dict(idx = current_idx, response = current_response, scores = None, result = res)
143 | # dataloader.dataset.save_result_item_into_log(log,self.log_save_dir)
144 | if progress_bar is not None:
145 | progress_bar.update(self.save_to_ckpt_interval)
146 |
--------------------------------------------------------------------------------
/src/active_prm/eval/processbench.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Tuple
3 |
4 | import numpy as np
5 | import torch
6 | from datasets import load_dataset
7 |
8 | from active_prm.utils.worker import (EnsemblePRMWorker, LLMasJudgerWorker,
9 | QwenPRMWorker)
10 |
11 |
12 | class ProcessBench:
13 | def __init__(self, subset):
14 | self.data = load_dataset("Qwen/processbench")[subset]
15 |
16 | def process_results(self, preds):
17 | labels = np.array(self.data["label"])
18 | preds = np.array(preds)
19 | positive_mask = labels != -1
20 | _pos = np.mean(labels[positive_mask] == preds[positive_mask])
21 | _neg = np.mean(labels[~positive_mask] == preds[~positive_mask])
22 | # print(f"Positive: {_pos:.4f}, Negative: {_neg:.4f}")
23 | return _pos, _neg, 2 * _pos * _neg / (_pos + _neg)
24 |
25 | def process_data(self):
26 | def _process(example):
27 | list_strip = lambda x_list: [x.strip() for x in x_list]
28 | # return {"prompt": [example["problem"], list_strip(example["steps"])]}
29 | return {"prompt": [example["problem"], list_strip(example["steps"])]}
30 |
31 | return self.data.map(_process)
32 |
33 |
34 | def predict_with_threshold(rating, threshold):
35 | rating = np.array(rating)
36 | preds = rating < threshold
37 | pos = np.argmax(preds)
38 | if preds.sum() == 0:
39 | pos = -1
40 | return pos
41 |
42 |
43 | def predict_with_advantages(rating, threshold):
44 | rating = np.array(rating)
45 | advs_preds = (rating[1:] - rating[:-1]) < 0
46 | values_preds = rating[1:] < threshold
47 | preds = advs_preds & values_preds
48 | pos = np.argmax(preds)
49 | if preds.sum() == 0:
50 | pos = -1
51 | return pos
52 |
53 |
54 | def judger_entrypoint(reward_model_path, subsets="math,gsm8k,olympiadbench,omnimath", n=1, temperature=0.0):
55 | def _get_list(k):
56 | if isinstance(k, Tuple):
57 | k_list = list(k)
58 | elif isinstance(k, str):
59 | k_list = str(k).split(",")
60 | return k_list
61 |
62 | def _process_data(example):
63 | list_strip = lambda x_list: [x.strip() for x in x_list]
64 | return [example["problem"], list_strip(example["steps"])]
65 |
66 | reasoning_parser = {}
67 | if "qwq" in reward_model_path.lower() or "r1" in reward_model_path.lower():
68 | reasoning_parser = {"reasoning_parser": "deepseek-r1"}
69 |
70 | worker = LLMasJudgerWorker(reward_model_path, **reasoning_parser)
71 |
72 | subsets = _get_list(subsets)
73 | assert all(subset in ["gsm8k", "math", "olympiadbench", "omnimath"] for subset in subsets)
74 | for i, subset in enumerate(subsets):
75 | process_bench = ProcessBench(subset)
76 | # process_bench.data = process_bench.data.shuffle().select(range(3))
77 | prompts = [_process_data(row) for row in process_bench.data]
78 | preds, total_outputs = worker.generate(
79 | prompts,
80 | verbose=True,
81 | max_new_tokens=8192,
82 | temperature=temperature,
83 | return_preds_only=False,
84 | )
85 | # preds = [predict_with_threshold(logit, 0) for logit in logits]
86 | pos, neg, f1 = process_bench.process_results(preds)
87 | print(f"Positive: {pos:.4f}, Negative: {neg:.4f}, F1: {f1:.4f} for {subset}")
88 | data = process_bench.data.to_pandas()
89 | data["labeling_outputs"] = total_outputs
90 | data["preds"] = preds
91 | model_id = reward_model_path.split("/")[-1]
92 | output_dir = f"./out/bench/processbench/{model_id}/{subset}"
93 | os.makedirs(output_dir, exist_ok=True)
94 | data.to_json(os.path.join(output_dir, "outputs.jsonl"), orient="records", lines=True)
95 | print(f"Outputs saved to {output_dir}")
96 |
97 |
98 | def prm_entrypoint(reward_model_path, subsets="gsm8k,math,olympiadbench,omnimath", rating_threshold=None, **kwargs):
99 | def _get_list(k):
100 | if isinstance(k, Tuple):
101 | k_list = list(k)
102 | elif isinstance(k, str):
103 | k_list = str(k).split(",")
104 | elif isinstance(k, float):
105 | k_list = [k]
106 | return k_list
107 |
108 | def _process_data(example):
109 | list_strip = lambda x_list: [x.strip() for x in x_list]
110 | return [example["problem"], list_strip(example["steps"])]
111 |
112 | if reward_model_path in ["Qwen/Qwen2.5-Math-PRM-7B", "Qwen/Qwen2.5-Math-PRM-72B"]:
113 | worker = QwenPRMWorker(reward_model_path, torch_dtype=torch.bfloat16, **kwargs)
114 | worker.model.to("cuda")
115 | else:
116 | worker = EnsemblePRMWorker(reward_model_path, torch_dtype="auto", **kwargs)
117 | worker.model.config.problem_type = "single_label_classification"
118 | worker.model.to("cuda")
119 |
120 | subsets = _get_list(subsets)
121 | if rating_threshold is not None:
122 | thresholds = _get_list(rating_threshold)
123 | else:
124 | thresholds = [0.5]
125 |
126 | # TODO: THE FOLLOWING CODE IS NOT WORKING FOR JUDER
127 | best_threshold = None
128 | assert all(subset in ["gsm8k", "math", "olympiadbench", "omnimath"] for subset in subsets)
129 | average_f1 = []
130 | for i, subset in enumerate(subsets):
131 | process_bench = ProcessBench(subset)
132 | prompts = [_process_data(row) for row in process_bench.data]
133 | logits = worker.generate(prompts, batch_size=16, verbose=True)
134 | if i == 0:
135 | f1_list, pos_list, neg_list = [], [], []
136 | for threshold in thresholds:
137 | preds = [predict_with_threshold(logit, float(threshold)) for logit in logits]
138 | pos, neg, f1 = process_bench.process_results(preds)
139 | # print(f"F1 score: {acc:.4f} for {subset}")
140 | f1_list.append(f1)
141 | pos_list.append(pos)
142 | neg_list.append(neg)
143 | best_threshold = thresholds[np.argmax(f1_list)]
144 | print(f"Best F1: {max(f1_list):.4f} for {subset}, use this one {best_threshold} as the best threshold")
145 | print(
146 | f"Positive: {pos_list[np.argmax(f1_list)]:.4f}, Negative: {neg_list[np.argmax(f1_list)]:.4f}, F1: {f1_list[np.argmax(f1_list)]:.4f} for {subset}"
147 | )
148 | average_f1.append(f1_list[np.argmax(f1_list)])
149 | else:
150 | preds = [predict_with_threshold(logit, best_threshold) for logit in logits]
151 | pos, neg, f1 = process_bench.process_results(preds)
152 | print(f"Positive: {pos:.4f}, Negative: {neg:.4f}, F1: {f1:.4f} for {subset}")
153 | average_f1.append(f1)
154 | print(f"Average F1: {np.mean(average_f1):.4f}")
155 |
156 |
157 | if __name__ == "__main__":
158 | import fire
159 |
160 | fire.Fire({"judger": judger_entrypoint, "prm": prm_entrypoint})
161 |
--------------------------------------------------------------------------------
/src/active_prm/models/nets.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Garena Online Private Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Deep networks."""
16 |
17 | from copy import deepcopy
18 |
19 | import numpy as np
20 | import torch
21 | import torch.nn.functional as F
22 | from torch import nn
23 |
24 |
25 | def init_weights(m):
26 | @torch.no_grad()
27 | def truncated_normal_init(t, mean=0.0, std=0.01):
28 | # torch.nn.init.normal_(t, mean=mean, std=std)
29 | t.data.normal_(mean, std)
30 | while True:
31 | cond = torch.logical_or(t < mean - 2 * std, t > mean + 2 * std)
32 | if not torch.sum(cond):
33 | break
34 | w = torch.empty(t.shape, device=t.device, dtype=t.dtype)
35 | # torch.nn.init.normal_(w, mean=mean, std=std)
36 | w.data.normal_(mean, std)
37 | t = torch.where(cond, w, t)
38 | return t
39 |
40 | if type(m) is nn.Linear or isinstance(m, EnsembleFC):
41 | truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(m.in_features)))
42 | if m.bias is not None:
43 | m.bias.data.fill_(0.0)
44 |
45 |
46 | def init_weights_uniform(m):
47 | input_dim = m.in_features
48 | torch.nn.init.uniform(m.weight, -1 / np.sqrt(input_dim), 1 / np.sqrt(input_dim))
49 | if m.bias is not None:
50 | m.bias.data.fill_(0.0)
51 |
52 |
53 | class Swish(nn.Module):
54 | def __init__(self):
55 | super(Swish, self).__init__()
56 |
57 | def forward(self, x):
58 | x = x * F.sigmoid(x)
59 | return x
60 |
61 |
62 | class MLPModel(nn.Module):
63 | def __init__(self, encoding_dim, hidden_dim=128, activation="relu") -> None:
64 | super(MLPModel, self).__init__()
65 | self.hidden_size = hidden_dim
66 | self.output_dim = 1
67 |
68 | self.nn1 = nn.Linear(encoding_dim, hidden_dim)
69 | self.nn2 = nn.Linear(hidden_dim, hidden_dim)
70 | self.nn_out = nn.Linear(hidden_dim, self.output_dim)
71 |
72 | self.apply(init_weights)
73 |
74 | if activation == "swish":
75 | self.activation = Swish()
76 | elif activation == "relu":
77 | self.activation = nn.ReLU()
78 | else:
79 | raise ValueError(f"Unknown activation {activation}")
80 |
81 | def get_params(self) -> torch.Tensor:
82 | params = []
83 | for pp in list(self.parameters()):
84 | params.append(pp.view(-1))
85 | return torch.cat(params)
86 |
87 | def forward(self, encoding: torch.Tensor) -> torch.Tensor:
88 | x = self.activation(self.nn1(encoding))
89 | x = self.activation(self.nn2(x))
90 | score = self.nn_out(x)
91 | return score
92 |
93 | def init(self):
94 | self.init_params = self.get_params().data.clone()
95 | if torch.cuda.is_available():
96 | self.init_params = self.init_params.cuda()
97 |
98 | def regularization(self):
99 | """Prior towards independent initialization."""
100 | return ((self.get_params() - self.init_params) ** 2).mean()
101 |
102 |
103 | class EnsembleFC(nn.Module):
104 | __constants__ = ["in_features", "out_features"]
105 | in_features: int
106 | out_features: int
107 | ensemble_size: int
108 | weight: torch.Tensor
109 |
110 | def __init__(
111 | self,
112 | in_features: int,
113 | out_features: int,
114 | ensemble_size: int,
115 | bias: bool = True,
116 | dtype=torch.float32,
117 | ) -> None:
118 | super(EnsembleFC, self).__init__()
119 | self.in_features = in_features
120 | self.out_features = out_features
121 | self.ensemble_size = ensemble_size
122 | # init immediately to avoid error
123 | self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features, dtype=dtype))
124 | if bias:
125 | self.bias = nn.Parameter(torch.empty(ensemble_size, out_features, dtype=dtype))
126 | else:
127 | self.register_parameter("bias", None)
128 |
129 | def forward(self, input: torch.Tensor) -> torch.Tensor:
130 | input = input.to(self.weight.dtype)
131 | wx = torch.einsum("eblh,ehm->eblm", input, self.weight)
132 |
133 | return torch.add(wx, self.bias[:, None, None, :]) # w times x + b
134 |
135 |
136 | def get_params(model):
137 | return torch.cat([p.view(-1) for p in model.parameters()])
138 |
139 |
140 | class _EnsembleModel(nn.Module):
141 | def __init__(self, encoding_dim, num_ensemble, hidden_dim=128, activation="relu", dtype=torch.float32) -> None:
142 | # super().__init__(encoding_dim, hidden_dim, activation)
143 | super(_EnsembleModel, self).__init__()
144 | self.num_ensemble = num_ensemble
145 | self.hidden_dim = hidden_dim
146 | self.output_dim = 1
147 |
148 | self.nn1 = EnsembleFC(encoding_dim, hidden_dim, num_ensemble, dtype=dtype)
149 | self.nn2 = EnsembleFC(hidden_dim, hidden_dim, num_ensemble, dtype=dtype)
150 | self.nn_out = EnsembleFC(hidden_dim, self.output_dim, num_ensemble, dtype=dtype)
151 |
152 | self.apply(init_weights)
153 |
154 | if activation == "swish":
155 | self.activation = Swish()
156 | elif activation == "relu":
157 | self.activation = nn.ReLU()
158 | else:
159 | raise ValueError(f"Unknown activation {activation}")
160 |
161 | def forward(self, encoding: torch.Tensor) -> torch.Tensor:
162 | x = self.activation(self.nn1(encoding))
163 | x = self.activation(self.nn2(x))
164 | score = self.nn_out(x)
165 | return score
166 |
167 | def regularization(self):
168 | """Prior towards independent initialization."""
169 | return ((self.get_params() - self.init_params) ** 2).mean()
170 |
171 |
172 | class EnsembleModel(nn.Module):
173 | def __init__(self, encoding_dim, num_ensemble, hidden_dim=128, activation="relu", dtype=torch.float32) -> None:
174 | super(EnsembleModel, self).__init__()
175 | self.encoding_dim = encoding_dim
176 | self.num_ensemble = num_ensemble
177 | self.hidden_dim = hidden_dim
178 | self.model = _EnsembleModel(encoding_dim, num_ensemble, hidden_dim, activation, dtype)
179 | self.reg_model = deepcopy(self.model) # only used for regularization
180 | # freeze the reg model
181 | for param in self.reg_model.parameters():
182 | param.requires_grad = False
183 |
184 | def forward(self, encoding: torch.Tensor) -> torch.Tensor:
185 | return self.model(encoding)
186 |
187 | def regularization(self):
188 | """Prior towards independent initialization."""
189 | model_params = get_params(self.model)
190 | reg_params = get_params(self.reg_model).detach()
191 | return ((model_params - reg_params) ** 2).mean()
192 |
--------------------------------------------------------------------------------